Sophia Tang commited on
Commit ·
40e7e76
0
Parent(s):
initial commit
Browse files- .gitattributes +2 -0
- .gitignore +1 -0
- README.md +193 -0
- assets/mcts.png +3 -0
- assets/mdlm.png +3 -0
- assets/peptune.png +3 -0
- assets/poster.png +3 -0
- data/dataloading_for_dynamic_batching.py +156 -0
- data/dataset.py +207 -0
- scripts/generate_mcts.sh +57 -0
- scripts/generate_unconditional.sh +16 -0
- scripts/train.sh +18 -0
- src/config.py +319 -0
- src/config.yaml +164 -0
- src/diffusion.py +1015 -0
- src/environment.yml +40 -0
- src/generate_mcts.py +365 -0
- src/generate_unconditional.py +111 -0
- src/metrics.py +72 -0
- src/noise_schedule.py +152 -0
- src/pareto_mcts.py +492 -0
- src/roformer.py +74 -0
- src/scoring/functions/binding.py +178 -0
- src/scoring/functions/binding_utils.py +290 -0
- src/scoring/functions/classifiers/hemolysis-xgboost.json +0 -0
- src/scoring/functions/classifiers/nonfouling-xgboost.json +0 -0
- src/scoring/functions/classifiers/permeability-xgboost.json +3 -0
- src/scoring/functions/classifiers/solubility-xgboost.json +0 -0
- src/scoring/functions/hemolysis.py +63 -0
- src/scoring/functions/nonfouling.py +66 -0
- src/scoring/functions/permeability.py +171 -0
- src/scoring/functions/scoring_utils.py +94 -0
- src/scoring/functions/solubility.py +63 -0
- src/scoring/scoring_functions.py +75 -0
- src/scoring/tokenizer/my_tokenizers.py +424 -0
- src/scoring/tokenizer/new_splits.txt +159 -0
- src/scoring/tokenizer/new_vocab.txt +587 -0
- src/tokenizer/__init__.py +0 -0
- src/tokenizer/my_tokenizers.py +441 -0
- src/tokenizer/new_splits.txt +159 -0
- src/tokenizer/new_vocab.txt +587 -0
- src/train.py +133 -0
- src/train_peptune.py +226 -0
- src/utils/app.py +1255 -0
- src/utils/generate_utils.py +77 -0
- src/utils/utils.py +256 -0
.gitattributes
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
src/scoring/functions/classifiers/permeability-xgboost.json filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
assets/*.png filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
logs/
|
README.md
ADDED
|
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# [PepTune: De Novo Generation of Therapeutic Peptides with Multi-Objective-Guided Discrete Diffusion](https://arxiv.org/abs/2412.17780) 🧬🔮 (ICML 2025)
|
| 2 |
+
|
| 3 |
+
[**Sophia Tang**](https://sophtang.github.io/)\*, [**Yinuo Zhang**](https://www.linkedin.com/in/yinuozhang98/)\* and [**Pranam Chatterjee**](https://www.chatterjeelab.com/)
|
| 4 |
+
|
| 5 |
+

|
| 6 |
+
|
| 7 |
+
This is the repository for **[PepTune: De Novo Generation of Therapeutic Peptides with Multi-Objective-Guided Discrete Diffusion](https://arxiv.org/abs/2412.17780)** 🧬🔮 published at **ICML 2025**. It is partially built on the **[MDLM repo](https://github.com/kuleshov-group/mdlm)** ([Sahoo et al. 2024](https://arxiv.org/abs/2406.07524)).
|
| 8 |
+
|
| 9 |
+
PepTune leverages **Monte-Carlo Tree Search (MCTS)** to guide a generative masked discrete diffusion model which iteratively refines a set of Pareto non-dominated sequences optimized across a set of therapeutic properties, including binding affinity, cell membrane permeability, solubility, non-fouling, and non-hemolysis.
|
| 10 |
+
|
| 11 |
+
## Environment Installation
|
| 12 |
+
|
| 13 |
+
```bash
|
| 14 |
+
conda env create -f src/environment.yml
|
| 15 |
+
|
| 16 |
+
conda activate peptune
|
| 17 |
+
```
|
| 18 |
+
|
| 19 |
+
## Model Pretrained Weights Download
|
| 20 |
+
|
| 21 |
+
Follow the steps below to download the model weights required for this experiment.
|
| 22 |
+
|
| 23 |
+
1. Download the PepTune pre-trained MDLM checkpoint and place in `checkpoints/`: https://drive.google.com/file/d/1oXGDpKLNF0KX0ZdOcl1NZj5Czk2lSFUn/view?usp=sharing
|
| 24 |
+
2. Download the pre-trained binding affinity Transformer model and place in `src/scoring/functions/classifiers/`: https://drive.google.com/file/d/128shlEP_-rYAxPgZRCk_n0HBWVbOYSva/view?usp=sharing
|
| 25 |
+
|
| 26 |
+
## Training Data Download
|
| 27 |
+
|
| 28 |
+
Download the peptide training dataset from https://drive.google.com/file/d/1yCDr641WVjCtECg3nbG0nsMNu8j7d7gp/view?usp=drive_link and unzip it into the `data/` directory:
|
| 29 |
+
|
| 30 |
+
```bash
|
| 31 |
+
# Download peptide_data.zip into the data/ directory
|
| 32 |
+
cd data/
|
| 33 |
+
|
| 34 |
+
# Unzip the training data
|
| 35 |
+
unzip peptide_data.zip
|
| 36 |
+
|
| 37 |
+
cd ..
|
| 38 |
+
```
|
| 39 |
+
|
| 40 |
+
After unzipping, the data should be located at `data/peptide_data/`.
|
| 41 |
+
|
| 42 |
+
## Repository Structure
|
| 43 |
+
|
| 44 |
+
```
|
| 45 |
+
PepTune/
|
| 46 |
+
├── src/
|
| 47 |
+
│ ├── train_peptune.py # Main training script
|
| 48 |
+
│ ├── generate_mcts.py # MCTS-guided peptide generation
|
| 49 |
+
│ ├── generate_unconditional.py # Unconditional generation
|
| 50 |
+
│ ├── diffusion.py # Core masked discrete diffusion model
|
| 51 |
+
│ ├── pareto_mcts.py # Pareto-front MCTS implementation
|
| 52 |
+
│ ├── roformer.py # RoFormer backbone
|
| 53 |
+
│ ├── noise_schedule.py # Noise scheduling (loglinear, logpoly)
|
| 54 |
+
│ ├── config.yaml # Hydra configuration
|
| 55 |
+
│ ├── config.py # Argparse configuration
|
| 56 |
+
│ ├── environment.yml # Conda environment
|
| 57 |
+
│ ├── scoring/ # Therapeutic property scoring
|
| 58 |
+
│ │ ├── scoring_functions.py # Unified scoring interface
|
| 59 |
+
│ │ └── functions/ # Individual property predictors
|
| 60 |
+
│ │ ├── binding.py
|
| 61 |
+
│ │ ├── hemolysis.py
|
| 62 |
+
│ │ ├── nonfouling.py
|
| 63 |
+
│ │ ├── permeability.py
|
| 64 |
+
│ │ ├── solubility.py
|
| 65 |
+
│ │ └── classifiers/ # Pre-trained scoring model weights
|
| 66 |
+
│ ├── tokenizer/ # SMILES SPE tokenizer
|
| 67 |
+
│ │ ├── my_tokenizers.py
|
| 68 |
+
│ │ ├── new_vocab.txt
|
| 69 |
+
│ │ └── new_splits.txt
|
| 70 |
+
│ └── utils/ # Utilities & PeptideAnalyzer
|
| 71 |
+
│ ├── app.py
|
| 72 |
+
│ ├── generate_utils.py
|
| 73 |
+
│ └── utils.py
|
| 74 |
+
├── scripts/ # Shell scripts for running experiments
|
| 75 |
+
│ ├── train.sh # Pre-training
|
| 76 |
+
│ ├── generate_mcts.sh # MCTS-guided generation
|
| 77 |
+
│ └── generate_unconditional.sh # Unconditional generation
|
| 78 |
+
├── data/ # Training data
|
| 79 |
+
│ ├── dataloading_for_dynamic_batching.py
|
| 80 |
+
│ └── dataset.py
|
| 81 |
+
├── checkpoints/ # Model checkpoints
|
| 82 |
+
└── assets/ # Figures
|
| 83 |
+
```
|
| 84 |
+
|
| 85 |
+
## Pre-training
|
| 86 |
+
|
| 87 |
+
Before running, fill in `HOME_LOC` and `ENV_LOC` in `scripts/train.sh` and `base_path` in `src/config.yaml` to match your paths.
|
| 88 |
+
|
| 89 |
+
```bash
|
| 90 |
+
chmod +x scripts/train.sh
|
| 91 |
+
|
| 92 |
+
nohup ./scripts/train.sh > train.log 2>&1 &
|
| 93 |
+
```
|
| 94 |
+
|
| 95 |
+
Training uses Hydra configuration from `src/config.yaml`. Key settings:
|
| 96 |
+
- **Backbone:** RoFormer (768 hidden, 8 layers, 12 heads)
|
| 97 |
+
- **Optimizer:** AdamW (lr=3e-4, weight_decay=0.075)
|
| 98 |
+
- **Data:** 11M SMILES peptide dataset with dynamic batching by length
|
| 99 |
+
- **Precision:** fp64
|
| 100 |
+
- Checkpoints saved to `checkpoints/` (monitors `val/nll`, saves top 10)
|
| 101 |
+
|
| 102 |
+
## MCTS-Guided Peptide Generation
|
| 103 |
+
|
| 104 |
+
Generate therapeutic peptides optimized across multiple objectives using Monte-Carlo Tree Search.
|
| 105 |
+
|
| 106 |
+
1. Fill in `base_path` in `src/config.yaml` and `src/scoring/scoring_functions.py`.
|
| 107 |
+
2. Fill in `HOME_LOC` in `scripts/generate_mcts.sh`.
|
| 108 |
+
3. Create output directories: `mkdir -p results logs`
|
| 109 |
+
|
| 110 |
+
```bash
|
| 111 |
+
chmod +x scripts/generate_mcts.sh
|
| 112 |
+
|
| 113 |
+
# Usage: ./scripts/generate_mcts.sh [PROT_NAME] [PROT_NAME2] [MODE] [MODEL] [LENGTH] [EPOCH]
|
| 114 |
+
# Example: Generate peptides targeting GFAP with length 100
|
| 115 |
+
nohup ./scripts/generate_mcts.sh gfap "" 2 mcts 100 7 > generate.log 2>&1 &
|
| 116 |
+
```
|
| 117 |
+
|
| 118 |
+
### Available Target Proteins
|
| 119 |
+
|
| 120 |
+
| Name | Target |
|
| 121 |
+
|------|--------|
|
| 122 |
+
| `amhr` | AMH Receptor |
|
| 123 |
+
| `tfr` | Transferrin Receptor |
|
| 124 |
+
| `gfap` | Glial Fibrillary Acidic Protein |
|
| 125 |
+
| `glp1` | GLP-1 Receptor |
|
| 126 |
+
| `glast` | Excitatory Amino Acid Transporter |
|
| 127 |
+
| `ncam` | Neural Cell Adhesion Molecule |
|
| 128 |
+
| `cereblon` | Cereblon (CRBN) |
|
| 129 |
+
| `ligase` | E3 Ubiquitin Ligase |
|
| 130 |
+
| `skp2` | S-Phase Kinase-Associated Protein 2 |
|
| 131 |
+
| `p53` | Tumor Suppressor p53 |
|
| 132 |
+
| `egfp` | Enhanced Green Fluorescent Protein |
|
| 133 |
+
|
| 134 |
+
To specify a custom target protein, override `+prot_seq=<amino acid sequence>` and `+prot_name=<name>` as Hydra arguments in the generation script.
|
| 135 |
+
|
| 136 |
+
### Scoring Objectives
|
| 137 |
+
|
| 138 |
+
PepTune jointly optimizes across five therapeutic properties via the integrated scoring suite:
|
| 139 |
+
|
| 140 |
+
| Objective | Property | Model |
|
| 141 |
+
|-----------|----------|-------|
|
| 142 |
+
| `binding_affinity1` | Binding affinity to target protein | Cross-attention Transformer |
|
| 143 |
+
| `solubility` | Aqueous solubility | XGBoost on SMILES CNN embeddings |
|
| 144 |
+
| `hemolysis` | Non-hemolytic | SMILES binary classifier |
|
| 145 |
+
| `nonfouling` | Non-fouling | SMILES binary classifier |
|
| 146 |
+
| `permeability` | Cell membrane permeability | PAMPA CNN |
|
| 147 |
+
|
| 148 |
+
### Default MCTS Hyperparameters
|
| 149 |
+
|
| 150 |
+
These can be overridden via Hydra config overrides:
|
| 151 |
+
|
| 152 |
+
| Parameter | Default | Description |
|
| 153 |
+
|-----------|---------|-------------|
|
| 154 |
+
| `mcts.num_children` | 50 | Branching factor per MCTS node |
|
| 155 |
+
| `mcts.num_iter` | 128 | Number of MCTS iterations |
|
| 156 |
+
| `mcts.num_objectives` | 5 | Number of optimization objectives |
|
| 157 |
+
| `sampling.steps` | 128 | Diffusion denoising steps |
|
| 158 |
+
| `sampling.seq_length` | 200 | Generated peptide length |
|
| 159 |
+
|
| 160 |
+
## Unconditional Generation
|
| 161 |
+
|
| 162 |
+
Generate peptides without property guidance:
|
| 163 |
+
|
| 164 |
+
```bash
|
| 165 |
+
chmod +x scripts/generate_unconditional.sh
|
| 166 |
+
|
| 167 |
+
nohup ./scripts/generate_unconditional.sh > generate_unconditional.log 2>&1 &
|
| 168 |
+
```
|
| 169 |
+
|
| 170 |
+
## Evaluation
|
| 171 |
+
|
| 172 |
+
To summarize metrics after generation, fill in `path` and `prot_name` in `src/metrics.py` and run:
|
| 173 |
+
|
| 174 |
+
```bash
|
| 175 |
+
python src/metrics.py
|
| 176 |
+
```
|
| 177 |
+
|
| 178 |
+
## Citation
|
| 179 |
+
|
| 180 |
+
If you find this repository helpful for your publications, please consider citing our paper:
|
| 181 |
+
|
| 182 |
+
```bibtex
|
| 183 |
+
@article{tang2025peptune,
|
| 184 |
+
title={Peptune: De novo generation of therapeutic peptides with multi-objective-guided discrete diffusion},
|
| 185 |
+
author={Tang, Sophia and Zhang, Yinuo and Chatterjee, Pranam},
|
| 186 |
+
journal={42nd International Conference on Machine Learning},
|
| 187 |
+
year={2025}
|
| 188 |
+
}
|
| 189 |
+
```
|
| 190 |
+
|
| 191 |
+
## License
|
| 192 |
+
|
| 193 |
+
To use this repository, you agree to abide by the MIT License.
|
assets/mcts.png
ADDED
|
Git LFS Details
|
assets/mdlm.png
ADDED
|
Git LFS Details
|
assets/peptune.png
ADDED
|
Git LFS Details
|
assets/poster.png
ADDED
|
Git LFS Details
|
data/dataloading_for_dynamic_batching.py
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env
|
| 2 |
+
import torch
|
| 3 |
+
from torch.utils.data import Dataset, DataLoader
|
| 4 |
+
from datasets import Dataset,load_from_disk
|
| 5 |
+
import sys
|
| 6 |
+
import lightning.pytorch as pl
|
| 7 |
+
from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
|
| 8 |
+
from functools import partial
|
| 9 |
+
import re
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class DynamicBatchingDataset(Dataset):
|
| 13 |
+
def __init__(self, dataset_dict, tokenizer):
|
| 14 |
+
print('Initializing dataset...')
|
| 15 |
+
self.dataset_dict = {
|
| 16 |
+
'attention_mask': [torch.tensor(item) for item in dataset_dict['attention_mask']],
|
| 17 |
+
'input_ids': [torch.tensor(item) for item in dataset_dict['input_ids']],
|
| 18 |
+
'labels': dataset_dict['labels']
|
| 19 |
+
}
|
| 20 |
+
self.tokenizer = tokenizer
|
| 21 |
+
|
| 22 |
+
def __len__(self):
|
| 23 |
+
return len(self.dataset_dict['attention_mask'])
|
| 24 |
+
|
| 25 |
+
def __getitem__(self, idx):
|
| 26 |
+
if isinstance(idx, int):
|
| 27 |
+
return {
|
| 28 |
+
'input_ids': self.dataset_dict['input_ids'][idx],
|
| 29 |
+
'attention_mask': self.dataset_dict['attention_mask'][idx],
|
| 30 |
+
'labels': self.dataset_dict['labels'][idx]
|
| 31 |
+
}
|
| 32 |
+
elif isinstance(idx, list):
|
| 33 |
+
return {
|
| 34 |
+
'input_ids': [self.dataset_dict['input_ids'][i] for i in idx],
|
| 35 |
+
'attention_mask': [self.dataset_dict['attention_mask'][i] for i in idx],
|
| 36 |
+
'labels': [self.dataset_dict['labels'][i] for i in idx]
|
| 37 |
+
}
|
| 38 |
+
else:
|
| 39 |
+
raise ValueError(f"Expected idx to be int or list, but got {type(idx)}")
|
| 40 |
+
|
| 41 |
+
class CustomDataModule(pl.LightningDataModule):
|
| 42 |
+
def __init__(self, dataset_path, tokenizer):
|
| 43 |
+
super().__init__()
|
| 44 |
+
self.dataset = load_from_disk(dataset_path)
|
| 45 |
+
self.tokenizer = tokenizer
|
| 46 |
+
|
| 47 |
+
def peptide_bond_mask(self, smiles_list):
|
| 48 |
+
"""
|
| 49 |
+
Returns a mask with shape (batch_size, seq_length) that has 1 at the locations
|
| 50 |
+
of recognized bonds in the positions dictionary and 0 elsewhere.
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
smiles_list: List of peptide SMILES strings (batch of SMILES strings).
|
| 54 |
+
|
| 55 |
+
Returns:
|
| 56 |
+
np.ndarray: A mask of shape (batch_size, seq_length) with 1s at bond positions.
|
| 57 |
+
"""
|
| 58 |
+
# Initialize the batch mask
|
| 59 |
+
batch_size = len(smiles_list)
|
| 60 |
+
max_seq_length = 1035 #max(len(smiles) for smiles in smiles_list) # Find the longest SMILES
|
| 61 |
+
mask = torch.zeros((batch_size, max_seq_length), dtype=torch.int) # Mask filled with zeros
|
| 62 |
+
|
| 63 |
+
bond_patterns = [
|
| 64 |
+
(r'OC\(=O\)', 'ester'),
|
| 65 |
+
(r'N\(C\)C\(=O\)', 'n_methyl'),
|
| 66 |
+
(r'N[12]C\(=O\)', 'peptide'), # Pro peptide bonds
|
| 67 |
+
(r'NC\(=O\)', 'peptide'), # Regular peptide bonds
|
| 68 |
+
(r'C\(=O\)N\(C\)', 'n_methyl'),
|
| 69 |
+
(r'C\(=O\)N[12]?', 'peptide')
|
| 70 |
+
]
|
| 71 |
+
|
| 72 |
+
for batch_idx, smiles in enumerate(smiles_list):
|
| 73 |
+
positions = []
|
| 74 |
+
used = set()
|
| 75 |
+
|
| 76 |
+
# Identify bonds
|
| 77 |
+
for pattern, bond_type in bond_patterns:
|
| 78 |
+
for match in re.finditer(pattern, smiles):
|
| 79 |
+
if not any(p in range(match.start(), match.end()) for p in used):
|
| 80 |
+
positions.append({
|
| 81 |
+
'start': match.start(),
|
| 82 |
+
'end': match.end(),
|
| 83 |
+
'type': bond_type,
|
| 84 |
+
'pattern': match.group()
|
| 85 |
+
})
|
| 86 |
+
used.update(range(match.start(), match.end()))
|
| 87 |
+
|
| 88 |
+
# Update the mask for the current SMILES
|
| 89 |
+
for pos in positions:
|
| 90 |
+
mask[batch_idx, pos['start']:pos['end']] = 1
|
| 91 |
+
|
| 92 |
+
return mask
|
| 93 |
+
|
| 94 |
+
def peptide_token_mask(self, smiles_list, token_lists):
|
| 95 |
+
"""
|
| 96 |
+
Returns a mask with shape (batch_size, num_tokens) that has 1 for tokens
|
| 97 |
+
where any part of the token overlaps with a peptide bond, and 0 elsewhere.
|
| 98 |
+
|
| 99 |
+
Args:
|
| 100 |
+
smiles_list: List of peptide SMILES strings (batch of SMILES strings).
|
| 101 |
+
token_lists: List of tokenized SMILES strings (split into tokens).
|
| 102 |
+
|
| 103 |
+
Returns:
|
| 104 |
+
np.ndarray: A mask of shape (batch_size, num_tokens) with 1s for peptide bond tokens.
|
| 105 |
+
"""
|
| 106 |
+
# Initialize the batch mask
|
| 107 |
+
batch_size = len(smiles_list)
|
| 108 |
+
token_seq_length = max(len(tokens) for tokens in token_lists) # Find the longest tokenized sequence
|
| 109 |
+
tokenized_masks = torch.zeros((batch_size, token_seq_length), dtype=torch.int) # Mask filled with zeros
|
| 110 |
+
atomwise_masks = self.peptide_bond_mask(smiles_list)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
for batch_idx, atomwise_mask in enumerate(atomwise_masks):
|
| 114 |
+
token_seq = token_lists[batch_idx]
|
| 115 |
+
atom_idx = 0
|
| 116 |
+
|
| 117 |
+
for token_idx, token in enumerate(token_seq):
|
| 118 |
+
if token_idx != 0 and token_idx != len(token_seq) - 1:
|
| 119 |
+
if torch.sum(atomwise_mask[atom_idx:atom_idx+len(token)]) >= 1:
|
| 120 |
+
tokenized_masks[batch_idx][token_idx] = 1
|
| 121 |
+
atom_idx += len(token)
|
| 122 |
+
|
| 123 |
+
return tokenized_masks
|
| 124 |
+
|
| 125 |
+
def collate_fn(self, batch):
|
| 126 |
+
item = batch[0]
|
| 127 |
+
|
| 128 |
+
token_array = self.tokenizer.get_token_split(item['input_ids'])
|
| 129 |
+
bond_mask = self.peptide_token_mask(item['labels'], token_array)
|
| 130 |
+
|
| 131 |
+
return {
|
| 132 |
+
'input_ids': item['input_ids'],
|
| 133 |
+
'attention_mask': item['attention_mask'],
|
| 134 |
+
'bond_mask': bond_mask
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
def train_dataloader(self):
|
| 138 |
+
train_dataset = DynamicBatchingDataset(self.dataset['train'], tokenizer=self.tokenizer)
|
| 139 |
+
return DataLoader(
|
| 140 |
+
train_dataset,
|
| 141 |
+
batch_size=1,
|
| 142 |
+
collate_fn=self.collate_fn, # Use the instance method
|
| 143 |
+
shuffle=True,
|
| 144 |
+
num_workers=12,
|
| 145 |
+
pin_memory=True
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
def val_dataloader(self):
|
| 149 |
+
val_dataset = DynamicBatchingDataset(self.dataset['val'], tokenizer=self.tokenizer)
|
| 150 |
+
return DataLoader(
|
| 151 |
+
val_dataset,
|
| 152 |
+
batch_size=1,
|
| 153 |
+
collate_fn=self.collate_fn, # Use the instance method
|
| 154 |
+
num_workers=8,
|
| 155 |
+
pin_memory=True
|
| 156 |
+
)
|
data/dataset.py
ADDED
|
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import re
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
import utils
|
| 6 |
+
|
| 7 |
+
from torch.utils.data import Dataset, DataLoader
|
| 8 |
+
import lightning.pytorch as pl
|
| 9 |
+
from functools import partial
|
| 10 |
+
import sys
|
| 11 |
+
|
| 12 |
+
class CustomDataset(Dataset):
|
| 13 |
+
def __init__(self, dataset, indices):
|
| 14 |
+
self.dataset = dataset
|
| 15 |
+
self.indices = indices
|
| 16 |
+
|
| 17 |
+
def __len__(self):
|
| 18 |
+
return len(self.indices)
|
| 19 |
+
|
| 20 |
+
def __getitem__(self, idx):
|
| 21 |
+
actual_idx = int(self.indices[idx])
|
| 22 |
+
item = self.dataset[actual_idx]
|
| 23 |
+
return item
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
# for weighting losses of peptide bonds
|
| 27 |
+
def peptide_bond_mask(smiles_list):
|
| 28 |
+
"""
|
| 29 |
+
Returns a mask with shape (batch_size, seq_length) that has 1 at the locations
|
| 30 |
+
of recognized bonds in the positions dictionary and 0 elsewhere.
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
smiles_list: List of peptide SMILES strings (batch of SMILES strings).
|
| 34 |
+
|
| 35 |
+
Returns:
|
| 36 |
+
np.ndarray: A mask of shape (batch_size, seq_length) with 1s at bond positions.
|
| 37 |
+
"""
|
| 38 |
+
# Initialize the batch mask
|
| 39 |
+
batch_size = len(smiles_list)
|
| 40 |
+
max_seq_length = max(len(smiles) for smiles in smiles_list) # Find the longest SMILES
|
| 41 |
+
mask = torch.zeros((batch_size, max_seq_length), dtype=torch.int) # Mask filled with zeros
|
| 42 |
+
|
| 43 |
+
bond_patterns = [
|
| 44 |
+
(r'OC\(=O\)', 'ester'),
|
| 45 |
+
(r'N\(C\)C\(=O\)', 'n_methyl'),
|
| 46 |
+
(r'N[12]C\(=O\)', 'peptide'), # Pro peptide bonds
|
| 47 |
+
(r'NC\(=O\)', 'peptide'), # Regular peptide bonds
|
| 48 |
+
(r'C\(=O\)N\(C\)', 'n_methyl'),
|
| 49 |
+
(r'C\(=O\)N[12]?', 'peptide')
|
| 50 |
+
]
|
| 51 |
+
|
| 52 |
+
for batch_idx, smiles in enumerate(smiles_list):
|
| 53 |
+
positions = []
|
| 54 |
+
used = set()
|
| 55 |
+
|
| 56 |
+
# Identify bonds
|
| 57 |
+
for pattern, bond_type in bond_patterns:
|
| 58 |
+
for match in re.finditer(pattern, smiles):
|
| 59 |
+
if not any(p in range(match.start(), match.end()) for p in used):
|
| 60 |
+
positions.append({
|
| 61 |
+
'start': match.start(),
|
| 62 |
+
'end': match.end(),
|
| 63 |
+
'type': bond_type,
|
| 64 |
+
'pattern': match.group()
|
| 65 |
+
})
|
| 66 |
+
used.update(range(match.start(), match.end()))
|
| 67 |
+
|
| 68 |
+
# Update the mask for the current SMILES
|
| 69 |
+
for pos in positions:
|
| 70 |
+
mask[batch_idx, pos['start']:pos['end']] = 1
|
| 71 |
+
|
| 72 |
+
return mask
|
| 73 |
+
|
| 74 |
+
def peptide_token_mask(smiles_list, token_lists):
|
| 75 |
+
"""
|
| 76 |
+
Returns a mask with shape (batch_size, num_tokens) that has 1 for tokens
|
| 77 |
+
where any part of the token overlaps with a peptide bond, and 0 elsewhere.
|
| 78 |
+
|
| 79 |
+
Args:
|
| 80 |
+
smiles_list: List of peptide SMILES strings (batch of SMILES strings).
|
| 81 |
+
token_lists: List of tokenized SMILES strings (split into tokens).
|
| 82 |
+
|
| 83 |
+
Returns:
|
| 84 |
+
np.ndarray: A mask of shape (batch_size, num_tokens) with 1s for peptide bond tokens.
|
| 85 |
+
"""
|
| 86 |
+
# Initialize the batch mask
|
| 87 |
+
batch_size = len(smiles_list)
|
| 88 |
+
token_seq_length = max(len(tokens) for tokens in token_lists) # Find the longest tokenized sequence
|
| 89 |
+
tokenized_masks = torch.zeros((batch_size, token_seq_length), dtype=torch.int) # Mask filled with zeros
|
| 90 |
+
atomwise_masks = peptide_bond_mask(smiles_list)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
for batch_idx, atomwise_mask in enumerate(atomwise_masks):
|
| 94 |
+
token_seq = token_lists[batch_idx]
|
| 95 |
+
atom_idx = 0
|
| 96 |
+
|
| 97 |
+
for token_idx, token in enumerate(token_seq):
|
| 98 |
+
if token_idx != 0 and token_idx != len(token_seq) - 1:
|
| 99 |
+
if torch.sum(atomwise_mask[atom_idx:atom_idx+len(token)]) >= 1:
|
| 100 |
+
tokenized_masks[batch_idx][token_idx] = 1
|
| 101 |
+
atom_idx += len(token)
|
| 102 |
+
|
| 103 |
+
return tokenized_masks
|
| 104 |
+
|
| 105 |
+
def extract_amino_acid_sequence(helm_string):
|
| 106 |
+
"""
|
| 107 |
+
Extracts the amino acid sequence from a HELM peptide notation and outputs it as an array,
|
| 108 |
+
removing any brackets around each amino acid.
|
| 109 |
+
|
| 110 |
+
Args:
|
| 111 |
+
helm_string (str): The HELM notation string for a peptide.
|
| 112 |
+
|
| 113 |
+
Returns:
|
| 114 |
+
list: A list containing each amino acid in sequence without brackets.
|
| 115 |
+
"""
|
| 116 |
+
# Use regex to find the pattern within `{}` brackets following "PEPTIDE" followed by a number
|
| 117 |
+
matches = re.findall(r'PEPTIDE\d+\{([^}]+)\}', helm_string)
|
| 118 |
+
|
| 119 |
+
if matches:
|
| 120 |
+
# Join all matched sequences and split by dots to get individual amino acids
|
| 121 |
+
amino_acid_sequence = []
|
| 122 |
+
for match in matches:
|
| 123 |
+
sequence = match.replace('[', '').replace(']', '').split('.')
|
| 124 |
+
amino_acid_sequence.extend(sequence)
|
| 125 |
+
return amino_acid_sequence
|
| 126 |
+
else:
|
| 127 |
+
return "Invalid HELM notation or no peptide sequence found."
|
| 128 |
+
|
| 129 |
+
def helm_collate_fn(batch, tokenizer):
|
| 130 |
+
sequences = [item['HELM'] for item in batch]
|
| 131 |
+
|
| 132 |
+
max_len = 0
|
| 133 |
+
for sequence in sequences:
|
| 134 |
+
seq_len = len(extract_amino_acid_sequence(sequence))
|
| 135 |
+
if seq_len > max_len:
|
| 136 |
+
max_len = seq_len
|
| 137 |
+
|
| 138 |
+
tokens = tokenizer(sequences, return_tensors='pt', padding=True, truncation=True, max_length=1024)
|
| 139 |
+
|
| 140 |
+
return {
|
| 141 |
+
'input_ids': tokens['input_ids'],
|
| 142 |
+
'attention_mask': tokens['attention_mask']
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def collate_fn(batch, tokenizer):
|
| 147 |
+
"""Standard data collator that truncates/pad sequences based on max_length"""
|
| 148 |
+
valid_sequences = []
|
| 149 |
+
valid_items = []
|
| 150 |
+
|
| 151 |
+
for item in batch:
|
| 152 |
+
try:
|
| 153 |
+
test_tokens = tokenizer([item['SMILES']], return_tensors='pt', padding=False, truncation=True, max_length=1035)
|
| 154 |
+
valid_sequences.append(item['SMILES'])
|
| 155 |
+
valid_items.append(item)
|
| 156 |
+
except Exception as e:
|
| 157 |
+
print(f"Skipping sequence due to: {str(e)}")
|
| 158 |
+
continue
|
| 159 |
+
|
| 160 |
+
#sequences = [item['SMILES'] for item in batch]
|
| 161 |
+
#max_len = max([len(seq) for seq in sequences])
|
| 162 |
+
#labels = torch.tensor([item['labels'] for item in batch], dtype=torch.float32)
|
| 163 |
+
|
| 164 |
+
tokens = tokenizer(valid_sequences, return_tensors='pt', padding=True, truncation=True, max_length=1035)
|
| 165 |
+
|
| 166 |
+
token_array = tokenizer.get_token_split(tokens['input_ids'])
|
| 167 |
+
bond_mask = peptide_token_mask(valid_sequences, token_array)
|
| 168 |
+
#attention_masks = torch.ones(tokens.size()[:2], dtype=torch.bool)
|
| 169 |
+
|
| 170 |
+
return {
|
| 171 |
+
'input_ids': tokens['input_ids'],
|
| 172 |
+
'attention_mask': tokens['attention_mask'],
|
| 173 |
+
'bond_mask': bond_mask
|
| 174 |
+
}
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
class CustomDataModule(pl.LightningDataModule):
|
| 178 |
+
def __init__(self, train_dataset, val_dataset, test_dataset, tokenizer, batch_size, collate_fn=collate_fn):
|
| 179 |
+
super().__init__()
|
| 180 |
+
self.train_dataset = train_dataset
|
| 181 |
+
self.val_dataset = val_dataset
|
| 182 |
+
#self.test_dataset = test_dataset
|
| 183 |
+
self.batch_size = batch_size
|
| 184 |
+
self.tokenizer = tokenizer
|
| 185 |
+
self.collate_fn = collate_fn
|
| 186 |
+
|
| 187 |
+
def train_dataloader(self):
|
| 188 |
+
return DataLoader(self.train_dataset,
|
| 189 |
+
batch_size=self.batch_size,
|
| 190 |
+
collate_fn=partial(self.collate_fn, tokenizer=self.tokenizer),
|
| 191 |
+
num_workers=8,
|
| 192 |
+
pin_memory=True
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
def val_dataloader(self):
|
| 197 |
+
return DataLoader(self.val_dataset,
|
| 198 |
+
batch_size=self.batch_size,
|
| 199 |
+
collate_fn=partial(self.collate_fn, tokenizer=self.tokenizer),
|
| 200 |
+
num_workers=8,
|
| 201 |
+
pin_memory=True
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
"""def test_dataloader(self):
|
| 205 |
+
return DataLoader(self.test_dataset, batch_size=self.batch_size,
|
| 206 |
+
collate_fn=partial(self.collate_fn, tokenizer=self.tokenizer),
|
| 207 |
+
num_workers=8, pin_memory=True)"""
|
scripts/generate_mcts.sh
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
HOME_LOC=/path/to/your/home/PepTune
|
| 4 |
+
SCRIPT_LOC=$HOME_LOC/src
|
| 5 |
+
LOG_LOC=$HOME_LOC/logs
|
| 6 |
+
DATE=$(date +%m_%d)
|
| 7 |
+
SPECIAL_PREFIX='mcts'
|
| 8 |
+
PYTHON_EXECUTABLE=python
|
| 9 |
+
|
| 10 |
+
# ===================================================================
|
| 11 |
+
# Default parameters (can be overridden by command line arguments)
|
| 12 |
+
# Available proteins: amhr, tfr, gfap, glp1, glast, ncam, cereblon, ligase, skp2, p53, egfp
|
| 13 |
+
PROT_NAME1=${1:-"gfap"}
|
| 14 |
+
PROT_NAME2=${2:-""}
|
| 15 |
+
MODE=${3:-"2"}
|
| 16 |
+
MODEL=${4:-"mcts"}
|
| 17 |
+
LENGTH=${5:-"100"}
|
| 18 |
+
EPOCH=${6:-"7"}
|
| 19 |
+
CKPT=$HOME_LOC/checkpoints/epoch13-new-tokenizer.ckpt
|
| 20 |
+
|
| 21 |
+
# ===================================================================
|
| 22 |
+
echo "Activating conda environment..."
|
| 23 |
+
source "$(conda info --base)/etc/profile.d/conda.sh"
|
| 24 |
+
conda activate peptune
|
| 25 |
+
|
| 26 |
+
mkdir -p "${HOME_LOC}/${PROT_NAME1}"
|
| 27 |
+
mkdir -p "${LOG_LOC}"
|
| 28 |
+
|
| 29 |
+
echo "Running MCTS generation with parameters:"
|
| 30 |
+
echo " Protein Name 1: $PROT_NAME1"
|
| 31 |
+
echo " Protein Name 2: $PROT_NAME2"
|
| 32 |
+
echo " Mode: $MODE"
|
| 33 |
+
echo " Model: $MODEL"
|
| 34 |
+
echo " Length: $LENGTH"
|
| 35 |
+
echo " Epoch: $EPOCH"
|
| 36 |
+
|
| 37 |
+
# Build Hydra override arguments
|
| 38 |
+
mkdir -p "${LOG_LOC}"
|
| 39 |
+
|
| 40 |
+
HYDRA_ARGS="+prot_name1=$PROT_NAME1 ++mode=$MODE +model_type=$MODEL +length=$LENGTH +epoch=$EPOCH"
|
| 41 |
+
if [ -n "$PROT_NAME2" ]; then
|
| 42 |
+
HYDRA_ARGS="$HYDRA_ARGS +prot_name2=$PROT_NAME2"
|
| 43 |
+
fi
|
| 44 |
+
|
| 45 |
+
cd "$SCRIPT_LOC"
|
| 46 |
+
|
| 47 |
+
# Run the MCTS generation script with Hydra overrides
|
| 48 |
+
$PYTHON_EXECUTABLE $SCRIPT_LOC/generate_mcts.py \
|
| 49 |
+
--config-path "$SCRIPT_LOC" \
|
| 50 |
+
--config-name config \
|
| 51 |
+
base_path="$HOME_LOC" \
|
| 52 |
+
eval.checkpoint_path="$CKPT" \
|
| 53 |
+
$HYDRA_ARGS >> ${LOG_LOC}/${DATE}_${SPECIAL_PREFIX}_generate.log 2>&1
|
| 54 |
+
|
| 55 |
+
echo "Generation complete. Check logs at: ${LOG_LOC}/${DATE}_${SPECIAL_PREFIX}_generate.log"
|
| 56 |
+
|
| 57 |
+
conda deactivate
|
scripts/generate_unconditional.sh
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
HOME_LOC=/path/to/your/home/PepTune
|
| 4 |
+
SCRIPT_LOC=$HOME_LOC/src
|
| 5 |
+
LOG_LOC=$HOME_LOC/logs
|
| 6 |
+
DATE=$(date +%m_%d)
|
| 7 |
+
SPECIAL_PREFIX='unconditional'
|
| 8 |
+
PYTHON_EXECUTABLE=python
|
| 9 |
+
|
| 10 |
+
# ===================================================================
|
| 11 |
+
source "$(conda info --base)/etc/profile.d/conda.sh"
|
| 12 |
+
conda activate peptune
|
| 13 |
+
|
| 14 |
+
$PYTHON_EXECUTABLE $SCRIPT_LOC/generate_unconditional.py >> ${DATE}_${SPECIAL_PREFIX}_generate.log 2>&1
|
| 15 |
+
|
| 16 |
+
conda deactivate
|
scripts/train.sh
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
HOME_LOC=/path/to/your/home/PepTune
|
| 4 |
+
ENV_LOC=/path/to/your/envs/peptune
|
| 5 |
+
SCRIPT_LOC=$HOME_LOC/src
|
| 6 |
+
LOG_LOC=$HOME_LOC/logs
|
| 7 |
+
DATE=$(date +%m_%d)
|
| 8 |
+
SPECIAL_PREFIX='11M-ablation-all-losses'
|
| 9 |
+
# set 3 have skip connection
|
| 10 |
+
PYTHON_EXECUTABLE=$ENV_LOC/bin/python
|
| 11 |
+
|
| 12 |
+
# ===================================================================
|
| 13 |
+
source "$(conda info --base)/etc/profile.d/conda.sh"
|
| 14 |
+
conda activate $ENV_LOC
|
| 15 |
+
|
| 16 |
+
$PYTHON_EXECUTABLE $SCRIPT_LOC/train_peptune.py >> ${DATE}_${SPECIAL_PREFIX}_train.log 2>&1
|
| 17 |
+
|
| 18 |
+
conda deactivate
|
src/config.py
ADDED
|
@@ -0,0 +1,319 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def get_parser():
|
| 6 |
+
parser = argparse.ArgumentParser(description='PepTune Training and Evaluation')
|
| 7 |
+
|
| 8 |
+
# Noise parameters
|
| 9 |
+
noise_group = parser.add_argument_group('noise')
|
| 10 |
+
noise_group.add_argument('--noise-type', type=str, default='loglinear',
|
| 11 |
+
help='Type of noise schedule')
|
| 12 |
+
noise_group.add_argument('--sigma-min', type=float, default=1e-4,
|
| 13 |
+
help='Minimum sigma value')
|
| 14 |
+
noise_group.add_argument('--sigma-max', type=float, default=20,
|
| 15 |
+
help='Maximum sigma value')
|
| 16 |
+
noise_group.add_argument('--state-dependent', action='store_true', default=True,
|
| 17 |
+
help='Use state-dependent noise')
|
| 18 |
+
|
| 19 |
+
# Base parameters
|
| 20 |
+
parser.add_argument('--base-path', type=str, default='/path/to/PepTune',
|
| 21 |
+
help='Base path to PepTune')
|
| 22 |
+
parser.add_argument('--mode', type=str, default='ppl_eval',
|
| 23 |
+
choices=['train', 'ppl_eval', 'sample_eval'],
|
| 24 |
+
help='Running mode')
|
| 25 |
+
parser.add_argument('--diffusion', type=str, default='absorbing_state',
|
| 26 |
+
help='Diffusion type')
|
| 27 |
+
parser.add_argument('--vocab', type=str, default='old_smiles',
|
| 28 |
+
choices=['old_smiles', 'new_smiles', 'selfies', 'helm'],
|
| 29 |
+
help='Vocabulary type')
|
| 30 |
+
parser.add_argument('--backbone', type=str, default='roformer',
|
| 31 |
+
choices=['peptideclm', 'helmgpt', 'dit', 'roformer', 'finetune_roformer'],
|
| 32 |
+
help='Model backbone')
|
| 33 |
+
parser.add_argument('--parameterization', type=str, default='subs',
|
| 34 |
+
help='Parameterization type')
|
| 35 |
+
parser.add_argument('--time-conditioning', action='store_true', default=False,
|
| 36 |
+
help='Use time conditioning')
|
| 37 |
+
parser.add_argument('--T', type=int, default=0,
|
| 38 |
+
help='Number of diffusion steps (0 for continuous time, 1000 for discrete)')
|
| 39 |
+
parser.add_argument('--subs-masking', action='store_true', default=False,
|
| 40 |
+
help='Use substitution masking')
|
| 41 |
+
parser.add_argument('--seed', type=int, default=42,
|
| 42 |
+
help='Random seed')
|
| 43 |
+
|
| 44 |
+
# MCTS parameters
|
| 45 |
+
mcts_group = parser.add_argument_group('mcts')
|
| 46 |
+
mcts_group.add_argument('--mcts-num-children', type=int, default=50,
|
| 47 |
+
help='Number of children in MCTS')
|
| 48 |
+
mcts_group.add_argument('--mcts-num-objectives', type=int, default=5,
|
| 49 |
+
help='Number of objectives in MCTS')
|
| 50 |
+
mcts_group.add_argument('--mcts-topk', type=int, default=100,
|
| 51 |
+
help='Top-k for MCTS')
|
| 52 |
+
mcts_group.add_argument('--mcts-mask-token', type=int, default=4,
|
| 53 |
+
help='Mask token ID')
|
| 54 |
+
mcts_group.add_argument('--mcts-num-iter', type=int, default=128,
|
| 55 |
+
help='Number of MCTS iterations')
|
| 56 |
+
mcts_group.add_argument('--mcts-sampling', type=int, default=0,
|
| 57 |
+
help='Sampling strategy (0 for gumbel, >0 for top-k)')
|
| 58 |
+
mcts_group.add_argument('--mcts-invalid-penalty', type=float, default=0.5,
|
| 59 |
+
help='Penalty for invalid sequences')
|
| 60 |
+
mcts_group.add_argument('--mcts-sample-prob', type=float, default=1.0,
|
| 61 |
+
help='Sampling probability')
|
| 62 |
+
mcts_group.add_argument('--mcts-perm', action='store_true', default=True,
|
| 63 |
+
help='Use permutation in MCTS')
|
| 64 |
+
mcts_group.add_argument('--mcts-dual', action='store_true', default=False,
|
| 65 |
+
help='Use dual mode')
|
| 66 |
+
mcts_group.add_argument('--mcts-single', action='store_true', default=False,
|
| 67 |
+
help='Use single mode')
|
| 68 |
+
mcts_group.add_argument('--mcts-time-dependent', action='store_true', default=True,
|
| 69 |
+
help='Use time-dependent MCTS')
|
| 70 |
+
|
| 71 |
+
# Data parameters
|
| 72 |
+
data_group = parser.add_argument_group('data')
|
| 73 |
+
data_group.add_argument('--train-data', type=str,
|
| 74 |
+
default='/path/to/your/home/PepTune/data/peptide_data',
|
| 75 |
+
help='Path to training data')
|
| 76 |
+
data_group.add_argument('--valid-data', type=str,
|
| 77 |
+
default='/path/to/your/home/PepTune/data/peptide_data',
|
| 78 |
+
help='Path to validation data')
|
| 79 |
+
data_group.add_argument('--data-batching', type=str, default='wrapping',
|
| 80 |
+
choices=['padding', 'wrapping'],
|
| 81 |
+
help='Batching strategy')
|
| 82 |
+
|
| 83 |
+
# Loader parameters
|
| 84 |
+
loader_group = parser.add_argument_group('loader')
|
| 85 |
+
loader_group.add_argument('--global-batch-size', type=int, default=64,
|
| 86 |
+
help='Global batch size')
|
| 87 |
+
loader_group.add_argument('--eval-global-batch-size', type=int, default=None,
|
| 88 |
+
help='Evaluation global batch size (defaults to global-batch-size)')
|
| 89 |
+
loader_group.add_argument('--num-workers', type=int, default=None,
|
| 90 |
+
help='Number of dataloader workers (defaults to available CPUs)')
|
| 91 |
+
loader_group.add_argument('--pin-memory', action='store_true', default=True,
|
| 92 |
+
help='Pin memory for dataloaders')
|
| 93 |
+
|
| 94 |
+
# Sampling parameters
|
| 95 |
+
sampling_group = parser.add_argument_group('sampling')
|
| 96 |
+
sampling_group.add_argument('--predictor', type=str, default='ddpm_cache',
|
| 97 |
+
choices=['analytic', 'ddpm', 'ddpm_cache'],
|
| 98 |
+
help='Predictor type for sampling')
|
| 99 |
+
sampling_group.add_argument('--num-sequences', type=int, default=100,
|
| 100 |
+
help='Number of sequences to generate')
|
| 101 |
+
sampling_group.add_argument('--sampling-eps', type=float, default=1e-3,
|
| 102 |
+
help='Sampling epsilon')
|
| 103 |
+
sampling_group.add_argument('--steps', type=int, default=128,
|
| 104 |
+
help='Number of sampling steps')
|
| 105 |
+
sampling_group.add_argument('--seq-length', type=int, default=100,
|
| 106 |
+
help='Sequence length')
|
| 107 |
+
sampling_group.add_argument('--noise-removal', action='store_true', default=True,
|
| 108 |
+
help='Use noise removal')
|
| 109 |
+
sampling_group.add_argument('--num-sample-batches', type=int, default=2,
|
| 110 |
+
help='Number of sample batches')
|
| 111 |
+
sampling_group.add_argument('--num-sample-log', type=int, default=2,
|
| 112 |
+
help='Number of samples to log')
|
| 113 |
+
sampling_group.add_argument('--stride-length', type=int, default=1,
|
| 114 |
+
help='Stride length for sampling')
|
| 115 |
+
sampling_group.add_argument('--num-strides', type=int, default=1,
|
| 116 |
+
help='Number of strides')
|
| 117 |
+
|
| 118 |
+
# Training parameters
|
| 119 |
+
training_group = parser.add_argument_group('training')
|
| 120 |
+
training_group.add_argument('--antithetic-sampling', action='store_true', default=True,
|
| 121 |
+
help='Use antithetic sampling')
|
| 122 |
+
training_group.add_argument('--training-sampling-eps', type=float, default=1e-3,
|
| 123 |
+
help='Training sampling epsilon')
|
| 124 |
+
training_group.add_argument('--focus-mask', action='store_true', default=False,
|
| 125 |
+
help='Use focus mask')
|
| 126 |
+
training_group.add_argument('--accumulator', action='store_true', default=False,
|
| 127 |
+
help='Use accumulator')
|
| 128 |
+
|
| 129 |
+
# Evaluation parameters
|
| 130 |
+
eval_group = parser.add_argument_group('eval')
|
| 131 |
+
eval_group.add_argument('--checkpoint-path', type=str, default=None,
|
| 132 |
+
help='Path to checkpoint for evaluation')
|
| 133 |
+
eval_group.add_argument('--disable-ema', action='store_true', default=False,
|
| 134 |
+
help='Disable EMA')
|
| 135 |
+
eval_group.add_argument('--compute-generative-perplexity', action='store_true', default=False,
|
| 136 |
+
help='Compute generative perplexity')
|
| 137 |
+
eval_group.add_argument('--perplexity-batch-size', type=int, default=8,
|
| 138 |
+
help='Batch size for perplexity computation')
|
| 139 |
+
eval_group.add_argument('--compute-perplexity-on-sanity', action='store_true', default=False,
|
| 140 |
+
help='Compute perplexity on sanity check')
|
| 141 |
+
eval_group.add_argument('--gen-ppl-eval-model', type=str, default='gpt2-large',
|
| 142 |
+
help='Model for generative perplexity evaluation')
|
| 143 |
+
eval_group.add_argument('--generate-samples', action='store_true', default=True,
|
| 144 |
+
help='Generate samples during evaluation')
|
| 145 |
+
eval_group.add_argument('--generation-model', type=str, default=None,
|
| 146 |
+
help='Model for generation')
|
| 147 |
+
|
| 148 |
+
# Optimizer parameters
|
| 149 |
+
optim_group = parser.add_argument_group('optim')
|
| 150 |
+
optim_group.add_argument('--weight-decay', type=float, default=0.075,
|
| 151 |
+
help='Weight decay')
|
| 152 |
+
optim_group.add_argument('--lr', type=float, default=3e-4,
|
| 153 |
+
help='Learning rate')
|
| 154 |
+
optim_group.add_argument('--beta1', type=float, default=0.9,
|
| 155 |
+
help='Adam beta1')
|
| 156 |
+
optim_group.add_argument('--beta2', type=float, default=0.999,
|
| 157 |
+
help='Adam beta2')
|
| 158 |
+
optim_group.add_argument('--eps', type=float, default=1e-8,
|
| 159 |
+
help='Adam epsilon')
|
| 160 |
+
|
| 161 |
+
# PepCLM model parameters
|
| 162 |
+
pepclm_group = parser.add_argument_group('pepclm')
|
| 163 |
+
pepclm_group.add_argument('--pepclm-hidden-size', type=int, default=768,
|
| 164 |
+
help='PepCLM hidden size')
|
| 165 |
+
pepclm_group.add_argument('--pepclm-cond-dim', type=int, default=256,
|
| 166 |
+
help='PepCLM conditioning dimension')
|
| 167 |
+
pepclm_group.add_argument('--pepclm-n-heads', type=int, default=20,
|
| 168 |
+
help='PepCLM number of attention heads')
|
| 169 |
+
pepclm_group.add_argument('--pepclm-n-blocks', type=int, default=4,
|
| 170 |
+
help='PepCLM number of blocks')
|
| 171 |
+
pepclm_group.add_argument('--pepclm-dropout', type=float, default=0.5,
|
| 172 |
+
help='PepCLM dropout rate')
|
| 173 |
+
pepclm_group.add_argument('--pepclm-length', type=int, default=512,
|
| 174 |
+
help='PepCLM sequence length')
|
| 175 |
+
|
| 176 |
+
# General model parameters
|
| 177 |
+
model_group = parser.add_argument_group('model')
|
| 178 |
+
model_group.add_argument('--model-type', type=str, default='ddit',
|
| 179 |
+
help='Model type')
|
| 180 |
+
model_group.add_argument('--hidden-size', type=int, default=768,
|
| 181 |
+
help='Model hidden size')
|
| 182 |
+
model_group.add_argument('--cond-dim', type=int, default=128,
|
| 183 |
+
help='Conditioning dimension')
|
| 184 |
+
model_group.add_argument('--length', type=int, default=512,
|
| 185 |
+
help='Sequence length')
|
| 186 |
+
model_group.add_argument('--n-blocks', type=int, default=12,
|
| 187 |
+
help='Number of blocks')
|
| 188 |
+
model_group.add_argument('--n-heads', type=int, default=12,
|
| 189 |
+
help='Number of attention heads')
|
| 190 |
+
model_group.add_argument('--scale-by-sigma', action='store_true', default=True,
|
| 191 |
+
help='Scale by sigma')
|
| 192 |
+
model_group.add_argument('--dropout', type=float, default=0.1,
|
| 193 |
+
help='Dropout rate')
|
| 194 |
+
|
| 195 |
+
# RoFormer parameters
|
| 196 |
+
roformer_group = parser.add_argument_group('roformer')
|
| 197 |
+
roformer_group.add_argument('--roformer-hidden-size', type=int, default=768,
|
| 198 |
+
help='RoFormer hidden size')
|
| 199 |
+
roformer_group.add_argument('--roformer-n-layers', type=int, default=8,
|
| 200 |
+
help='RoFormer number of layers')
|
| 201 |
+
roformer_group.add_argument('--roformer-n-heads', type=int, default=8,
|
| 202 |
+
help='RoFormer number of attention heads')
|
| 203 |
+
roformer_group.add_argument('--roformer-max-position-embeddings', type=int, default=1035,
|
| 204 |
+
help='RoFormer max position embeddings')
|
| 205 |
+
|
| 206 |
+
# HelmGPT parameters
|
| 207 |
+
helmgpt_group = parser.add_argument_group('helmgpt')
|
| 208 |
+
helmgpt_group.add_argument('--helmgpt-hidden-size', type=int, default=256,
|
| 209 |
+
help='HelmGPT hidden size')
|
| 210 |
+
helmgpt_group.add_argument('--helmgpt-embd-pdrop', type=float, default=0.1,
|
| 211 |
+
help='HelmGPT embedding dropout')
|
| 212 |
+
helmgpt_group.add_argument('--helmgpt-resid-pdrop', type=float, default=0.1,
|
| 213 |
+
help='HelmGPT residual dropout')
|
| 214 |
+
helmgpt_group.add_argument('--helmgpt-attn-pdrop', type=float, default=0.1,
|
| 215 |
+
help='HelmGPT attention dropout')
|
| 216 |
+
helmgpt_group.add_argument('--helmgpt-ff-dropout', type=float, default=0.0,
|
| 217 |
+
help='HelmGPT feedforward dropout')
|
| 218 |
+
helmgpt_group.add_argument('--helmgpt-block-size', type=int, default=140,
|
| 219 |
+
help='HelmGPT block size')
|
| 220 |
+
helmgpt_group.add_argument('--helmgpt-n-layer', type=int, default=8,
|
| 221 |
+
help='HelmGPT number of layers')
|
| 222 |
+
helmgpt_group.add_argument('--helmgpt-n-heads', type=int, default=8,
|
| 223 |
+
help='HelmGPT number of attention heads')
|
| 224 |
+
|
| 225 |
+
# Trainer parameters
|
| 226 |
+
trainer_group = parser.add_argument_group('trainer')
|
| 227 |
+
trainer_group.add_argument('--accelerator', type=str, default='cuda',
|
| 228 |
+
help='Accelerator type')
|
| 229 |
+
trainer_group.add_argument('--num-nodes', type=int, default=1,
|
| 230 |
+
help='Number of nodes')
|
| 231 |
+
trainer_group.add_argument('--devices', type=int, default=1,
|
| 232 |
+
help='Number of devices')
|
| 233 |
+
trainer_group.add_argument('--gradient-clip-val', type=float, default=1.0,
|
| 234 |
+
help='Gradient clipping value')
|
| 235 |
+
trainer_group.add_argument('--precision', type=str, default='64-true',
|
| 236 |
+
help='Training precision')
|
| 237 |
+
trainer_group.add_argument('--num-sanity-val-steps', type=int, default=2,
|
| 238 |
+
help='Number of sanity validation steps')
|
| 239 |
+
trainer_group.add_argument('--max-epochs', type=int, default=100,
|
| 240 |
+
help='Maximum number of epochs')
|
| 241 |
+
trainer_group.add_argument('--max-steps', type=int, default=1_000_000,
|
| 242 |
+
help='Maximum number of steps')
|
| 243 |
+
trainer_group.add_argument('--log-every-n-steps', type=int, default=10,
|
| 244 |
+
help='Log every n steps')
|
| 245 |
+
trainer_group.add_argument('--limit-train-batches', type=float, default=1.0,
|
| 246 |
+
help='Limit training batches')
|
| 247 |
+
trainer_group.add_argument('--limit-val-batches', type=float, default=1.0,
|
| 248 |
+
help='Limit validation batches')
|
| 249 |
+
trainer_group.add_argument('--check-val-every-n-epoch', type=int, default=1,
|
| 250 |
+
help='Check validation every n epochs')
|
| 251 |
+
|
| 252 |
+
# WandB parameters
|
| 253 |
+
wandb_group = parser.add_argument_group('wandb')
|
| 254 |
+
wandb_group.add_argument('--wandb-project', type=str, default='peptune',
|
| 255 |
+
help='WandB project name')
|
| 256 |
+
wandb_group.add_argument('--wandb-notes', type=str, default=None,
|
| 257 |
+
help='WandB notes')
|
| 258 |
+
wandb_group.add_argument('--wandb-group', type=str, default=None,
|
| 259 |
+
help='WandB group')
|
| 260 |
+
wandb_group.add_argument('--wandb-job-type', type=str, default=None,
|
| 261 |
+
help='WandB job type')
|
| 262 |
+
wandb_group.add_argument('--wandb-name', type=str, default='sophia-tang',
|
| 263 |
+
help='WandB run name')
|
| 264 |
+
wandb_group.add_argument('--wandb-id', type=str, default=None,
|
| 265 |
+
help='WandB run ID')
|
| 266 |
+
|
| 267 |
+
# Checkpointing parameters
|
| 268 |
+
checkpoint_group = parser.add_argument_group('checkpointing')
|
| 269 |
+
checkpoint_group.add_argument('--save-dir', type=str, default=None,
|
| 270 |
+
help='Directory to save checkpoints')
|
| 271 |
+
checkpoint_group.add_argument('--resume-from-ckpt', action='store_true', default=True,
|
| 272 |
+
help='Resume from checkpoint')
|
| 273 |
+
checkpoint_group.add_argument('--resume-ckpt-path', type=str, default=None,
|
| 274 |
+
help='Path to checkpoint to resume from')
|
| 275 |
+
checkpoint_group.add_argument('--checkpoint-every-n-epochs', type=int, default=1,
|
| 276 |
+
help='Save checkpoint every n epochs')
|
| 277 |
+
checkpoint_group.add_argument('--checkpoint-monitor', type=str, default='val/nll',
|
| 278 |
+
help='Metric to monitor for checkpointing')
|
| 279 |
+
checkpoint_group.add_argument('--checkpoint-save-top-k', type=int, default=10,
|
| 280 |
+
help='Save top k checkpoints')
|
| 281 |
+
checkpoint_group.add_argument('--checkpoint-mode', type=str, default='min',
|
| 282 |
+
choices=['min', 'max'],
|
| 283 |
+
help='Mode for checkpoint monitoring')
|
| 284 |
+
checkpoint_group.add_argument('--checkpoint-dirpath', type=str,
|
| 285 |
+
default='./checkpoints/11M-old-tokenizer',
|
| 286 |
+
help='Directory path for checkpoints')
|
| 287 |
+
|
| 288 |
+
# LR Scheduler parameters
|
| 289 |
+
scheduler_group = parser.add_argument_group('lr_scheduler')
|
| 290 |
+
scheduler_group.add_argument('--lr-warmup-steps', type=int, default=2500,
|
| 291 |
+
help='Number of warmup steps for learning rate')
|
| 292 |
+
|
| 293 |
+
return parser
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
def get_args():
|
| 297 |
+
"""Parse and return arguments."""
|
| 298 |
+
parser = get_parser()
|
| 299 |
+
args = parser.parse_args()
|
| 300 |
+
|
| 301 |
+
# Post-process arguments
|
| 302 |
+
if args.eval_global_batch_size is None:
|
| 303 |
+
args.eval_global_batch_size = args.global_batch_size
|
| 304 |
+
|
| 305 |
+
if args.num_workers is None:
|
| 306 |
+
args.num_workers = len(os.sched_getaffinity(0))
|
| 307 |
+
|
| 308 |
+
if args.wandb_id is None:
|
| 309 |
+
args.wandb_id = f"{args.wandb_name}_nov12_set2"
|
| 310 |
+
|
| 311 |
+
if args.save_dir is None:
|
| 312 |
+
args.save_dir = os.getcwd()
|
| 313 |
+
|
| 314 |
+
return args
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
if __name__ == '__main__':
|
| 318 |
+
args = get_args()
|
| 319 |
+
print(args)
|
src/config.yaml
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
noise:
|
| 2 |
+
type: loglinear
|
| 3 |
+
sigma_min: 1e-4
|
| 4 |
+
sigma_max: 20
|
| 5 |
+
state_dependent: True
|
| 6 |
+
|
| 7 |
+
base_path: /path/to/your/home/PepTune
|
| 8 |
+
mode: train # train / ppl_eval / sample_eval
|
| 9 |
+
diffusion: absorbing_state
|
| 10 |
+
vocab: old_smiles # old_smiles / new_smiles / selfies / helm
|
| 11 |
+
backbone: roformer # peptideclm / helmgpt / dit / roformer / finetune_roformer
|
| 12 |
+
parameterization: subs # subs
|
| 13 |
+
time_conditioning: False
|
| 14 |
+
T: 0 # 0 (continuous time) / 1000
|
| 15 |
+
subs_masking: False
|
| 16 |
+
|
| 17 |
+
seed: 42
|
| 18 |
+
|
| 19 |
+
mcts:
|
| 20 |
+
num_children: 50
|
| 21 |
+
num_objectives: 5
|
| 22 |
+
topk: 100
|
| 23 |
+
mask_token: 4
|
| 24 |
+
num_iter: 128
|
| 25 |
+
sampling: 0 # 0 is gumbel sampling / > 0 samples children from top k probs
|
| 26 |
+
invalid_penalty: 0.5
|
| 27 |
+
sample_prob: 1.0
|
| 28 |
+
perm: True
|
| 29 |
+
dual: False
|
| 30 |
+
single: False
|
| 31 |
+
time_dependent: True
|
| 32 |
+
|
| 33 |
+
lr_scheduler:
|
| 34 |
+
_target_: transformers.get_constant_schedule_with_warmup
|
| 35 |
+
num_warmup_steps: 2500
|
| 36 |
+
|
| 37 |
+
loader:
|
| 38 |
+
global_batch_size: 64
|
| 39 |
+
eval_global_batch_size: ${.global_batch_size}
|
| 40 |
+
# Note: batch_size and eval_batch_size are **per machine**
|
| 41 |
+
batch_size: ${div_up:${.global_batch_size}, ${eval:${trainer.devices} * ${trainer.num_nodes}}}
|
| 42 |
+
eval_batch_size: ${div_up:${.eval_global_batch_size}, ${eval:${trainer.devices} * ${trainer.num_nodes}}}
|
| 43 |
+
num_workers: ${eval:"len(__import__('os').sched_getaffinity(0))"}
|
| 44 |
+
pin_memory: True
|
| 45 |
+
|
| 46 |
+
sampling:
|
| 47 |
+
predictor: ddpm_cache # analytic, ddpm, ddpm_cache
|
| 48 |
+
num_sequences: 100
|
| 49 |
+
sampling_eps: 1e-3
|
| 50 |
+
steps: 128
|
| 51 |
+
seq_length: 200
|
| 52 |
+
noise_removal: True
|
| 53 |
+
num_sample_batches: 2 # Total samples: `num_gpus` * `loader.eval_batch_size` * num_sample_batches
|
| 54 |
+
num_sample_log: 2
|
| 55 |
+
stride_length: 1
|
| 56 |
+
num_strides: 1
|
| 57 |
+
|
| 58 |
+
training:
|
| 59 |
+
antithetic_sampling: True
|
| 60 |
+
sampling_eps: 1e-3
|
| 61 |
+
focus_mask: False
|
| 62 |
+
#dynamic_batching: True
|
| 63 |
+
accumulator: False
|
| 64 |
+
|
| 65 |
+
eval:
|
| 66 |
+
checkpoint_path: None
|
| 67 |
+
disable_ema: False
|
| 68 |
+
compute_generative_perplexity: False
|
| 69 |
+
perplexity_batch_size: 8
|
| 70 |
+
compute_perplexity_on_sanity: False
|
| 71 |
+
gen_ppl_eval_model_name_or_path: gpt2-large # gpt2-large, meta-llama/Llama-2-7b-hf
|
| 72 |
+
generate_samples: True
|
| 73 |
+
generation_model: None
|
| 74 |
+
|
| 75 |
+
optim:
|
| 76 |
+
weight_decay: 0.075
|
| 77 |
+
lr: 3e-4
|
| 78 |
+
beta1: 0.9
|
| 79 |
+
beta2: 0.999
|
| 80 |
+
eps: 1e-8
|
| 81 |
+
|
| 82 |
+
pepclm:
|
| 83 |
+
hidden_size: 768
|
| 84 |
+
cond_dim: 256
|
| 85 |
+
n_heads: 20
|
| 86 |
+
n_blocks: 4
|
| 87 |
+
dropout: 0.5
|
| 88 |
+
length: 512
|
| 89 |
+
#scale_by_sigma: True
|
| 90 |
+
|
| 91 |
+
model:
|
| 92 |
+
type: ddit
|
| 93 |
+
hidden_size: 768
|
| 94 |
+
cond_dim: 128
|
| 95 |
+
length: 512
|
| 96 |
+
n_blocks: 12
|
| 97 |
+
n_heads: 12
|
| 98 |
+
scale_by_sigma: True
|
| 99 |
+
dropout: 0.1
|
| 100 |
+
|
| 101 |
+
roformer:
|
| 102 |
+
hidden_size: 768
|
| 103 |
+
n_layers: 8
|
| 104 |
+
n_heads: 8
|
| 105 |
+
max_position_embeddings: 1035
|
| 106 |
+
|
| 107 |
+
helmgpt:
|
| 108 |
+
hidden_size: 256
|
| 109 |
+
embd_pdrop: 0.1
|
| 110 |
+
resid_pdrop: 0.1
|
| 111 |
+
attn_pdrop: 0.1
|
| 112 |
+
ff_dropout: 0.
|
| 113 |
+
block_size: 140
|
| 114 |
+
n_layer: 8
|
| 115 |
+
n_heads: 8
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
trainer:
|
| 119 |
+
_target_: lightning.Trainer
|
| 120 |
+
accelerator: cuda
|
| 121 |
+
num_nodes: 1
|
| 122 |
+
devices: ${device_count:}
|
| 123 |
+
accumulate_grad_batches: ${div_up:${loader.global_batch_size}, ${eval:${trainer.devices} * ${loader.batch_size} * ${trainer.num_nodes}}}
|
| 124 |
+
gradient_clip_val: 1.0
|
| 125 |
+
precision: 64-true
|
| 126 |
+
num_sanity_val_steps: 2
|
| 127 |
+
max_epochs: 100
|
| 128 |
+
max_steps: 1_000_000
|
| 129 |
+
log_every_n_steps: 10
|
| 130 |
+
limit_train_batches: 1.0 # train on full dataset, can be used to toggle quick run
|
| 131 |
+
limit_val_batches: 1.0 # validate on full dataset, can be used to toggle quick run
|
| 132 |
+
#val_check_interval: 40 #954
|
| 133 |
+
check_val_every_n_epoch: 1
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
wandb:
|
| 137 |
+
project: peptune
|
| 138 |
+
notes: null
|
| 139 |
+
group: null
|
| 140 |
+
job_type: null
|
| 141 |
+
name: sophia-tang
|
| 142 |
+
id: ${.name}_nov12_set2
|
| 143 |
+
|
| 144 |
+
hydra:
|
| 145 |
+
run:
|
| 146 |
+
dir: ./${now:%Y.%m.%d}/
|
| 147 |
+
job:
|
| 148 |
+
chdir: True
|
| 149 |
+
|
| 150 |
+
checkpointing:
|
| 151 |
+
# Use custom `save_dir` if, e.g., saving to S3 bucket, otherwise leave this parameter as is
|
| 152 |
+
save_dir: ${cwd:}
|
| 153 |
+
# Note: `checkpoints` path should correspond to `checkpoint_every_n_steps.dirpath`
|
| 154 |
+
resume_from_ckpt: True
|
| 155 |
+
resume_ckpt_path: None
|
| 156 |
+
|
| 157 |
+
callbacks:
|
| 158 |
+
model_checkpoint:
|
| 159 |
+
_target_: pytorch_lightning.callbacks.ModelCheckpoint
|
| 160 |
+
every_n_epochs: 1
|
| 161 |
+
monitor: "val/nll"
|
| 162 |
+
save_top_k: 10
|
| 163 |
+
mode: "min"
|
| 164 |
+
dirpath: './checkpoints/'
|
src/diffusion.py
ADDED
|
@@ -0,0 +1,1015 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## Adapted from MDLM: https://github.com/kuleshov-group/mdlm
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import sys
|
| 5 |
+
import itertools
|
| 6 |
+
import time
|
| 7 |
+
import torch
|
| 8 |
+
from torch import Tensor
|
| 9 |
+
import math
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
import numpy as np
|
| 12 |
+
import random as rd
|
| 13 |
+
import lightning as L
|
| 14 |
+
import torchmetrics
|
| 15 |
+
from dataclasses import dataclass
|
| 16 |
+
import gc
|
| 17 |
+
import pickle
|
| 18 |
+
import utils.utils as utils
|
| 19 |
+
from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
|
| 20 |
+
import noise_schedule
|
| 21 |
+
from torch.optim.lr_scheduler import _LRScheduler
|
| 22 |
+
import roformer as roformer
|
| 23 |
+
from utils.app import PeptideAnalyzer
|
| 24 |
+
|
| 25 |
+
@dataclass
|
| 26 |
+
class Loss:
|
| 27 |
+
loss: torch.FloatTensor
|
| 28 |
+
nlls: torch.FloatTensor
|
| 29 |
+
attn_mask: torch.FloatTensor
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class NLL(torchmetrics.aggregation.MeanMetric):
|
| 33 |
+
pass
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class BPD(NLL):
|
| 37 |
+
def compute(self) -> Tensor:
|
| 38 |
+
"""Computes the bits per dimension.
|
| 39 |
+
|
| 40 |
+
Returns:
|
| 41 |
+
bpd
|
| 42 |
+
"""
|
| 43 |
+
return self.mean_value / self.weight / math.log(2)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class Perplexity(NLL):
|
| 47 |
+
def compute(self) -> Tensor:
|
| 48 |
+
"""Computes the Perplexity.
|
| 49 |
+
|
| 50 |
+
Returns:
|
| 51 |
+
Perplexity
|
| 52 |
+
"""
|
| 53 |
+
return torch.exp(self.mean_value / self.weight)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class Diffusion(L.LightningModule):
|
| 57 |
+
def __init__(self, config, tokenizer):
|
| 58 |
+
|
| 59 |
+
super().__init__()
|
| 60 |
+
self.config = config
|
| 61 |
+
#self.save_hyperparameters()
|
| 62 |
+
|
| 63 |
+
# PeptideCLM tokenizer
|
| 64 |
+
self.tokenizer = tokenizer
|
| 65 |
+
self.vocab_size = self.tokenizer.vocab_size
|
| 66 |
+
self.mask_token_id = self.tokenizer.mask_token_id
|
| 67 |
+
self.sampler = self.config.sampling.predictor
|
| 68 |
+
self.analyzer = PeptideAnalyzer()
|
| 69 |
+
|
| 70 |
+
# backbone LM PeptideCLM model
|
| 71 |
+
if self.config.backbone == 'roformer':
|
| 72 |
+
self.backbone = roformer.Roformer(self.config, self.tokenizer)
|
| 73 |
+
self.backbone.unfreeze_all_layers()
|
| 74 |
+
elif self.config.backbone == 'finetune_roformer':
|
| 75 |
+
self.backbone = roformer.Roformer(self.config, self.tokenizer)
|
| 76 |
+
self.backbone.freeze_model()
|
| 77 |
+
self.backbone.unfreeze_n_layers(n=8)
|
| 78 |
+
else:
|
| 79 |
+
Exception('invalid backbone config')
|
| 80 |
+
|
| 81 |
+
self.neg_infinity = -1000000.0
|
| 82 |
+
self.T = config.T
|
| 83 |
+
# noise schedule for non-peptide bond tokens (default to log-linear)
|
| 84 |
+
self.noise = noise_schedule.get_noise(config)
|
| 85 |
+
# noise schedule for peptide bonds (log-polynomial)
|
| 86 |
+
self.bond_noise = noise_schedule.LogPolyNoise()
|
| 87 |
+
self.time_conditioning = self.config.time_conditioning
|
| 88 |
+
self.fast_forward_epochs = None
|
| 89 |
+
self.fast_forward_batches = None
|
| 90 |
+
|
| 91 |
+
self.gen_ppl_eval_model_name_or_path = self.config.eval.gen_ppl_eval_model_name_or_path
|
| 92 |
+
self.gen_ppl_metric = Perplexity()
|
| 93 |
+
|
| 94 |
+
self.lr = self.config.optim.lr
|
| 95 |
+
self.sampling_eps = self.config.training.sampling_eps
|
| 96 |
+
|
| 97 |
+
metrics = torchmetrics.MetricCollection({
|
| 98 |
+
'nll': NLL(),
|
| 99 |
+
'bpd': BPD(),
|
| 100 |
+
'ppl': Perplexity(),
|
| 101 |
+
})
|
| 102 |
+
metrics.set_dtype(torch.float64)
|
| 103 |
+
self.train_metrics = metrics.clone(prefix='trainer/')
|
| 104 |
+
self.valid_metrics = metrics.clone(prefix='val/')
|
| 105 |
+
self.test_metrics = metrics.clone(prefix='test/')
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
"""LOSS FOR INVALID PEPTIDES"""
|
| 109 |
+
|
| 110 |
+
@torch.no_grad()
|
| 111 |
+
def conditional_gumbel(self, logits, D, k):
|
| 112 |
+
"""
|
| 113 |
+
Outputs k samples of Q = StandardGumbel(), such that argmax(logits
|
| 114 |
+
+ Q) is given by D (one-hot vector).
|
| 115 |
+
|
| 116 |
+
Input:
|
| 117 |
+
- logits: Tensor of shape (batch_size, seq_len, vocab_size)
|
| 118 |
+
- D: One-hot tensor of shape (batch_size, seq_len, vocab_size)
|
| 119 |
+
- k: Number of Gumbel samples
|
| 120 |
+
|
| 121 |
+
Output:
|
| 122 |
+
- Adjusted logits with shape (k, batch_size, seq_len, vocab_size)
|
| 123 |
+
"""
|
| 124 |
+
|
| 125 |
+
# iid. exponential samples of shape (k, batch_size, seq_len, vocab_size)
|
| 126 |
+
E = torch.distributions.exponential.Exponential(rate=torch.ones_like(logits)).sample([k])
|
| 127 |
+
|
| 128 |
+
# E of the chosen class, shape (k, batch_size, seq_len, 1)
|
| 129 |
+
Ei = (D * E).sum(dim=-1, keepdim=True)
|
| 130 |
+
|
| 131 |
+
# Partition function (normalization constant), shape (batch_size, seq_len, 1)
|
| 132 |
+
Z = logits.exp().sum(dim=-1, keepdim=True)
|
| 133 |
+
|
| 134 |
+
# Adjusted logits for Gumbel distribution
|
| 135 |
+
adjusted = (
|
| 136 |
+
D * (-torch.log(Ei) + torch.log(Z)) +
|
| 137 |
+
(1 - D) * -torch.log(E / logits.exp() + Ei / Z)
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
# Adjusted logits shape: (k, batch_size, seq_len, vocab_size)
|
| 141 |
+
return adjusted - logits
|
| 142 |
+
|
| 143 |
+
def replace_gradient(self, value, surrogate):
|
| 144 |
+
"""
|
| 145 |
+
Returns `value` but backpropagates gradients through `surrogate`.
|
| 146 |
+
"""
|
| 147 |
+
return surrogate + (value - surrogate).detach()
|
| 148 |
+
|
| 149 |
+
def gumbel_rao(self, logits, k, temp=1.0, I=None):
|
| 150 |
+
"""
|
| 151 |
+
Returns a categorical sample from logits (over axis=-1) as a
|
| 152 |
+
one-hot vector, with gumbel-rao gradient.
|
| 153 |
+
|
| 154 |
+
Input:
|
| 155 |
+
- logits: Tensor of shape (batch_size, seq_len, vocab_size)
|
| 156 |
+
- k: Number of Gumbel samples for Rao-Blackwellization
|
| 157 |
+
- temp: Temperature for softmax
|
| 158 |
+
- I: Optional, precomputed categorical sample tensor of shape (batch_size, seq_len)
|
| 159 |
+
|
| 160 |
+
Output:
|
| 161 |
+
- One-hot tensor of shape (batch_size, seq_len, vocab_size)
|
| 162 |
+
with Gumbel-Rao gradient.
|
| 163 |
+
"""
|
| 164 |
+
assert logits.shape[-1] == self.tokenizer.vocab_size
|
| 165 |
+
vocab_size = logits.shape[-1]
|
| 166 |
+
|
| 167 |
+
if I is None:
|
| 168 |
+
# Sample indices for each token in the batch
|
| 169 |
+
I = torch.distributions.categorical.Categorical(logits=logits).sample() # (batch_size, seq_len)
|
| 170 |
+
|
| 171 |
+
# Convert indices to one-hot encodings, shape (batch_size, seq_len, vocab_size)
|
| 172 |
+
D = torch.nn.functional.one_hot(I, num_classes=vocab_size).float()
|
| 173 |
+
|
| 174 |
+
# Generate k different adjusted logits that all evaluate to the same sequence
|
| 175 |
+
adjusted = logits + self.conditional_gumbel(logits, D, k=k) # (k, batch_size, seq_len, vocab_size)
|
| 176 |
+
|
| 177 |
+
# Compute the surrogate by averaging softmax across k samples
|
| 178 |
+
surrogate = torch.nn.functional.softmax(adjusted / temp, dim=-1).mean(dim=0) # (batch_size, seq_len, vocab_size)
|
| 179 |
+
|
| 180 |
+
# Return one-hot representation with surrogate gradient
|
| 181 |
+
return self.replace_gradient(D, surrogate)
|
| 182 |
+
|
| 183 |
+
def compute_invalid_loss(self, logits, k=None, temp=None):
|
| 184 |
+
"""
|
| 185 |
+
Penalizes logits that produce invalid sequences using the `is_peptide` function,
|
| 186 |
+
scaling penalties inversely with token probabilities.
|
| 187 |
+
|
| 188 |
+
Args:
|
| 189 |
+
logits: Tensor of shape [batch_size, seq_len, vocab_size].
|
| 190 |
+
k: Number of samples for Gumbel-Rao.
|
| 191 |
+
temp: Temperature for softmax.
|
| 192 |
+
|
| 193 |
+
Returns:
|
| 194 |
+
loss: A scalar tensor representing the total loss for invalid sequences.
|
| 195 |
+
"""
|
| 196 |
+
|
| 197 |
+
#samples = self.gumbel_rao(logits, k=k, temp=temp) # (batch_size, seq_len, vocab_size)
|
| 198 |
+
|
| 199 |
+
# Convert logits to sequences using the tokenizer
|
| 200 |
+
batch_token_ids = logits.argmax(dim=-1).to(self.device) # (batch_size, seq_len)
|
| 201 |
+
sampled_sequences = self.tokenizer.batch_decode(batch_token_ids)
|
| 202 |
+
|
| 203 |
+
# Check validity of each sampled sequence (not differentiable)
|
| 204 |
+
penalties = torch.tensor(
|
| 205 |
+
[1 if not self.analyzer.is_peptide(seq) else 0 for seq in sampled_sequences],
|
| 206 |
+
dtype=torch.float32,
|
| 207 |
+
device=self.device
|
| 208 |
+
)
|
| 209 |
+
#print(penalties)
|
| 210 |
+
|
| 211 |
+
# Compute probabilities for each token (batch_size, seq_length)
|
| 212 |
+
sampled_probs = torch.softmax(logits, dim=-1).gather(dim=-1, index=batch_token_ids.unsqueeze(-1)).squeeze(-1).to(self.device)
|
| 213 |
+
|
| 214 |
+
# scale penalties by softmax probability of sampled tokens
|
| 215 |
+
scaled_penalty = penalties[:, None] * sampled_probs # (batch_size, seq_length)
|
| 216 |
+
|
| 217 |
+
return scaled_penalty.to(self.device)
|
| 218 |
+
|
| 219 |
+
"""DIFFUSION LOSS"""
|
| 220 |
+
|
| 221 |
+
def sample_t(self, n, device):
|
| 222 |
+
"""
|
| 223 |
+
Sample random time steps for batch training
|
| 224 |
+
"""
|
| 225 |
+
# sample values uniformly at random from [0, 1)
|
| 226 |
+
eps_t = torch.rand(n, device=device)
|
| 227 |
+
# antithetic sampling: reduce variance by pairing each sample with complementary sample
|
| 228 |
+
if self.config.training.antithetic_sampling:
|
| 229 |
+
# compute interval between sampled time steps
|
| 230 |
+
offset = torch.arange(n, device=device) / n
|
| 231 |
+
# ensure that each eps value is evenly spaced between [0, 1)
|
| 232 |
+
eps_t = ((eps_t / n) + offset) % 1
|
| 233 |
+
|
| 234 |
+
# ensures values are not exactly 0 or 1
|
| 235 |
+
t = (1 - self.config.training.sampling_eps) * eps_t + self.config.training.sampling_eps
|
| 236 |
+
|
| 237 |
+
return t
|
| 238 |
+
|
| 239 |
+
def q_xt(self, x, mask_prob):
|
| 240 |
+
"""Computes the noisy sample xt.
|
| 241 |
+
|
| 242 |
+
Args:
|
| 243 |
+
x: int torch.Tensor with shape (batch_size,
|
| 244 |
+
diffusion_model_input_length), input.
|
| 245 |
+
move_chance: float torch.Tensor with shape (batch_size, 1).
|
| 246 |
+
"""
|
| 247 |
+
|
| 248 |
+
actual_seq_length = (x != 0).sum(dim=-1, keepdim=True)
|
| 249 |
+
#print(actual_seq_length)
|
| 250 |
+
|
| 251 |
+
max_mask_length = (actual_seq_length * 0.75).long()
|
| 252 |
+
|
| 253 |
+
mask_indices = torch.rand(*x.shape, device=x.device) < mask_prob
|
| 254 |
+
|
| 255 |
+
restricted_move_indices = torch.zeros_like(mask_indices, dtype=torch.bool)
|
| 256 |
+
|
| 257 |
+
for i in range(x.shape[0]):
|
| 258 |
+
true_positions = torch.where(mask_indices[i])[0]
|
| 259 |
+
if len(true_positions) > max_mask_length[i]:
|
| 260 |
+
selected_positions = true_positions[:max_mask_length[i].item()]
|
| 261 |
+
restricted_move_indices[i, selected_positions] = True
|
| 262 |
+
else:
|
| 263 |
+
restricted_move_indices[i] = mask_indices[i]
|
| 264 |
+
|
| 265 |
+
xt = torch.where(restricted_move_indices, self.tokenizer.mask_token_id, x)
|
| 266 |
+
|
| 267 |
+
return xt
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
def sample_prior(self, *batch_dims):
|
| 271 |
+
"""
|
| 272 |
+
Returns array of fully masked sequences with same shape as input
|
| 273 |
+
"""
|
| 274 |
+
return self.mask_token_id * torch.ones(* batch_dims, dtype=torch.int64)
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
"""COMPUTING LOSS"""
|
| 278 |
+
|
| 279 |
+
def compute_diffusion_loss(self, model_output, xt, x0, t):
|
| 280 |
+
"""
|
| 281 |
+
Computes diffusion loss term in ELBO
|
| 282 |
+
(evaluates how accurately the model predicts the token probabilities at each time step)
|
| 283 |
+
|
| 284 |
+
Inputs:
|
| 285 |
+
- model_output: [sequence length, vocab size, vocab size] array of logits for each token at each sequence position
|
| 286 |
+
- zt: corrupted version of original input x0 at timestep t
|
| 287 |
+
- x0: original input sequence
|
| 288 |
+
- t: timestep
|
| 289 |
+
"""
|
| 290 |
+
# compute interval between each timestep
|
| 291 |
+
dt = 1 / self.T
|
| 292 |
+
|
| 293 |
+
# compute vectorized alpha scaling terms for the logits at timestep s and t
|
| 294 |
+
alpha_t = 1 - t + torch.zeros_like(x0)
|
| 295 |
+
# s = t - dt
|
| 296 |
+
alpha_s = 1 - (t - dt) + torch.zeros_like(x0)
|
| 297 |
+
|
| 298 |
+
# gather vector of log-probabilities for each token in x0
|
| 299 |
+
# log<x_theta, x>
|
| 300 |
+
log_x_theta_at_x0 = torch.gather(model_output, -1, x0[:, :, None]) # shape (B, L, vocab_size)
|
| 301 |
+
# gather log-probabillities for assigning a masked token at each position in the sequence at time t
|
| 302 |
+
# log<x_theta, m>
|
| 303 |
+
log_x_theta_at_m = model_output[:, :, self.mask_token_id]
|
| 304 |
+
# obtain non-log probability of assigning a masked token
|
| 305 |
+
# <xt, m>
|
| 306 |
+
x_theta_at_m = log_x_theta_at_m.exp()
|
| 307 |
+
|
| 308 |
+
# first term of diffusion loss
|
| 309 |
+
term_1_coef = dt / t
|
| 310 |
+
term_1_log_numerator = torch.log((alpha_t * x_theta_at_m) / t + 1)
|
| 311 |
+
term_1_log_denom = log_x_theta_at_x0
|
| 312 |
+
|
| 313 |
+
# second term of diffusion loss
|
| 314 |
+
term_2_coef = 1 - (dt / t)
|
| 315 |
+
term_2_log_numerator = term_1_log_numerator
|
| 316 |
+
term_2_log_denom = torch.log((alpha_s * x_theta_at_m) / (t - dt) + 1)
|
| 317 |
+
|
| 318 |
+
L_vb_masked = (term_1_coef * (term_1_log_numerator - term_1_log_denom) +
|
| 319 |
+
term_2_coef * (term_2_log_numerator - term_2_log_denom))
|
| 320 |
+
|
| 321 |
+
# multiply by <zt, m> term
|
| 322 |
+
L_vb = L_vb_masked * (xt == self.mask_token_id)
|
| 323 |
+
|
| 324 |
+
# scale by T and return
|
| 325 |
+
return self.T * L_vb
|
| 326 |
+
|
| 327 |
+
def _forward_pass_diffusion(self, x0, attn_mask, bond_mask=None, mask=None):
|
| 328 |
+
"""
|
| 329 |
+
Training reverse diffusion model x_theta to reconstruct samples x0
|
| 330 |
+
|
| 331 |
+
bond_mask: (batch, seq_length)
|
| 332 |
+
"""
|
| 333 |
+
# randomly sample time steps to start the denoising process for each x0 in batch
|
| 334 |
+
t = self.sample_t(x0.shape[0], self.device)
|
| 335 |
+
|
| 336 |
+
# if we are training the intermediate transition blocks
|
| 337 |
+
if self.T > 0:
|
| 338 |
+
# scale by total timesteps T and cast to integer
|
| 339 |
+
t = (t * self.T).to(torch.int)
|
| 340 |
+
# scale down by T to get a multiple of 1/T
|
| 341 |
+
t = t / self.T
|
| 342 |
+
# add 1/T to ensure no 0 values
|
| 343 |
+
t += (1 / self.T)
|
| 344 |
+
|
| 345 |
+
# get noise and rate of noise at timestep t
|
| 346 |
+
# sigma = -log(1-t); dsigma = 1 / (1-t)
|
| 347 |
+
sigma, dsigma = self.noise(t)
|
| 348 |
+
time_conditioning = sigma[:, None]
|
| 349 |
+
|
| 350 |
+
# Get masking probabilities for all tokens for each batch
|
| 351 |
+
# log-linear: 1 - alpha = t
|
| 352 |
+
base_mask_prob = 1 - torch.exp(-sigma[:, None]) # (batch_size, L)
|
| 353 |
+
|
| 354 |
+
if self.config.noise.state_dependent and (bond_mask is not None):
|
| 355 |
+
# log-polynomial masking schedule: alpha = 1 - t^w
|
| 356 |
+
# bond_sigma = -log(1-t^w) for w = 3 (default)
|
| 357 |
+
# bond_dsigma = -wt^(w-1) / (1-t^w)
|
| 358 |
+
bond_sigma, bond_dsigma = self.bond_noise(t) # scalar
|
| 359 |
+
# expand dimensions for broadcasting to (B, L)
|
| 360 |
+
bond_sigma = bond_sigma[:, None]
|
| 361 |
+
bond_dsigma = bond_dsigma[:, None]
|
| 362 |
+
sigma = sigma[:, None]
|
| 363 |
+
dsigma = dsigma[:, None]
|
| 364 |
+
|
| 365 |
+
# compute masking probability for peptide bonds 1 - bond_alpha = t^w
|
| 366 |
+
bond_mask_prob = 1 - torch.exp(-bond_sigma).to(self.device)
|
| 367 |
+
# piece together (B, L) tensor with modified masking prob at peptide-bond locations
|
| 368 |
+
mask_prob = torch.where(bond_mask == 1, bond_mask_prob, base_mask_prob).to(self.device)
|
| 369 |
+
#print(mask_prob)
|
| 370 |
+
dsigma = torch.where(bond_mask == 1, bond_dsigma, dsigma).to(self.device)
|
| 371 |
+
sigma = torch.where(bond_mask == 1, bond_sigma, sigma).to(self.device)
|
| 372 |
+
else:
|
| 373 |
+
mask_prob = base_mask_prob.to(self.device)
|
| 374 |
+
|
| 375 |
+
# get masked samples at different timesteps
|
| 376 |
+
if mask is None:
|
| 377 |
+
zt = self.q_xt(x0, mask_prob).to(self.device)
|
| 378 |
+
else:
|
| 379 |
+
zt = x0.where(mask==1, torch.full_like(x0, self.mask_token_id)).to(self.device)
|
| 380 |
+
|
| 381 |
+
model_output = self.forward(zt, attn_mask=attn_mask.to(self.device), sigma=time_conditioning).to(self.device)
|
| 382 |
+
|
| 383 |
+
# debugging
|
| 384 |
+
assert not torch.isnan(model_output).any()
|
| 385 |
+
assert model_output.is_cuda
|
| 386 |
+
utils.print_nans(model_output, 'model_output')
|
| 387 |
+
|
| 388 |
+
# compute invalid loss
|
| 389 |
+
invalid_loss = self.compute_invalid_loss(logits=model_output).to(self.device) # (B, L)
|
| 390 |
+
#print(invalid_loss)
|
| 391 |
+
|
| 392 |
+
if self.T > 0:
|
| 393 |
+
# compute diffusion loss
|
| 394 |
+
diffusion_loss = self.compute_diffusion_loss(model_output, zt, x0, t)
|
| 395 |
+
return diffusion_loss
|
| 396 |
+
|
| 397 |
+
# compute loss for the final that converts from z0 to x0
|
| 398 |
+
# -log(p_theta)
|
| 399 |
+
# get (batch_size, L) array of log-probabilities
|
| 400 |
+
log_p_theta = torch.gather(input=model_output, dim=-1, index=x0[:, :, None]).squeeze(-1).to(self.device) # (B, L)
|
| 401 |
+
|
| 402 |
+
if self.config.noise.state_dependent and (bond_mask is not None):
|
| 403 |
+
return (-log_p_theta * (dsigma / torch.expm1(sigma)) + invalid_loss).to(self.device)
|
| 404 |
+
else:
|
| 405 |
+
return ((-log_p_theta * (dsigma / torch.expm1(sigma))[:, None]) + invalid_loss).to(self.device)
|
| 406 |
+
|
| 407 |
+
def _loss(self, x0, attn_mask, bond_mask=None, mask=None):
|
| 408 |
+
loss = self._forward_pass_diffusion(x0, attn_mask, bond_mask, mask)
|
| 409 |
+
|
| 410 |
+
# negative log loss
|
| 411 |
+
nlls = loss * attn_mask
|
| 412 |
+
|
| 413 |
+
# count number of tokens
|
| 414 |
+
num_tokens = attn_mask.sum()
|
| 415 |
+
|
| 416 |
+
# compute batch loss
|
| 417 |
+
batch_nll = nlls.sum()
|
| 418 |
+
# compute per token loss
|
| 419 |
+
token_nll = batch_nll / num_tokens
|
| 420 |
+
# return losses
|
| 421 |
+
return Loss(loss = token_nll.to(self.device), nlls = nlls.to(self.device), attn_mask = attn_mask.to(self.device))
|
| 422 |
+
|
| 423 |
+
def _compute_loss(self, batch, prefix, bond_mask=None):
|
| 424 |
+
|
| 425 |
+
attn_mask = batch['attention_mask'].to(self.device)
|
| 426 |
+
|
| 427 |
+
if 'mask' in batch:
|
| 428 |
+
mask = batch['mask'].to(self.device)
|
| 429 |
+
else:
|
| 430 |
+
mask = None
|
| 431 |
+
|
| 432 |
+
if 'bond_mask' in batch:
|
| 433 |
+
bond_mask = batch['bond_mask'].to(self.device)
|
| 434 |
+
else:
|
| 435 |
+
bond_mask = None
|
| 436 |
+
|
| 437 |
+
losses = self._loss(batch['input_ids'].to(self.device), attn_mask, bond_mask, mask)
|
| 438 |
+
loss = losses.loss
|
| 439 |
+
|
| 440 |
+
if prefix == 'train':
|
| 441 |
+
self.train_metrics.update(
|
| 442 |
+
losses.nlls.to(self.device),
|
| 443 |
+
losses.attn_mask.to(self.device)
|
| 444 |
+
)
|
| 445 |
+
metrics = self.train_metrics
|
| 446 |
+
elif prefix == 'val':
|
| 447 |
+
self.valid_metrics.update(
|
| 448 |
+
losses.nlls.to(self.device),
|
| 449 |
+
losses.attn_mask.to(self.device)
|
| 450 |
+
)
|
| 451 |
+
metrics = self.valid_metrics
|
| 452 |
+
elif prefix == 'test':
|
| 453 |
+
self.test_metrics.update(losses.nlls, losses.attn_mask)
|
| 454 |
+
metrics = self.test_metrics
|
| 455 |
+
else:
|
| 456 |
+
raise ValueError(f'Invalid prefix: {prefix}')
|
| 457 |
+
|
| 458 |
+
self.log_dict(metrics,
|
| 459 |
+
on_step=False,
|
| 460 |
+
on_epoch=True,
|
| 461 |
+
sync_dist=True)
|
| 462 |
+
|
| 463 |
+
return loss
|
| 464 |
+
|
| 465 |
+
|
| 466 |
+
"""SAMPLING"""
|
| 467 |
+
|
| 468 |
+
def generate_from_masked(self, num_samples=None, seq_length=None, sample_steps=128, eps=1e-5):
|
| 469 |
+
# get number of timesteps
|
| 470 |
+
if sample_steps is None:
|
| 471 |
+
sample_steps = self.config.sampling.steps
|
| 472 |
+
|
| 473 |
+
if seq_length is None:
|
| 474 |
+
seq_length = self.config.sampling.seq_length
|
| 475 |
+
|
| 476 |
+
# sample fully masked sequences
|
| 477 |
+
z = self.sample_prior(num_samples, seq_length).to(self.device)
|
| 478 |
+
|
| 479 |
+
# create vector of sample_steps timesteps
|
| 480 |
+
timesteps = torch.linspace(1, eps, sample_steps + 1, device=self.device)
|
| 481 |
+
|
| 482 |
+
# compute interval between timesteps
|
| 483 |
+
dt = (1 - eps) / sample_steps
|
| 484 |
+
|
| 485 |
+
for i in range(sample_steps):
|
| 486 |
+
t = timesteps[i] * torch.ones(z.shape[0], 1, device=self.device)
|
| 487 |
+
|
| 488 |
+
z = self.single_reverse_step(z, t, dt)
|
| 489 |
+
|
| 490 |
+
return z
|
| 491 |
+
|
| 492 |
+
|
| 493 |
+
"""SAMPLING STEP"""
|
| 494 |
+
|
| 495 |
+
def single_reverse_step(self, zt, t, dt, attn_mask=None):
|
| 496 |
+
"""
|
| 497 |
+
Take a single reverse diffusion step for the expansion step of the MCTS algorithm
|
| 498 |
+
"""
|
| 499 |
+
# get sigma values that determine masking prob
|
| 500 |
+
sigma_t, _ = self.noise(t)
|
| 501 |
+
sigma_s, _ = self.noise(t - dt)
|
| 502 |
+
|
| 503 |
+
# reshape sigmas
|
| 504 |
+
if sigma_t.ndim > 1:
|
| 505 |
+
sigma_t = sigma_t.squeeze(-1)
|
| 506 |
+
if sigma_s.ndim > 1:
|
| 507 |
+
sigma_s = sigma_s.squeeze(-1)
|
| 508 |
+
assert sigma_t.ndim == 1, sigma_t.shape
|
| 509 |
+
assert sigma_s.ndim == 1, sigma_s.shape
|
| 510 |
+
|
| 511 |
+
# compute masking probabilities for each timestep
|
| 512 |
+
change_prob_t = 1 - torch.exp(-sigma_t)
|
| 513 |
+
change_prob_s = 1 - torch.exp(-sigma_s)
|
| 514 |
+
|
| 515 |
+
# expand dimensions
|
| 516 |
+
change_prob_t = change_prob_t[:, None, None]
|
| 517 |
+
change_prob_s = change_prob_s[:, None, None]
|
| 518 |
+
|
| 519 |
+
# get prodiction model that outputs token probabilities
|
| 520 |
+
log_p_x0 = self.forward(zt, attn_mask=attn_mask, sigma=sigma_t)
|
| 521 |
+
|
| 522 |
+
# check dimensions match
|
| 523 |
+
assert change_prob_t.ndim == log_p_x0.ndim
|
| 524 |
+
|
| 525 |
+
# compute reverse diffusion probability of being unmasked at timestep s
|
| 526 |
+
# (sigma_s - sigma_t)*x_theta
|
| 527 |
+
q_zs = log_p_x0.exp() * (change_prob_t - change_prob_s)
|
| 528 |
+
|
| 529 |
+
# compute reverse diffusion probability of remaining masked at timestep s
|
| 530 |
+
# (1 - sigma_s)*m
|
| 531 |
+
q_zs[:, :, self.mask_token_id] = change_prob_s[:, :, 0]
|
| 532 |
+
|
| 533 |
+
# sample sequence at timestep s from categorical distribution of q_zs
|
| 534 |
+
z_changed = sample_categorical(q_zs)
|
| 535 |
+
|
| 536 |
+
copy_flag = (zt != self.mask_token_id).to(zt.dtype)
|
| 537 |
+
return (copy_flag * zt) + ((1 - copy_flag) * z_changed)
|
| 538 |
+
|
| 539 |
+
def cached_reverse_step(self, x, t, dt, p_x0=None, attn_mask=None):
|
| 540 |
+
assert self.config.noise.type == 'loglinear'
|
| 541 |
+
sigma_t, _ = self.noise(t)
|
| 542 |
+
|
| 543 |
+
if t.ndim > 1:
|
| 544 |
+
t = t.squeeze(-1)
|
| 545 |
+
assert t.ndim == 1
|
| 546 |
+
|
| 547 |
+
change_prob_t = t[:, None, None]
|
| 548 |
+
change_prob_s = (t - dt)[:, None, None]
|
| 549 |
+
|
| 550 |
+
assert change_prob_t.ndim == 3, change_prob_t.shape
|
| 551 |
+
|
| 552 |
+
if p_x0 is None:
|
| 553 |
+
p_x0 = self.forward(x, attn_mask=attn_mask, sigma=sigma_t).exp()
|
| 554 |
+
|
| 555 |
+
assert change_prob_t.ndim == p_x0.ndim
|
| 556 |
+
|
| 557 |
+
q_xs = p_x0 * (change_prob_t - change_prob_s)
|
| 558 |
+
|
| 559 |
+
q_xs[:, :, self.mask_token_id] = change_prob_s[:, :, 0]
|
| 560 |
+
|
| 561 |
+
x_changed = sample_categorical(q_xs)
|
| 562 |
+
|
| 563 |
+
copy_flag = (x != self.mask_token_id).to(x.dtype)
|
| 564 |
+
|
| 565 |
+
return p_x0, copy_flag * x + (1 - copy_flag) * x_changed
|
| 566 |
+
|
| 567 |
+
# first step in expansion
|
| 568 |
+
def batch_cached_reverse_step(self, token_array, t, dt, batch_size, p_x0=None, attn_mask=None):
|
| 569 |
+
|
| 570 |
+
assert self.config.noise.type == 'loglinear'
|
| 571 |
+
sigma_t, _ = self.noise(t)
|
| 572 |
+
|
| 573 |
+
if t.ndim > 1:
|
| 574 |
+
t = t.squeeze(-1)
|
| 575 |
+
assert t.ndim == 1
|
| 576 |
+
|
| 577 |
+
change_prob_t = t[:, None, None]
|
| 578 |
+
change_prob_s = (t - dt)[:, None, None]
|
| 579 |
+
|
| 580 |
+
assert change_prob_t.ndim == 3, change_prob_t.shape
|
| 581 |
+
|
| 582 |
+
if token_array.dim() == 1:
|
| 583 |
+
token_array = token_array.unsqueeze(0)
|
| 584 |
+
#token_array = token_array.repeat(batch_size, 1)
|
| 585 |
+
|
| 586 |
+
attn_mask = torch.ones_like(token_array)
|
| 587 |
+
|
| 588 |
+
if p_x0 is None:
|
| 589 |
+
p_x0 = self.forward(token_array, attn_mask=attn_mask, sigma=sigma_t).exp()
|
| 590 |
+
|
| 591 |
+
assert change_prob_t.ndim == p_x0.ndim
|
| 592 |
+
|
| 593 |
+
q_xs = p_x0 * (change_prob_t - change_prob_s)
|
| 594 |
+
|
| 595 |
+
# zero-masking probability
|
| 596 |
+
q_xs[:, :, self.mask_token_id] = change_prob_s[:, :, 0]
|
| 597 |
+
|
| 598 |
+
# repeat the parent token along the first dimension which will be unmasked into distinct sequences
|
| 599 |
+
token_array = token_array.repeat(batch_size, 1)
|
| 600 |
+
|
| 601 |
+
if self.config.mcts.sampling == 0:
|
| 602 |
+
x_changed = sample_batched_categorical(q_xs.to(self.device), batch_size)
|
| 603 |
+
else:
|
| 604 |
+
x_changed = sample_batched_top_k(q_xs.to(self.device), batch_size, self.config.mcts.sampling)
|
| 605 |
+
|
| 606 |
+
copy_flag = (token_array != self.mask_token_id).to(token_array.dtype)
|
| 607 |
+
|
| 608 |
+
return p_x0, copy_flag * token_array + (1 - copy_flag) * x_changed
|
| 609 |
+
|
| 610 |
+
def _process_sigma(self, sigma):
|
| 611 |
+
if sigma.ndim > 1:
|
| 612 |
+
sigma = sigma.squeeze(-1)
|
| 613 |
+
if not self.time_conditioning:
|
| 614 |
+
sigma = torch.zeros_like(sigma)
|
| 615 |
+
assert sigma.ndim == 1, sigma.shape
|
| 616 |
+
return sigma
|
| 617 |
+
|
| 618 |
+
def forward(self, zt, attn_mask, sigma):
|
| 619 |
+
"""
|
| 620 |
+
Predicts the token log-probabilities from zt at time t with noise schedule sigma
|
| 621 |
+
"""
|
| 622 |
+
sigma = self._process_sigma(sigma)
|
| 623 |
+
|
| 624 |
+
with torch.amp.autocast("cuda", enabled=True, dtype=torch.float32, cache_enabled=True):
|
| 625 |
+
logits = self.backbone(zt, attn_mask).to(self.device)
|
| 626 |
+
|
| 627 |
+
return self.subs_parameterization(logits, zt)
|
| 628 |
+
|
| 629 |
+
def subs_parameterization(self, logits, zt):
|
| 630 |
+
"""
|
| 631 |
+
Updates reverse diffusion logits based on SUBS parameterization:
|
| 632 |
+
- zero masking probabilities: -infinity probability of being masked during reverse diffusion
|
| 633 |
+
- carry-over unmasking: unmasked input tokens remain unchanged during reverse diffusion
|
| 634 |
+
|
| 635 |
+
Args:
|
| 636 |
+
logits: vector of token probabilities for unmasking masked tokens
|
| 637 |
+
zt: partially unmasked sequence at current timestep
|
| 638 |
+
"""
|
| 639 |
+
logits[:, :, self.mask_token_id] += self.neg_infinity # [sequence index, current token, next token]
|
| 640 |
+
|
| 641 |
+
|
| 642 |
+
logits = (logits - torch.logsumexp(logits, dim=-1, keepdim=True)).to(self.device)
|
| 643 |
+
|
| 644 |
+
|
| 645 |
+
unmasked_indices = (zt != self.mask_token_id).to(self.device) # shape: [200, seq_length]
|
| 646 |
+
batch_idx, seq_idx = torch.where(unmasked_indices) # Get explicit indices
|
| 647 |
+
batch_idx = batch_idx.to(self.device)
|
| 648 |
+
seq_idx = seq_idx.to(self.device)
|
| 649 |
+
tokens = zt[batch_idx, seq_idx].to(self.device) # Get the tokens at those positions
|
| 650 |
+
|
| 651 |
+
assert logits.is_contiguous(), "logits tensor is not contiguous"
|
| 652 |
+
assert unmasked_indices.shape == zt.shape, "same shape"
|
| 653 |
+
assert not torch.isnan(logits).any(), "NaN values found in logits"
|
| 654 |
+
assert tokens.max() < logits.shape[-1], "token indices out of bounds"
|
| 655 |
+
assert batch_idx.max() < logits.shape[0], "batch index out of bounds"
|
| 656 |
+
assert seq_idx.max() < logits.shape[1], "seq index out of bounds"
|
| 657 |
+
assert batch_idx.device == seq_idx.device == logits.device == tokens.device, "device inconsistent"
|
| 658 |
+
|
| 659 |
+
logits[batch_idx, seq_idx] = self.neg_infinity # Set everything to -inf first
|
| 660 |
+
logits[batch_idx, seq_idx, tokens] = 0 # Set only the specific token positions to 0
|
| 661 |
+
# return logits with SUBS parameterization
|
| 662 |
+
return logits.to(self.device)
|
| 663 |
+
|
| 664 |
+
"""SAMPLING"""
|
| 665 |
+
@torch.no_grad()
|
| 666 |
+
def _sample(self, num_steps=None, eps=1e-5, x_input=None):
|
| 667 |
+
"""
|
| 668 |
+
Generate samples
|
| 669 |
+
"""
|
| 670 |
+
batch_size_per_gpu = self.config.eval.perplexity_batch_size
|
| 671 |
+
|
| 672 |
+
if num_steps is None:
|
| 673 |
+
num_steps = self.config.sampling.steps
|
| 674 |
+
|
| 675 |
+
if x_input is not None:
|
| 676 |
+
x = x_input['input_ids'].to(self.device)
|
| 677 |
+
attn_mask = x_input['attention_mask'].to(self.device)
|
| 678 |
+
else:
|
| 679 |
+
x = self.sample_prior(batch_size_per_gpu, self.config.model.length).to(self.device)
|
| 680 |
+
attn_mask = torch.ones_like(x).to(self.device)
|
| 681 |
+
|
| 682 |
+
|
| 683 |
+
timesteps = torch.linspace(1, eps, num_steps+1, device=self.device)
|
| 684 |
+
dt = (1 - eps) / num_steps
|
| 685 |
+
p_x0_cache = None
|
| 686 |
+
generation_history = [] # used to track which tokens are unmasked
|
| 687 |
+
|
| 688 |
+
for i in range(num_steps):
|
| 689 |
+
t = timesteps[i] * torch.ones(x.shape[0], 1, device = self.device)
|
| 690 |
+
if self.sampler == 'ddpm':
|
| 691 |
+
x = self.single_reverse_step(x, t, dt).to(self.device)
|
| 692 |
+
|
| 693 |
+
elif self.sampler == 'ddpm_cache':
|
| 694 |
+
p_x0_cache, x_next = self.cached_reverse_step(x, t, dt, p_x0=p_x0_cache, attn_mask=attn_mask)
|
| 695 |
+
if (not torch.allclose(x_next, x) or self.time_conditioning):
|
| 696 |
+
# Disable caching
|
| 697 |
+
p_x0_cache = None
|
| 698 |
+
x = x_next.to(self.device)
|
| 699 |
+
#print(self.tokenizer.decode(x.squeeze()))
|
| 700 |
+
else:
|
| 701 |
+
x = self._analytic_update(x, t, dt, attn_mask).to(self.device)
|
| 702 |
+
|
| 703 |
+
if self.config.sampling.noise_removal:
|
| 704 |
+
t = timesteps[-1] * torch.ones(x.shape[0], 1, device=self.device)
|
| 705 |
+
if self.sampler == 'analytic':
|
| 706 |
+
x = self._denoiser_update(x, t).to(self.device)
|
| 707 |
+
else:
|
| 708 |
+
time_conditioning = self.noise(t)[0].to(self.device)
|
| 709 |
+
x = self.forward(x, attn_mask=attn_mask, sigma=time_conditioning).argmax(dim=-1).to(self.device)
|
| 710 |
+
#print(self.tokenizer.decode(x.squeeze()))
|
| 711 |
+
return x.to(self.device)
|
| 712 |
+
|
| 713 |
+
|
| 714 |
+
def restore_model_and_sample(self, num_steps, eps=1e-5):
|
| 715 |
+
"""Generate samples from the model."""
|
| 716 |
+
self.backbone.eval()
|
| 717 |
+
self.noise.eval()
|
| 718 |
+
samples = self._sample(num_steps=num_steps, eps=eps)
|
| 719 |
+
self.backbone.train()
|
| 720 |
+
self.noise.train()
|
| 721 |
+
return samples
|
| 722 |
+
|
| 723 |
+
def get_score(self, zt, sigma, attn_mask=None):
|
| 724 |
+
|
| 725 |
+
# score(x, t) = p_t(y) / p_t(x)
|
| 726 |
+
# => log score(x, t) = log p_t(y) - log p_t(x)
|
| 727 |
+
|
| 728 |
+
# case 1: x = masked
|
| 729 |
+
# (i) y = unmasked
|
| 730 |
+
# log score(x, t) = log p_\theta(x)|_y + log k
|
| 731 |
+
# where k = exp(- sigma) / (1 - exp(- sigma))
|
| 732 |
+
# (ii) y = masked
|
| 733 |
+
# log score(x, t) = 0
|
| 734 |
+
|
| 735 |
+
# case 2: x = unmasked
|
| 736 |
+
# (i) y != masked, y != x
|
| 737 |
+
# log score(x_i, t) = - inf
|
| 738 |
+
# (ii) y = x
|
| 739 |
+
# log score(x_i, t) = 0
|
| 740 |
+
# (iii) y = masked token
|
| 741 |
+
# log score(x_i, t) = - log k
|
| 742 |
+
# where k = exp(- sigma) / (1 - exp(- sigma))
|
| 743 |
+
|
| 744 |
+
model_output = self.forward(zt, attn_mask=attn_mask, sigma=sigma)
|
| 745 |
+
|
| 746 |
+
log_k = -torch.log(torch.expm1(sigma)).squeeze(-1)
|
| 747 |
+
assert log_k.ndim == 1
|
| 748 |
+
|
| 749 |
+
masked_score = model_output + log_k[:, None, None]
|
| 750 |
+
masked_score[:, :, self.mask_token_id] = 0
|
| 751 |
+
|
| 752 |
+
unmasked_score = self.neg_infinity * torch.ones_like(model_output)
|
| 753 |
+
unmasked_score = torch.scatter(
|
| 754 |
+
unmasked_score, -1,
|
| 755 |
+
zt[..., None],
|
| 756 |
+
torch.zeros_like(unmasked_score[..., :1]))
|
| 757 |
+
|
| 758 |
+
unmasked_score[:, :, self.mask_token_id] = - (log_k[:, None] * torch.ones_like(zt))
|
| 759 |
+
|
| 760 |
+
masked_indices = (zt == self.mask_token_id).to(model_output.dtype)[:, :, None]
|
| 761 |
+
|
| 762 |
+
model_output = (masked_score * masked_indices + unmasked_score * (1 - masked_indices))
|
| 763 |
+
|
| 764 |
+
return model_output.exp()
|
| 765 |
+
|
| 766 |
+
def _staggered_score(self, score, dsigma):
|
| 767 |
+
score = score.clone()
|
| 768 |
+
extra_const = (1 - dsigma.exp()) * score.sum(dim=-1)
|
| 769 |
+
score *= dsigma.exp()[:, None]
|
| 770 |
+
score[..., self.mask_token_id] += extra_const
|
| 771 |
+
return score
|
| 772 |
+
|
| 773 |
+
def _analytic_update(self, x, t, step_size, attn_mask=None):
|
| 774 |
+
curr_sigma, _ = self.noise(t)
|
| 775 |
+
next_sigma, _ = self.noise(t - step_size)
|
| 776 |
+
dsigma = curr_sigma - next_sigma
|
| 777 |
+
score = self.get_score(x, attn_mask, curr_sigma)
|
| 778 |
+
stag_score = self._staggered_score(score, dsigma)
|
| 779 |
+
probs = stag_score * self._transp_transition(x, dsigma)
|
| 780 |
+
return sample_categorical(probs)
|
| 781 |
+
|
| 782 |
+
def _denoiser_update(self, x, t):
|
| 783 |
+
sigma, _ = self.noise(t)
|
| 784 |
+
score = self.get_score(x, sigma)
|
| 785 |
+
stag_score = self._staggered_score(score, sigma)
|
| 786 |
+
probs = stag_score * self._transp_transition(x, sigma)
|
| 787 |
+
probs[..., self.mask_token_id] = 0
|
| 788 |
+
samples = sample_categorical(probs)
|
| 789 |
+
return samples
|
| 790 |
+
|
| 791 |
+
def _transp_transition(self, i, sigma):
|
| 792 |
+
sigma = unsqueeze(sigma, reference=i[..., None])
|
| 793 |
+
edge = torch.exp(-sigma) * F.one_hot(
|
| 794 |
+
i, num_classes=self.vocab_size)
|
| 795 |
+
edge += torch.where(i == self.mask_token_id,
|
| 796 |
+
1 - torch.exp(-sigma).squeeze(-1),
|
| 797 |
+
0)[..., None]
|
| 798 |
+
return edge
|
| 799 |
+
|
| 800 |
+
|
| 801 |
+
def on_train_epoch_start(self):
|
| 802 |
+
torch.cuda.empty_cache()
|
| 803 |
+
self.backbone.train()
|
| 804 |
+
self.noise.train()
|
| 805 |
+
|
| 806 |
+
|
| 807 |
+
def training_step(self, batch, batch_idx):
|
| 808 |
+
# Initialize throughput calculation
|
| 809 |
+
start_time = time.time()
|
| 810 |
+
|
| 811 |
+
if self.config.vocab == 'old_smiles' or self.config.vocab == 'new_smiles':
|
| 812 |
+
loss = self._compute_loss(batch, prefix='train', bond_mask=batch['bond_mask'])
|
| 813 |
+
else:
|
| 814 |
+
loss = self._compute_loss(batch, prefix='train')
|
| 815 |
+
|
| 816 |
+
self.log(name='trainer/loss',
|
| 817 |
+
value=loss.item(),
|
| 818 |
+
on_step=True,
|
| 819 |
+
on_epoch=False,
|
| 820 |
+
sync_dist=True)
|
| 821 |
+
|
| 822 |
+
# Calculate throughput
|
| 823 |
+
elapsed_time = time.time() - start_time
|
| 824 |
+
total_tokens = batch['input_ids'].numel()
|
| 825 |
+
throughput = total_tokens / elapsed_time
|
| 826 |
+
|
| 827 |
+
self.log(name='trainer/throughput',
|
| 828 |
+
value=throughput,
|
| 829 |
+
on_step=True,
|
| 830 |
+
on_epoch=False,
|
| 831 |
+
sync_dist=True)
|
| 832 |
+
|
| 833 |
+
return loss
|
| 834 |
+
|
| 835 |
+
|
| 836 |
+
def on_load_checkpoint(self, checkpoint):
|
| 837 |
+
self.fast_forward_epochs = checkpoint['loops']['fit_loop']['epoch_progress']['current']['completed']
|
| 838 |
+
self.fast_forward_batches = checkpoint['loops']['fit_loop']['epoch_loop.batch_progress']['current']['completed']
|
| 839 |
+
|
| 840 |
+
"""VALIDATION"""
|
| 841 |
+
def on_validation_epoch_start(self):
|
| 842 |
+
gc.collect()
|
| 843 |
+
torch.cuda.empty_cache()
|
| 844 |
+
self.backbone.eval()
|
| 845 |
+
self.noise.eval()
|
| 846 |
+
assert self.valid_metrics.nll.mean_value == 0
|
| 847 |
+
assert self.valid_metrics.nll.weight == 0
|
| 848 |
+
|
| 849 |
+
def validation_step(self, batch, batch_idx):
|
| 850 |
+
if self.config.vocab == 'old_smiles' or self.config.vocab == 'new_smiles':
|
| 851 |
+
loss = self._compute_loss(batch, prefix='val', bond_mask=batch['bond_mask'])
|
| 852 |
+
else:
|
| 853 |
+
loss = self._compute_loss(batch, prefix='val')
|
| 854 |
+
|
| 855 |
+
self.log(name='trainer/val_loss',
|
| 856 |
+
value=loss.item(),
|
| 857 |
+
on_step=True,
|
| 858 |
+
on_epoch=False,
|
| 859 |
+
prog_bar=True,
|
| 860 |
+
sync_dist=True)
|
| 861 |
+
return loss
|
| 862 |
+
|
| 863 |
+
def on_validation_epoch_end(self):
|
| 864 |
+
gc.collect()
|
| 865 |
+
torch.cuda.empty_cache()
|
| 866 |
+
|
| 867 |
+
"""OPTIMIZATION"""
|
| 868 |
+
|
| 869 |
+
def optimizer_step(self, *args, **kwargs):
|
| 870 |
+
super().optimizer_step(*args, **kwargs)
|
| 871 |
+
|
| 872 |
+
gc.collect()
|
| 873 |
+
torch.cuda.empty_cache()
|
| 874 |
+
|
| 875 |
+
def configure_optimizers(self):
|
| 876 |
+
optimizer = torch.optim.AdamW(
|
| 877 |
+
itertools.chain(self.backbone.parameters(),self.noise.parameters()),
|
| 878 |
+
lr=self.config.optim.lr,
|
| 879 |
+
betas=(self.config.optim.beta1, self.config.optim.beta2),
|
| 880 |
+
eps=self.config.optim.eps,
|
| 881 |
+
weight_decay=self.config.optim.weight_decay
|
| 882 |
+
)
|
| 883 |
+
|
| 884 |
+
self.total_steps = self.config.trainer.max_steps
|
| 885 |
+
scheduler = CosineWarmup(optimizer,
|
| 886 |
+
warmup_steps=self.config.lr_scheduler.num_warmup_steps,
|
| 887 |
+
total_steps=self.total_steps)
|
| 888 |
+
|
| 889 |
+
scheduler_dict = {
|
| 890 |
+
'scheduler': scheduler,
|
| 891 |
+
'interval': 'step',
|
| 892 |
+
'frequency': 1,
|
| 893 |
+
'monitor': 'val/loss',
|
| 894 |
+
'name': 'trainer/lr'
|
| 895 |
+
}
|
| 896 |
+
|
| 897 |
+
return [optimizer], [scheduler_dict]
|
| 898 |
+
|
| 899 |
+
@torch.no_grad()
|
| 900 |
+
def compute_masked_perplexity(self, generated_ids, input_ids):
|
| 901 |
+
"""
|
| 902 |
+
Computes masked perplexity between array of generated token ids and masked ids that are converted to logits
|
| 903 |
+
"""
|
| 904 |
+
|
| 905 |
+
total_nll = 0
|
| 906 |
+
total_tokens = 0
|
| 907 |
+
|
| 908 |
+
input_ids = torch.tensor(input_ids).to(self.device)
|
| 909 |
+
#print(input_ids)
|
| 910 |
+
|
| 911 |
+
for sequence in generated_ids:
|
| 912 |
+
# tokenize the sequence
|
| 913 |
+
|
| 914 |
+
gt_ids = torch.tensor(sequence).to(self.device)
|
| 915 |
+
#print(gt_ids)
|
| 916 |
+
|
| 917 |
+
sys.stdout.flush()
|
| 918 |
+
|
| 919 |
+
# forward pass thorugh backbone peptideclm model
|
| 920 |
+
attn_mask = torch.ones_like(input_ids).to(self.device)
|
| 921 |
+
|
| 922 |
+
# compute logits using backbone
|
| 923 |
+
|
| 924 |
+
outputs = self.backbone.forward(input_ids=input_ids, attn_mask=attn_mask)
|
| 925 |
+
|
| 926 |
+
|
| 927 |
+
# get logits for each position in sequence across all tokens in vocab
|
| 928 |
+
#logits = outputs[-1] # (batch_size, seq_length, vocab_size)
|
| 929 |
+
|
| 930 |
+
logits = outputs.view(-1, outputs.size(-1))
|
| 931 |
+
gt_ids = gt_ids.view(-1)
|
| 932 |
+
|
| 933 |
+
#print(logits.shape)
|
| 934 |
+
#print(gt_ids.shape)
|
| 935 |
+
|
| 936 |
+
# compute loss
|
| 937 |
+
# shift_logits = logits[:, :-1, :].contiguous() # remove eos
|
| 938 |
+
# shift_labels = input_ids[:, 1:].contiguous()
|
| 939 |
+
# print(masked)
|
| 940 |
+
|
| 941 |
+
loss = F.cross_entropy(logits,
|
| 942 |
+
gt_ids.where(input_ids==self.mask_token_id, torch.full_like(gt_ids, -100)).view(-1),
|
| 943 |
+
reduction='sum')
|
| 944 |
+
|
| 945 |
+
total_nll += loss.item()
|
| 946 |
+
# count all non-padding tokens
|
| 947 |
+
total_tokens += input_ids.ne(self.tokenizer.pad_token_id).sum().item() # count in bos and eos
|
| 948 |
+
|
| 949 |
+
# compute pseudo-perplexity
|
| 950 |
+
# print(total_nll, ",;,", total_tokens)
|
| 951 |
+
pseudo_perplexity = torch.exp(torch.tensor(total_nll / total_tokens))
|
| 952 |
+
self.gen_ppl_metric.update(pseudo_perplexity)
|
| 953 |
+
|
| 954 |
+
return pseudo_perplexity.item()
|
| 955 |
+
|
| 956 |
+
|
| 957 |
+
def sample_categorical(categorical_probs):
|
| 958 |
+
gumbel_norm = (
|
| 959 |
+
1e-10
|
| 960 |
+
- (torch.rand_like(categorical_probs) + 1e-10).log())
|
| 961 |
+
return (categorical_probs / gumbel_norm).argmax(dim=-1)
|
| 962 |
+
|
| 963 |
+
def sample_batched_categorical(categorical_probs, batch_size):
|
| 964 |
+
_, sequence_length, vocab_size = categorical_probs.shape
|
| 965 |
+
|
| 966 |
+
# add Gumbel noise and sample m sequences
|
| 967 |
+
gumbel_noise = (-torch.log(-torch.log(torch.rand(batch_size, sequence_length, vocab_size) + 1e-10) + 1e-10)).to(categorical_probs.device)
|
| 968 |
+
noisy_scores = torch.log(categorical_probs) + gumbel_noise # add Gumbel noise to log probabilities
|
| 969 |
+
|
| 970 |
+
# select the highest score (most likely category after Gumbel noise)
|
| 971 |
+
sampled_sequences = noisy_scores.argmax(dim=-1) # shape: (m, sequence_length)
|
| 972 |
+
|
| 973 |
+
return sampled_sequences
|
| 974 |
+
|
| 975 |
+
def sample_batched_top_k(categorical_probs, batch_size, k):
|
| 976 |
+
_, sequence_length, vocab_length = categorical_probs.shape
|
| 977 |
+
|
| 978 |
+
# Add Gumbel noise to the log probabilities
|
| 979 |
+
gumbel_noise = -torch.log(-torch.log(torch.rand(batch_size, sequence_length, vocab_length) + 1e-10) + 1e-10).to(categorical_probs.device)
|
| 980 |
+
noisy_scores = torch.log(categorical_probs[None, :, :]) + gumbel_noise # Shape: (m, sequence_length, vocab_length)
|
| 981 |
+
|
| 982 |
+
# Get the top-k categories based on noisy scores
|
| 983 |
+
top_k_scores, top_k_indices = torch.topk(noisy_scores, k, dim=-1) # Shape: (m, sequence_length, k)
|
| 984 |
+
|
| 985 |
+
# Convert top-k scores back to probabilities and normalize
|
| 986 |
+
top_k_probs = torch.softmax(top_k_scores, dim=-1).to(categorical_probs.device) # Shape: (m, sequence_length, k)
|
| 987 |
+
|
| 988 |
+
# Sample randomly from the top-k probabilities
|
| 989 |
+
sampled_indices_in_top_k = torch.multinomial(top_k_probs.reshape(-1, k), num_samples=1).squeeze(-1).to(categorical_probs.device)
|
| 990 |
+
sampled_indices_in_top_k = sampled_indices_in_top_k.view(batch_size, sequence_length).to(categorical_probs.device) # Shape: (batch_size, sequence_length)
|
| 991 |
+
|
| 992 |
+
# Map sampled indices back to the original vocabulary indices
|
| 993 |
+
sampled_sequences = torch.gather(top_k_indices, -1, sampled_indices_in_top_k.unsqueeze(-1)).squeeze(-1).to(categorical_probs.device)
|
| 994 |
+
|
| 995 |
+
return sampled_sequences
|
| 996 |
+
|
| 997 |
+
def unsqueeze(x, reference):
|
| 998 |
+
return x.view(* x.shape, * ((1,) * (len(reference.shape) - len(x.shape))))
|
| 999 |
+
|
| 1000 |
+
class CosineWarmup(_LRScheduler):
|
| 1001 |
+
def __init__(self, optimizer, warmup_steps, total_steps, eta_ratio=0.1, last_epoch=-1):
|
| 1002 |
+
self.warmup_steps = warmup_steps
|
| 1003 |
+
self.total_steps = total_steps
|
| 1004 |
+
self.eta_ratio = eta_ratio # The ratio of minimum to maximum learning rate
|
| 1005 |
+
super(CosineWarmup, self).__init__(optimizer, last_epoch)
|
| 1006 |
+
|
| 1007 |
+
def get_lr(self):
|
| 1008 |
+
if self.last_epoch < self.warmup_steps:
|
| 1009 |
+
return [base_lr * self.last_epoch / self.warmup_steps for base_lr in self.base_lrs]
|
| 1010 |
+
|
| 1011 |
+
progress = (self.last_epoch - self.warmup_steps) / (self.total_steps - self.warmup_steps)
|
| 1012 |
+
cosine_decay = 0.5 * (1 + np.cos(np.pi * progress))
|
| 1013 |
+
decayed_lr = (1 - self.eta_ratio) * cosine_decay + self.eta_ratio
|
| 1014 |
+
|
| 1015 |
+
return [decayed_lr * base_lr for base_lr in self.base_lrs]
|
src/environment.yml
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: peptune
|
| 2 |
+
channels:
|
| 3 |
+
- pytorch
|
| 4 |
+
- nvidia
|
| 5 |
+
- conda-forge
|
| 6 |
+
dependencies:
|
| 7 |
+
- python=3.10
|
| 8 |
+
- pip
|
| 9 |
+
- pytorch
|
| 10 |
+
- torchvision
|
| 11 |
+
- pytorch-cuda=12.1
|
| 12 |
+
- rdkit
|
| 13 |
+
- numpy
|
| 14 |
+
- pandas
|
| 15 |
+
- scikit-learn
|
| 16 |
+
- jupyterlab
|
| 17 |
+
- matplotlib-base
|
| 18 |
+
- seaborn
|
| 19 |
+
- tqdm
|
| 20 |
+
- pyyaml
|
| 21 |
+
- pip:
|
| 22 |
+
- pytorch-lightning==2.5.5
|
| 23 |
+
- lightning==2.5.5
|
| 24 |
+
- fair-esm==2.0.0
|
| 25 |
+
- transformers==4.56.2
|
| 26 |
+
- SmilesPE==0.0.3
|
| 27 |
+
- scipy==1.13.1
|
| 28 |
+
- wandb==0.22.0
|
| 29 |
+
- hydra-core==1.3.2
|
| 30 |
+
- hydra-submitit-launcher==1.2.0
|
| 31 |
+
- pathos==0.3.4
|
| 32 |
+
- matplotlib==3.10.1
|
| 33 |
+
- pandas==2.2.2
|
| 34 |
+
- seaborn==0.13.2
|
| 35 |
+
- timm==1.0.20
|
| 36 |
+
- xgboost==3.0.5
|
| 37 |
+
- loguru==0.7.3
|
| 38 |
+
- peft==0.17.1
|
| 39 |
+
- accelerate==1.11.0
|
| 40 |
+
- datasets
|
src/generate_mcts.py
ADDED
|
@@ -0,0 +1,365 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env
|
| 2 |
+
import time
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
import math
|
| 6 |
+
import random
|
| 7 |
+
import sys
|
| 8 |
+
import pandas as pd
|
| 9 |
+
from utils.generate_utils import mask_for_de_novo
|
| 10 |
+
from diffusion import Diffusion
|
| 11 |
+
from pareto_mcts import Node, MCTS
|
| 12 |
+
import hydra
|
| 13 |
+
from tqdm import tqdm
|
| 14 |
+
from transformers import AutoTokenizer, AutoModel, pipeline
|
| 15 |
+
from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
|
| 16 |
+
from utils.app import PeptideAnalyzer
|
| 17 |
+
import matplotlib.pyplot as plt
|
| 18 |
+
import os
|
| 19 |
+
import seaborn as sns
|
| 20 |
+
import pandas as pd
|
| 21 |
+
import numpy as np
|
| 22 |
+
|
| 23 |
+
# Protein sequence dictionary
|
| 24 |
+
PROTEIN_SEQUENCES = {
|
| 25 |
+
'amhr': 'MLGSLGLWALLPTAVEAPPNRRTCVFFEAPGVRGSTKTLGELLDTGTELPRAIRCLYSRCCFGIWNLTQDRAQVEMQGCRDSDEPGCESLHCDPSPRAHPSPGSTLFTCSCGTDFCNANYSHLPPPGSPGTPGSQGPQAAPGESIWMALVLLGLFLLLLLLLGSIILALLQRKNYRVRGEPVPEPRPDSGRDWSVELQELPELCFSQVIREGGHAVVWAGQLQGKLVAIKAFPPRSVAQFQAERALYELPGLQHDHIVRFITASRGGPGRLLSGPLLVLELHPKGSLCHYLTQYTSDWGSSLRMALSLAQGLAFLHEERWQNGQYKPGIAHRDLSSQNVLIREDGSCAIGDLGLALVLPGLTQPPAWTPTQPQGPAAIMEAGTQRYMAPELLDKTLDLQDWGMALRRADIYSLALLLWEILSRCPDLRPDSSPPPFQLAYEAELGNTPTSDELWALAVQERRRPYIPSTWRCFATDPDGLRELLEDCWDADPEARLTAECVQQRLAALAHPQESHPFPESCPRGCPPLCPEDCTSIPAPTILPCRPQRSACHFSVQQGPCSRNPQPACTLSPV',
|
| 26 |
+
'tfr': 'MMDQARSAFSNLFGGEPLSYTRFSLARQVDGDNSHVEMKLAVDEEENADNNTKANVTKPKRCSGSICYGTIAVIVFFLIGFMIGYLGYCKGVEPKTECERLAGTESPVREEPGEDFPAARRLYWDDLKRKLSEKLDSTDFTGTIKLLNENSYVPREAGSQKDENLALYVENQFREFKLSKVWRDQHFVKIQVKDSAQNSVIIVDKNGRLVYLVENPGGYVAYSKAATVTGKLVHANFGTKKDFEDLYTPVNGSIVIVRAGKITFAEKVANAESLNAIGVLIYMDQTKFPIVNAELSFFGHAHLGTGDPYTPGFPSFNHTQFPPSRSSGLPNIPVQTISRAAAEKLFGNMEGDCPSDWKTDSTCRMVTSESKNVKLTVSNVLKEIKILNIFGVIKGFVEPDHYVVVGAQRDAWGPGAAKSGVGTALLLKLAQMFSDMVLKDGFQPSRSIIFASWSAGDFGSVGATEWLEGYLSSLHLKAFTYINLDKAVLGTSNFKVSASPLLYTLIEKTMQNVKHPVTGQFLYQDSNWASKVEKLTLDNAAFPFLAYSGIPAVSFCFCEDTDYPYLGTTMDTYKELIERIPELNKVARAAAEVAGQFVIKLTHDVELNLDYERYNSQLLSFVRDLNQYRADIKEMGLSLQWLYSARGDFFRATSRLTTDFGNAEKTDRFVMKKLNDRVMRVEYHFLSPYVSPKESPFRHVFWGSGSHTLPALLENLKLRKQNNGAFNETLFRNQLALATWTIQGAANALSGDVWDIDNEF',
|
| 27 |
+
'gfap': 'MERRRITSAARRSYVSSGEMMVGGLAPGRRLGPGTRLSLARMPPPLPTRVDFSLAGALNAGFKETRASERAEMMELNDRFASYIEKVRFLEQQNKALAAELNQLRAKEPTKLADVYQAELRELRLRLDQLTANSARLEVERDNLAQDLATVRQKLQDETNLRLEAENNLAAYRQEADEATLARLDLERKIESLEEEIRFLRKIHEEEVRELQEQLARQQVHVELDVAKPDLTAALKEIRTQYEAMASSNMHEAEEWYRSKFADLTDAAARNAELLRQAKHEANDYRRQLQSLTCDLESLRGTNESLERQMREQEERHVREAASYQEALARLEEEGQSLKDEMARHLQEYQDLLNVKLALDIEIATYRKLLEGEENRITIPVQTFSNLQIRETSLDTKSVSEGHLKRNIVVKTVEMRDGEVIKESKQEHKDVM',
|
| 28 |
+
'glp1': 'MAGAPGPLRLALLLLGMVGRAGPRPQGATVSLWETVQKWREYRRQCQRSLTEDPPPATDLFCNRTFDEYACWPDGEPGSFVNVSCPWYLPWASSVPQGHVYRFCTAEGLWLQKDNSSLPWRDLSECEESKRGERSSPEEQLLFLYIIYTVGYALSFSALVIASAILLGFRHLHCTRNYIHLNLFASFILRALSVFIKDAALKWMYSTAAQQHQWDGLLSYQDSLSCRLVFLLMQYCVAANYYWLLVEGVYLYTLLAFSVLSEQWIFRLYVSIGWGVPLLFVVPWGIVKYLYEDEGCWTRNSNMNYWLIIRLPILFAIGVNFLIFVRVICIVVSKLKANLMCKTDIKCRLAKSTLTLIPLLGTHEVIFAFVMDEHARGTLRFIKLFTELSFTSFQGLMVAILYCFVNNEVQLEFRKSWERWRLEHLHIQRDSSMKPLKCPTSSLSSGATAGSSMYTATCQASCS',
|
| 29 |
+
'glast': 'MTKSNGEEPKMGGRMERFQQGVRKRTLLAKKKVQNITKEDVKSYLFRNAFVLLTVTAVIVGTILGFTLRPYRMSYREVKYFSFPGELLMRMLQMLVLPLIISSLVTGMAALDSKASGKMGMRAVVYYMTTTIIAVVIGIIIVIIIHPGKGTKENMHREGKIVRVTAADAFLDLIRNMFPPNLVEACFKQFKTNYEKRSFKVPIQANETLVGAVINNVSEAMETLTRITEELVPVPGSVNGVNALGLVVFSMCFGFVIGNMKEQGQALREFFDSLNEAIMRLVAVIMWYAPVGILFLIAGKIVEMEDMGVIGGQLAMYTVTVIVGLLIHAVIVLPLLYFLVTRKNPWVFIGGLLQALITALGTSSSSATLPITFKCLEENNGVDKRVTRFVLPVGATINMDGTALYEALAAIFIAQVNNFELNFGQIITISITATAASIGAAGIPQAGLVTMVIVLTSVGLPTDDITLIIAVDWFLDRLRTTTNVLGDSLGAGIVEHLSRHELKNRDVEMGNSVIEENEMKKPYQLIAQDNETEKPIDSETKM',
|
| 30 |
+
'ncam': 'LQTKDLIWTLFFLGTAVSLQVDIVPSQGEISVGESKFFLCQVAGDAKDKDISWFSPNGEKLTPNQQRISVVWNDDSSSTLTIYNANIDDAGIYKCVVTGEDGSESEATVNVKIFQKLMFKNAPTPQEFREGEDAVIVCDVVSSLPPTIIWKHKGRDVILKKDVRFIVLSNNYLQIRGIKKTDEGTYRCEGRILARGEINFKDIQVIVNVPPTIQARQNIVNATANLGQSVTLVCDAEGFPEPTMSWTKDGEQIEQEEDDEKYIFSDDSSQLTIKKVDKNDEAEYICIAENKAGEQDATIHLKVFAKPKITYVENQTAMELEEQVTLTCEASGDPIPSITWRTSTRNISSEEKASWTRPEKQETLDGHMVVRSHARVSSLTLKSIQYTDAGEYICTASNTIGQDSQSMYLEVQYAPKLQGPVAVYTWEGNQVNITCEVFAYPSATISWFRDGQLLPSSNYSNIKIYNTPSASYLEVTPDSENDFGNYNCTAVNRIGQESLEFILVQADTPSSPSIDQVEPYSSTAQVQFDEPEATGGVPILKYKAEWRAVGEEVWHSKWYDAKEASMEGIVTIVGLKPETTYAVRLAALNGKGLGEISAASEF',
|
| 31 |
+
'cereblon': 'MAGEGDQQDAAHNMGNHLPLLPAESEEEDEMEVEDQDSKEAKKPNIINFDTSLPTSHTYLGADMEEFHGRTLHDDDSCQVIPVLPQVMMILIPGQTLPLQLFHPQEVSMVRNLIQKDRTFAVLAYSNVQEREAQFGTTAEIYAYREEQDFGIEIVKVKAIGRQRFKVLELRTQSDGIQQAKVQILPECVLPSTMSAVQLESLNKCQIFPSKPVSREDQCSYKWWQKYQKRKFHCANLTSWPRWLYSLYDAETLMDRIKKQLREWDENLKDDSLPSNPIDFSYRVAACLPIDDVLRIQLLKIGSAIQRLRCELDIMNKCTSLCCKQCQETEITTKNEIFSLSLCGPMAAYVNPHGYVHETLTVYKACNLNLIGRPSTEHSWFPGYAWTVAQCKICASHIGWKFTATKKDMSPQKFWGLTRSALLPTIPDTEDEISPDKVILCL',
|
| 32 |
+
'ligase': 'MASQPPEDTAESQASDELECKICYNRYNLKQRKPKVLECCHRVCAKCLYKIIDFGDSPQGVIVCPFCRFETCLPDDEVSSLPDDNNILVNLTCGGKGKKCLPENPTELLLTPKRLASLVSPSHTSSNCLVITIMEVQRESSPSLSSTPVVEFYRPASFDSVTTVSHNWTVWNCTSLLFQTSIRVLVWLLGLLYFSSLPLGIYLLVSKKVTLGVVFVSLVPSSLVILMVYGFCQCVCHEFLDCMAPPS',
|
| 33 |
+
'skp2': 'MHRKHLQEIPDLSSNVATSFTWGWDSSKTSELLSGMGVSALEKEEPDSENIPQELLSNLGHPESPPRKRLKSKGSDKDFVIVRRPKLNRENFPGVSWDSLPDELLLGIFSCLCLPELLKVSGVCKRWYRLASDESLWQTLDLTGKNLHPDVTGRLLSQGVIAFRCPRSFMDQPLAEHFSPFRVQHMDLSNSVIEVSTLHGILSQCSKLQNLSLEGLRLSDPIVNTLAKNSNLVRLNLSGCSGFSEFALQTLLSSCSRLDELNLSWCFDFTEKHVQVAVAHVSETITQLNLSGYRKNLQKSDLSTLVRRCPNLVHLDLSDSVMLKNDCFQEFFQLNYLQHLSLSRCYDIIPETLLELGEIPTLKTLQVFGIVPDGTLQLLKEALPHLQINCSHFTTIARPTIGNKKNQEIWGIKCRLTLQKPSCL',
|
| 34 |
+
'p53': 'MEEPQSDPSVEPPLSQETFSDLWKLLPENNVLSPLPSQAMDDLMLSPDDIEQWFTEDPGPDEAPRMPEAAPPVAPAPAAPTPAAPAPAPSWPLSSSVPSQKTYQGSYGFRLGFLHSGTAKSVTCTYSPALNKMFCQLAKTCPVQLWVDSTPPPGTRVRAMAIYKQSQHMTEVVRRCPHHERCSDSDGLAPPQHLIRVEGNLRVEYLDDRNTFRHSVVVPYEPPEVGSDCTTIHYNYMCNSSCMGGMNRRPILTIITLEDSSGNLLGRNSFEVRVCACPGRDRRTEEENLRKKGEPHHELPPGSTKRALPNNTSSSPQPKKKPLDGEYFTLQIRGRERFEMFRELNEALELKDAQAGKEPGGSRAHSSHLKSKKGQSTSRHKKLMFKTEGPDSD',
|
| 35 |
+
'egfp': 'VSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTLTYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITLGMDELYK'
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
def save_logs_to_file(config, valid_fraction_log, score_logs, output_path):
|
| 39 |
+
"""
|
| 40 |
+
Saves the logs to a CSV file.
|
| 41 |
+
|
| 42 |
+
Parameters:
|
| 43 |
+
valid_fraction_log (list): Log of valid fractions over iterations.
|
| 44 |
+
score_logs (dict): Dict mapping score func names to lists of scores.
|
| 45 |
+
output_path (str): Path to save the log CSV file.
|
| 46 |
+
"""
|
| 47 |
+
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
| 48 |
+
|
| 49 |
+
log_data = {
|
| 50 |
+
"Iteration": list(range(1, len(valid_fraction_log) + 1)),
|
| 51 |
+
"Valid Fraction": valid_fraction_log,
|
| 52 |
+
}
|
| 53 |
+
for name, log in score_logs.items():
|
| 54 |
+
log_data[name] = log
|
| 55 |
+
|
| 56 |
+
df = pd.DataFrame(log_data)
|
| 57 |
+
|
| 58 |
+
# Save to CSV
|
| 59 |
+
df.to_csv(output_path, index=False)
|
| 60 |
+
|
| 61 |
+
def plot_data(log1, log2=None,
|
| 62 |
+
save_path=None,
|
| 63 |
+
label1="Log 1",
|
| 64 |
+
label2=None,
|
| 65 |
+
title="Fraction of Valid Peptides Over Iterations",
|
| 66 |
+
palette=None):
|
| 67 |
+
"""
|
| 68 |
+
Plots one or two datasets with their mean values over iterations.
|
| 69 |
+
|
| 70 |
+
Parameters:
|
| 71 |
+
log1 (list): The first list of mean values for each iteration.
|
| 72 |
+
log2 (list, optional): The second list of mean values for each iteration. Defaults to None.
|
| 73 |
+
save_path (str): Path to save the plot. Defaults to None.
|
| 74 |
+
label1 (str): Label for the first dataset. Defaults to "Log 1".
|
| 75 |
+
label2 (str, optional): Label for the second dataset. Defaults to None.
|
| 76 |
+
title (str): Title of the plot. Defaults to "Mean Values Over Iterations".
|
| 77 |
+
palette (dict, optional): A dictionary defining custom colors for datasets. Defaults to None.
|
| 78 |
+
"""
|
| 79 |
+
# Prepare data for log1
|
| 80 |
+
data1 = pd.DataFrame({
|
| 81 |
+
"Iteration": range(1, len(log1) + 1),
|
| 82 |
+
"Fraction of Valid Peptides": log1,
|
| 83 |
+
"Dataset": label1
|
| 84 |
+
})
|
| 85 |
+
|
| 86 |
+
# Prepare data for log2 if provided
|
| 87 |
+
if log2 is not None:
|
| 88 |
+
data2 = pd.DataFrame({
|
| 89 |
+
"Iteration": range(1, len(log2) + 1),
|
| 90 |
+
"Fraction of Valid Peptides": log2,
|
| 91 |
+
"Dataset": label2
|
| 92 |
+
})
|
| 93 |
+
data = pd.concat([data1, data2], ignore_index=True)
|
| 94 |
+
else:
|
| 95 |
+
data = data1
|
| 96 |
+
|
| 97 |
+
palette = {
|
| 98 |
+
label1: "#8181ED", # Default color for log1
|
| 99 |
+
label2: "#D577FF" # Default color for log2 (if provided)
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
# Set Seaborn theme
|
| 103 |
+
sns.set_theme()
|
| 104 |
+
sns.set_context("paper")
|
| 105 |
+
|
| 106 |
+
# Create the plot
|
| 107 |
+
sns.lineplot(
|
| 108 |
+
data=data,
|
| 109 |
+
x="Iteration",
|
| 110 |
+
y="Fraction of Valid Peptides",
|
| 111 |
+
hue="Dataset",
|
| 112 |
+
style="Dataset",
|
| 113 |
+
markers=True,
|
| 114 |
+
dashes=False,
|
| 115 |
+
palette=palette
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
# Titles and labels
|
| 119 |
+
plt.title(title)
|
| 120 |
+
plt.xlabel("Iteration")
|
| 121 |
+
plt.ylabel("Fraction of Valid Peptides")
|
| 122 |
+
|
| 123 |
+
if save_path:
|
| 124 |
+
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
| 125 |
+
print(f"Plot saved to {save_path}")
|
| 126 |
+
plt.show()
|
| 127 |
+
|
| 128 |
+
def plot_data_with_distribution_seaborn(log1, log2=None,
|
| 129 |
+
save_path=None,
|
| 130 |
+
label1=None,
|
| 131 |
+
label2=None,
|
| 132 |
+
title=None):
|
| 133 |
+
"""
|
| 134 |
+
Plots one or two datasets with the average values and distributions over iterations using Seaborn.
|
| 135 |
+
|
| 136 |
+
Parameters:
|
| 137 |
+
log1 (list of lists): The first list of scores (each element is a list of scores for an iteration).
|
| 138 |
+
log2 (list of lists, optional): The second list of scores (each element is a list of scores for an iteration). Defaults to None.
|
| 139 |
+
save_path (str): Path to save the plot. Defaults to None.
|
| 140 |
+
label1 (str): Label for the first dataset. Defaults to "Fraction of Valid Peptide SMILES".
|
| 141 |
+
label2 (str, optional): Label for the second dataset. Defaults to None.
|
| 142 |
+
title (str): Title of the plot. Defaults to "Fraction of Valid Peptides Over Iterations".
|
| 143 |
+
"""
|
| 144 |
+
# Prepare data for log1
|
| 145 |
+
data1 = pd.DataFrame({
|
| 146 |
+
"Iteration": np.repeat(range(1, len(log1) + 1), [len(scores) for scores in log1]),
|
| 147 |
+
"Fraction of Valid Peptides": [float(score) for scores in log1 for score in scores],
|
| 148 |
+
"Dataset": label1,
|
| 149 |
+
"Style": "Log1"
|
| 150 |
+
})
|
| 151 |
+
|
| 152 |
+
# Prepare data for log2 if provided
|
| 153 |
+
if log2 is not None:
|
| 154 |
+
data2 = pd.DataFrame({
|
| 155 |
+
"Iteration": np.repeat(range(1, len(log2) + 1), [len(scores) for scores in log2]),
|
| 156 |
+
"Fraction of Valid Peptides": [float(score) for scores in log2 for score in scores],
|
| 157 |
+
"Dataset": label2,
|
| 158 |
+
"Style": "Log2"
|
| 159 |
+
})
|
| 160 |
+
data = pd.concat([data1, data2], ignore_index=True)
|
| 161 |
+
else:
|
| 162 |
+
data = data1
|
| 163 |
+
|
| 164 |
+
palette = {
|
| 165 |
+
label1: "#8181ED", # Default color for log1
|
| 166 |
+
label2: "#D577FF" # Default color for log2 (if provided)
|
| 167 |
+
}
|
| 168 |
+
|
| 169 |
+
# Set Seaborn theme
|
| 170 |
+
sns.set_theme()
|
| 171 |
+
sns.set_context("paper")
|
| 172 |
+
|
| 173 |
+
# Create the plot
|
| 174 |
+
sns.relplot(
|
| 175 |
+
data=data,
|
| 176 |
+
kind="line",
|
| 177 |
+
x="Iteration",
|
| 178 |
+
y="Fraction of Valid Peptides",
|
| 179 |
+
hue="Dataset",
|
| 180 |
+
style="Style",
|
| 181 |
+
markers=True,
|
| 182 |
+
dashes=True,
|
| 183 |
+
ci="sd", # Show standard deviation
|
| 184 |
+
height=5,
|
| 185 |
+
aspect=1.5,
|
| 186 |
+
palette=palette
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
# Titles and labels
|
| 190 |
+
plt.title(title)
|
| 191 |
+
plt.xlabel("Iteration")
|
| 192 |
+
plt.ylabel("Fraction of Valid Peptides")
|
| 193 |
+
|
| 194 |
+
if save_path:
|
| 195 |
+
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
| 196 |
+
print(f"Plot saved to {save_path}")
|
| 197 |
+
plt.show()
|
| 198 |
+
|
| 199 |
+
@torch.no_grad()
|
| 200 |
+
def generate_valid_mcts(config, mdlm, prot1=None, prot2=None, filename=None, prot_name1=None, prot_name2 = None):
|
| 201 |
+
tokenizer = mdlm.tokenizer
|
| 202 |
+
max_sequence_length = config.sampling.seq_length
|
| 203 |
+
|
| 204 |
+
# generate array of [MASK] tokens
|
| 205 |
+
masked_array = mask_for_de_novo(config, max_sequence_length)
|
| 206 |
+
|
| 207 |
+
inputs = tokenizer.encode(masked_array)
|
| 208 |
+
|
| 209 |
+
inputs = {key: value.to(mdlm.device) for key, value in inputs.items()}
|
| 210 |
+
|
| 211 |
+
# initialize root node
|
| 212 |
+
rootNode = Node(config=config, tokens=inputs, timestep=0)
|
| 213 |
+
# initalize tree search algorithm
|
| 214 |
+
|
| 215 |
+
if config.mcts.perm:
|
| 216 |
+
score_func_names = ['permeability', 'binding_affinity1', 'solubility', 'hemolysis', 'nonfouling']
|
| 217 |
+
num_func = [0, 0, 0, 0, 0]
|
| 218 |
+
elif config.mcts.dual:
|
| 219 |
+
score_func_names = ['binding_affinity1', 'solubility', 'hemolysis', 'nonfouling', 'binding_affinity2']
|
| 220 |
+
num_func = [0, 0, 0, 0, 0]
|
| 221 |
+
elif config.mcts.single:
|
| 222 |
+
if config.mode == 'binding':
|
| 223 |
+
score_func_names = ['binding_affinity1']
|
| 224 |
+
else:
|
| 225 |
+
score_func_names = ['permeability']
|
| 226 |
+
num_func = [0]
|
| 227 |
+
else:
|
| 228 |
+
score_func_names = ['binding_affinity1', 'solubility', 'hemolysis', 'nonfouling']
|
| 229 |
+
num_func = [0, 0, 0, 0]
|
| 230 |
+
|
| 231 |
+
if not config.mcts.time_dependent:
|
| 232 |
+
num_func = [0] * len(score_func_names)
|
| 233 |
+
|
| 234 |
+
if prot1 and prot2 is not None:
|
| 235 |
+
mcts = MCTS(config=config, max_sequence_length=max_sequence_length, mdlm=mdlm, score_func_names=score_func_names, prot_seqs=[prot1, prot2], num_func=num_func)
|
| 236 |
+
elif prot1 is not None:
|
| 237 |
+
mcts = MCTS(config=config, max_sequence_length=max_sequence_length, mdlm=mdlm, score_func_names=score_func_names, prot_seqs=[prot1], num_func=num_func)
|
| 238 |
+
elif config.mcts.single:
|
| 239 |
+
mcts = MCTS(config=config, max_sequence_length=max_sequence_length, mdlm=mdlm, score_func_names=score_func_names, num_func=num_func)
|
| 240 |
+
else:
|
| 241 |
+
mcts = MCTS(config=config, max_sequence_length=max_sequence_length, mdlm=mdlm, score_func_names=score_func_names, num_func=num_func)
|
| 242 |
+
|
| 243 |
+
paretoFront = mcts.forward(rootNode)
|
| 244 |
+
|
| 245 |
+
output_log_path = f'{config.base_path}/{prot_name1}/log_{filename}.csv'
|
| 246 |
+
save_logs_to_file(config, mcts.valid_fraction_log, mcts.score_logs, output_log_path)
|
| 247 |
+
|
| 248 |
+
plot_data(mcts.valid_fraction_log,
|
| 249 |
+
save_path=f'{config.base_path}/{prot_name1}/valid_{filename}.png')
|
| 250 |
+
|
| 251 |
+
for name in mcts.score_func_names:
|
| 252 |
+
plot_data_with_distribution_seaborn(log1=mcts.score_logs[name],
|
| 253 |
+
save_path=f'{config.base_path}/{prot_name1}/{name}_{filename}.png',
|
| 254 |
+
label1=f"Average {name}",
|
| 255 |
+
title=f"Average {name} Over Iterations")
|
| 256 |
+
|
| 257 |
+
return paretoFront, inputs
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
@hydra.main(version_base=None, config_path='.', config_name='config')
|
| 261 |
+
def main(config):
|
| 262 |
+
# Get parameters from config with defaults
|
| 263 |
+
prot_name1 = config.get('prot_name1', 'gfap')
|
| 264 |
+
prot_name2 = config.get('prot_name2', None)
|
| 265 |
+
mode = config.get('mode', '2')
|
| 266 |
+
model = config.get('model_type', 'mcts')
|
| 267 |
+
length = config.get('length', '100')
|
| 268 |
+
epoch = config.get('epoch', '7')
|
| 269 |
+
|
| 270 |
+
filename = f'{mode}_{model}_length_{length}_epoch_{epoch}'
|
| 271 |
+
|
| 272 |
+
tokenizer = SMILES_SPE_Tokenizer(f'{config.base_path}/src/tokenizer/new_vocab.txt',
|
| 273 |
+
f'{config.base_path}/src/tokenizer/new_splits.txt')
|
| 274 |
+
|
| 275 |
+
mdlm = Diffusion.load_from_checkpoint(config.eval.checkpoint_path, config=config, tokenizer=tokenizer, strict=False)
|
| 276 |
+
|
| 277 |
+
mdlm.eval()
|
| 278 |
+
device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
|
| 279 |
+
mdlm.to(device)
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
print("loaded models...")
|
| 283 |
+
analyzer = PeptideAnalyzer()
|
| 284 |
+
|
| 285 |
+
# Look up protein sequences from names
|
| 286 |
+
prot_seq1 = PROTEIN_SEQUENCES.get(prot_name1.lower())
|
| 287 |
+
prot_seq2 = PROTEIN_SEQUENCES.get(prot_name2.lower()) if prot_name2 else None
|
| 288 |
+
|
| 289 |
+
if prot_seq1 is None:
|
| 290 |
+
raise ValueError(f"Protein '{prot_name1}' not found in PROTEIN_SEQUENCES dictionary. Available proteins: {list(PROTEIN_SEQUENCES.keys())}")
|
| 291 |
+
|
| 292 |
+
if prot_name2 and prot_seq2 is None:
|
| 293 |
+
raise ValueError(f"Protein '{prot_name2}' not found in PROTEIN_SEQUENCES dictionary. Available proteins: {list(PROTEIN_SEQUENCES.keys())}")
|
| 294 |
+
|
| 295 |
+
print(f"Using protein 1: {prot_name1}")
|
| 296 |
+
if prot_name2:
|
| 297 |
+
print(f"Using protein 2: {prot_name2}")
|
| 298 |
+
|
| 299 |
+
t_start = time.time()
|
| 300 |
+
paretoFront, input_array = generate_valid_mcts(config, mdlm, prot_seq1, prot_seq2, filename, prot_name1, prot_name2)
|
| 301 |
+
generation_results = []
|
| 302 |
+
|
| 303 |
+
for sequence, v in paretoFront.items():
|
| 304 |
+
generated_array = v['token_ids'].to(mdlm.device)
|
| 305 |
+
|
| 306 |
+
# compute perplexity
|
| 307 |
+
perplexity = mdlm.compute_masked_perplexity(generated_array, input_array['input_ids'])
|
| 308 |
+
perplexity = round(perplexity, 4)
|
| 309 |
+
|
| 310 |
+
aa_seq, seq_length = analyzer.analyze_structure(sequence)
|
| 311 |
+
scores = v['scores']
|
| 312 |
+
|
| 313 |
+
if config.mcts.single == False:
|
| 314 |
+
binding1 = scores[0]
|
| 315 |
+
solubility = scores[1]
|
| 316 |
+
hemo = scores[2]
|
| 317 |
+
nonfouling = scores[3]
|
| 318 |
+
|
| 319 |
+
if config.mcts.perm:
|
| 320 |
+
permeability = scores[4]
|
| 321 |
+
generation_results.append([sequence, perplexity, aa_seq, binding1, solubility, hemo, nonfouling, permeability])
|
| 322 |
+
print(f"perplexity: {perplexity} | length: {seq_length} | smiles sequence: {sequence} | amino acid sequence: {aa_seq} | Binding Affinity: {binding1} | Solubility: {solubility} | Hemolysis: {hemo} | Nonfouling: {nonfouling} | Permeability: {permeability}")
|
| 323 |
+
elif config.mcts.dual:
|
| 324 |
+
binding2 = scores[4]
|
| 325 |
+
generation_results.append([sequence, perplexity, aa_seq, binding1, binding2, solubility, hemo, nonfouling])
|
| 326 |
+
print(f"perplexity: {perplexity} | length: {seq_length} | smiles sequence: {sequence} | amino acid sequence: {aa_seq} | Binding Affinity 1: {binding1} | Binding Affinity 2: {binding2} | Solubility: {solubility} | Hemolysis: {hemo} | Nonfouling: {nonfouling}")
|
| 327 |
+
elif config.mcts.single:
|
| 328 |
+
permeability = scores[0]
|
| 329 |
+
else:
|
| 330 |
+
generation_results.append([sequence, perplexity, aa_seq, binding1, solubility, hemo, nonfouling])
|
| 331 |
+
print(f"perplexity: {perplexity} | length: {seq_length} | smiles sequence: {sequence} | amino acid sequence: {aa_seq} | Binding Affinity: {binding1} | Solubility: {solubility} | Hemolysis: {hemo} | Nonfouling: {nonfouling}")
|
| 332 |
+
|
| 333 |
+
sys.stdout.flush()
|
| 334 |
+
|
| 335 |
+
if config.mcts.perm:
|
| 336 |
+
df = pd.DataFrame(generation_results, columns=['Generated SMILES', 'Perplexity', 'Peptide Sequence', 'Binding Affinity', 'Solubility', 'Hemolysis', 'Nonfouling', 'Permeability'])
|
| 337 |
+
elif config.mcts.dual:
|
| 338 |
+
df = pd.DataFrame(generation_results, columns=['Generated SMILES', 'Perplexity', 'Peptide Sequence', 'Binding Affinity 1', 'Binding Affinity 2', 'Solubility', 'Hemolysis', 'Nonfouling'])
|
| 339 |
+
elif config.mcts.single:
|
| 340 |
+
df = pd.DataFrame(generation_results, columns=['Generated SMILES', 'Perplexity', 'Peptide Sequence', 'Permeability'])
|
| 341 |
+
else:
|
| 342 |
+
df = pd.DataFrame(generation_results, columns=['Generated SMILES', 'Perplexity', 'Peptide Sequence', 'Binding Affinity', 'Solubility', 'Hemolysis', 'Nonfouling'])
|
| 343 |
+
|
| 344 |
+
df.to_csv(f'{config.base_path}/{prot_name1}/{filename}.csv', index=False)
|
| 345 |
+
|
| 346 |
+
# ── timing ──
|
| 347 |
+
elapsed = time.time() - t_start
|
| 348 |
+
print(f"\n{'='*60}")
|
| 349 |
+
print(f"Generation complete in {elapsed:.1f}s ({elapsed/60:.1f} min)")
|
| 350 |
+
print(f"Pareto front size: {len(df)}")
|
| 351 |
+
|
| 352 |
+
# ── score statistics ──
|
| 353 |
+
score_cols = [c for c in df.columns if c not in ('Generated SMILES', 'Peptide Sequence')]
|
| 354 |
+
print(f"\n{'Score':<22} {'Mean':>8} {'Std':>8} {'Min':>8} {'Max':>8}")
|
| 355 |
+
print('-' * 58)
|
| 356 |
+
for col in score_cols:
|
| 357 |
+
vals = pd.to_numeric(df[col], errors='coerce').dropna()
|
| 358 |
+
if len(vals) == 0:
|
| 359 |
+
continue
|
| 360 |
+
print(f"{col:<22} {vals.mean():8.4f} {vals.std():8.4f} {vals.min():8.4f} {vals.max():8.4f}")
|
| 361 |
+
print('=' * 60)
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
if __name__ == "__main__":
|
| 365 |
+
main()
|
src/generate_unconditional.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import sys
|
| 5 |
+
import pandas as pd
|
| 6 |
+
import omegaconf
|
| 7 |
+
from utils.generate_utils import mask_for_de_novo, calculate_cosine_sim, calculate_hamming_dist
|
| 8 |
+
from diffusion import Diffusion
|
| 9 |
+
import hydra
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
|
| 12 |
+
from utils.app import PeptideAnalyzer
|
| 13 |
+
from scoring.scoring_functions import ScoringFunctions
|
| 14 |
+
|
| 15 |
+
# Register custom OmegaConf resolvers required by config.yaml
|
| 16 |
+
omegaconf.OmegaConf.register_new_resolver('cwd', os.getcwd, replace=True)
|
| 17 |
+
omegaconf.OmegaConf.register_new_resolver('device_count', torch.cuda.device_count, replace=True)
|
| 18 |
+
omegaconf.OmegaConf.register_new_resolver('eval', eval, replace=True)
|
| 19 |
+
omegaconf.OmegaConf.register_new_resolver('div_up', lambda x, y: (x + y - 1) // y, replace=True)
|
| 20 |
+
|
| 21 |
+
base_path = '/path/to/your/home/PepTune'
|
| 22 |
+
ckpt_path = base_path + '/checkpoints/peptune-pretrained.ckpt'
|
| 23 |
+
|
| 24 |
+
@torch.no_grad()
|
| 25 |
+
def generate_sequence_unconditional(config, sequence_length: int, mdlm: Diffusion):
|
| 26 |
+
tokenizer = mdlm.tokenizer
|
| 27 |
+
# generate array of [MASK] tokens
|
| 28 |
+
masked_array = mask_for_de_novo(config, sequence_length)
|
| 29 |
+
|
| 30 |
+
inputs = tokenizer.encode(masked_array)
|
| 31 |
+
|
| 32 |
+
# tokenized masked array
|
| 33 |
+
inputs = {key: value.to(mdlm.device) for key, value in inputs.items()}
|
| 34 |
+
# sample unconditional array of tokens
|
| 35 |
+
logits = mdlm._sample(x_input=inputs) # using sample, change config.sampling.steps to determine robustness
|
| 36 |
+
|
| 37 |
+
return logits, inputs
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
@hydra.main(version_base=None, config_path='.', config_name='config')
|
| 41 |
+
def main(config):
|
| 42 |
+
|
| 43 |
+
tokenizer = SMILES_SPE_Tokenizer(f'{base_path}/src/tokenizer/new_vocab.txt',
|
| 44 |
+
f'{base_path}/src/tokenizer/new_splits.txt')
|
| 45 |
+
|
| 46 |
+
# Build model with current config, then load weights manually
|
| 47 |
+
# (load_from_checkpoint overrides config with saved hparams)
|
| 48 |
+
mdlm_model = Diffusion(config=config, tokenizer=tokenizer)
|
| 49 |
+
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
|
| 50 |
+
mdlm_model.load_state_dict(ckpt["state_dict"], strict=False)
|
| 51 |
+
|
| 52 |
+
mdlm_model.eval()
|
| 53 |
+
device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
|
| 54 |
+
mdlm_model.to(device)
|
| 55 |
+
|
| 56 |
+
print("loaded models...")
|
| 57 |
+
analyzer = PeptideAnalyzer()
|
| 58 |
+
|
| 59 |
+
gfap = 'MERRRITSAARRSYVSSGEMMVGGLAPGRRLGPGTRLSLARMPPPLPTRVDFSLAGALNAGFKETRASERAEMMELNDRFASYIEKVRFLEQQNKALAAELNQLRAKEPTKLADVYQAELRELRLRLDQLTANSARLEVERDNLAQDLATVRQKLQDETNLRLEAENNLAAYRQEADEATLARLDLERKIESLEEEIRFLRKIHEEEVRELQEQLARQQVHVELDVAKPDLTAALKEIRTQYEAMASSNMHEAEEWYRSKFADLTDAAARNAELLRQAKHEANDYRRQLQSLTCDLESLRGTNESLERQMREQEERHVREAASYQEALARLEEEGQSLKDEMARHLQEYQDLLNVKLALDIEIATYRKLLEGEENRITIPVQTFSNLQIRETSLDTKSVSEGHLKRNIVVKTVEMRDGEVIKESKQEHKDVM'
|
| 60 |
+
|
| 61 |
+
# scoring functions
|
| 62 |
+
score_func_names = ['binding_affinity1', 'solubility', 'hemolysis', 'nonfouling', 'permeability']
|
| 63 |
+
score_functions = ScoringFunctions(score_func_names, [gfap])
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
max_seq_length = config.sampling.seq_length
|
| 67 |
+
num_sequences = config.sampling.num_sequences
|
| 68 |
+
generation_results = []
|
| 69 |
+
num_valid = 0.
|
| 70 |
+
num_total = 0.
|
| 71 |
+
while num_total < num_sequences:
|
| 72 |
+
num_total += 1
|
| 73 |
+
generated_array, input_array = generate_sequence_unconditional(config, max_seq_length, mdlm_model)
|
| 74 |
+
|
| 75 |
+
# store in device
|
| 76 |
+
generated_array = generated_array.to(mdlm_model.device)
|
| 77 |
+
print(generated_array)
|
| 78 |
+
|
| 79 |
+
# compute masked perplexity
|
| 80 |
+
perplexity = mdlm_model.compute_masked_perplexity(generated_array, input_array['input_ids'])
|
| 81 |
+
perplexity = round(perplexity, 4)
|
| 82 |
+
|
| 83 |
+
smiles_seq = tokenizer.decode(generated_array)
|
| 84 |
+
if analyzer.is_peptide(smiles_seq):
|
| 85 |
+
aa_seq, seq_length = analyzer.analyze_structure(smiles_seq)
|
| 86 |
+
num_valid += 1
|
| 87 |
+
scores = score_functions(input_seqs=[smiles_seq])
|
| 88 |
+
|
| 89 |
+
binding = scores[0][0]
|
| 90 |
+
sol = scores[0][1]
|
| 91 |
+
hemo = scores[0][2]
|
| 92 |
+
nf = scores[0][3]
|
| 93 |
+
perm = scores[0][4]
|
| 94 |
+
|
| 95 |
+
generation_results.append([smiles_seq, perplexity, aa_seq, binding, sol, hemo, nf, perm])
|
| 96 |
+
else:
|
| 97 |
+
aa_seq = "not valid peptide"
|
| 98 |
+
seq_length = '-'
|
| 99 |
+
scores = "not valid peptide"
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
print(f"perplexity: {perplexity} | length: {seq_length} | smiles sequence: {smiles_seq} | amino acid sequence: {aa_seq} | scores: {scores}")
|
| 103 |
+
sys.stdout.flush()
|
| 104 |
+
|
| 105 |
+
valid_frac = num_valid / num_total
|
| 106 |
+
print(f"fraction of synthesizable peptides: {valid_frac}")
|
| 107 |
+
df = pd.DataFrame(generation_results, columns=['Generated SMILES', 'Perplexity', 'Peptide Sequence', 'Binding Affinity', 'Solubility', 'Hemolysis', 'Nonfouling', 'Permeability'])
|
| 108 |
+
df.to_csv(base_path + f'/results/test_generate.csv', index=False)
|
| 109 |
+
|
| 110 |
+
if __name__ == "__main__":
|
| 111 |
+
main()
|
src/metrics.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
from math import sqrt
|
| 3 |
+
|
| 4 |
+
def summarize_metrics(skip, csv_path: str, save_path: str | None = None) -> pd.DataFrame:
|
| 5 |
+
"""
|
| 6 |
+
Compute mean and standard deviation for all columns except the first
|
| 7 |
+
(assumed non-numeric identifier like 'Peptide Sequence').
|
| 8 |
+
|
| 9 |
+
Returns a DataFrame with rows = column names and columns = ['mean','std','count'].
|
| 10 |
+
Uses sample std (ddof=1). Non-numeric cells are coerced to NaN.
|
| 11 |
+
"""
|
| 12 |
+
df = pd.read_csv(csv_path)
|
| 13 |
+
vals = df.iloc[:, skip:].apply(pd.to_numeric, errors='coerce') # columns 2..end
|
| 14 |
+
stats = vals.agg(['mean', 'std', 'count']).T # shape: (num_metrics, 3)
|
| 15 |
+
if save_path:
|
| 16 |
+
stats.to_csv(save_path, index=True)
|
| 17 |
+
return stats
|
| 18 |
+
|
| 19 |
+
def summarize_list(xs, ddof = 1):
|
| 20 |
+
# Clean & coerce to float
|
| 21 |
+
vals = []
|
| 22 |
+
for x in xs:
|
| 23 |
+
if x is None or x == "":
|
| 24 |
+
continue
|
| 25 |
+
try:
|
| 26 |
+
vals.append(float(x))
|
| 27 |
+
except (TypeError, ValueError):
|
| 28 |
+
continue
|
| 29 |
+
|
| 30 |
+
n = len(vals)
|
| 31 |
+
if n == 0:
|
| 32 |
+
raise ValueError("No numeric values found.")
|
| 33 |
+
if n <= ddof:
|
| 34 |
+
raise ValueError(f"Need at least {ddof + 1} numeric values; got {n}.")
|
| 35 |
+
|
| 36 |
+
# Welford’s algorithm (one pass, stable)
|
| 37 |
+
mean = 0.0
|
| 38 |
+
M2 = 0.0
|
| 39 |
+
count = 0
|
| 40 |
+
for v in vals:
|
| 41 |
+
count += 1
|
| 42 |
+
delta = v - mean
|
| 43 |
+
mean += delta / count
|
| 44 |
+
M2 += delta * (v - mean)
|
| 45 |
+
|
| 46 |
+
var = M2 / (count - ddof)
|
| 47 |
+
std = sqrt(var)
|
| 48 |
+
|
| 49 |
+
result = {"mean": mean, "std": std, "count": count}
|
| 50 |
+
|
| 51 |
+
return result
|
| 52 |
+
|
| 53 |
+
def csv_column_to_list(path: str, column: str, *, dropna: bool = True):
|
| 54 |
+
df = pd.read_csv(path)
|
| 55 |
+
if column not in df.columns:
|
| 56 |
+
raise KeyError(f"Column '{column}' not found. Available: {list(df.columns)}")
|
| 57 |
+
s = df[column]
|
| 58 |
+
if dropna:
|
| 59 |
+
s = s.dropna()
|
| 60 |
+
return s.tolist()
|
| 61 |
+
|
| 62 |
+
def main():
|
| 63 |
+
csv_path = "/scratch/pranamlab/sophtang/home/tr2d2/peptides/plots/glast_resample20_no-mcts/"
|
| 64 |
+
path = "/scratch/pranamlab/sophtang/home/TR2-D2/tr2d2-pep/results/tfr_resample10_buffer20_numiter10_children50_20260326_183626"
|
| 65 |
+
prot_name = "tfr"
|
| 66 |
+
stats = summarize_metrics(skip=1, csv_path=f"{path}/{prot_name}_generation_results.csv",
|
| 67 |
+
save_path=f"{path}/results_summary.csv")
|
| 68 |
+
|
| 69 |
+
print(stats)
|
| 70 |
+
|
| 71 |
+
if __name__ == '__main__':
|
| 72 |
+
main()
|
src/noise_schedule.py
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## Adapted from MDLM: https://github.com/kuleshov-group/mdlm
|
| 2 |
+
|
| 3 |
+
import abc
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
|
| 8 |
+
torch._C._jit_set_profiling_mode(False)
|
| 9 |
+
torch._C._jit_set_profiling_executor(False)
|
| 10 |
+
torch._C._jit_override_can_fuse_on_cpu(True)
|
| 11 |
+
torch._C._jit_override_can_fuse_on_gpu(True)
|
| 12 |
+
|
| 13 |
+
def get_noise(config, dtype=torch.float32):
|
| 14 |
+
if config.noise.type == 'geometric':
|
| 15 |
+
return GeometricNoise(config.noise.sigma_min, config.noise.sigma_max)
|
| 16 |
+
elif config.noise.type == 'loglinear':
|
| 17 |
+
return LogLinearNoise()
|
| 18 |
+
elif config.noise.type == 'cosine':
|
| 19 |
+
return CosineNoise()
|
| 20 |
+
elif config.noise.type == 'cosinesqr':
|
| 21 |
+
return CosineSqrNoise()
|
| 22 |
+
elif config.noise.type == 'linear':
|
| 23 |
+
return Linear(config.noise.sigma_min, config.noise.sigma_max, dtype)
|
| 24 |
+
else:
|
| 25 |
+
raise ValueError(f'{config.noise.type} is not a valid noise')
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def binary_discretization(z):
|
| 29 |
+
z_hard = torch.sign(z)
|
| 30 |
+
z_soft = z / torch.norm(z, dim=-1, keepdim=True)
|
| 31 |
+
return z_soft + (z_hard - z_soft).detach()
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class Noise(abc.ABC, nn.Module):
|
| 35 |
+
"""
|
| 36 |
+
Baseline forward method to get the total + rate of noise at a timestep
|
| 37 |
+
"""
|
| 38 |
+
def forward(self, t):
|
| 39 |
+
# Assume time goes from 0 to 1
|
| 40 |
+
return self.total_noise(t), self.rate_noise(t)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class CosineNoise(Noise):
|
| 44 |
+
def __init__(self, eps=1e-3):
|
| 45 |
+
super().__init__()
|
| 46 |
+
self.eps = eps
|
| 47 |
+
|
| 48 |
+
def rate_noise(self, t):
|
| 49 |
+
cos = (1 - self.eps) * torch.cos(t * torch.pi / 2)
|
| 50 |
+
sin = (1 - self.eps) * torch.sin(t * torch.pi / 2)
|
| 51 |
+
scale = torch.pi / 2
|
| 52 |
+
return scale * sin / (cos + self.eps)
|
| 53 |
+
|
| 54 |
+
def total_noise(self, t):
|
| 55 |
+
cos = torch.cos(t * torch.pi / 2)
|
| 56 |
+
return - torch.log(self.eps + (1 - self.eps) * cos)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class CosineSqrNoise(Noise):
|
| 60 |
+
def __init__(self, eps=1e-3):
|
| 61 |
+
super().__init__()
|
| 62 |
+
self.eps = eps
|
| 63 |
+
|
| 64 |
+
def rate_noise(self, t):
|
| 65 |
+
cos = (1 - self.eps) * (
|
| 66 |
+
torch.cos(t * torch.pi / 2) ** 2)
|
| 67 |
+
sin = (1 - self.eps) * torch.sin(t * torch.pi)
|
| 68 |
+
scale = torch.pi / 2
|
| 69 |
+
return scale * sin / (cos + self.eps)
|
| 70 |
+
|
| 71 |
+
def total_noise(self, t):
|
| 72 |
+
cos = torch.cos(t * torch.pi / 2) ** 2
|
| 73 |
+
return - torch.log(self.eps + (1 - self.eps) * cos)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class Linear(Noise):
|
| 77 |
+
def __init__(self, sigma_min=0, sigma_max=10, dtype=torch.float32):
|
| 78 |
+
super().__init__()
|
| 79 |
+
self.sigma_min = torch.tensor(sigma_min, dtype=dtype)
|
| 80 |
+
self.sigma_max = torch.tensor(sigma_max, dtype=dtype)
|
| 81 |
+
|
| 82 |
+
def rate_noise(self):
|
| 83 |
+
return self.sigma_max - self.sigma_min
|
| 84 |
+
|
| 85 |
+
def total_noise(self, t):
|
| 86 |
+
return self.sigma_min + t * (self.sigma_max - self.sigma_min)
|
| 87 |
+
|
| 88 |
+
def importance_sampling_transformation(self, t):
|
| 89 |
+
f_T = torch.log1p(- torch.exp(- self.sigma_max))
|
| 90 |
+
f_0 = torch.log1p(- torch.exp(- self.sigma_min))
|
| 91 |
+
sigma_t = - torch.log1p(- torch.exp(t * f_T + (1 - t) * f_0))
|
| 92 |
+
return (sigma_t - self.sigma_min) / (
|
| 93 |
+
self.sigma_max - self.sigma_min)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
class GeometricNoise(Noise):
|
| 97 |
+
def __init__(self, sigma_min=1e-3, sigma_max=1):
|
| 98 |
+
super().__init__()
|
| 99 |
+
self.sigmas = 1.0 * torch.tensor([sigma_min, sigma_max])
|
| 100 |
+
|
| 101 |
+
def rate_noise(self, t):
|
| 102 |
+
return self.sigmas[0] ** (1 - t) * self.sigmas[1] ** t * (
|
| 103 |
+
self.sigmas[1].log() - self.sigmas[0].log())
|
| 104 |
+
|
| 105 |
+
def total_noise(self, t):
|
| 106 |
+
return self.sigmas[0] ** (1 - t) * self.sigmas[1] ** t
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
class LogLinearNoise(Noise):
|
| 110 |
+
"""Log Linear noise schedule.
|
| 111 |
+
|
| 112 |
+
Built such that 1 - 1/e^(n(t)) interpolates between 0 and
|
| 113 |
+
~1 when t varies from 0 to 1. Total noise is
|
| 114 |
+
-log(1 - (1 - eps) * t), so the sigma will be
|
| 115 |
+
(1 - eps) * t.
|
| 116 |
+
"""
|
| 117 |
+
def __init__(self, eps=1e-3):
|
| 118 |
+
super().__init__()
|
| 119 |
+
self.eps = eps
|
| 120 |
+
self.sigma_max = self.total_noise(torch.tensor(1.0))
|
| 121 |
+
self.sigma_min = self.eps + self.total_noise(torch.tensor(0.0))
|
| 122 |
+
|
| 123 |
+
def rate_noise(self, t):
|
| 124 |
+
return (1 - self.eps) / (1 - (1 - self.eps) * t)
|
| 125 |
+
|
| 126 |
+
def total_noise(self, t):
|
| 127 |
+
return -torch.log1p(-(1 - self.eps) * t)
|
| 128 |
+
|
| 129 |
+
def importance_sampling_transformation(self, t):
|
| 130 |
+
f_T = torch.log1p(- torch.exp(- self.sigma_max))
|
| 131 |
+
f_0 = torch.log1p(- torch.exp(- self.sigma_min))
|
| 132 |
+
sigma_t = - torch.log1p(- torch.exp(t * f_T + (1 - t) * f_0))
|
| 133 |
+
t = - torch.expm1(- sigma_t) / (1 - self.eps)
|
| 134 |
+
return t
|
| 135 |
+
|
| 136 |
+
class LogPolyNoise(Noise):
|
| 137 |
+
"""
|
| 138 |
+
Log Polynomial noise schedule for slower masking of peptide bond tokens
|
| 139 |
+
"""
|
| 140 |
+
def __init__(self, eps=1e-3):
|
| 141 |
+
super().__init__()
|
| 142 |
+
self.eps = eps
|
| 143 |
+
self.sigma_max = self.total_noise(torch.tensor(1.0))
|
| 144 |
+
self.sigma_min = self.eps + self.total_noise(torch.tensor(0.0))
|
| 145 |
+
|
| 146 |
+
def rate_noise(self, t):
|
| 147 |
+
# derivative of -log(1-t^w)
|
| 148 |
+
return ((3 * (t**2)) - self.eps) / (1 - (1 - self.eps) * (t**3))
|
| 149 |
+
|
| 150 |
+
def total_noise(self, t):
|
| 151 |
+
# -log(1-t^w)
|
| 152 |
+
return -torch.log1p(-(1 - self.eps) * (t**3))
|
src/pareto_mcts.py
ADDED
|
@@ -0,0 +1,492 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import numpy as np
|
| 5 |
+
import random as rd
|
| 6 |
+
|
| 7 |
+
from diffusion import Diffusion
|
| 8 |
+
from scoring.scoring_functions import ScoringFunctions
|
| 9 |
+
from utils.app import PeptideAnalyzer
|
| 10 |
+
import noise_schedule
|
| 11 |
+
|
| 12 |
+
""""
|
| 13 |
+
Notes: store rolled out sequence?
|
| 14 |
+
path of node objects or strings?
|
| 15 |
+
should we only select valid expandable leaf nodes?
|
| 16 |
+
calculate similarity between sibling nodes?
|
| 17 |
+
should we evaluate generated sequences?
|
| 18 |
+
"""
|
| 19 |
+
class Node:
|
| 20 |
+
"""
|
| 21 |
+
Node class: partially unmasked SMILES string
|
| 22 |
+
- parentNode: Node object at previous time step
|
| 23 |
+
- childNodes: set of M Node objects generated from sampling M distinct unmasking schemes
|
| 24 |
+
- totalReward: vector of cumulative rewards for all K objectives
|
| 25 |
+
- visits: number of times the node has been visited by an interation
|
| 26 |
+
- path: array of partially unmasked SMILES strings leading to the node from the completely masked root node
|
| 27 |
+
- timestep: the time step where the sequence was sampled
|
| 28 |
+
- sampleProb: probability of sampling the sequence from the diffusion model
|
| 29 |
+
"""
|
| 30 |
+
def __init__(self, config, tokens=None, parentNode=None, childNodes=[], scoreVector=None, totalReward=None, timestep=None, sampleProb=None):
|
| 31 |
+
self.config = config
|
| 32 |
+
self.parentNode = parentNode
|
| 33 |
+
self.childNodes = childNodes
|
| 34 |
+
self.scoreVector = scoreVector
|
| 35 |
+
|
| 36 |
+
# initialize total rewards to the reward of the roll out unmasked sequence
|
| 37 |
+
if totalReward is not None:
|
| 38 |
+
self.totalReward = totalReward
|
| 39 |
+
else:
|
| 40 |
+
self.totalReward = np.zeros(self.config.mcts.num_objectives)
|
| 41 |
+
|
| 42 |
+
# set initial visits to 1
|
| 43 |
+
self.visits = 1
|
| 44 |
+
# array of all sequences in path from the root -> node
|
| 45 |
+
#self.path = path
|
| 46 |
+
# set timestep (value between 0 and num_steps)
|
| 47 |
+
self.timestep = timestep
|
| 48 |
+
# set the sampling probabiltiy equal to the probability from the reverse posterior
|
| 49 |
+
self.sampleProb = sampleProb
|
| 50 |
+
|
| 51 |
+
# dict with 'input_ids' as token array and 'attention_mask'
|
| 52 |
+
self.tokens = tokens
|
| 53 |
+
|
| 54 |
+
#self.sequence = sequence
|
| 55 |
+
|
| 56 |
+
def selectNode(self, num_func):
|
| 57 |
+
"""
|
| 58 |
+
Selects a node to move to among the children nodes
|
| 59 |
+
"""
|
| 60 |
+
# extract the status of the current node
|
| 61 |
+
nodeStatus = self.getExpandStatus()
|
| 62 |
+
|
| 63 |
+
# if the node is a legal non-leaf node
|
| 64 |
+
if (nodeStatus == 3):
|
| 65 |
+
# initialize array that will store select score vectors of each child node
|
| 66 |
+
paretoFront = {}
|
| 67 |
+
for childNode in self.childNodes:
|
| 68 |
+
childStatus = childNode.getExpandStatus()
|
| 69 |
+
# only append child if it is legal leaf node (expandable) or legal non-leaf node
|
| 70 |
+
if childStatus == 2 or childStatus == 3:
|
| 71 |
+
selectScore = childNode.calcSelectScore()
|
| 72 |
+
paretoFront = updateParetoFront(paretoFront, childNode, selectScore, num_func)
|
| 73 |
+
|
| 74 |
+
# if no selectable children (all terminal), return self as a leaf
|
| 75 |
+
if len(paretoFront) == 0:
|
| 76 |
+
return self, 1
|
| 77 |
+
|
| 78 |
+
# randomly select a node on the Pareto front
|
| 79 |
+
#selected = rd.choice(paretoFront)
|
| 80 |
+
selected = rd.choice(list(paretoFront.keys()))
|
| 81 |
+
# return selected child node and status
|
| 82 |
+
return selected, selected.getExpandStatus()
|
| 83 |
+
|
| 84 |
+
# if node is not valid non-leaf node
|
| 85 |
+
return self, nodeStatus
|
| 86 |
+
|
| 87 |
+
def addChildNode(self, tokens, totalReward, prob=None):
|
| 88 |
+
""""
|
| 89 |
+
Adds a child node
|
| 90 |
+
"""
|
| 91 |
+
child = Node(config=self.config,
|
| 92 |
+
tokens=tokens,
|
| 93 |
+
parentNode=self,
|
| 94 |
+
childNodes=[],
|
| 95 |
+
totalReward=totalReward,
|
| 96 |
+
timestep=self.timestep+1,
|
| 97 |
+
sampleProb=prob)
|
| 98 |
+
|
| 99 |
+
self.childNodes.append(child)
|
| 100 |
+
return child
|
| 101 |
+
|
| 102 |
+
def updateNode(self, rewards):
|
| 103 |
+
"""
|
| 104 |
+
Updates the cumulative rewards vector with the reward vector at a descendent leaf node.
|
| 105 |
+
Increments the number of visits to the node.
|
| 106 |
+
"""
|
| 107 |
+
self.visits += 1
|
| 108 |
+
self.totalReward += rewards
|
| 109 |
+
|
| 110 |
+
def calcSelectScore(self):
|
| 111 |
+
"""
|
| 112 |
+
Calculates the select score for the node from the cumulative rewards vector and number of visits.
|
| 113 |
+
- c: determines the degree of exploration
|
| 114 |
+
- minSelectScore: determines the
|
| 115 |
+
"""
|
| 116 |
+
""""
|
| 117 |
+
if not self.parentNode:
|
| 118 |
+
return 0.0
|
| 119 |
+
"""
|
| 120 |
+
# K-dimensional vector of normalized rewards for each objective
|
| 121 |
+
normRewards = self.totalReward / self.visits
|
| 122 |
+
if self.sampleProb is not None:
|
| 123 |
+
print("Sample Prob")
|
| 124 |
+
print(self.sampleProb)
|
| 125 |
+
return normRewards + (self.config.mcts.sample_prob * self.sampleProb * np.sqrt(self.root.visits) / self.visits)
|
| 126 |
+
return normRewards
|
| 127 |
+
|
| 128 |
+
def getExpandStatus(self):
|
| 129 |
+
"""
|
| 130 |
+
Returns an integer indicating whether the node is a:
|
| 131 |
+
1. terminal node (sequence is fully unmasked)
|
| 132 |
+
2. legal leaf node (partially unmasked sequence that can be expanded)
|
| 133 |
+
3. legal non-leaf node (already expanded sequence with M child nodes)
|
| 134 |
+
"""
|
| 135 |
+
if self.timestep == self.config.sampling.steps:
|
| 136 |
+
return 1
|
| 137 |
+
elif (self.timestep < self.config.sampling.steps) and (len(self.childNodes) == 0):
|
| 138 |
+
return 2
|
| 139 |
+
return 3
|
| 140 |
+
|
| 141 |
+
"""END OF NODE CLASS"""
|
| 142 |
+
|
| 143 |
+
def updateParetoFront(paretoFront, node, scoreVector, num_func):
|
| 144 |
+
"""
|
| 145 |
+
Removes sequences that are dominated by scoreVector
|
| 146 |
+
adds the SMILES sequence if it is non-dominated and its scoreVector
|
| 147 |
+
"""
|
| 148 |
+
paretoSize = len(paretoFront)
|
| 149 |
+
if paretoSize == 0:
|
| 150 |
+
# if pareto front is empty, add sequence and scoreVector
|
| 151 |
+
paretoFront[node] = scoreVector
|
| 152 |
+
else:
|
| 153 |
+
# vector of boolean
|
| 154 |
+
# true: sequence is non-dominated by the pareto-optimal sequence
|
| 155 |
+
# false: sequence is completely dominated by the pareto-optimal sequence
|
| 156 |
+
nondominate = []
|
| 157 |
+
# sequences to be deleted
|
| 158 |
+
delete = []
|
| 159 |
+
for k, v in paretoFront.items():
|
| 160 |
+
nondominated = scoreVector >= np.asarray(v)
|
| 161 |
+
dominant = scoreVector > np.asarray(v)
|
| 162 |
+
|
| 163 |
+
if num_func <= len(nondominated):
|
| 164 |
+
attn_nondominated = nondominated[:num_func]
|
| 165 |
+
attn_dominant = dominant[:num_func]
|
| 166 |
+
|
| 167 |
+
# all scores are greater than or equal to v and at least one score is strictly greater than v
|
| 168 |
+
if attn_nondominated.all() and attn_dominant.any():
|
| 169 |
+
# add the dominated sequence to be deleted
|
| 170 |
+
delete.append(k)
|
| 171 |
+
# sequence is dominant
|
| 172 |
+
nondominate.append(True)
|
| 173 |
+
elif attn_nondominated.all():
|
| 174 |
+
# sequence is non-dominated
|
| 175 |
+
nondominate.append(True)
|
| 176 |
+
else:
|
| 177 |
+
# sequence is completely dominated
|
| 178 |
+
nondominate.append(False)
|
| 179 |
+
|
| 180 |
+
nondominate = np.asarray(nondominate)
|
| 181 |
+
# if sequence is either dominant or non-dominated by all sequences in pareto-front -> add to pareto front
|
| 182 |
+
if nondominate.all():
|
| 183 |
+
paretoFront[node] = scoreVector
|
| 184 |
+
|
| 185 |
+
# delete all dominated sequences
|
| 186 |
+
while (paretoSize > 0) and (len(delete) > 0):
|
| 187 |
+
#for k in delete:
|
| 188 |
+
del paretoFront[delete[0]]
|
| 189 |
+
del delete[0]
|
| 190 |
+
paretoSize -= 1
|
| 191 |
+
return paretoFront
|
| 192 |
+
|
| 193 |
+
###BEGINNING OF MCTS CLASS###
|
| 194 |
+
|
| 195 |
+
class MCTS:
|
| 196 |
+
def __init__(self, config, max_sequence_length=None, mdlm=None, score_func_names=[], prot_seqs=None, num_func = []):
|
| 197 |
+
self.config = config
|
| 198 |
+
self.noise = noise_schedule.get_noise(config)
|
| 199 |
+
self.time_conditioning = self.config.time_conditioning
|
| 200 |
+
# dictionary of k (SMILES string) and v (score vector) of Pareto-optimal sequences
|
| 201 |
+
self.peptideParetoFront = {}
|
| 202 |
+
self.num_steps = config.sampling.steps
|
| 203 |
+
self.num_sequences = config.sampling.num_sequences
|
| 204 |
+
|
| 205 |
+
# mdlm model
|
| 206 |
+
self.mdlm = mdlm
|
| 207 |
+
self.tokenizer = mdlm.tokenizer
|
| 208 |
+
self.device = mdlm.device
|
| 209 |
+
|
| 210 |
+
if max_sequence_length is None:
|
| 211 |
+
self.sequence_length = self.config.sampling.seq_length
|
| 212 |
+
else:
|
| 213 |
+
self.sequence_length = max_sequence_length
|
| 214 |
+
|
| 215 |
+
self.num_iter = config.mcts.num_iter
|
| 216 |
+
|
| 217 |
+
self.num_child = config.mcts.num_children
|
| 218 |
+
|
| 219 |
+
# score functions
|
| 220 |
+
self.score_functions = ScoringFunctions(score_func_names, prot_seqs)
|
| 221 |
+
self.score_func_names = score_func_names
|
| 222 |
+
self.num_func = num_func # K-dimensional vector with the iteration number to start conditioning on each of the objectives in increasng order
|
| 223 |
+
self.iter_num = 0
|
| 224 |
+
self.curr_num_func = 1
|
| 225 |
+
self.analyzer = PeptideAnalyzer()
|
| 226 |
+
|
| 227 |
+
# track fraction of valid peptides
|
| 228 |
+
self.valid_fraction_log = []
|
| 229 |
+
self.score_logs = {name: [] for name in score_func_names}
|
| 230 |
+
|
| 231 |
+
def reset(self):
|
| 232 |
+
self.iter_num = 0
|
| 233 |
+
self.valid_fraction_log = []
|
| 234 |
+
self.score_logs = {name: [] for name in self.score_func_names}
|
| 235 |
+
self.peptideParetoFront = {}
|
| 236 |
+
|
| 237 |
+
def forward(self, rootNode):
|
| 238 |
+
self.reset()
|
| 239 |
+
|
| 240 |
+
while (self.iter_num < self.num_iter):
|
| 241 |
+
self.iter_num += 1
|
| 242 |
+
|
| 243 |
+
# traverse the tree form the root node until a leaf node
|
| 244 |
+
leafNode, _ = self.select(rootNode)
|
| 245 |
+
#print(leafNode.tokens['input_ids'])
|
| 246 |
+
|
| 247 |
+
# expand leaf node into num_children partially unmasked sequences at the next timestep
|
| 248 |
+
self.expand(leafNode)
|
| 249 |
+
|
| 250 |
+
# return dictionary of pareto front peptides and their score vectors
|
| 251 |
+
return self.peptideParetoFront
|
| 252 |
+
|
| 253 |
+
# change to include more even if dominated? since there is error in the scores
|
| 254 |
+
def updateParetoFront(self, sequence, scoreVector, tokens):
|
| 255 |
+
"""
|
| 256 |
+
Removes sequences that are dominated by scoreVector
|
| 257 |
+
adds the SMILES sequence if it is non-dominated and its scoreVector
|
| 258 |
+
|
| 259 |
+
num_func: index of the last objective to consider when updating the pareto front from 0 to K
|
| 260 |
+
"""
|
| 261 |
+
paretoSize = len(self.peptideParetoFront)
|
| 262 |
+
|
| 263 |
+
self.curr_num_func = 1
|
| 264 |
+
|
| 265 |
+
for i in range(len(self.num_func)):
|
| 266 |
+
if self.iter_num >= self.num_func[i]:
|
| 267 |
+
self.curr_num_func = i+1
|
| 268 |
+
|
| 269 |
+
if paretoSize == 0:
|
| 270 |
+
# if pareto front is empty, add sequence and scoreVector
|
| 271 |
+
self.peptideParetoFront[sequence] = {'scores': scoreVector, 'token_ids': tokens}
|
| 272 |
+
# if pareto front is empty, set reward vector to 1s
|
| 273 |
+
rewardVector = np.ones(len(scoreVector))
|
| 274 |
+
else:
|
| 275 |
+
# vector of boolean
|
| 276 |
+
# true: sequence is non-dominated by the pareto-optimal sequence
|
| 277 |
+
# false: sequence is completely dominated by the pareto-optimal sequence
|
| 278 |
+
nondominate = []
|
| 279 |
+
# sequences to be deleted
|
| 280 |
+
delete = []
|
| 281 |
+
# initialize reward vector with zeros
|
| 282 |
+
rewardVector = np.zeros(len(scoreVector))
|
| 283 |
+
for k, v in self.peptideParetoFront.items():
|
| 284 |
+
# boolean vector
|
| 285 |
+
# true: if all metrics are equal or larger
|
| 286 |
+
# false: if the pareto front sequence dominates scoreVector
|
| 287 |
+
nondominated = scoreVector >= np.asarray(v['scores']) # [num_objectives]
|
| 288 |
+
dominant = scoreVector > np.asarray(v['scores'])
|
| 289 |
+
# add to reward vector
|
| 290 |
+
rewardVector += nondominated # [num_objectives]
|
| 291 |
+
|
| 292 |
+
if self.curr_num_func <= len(nondominated):
|
| 293 |
+
attn_nondominated = nondominated[:self.curr_num_func]
|
| 294 |
+
attn_dominant = dominant[:self.curr_num_func]
|
| 295 |
+
|
| 296 |
+
# only delete pareto-optimal sequence if
|
| 297 |
+
# all scores are greater than or equal to v and at least one score is strictly greater than v
|
| 298 |
+
if attn_nondominated.all() and attn_dominant.any():
|
| 299 |
+
# add the dominated sequence to be deleted
|
| 300 |
+
delete.append(k)
|
| 301 |
+
# sequence is dominant
|
| 302 |
+
nondominate.append(True)
|
| 303 |
+
elif attn_nondominated.all():
|
| 304 |
+
# sequence is non-dominated
|
| 305 |
+
nondominate.append(True)
|
| 306 |
+
else:
|
| 307 |
+
# sequence is completely dominated
|
| 308 |
+
nondominate.append(False)
|
| 309 |
+
|
| 310 |
+
assert len(nondominate) == paretoSize
|
| 311 |
+
nondominate = np.asarray(nondominate)
|
| 312 |
+
# if sequence is either dominant or non-dominated by all sequences in pareto-front -> add to pareto front
|
| 313 |
+
# or if the pareto front does not have enough sequences
|
| 314 |
+
if nondominate.all() or paretoSize < self.num_sequences:
|
| 315 |
+
self.peptideParetoFront[sequence] = {'scores': scoreVector, 'token_ids': tokens}
|
| 316 |
+
|
| 317 |
+
rewardVector = rewardVector / paretoSize
|
| 318 |
+
|
| 319 |
+
# delete all dominated sequences if pareto front is larger than num_sequences
|
| 320 |
+
while (paretoSize > self.num_sequences) and (len(delete) > 0):
|
| 321 |
+
#for k in delete:
|
| 322 |
+
del self.peptideParetoFront[delete[0]]
|
| 323 |
+
del delete[0]
|
| 324 |
+
paretoSize -= 1
|
| 325 |
+
|
| 326 |
+
return rewardVector
|
| 327 |
+
|
| 328 |
+
def isPathEnd(self, path, maxDepth):
|
| 329 |
+
"""
|
| 330 |
+
Checks if the node is completely unmasked (ie. end of path)
|
| 331 |
+
or if the path is at the max depth
|
| 332 |
+
"""
|
| 333 |
+
if (path[-1] != self.config.mcts.mask_token).all():
|
| 334 |
+
return True
|
| 335 |
+
elif len(path) >= maxDepth:
|
| 336 |
+
return True
|
| 337 |
+
return False
|
| 338 |
+
|
| 339 |
+
def select(self, currNode):
|
| 340 |
+
"""
|
| 341 |
+
Traverse the tree from the root node until reaching a legal leaf node
|
| 342 |
+
"""
|
| 343 |
+
while True:
|
| 344 |
+
currNode, nodeStatus = currNode.selectNode(self.curr_num_func)
|
| 345 |
+
if nodeStatus != 3:
|
| 346 |
+
return currNode, nodeStatus
|
| 347 |
+
|
| 348 |
+
def expand(self, parentNode, eps=1e-5, checkSimilarity = True):
|
| 349 |
+
"""
|
| 350 |
+
Sample unmasking steps from the pre-trained MDLM
|
| 351 |
+
adds num_children partially unmasked sequences to the children of the parentNode
|
| 352 |
+
"""
|
| 353 |
+
|
| 354 |
+
num_children = self.config.mcts.num_children
|
| 355 |
+
# initialize child rewards that will be added to total rewards
|
| 356 |
+
allChildReward = np.zeros_like(parentNode.totalReward) # (n_objectives)
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
# compute number of rollout steps
|
| 360 |
+
# if parentNode.timestep = self.num_steps then num_rollout_steps = 1
|
| 361 |
+
num_rollout_steps = self.num_steps - parentNode.timestep
|
| 362 |
+
# array of rollout timesteps from the timestep of parent node to 0
|
| 363 |
+
rollout_t = torch.linspace(1, eps, num_rollout_steps, device=self.device)
|
| 364 |
+
dt = (1 - eps) / self.num_steps
|
| 365 |
+
p_x0_cache = None
|
| 366 |
+
|
| 367 |
+
# initialize x and attn_mask
|
| 368 |
+
x = parentNode.tokens['input_ids'].to(self.device)
|
| 369 |
+
attn_mask = parentNode.tokens['attention_mask'].to(self.device)
|
| 370 |
+
|
| 371 |
+
t = rollout_t[0] * torch.ones(num_children, 1, device = self.device)
|
| 372 |
+
# generate (n_children, seq_length) array of sampled children nodes
|
| 373 |
+
print("token array:")
|
| 374 |
+
print(x)
|
| 375 |
+
p_x0_cache, x_children = self.mdlm.batch_cached_reverse_step(token_array=x,
|
| 376 |
+
t=t, dt=dt,
|
| 377 |
+
batch_size=num_children,
|
| 378 |
+
attn_mask=attn_mask)
|
| 379 |
+
x_rollout = x_children
|
| 380 |
+
|
| 381 |
+
for i in range(1, num_rollout_steps):
|
| 382 |
+
t = rollout_t[i] * torch.ones(num_children, 1, device = self.device)
|
| 383 |
+
|
| 384 |
+
p_x0_cache, x_next = self.mdlm.cached_reverse_step(x=x_rollout,
|
| 385 |
+
t=t, dt=dt, p_x0=p_x0_cache,
|
| 386 |
+
attn_mask=attn_mask)
|
| 387 |
+
|
| 388 |
+
if (not torch.allclose(x_next, x) or self.time_conditioning):
|
| 389 |
+
# Disable caching
|
| 390 |
+
p_x0_cache = None
|
| 391 |
+
|
| 392 |
+
x_rollout = x_next
|
| 393 |
+
|
| 394 |
+
if self.config.sampling.noise_removal:
|
| 395 |
+
t = rollout_t[-1] * torch.ones(x.shape[0], 1, device=self.device)
|
| 396 |
+
|
| 397 |
+
time_cond = self.noise(t)[0]
|
| 398 |
+
x_rollout = self.mdlm.forward(x_rollout, attn_mask, time_cond).argmax(dim=-1) # (n_children, seq_length)
|
| 399 |
+
|
| 400 |
+
childSequences = self.tokenizer.batch_decode(x_rollout)
|
| 401 |
+
|
| 402 |
+
validSequences = []
|
| 403 |
+
maskedTokens = []
|
| 404 |
+
unmaskedTokens = []
|
| 405 |
+
for i in range(num_children):
|
| 406 |
+
childSeq = childSequences[i]
|
| 407 |
+
#scoreVector = scoreVectors[i]
|
| 408 |
+
rewardVector = np.zeros(self.config.mcts.num_objectives)
|
| 409 |
+
|
| 410 |
+
# check if the peptide is valid
|
| 411 |
+
if self.analyzer.is_peptide(childSeq):
|
| 412 |
+
validSequences.append(childSeq)
|
| 413 |
+
maskedTokens.append(x_children[i])
|
| 414 |
+
unmaskedTokens.append(x_rollout[i])
|
| 415 |
+
else:
|
| 416 |
+
childTokens = {'input_ids': x_children[i], 'attention_mask': attn_mask}
|
| 417 |
+
parentNode.addChildNode(tokens=childTokens,
|
| 418 |
+
totalReward=rewardVector)
|
| 419 |
+
|
| 420 |
+
if (len(validSequences) != 0):
|
| 421 |
+
scoreVectors = self.score_functions(input_seqs=validSequences)
|
| 422 |
+
average_scores = scoreVectors.T
|
| 423 |
+
for i, name in enumerate(self.score_func_names):
|
| 424 |
+
self.score_logs[name].append(average_scores[i])
|
| 425 |
+
else:
|
| 426 |
+
for name in self.score_func_names:
|
| 427 |
+
self.score_logs[name].append(np.zeros(0))
|
| 428 |
+
|
| 429 |
+
for i, validSeq in enumerate(validSequences):
|
| 430 |
+
#tokens = validTokens[i]
|
| 431 |
+
scoreVector = scoreVectors[i]
|
| 432 |
+
|
| 433 |
+
# update pareto front
|
| 434 |
+
rewardVector = self.updateParetoFront(validSeq, scoreVector, unmaskedTokens[i])
|
| 435 |
+
print(scoreVector)
|
| 436 |
+
print(rewardVector)
|
| 437 |
+
|
| 438 |
+
# add to all child reward vector for backprop
|
| 439 |
+
allChildReward += rewardVector
|
| 440 |
+
|
| 441 |
+
# create node for sequence and add to the children node of parent
|
| 442 |
+
childTokens = {'input_ids': maskedTokens[i], 'attention_mask': attn_mask}
|
| 443 |
+
parentNode.addChildNode(tokens=childTokens,
|
| 444 |
+
totalReward=rewardVector)
|
| 445 |
+
|
| 446 |
+
# compute fraction of invalid child sequences
|
| 447 |
+
invalid = (num_children - len(validSequences)) / num_children
|
| 448 |
+
|
| 449 |
+
valid_fraction = len(validSequences) / num_children
|
| 450 |
+
print(f"Valid fraction: {valid_fraction}")
|
| 451 |
+
self.valid_fraction_log.append(valid_fraction)
|
| 452 |
+
|
| 453 |
+
print(self.config.mcts.invalid_penalty)
|
| 454 |
+
# subtract score using fraction of invalid sequences from reward
|
| 455 |
+
allChildReward = allChildReward - (self.config.mcts.invalid_penalty * invalid)
|
| 456 |
+
# backpropogate all child rewards
|
| 457 |
+
self.backprop(parentNode, allChildReward)
|
| 458 |
+
|
| 459 |
+
|
| 460 |
+
def backprop(self, node, reward_vector):
|
| 461 |
+
# backpropogate rewards through the path leading to the leaf node from the root
|
| 462 |
+
while node:
|
| 463 |
+
node.updateNode(reward_vector)
|
| 464 |
+
node = node.parentNode
|
| 465 |
+
|
| 466 |
+
|
| 467 |
+
def getSequenceForObjective(self, objective_index, k):
|
| 468 |
+
"""
|
| 469 |
+
Returns the top-k sequences in the pareto front that has the best score for
|
| 470 |
+
a given objective and their score vectors for all objectives
|
| 471 |
+
"""
|
| 472 |
+
|
| 473 |
+
# dictionary of top-k peptides for the objective
|
| 474 |
+
topk = {}
|
| 475 |
+
|
| 476 |
+
peptides = []
|
| 477 |
+
objectiveScores = []
|
| 478 |
+
for k, v in self.peptideParetoFront.items():
|
| 479 |
+
# store peptides in list
|
| 480 |
+
peptides.append(k)
|
| 481 |
+
# store score for objective
|
| 482 |
+
objectiveScores.append(v['token_ids'][objective_index])
|
| 483 |
+
|
| 484 |
+
objectiveScores = torch.tensor(objectiveScores)
|
| 485 |
+
topKScores = torch.topk(objectiveScores, k)
|
| 486 |
+
for (_, index) in topKScores.items():
|
| 487 |
+
seq = peptides[index]
|
| 488 |
+
|
| 489 |
+
topk[seq] = self.peptideParetoFront.get(seq)
|
| 490 |
+
|
| 491 |
+
return topk
|
| 492 |
+
|
src/roformer.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import RoFormerConfig, RoFormerForMaskedLM
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
class Roformer(nn.Module):
|
| 7 |
+
def __init__(self, config, tokenizer):
|
| 8 |
+
super(Roformer, self).__init__()
|
| 9 |
+
|
| 10 |
+
self.tokenizer = tokenizer
|
| 11 |
+
self.vocab_size = self.tokenizer.vocab_size
|
| 12 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 13 |
+
self.device = device
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
roformer_config = RoFormerConfig(
|
| 17 |
+
vocab_size=self.tokenizer.vocab_size,
|
| 18 |
+
embedding_size=config.roformer.hidden_size,
|
| 19 |
+
hidden_size=config.roformer.hidden_size,
|
| 20 |
+
num_hidden_layers=config.roformer.n_layers,
|
| 21 |
+
num_attention_heads=config.roformer.n_heads,
|
| 22 |
+
intermediate_size=config.roformer.hidden_size * 4,
|
| 23 |
+
max_position_embeddings=config.roformer.max_position_embeddings,
|
| 24 |
+
hidden_dropout_prob=0.1,
|
| 25 |
+
attention_probs_dropout_prob=0.1,
|
| 26 |
+
pad_token_id=0,
|
| 27 |
+
rotary_value=False
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
self.model = RoFormerForMaskedLM(roformer_config).to(self.device)
|
| 31 |
+
|
| 32 |
+
def freeze_model(self):
|
| 33 |
+
for param in self.model.parameters():
|
| 34 |
+
param.requires_grad = False
|
| 35 |
+
|
| 36 |
+
def unfreeze_all_layers(self):
|
| 37 |
+
for param in self.model.parameters():
|
| 38 |
+
param.requires_grad = True
|
| 39 |
+
|
| 40 |
+
def unfreeze_n_layers(self, n):
|
| 41 |
+
num_layers = 8
|
| 42 |
+
|
| 43 |
+
for i, layer in enumerate(self.model.roformer.encoder.layer):
|
| 44 |
+
# finetune final n layers
|
| 45 |
+
if i >= num_layers - n:
|
| 46 |
+
# unfreeze query weights
|
| 47 |
+
for module in layer.attention.self.query.modules():
|
| 48 |
+
for param in module.parameters():
|
| 49 |
+
param.requires_grad = True
|
| 50 |
+
# unfreeze key weights
|
| 51 |
+
for module in layer.attention.self.key.modules():
|
| 52 |
+
for param in module.parameters():
|
| 53 |
+
param.requires_grad = True
|
| 54 |
+
|
| 55 |
+
def forward(self, input_ids, attn_mask):
|
| 56 |
+
|
| 57 |
+
input_ids = input_ids.to(self.device)
|
| 58 |
+
attn_mask = attn_mask.to(self.device)
|
| 59 |
+
|
| 60 |
+
# get logits embeddings
|
| 61 |
+
logits = self.model(input_ids=input_ids, attention_mask=attn_mask)
|
| 62 |
+
# return logits
|
| 63 |
+
#print(logits.logits)
|
| 64 |
+
return logits.logits
|
| 65 |
+
|
| 66 |
+
def save_model(self, save_dir):
|
| 67 |
+
self.model.save_pretrained(save_dir)
|
| 68 |
+
self.tokenizer.save_pretrained(save_dir)
|
| 69 |
+
|
| 70 |
+
@classmethod
|
| 71 |
+
def load_model(cls, save_dir, config, tokenizer):
|
| 72 |
+
roformer = cls(config, tokenizer)
|
| 73 |
+
roformer.model = RoFormerForMaskedLM.from_pretrained(save_dir)
|
| 74 |
+
return roformer
|
src/scoring/functions/binding.py
ADDED
|
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import os, torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import pandas as pd
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import esm
|
| 8 |
+
from transformers import AutoModelForMaskedLM
|
| 9 |
+
|
| 10 |
+
class ImprovedBindingPredictor(nn.Module):
|
| 11 |
+
def __init__(self,
|
| 12 |
+
esm_dim=1280,
|
| 13 |
+
smiles_dim=768,
|
| 14 |
+
hidden_dim=512,
|
| 15 |
+
n_heads=8,
|
| 16 |
+
n_layers=3,
|
| 17 |
+
dropout=0.1):
|
| 18 |
+
super().__init__()
|
| 19 |
+
|
| 20 |
+
# Define binding thresholds
|
| 21 |
+
self.tight_threshold = 7.5 # Kd/Ki/IC50 ≤ ~30nM
|
| 22 |
+
self.weak_threshold = 6.0 # Kd/Ki/IC50 > 1μM
|
| 23 |
+
|
| 24 |
+
# Project to same dimension
|
| 25 |
+
self.smiles_projection = nn.Linear(smiles_dim, hidden_dim)
|
| 26 |
+
self.protein_projection = nn.Linear(esm_dim, hidden_dim)
|
| 27 |
+
self.protein_norm = nn.LayerNorm(hidden_dim)
|
| 28 |
+
self.smiles_norm = nn.LayerNorm(hidden_dim)
|
| 29 |
+
|
| 30 |
+
# Cross attention blocks with layer norm
|
| 31 |
+
self.cross_attention_layers = nn.ModuleList([
|
| 32 |
+
nn.ModuleDict({
|
| 33 |
+
'attention': nn.MultiheadAttention(hidden_dim, n_heads, dropout=dropout),
|
| 34 |
+
'norm1': nn.LayerNorm(hidden_dim),
|
| 35 |
+
'ffn': nn.Sequential(
|
| 36 |
+
nn.Linear(hidden_dim, hidden_dim * 4),
|
| 37 |
+
nn.ReLU(),
|
| 38 |
+
nn.Dropout(dropout),
|
| 39 |
+
nn.Linear(hidden_dim * 4, hidden_dim)
|
| 40 |
+
),
|
| 41 |
+
'norm2': nn.LayerNorm(hidden_dim)
|
| 42 |
+
}) for _ in range(n_layers)
|
| 43 |
+
])
|
| 44 |
+
|
| 45 |
+
# Prediction heads
|
| 46 |
+
self.shared_head = nn.Sequential(
|
| 47 |
+
nn.Linear(hidden_dim * 2, hidden_dim),
|
| 48 |
+
nn.ReLU(),
|
| 49 |
+
nn.Dropout(dropout),
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
# Regression head
|
| 53 |
+
self.regression_head = nn.Linear(hidden_dim, 1)
|
| 54 |
+
|
| 55 |
+
# Classification head (3 classes: tight, medium, loose binding)
|
| 56 |
+
self.classification_head = nn.Linear(hidden_dim, 3)
|
| 57 |
+
|
| 58 |
+
def get_binding_class(self, affinity):
|
| 59 |
+
"""Convert affinity values to class indices
|
| 60 |
+
0: tight binding (>= 7.5)
|
| 61 |
+
1: medium binding (6.0-7.5)
|
| 62 |
+
2: weak binding (< 6.0)
|
| 63 |
+
"""
|
| 64 |
+
if isinstance(affinity, torch.Tensor):
|
| 65 |
+
tight_mask = affinity >= self.tight_threshold
|
| 66 |
+
weak_mask = affinity < self.weak_threshold
|
| 67 |
+
medium_mask = ~(tight_mask | weak_mask)
|
| 68 |
+
|
| 69 |
+
classes = torch.zeros_like(affinity, dtype=torch.long)
|
| 70 |
+
classes[medium_mask] = 1
|
| 71 |
+
classes[weak_mask] = 2
|
| 72 |
+
return classes
|
| 73 |
+
else:
|
| 74 |
+
if affinity >= self.tight_threshold:
|
| 75 |
+
return 0 # tight binding
|
| 76 |
+
elif affinity < self.weak_threshold:
|
| 77 |
+
return 2 # weak binding
|
| 78 |
+
else:
|
| 79 |
+
return 1 # medium binding
|
| 80 |
+
|
| 81 |
+
def forward(self, protein_emb, smiles_emb):
|
| 82 |
+
protein = self.protein_norm(self.protein_projection(protein_emb))
|
| 83 |
+
smiles = self.smiles_norm(self.smiles_projection(smiles_emb))
|
| 84 |
+
|
| 85 |
+
#protein = protein.transpose(0, 1)
|
| 86 |
+
#smiles = smiles.transpose(0, 1)
|
| 87 |
+
|
| 88 |
+
# Cross attention layers
|
| 89 |
+
for layer in self.cross_attention_layers:
|
| 90 |
+
# Protein attending to SMILES
|
| 91 |
+
attended_protein = layer['attention'](
|
| 92 |
+
protein, smiles, smiles
|
| 93 |
+
)[0]
|
| 94 |
+
protein = layer['norm1'](protein + attended_protein)
|
| 95 |
+
protein = layer['norm2'](protein + layer['ffn'](protein))
|
| 96 |
+
|
| 97 |
+
# SMILES attending to protein
|
| 98 |
+
attended_smiles = layer['attention'](
|
| 99 |
+
smiles, protein, protein
|
| 100 |
+
)[0]
|
| 101 |
+
smiles = layer['norm1'](smiles + attended_smiles)
|
| 102 |
+
smiles = layer['norm2'](smiles + layer['ffn'](smiles))
|
| 103 |
+
|
| 104 |
+
# Get sequence-level representations
|
| 105 |
+
protein_pool = torch.mean(protein, dim=0)
|
| 106 |
+
smiles_pool = torch.mean(smiles, dim=0)
|
| 107 |
+
|
| 108 |
+
# Concatenate both representations
|
| 109 |
+
combined = torch.cat([protein_pool, smiles_pool], dim=-1)
|
| 110 |
+
|
| 111 |
+
# Shared features
|
| 112 |
+
shared_features = self.shared_head(combined)
|
| 113 |
+
|
| 114 |
+
regression_output = self.regression_head(shared_features)
|
| 115 |
+
classification_logits = self.classification_head(shared_features)
|
| 116 |
+
|
| 117 |
+
return regression_output, classification_logits
|
| 118 |
+
|
| 119 |
+
class BindingAffinity:
|
| 120 |
+
def __init__(self, prot_seq, tokenizer, base_path, device=None, emb_model=None):
|
| 121 |
+
super().__init__()
|
| 122 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
|
| 123 |
+
|
| 124 |
+
# peptide embeddings
|
| 125 |
+
if emb_model is not None:
|
| 126 |
+
self.pep_model = emb_model.to(self.device).eval()
|
| 127 |
+
else:
|
| 128 |
+
self.pep_model = AutoModelForMaskedLM.from_pretrained('aaronfeller/PeptideCLM-23M-all').roformer.to(self.device).eval()
|
| 129 |
+
|
| 130 |
+
self.pep_tokenizer = tokenizer
|
| 131 |
+
|
| 132 |
+
self.model = ImprovedBindingPredictor().to(self.device)
|
| 133 |
+
checkpoint = torch.load(f'{base_path}/src/scoring/functions/classifiers/binding-affinity.pt',
|
| 134 |
+
map_location=self.device,
|
| 135 |
+
weights_only=False)
|
| 136 |
+
self.model.load_state_dict(checkpoint['model_state_dict'])
|
| 137 |
+
|
| 138 |
+
self.model.eval()
|
| 139 |
+
|
| 140 |
+
self.esm_model, alphabet = esm.pretrained.esm2_t33_650M_UR50D() # load ESM-2 model
|
| 141 |
+
self.esm_model = self.esm_model.to(self.device).eval()
|
| 142 |
+
self.prot_tokenizer = alphabet.get_batch_converter() # load esm tokenizer
|
| 143 |
+
|
| 144 |
+
data = [("target", prot_seq)]
|
| 145 |
+
# get tokenized protein
|
| 146 |
+
_, _, prot_tokens = self.prot_tokenizer(data)
|
| 147 |
+
prot_tokens = prot_tokens.to(self.device)
|
| 148 |
+
with torch.no_grad():
|
| 149 |
+
results = self.esm_model.forward(prot_tokens, repr_layers=[33]) # Example with ESM-2
|
| 150 |
+
prot_emb = results["representations"][33]
|
| 151 |
+
|
| 152 |
+
self.prot_emb = prot_emb[0].to(self.device)
|
| 153 |
+
self.prot_emb = torch.mean(self.prot_emb, dim=0, keepdim=True)
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def forward(self, input_seqs):
|
| 157 |
+
with torch.no_grad():
|
| 158 |
+
scores = []
|
| 159 |
+
for seq in input_seqs:
|
| 160 |
+
pep_tokens = self.pep_tokenizer(seq, return_tensors='pt', padding=True)
|
| 161 |
+
|
| 162 |
+
pep_tokens = {k: v.to(self.device) for k, v in pep_tokens.items()}
|
| 163 |
+
|
| 164 |
+
with torch.no_grad():
|
| 165 |
+
emb = self.pep_model(input_ids=pep_tokens['input_ids'],
|
| 166 |
+
attention_mask=pep_tokens['attention_mask'],
|
| 167 |
+
output_hidden_states=True)
|
| 168 |
+
|
| 169 |
+
#emb = self.pep_model(input_ids=pep_tokens['input_ids'], attention_mask=pep_tokens['attention_mask'])
|
| 170 |
+
pep_emb = emb.last_hidden_state.squeeze(0)
|
| 171 |
+
pep_emb = torch.mean(pep_emb, dim=0, keepdim=True)
|
| 172 |
+
|
| 173 |
+
score, logits = self.model.forward(self.prot_emb, pep_emb)
|
| 174 |
+
scores.append(score.item())
|
| 175 |
+
return scores
|
| 176 |
+
|
| 177 |
+
def __call__(self, input_seqs: list):
|
| 178 |
+
return self.forward(input_seqs)
|
src/scoring/functions/binding_utils.py
ADDED
|
@@ -0,0 +1,290 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch import nn
|
| 2 |
+
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
def to_var(x):
|
| 6 |
+
if torch.cuda.is_available():
|
| 7 |
+
x = x.cuda()
|
| 8 |
+
return x
|
| 9 |
+
|
| 10 |
+
class MultiHeadAttentionSequence(nn.Module):
|
| 11 |
+
|
| 12 |
+
def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
|
| 13 |
+
|
| 14 |
+
super().__init__()
|
| 15 |
+
|
| 16 |
+
self.n_head = n_head
|
| 17 |
+
self.d_model = d_model
|
| 18 |
+
self.d_k = d_k
|
| 19 |
+
self.d_v = d_v
|
| 20 |
+
|
| 21 |
+
self.W_Q = nn.Linear(d_model, n_head*d_k)
|
| 22 |
+
self.W_K = nn.Linear(d_model, n_head*d_k)
|
| 23 |
+
self.W_V = nn.Linear(d_model, n_head*d_v)
|
| 24 |
+
self.W_O = nn.Linear(n_head*d_v, d_model)
|
| 25 |
+
|
| 26 |
+
self.layer_norm = nn.LayerNorm(d_model)
|
| 27 |
+
|
| 28 |
+
self.dropout = nn.Dropout(dropout)
|
| 29 |
+
|
| 30 |
+
def forward(self, q, k, v):
|
| 31 |
+
|
| 32 |
+
batch, len_q, _ = q.size()
|
| 33 |
+
batch, len_k, _ = k.size()
|
| 34 |
+
batch, len_v, _ = v.size()
|
| 35 |
+
|
| 36 |
+
Q = self.W_Q(q).view([batch, len_q, self.n_head, self.d_k])
|
| 37 |
+
K = self.W_K(k).view([batch, len_k, self.n_head, self.d_k])
|
| 38 |
+
V = self.W_V(v).view([batch, len_v, self.n_head, self.d_v])
|
| 39 |
+
|
| 40 |
+
Q = Q.transpose(1, 2)
|
| 41 |
+
K = K.transpose(1, 2).transpose(2, 3)
|
| 42 |
+
V = V.transpose(1, 2)
|
| 43 |
+
|
| 44 |
+
attention = torch.matmul(Q, K)
|
| 45 |
+
|
| 46 |
+
attention = attention / np.sqrt(self.d_k)
|
| 47 |
+
|
| 48 |
+
attention = F.softmax(attention, dim=-1)
|
| 49 |
+
|
| 50 |
+
output = torch.matmul(attention, V)
|
| 51 |
+
|
| 52 |
+
output = output.transpose(1, 2).reshape([batch, len_q, self.d_v*self.n_head])
|
| 53 |
+
|
| 54 |
+
output = self.W_O(output)
|
| 55 |
+
|
| 56 |
+
output = self.dropout(output)
|
| 57 |
+
|
| 58 |
+
output = self.layer_norm(output + q)
|
| 59 |
+
|
| 60 |
+
return output, attention
|
| 61 |
+
|
| 62 |
+
class MultiHeadAttentionReciprocal(nn.Module):
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
|
| 66 |
+
|
| 67 |
+
super().__init__()
|
| 68 |
+
|
| 69 |
+
self.n_head = n_head
|
| 70 |
+
self.d_model = d_model
|
| 71 |
+
self.d_k = d_k
|
| 72 |
+
self.d_v = d_v
|
| 73 |
+
|
| 74 |
+
self.W_Q = nn.Linear(d_model, n_head*d_k)
|
| 75 |
+
self.W_K = nn.Linear(d_model, n_head*d_k)
|
| 76 |
+
self.W_V = nn.Linear(d_model, n_head*d_v)
|
| 77 |
+
self.W_O = nn.Linear(n_head*d_v, d_model)
|
| 78 |
+
self.W_V_2 = nn.Linear(d_model, n_head*d_v)
|
| 79 |
+
self.W_O_2 = nn.Linear(n_head*d_v, d_model)
|
| 80 |
+
|
| 81 |
+
self.layer_norm = nn.LayerNorm(d_model)
|
| 82 |
+
|
| 83 |
+
self.dropout = nn.Dropout(dropout)
|
| 84 |
+
|
| 85 |
+
self.layer_norm_2 = nn.LayerNorm(d_model)
|
| 86 |
+
|
| 87 |
+
self.dropout_2 = nn.Dropout(dropout)
|
| 88 |
+
|
| 89 |
+
def forward(self, q, k, v, v_2):
|
| 90 |
+
|
| 91 |
+
batch, len_q, _ = q.size()
|
| 92 |
+
batch, len_k, _ = k.size()
|
| 93 |
+
batch, len_v, _ = v.size()
|
| 94 |
+
batch, len_v_2, _ = v_2.size()
|
| 95 |
+
|
| 96 |
+
Q = self.W_Q(q).view([batch, len_q, self.n_head, self.d_k])
|
| 97 |
+
K = self.W_K(k).view([batch, len_k, self.n_head, self.d_k])
|
| 98 |
+
V = self.W_V(v).view([batch, len_v, self.n_head, self.d_v])
|
| 99 |
+
V_2 = self.W_V_2(v_2).view([batch, len_v_2, self.n_head, self.d_v])
|
| 100 |
+
|
| 101 |
+
Q = Q.transpose(1, 2)
|
| 102 |
+
K = K.transpose(1, 2).transpose(2, 3)
|
| 103 |
+
V = V.transpose(1, 2)
|
| 104 |
+
V_2 = V_2.transpose(1,2)
|
| 105 |
+
|
| 106 |
+
attention = torch.matmul(Q, K)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
attention = attention /np.sqrt(self.d_k)
|
| 110 |
+
|
| 111 |
+
attention_2 = attention.transpose(-2, -1)
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
attention = F.softmax(attention, dim=-1)
|
| 116 |
+
|
| 117 |
+
attention_2 = F.softmax(attention_2, dim=-1)
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
output = torch.matmul(attention, V)
|
| 121 |
+
|
| 122 |
+
output_2 = torch.matmul(attention_2, V_2)
|
| 123 |
+
|
| 124 |
+
output = output.transpose(1, 2).reshape([batch, len_q, self.d_v*self.n_head])
|
| 125 |
+
|
| 126 |
+
output_2 = output_2.transpose(1, 2).reshape([batch, len_k, self.d_v*self.n_head])
|
| 127 |
+
|
| 128 |
+
output = self.W_O(output)
|
| 129 |
+
|
| 130 |
+
output_2 = self.W_O_2(output_2)
|
| 131 |
+
|
| 132 |
+
output = self.dropout(output)
|
| 133 |
+
|
| 134 |
+
output = self.layer_norm(output + q)
|
| 135 |
+
|
| 136 |
+
output_2 = self.dropout(output_2)
|
| 137 |
+
|
| 138 |
+
output_2 = self.layer_norm(output_2 + k)
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
return output, output_2, attention, attention_2
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
class FFN(nn.Module):
|
| 145 |
+
|
| 146 |
+
def __init__(self, d_in, d_hid, dropout=0.1):
|
| 147 |
+
super().__init__()
|
| 148 |
+
|
| 149 |
+
self.layer_1 = nn.Conv1d(d_in, d_hid,1)
|
| 150 |
+
self.layer_2 = nn.Conv1d(d_hid, d_in,1)
|
| 151 |
+
self.relu = nn.ReLU()
|
| 152 |
+
self.layer_norm = nn.LayerNorm(d_in)
|
| 153 |
+
|
| 154 |
+
self.dropout = nn.Dropout(dropout)
|
| 155 |
+
|
| 156 |
+
def forward(self, x):
|
| 157 |
+
|
| 158 |
+
residual = x
|
| 159 |
+
output = self.layer_1(x.transpose(1, 2))
|
| 160 |
+
|
| 161 |
+
output = self.relu(output)
|
| 162 |
+
|
| 163 |
+
output = self.layer_2(output)
|
| 164 |
+
|
| 165 |
+
output = self.dropout(output)
|
| 166 |
+
|
| 167 |
+
output = self.layer_norm(output.transpose(1, 2)+residual)
|
| 168 |
+
|
| 169 |
+
return output
|
| 170 |
+
|
| 171 |
+
class ConvLayer(nn.Module):
|
| 172 |
+
def __init__(self, in_channels, out_channels, kernel_size, padding, dilation):
|
| 173 |
+
super(ConvLayer, self).__init__()
|
| 174 |
+
self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, padding=padding, dilation=dilation)
|
| 175 |
+
self.relu = nn.ReLU()
|
| 176 |
+
|
| 177 |
+
def forward(self, x):
|
| 178 |
+
out = self.conv(x)
|
| 179 |
+
out = self.relu(out)
|
| 180 |
+
return out
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
class DilatedCNN(nn.Module):
|
| 184 |
+
def __init__(self, d_model, d_hidden):
|
| 185 |
+
super(DilatedCNN, self).__init__()
|
| 186 |
+
self.first_ = nn.ModuleList()
|
| 187 |
+
self.second_ = nn.ModuleList()
|
| 188 |
+
self.third_ = nn.ModuleList()
|
| 189 |
+
|
| 190 |
+
dilation_tuple = (1, 2, 3)
|
| 191 |
+
dim_in_tuple = (d_model, d_hidden, d_hidden)
|
| 192 |
+
dim_out_tuple = (d_hidden, d_hidden, d_hidden)
|
| 193 |
+
|
| 194 |
+
for i, dilation_rate in enumerate(dilation_tuple):
|
| 195 |
+
self.first_.append(ConvLayer(dim_in_tuple[i], dim_out_tuple[i], kernel_size=3, padding=dilation_rate,
|
| 196 |
+
dilation=dilation_rate))
|
| 197 |
+
|
| 198 |
+
for i, dilation_rate in enumerate(dilation_tuple):
|
| 199 |
+
self.second_.append(ConvLayer(dim_in_tuple[i], dim_out_tuple[i], kernel_size=5, padding=2*dilation_rate,
|
| 200 |
+
dilation=dilation_rate))
|
| 201 |
+
|
| 202 |
+
for i, dilation_rate in enumerate(dilation_tuple):
|
| 203 |
+
self.third_.append(ConvLayer(dim_in_tuple[i], dim_out_tuple[i], kernel_size=7, padding=3*dilation_rate,
|
| 204 |
+
dilation=dilation_rate))
|
| 205 |
+
|
| 206 |
+
def forward(self, protein_seq_enc):
|
| 207 |
+
# pdb.set_trace()
|
| 208 |
+
protein_seq_enc = protein_seq_enc.transpose(1, 2) # protein_seq_enc's shape: B*L*d_model -> B*d_model*L
|
| 209 |
+
|
| 210 |
+
first_embedding = protein_seq_enc
|
| 211 |
+
second_embedding = protein_seq_enc
|
| 212 |
+
third_embedding = protein_seq_enc
|
| 213 |
+
|
| 214 |
+
for i in range(len(self.first_)):
|
| 215 |
+
first_embedding = self.first_[i](first_embedding)
|
| 216 |
+
|
| 217 |
+
for i in range(len(self.second_)):
|
| 218 |
+
second_embedding = self.second_[i](second_embedding)
|
| 219 |
+
|
| 220 |
+
for i in range(len(self.third_)):
|
| 221 |
+
third_embedding = self.third_[i](third_embedding)
|
| 222 |
+
|
| 223 |
+
# pdb.set_trace()
|
| 224 |
+
|
| 225 |
+
protein_seq_enc = first_embedding + second_embedding + third_embedding
|
| 226 |
+
|
| 227 |
+
return protein_seq_enc.transpose(1, 2)
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
class ReciprocalLayerwithCNN(nn.Module):
|
| 231 |
+
|
| 232 |
+
def __init__(self, d_model, d_inner, d_hidden, n_head, d_k, d_v):
|
| 233 |
+
super().__init__()
|
| 234 |
+
|
| 235 |
+
self.cnn = DilatedCNN(d_model, d_hidden)
|
| 236 |
+
|
| 237 |
+
self.sequence_attention_layer = MultiHeadAttentionSequence(n_head, d_hidden, d_k, d_v)
|
| 238 |
+
|
| 239 |
+
self.protein_attention_layer = MultiHeadAttentionSequence(n_head, d_hidden, d_k, d_v)
|
| 240 |
+
|
| 241 |
+
self.reciprocal_attention_layer = MultiHeadAttentionReciprocal(n_head, d_hidden, d_k, d_v)
|
| 242 |
+
|
| 243 |
+
self.ffn_seq = FFN(d_hidden, d_inner)
|
| 244 |
+
|
| 245 |
+
self.ffn_protein = FFN(d_hidden, d_inner)
|
| 246 |
+
|
| 247 |
+
def forward(self, sequence_enc, protein_seq_enc):
|
| 248 |
+
# pdb.set_trace() # protein_seq_enc.shape = B * L * d_model
|
| 249 |
+
protein_seq_enc = self.cnn(protein_seq_enc)
|
| 250 |
+
prot_enc, prot_attention = self.protein_attention_layer(protein_seq_enc, protein_seq_enc, protein_seq_enc)
|
| 251 |
+
|
| 252 |
+
seq_enc, sequence_attention = self.sequence_attention_layer(sequence_enc, sequence_enc, sequence_enc)
|
| 253 |
+
|
| 254 |
+
prot_enc, seq_enc, prot_seq_attention, seq_prot_attention = self.reciprocal_attention_layer(prot_enc, seq_enc, seq_enc, prot_enc)
|
| 255 |
+
|
| 256 |
+
prot_enc = self.ffn_protein(prot_enc)
|
| 257 |
+
|
| 258 |
+
seq_enc = self.ffn_seq(seq_enc)
|
| 259 |
+
|
| 260 |
+
return prot_enc, seq_enc, prot_attention, sequence_attention, prot_seq_attention, seq_prot_attention
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
class ReciprocalLayer(nn.Module):
|
| 264 |
+
|
| 265 |
+
def __init__(self, d_model, d_inner, n_head, d_k, d_v):
|
| 266 |
+
|
| 267 |
+
super().__init__()
|
| 268 |
+
|
| 269 |
+
self.sequence_attention_layer = MultiHeadAttentionSequence(n_head, d_model, d_k, d_v)
|
| 270 |
+
|
| 271 |
+
self.protein_attention_layer = MultiHeadAttentionSequence(n_head, d_model, d_k, d_v)
|
| 272 |
+
|
| 273 |
+
self.reciprocal_attention_layer = MultiHeadAttentionReciprocal(n_head, d_model, d_k, d_v)
|
| 274 |
+
|
| 275 |
+
self.ffn_seq = FFN(d_model, d_inner)
|
| 276 |
+
|
| 277 |
+
self.ffn_protein = FFN(d_model, d_inner)
|
| 278 |
+
|
| 279 |
+
def forward(self, sequence_enc, protein_seq_enc):
|
| 280 |
+
prot_enc, prot_attention = self.protein_attention_layer(protein_seq_enc, protein_seq_enc, protein_seq_enc)
|
| 281 |
+
|
| 282 |
+
seq_enc, sequence_attention = self.sequence_attention_layer(sequence_enc, sequence_enc, sequence_enc)
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
prot_enc, seq_enc, prot_seq_attention, seq_prot_attention = self.reciprocal_attention_layer(prot_enc, seq_enc, seq_enc, prot_enc)
|
| 286 |
+
prot_enc = self.ffn_protein(prot_enc)
|
| 287 |
+
|
| 288 |
+
seq_enc = self.ffn_seq(seq_enc)
|
| 289 |
+
|
| 290 |
+
return prot_enc, seq_enc, prot_attention, sequence_attention, prot_seq_attention, seq_prot_attention
|
src/scoring/functions/classifiers/hemolysis-xgboost.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
src/scoring/functions/classifiers/nonfouling-xgboost.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
src/scoring/functions/classifiers/permeability-xgboost.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7e5d8c84bdad75f7091b5b3963133d4b0ebd180ae45654618ca6c090eee0bc06
|
| 3 |
+
size 45249160
|
src/scoring/functions/classifiers/solubility-xgboost.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
src/scoring/functions/hemolysis.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import xgboost as xgb
|
| 2 |
+
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
from transformers import AutoModelForMaskedLM
|
| 5 |
+
import warnings
|
| 6 |
+
import numpy as np
|
| 7 |
+
from rdkit import rdBase
|
| 8 |
+
|
| 9 |
+
rdBase.DisableLog('rdApp.error')
|
| 10 |
+
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
| 11 |
+
warnings.filterwarnings("ignore", category=UserWarning)
|
| 12 |
+
warnings.filterwarnings("ignore", category=FutureWarning)
|
| 13 |
+
|
| 14 |
+
class Hemolysis:
|
| 15 |
+
|
| 16 |
+
def __init__(self, tokenizer, base_path, device=None, emb_model=None):
|
| 17 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
|
| 18 |
+
self.predictor = xgb.Booster(model_file=f'{base_path}/src/scoring/functions/classifiers/hemolysis-xgboost.json')
|
| 19 |
+
self.emb_model = emb_model if emb_model is not None else AutoModelForMaskedLM.from_pretrained('aaronfeller/PeptideCLM-23M-all').roformer.to(device).eval()
|
| 20 |
+
self.tokenizer = tokenizer
|
| 21 |
+
|
| 22 |
+
def generate_embeddings(self, sequences):
|
| 23 |
+
embeddings = []
|
| 24 |
+
for sequence in sequences:
|
| 25 |
+
tokenized = self.tokenizer(sequence, return_tensors='pt')
|
| 26 |
+
tokenized = {k: v.to(self.device) for k, v in tokenized.items()}
|
| 27 |
+
with torch.no_grad():
|
| 28 |
+
output = self.emb_model(**tokenized)
|
| 29 |
+
# Mean pooling across sequence length
|
| 30 |
+
embedding = output.last_hidden_state.mean(dim=1).squeeze(0).cpu().numpy()
|
| 31 |
+
embeddings.append(embedding)
|
| 32 |
+
return np.array(embeddings)
|
| 33 |
+
|
| 34 |
+
def get_scores(self, input_seqs: list):
|
| 35 |
+
scores = np.ones(len(input_seqs))
|
| 36 |
+
features = self.generate_embeddings(input_seqs)
|
| 37 |
+
|
| 38 |
+
if len(features) == 0:
|
| 39 |
+
return scores
|
| 40 |
+
|
| 41 |
+
features = np.nan_to_num(features, nan=0.)
|
| 42 |
+
features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max)
|
| 43 |
+
|
| 44 |
+
features = xgb.DMatrix(features)
|
| 45 |
+
|
| 46 |
+
probs = self.predictor.predict(features)
|
| 47 |
+
# return the probability of it being not hemolytic
|
| 48 |
+
return scores - probs
|
| 49 |
+
|
| 50 |
+
def __call__(self, input_seqs: list):
|
| 51 |
+
scores = self.get_scores(input_seqs)
|
| 52 |
+
return scores
|
| 53 |
+
|
| 54 |
+
def unittest():
|
| 55 |
+
hemo = Hemolysis()
|
| 56 |
+
seq = ["[te]NCC(=O)N[C@H](CS)C(=O)N[C@@H](CO)C(=O)NCC(=O)N[C@@H](CC1=CN=C-N1)C(=O)N[C@@H](CC(=O)N)C(=O)N[C@@H](CC(=CN2)C1=C2C=CC=C1)C(=O)N[C@@H](c1ccc(cc1)F)C(=O)N[C@@H]([C@H](CC)C)C(=O)N[C@@H](CCCO)C(=O)N[C@@H](CC1=CN=C-N1)C(=O)N[C@@H](CCC(=O)O)C(=O)N[C@@H](CO)C(=O)O"]
|
| 57 |
+
print(hemo.tokenizer.vocab_size)
|
| 58 |
+
scores = hemo(input_seqs=seq)
|
| 59 |
+
print(scores)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
if __name__ == '__main__':
|
| 63 |
+
unittest()
|
src/scoring/functions/nonfouling.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import os
|
| 3 |
+
import xgboost as xgb
|
| 4 |
+
import torch
|
| 5 |
+
import numpy as np
|
| 6 |
+
from transformers import AutoModelForMaskedLM
|
| 7 |
+
import warnings
|
| 8 |
+
import numpy as np
|
| 9 |
+
from rdkit import Chem, rdBase, DataStructs
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
rdBase.DisableLog('rdApp.error')
|
| 13 |
+
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
| 14 |
+
warnings.filterwarnings("ignore", category=UserWarning)
|
| 15 |
+
warnings.filterwarnings("ignore", category=FutureWarning)
|
| 16 |
+
|
| 17 |
+
class Nonfouling:
|
| 18 |
+
|
| 19 |
+
def __init__(self, tokenizer, base_path, device=None, emb_model=None):
|
| 20 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
|
| 21 |
+
self.predictor = xgb.Booster(model_file=f'{base_path}/src/scoring/functions/classifiers/nonfouling-xgboost.json')
|
| 22 |
+
self.emb_model = emb_model if emb_model is not None else AutoModelForMaskedLM.from_pretrained('aaronfeller/PeptideCLM-23M-all').roformer.to(device).eval()
|
| 23 |
+
self.tokenizer = tokenizer
|
| 24 |
+
|
| 25 |
+
def generate_embeddings(self, sequences):
|
| 26 |
+
embeddings = []
|
| 27 |
+
for sequence in sequences:
|
| 28 |
+
tokenized = self.tokenizer(sequence, return_tensors='pt')
|
| 29 |
+
tokenized = {k: v.to(self.device) for k, v in tokenized.items()}
|
| 30 |
+
with torch.no_grad():
|
| 31 |
+
output = self.emb_model(**tokenized)
|
| 32 |
+
# Mean pooling across sequence length
|
| 33 |
+
embedding = output.last_hidden_state.mean(dim=1).squeeze(0).cpu().numpy()
|
| 34 |
+
embeddings.append(embedding)
|
| 35 |
+
return np.array(embeddings)
|
| 36 |
+
|
| 37 |
+
def get_scores(self, input_seqs: list):
|
| 38 |
+
scores = np.zeros(len(input_seqs))
|
| 39 |
+
features = self.generate_embeddings(input_seqs)
|
| 40 |
+
|
| 41 |
+
if len(features) == 0:
|
| 42 |
+
return scores
|
| 43 |
+
|
| 44 |
+
features = np.nan_to_num(features, nan=0.)
|
| 45 |
+
features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max)
|
| 46 |
+
|
| 47 |
+
features = xgb.DMatrix(features)
|
| 48 |
+
|
| 49 |
+
scores = self.predictor.predict(features)
|
| 50 |
+
# return the probability of it being not hemolytic
|
| 51 |
+
return scores
|
| 52 |
+
|
| 53 |
+
def __call__(self, input_seqs: list):
|
| 54 |
+
scores = self.get_scores(input_seqs)
|
| 55 |
+
return scores
|
| 56 |
+
|
| 57 |
+
def unittest():
|
| 58 |
+
nf = Nonfouling()
|
| 59 |
+
seq = ["NCC(=O)N[C@H](CS)C(=O)N[C@@H](CO)C(=O)NCC(=O)N[C@@H](CC1=CN=C-N1)C(=O)N[C@@H](CC(=O)N)C(=O)N[C@@H](CC(=CN2)C1=C2C=CC=C1)C(=O)N[C@@H](c1ccc(cc1)F)C(=O)N[C@@H]([C@H](CC)C)C(=O)N[C@@H](CCCO)C(=O)N[C@@H](CC1=CN=C-N1)C(=O)N[C@@H](CCC(=O)O)C(=O)N[C@@H](CO)C(=O)O"]
|
| 60 |
+
|
| 61 |
+
scores = nf(input_seqs=seq)
|
| 62 |
+
print(scores)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
if __name__ == '__main__':
|
| 66 |
+
unittest()
|
src/scoring/functions/permeability.py
ADDED
|
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import os
|
| 3 |
+
import xgboost as xgb
|
| 4 |
+
import torch
|
| 5 |
+
import numpy as np
|
| 6 |
+
from transformers import AutoModelForMaskedLM
|
| 7 |
+
import warnings
|
| 8 |
+
import numpy as np
|
| 9 |
+
from rdkit.Chem import Descriptors, rdMolDescriptors
|
| 10 |
+
from rdkit import Chem, rdBase, DataStructs
|
| 11 |
+
from rdkit.Chem import AllChem
|
| 12 |
+
from typing import List
|
| 13 |
+
from transformers import AutoModelForMaskedLM
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
rdBase.DisableLog('rdApp.error')
|
| 17 |
+
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
| 18 |
+
warnings.filterwarnings("ignore", category=UserWarning)
|
| 19 |
+
warnings.filterwarnings("ignore", category=FutureWarning)
|
| 20 |
+
|
| 21 |
+
def fingerprints_from_smiles(smiles: List, size=2048):
|
| 22 |
+
""" Create ECFP fingerprints of smiles, with validity check """
|
| 23 |
+
fps = []
|
| 24 |
+
valid_mask = []
|
| 25 |
+
for i, smile in enumerate(smiles):
|
| 26 |
+
mol = Chem.MolFromSmiles(smile)
|
| 27 |
+
valid_mask.append(int(mol is not None))
|
| 28 |
+
fp = fingerprints_from_mol(mol, size=size) if mol else np.zeros((1, size))
|
| 29 |
+
fps.append(fp)
|
| 30 |
+
|
| 31 |
+
fps = np.concatenate(fps, axis=0)
|
| 32 |
+
return fps, valid_mask
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def fingerprints_from_mol(molecule, radius=3, size=2048, hashed=False):
|
| 36 |
+
""" Create ECFP fingerprint of a molecule """
|
| 37 |
+
if hashed:
|
| 38 |
+
fp_bits = AllChem.GetHashedMorganFingerprint(molecule, radius, nBits=size)
|
| 39 |
+
else:
|
| 40 |
+
fp_bits = AllChem.GetMorganFingerprintAsBitVect(molecule, radius, nBits=size)
|
| 41 |
+
fp_np = np.zeros((1,))
|
| 42 |
+
DataStructs.ConvertToNumpyArray(fp_bits, fp_np)
|
| 43 |
+
return fp_np.reshape(1, -1)
|
| 44 |
+
|
| 45 |
+
def getMolDescriptors(mol, missingVal=0):
|
| 46 |
+
""" calculate the full list of descriptors for a molecule """
|
| 47 |
+
|
| 48 |
+
values, names = [], []
|
| 49 |
+
for nm, fn in Descriptors._descList:
|
| 50 |
+
try:
|
| 51 |
+
val = fn(mol)
|
| 52 |
+
except:
|
| 53 |
+
val = missingVal
|
| 54 |
+
values.append(val)
|
| 55 |
+
names.append(nm)
|
| 56 |
+
|
| 57 |
+
custom_descriptors = {'hydrogen-bond donors': rdMolDescriptors.CalcNumLipinskiHBD,
|
| 58 |
+
'hydrogen-bond acceptors': rdMolDescriptors.CalcNumLipinskiHBA,
|
| 59 |
+
'rotatable bonds': rdMolDescriptors.CalcNumRotatableBonds,}
|
| 60 |
+
|
| 61 |
+
for nm, fn in custom_descriptors.items():
|
| 62 |
+
try:
|
| 63 |
+
val = fn(mol)
|
| 64 |
+
except:
|
| 65 |
+
val = missingVal
|
| 66 |
+
values.append(val)
|
| 67 |
+
names.append(nm)
|
| 68 |
+
return values, names
|
| 69 |
+
|
| 70 |
+
def get_pep_dps_from_smi(smi):
|
| 71 |
+
try:
|
| 72 |
+
mol = Chem.MolFromSmiles(smi)
|
| 73 |
+
except:
|
| 74 |
+
print(f"convert smi {smi} to molecule failed!")
|
| 75 |
+
mol = None
|
| 76 |
+
|
| 77 |
+
dps, _ = getMolDescriptors(mol)
|
| 78 |
+
return np.array(dps)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def get_pep_dps(smi_list):
|
| 82 |
+
if len(smi_list) == 0:
|
| 83 |
+
return np.zeros((0, 213))
|
| 84 |
+
return np.array([get_pep_dps_from_smi(smi) for smi in smi_list])
|
| 85 |
+
|
| 86 |
+
def check_smi_validity(smiles: list):
|
| 87 |
+
valid_smi, valid_idx = [], []
|
| 88 |
+
for idx, smi in enumerate(smiles):
|
| 89 |
+
try:
|
| 90 |
+
mol = Chem.MolFromSmiles(smi) if smi else None
|
| 91 |
+
if mol:
|
| 92 |
+
valid_smi.append(smi)
|
| 93 |
+
valid_idx.append(idx)
|
| 94 |
+
except Exception as e:
|
| 95 |
+
# logger.debug(f'Error: {e} in smiles {smi}')
|
| 96 |
+
pass
|
| 97 |
+
return valid_smi, valid_idx
|
| 98 |
+
|
| 99 |
+
class Permeability:
|
| 100 |
+
|
| 101 |
+
def __init__(self, tokenizer, base_path, device=None, emb_model=None):
|
| 102 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
|
| 103 |
+
self.predictor = xgb.Booster(model_file=f'{base_path}/src/scoring/functions/classifiers/permeability-xgboost.json')
|
| 104 |
+
if emb_model is not None:
|
| 105 |
+
self.emb_model = emb_model.to(self.device).eval()
|
| 106 |
+
else:
|
| 107 |
+
self.emb_model = AutoModelForMaskedLM.from_pretrained('aaronfeller/PeptideCLM-23M-all').roformer.to(device).eval()
|
| 108 |
+
|
| 109 |
+
self.tokenizer = tokenizer
|
| 110 |
+
|
| 111 |
+
def generate_embeddings(self, sequences):
|
| 112 |
+
embeddings = []
|
| 113 |
+
for sequence in sequences:
|
| 114 |
+
tokenized = self.tokenizer(sequence, return_tensors='pt')
|
| 115 |
+
tokenized = {k: v.to(self.device) for k, v in tokenized.items()}
|
| 116 |
+
with torch.no_grad():
|
| 117 |
+
output = self.emb_model(**tokenized)
|
| 118 |
+
# Mean pooling across sequence length
|
| 119 |
+
embedding = output.last_hidden_state.mean(dim=1).squeeze(0).cpu().numpy()
|
| 120 |
+
embeddings.append(embedding)
|
| 121 |
+
return np.array(embeddings)
|
| 122 |
+
|
| 123 |
+
def get_features(self, input_seqs: list, dps=False, fps=False):
|
| 124 |
+
#valid_smiles, valid_idxes = check_smi_validity(input_seqs)
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
if fps:
|
| 128 |
+
fingerprints = fingerprints_from_smiles(input_seqs)[0]
|
| 129 |
+
else:
|
| 130 |
+
fingerprints = torch.empty((len(input_seqs), 0))
|
| 131 |
+
|
| 132 |
+
if dps:
|
| 133 |
+
descriptors = get_pep_dps(input_seqs)
|
| 134 |
+
else:
|
| 135 |
+
descriptors = torch.empty((len(input_seqs), 0))
|
| 136 |
+
|
| 137 |
+
embeddings = self.generate_embeddings(input_seqs)
|
| 138 |
+
# logger.debug(f'X_fps.shape: {X_fps.shape}, X_dps.shape: {X_dps.shape}')
|
| 139 |
+
|
| 140 |
+
features = np.concatenate([fingerprints, descriptors, embeddings], axis=1)
|
| 141 |
+
|
| 142 |
+
return features
|
| 143 |
+
|
| 144 |
+
def get_scores(self, input_seqs: list):
|
| 145 |
+
scores = -10 * np.ones(len(input_seqs))
|
| 146 |
+
features = self.get_features(input_seqs)
|
| 147 |
+
|
| 148 |
+
if len(features) == 0:
|
| 149 |
+
return scores
|
| 150 |
+
|
| 151 |
+
features = np.nan_to_num(features, nan=0.)
|
| 152 |
+
features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max)
|
| 153 |
+
|
| 154 |
+
features = xgb.DMatrix(features)
|
| 155 |
+
|
| 156 |
+
scores = self.predictor.predict(features)
|
| 157 |
+
return scores
|
| 158 |
+
|
| 159 |
+
def __call__(self, input_seqs: list):
|
| 160 |
+
scores = self.get_scores(input_seqs)
|
| 161 |
+
return scores
|
| 162 |
+
|
| 163 |
+
def unittest():
|
| 164 |
+
permeability = Permeability()
|
| 165 |
+
seq = ['N[C@@H](CCCNC(=N)N)C(=O)N[C@@H](Cc1cNc2c1cc(O)cc2)C(=O)N[C@@H](CC1=CN=C-N1)C(=O)N[C@@H](CCCNC(=N)N)C(=O)N[C@@H](Cc1ccccc1)C(=O)N[C@@H](CCC(=O)O)C(=O)N[C@@H]([C@@H](O)C(C)C)C(=O)N[C@@H](Cc1ccc(O)cc1)C(=O)N[C@H](CC(=CN2)C1=C2C=CC=C1)C(=O)O']
|
| 166 |
+
scores = permeability(input_seqs=seq)
|
| 167 |
+
print(scores)
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
if __name__ == '__main__':
|
| 171 |
+
unittest()
|
src/scoring/functions/scoring_utils.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import warnings
|
| 2 |
+
import numpy as np
|
| 3 |
+
from loguru import logger
|
| 4 |
+
from sklearn.ensemble import RandomForestRegressor
|
| 5 |
+
from rdkit.Chem import Descriptors, rdMolDescriptors
|
| 6 |
+
import joblib
|
| 7 |
+
from rdkit import Chem, rdBase, DataStructs
|
| 8 |
+
from rdkit.Chem import AllChem
|
| 9 |
+
from typing import List
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def fingerprints_from_mol(molecule, radius=3, size=2048, hashed=False):
|
| 13 |
+
"""
|
| 14 |
+
Create ECFP fingerprint of a molecule
|
| 15 |
+
"""
|
| 16 |
+
if hashed:
|
| 17 |
+
fp_bits = AllChem.GetHashedMorganFingerprint(molecule, radius, nBits=size)
|
| 18 |
+
else:
|
| 19 |
+
fp_bits = AllChem.GetMorganFingerprintAsBitVect(molecule, radius, nBits=size)
|
| 20 |
+
fp_np = np.zeros((1,))
|
| 21 |
+
DataStructs.ConvertToNumpyArray(fp_bits, fp_np)
|
| 22 |
+
return fp_np.reshape(1, -1)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def fingerprints_from_smiles(smiles: List, size=2048):
|
| 26 |
+
""" Create ECFP fingerprints of smiles, with validity check """
|
| 27 |
+
fps = []
|
| 28 |
+
valid_mask = []
|
| 29 |
+
for i, smile in enumerate(smiles):
|
| 30 |
+
mol = Chem.MolFromSmiles(smile)
|
| 31 |
+
valid_mask.append(int(mol is not None))
|
| 32 |
+
fp = fingerprints_from_mol(mol, size=size) if mol else np.zeros((1, size))
|
| 33 |
+
fps.append(fp)
|
| 34 |
+
|
| 35 |
+
fps = np.concatenate(fps, axis=0) if len(fps) > 0 else np.zeros((0, size))
|
| 36 |
+
return fps, valid_mask
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def getMolDescriptors(mol, missingVal=0):
|
| 40 |
+
""" calculate the full list of descriptors for a molecule """
|
| 41 |
+
|
| 42 |
+
values, names = [], []
|
| 43 |
+
for nm, fn in Descriptors._descList:
|
| 44 |
+
try:
|
| 45 |
+
val = fn(mol)
|
| 46 |
+
except:
|
| 47 |
+
val = missingVal
|
| 48 |
+
values.append(val)
|
| 49 |
+
names.append(nm)
|
| 50 |
+
|
| 51 |
+
custom_descriptors = {'hydrogen-bond donors': rdMolDescriptors.CalcNumLipinskiHBD,
|
| 52 |
+
'hydrogen-bond acceptors': rdMolDescriptors.CalcNumLipinskiHBA,
|
| 53 |
+
'rotatable bonds': rdMolDescriptors.CalcNumRotatableBonds,}
|
| 54 |
+
|
| 55 |
+
for nm, fn in custom_descriptors.items():
|
| 56 |
+
try:
|
| 57 |
+
val = fn(mol)
|
| 58 |
+
except:
|
| 59 |
+
val = missingVal
|
| 60 |
+
values.append(val)
|
| 61 |
+
names.append(nm)
|
| 62 |
+
return values, names
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def get_pep_dps_from_smi(smi):
|
| 66 |
+
try:
|
| 67 |
+
mol = Chem.MolFromSmiles(smi)
|
| 68 |
+
except:
|
| 69 |
+
print(f"convert smi {smi} to molecule failed!")
|
| 70 |
+
mol = None
|
| 71 |
+
|
| 72 |
+
dps, _ = getMolDescriptors(mol)
|
| 73 |
+
return np.array(dps)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def get_pep_dps(smi_list):
|
| 77 |
+
if len(smi_list) == 0:
|
| 78 |
+
return np.zeros((0, 211))
|
| 79 |
+
return np.array([get_pep_dps_from_smi(smi) for smi in smi_list])
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def check_smi_validity(smiles: list):
|
| 84 |
+
valid_smi, valid_idx = [], []
|
| 85 |
+
for idx, smi in enumerate(smiles):
|
| 86 |
+
try:
|
| 87 |
+
mol = Chem.MolFromSmiles(smi) if smi else None
|
| 88 |
+
if mol:
|
| 89 |
+
valid_smi.append(smi)
|
| 90 |
+
valid_idx.append(idx)
|
| 91 |
+
except Exception as e:
|
| 92 |
+
# logger.debug(f'Error: {e} in smiles {smi}')
|
| 93 |
+
pass
|
| 94 |
+
return valid_smi, valid_idx
|
src/scoring/functions/solubility.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import xgboost as xgb
|
| 2 |
+
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
from transformers import AutoModelForMaskedLM
|
| 5 |
+
import warnings
|
| 6 |
+
import numpy as np
|
| 7 |
+
from rdkit import rdBase
|
| 8 |
+
|
| 9 |
+
rdBase.DisableLog('rdApp.error')
|
| 10 |
+
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
| 11 |
+
warnings.filterwarnings("ignore", category=UserWarning)
|
| 12 |
+
warnings.filterwarnings("ignore", category=FutureWarning)
|
| 13 |
+
|
| 14 |
+
class Solubility:
|
| 15 |
+
def __init__(self, tokenizer, base_path, device=None, emb_model=None):
|
| 16 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
|
| 17 |
+
self.predictor = xgb.Booster(model_file=f'{base_path}/src/scoring/functions/classifiers/solubility-xgboost.json')
|
| 18 |
+
if emb_model is not None:
|
| 19 |
+
self.emb_model = emb_model.to(self.device).eval()
|
| 20 |
+
else:
|
| 21 |
+
self.emb_model = AutoModelForMaskedLM.from_pretrained('aaronfeller/PeptideCLM-23M-all').roformer.to(self.device).eval()
|
| 22 |
+
|
| 23 |
+
self.tokenizer = tokenizer
|
| 24 |
+
|
| 25 |
+
def generate_embeddings(self, sequences):
|
| 26 |
+
embeddings = []
|
| 27 |
+
for sequence in sequences:
|
| 28 |
+
tokenized = self.tokenizer(sequence, return_tensors='pt')
|
| 29 |
+
tokenized = {k: v.to(self.device) for k, v in tokenized.items()}
|
| 30 |
+
with torch.no_grad():
|
| 31 |
+
output = self.emb_model(**tokenized)
|
| 32 |
+
# Mean pooling across sequence length
|
| 33 |
+
embedding = output.last_hidden_state.mean(dim=1).squeeze(0).cpu().numpy()
|
| 34 |
+
embeddings.append(embedding)
|
| 35 |
+
return np.array(embeddings)
|
| 36 |
+
|
| 37 |
+
def get_scores(self, input_seqs: list):
|
| 38 |
+
scores = np.zeros(len(input_seqs))
|
| 39 |
+
features = self.generate_embeddings(input_seqs)
|
| 40 |
+
|
| 41 |
+
if len(features) == 0:
|
| 42 |
+
return scores
|
| 43 |
+
|
| 44 |
+
features = np.nan_to_num(features, nan=0.)
|
| 45 |
+
features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max)
|
| 46 |
+
|
| 47 |
+
features = xgb.DMatrix(features)
|
| 48 |
+
|
| 49 |
+
scores = self.predictor.predict(features)
|
| 50 |
+
return scores
|
| 51 |
+
|
| 52 |
+
def __call__(self, input_seqs: list):
|
| 53 |
+
scores = self.get_scores(input_seqs)
|
| 54 |
+
return scores
|
| 55 |
+
|
| 56 |
+
def unittest():
|
| 57 |
+
solubility = Solubility()
|
| 58 |
+
seq = ["NCC(=O)N[C@H](CS)C(=O)N[C@@H](CO)C(=O)NCC(=O)N[C@@H](CC1=CN=C-N1)C(=O)N[C@@H](CC(=O)N)C(=O)N[C@@H](CC(=CN2)C1=C2C=CC=C1)C(=O)N[C@@H](c1ccc(cc1)F)C(=O)N[C@@H]([C@H](CC)C)C(=O)N[C@@H](CCCO)C(=O)N[C@@H](CC1=CN=C-N1)C(=O)N[C@@H](CCC(=O)O)C(=O)N[C@@H](CO)C(=O)O"]
|
| 59 |
+
scores = solubility(input_seqs=seq)
|
| 60 |
+
print(scores)
|
| 61 |
+
|
| 62 |
+
if __name__ == '__main__':
|
| 63 |
+
unittest()
|
src/scoring/scoring_functions.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
|
| 2 |
+
from transformers import AutoModelForMaskedLM
|
| 3 |
+
import numpy as np
|
| 4 |
+
from scoring.functions.binding import BindingAffinity
|
| 5 |
+
from scoring.functions.permeability import Permeability
|
| 6 |
+
from scoring.functions.solubility import Solubility
|
| 7 |
+
from scoring.functions.hemolysis import Hemolysis
|
| 8 |
+
from scoring.functions.nonfouling import Nonfouling
|
| 9 |
+
|
| 10 |
+
base_path = '/path/to/your/home'
|
| 11 |
+
|
| 12 |
+
class ScoringFunctions:
|
| 13 |
+
def __init__(self, score_func_names=None, prot_seqs=None, device=None):
|
| 14 |
+
"""
|
| 15 |
+
Class for generating score vectors given generated sequence
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
score_func_names: list of scoring function names to be evaluated
|
| 19 |
+
score_weights: weights to scale scores (default: 1)
|
| 20 |
+
target_protein: sequence of target protein binder
|
| 21 |
+
"""
|
| 22 |
+
emb_model = AutoModelForMaskedLM.from_pretrained('aaronfeller/PeptideCLM-23M-all').roformer.to(device).eval()
|
| 23 |
+
tokenizer = SMILES_SPE_Tokenizer(f'{base_path}/src/scoring/functions/tokenizer/new_vocab.txt',
|
| 24 |
+
f'{base_path}/src/scoring/functions/tokenizer/new_splits.txt')
|
| 25 |
+
prot_seqs = prot_seqs if prot_seqs is not None else []
|
| 26 |
+
|
| 27 |
+
if score_func_names is None:
|
| 28 |
+
# just do unmasking based on validity of peptide bonds
|
| 29 |
+
self.score_func_names = []
|
| 30 |
+
else:
|
| 31 |
+
self.score_func_names = score_func_names
|
| 32 |
+
|
| 33 |
+
# binding affinities
|
| 34 |
+
self.target_protein = prot_seqs
|
| 35 |
+
print(len(prot_seqs))
|
| 36 |
+
|
| 37 |
+
if ('binding_affinity1' in score_func_names) and (len(prot_seqs) == 1):
|
| 38 |
+
binding_affinity1 = BindingAffinity(prot_seqs[0], tokenizer=tokenizer, base_path=base_path, device=device)
|
| 39 |
+
binding_affinity2 = None
|
| 40 |
+
elif ('binding_affinity1' in score_func_names) and ('binding_affinity2' in score_func_names) and (len(prot_seqs) == 2):
|
| 41 |
+
binding_affinity1 = BindingAffinity(prot_seqs[0], tokenizer=tokenizer, base_path=base_path, device=device)
|
| 42 |
+
binding_affinity2 = BindingAffinity(prot_seqs[1], tokenizer=tokenizer, base_path=base_path, device=device)
|
| 43 |
+
else:
|
| 44 |
+
print("here")
|
| 45 |
+
binding_affinity1 = None
|
| 46 |
+
binding_affinity2 = None
|
| 47 |
+
|
| 48 |
+
permeability = Permeability(tokenizer=tokenizer, base_path=base_path, device=device, emb_model=emb_model)
|
| 49 |
+
sol = Solubility(tokenizer=tokenizer, base_path=base_path, device=device, emb_model=emb_model)
|
| 50 |
+
nonfouling = Nonfouling(tokenizer=tokenizer, base_path=base_path, device=device, emb_model=emb_model)
|
| 51 |
+
hemo = Hemolysis(tokenizer=tokenizer, base_path=base_path, device=device, emb_model=emb_model)
|
| 52 |
+
|
| 53 |
+
self.all_funcs = {'binding_affinity1': binding_affinity1,
|
| 54 |
+
'binding_affinity2': binding_affinity2,
|
| 55 |
+
'permeability': permeability,
|
| 56 |
+
'nonfouling': nonfouling,
|
| 57 |
+
'solubility': sol,
|
| 58 |
+
'hemolysis': hemo
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
def forward(self, input_seqs):
|
| 62 |
+
scores = []
|
| 63 |
+
|
| 64 |
+
for i, score_func in enumerate(self.score_func_names):
|
| 65 |
+
score = self.all_funcs[score_func](input_seqs = input_seqs)
|
| 66 |
+
|
| 67 |
+
scores.append(score)
|
| 68 |
+
|
| 69 |
+
# convert to numpy arrays with shape (num_sequences, num_functions)
|
| 70 |
+
scores = np.float32(scores).T
|
| 71 |
+
|
| 72 |
+
return scores
|
| 73 |
+
|
| 74 |
+
def __call__(self, input_seqs: list):
|
| 75 |
+
return self.forward(input_seqs)
|
src/scoring/tokenizer/my_tokenizers.py
ADDED
|
@@ -0,0 +1,424 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import collections
|
| 2 |
+
import os
|
| 3 |
+
import re
|
| 4 |
+
from typing import List, Optional
|
| 5 |
+
from transformers import PreTrainedTokenizer
|
| 6 |
+
from SmilesPE.tokenizer import SPE_Tokenizer
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
def load_vocab(vocab_file):
|
| 10 |
+
"""Loads a vocabulary file into a dictionary."""
|
| 11 |
+
vocab = collections.OrderedDict()
|
| 12 |
+
with open(vocab_file, "r", encoding="utf-8") as reader:
|
| 13 |
+
tokens = reader.readlines()
|
| 14 |
+
for index, token in enumerate(tokens):
|
| 15 |
+
token = token.rstrip("\n")
|
| 16 |
+
vocab[token] = index
|
| 17 |
+
return vocab
|
| 18 |
+
|
| 19 |
+
class Atomwise_Tokenizer(object):
|
| 20 |
+
"""Run atom-level SMILES tokenization"""
|
| 21 |
+
|
| 22 |
+
def __init__(self):
|
| 23 |
+
""" Constructs a atom-level Tokenizer.
|
| 24 |
+
"""
|
| 25 |
+
# self.regex_pattern = r"(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\|\/|:|~|@|\?|>>?|\*|\$|\%[0-9]{2}|[0-9])"
|
| 26 |
+
self.regex_pattern = r"(\([^\(\)]{0,4}\)|\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\|\/\/?|:|~|@|\?|>>?|\*|\$|\%[0-9]{2}|[0-9])"
|
| 27 |
+
|
| 28 |
+
self.regex = re.compile(self.regex_pattern)
|
| 29 |
+
|
| 30 |
+
def tokenize(self, text):
|
| 31 |
+
""" Basic Tokenization of a SMILES.
|
| 32 |
+
"""
|
| 33 |
+
tokens = [token for token in self.regex.findall(text)]
|
| 34 |
+
return tokens
|
| 35 |
+
|
| 36 |
+
class SMILES_SPE_Tokenizer(PreTrainedTokenizer):
|
| 37 |
+
r"""
|
| 38 |
+
Constructs a SMILES tokenizer. Based on SMILES Pair Encoding (https://github.com/XinhaoLi74/SmilesPE).
|
| 39 |
+
This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the methods. Users
|
| 40 |
+
should refer to the superclass for more information regarding methods.
|
| 41 |
+
Args:
|
| 42 |
+
vocab_file (:obj:`string`):
|
| 43 |
+
File containing the vocabulary.
|
| 44 |
+
spe_file (:obj:`string`):
|
| 45 |
+
File containing the trained SMILES Pair Encoding vocabulary.
|
| 46 |
+
unk_token (:obj:`string`, `optional`, defaults to "[UNK]"):
|
| 47 |
+
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
| 48 |
+
token instead.
|
| 49 |
+
sep_token (:obj:`string`, `optional`, defaults to "[SEP]"):
|
| 50 |
+
The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences
|
| 51 |
+
for sequence classification or for a text and a question for question answering.
|
| 52 |
+
It is also used as the last token of a sequence built with special tokens.
|
| 53 |
+
pad_token (:obj:`string`, `optional`, defaults to "[PAD]"):
|
| 54 |
+
The token used for padding, for example when batching sequences of different lengths.
|
| 55 |
+
cls_token (:obj:`string`, `optional`, defaults to "[CLS]"):
|
| 56 |
+
The classifier token which is used when doing sequence classification (classification of the whole
|
| 57 |
+
sequence instead of per-token classification). It is the first token of the sequence when built with
|
| 58 |
+
special tokens.
|
| 59 |
+
mask_token (:obj:`string`, `optional`, defaults to "[MASK]"):
|
| 60 |
+
The token used for masking values. This is the token used when training this model with masked language
|
| 61 |
+
modeling. This is the token which the model will try to predict.
|
| 62 |
+
"""
|
| 63 |
+
|
| 64 |
+
def __init__(self, vocab_file, spe_file,
|
| 65 |
+
unk_token="[UNK]",
|
| 66 |
+
sep_token="[SEP]",
|
| 67 |
+
pad_token="[PAD]",
|
| 68 |
+
cls_token="[CLS]",
|
| 69 |
+
mask_token="[MASK]",
|
| 70 |
+
**kwargs):
|
| 71 |
+
if not os.path.isfile(vocab_file):
|
| 72 |
+
raise ValueError("Can't find a vocabulary file at path '{}'.".format(vocab_file))
|
| 73 |
+
if not os.path.isfile(spe_file):
|
| 74 |
+
raise ValueError("Can't find a SPE vocabulary file at path '{}'.".format(spe_file))
|
| 75 |
+
|
| 76 |
+
self.vocab = load_vocab(vocab_file)
|
| 77 |
+
self.spe_vocab = open(spe_file, 'r', encoding='utf-8')
|
| 78 |
+
self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
|
| 79 |
+
self.spe_tokenizer = SPE_Tokenizer(self.spe_vocab)
|
| 80 |
+
|
| 81 |
+
super().__init__(
|
| 82 |
+
unk_token=unk_token,
|
| 83 |
+
sep_token=sep_token,
|
| 84 |
+
pad_token=pad_token,
|
| 85 |
+
cls_token=cls_token,
|
| 86 |
+
mask_token=mask_token,
|
| 87 |
+
**kwargs)
|
| 88 |
+
|
| 89 |
+
@property
|
| 90 |
+
def vocab_size(self):
|
| 91 |
+
return len(self.vocab)
|
| 92 |
+
|
| 93 |
+
def get_vocab(self):
|
| 94 |
+
return dict(self.vocab, **self.added_tokens_encoder)
|
| 95 |
+
|
| 96 |
+
def _tokenize(self, text):
|
| 97 |
+
return self.spe_tokenizer.tokenize(text).split(' ')
|
| 98 |
+
|
| 99 |
+
def _convert_token_to_id(self, token):
|
| 100 |
+
""" Converts a token (str) in an id using the vocab. """
|
| 101 |
+
return self.vocab.get(token, self.vocab.get(self.unk_token))
|
| 102 |
+
|
| 103 |
+
# changed encode and decode functions
|
| 104 |
+
def encode(self, token_array):
|
| 105 |
+
token_ids = []
|
| 106 |
+
token_ids.append(2)
|
| 107 |
+
for token in token_array:
|
| 108 |
+
id = self._convert_token_to_id(token)
|
| 109 |
+
token_ids.append(id)
|
| 110 |
+
token_ids.append(3)
|
| 111 |
+
token_ids = torch.tensor([token_ids])
|
| 112 |
+
attn_mask = torch.ones_like(token_ids)
|
| 113 |
+
return {'input_ids': token_ids, 'attention_mask': attn_mask}
|
| 114 |
+
|
| 115 |
+
def decode(self, token_ids, skip_special_tokens=True):
|
| 116 |
+
token_ids = token_ids.squeeze(0).cpu().tolist()
|
| 117 |
+
token_array = []
|
| 118 |
+
for idx in token_ids:
|
| 119 |
+
if idx == 3: # Stop decoding when token ID 3 is encountered
|
| 120 |
+
break
|
| 121 |
+
if skip_special_tokens and idx in self.all_special_ids:
|
| 122 |
+
continue
|
| 123 |
+
token = self._convert_id_to_token(idx)
|
| 124 |
+
token_array.append(token)
|
| 125 |
+
sequence = "".join(token_array)
|
| 126 |
+
return sequence
|
| 127 |
+
|
| 128 |
+
def batch_decode(self, batch_token_ids, skip_special_tokens=True):
|
| 129 |
+
sequences = []
|
| 130 |
+
for token_ids in batch_token_ids:
|
| 131 |
+
sequences.append(self.decode(token_ids))
|
| 132 |
+
return sequences
|
| 133 |
+
|
| 134 |
+
def get_token_split(self, token_ids):
|
| 135 |
+
if isinstance(token_ids, torch.Tensor):
|
| 136 |
+
token_ids = token_ids.cpu().tolist()
|
| 137 |
+
|
| 138 |
+
token_array = []
|
| 139 |
+
for seq_ids in token_ids:
|
| 140 |
+
seq_array = []
|
| 141 |
+
for id in seq_ids:
|
| 142 |
+
token = self._convert_id_to_token(id)
|
| 143 |
+
seq_array.append(token)
|
| 144 |
+
token_array.append(seq_array)
|
| 145 |
+
|
| 146 |
+
return token_array
|
| 147 |
+
|
| 148 |
+
def _convert_id_to_token(self, index):
|
| 149 |
+
"""Converts an index (integer) in a token (str) using the vocab."""
|
| 150 |
+
return self.ids_to_tokens.get(index, self.unk_token)
|
| 151 |
+
|
| 152 |
+
def convert_tokens_to_string(self, tokens):
|
| 153 |
+
""" Converts a sequence of tokens (string) in a single string. """
|
| 154 |
+
out_string = " ".join(tokens).replace(" ##", "").strip()
|
| 155 |
+
return out_string
|
| 156 |
+
|
| 157 |
+
def build_inputs_with_special_tokens(
|
| 158 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
| 159 |
+
) -> List[int]:
|
| 160 |
+
"""
|
| 161 |
+
Build model inputs from a sequence or a pair of sequence for sequence classification tasks
|
| 162 |
+
by concatenating and adding special tokens.
|
| 163 |
+
A BERT sequence has the following format:
|
| 164 |
+
- single sequence: ``[CLS] X [SEP]``
|
| 165 |
+
- pair of sequences: ``[CLS] A [SEP] B [SEP]``
|
| 166 |
+
Args:
|
| 167 |
+
token_ids_0 (:obj:`List[int]`):
|
| 168 |
+
List of IDs to which the special tokens will be added
|
| 169 |
+
token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
|
| 170 |
+
Optional second list of IDs for sequence pairs.
|
| 171 |
+
Returns:
|
| 172 |
+
:obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
|
| 173 |
+
"""
|
| 174 |
+
if token_ids_1 is None:
|
| 175 |
+
return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
|
| 176 |
+
cls = [self.cls_token_id]
|
| 177 |
+
sep = [self.sep_token_id]
|
| 178 |
+
return cls + token_ids_0 + sep + token_ids_1 + sep
|
| 179 |
+
|
| 180 |
+
def get_special_tokens_mask(
|
| 181 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
|
| 182 |
+
) -> List[int]:
|
| 183 |
+
"""
|
| 184 |
+
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
|
| 185 |
+
special tokens using the tokenizer ``prepare_for_model`` method.
|
| 186 |
+
Args:
|
| 187 |
+
token_ids_0 (:obj:`List[int]`):
|
| 188 |
+
List of ids.
|
| 189 |
+
token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
|
| 190 |
+
Optional second list of IDs for sequence pairs.
|
| 191 |
+
already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
| 192 |
+
Set to True if the token list is already formatted with special tokens for the model
|
| 193 |
+
Returns:
|
| 194 |
+
:obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
|
| 195 |
+
"""
|
| 196 |
+
|
| 197 |
+
if already_has_special_tokens:
|
| 198 |
+
if token_ids_1 is not None:
|
| 199 |
+
raise ValueError(
|
| 200 |
+
"You should not supply a second sequence if the provided sequence of "
|
| 201 |
+
"ids is already formated with special tokens for the model."
|
| 202 |
+
)
|
| 203 |
+
return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
|
| 204 |
+
|
| 205 |
+
if token_ids_1 is not None:
|
| 206 |
+
return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
|
| 207 |
+
return [1] + ([0] * len(token_ids_0)) + [1]
|
| 208 |
+
|
| 209 |
+
def create_token_type_ids_from_sequences(
|
| 210 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
| 211 |
+
) -> List[int]:
|
| 212 |
+
"""
|
| 213 |
+
Creates a mask from the two sequences passed to be used in a sequence-pair classification task.
|
| 214 |
+
A BERT sequence pair mask has the following format:
|
| 215 |
+
::
|
| 216 |
+
0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
|
| 217 |
+
| first sequence | second sequence |
|
| 218 |
+
if token_ids_1 is None, only returns the first portion of the mask (0's).
|
| 219 |
+
Args:
|
| 220 |
+
token_ids_0 (:obj:`List[int]`):
|
| 221 |
+
List of ids.
|
| 222 |
+
token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
|
| 223 |
+
Optional second list of IDs for sequence pairs.
|
| 224 |
+
Returns:
|
| 225 |
+
:obj:`List[int]`: List of `token type IDs <../glossary.html#token-type-ids>`_ according to the given
|
| 226 |
+
sequence(s).
|
| 227 |
+
"""
|
| 228 |
+
sep = [self.sep_token_id]
|
| 229 |
+
cls = [self.cls_token_id]
|
| 230 |
+
if token_ids_1 is None:
|
| 231 |
+
return len(cls + token_ids_0 + sep) * [0]
|
| 232 |
+
return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
|
| 233 |
+
|
| 234 |
+
def save_vocabulary(self, vocab_path):
|
| 235 |
+
"""
|
| 236 |
+
Save the sentencepiece vocabulary (copy original file) and special tokens file to a directory.
|
| 237 |
+
Args:
|
| 238 |
+
vocab_path (:obj:`str`):
|
| 239 |
+
The directory in which to save the vocabulary.
|
| 240 |
+
Returns:
|
| 241 |
+
:obj:`Tuple(str)`: Paths to the files saved.
|
| 242 |
+
"""
|
| 243 |
+
index = 0
|
| 244 |
+
vocab_file = vocab_path
|
| 245 |
+
with open(vocab_file, "w", encoding="utf-8") as writer:
|
| 246 |
+
for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
|
| 247 |
+
if index != token_index:
|
| 248 |
+
index = token_index
|
| 249 |
+
writer.write(token + "\n")
|
| 250 |
+
index += 1
|
| 251 |
+
return (vocab_file,)
|
| 252 |
+
|
| 253 |
+
class SMILES_Atomwise_Tokenizer(PreTrainedTokenizer):
|
| 254 |
+
r"""
|
| 255 |
+
Constructs a SMILES tokenizer. Based on SMILES Pair Encoding (https://github.com/XinhaoLi74/SmilesPE).
|
| 256 |
+
This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the methods. Users
|
| 257 |
+
should refer to the superclass for more information regarding methods.
|
| 258 |
+
Args:
|
| 259 |
+
vocab_file (:obj:`string`):
|
| 260 |
+
File containing the vocabulary.
|
| 261 |
+
unk_token (:obj:`string`, `optional`, defaults to "[UNK]"):
|
| 262 |
+
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
| 263 |
+
token instead.
|
| 264 |
+
sep_token (:obj:`string`, `optional`, defaults to "[SEP]"):
|
| 265 |
+
The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences
|
| 266 |
+
for sequence classification or for a text and a question for question answering.
|
| 267 |
+
It is also used as the last token of a sequence built with special tokens.
|
| 268 |
+
pad_token (:obj:`string`, `optional`, defaults to "[PAD]"):
|
| 269 |
+
The token used for padding, for example when batching sequences of different lengths.
|
| 270 |
+
cls_token (:obj:`string`, `optional`, defaults to "[CLS]"):
|
| 271 |
+
The classifier token which is used when doing sequence classification (classification of the whole
|
| 272 |
+
sequence instead of per-token classification). It is the first token of the sequence when built with
|
| 273 |
+
special tokens.
|
| 274 |
+
mask_token (:obj:`string`, `optional`, defaults to "[MASK]"):
|
| 275 |
+
The token used for masking values. This is the token used when training this model with masked language
|
| 276 |
+
modeling. This is the token which the model will try to predict.
|
| 277 |
+
"""
|
| 278 |
+
|
| 279 |
+
def __init__(
|
| 280 |
+
self,
|
| 281 |
+
vocab_file,
|
| 282 |
+
unk_token="[UNK]",
|
| 283 |
+
sep_token="[SEP]",
|
| 284 |
+
pad_token="[PAD]",
|
| 285 |
+
cls_token="[CLS]",
|
| 286 |
+
mask_token="[MASK]",
|
| 287 |
+
**kwargs
|
| 288 |
+
):
|
| 289 |
+
super().__init__(
|
| 290 |
+
unk_token=unk_token,
|
| 291 |
+
sep_token=sep_token,
|
| 292 |
+
pad_token=pad_token,
|
| 293 |
+
cls_token=cls_token,
|
| 294 |
+
mask_token=mask_token,
|
| 295 |
+
**kwargs,
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
if not os.path.isfile(vocab_file):
|
| 299 |
+
raise ValueError(
|
| 300 |
+
"Can't find a vocabulary file at path '{}'.".format(vocab_file)
|
| 301 |
+
)
|
| 302 |
+
self.vocab = load_vocab(vocab_file)
|
| 303 |
+
self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
|
| 304 |
+
self.tokenizer = Atomwise_Tokenizer()
|
| 305 |
+
|
| 306 |
+
@property
|
| 307 |
+
def vocab_size(self):
|
| 308 |
+
return len(self.vocab)
|
| 309 |
+
|
| 310 |
+
def get_vocab(self):
|
| 311 |
+
return dict(self.vocab, **self.added_tokens_encoder)
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
def _tokenize(self, text):
|
| 315 |
+
return self.tokenizer.tokenize(text)
|
| 316 |
+
|
| 317 |
+
def _convert_token_to_id(self, token):
|
| 318 |
+
""" Converts a token (str) in an id using the vocab. """
|
| 319 |
+
return self.vocab.get(token, self.vocab.get(self.unk_token))
|
| 320 |
+
|
| 321 |
+
def _convert_id_to_token(self, index):
|
| 322 |
+
"""Converts an index (integer) in a token (str) using the vocab."""
|
| 323 |
+
return self.ids_to_tokens.get(index, self.unk_token)
|
| 324 |
+
|
| 325 |
+
def convert_tokens_to_string(self, tokens):
|
| 326 |
+
""" Converts a sequence of tokens (string) in a single string. """
|
| 327 |
+
out_string = " ".join(tokens).replace(" ##", "").strip()
|
| 328 |
+
return out_string
|
| 329 |
+
|
| 330 |
+
def build_inputs_with_special_tokens(
|
| 331 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
| 332 |
+
) -> List[int]:
|
| 333 |
+
"""
|
| 334 |
+
Build model inputs from a sequence or a pair of sequence for sequence classification tasks
|
| 335 |
+
by concatenating and adding special tokens.
|
| 336 |
+
A BERT sequence has the following format:
|
| 337 |
+
- single sequence: ``[CLS] X [SEP]``
|
| 338 |
+
- pair of sequences: ``[CLS] A [SEP] B [SEP]``
|
| 339 |
+
Args:
|
| 340 |
+
token_ids_0 (:obj:`List[int]`):
|
| 341 |
+
List of IDs to which the special tokens will be added
|
| 342 |
+
token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
|
| 343 |
+
Optional second list of IDs for sequence pairs.
|
| 344 |
+
Returns:
|
| 345 |
+
:obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
|
| 346 |
+
"""
|
| 347 |
+
if token_ids_1 is None:
|
| 348 |
+
return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
|
| 349 |
+
cls = [self.cls_token_id]
|
| 350 |
+
sep = [self.sep_token_id]
|
| 351 |
+
return cls + token_ids_0 + sep + token_ids_1 + sep
|
| 352 |
+
|
| 353 |
+
def get_special_tokens_mask(
|
| 354 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
|
| 355 |
+
) -> List[int]:
|
| 356 |
+
"""
|
| 357 |
+
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
|
| 358 |
+
special tokens using the tokenizer ``prepare_for_model`` method.
|
| 359 |
+
Args:
|
| 360 |
+
token_ids_0 (:obj:`List[int]`):
|
| 361 |
+
List of ids.
|
| 362 |
+
token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
|
| 363 |
+
Optional second list of IDs for sequence pairs.
|
| 364 |
+
already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
| 365 |
+
Set to True if the token list is already formatted with special tokens for the model
|
| 366 |
+
Returns:
|
| 367 |
+
:obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
|
| 368 |
+
"""
|
| 369 |
+
|
| 370 |
+
if already_has_special_tokens:
|
| 371 |
+
if token_ids_1 is not None:
|
| 372 |
+
raise ValueError(
|
| 373 |
+
"You should not supply a second sequence if the provided sequence of "
|
| 374 |
+
"ids is already formated with special tokens for the model."
|
| 375 |
+
)
|
| 376 |
+
return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
|
| 377 |
+
|
| 378 |
+
if token_ids_1 is not None:
|
| 379 |
+
return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
|
| 380 |
+
return [1] + ([0] * len(token_ids_0)) + [1]
|
| 381 |
+
|
| 382 |
+
def create_token_type_ids_from_sequences(
|
| 383 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
| 384 |
+
) -> List[int]:
|
| 385 |
+
"""
|
| 386 |
+
Creates a mask from the two sequences passed to be used in a sequence-pair classification task.
|
| 387 |
+
A BERT sequence pair mask has the following format:
|
| 388 |
+
::
|
| 389 |
+
0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
|
| 390 |
+
| first sequence | second sequence |
|
| 391 |
+
if token_ids_1 is None, only returns the first portion of the mask (0's).
|
| 392 |
+
Args:
|
| 393 |
+
token_ids_0 (:obj:`List[int]`):
|
| 394 |
+
List of ids.
|
| 395 |
+
token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
|
| 396 |
+
Optional second list of IDs for sequence pairs.
|
| 397 |
+
Returns:
|
| 398 |
+
:obj:`List[int]`: List of `token type IDs <../glossary.html#token-type-ids>`_ according to the given
|
| 399 |
+
sequence(s).
|
| 400 |
+
"""
|
| 401 |
+
sep = [self.sep_token_id]
|
| 402 |
+
cls = [self.cls_token_id]
|
| 403 |
+
if token_ids_1 is None:
|
| 404 |
+
return len(cls + token_ids_0 + sep) * [0]
|
| 405 |
+
return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
|
| 406 |
+
|
| 407 |
+
def save_vocabulary(self, vocab_path):
|
| 408 |
+
"""
|
| 409 |
+
Save the sentencepiece vocabulary (copy original file) and special tokens file to a directory.
|
| 410 |
+
Args:
|
| 411 |
+
vocab_path (:obj:`str`):
|
| 412 |
+
The directory in which to save the vocabulary.
|
| 413 |
+
Returns:
|
| 414 |
+
:obj:`Tuple(str)`: Paths to the files saved.
|
| 415 |
+
"""
|
| 416 |
+
index = 0
|
| 417 |
+
vocab_file = vocab_path
|
| 418 |
+
with open(vocab_file, "w", encoding="utf-8") as writer:
|
| 419 |
+
for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
|
| 420 |
+
if index != token_index:
|
| 421 |
+
index = token_index
|
| 422 |
+
writer.write(token + "\n")
|
| 423 |
+
index += 1
|
| 424 |
+
return (vocab_file,)
|
src/scoring/tokenizer/new_splits.txt
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
c 1
|
| 2 |
+
c 2
|
| 3 |
+
c 3
|
| 4 |
+
c 4
|
| 5 |
+
c 5
|
| 6 |
+
c 6
|
| 7 |
+
c 7
|
| 8 |
+
c 8
|
| 9 |
+
c 9
|
| 10 |
+
( c1
|
| 11 |
+
( c2
|
| 12 |
+
c1 )
|
| 13 |
+
c2 )
|
| 14 |
+
n 1
|
| 15 |
+
n 2
|
| 16 |
+
n 3
|
| 17 |
+
n 4
|
| 18 |
+
n 5
|
| 19 |
+
n 6
|
| 20 |
+
n 7
|
| 21 |
+
n 8
|
| 22 |
+
n 9
|
| 23 |
+
( n1
|
| 24 |
+
( n2
|
| 25 |
+
n1 )
|
| 26 |
+
n2 )
|
| 27 |
+
O 1
|
| 28 |
+
O 2
|
| 29 |
+
O 3
|
| 30 |
+
O 4
|
| 31 |
+
O 5
|
| 32 |
+
O 6
|
| 33 |
+
O 7
|
| 34 |
+
O 8
|
| 35 |
+
O 9
|
| 36 |
+
( O1
|
| 37 |
+
( O2
|
| 38 |
+
O2 )
|
| 39 |
+
O2 )
|
| 40 |
+
= O
|
| 41 |
+
= C
|
| 42 |
+
= c
|
| 43 |
+
= N
|
| 44 |
+
= n
|
| 45 |
+
=C C
|
| 46 |
+
=C N
|
| 47 |
+
=C c
|
| 48 |
+
=c c
|
| 49 |
+
=N C
|
| 50 |
+
=N c
|
| 51 |
+
=n C
|
| 52 |
+
=n c
|
| 53 |
+
# N
|
| 54 |
+
# C
|
| 55 |
+
#N C
|
| 56 |
+
#C C
|
| 57 |
+
#C N
|
| 58 |
+
#N N
|
| 59 |
+
( C
|
| 60 |
+
C )
|
| 61 |
+
( O
|
| 62 |
+
O )
|
| 63 |
+
( N
|
| 64 |
+
N )
|
| 65 |
+
Br c
|
| 66 |
+
( =O
|
| 67 |
+
(=O )
|
| 68 |
+
C (=O)
|
| 69 |
+
C =O
|
| 70 |
+
C =N
|
| 71 |
+
C #N
|
| 72 |
+
C #C
|
| 73 |
+
C C
|
| 74 |
+
CC C
|
| 75 |
+
CC N
|
| 76 |
+
CC O
|
| 77 |
+
CC S
|
| 78 |
+
CC c
|
| 79 |
+
CC n
|
| 80 |
+
C N
|
| 81 |
+
CN C
|
| 82 |
+
CN c
|
| 83 |
+
C O
|
| 84 |
+
CO C
|
| 85 |
+
CO N
|
| 86 |
+
CO c
|
| 87 |
+
C S
|
| 88 |
+
CS C
|
| 89 |
+
CS S
|
| 90 |
+
CS c
|
| 91 |
+
C c
|
| 92 |
+
Cl c
|
| 93 |
+
C n
|
| 94 |
+
F c
|
| 95 |
+
N C
|
| 96 |
+
NC C
|
| 97 |
+
NC c
|
| 98 |
+
N N
|
| 99 |
+
N O
|
| 100 |
+
N c
|
| 101 |
+
N n
|
| 102 |
+
O C
|
| 103 |
+
OC C
|
| 104 |
+
OC O
|
| 105 |
+
OC c
|
| 106 |
+
O N
|
| 107 |
+
O O
|
| 108 |
+
O c
|
| 109 |
+
S C
|
| 110 |
+
SC C
|
| 111 |
+
SC c
|
| 112 |
+
S S
|
| 113 |
+
S c
|
| 114 |
+
c c
|
| 115 |
+
cc c
|
| 116 |
+
cc n
|
| 117 |
+
cc o
|
| 118 |
+
cc s
|
| 119 |
+
cc cc
|
| 120 |
+
c n
|
| 121 |
+
cn c
|
| 122 |
+
cn n
|
| 123 |
+
c o
|
| 124 |
+
co c
|
| 125 |
+
c s
|
| 126 |
+
cs c
|
| 127 |
+
cs n
|
| 128 |
+
n c
|
| 129 |
+
nc c
|
| 130 |
+
nc n
|
| 131 |
+
nc o
|
| 132 |
+
nc s
|
| 133 |
+
n n
|
| 134 |
+
nn c
|
| 135 |
+
nn n
|
| 136 |
+
n o
|
| 137 |
+
no c
|
| 138 |
+
no n
|
| 139 |
+
n s
|
| 140 |
+
ns c
|
| 141 |
+
ns n
|
| 142 |
+
o c
|
| 143 |
+
oc c
|
| 144 |
+
o n
|
| 145 |
+
s c
|
| 146 |
+
sc c
|
| 147 |
+
sc n
|
| 148 |
+
s n
|
| 149 |
+
N P
|
| 150 |
+
P N
|
| 151 |
+
C P
|
| 152 |
+
P C
|
| 153 |
+
N S
|
| 154 |
+
S N
|
| 155 |
+
C S
|
| 156 |
+
S C
|
| 157 |
+
S P
|
| 158 |
+
P S
|
| 159 |
+
C I
|
src/scoring/tokenizer/new_vocab.txt
ADDED
|
@@ -0,0 +1,587 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[PAD]
|
| 2 |
+
[UNK]
|
| 3 |
+
[CLS]
|
| 4 |
+
[SEP]
|
| 5 |
+
[MASK]
|
| 6 |
+
#
|
| 7 |
+
%
|
| 8 |
+
(
|
| 9 |
+
)
|
| 10 |
+
+
|
| 11 |
+
-
|
| 12 |
+
/
|
| 13 |
+
0
|
| 14 |
+
1
|
| 15 |
+
2
|
| 16 |
+
3
|
| 17 |
+
4
|
| 18 |
+
5
|
| 19 |
+
6
|
| 20 |
+
7
|
| 21 |
+
8
|
| 22 |
+
9
|
| 23 |
+
=
|
| 24 |
+
@
|
| 25 |
+
A
|
| 26 |
+
B
|
| 27 |
+
Br
|
| 28 |
+
Brc
|
| 29 |
+
C
|
| 30 |
+
CC
|
| 31 |
+
CCC
|
| 32 |
+
CCN
|
| 33 |
+
CCO
|
| 34 |
+
CCS
|
| 35 |
+
CCc
|
| 36 |
+
CCn
|
| 37 |
+
CN
|
| 38 |
+
CNC
|
| 39 |
+
CNc
|
| 40 |
+
CO
|
| 41 |
+
COC
|
| 42 |
+
CON
|
| 43 |
+
COc
|
| 44 |
+
CS
|
| 45 |
+
CSC
|
| 46 |
+
CSS
|
| 47 |
+
CSc
|
| 48 |
+
Cc
|
| 49 |
+
Cl
|
| 50 |
+
Clc
|
| 51 |
+
Cn
|
| 52 |
+
F
|
| 53 |
+
Fc
|
| 54 |
+
H
|
| 55 |
+
I
|
| 56 |
+
K
|
| 57 |
+
L
|
| 58 |
+
M
|
| 59 |
+
N
|
| 60 |
+
NC
|
| 61 |
+
NCC
|
| 62 |
+
NCc
|
| 63 |
+
NN
|
| 64 |
+
NO
|
| 65 |
+
Nc
|
| 66 |
+
Nn
|
| 67 |
+
O
|
| 68 |
+
OC
|
| 69 |
+
OCC
|
| 70 |
+
OCO
|
| 71 |
+
OCc
|
| 72 |
+
ON
|
| 73 |
+
OO
|
| 74 |
+
Oc
|
| 75 |
+
P
|
| 76 |
+
R
|
| 77 |
+
S
|
| 78 |
+
SC
|
| 79 |
+
SCC
|
| 80 |
+
SCc
|
| 81 |
+
SS
|
| 82 |
+
Sc
|
| 83 |
+
T
|
| 84 |
+
X
|
| 85 |
+
Z
|
| 86 |
+
[
|
| 87 |
+
\\
|
| 88 |
+
(/
|
| 89 |
+
]
|
| 90 |
+
a
|
| 91 |
+
b
|
| 92 |
+
c
|
| 93 |
+
cc
|
| 94 |
+
ccc
|
| 95 |
+
cccc
|
| 96 |
+
ccn
|
| 97 |
+
cco
|
| 98 |
+
ccs
|
| 99 |
+
cn
|
| 100 |
+
cnc
|
| 101 |
+
cnn
|
| 102 |
+
co
|
| 103 |
+
coc
|
| 104 |
+
cs
|
| 105 |
+
csc
|
| 106 |
+
csn
|
| 107 |
+
e
|
| 108 |
+
g
|
| 109 |
+
i
|
| 110 |
+
l
|
| 111 |
+
n
|
| 112 |
+
nc
|
| 113 |
+
ncc
|
| 114 |
+
ncn
|
| 115 |
+
nco
|
| 116 |
+
ncs
|
| 117 |
+
nn
|
| 118 |
+
nnc
|
| 119 |
+
nnn
|
| 120 |
+
no
|
| 121 |
+
noc
|
| 122 |
+
non
|
| 123 |
+
ns
|
| 124 |
+
nsc
|
| 125 |
+
nsn
|
| 126 |
+
o
|
| 127 |
+
oc
|
| 128 |
+
occ
|
| 129 |
+
on
|
| 130 |
+
p
|
| 131 |
+
r
|
| 132 |
+
s
|
| 133 |
+
sc
|
| 134 |
+
scc
|
| 135 |
+
scn
|
| 136 |
+
sn
|
| 137 |
+
t
|
| 138 |
+
c1
|
| 139 |
+
c2
|
| 140 |
+
c3
|
| 141 |
+
c4
|
| 142 |
+
c5
|
| 143 |
+
c6
|
| 144 |
+
c7
|
| 145 |
+
c8
|
| 146 |
+
c9
|
| 147 |
+
n1
|
| 148 |
+
n2
|
| 149 |
+
n3
|
| 150 |
+
n4
|
| 151 |
+
n5
|
| 152 |
+
n6
|
| 153 |
+
n7
|
| 154 |
+
n8
|
| 155 |
+
n9
|
| 156 |
+
O1
|
| 157 |
+
O2
|
| 158 |
+
O3
|
| 159 |
+
O4
|
| 160 |
+
O5
|
| 161 |
+
O6
|
| 162 |
+
O7
|
| 163 |
+
O8
|
| 164 |
+
O9
|
| 165 |
+
(c1
|
| 166 |
+
(c2
|
| 167 |
+
c1)
|
| 168 |
+
c2)
|
| 169 |
+
(n1
|
| 170 |
+
(n2
|
| 171 |
+
n1)
|
| 172 |
+
n2)
|
| 173 |
+
(O1
|
| 174 |
+
(O2
|
| 175 |
+
O2)
|
| 176 |
+
=O
|
| 177 |
+
=C
|
| 178 |
+
=c
|
| 179 |
+
=N
|
| 180 |
+
=n
|
| 181 |
+
=CC
|
| 182 |
+
=CN
|
| 183 |
+
=Cc
|
| 184 |
+
=cc
|
| 185 |
+
=NC
|
| 186 |
+
=Nc
|
| 187 |
+
=nC
|
| 188 |
+
=nc
|
| 189 |
+
#C
|
| 190 |
+
#CC
|
| 191 |
+
#CN
|
| 192 |
+
#N
|
| 193 |
+
#NC
|
| 194 |
+
#NN
|
| 195 |
+
(C
|
| 196 |
+
C)
|
| 197 |
+
(O
|
| 198 |
+
O)
|
| 199 |
+
(N
|
| 200 |
+
N)
|
| 201 |
+
NP
|
| 202 |
+
PN
|
| 203 |
+
CP
|
| 204 |
+
PC
|
| 205 |
+
NS
|
| 206 |
+
SN
|
| 207 |
+
SP
|
| 208 |
+
PS
|
| 209 |
+
C(=O)
|
| 210 |
+
(/Br)
|
| 211 |
+
(/C#N)
|
| 212 |
+
(/C)
|
| 213 |
+
(/C=N)
|
| 214 |
+
(/C=O)
|
| 215 |
+
(/CBr)
|
| 216 |
+
(/CC)
|
| 217 |
+
(/CCC)
|
| 218 |
+
(/CCF)
|
| 219 |
+
(/CCN)
|
| 220 |
+
(/CCO)
|
| 221 |
+
(/CCl)
|
| 222 |
+
(/CI)
|
| 223 |
+
(/CN)
|
| 224 |
+
(/CO)
|
| 225 |
+
(/CS)
|
| 226 |
+
(/Cl)
|
| 227 |
+
(/F)
|
| 228 |
+
(/I)
|
| 229 |
+
(/N)
|
| 230 |
+
(/NC)
|
| 231 |
+
(/NCC)
|
| 232 |
+
(/NO)
|
| 233 |
+
(/O)
|
| 234 |
+
(/OC)
|
| 235 |
+
(/OCC)
|
| 236 |
+
(/S)
|
| 237 |
+
(/SC)
|
| 238 |
+
(=C)
|
| 239 |
+
(=C/C)
|
| 240 |
+
(=C/F)
|
| 241 |
+
(=C/I)
|
| 242 |
+
(=C/N)
|
| 243 |
+
(=C/O)
|
| 244 |
+
(=CBr)
|
| 245 |
+
(=CC)
|
| 246 |
+
(=CCF)
|
| 247 |
+
(=CCN)
|
| 248 |
+
(=CCO)
|
| 249 |
+
(=CCl)
|
| 250 |
+
(=CF)
|
| 251 |
+
(=CI)
|
| 252 |
+
(=CN)
|
| 253 |
+
(=CO)
|
| 254 |
+
(=C\\C)
|
| 255 |
+
(=C\\F)
|
| 256 |
+
(=C\\I)
|
| 257 |
+
(=C\\N)
|
| 258 |
+
(=C\\O)
|
| 259 |
+
(=N)
|
| 260 |
+
(=N/C)
|
| 261 |
+
(=N/N)
|
| 262 |
+
(=N/O)
|
| 263 |
+
(=NBr)
|
| 264 |
+
(=NC)
|
| 265 |
+
(=NCC)
|
| 266 |
+
(=NCl)
|
| 267 |
+
(=NN)
|
| 268 |
+
(=NO)
|
| 269 |
+
(=NOC)
|
| 270 |
+
(=N\\C)
|
| 271 |
+
(=N\\N)
|
| 272 |
+
(=N\\O)
|
| 273 |
+
(=O)
|
| 274 |
+
(=S)
|
| 275 |
+
(B)
|
| 276 |
+
(Br)
|
| 277 |
+
(C#C)
|
| 278 |
+
(C#CC)
|
| 279 |
+
(C#CI)
|
| 280 |
+
(C#CO)
|
| 281 |
+
(C#N)
|
| 282 |
+
(C#SN)
|
| 283 |
+
(C)
|
| 284 |
+
(C=C)
|
| 285 |
+
(C=CF)
|
| 286 |
+
(C=CI)
|
| 287 |
+
(C=N)
|
| 288 |
+
(C=NN)
|
| 289 |
+
(C=NO)
|
| 290 |
+
(C=O)
|
| 291 |
+
(C=S)
|
| 292 |
+
(CBr)
|
| 293 |
+
(CC#C)
|
| 294 |
+
(CC#N)
|
| 295 |
+
(CC)
|
| 296 |
+
(CC=C)
|
| 297 |
+
(CC=O)
|
| 298 |
+
(CCBr)
|
| 299 |
+
(CCC)
|
| 300 |
+
(CCCC)
|
| 301 |
+
(CCCF)
|
| 302 |
+
(CCCI)
|
| 303 |
+
(CCCN)
|
| 304 |
+
(CCCO)
|
| 305 |
+
(CCCS)
|
| 306 |
+
(CCCl)
|
| 307 |
+
(CCF)
|
| 308 |
+
(CCI)
|
| 309 |
+
(CCN)
|
| 310 |
+
(CCNC)
|
| 311 |
+
(CCNN)
|
| 312 |
+
(CCNO)
|
| 313 |
+
(CCO)
|
| 314 |
+
(CCOC)
|
| 315 |
+
(CCON)
|
| 316 |
+
(CCS)
|
| 317 |
+
(CCSC)
|
| 318 |
+
(CCl)
|
| 319 |
+
(CF)
|
| 320 |
+
(CI)
|
| 321 |
+
(CN)
|
| 322 |
+
(CN=O)
|
| 323 |
+
(CNC)
|
| 324 |
+
(CNCC)
|
| 325 |
+
(CNCO)
|
| 326 |
+
(CNN)
|
| 327 |
+
(CNNC)
|
| 328 |
+
(CNO)
|
| 329 |
+
(CNOC)
|
| 330 |
+
(CO)
|
| 331 |
+
(COC)
|
| 332 |
+
(COCC)
|
| 333 |
+
(COCI)
|
| 334 |
+
(COCN)
|
| 335 |
+
(COCO)
|
| 336 |
+
(COF)
|
| 337 |
+
(CON)
|
| 338 |
+
(COO)
|
| 339 |
+
(CS)
|
| 340 |
+
(CSC)
|
| 341 |
+
(CSCC)
|
| 342 |
+
(CSCF)
|
| 343 |
+
(CSO)
|
| 344 |
+
(Cl)
|
| 345 |
+
(F)
|
| 346 |
+
(I)
|
| 347 |
+
(N)
|
| 348 |
+
(N=N)
|
| 349 |
+
(N=NO)
|
| 350 |
+
(N=O)
|
| 351 |
+
(N=S)
|
| 352 |
+
(NBr)
|
| 353 |
+
(NC#N)
|
| 354 |
+
(NC)
|
| 355 |
+
(NC=N)
|
| 356 |
+
(NC=O)
|
| 357 |
+
(NC=S)
|
| 358 |
+
(NCBr)
|
| 359 |
+
(NCC)
|
| 360 |
+
(NCCC)
|
| 361 |
+
(NCCF)
|
| 362 |
+
(NCCN)
|
| 363 |
+
(NCCO)
|
| 364 |
+
(NCCS)
|
| 365 |
+
(NCCl)
|
| 366 |
+
(NCNC)
|
| 367 |
+
(NCO)
|
| 368 |
+
(NCS)
|
| 369 |
+
(NCl)
|
| 370 |
+
(NN)
|
| 371 |
+
(NN=O)
|
| 372 |
+
(NNC)
|
| 373 |
+
(NO)
|
| 374 |
+
(NOC)
|
| 375 |
+
(O)
|
| 376 |
+
(OC#N)
|
| 377 |
+
(OC)
|
| 378 |
+
(OC=C)
|
| 379 |
+
(OC=O)
|
| 380 |
+
(OC=S)
|
| 381 |
+
(OCBr)
|
| 382 |
+
(OCC)
|
| 383 |
+
(OCCC)
|
| 384 |
+
(OCCF)
|
| 385 |
+
(OCCI)
|
| 386 |
+
(OCCN)
|
| 387 |
+
(OCCO)
|
| 388 |
+
(OCCS)
|
| 389 |
+
(OCCl)
|
| 390 |
+
(OCF)
|
| 391 |
+
(OCI)
|
| 392 |
+
(OCO)
|
| 393 |
+
(OCOC)
|
| 394 |
+
(OCON)
|
| 395 |
+
(OCSC)
|
| 396 |
+
(OCl)
|
| 397 |
+
(OI)
|
| 398 |
+
(ON)
|
| 399 |
+
(OO)
|
| 400 |
+
(OOC)
|
| 401 |
+
(OOCC)
|
| 402 |
+
(OOSN)
|
| 403 |
+
(OSC)
|
| 404 |
+
(P)
|
| 405 |
+
(S)
|
| 406 |
+
(SC#N)
|
| 407 |
+
(SC)
|
| 408 |
+
(SCC)
|
| 409 |
+
(SCCC)
|
| 410 |
+
(SCCF)
|
| 411 |
+
(SCCN)
|
| 412 |
+
(SCCO)
|
| 413 |
+
(SCCS)
|
| 414 |
+
(SCCl)
|
| 415 |
+
(SCF)
|
| 416 |
+
(SCN)
|
| 417 |
+
(SCOC)
|
| 418 |
+
(SCSC)
|
| 419 |
+
(SCl)
|
| 420 |
+
(SI)
|
| 421 |
+
(SN)
|
| 422 |
+
(SN=O)
|
| 423 |
+
(SO)
|
| 424 |
+
(SOC)
|
| 425 |
+
(SOOO)
|
| 426 |
+
(SS)
|
| 427 |
+
(SSC)
|
| 428 |
+
(SSCC)
|
| 429 |
+
([At])
|
| 430 |
+
([O-])
|
| 431 |
+
([O])
|
| 432 |
+
([S-])
|
| 433 |
+
(\\Br)
|
| 434 |
+
(\\C#N)
|
| 435 |
+
(\\C)
|
| 436 |
+
(\\C=N)
|
| 437 |
+
(\\C=O)
|
| 438 |
+
(\\CBr)
|
| 439 |
+
(\\CC)
|
| 440 |
+
(\\CCC)
|
| 441 |
+
(\\CCO)
|
| 442 |
+
(\\CCl)
|
| 443 |
+
(\\CF)
|
| 444 |
+
(\\CN)
|
| 445 |
+
(\\CNC)
|
| 446 |
+
(\\CO)
|
| 447 |
+
(\\COC)
|
| 448 |
+
(\\Cl)
|
| 449 |
+
(\\F)
|
| 450 |
+
(\\I)
|
| 451 |
+
(\\N)
|
| 452 |
+
(\\NC)
|
| 453 |
+
(\\NCC)
|
| 454 |
+
(\\NN)
|
| 455 |
+
(\\NO)
|
| 456 |
+
(\\NOC)
|
| 457 |
+
(\\O)
|
| 458 |
+
(\\OC)
|
| 459 |
+
(\\OCC)
|
| 460 |
+
(\\ON)
|
| 461 |
+
(\\S)
|
| 462 |
+
(\\SC)
|
| 463 |
+
(\\SCC)
|
| 464 |
+
[Ag+]
|
| 465 |
+
[Ag-4]
|
| 466 |
+
[Ag]
|
| 467 |
+
[Al-3]
|
| 468 |
+
[Al]
|
| 469 |
+
[As+]
|
| 470 |
+
[AsH3]
|
| 471 |
+
[AsH]
|
| 472 |
+
[As]
|
| 473 |
+
[At]
|
| 474 |
+
[B-]
|
| 475 |
+
[B@-]
|
| 476 |
+
[B@@-]
|
| 477 |
+
[BH-]
|
| 478 |
+
[BH2-]
|
| 479 |
+
[BH3-]
|
| 480 |
+
[B]
|
| 481 |
+
[Ba]
|
| 482 |
+
[Br+2]
|
| 483 |
+
[BrH]
|
| 484 |
+
[Br]
|
| 485 |
+
[C+]
|
| 486 |
+
[C-]
|
| 487 |
+
[C@@H]
|
| 488 |
+
[C@@]
|
| 489 |
+
[C@H]
|
| 490 |
+
[C@]
|
| 491 |
+
[CH-]
|
| 492 |
+
[CH2]
|
| 493 |
+
[CH3]
|
| 494 |
+
[CH]
|
| 495 |
+
[C]
|
| 496 |
+
[CaH2]
|
| 497 |
+
[Ca]
|
| 498 |
+
[Cl+2]
|
| 499 |
+
[Cl+3]
|
| 500 |
+
[Cl+]
|
| 501 |
+
[Cs]
|
| 502 |
+
[FH]
|
| 503 |
+
[F]
|
| 504 |
+
[H]
|
| 505 |
+
[He]
|
| 506 |
+
[I+2]
|
| 507 |
+
[I+3]
|
| 508 |
+
[I+]
|
| 509 |
+
[IH]
|
| 510 |
+
[I]
|
| 511 |
+
[K]
|
| 512 |
+
[Kr]
|
| 513 |
+
[Li+]
|
| 514 |
+
[LiH]
|
| 515 |
+
[MgH2]
|
| 516 |
+
[Mg]
|
| 517 |
+
[N+]
|
| 518 |
+
[N-]
|
| 519 |
+
[N@+]
|
| 520 |
+
[N@@+]
|
| 521 |
+
[N@@]
|
| 522 |
+
[N@]
|
| 523 |
+
[NH+]
|
| 524 |
+
[NH-]
|
| 525 |
+
[NH2+]
|
| 526 |
+
[NH3]
|
| 527 |
+
[NH]
|
| 528 |
+
[N]
|
| 529 |
+
[Na]
|
| 530 |
+
[O+]
|
| 531 |
+
[O-]
|
| 532 |
+
[OH+]
|
| 533 |
+
[OH2]
|
| 534 |
+
[OH]
|
| 535 |
+
[O]
|
| 536 |
+
[P+]
|
| 537 |
+
[P@+]
|
| 538 |
+
[P@@+]
|
| 539 |
+
[P@@]
|
| 540 |
+
[P@]
|
| 541 |
+
[PH2]
|
| 542 |
+
[PH]
|
| 543 |
+
[P]
|
| 544 |
+
[Ra]
|
| 545 |
+
[Rb]
|
| 546 |
+
[S+]
|
| 547 |
+
[S-]
|
| 548 |
+
[S@+]
|
| 549 |
+
[S@@+]
|
| 550 |
+
[S@@]
|
| 551 |
+
[S@]
|
| 552 |
+
[SH+]
|
| 553 |
+
[SH2]
|
| 554 |
+
[SH]
|
| 555 |
+
[S]
|
| 556 |
+
[Se+]
|
| 557 |
+
[Se-2]
|
| 558 |
+
[SeH2]
|
| 559 |
+
[SeH]
|
| 560 |
+
[Se]
|
| 561 |
+
[Si@]
|
| 562 |
+
[SiH2]
|
| 563 |
+
[SiH]
|
| 564 |
+
[Si]
|
| 565 |
+
[SrH2]
|
| 566 |
+
[TeH]
|
| 567 |
+
[Te]
|
| 568 |
+
[Xe]
|
| 569 |
+
[Zn+2]
|
| 570 |
+
[Zn-2]
|
| 571 |
+
[Zn]
|
| 572 |
+
[b-]
|
| 573 |
+
[c+]
|
| 574 |
+
[c-]
|
| 575 |
+
[cH-]
|
| 576 |
+
[cH]
|
| 577 |
+
[c]
|
| 578 |
+
[n+]
|
| 579 |
+
[n-]
|
| 580 |
+
[nH]
|
| 581 |
+
[n]
|
| 582 |
+
[o+]
|
| 583 |
+
[s+]
|
| 584 |
+
[se+]
|
| 585 |
+
[se]
|
| 586 |
+
[te+]
|
| 587 |
+
[te]
|
src/tokenizer/__init__.py
ADDED
|
File without changes
|
src/tokenizer/my_tokenizers.py
ADDED
|
@@ -0,0 +1,441 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import collections
|
| 2 |
+
import logging
|
| 3 |
+
import os
|
| 4 |
+
import re
|
| 5 |
+
import codecs
|
| 6 |
+
import unicodedata
|
| 7 |
+
from typing import List, Optional
|
| 8 |
+
from transformers import PreTrainedTokenizer
|
| 9 |
+
from SmilesPE.tokenizer import SPE_Tokenizer
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
def load_vocab(vocab_file):
|
| 13 |
+
"""Loads a vocabulary file into a dictionary."""
|
| 14 |
+
vocab = collections.OrderedDict()
|
| 15 |
+
with open(vocab_file, "r", encoding="utf-8") as reader:
|
| 16 |
+
tokens = reader.readlines()
|
| 17 |
+
for index, token in enumerate(tokens):
|
| 18 |
+
token = token.rstrip("\n")
|
| 19 |
+
vocab[token] = index
|
| 20 |
+
return vocab
|
| 21 |
+
|
| 22 |
+
class Atomwise_Tokenizer(object):
|
| 23 |
+
"""Run atom-level SMILES tokenization"""
|
| 24 |
+
|
| 25 |
+
def __init__(self):
|
| 26 |
+
""" Constructs a atom-level Tokenizer.
|
| 27 |
+
"""
|
| 28 |
+
# self.regex_pattern = r"(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\|\/|:|~|@|\?|>>?|\*|\$|\%[0-9]{2}|[0-9])"
|
| 29 |
+
self.regex_pattern = r"(\([^\(\)]{0,4}\)|\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\|\/\/?|:|~|@|\?|>>?|\*|\$|\%[0-9]{2}|[0-9])"
|
| 30 |
+
|
| 31 |
+
self.regex = re.compile(self.regex_pattern)
|
| 32 |
+
|
| 33 |
+
def tokenize(self, text):
|
| 34 |
+
""" Basic Tokenization of a SMILES.
|
| 35 |
+
"""
|
| 36 |
+
tokens = [token for token in self.regex.findall(text)]
|
| 37 |
+
return tokens
|
| 38 |
+
|
| 39 |
+
class SMILES_SPE_Tokenizer(PreTrainedTokenizer):
|
| 40 |
+
r"""
|
| 41 |
+
Constructs a SMILES tokenizer. Based on SMILES Pair Encoding (https://github.com/XinhaoLi74/SmilesPE).
|
| 42 |
+
This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the methods. Users
|
| 43 |
+
should refer to the superclass for more information regarding methods.
|
| 44 |
+
Args:
|
| 45 |
+
vocab_file (:obj:`string`):
|
| 46 |
+
File containing the vocabulary.
|
| 47 |
+
spe_file (:obj:`string`):
|
| 48 |
+
File containing the trained SMILES Pair Encoding vocabulary.
|
| 49 |
+
unk_token (:obj:`string`, `optional`, defaults to "[UNK]"):
|
| 50 |
+
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
| 51 |
+
token instead.
|
| 52 |
+
sep_token (:obj:`string`, `optional`, defaults to "[SEP]"):
|
| 53 |
+
The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences
|
| 54 |
+
for sequence classification or for a text and a question for question answering.
|
| 55 |
+
It is also used as the last token of a sequence built with special tokens.
|
| 56 |
+
pad_token (:obj:`string`, `optional`, defaults to "[PAD]"):
|
| 57 |
+
The token used for padding, for example when batching sequences of different lengths.
|
| 58 |
+
cls_token (:obj:`string`, `optional`, defaults to "[CLS]"):
|
| 59 |
+
The classifier token which is used when doing sequence classification (classification of the whole
|
| 60 |
+
sequence instead of per-token classification). It is the first token of the sequence when built with
|
| 61 |
+
special tokens.
|
| 62 |
+
mask_token (:obj:`string`, `optional`, defaults to "[MASK]"):
|
| 63 |
+
The token used for masking values. This is the token used when training this model with masked language
|
| 64 |
+
modeling. This is the token which the model will try to predict.
|
| 65 |
+
"""
|
| 66 |
+
|
| 67 |
+
def __init__(self, vocab_file, spe_file,
|
| 68 |
+
unk_token="[UNK]",
|
| 69 |
+
sep_token="[SEP]",
|
| 70 |
+
pad_token="[PAD]",
|
| 71 |
+
cls_token="[CLS]",
|
| 72 |
+
mask_token="[MASK]",
|
| 73 |
+
**kwargs):
|
| 74 |
+
if not os.path.isfile(vocab_file):
|
| 75 |
+
raise ValueError("Can't find a vocabulary file at path '{}'.".format(vocab_file))
|
| 76 |
+
if not os.path.isfile(spe_file):
|
| 77 |
+
raise ValueError("Can't find a SPE vocabulary file at path '{}'.".format(spe_file))
|
| 78 |
+
|
| 79 |
+
self.vocab = load_vocab(vocab_file)
|
| 80 |
+
self.spe_vocab = open(spe_file, 'r', encoding='utf-8')
|
| 81 |
+
self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
|
| 82 |
+
self.spe_tokenizer = SPE_Tokenizer(self.spe_vocab)
|
| 83 |
+
|
| 84 |
+
super().__init__(
|
| 85 |
+
unk_token=unk_token,
|
| 86 |
+
sep_token=sep_token,
|
| 87 |
+
pad_token=pad_token,
|
| 88 |
+
cls_token=cls_token,
|
| 89 |
+
mask_token=mask_token,
|
| 90 |
+
**kwargs)
|
| 91 |
+
|
| 92 |
+
@property
|
| 93 |
+
def vocab_size(self):
|
| 94 |
+
return len(self.vocab)
|
| 95 |
+
|
| 96 |
+
def get_vocab(self):
|
| 97 |
+
return dict(self.vocab, **self.added_tokens_encoder)
|
| 98 |
+
|
| 99 |
+
def _tokenize(self, text):
|
| 100 |
+
return self.spe_tokenizer.tokenize(text).split(' ')
|
| 101 |
+
|
| 102 |
+
def _convert_token_to_id(self, token):
|
| 103 |
+
""" Converts a token (str) in an id using the vocab. """
|
| 104 |
+
return self.vocab.get(token, self.vocab.get(self.unk_token))
|
| 105 |
+
|
| 106 |
+
# changed encode and decode functions
|
| 107 |
+
def encode(self, token_array):
|
| 108 |
+
token_ids = []
|
| 109 |
+
token_ids.append(2)
|
| 110 |
+
for token in token_array:
|
| 111 |
+
id = self._convert_token_to_id(token)
|
| 112 |
+
token_ids.append(id)
|
| 113 |
+
token_ids.append(3)
|
| 114 |
+
token_ids = torch.tensor([token_ids])
|
| 115 |
+
attn_mask = torch.ones_like(token_ids)
|
| 116 |
+
return {'input_ids': token_ids, 'attention_mask': attn_mask}
|
| 117 |
+
|
| 118 |
+
def decode(self, token_ids, skip_special_tokens=True):
|
| 119 |
+
token_ids = token_ids.squeeze(0).cpu().tolist()
|
| 120 |
+
token_array = []
|
| 121 |
+
for idx in token_ids:
|
| 122 |
+
if idx == 3: # Stop decoding when token ID 3 is encountered
|
| 123 |
+
break
|
| 124 |
+
if skip_special_tokens and idx in self.all_special_ids:
|
| 125 |
+
continue
|
| 126 |
+
token = self._convert_id_to_token(idx)
|
| 127 |
+
token_array.append(token)
|
| 128 |
+
sequence = "".join(token_array)
|
| 129 |
+
return sequence
|
| 130 |
+
|
| 131 |
+
def batch_decode(self, batch_token_ids, skip_special_tokens=True):
|
| 132 |
+
sequences = []
|
| 133 |
+
for token_ids in batch_token_ids:
|
| 134 |
+
sequences.append(self.decode(token_ids))
|
| 135 |
+
return sequences
|
| 136 |
+
|
| 137 |
+
def get_token_split(self, token_ids):
|
| 138 |
+
if isinstance(token_ids, torch.Tensor):
|
| 139 |
+
token_ids = token_ids.cpu().tolist()
|
| 140 |
+
|
| 141 |
+
token_array = []
|
| 142 |
+
for seq_ids in token_ids:
|
| 143 |
+
seq_array = []
|
| 144 |
+
for id in seq_ids:
|
| 145 |
+
token = self._convert_id_to_token(id)
|
| 146 |
+
seq_array.append(token)
|
| 147 |
+
token_array.append(seq_array)
|
| 148 |
+
|
| 149 |
+
return token_array
|
| 150 |
+
|
| 151 |
+
def _convert_id_to_token(self, index):
|
| 152 |
+
"""Converts an index (integer) in a token (str) using the vocab."""
|
| 153 |
+
return self.ids_to_tokens.get(index, self.unk_token)
|
| 154 |
+
|
| 155 |
+
def convert_tokens_to_string(self, tokens):
|
| 156 |
+
""" Converts a sequence of tokens (string) in a single string. """
|
| 157 |
+
out_string = " ".join(tokens).replace(" ##", "").strip()
|
| 158 |
+
return out_string
|
| 159 |
+
|
| 160 |
+
def build_inputs_with_special_tokens(
|
| 161 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
| 162 |
+
) -> List[int]:
|
| 163 |
+
"""
|
| 164 |
+
Build model inputs from a sequence or a pair of sequence for sequence classification tasks
|
| 165 |
+
by concatenating and adding special tokens.
|
| 166 |
+
A BERT sequence has the following format:
|
| 167 |
+
- single sequence: ``[CLS] X [SEP]``
|
| 168 |
+
- pair of sequences: ``[CLS] A [SEP] B [SEP]``
|
| 169 |
+
Args:
|
| 170 |
+
token_ids_0 (:obj:`List[int]`):
|
| 171 |
+
List of IDs to which the special tokens will be added
|
| 172 |
+
token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
|
| 173 |
+
Optional second list of IDs for sequence pairs.
|
| 174 |
+
Returns:
|
| 175 |
+
:obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
|
| 176 |
+
"""
|
| 177 |
+
if token_ids_1 is None:
|
| 178 |
+
return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
|
| 179 |
+
cls = [self.cls_token_id]
|
| 180 |
+
sep = [self.sep_token_id]
|
| 181 |
+
return cls + token_ids_0 + sep + token_ids_1 + sep
|
| 182 |
+
|
| 183 |
+
def get_special_tokens_mask(
|
| 184 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
|
| 185 |
+
) -> List[int]:
|
| 186 |
+
"""
|
| 187 |
+
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
|
| 188 |
+
special tokens using the tokenizer ``prepare_for_model`` method.
|
| 189 |
+
Args:
|
| 190 |
+
token_ids_0 (:obj:`List[int]`):
|
| 191 |
+
List of ids.
|
| 192 |
+
token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
|
| 193 |
+
Optional second list of IDs for sequence pairs.
|
| 194 |
+
already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
| 195 |
+
Set to True if the token list is already formatted with special tokens for the model
|
| 196 |
+
Returns:
|
| 197 |
+
:obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
|
| 198 |
+
"""
|
| 199 |
+
|
| 200 |
+
if already_has_special_tokens:
|
| 201 |
+
if token_ids_1 is not None:
|
| 202 |
+
raise ValueError(
|
| 203 |
+
"You should not supply a second sequence if the provided sequence of "
|
| 204 |
+
"ids is already formated with special tokens for the model."
|
| 205 |
+
)
|
| 206 |
+
return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
|
| 207 |
+
|
| 208 |
+
if token_ids_1 is not None:
|
| 209 |
+
return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
|
| 210 |
+
return [1] + ([0] * len(token_ids_0)) + [1]
|
| 211 |
+
|
| 212 |
+
def create_token_type_ids_from_sequences(
|
| 213 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
| 214 |
+
) -> List[int]:
|
| 215 |
+
"""
|
| 216 |
+
Creates a mask from the two sequences passed to be used in a sequence-pair classification task.
|
| 217 |
+
A BERT sequence pair mask has the following format:
|
| 218 |
+
::
|
| 219 |
+
0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
|
| 220 |
+
| first sequence | second sequence |
|
| 221 |
+
if token_ids_1 is None, only returns the first portion of the mask (0's).
|
| 222 |
+
Args:
|
| 223 |
+
token_ids_0 (:obj:`List[int]`):
|
| 224 |
+
List of ids.
|
| 225 |
+
token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
|
| 226 |
+
Optional second list of IDs for sequence pairs.
|
| 227 |
+
Returns:
|
| 228 |
+
:obj:`List[int]`: List of `token type IDs <../glossary.html#token-type-ids>`_ according to the given
|
| 229 |
+
sequence(s).
|
| 230 |
+
"""
|
| 231 |
+
sep = [self.sep_token_id]
|
| 232 |
+
cls = [self.cls_token_id]
|
| 233 |
+
if token_ids_1 is None:
|
| 234 |
+
return len(cls + token_ids_0 + sep) * [0]
|
| 235 |
+
return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
|
| 236 |
+
|
| 237 |
+
def save_vocabulary(self, vocab_path):
|
| 238 |
+
"""
|
| 239 |
+
Save the sentencepiece vocabulary (copy original file) and special tokens file to a directory.
|
| 240 |
+
Args:
|
| 241 |
+
vocab_path (:obj:`str`):
|
| 242 |
+
The directory in which to save the vocabulary.
|
| 243 |
+
Returns:
|
| 244 |
+
:obj:`Tuple(str)`: Paths to the files saved.
|
| 245 |
+
"""
|
| 246 |
+
index = 0
|
| 247 |
+
if os.path.isdir(vocab_path):
|
| 248 |
+
vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES["vocab_file"])
|
| 249 |
+
else:
|
| 250 |
+
vocab_file = vocab_path
|
| 251 |
+
with open(vocab_file, "w", encoding="utf-8") as writer:
|
| 252 |
+
for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
|
| 253 |
+
if index != token_index:
|
| 254 |
+
logger.warning(
|
| 255 |
+
"Saving vocabulary to {}: vocabulary indices are not consecutive."
|
| 256 |
+
" Please check that the vocabulary is not corrupted!".format(vocab_file)
|
| 257 |
+
)
|
| 258 |
+
index = token_index
|
| 259 |
+
writer.write(token + "\n")
|
| 260 |
+
index += 1
|
| 261 |
+
return (vocab_file,)
|
| 262 |
+
|
| 263 |
+
class SMILES_Atomwise_Tokenizer(PreTrainedTokenizer):
|
| 264 |
+
r"""
|
| 265 |
+
Constructs a SMILES tokenizer. Based on SMILES Pair Encoding (https://github.com/XinhaoLi74/SmilesPE).
|
| 266 |
+
This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the methods. Users
|
| 267 |
+
should refer to the superclass for more information regarding methods.
|
| 268 |
+
Args:
|
| 269 |
+
vocab_file (:obj:`string`):
|
| 270 |
+
File containing the vocabulary.
|
| 271 |
+
unk_token (:obj:`string`, `optional`, defaults to "[UNK]"):
|
| 272 |
+
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
| 273 |
+
token instead.
|
| 274 |
+
sep_token (:obj:`string`, `optional`, defaults to "[SEP]"):
|
| 275 |
+
The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences
|
| 276 |
+
for sequence classification or for a text and a question for question answering.
|
| 277 |
+
It is also used as the last token of a sequence built with special tokens.
|
| 278 |
+
pad_token (:obj:`string`, `optional`, defaults to "[PAD]"):
|
| 279 |
+
The token used for padding, for example when batching sequences of different lengths.
|
| 280 |
+
cls_token (:obj:`string`, `optional`, defaults to "[CLS]"):
|
| 281 |
+
The classifier token which is used when doing sequence classification (classification of the whole
|
| 282 |
+
sequence instead of per-token classification). It is the first token of the sequence when built with
|
| 283 |
+
special tokens.
|
| 284 |
+
mask_token (:obj:`string`, `optional`, defaults to "[MASK]"):
|
| 285 |
+
The token used for masking values. This is the token used when training this model with masked language
|
| 286 |
+
modeling. This is the token which the model will try to predict.
|
| 287 |
+
"""
|
| 288 |
+
|
| 289 |
+
def __init__(
|
| 290 |
+
self,
|
| 291 |
+
vocab_file,
|
| 292 |
+
unk_token="[UNK]",
|
| 293 |
+
sep_token="[SEP]",
|
| 294 |
+
pad_token="[PAD]",
|
| 295 |
+
cls_token="[CLS]",
|
| 296 |
+
mask_token="[MASK]",
|
| 297 |
+
**kwargs
|
| 298 |
+
):
|
| 299 |
+
super().__init__(
|
| 300 |
+
unk_token=unk_token,
|
| 301 |
+
sep_token=sep_token,
|
| 302 |
+
pad_token=pad_token,
|
| 303 |
+
cls_token=cls_token,
|
| 304 |
+
mask_token=mask_token,
|
| 305 |
+
**kwargs,
|
| 306 |
+
)
|
| 307 |
+
|
| 308 |
+
if not os.path.isfile(vocab_file):
|
| 309 |
+
raise ValueError(
|
| 310 |
+
"Can't find a vocabulary file at path '{}'.".format(vocab_file)
|
| 311 |
+
)
|
| 312 |
+
self.vocab = load_vocab(vocab_file)
|
| 313 |
+
self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
|
| 314 |
+
self.tokenizer = Atomwise_Tokenizer()
|
| 315 |
+
|
| 316 |
+
@property
|
| 317 |
+
def vocab_size(self):
|
| 318 |
+
return len(self.vocab)
|
| 319 |
+
|
| 320 |
+
def get_vocab(self):
|
| 321 |
+
return dict(self.vocab, **self.added_tokens_encoder)
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
def _tokenize(self, text):
|
| 325 |
+
return self.tokenizer.tokenize(text)
|
| 326 |
+
|
| 327 |
+
def _convert_token_to_id(self, token):
|
| 328 |
+
""" Converts a token (str) in an id using the vocab. """
|
| 329 |
+
return self.vocab.get(token, self.vocab.get(self.unk_token))
|
| 330 |
+
|
| 331 |
+
def _convert_id_to_token(self, index):
|
| 332 |
+
"""Converts an index (integer) in a token (str) using the vocab."""
|
| 333 |
+
return self.ids_to_tokens.get(index, self.unk_token)
|
| 334 |
+
|
| 335 |
+
def convert_tokens_to_string(self, tokens):
|
| 336 |
+
""" Converts a sequence of tokens (string) in a single string. """
|
| 337 |
+
out_string = " ".join(tokens).replace(" ##", "").strip()
|
| 338 |
+
return out_string
|
| 339 |
+
|
| 340 |
+
def build_inputs_with_special_tokens(
|
| 341 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
| 342 |
+
) -> List[int]:
|
| 343 |
+
"""
|
| 344 |
+
Build model inputs from a sequence or a pair of sequence for sequence classification tasks
|
| 345 |
+
by concatenating and adding special tokens.
|
| 346 |
+
A BERT sequence has the following format:
|
| 347 |
+
- single sequence: ``[CLS] X [SEP]``
|
| 348 |
+
- pair of sequences: ``[CLS] A [SEP] B [SEP]``
|
| 349 |
+
Args:
|
| 350 |
+
token_ids_0 (:obj:`List[int]`):
|
| 351 |
+
List of IDs to which the special tokens will be added
|
| 352 |
+
token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
|
| 353 |
+
Optional second list of IDs for sequence pairs.
|
| 354 |
+
Returns:
|
| 355 |
+
:obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
|
| 356 |
+
"""
|
| 357 |
+
if token_ids_1 is None:
|
| 358 |
+
return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
|
| 359 |
+
cls = [self.cls_token_id]
|
| 360 |
+
sep = [self.sep_token_id]
|
| 361 |
+
return cls + token_ids_0 + sep + token_ids_1 + sep
|
| 362 |
+
|
| 363 |
+
def get_special_tokens_mask(
|
| 364 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
|
| 365 |
+
) -> List[int]:
|
| 366 |
+
"""
|
| 367 |
+
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
|
| 368 |
+
special tokens using the tokenizer ``prepare_for_model`` method.
|
| 369 |
+
Args:
|
| 370 |
+
token_ids_0 (:obj:`List[int]`):
|
| 371 |
+
List of ids.
|
| 372 |
+
token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
|
| 373 |
+
Optional second list of IDs for sequence pairs.
|
| 374 |
+
already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
| 375 |
+
Set to True if the token list is already formatted with special tokens for the model
|
| 376 |
+
Returns:
|
| 377 |
+
:obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
|
| 378 |
+
"""
|
| 379 |
+
|
| 380 |
+
if already_has_special_tokens:
|
| 381 |
+
if token_ids_1 is not None:
|
| 382 |
+
raise ValueError(
|
| 383 |
+
"You should not supply a second sequence if the provided sequence of "
|
| 384 |
+
"ids is already formated with special tokens for the model."
|
| 385 |
+
)
|
| 386 |
+
return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
|
| 387 |
+
|
| 388 |
+
if token_ids_1 is not None:
|
| 389 |
+
return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
|
| 390 |
+
return [1] + ([0] * len(token_ids_0)) + [1]
|
| 391 |
+
|
| 392 |
+
def create_token_type_ids_from_sequences(
|
| 393 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
| 394 |
+
) -> List[int]:
|
| 395 |
+
"""
|
| 396 |
+
Creates a mask from the two sequences passed to be used in a sequence-pair classification task.
|
| 397 |
+
A BERT sequence pair mask has the following format:
|
| 398 |
+
::
|
| 399 |
+
0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
|
| 400 |
+
| first sequence | second sequence |
|
| 401 |
+
if token_ids_1 is None, only returns the first portion of the mask (0's).
|
| 402 |
+
Args:
|
| 403 |
+
token_ids_0 (:obj:`List[int]`):
|
| 404 |
+
List of ids.
|
| 405 |
+
token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
|
| 406 |
+
Optional second list of IDs for sequence pairs.
|
| 407 |
+
Returns:
|
| 408 |
+
:obj:`List[int]`: List of `token type IDs <../glossary.html#token-type-ids>`_ according to the given
|
| 409 |
+
sequence(s).
|
| 410 |
+
"""
|
| 411 |
+
sep = [self.sep_token_id]
|
| 412 |
+
cls = [self.cls_token_id]
|
| 413 |
+
if token_ids_1 is None:
|
| 414 |
+
return len(cls + token_ids_0 + sep) * [0]
|
| 415 |
+
return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
|
| 416 |
+
|
| 417 |
+
def save_vocabulary(self, vocab_path):
|
| 418 |
+
"""
|
| 419 |
+
Save the sentencepiece vocabulary (copy original file) and special tokens file to a directory.
|
| 420 |
+
Args:
|
| 421 |
+
vocab_path (:obj:`str`):
|
| 422 |
+
The directory in which to save the vocabulary.
|
| 423 |
+
Returns:
|
| 424 |
+
:obj:`Tuple(str)`: Paths to the files saved.
|
| 425 |
+
"""
|
| 426 |
+
index = 0
|
| 427 |
+
if os.path.isdir(vocab_path):
|
| 428 |
+
vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES["vocab_file"])
|
| 429 |
+
else:
|
| 430 |
+
vocab_file = vocab_path
|
| 431 |
+
with open(vocab_file, "w", encoding="utf-8") as writer:
|
| 432 |
+
for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
|
| 433 |
+
if index != token_index:
|
| 434 |
+
logger.warning(
|
| 435 |
+
"Saving vocabulary to {}: vocabulary indices are not consecutive."
|
| 436 |
+
" Please check that the vocabulary is not corrupted!".format(vocab_file)
|
| 437 |
+
)
|
| 438 |
+
index = token_index
|
| 439 |
+
writer.write(token + "\n")
|
| 440 |
+
index += 1
|
| 441 |
+
return (vocab_file,)
|
src/tokenizer/new_splits.txt
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
c 1
|
| 2 |
+
c 2
|
| 3 |
+
c 3
|
| 4 |
+
c 4
|
| 5 |
+
c 5
|
| 6 |
+
c 6
|
| 7 |
+
c 7
|
| 8 |
+
c 8
|
| 9 |
+
c 9
|
| 10 |
+
( c1
|
| 11 |
+
( c2
|
| 12 |
+
c1 )
|
| 13 |
+
c2 )
|
| 14 |
+
n 1
|
| 15 |
+
n 2
|
| 16 |
+
n 3
|
| 17 |
+
n 4
|
| 18 |
+
n 5
|
| 19 |
+
n 6
|
| 20 |
+
n 7
|
| 21 |
+
n 8
|
| 22 |
+
n 9
|
| 23 |
+
( n1
|
| 24 |
+
( n2
|
| 25 |
+
n1 )
|
| 26 |
+
n2 )
|
| 27 |
+
O 1
|
| 28 |
+
O 2
|
| 29 |
+
O 3
|
| 30 |
+
O 4
|
| 31 |
+
O 5
|
| 32 |
+
O 6
|
| 33 |
+
O 7
|
| 34 |
+
O 8
|
| 35 |
+
O 9
|
| 36 |
+
( O1
|
| 37 |
+
( O2
|
| 38 |
+
O2 )
|
| 39 |
+
O2 )
|
| 40 |
+
= O
|
| 41 |
+
= C
|
| 42 |
+
= c
|
| 43 |
+
= N
|
| 44 |
+
= n
|
| 45 |
+
=C C
|
| 46 |
+
=C N
|
| 47 |
+
=C c
|
| 48 |
+
=c c
|
| 49 |
+
=N C
|
| 50 |
+
=N c
|
| 51 |
+
=n C
|
| 52 |
+
=n c
|
| 53 |
+
# N
|
| 54 |
+
# C
|
| 55 |
+
#N C
|
| 56 |
+
#C C
|
| 57 |
+
#C N
|
| 58 |
+
#N N
|
| 59 |
+
( C
|
| 60 |
+
C )
|
| 61 |
+
( O
|
| 62 |
+
O )
|
| 63 |
+
( N
|
| 64 |
+
N )
|
| 65 |
+
Br c
|
| 66 |
+
( =O
|
| 67 |
+
(=O )
|
| 68 |
+
C (=O)
|
| 69 |
+
C =O
|
| 70 |
+
C =N
|
| 71 |
+
C #N
|
| 72 |
+
C #C
|
| 73 |
+
C C
|
| 74 |
+
CC C
|
| 75 |
+
CC N
|
| 76 |
+
CC O
|
| 77 |
+
CC S
|
| 78 |
+
CC c
|
| 79 |
+
CC n
|
| 80 |
+
C N
|
| 81 |
+
CN C
|
| 82 |
+
CN c
|
| 83 |
+
C O
|
| 84 |
+
CO C
|
| 85 |
+
CO N
|
| 86 |
+
CO c
|
| 87 |
+
C S
|
| 88 |
+
CS C
|
| 89 |
+
CS S
|
| 90 |
+
CS c
|
| 91 |
+
C c
|
| 92 |
+
Cl c
|
| 93 |
+
C n
|
| 94 |
+
F c
|
| 95 |
+
N C
|
| 96 |
+
NC C
|
| 97 |
+
NC c
|
| 98 |
+
N N
|
| 99 |
+
N O
|
| 100 |
+
N c
|
| 101 |
+
N n
|
| 102 |
+
O C
|
| 103 |
+
OC C
|
| 104 |
+
OC O
|
| 105 |
+
OC c
|
| 106 |
+
O N
|
| 107 |
+
O O
|
| 108 |
+
O c
|
| 109 |
+
S C
|
| 110 |
+
SC C
|
| 111 |
+
SC c
|
| 112 |
+
S S
|
| 113 |
+
S c
|
| 114 |
+
c c
|
| 115 |
+
cc c
|
| 116 |
+
cc n
|
| 117 |
+
cc o
|
| 118 |
+
cc s
|
| 119 |
+
cc cc
|
| 120 |
+
c n
|
| 121 |
+
cn c
|
| 122 |
+
cn n
|
| 123 |
+
c o
|
| 124 |
+
co c
|
| 125 |
+
c s
|
| 126 |
+
cs c
|
| 127 |
+
cs n
|
| 128 |
+
n c
|
| 129 |
+
nc c
|
| 130 |
+
nc n
|
| 131 |
+
nc o
|
| 132 |
+
nc s
|
| 133 |
+
n n
|
| 134 |
+
nn c
|
| 135 |
+
nn n
|
| 136 |
+
n o
|
| 137 |
+
no c
|
| 138 |
+
no n
|
| 139 |
+
n s
|
| 140 |
+
ns c
|
| 141 |
+
ns n
|
| 142 |
+
o c
|
| 143 |
+
oc c
|
| 144 |
+
o n
|
| 145 |
+
s c
|
| 146 |
+
sc c
|
| 147 |
+
sc n
|
| 148 |
+
s n
|
| 149 |
+
N P
|
| 150 |
+
P N
|
| 151 |
+
C P
|
| 152 |
+
P C
|
| 153 |
+
N S
|
| 154 |
+
S N
|
| 155 |
+
C S
|
| 156 |
+
S C
|
| 157 |
+
S P
|
| 158 |
+
P S
|
| 159 |
+
C I
|
src/tokenizer/new_vocab.txt
ADDED
|
@@ -0,0 +1,587 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[PAD]
|
| 2 |
+
[UNK]
|
| 3 |
+
[CLS]
|
| 4 |
+
[SEP]
|
| 5 |
+
[MASK]
|
| 6 |
+
#
|
| 7 |
+
%
|
| 8 |
+
(
|
| 9 |
+
)
|
| 10 |
+
+
|
| 11 |
+
-
|
| 12 |
+
/
|
| 13 |
+
0
|
| 14 |
+
1
|
| 15 |
+
2
|
| 16 |
+
3
|
| 17 |
+
4
|
| 18 |
+
5
|
| 19 |
+
6
|
| 20 |
+
7
|
| 21 |
+
8
|
| 22 |
+
9
|
| 23 |
+
=
|
| 24 |
+
@
|
| 25 |
+
A
|
| 26 |
+
B
|
| 27 |
+
Br
|
| 28 |
+
Brc
|
| 29 |
+
C
|
| 30 |
+
CC
|
| 31 |
+
CCC
|
| 32 |
+
CCN
|
| 33 |
+
CCO
|
| 34 |
+
CCS
|
| 35 |
+
CCc
|
| 36 |
+
CCn
|
| 37 |
+
CN
|
| 38 |
+
CNC
|
| 39 |
+
CNc
|
| 40 |
+
CO
|
| 41 |
+
COC
|
| 42 |
+
CON
|
| 43 |
+
COc
|
| 44 |
+
CS
|
| 45 |
+
CSC
|
| 46 |
+
CSS
|
| 47 |
+
CSc
|
| 48 |
+
Cc
|
| 49 |
+
Cl
|
| 50 |
+
Clc
|
| 51 |
+
Cn
|
| 52 |
+
F
|
| 53 |
+
Fc
|
| 54 |
+
H
|
| 55 |
+
I
|
| 56 |
+
K
|
| 57 |
+
L
|
| 58 |
+
M
|
| 59 |
+
N
|
| 60 |
+
NC
|
| 61 |
+
NCC
|
| 62 |
+
NCc
|
| 63 |
+
NN
|
| 64 |
+
NO
|
| 65 |
+
Nc
|
| 66 |
+
Nn
|
| 67 |
+
O
|
| 68 |
+
OC
|
| 69 |
+
OCC
|
| 70 |
+
OCO
|
| 71 |
+
OCc
|
| 72 |
+
ON
|
| 73 |
+
OO
|
| 74 |
+
Oc
|
| 75 |
+
P
|
| 76 |
+
R
|
| 77 |
+
S
|
| 78 |
+
SC
|
| 79 |
+
SCC
|
| 80 |
+
SCc
|
| 81 |
+
SS
|
| 82 |
+
Sc
|
| 83 |
+
T
|
| 84 |
+
X
|
| 85 |
+
Z
|
| 86 |
+
[
|
| 87 |
+
\\
|
| 88 |
+
(/
|
| 89 |
+
]
|
| 90 |
+
a
|
| 91 |
+
b
|
| 92 |
+
c
|
| 93 |
+
cc
|
| 94 |
+
ccc
|
| 95 |
+
cccc
|
| 96 |
+
ccn
|
| 97 |
+
cco
|
| 98 |
+
ccs
|
| 99 |
+
cn
|
| 100 |
+
cnc
|
| 101 |
+
cnn
|
| 102 |
+
co
|
| 103 |
+
coc
|
| 104 |
+
cs
|
| 105 |
+
csc
|
| 106 |
+
csn
|
| 107 |
+
e
|
| 108 |
+
g
|
| 109 |
+
i
|
| 110 |
+
l
|
| 111 |
+
n
|
| 112 |
+
nc
|
| 113 |
+
ncc
|
| 114 |
+
ncn
|
| 115 |
+
nco
|
| 116 |
+
ncs
|
| 117 |
+
nn
|
| 118 |
+
nnc
|
| 119 |
+
nnn
|
| 120 |
+
no
|
| 121 |
+
noc
|
| 122 |
+
non
|
| 123 |
+
ns
|
| 124 |
+
nsc
|
| 125 |
+
nsn
|
| 126 |
+
o
|
| 127 |
+
oc
|
| 128 |
+
occ
|
| 129 |
+
on
|
| 130 |
+
p
|
| 131 |
+
r
|
| 132 |
+
s
|
| 133 |
+
sc
|
| 134 |
+
scc
|
| 135 |
+
scn
|
| 136 |
+
sn
|
| 137 |
+
t
|
| 138 |
+
c1
|
| 139 |
+
c2
|
| 140 |
+
c3
|
| 141 |
+
c4
|
| 142 |
+
c5
|
| 143 |
+
c6
|
| 144 |
+
c7
|
| 145 |
+
c8
|
| 146 |
+
c9
|
| 147 |
+
n1
|
| 148 |
+
n2
|
| 149 |
+
n3
|
| 150 |
+
n4
|
| 151 |
+
n5
|
| 152 |
+
n6
|
| 153 |
+
n7
|
| 154 |
+
n8
|
| 155 |
+
n9
|
| 156 |
+
O1
|
| 157 |
+
O2
|
| 158 |
+
O3
|
| 159 |
+
O4
|
| 160 |
+
O5
|
| 161 |
+
O6
|
| 162 |
+
O7
|
| 163 |
+
O8
|
| 164 |
+
O9
|
| 165 |
+
(c1
|
| 166 |
+
(c2
|
| 167 |
+
c1)
|
| 168 |
+
c2)
|
| 169 |
+
(n1
|
| 170 |
+
(n2
|
| 171 |
+
n1)
|
| 172 |
+
n2)
|
| 173 |
+
(O1
|
| 174 |
+
(O2
|
| 175 |
+
O2)
|
| 176 |
+
=O
|
| 177 |
+
=C
|
| 178 |
+
=c
|
| 179 |
+
=N
|
| 180 |
+
=n
|
| 181 |
+
=CC
|
| 182 |
+
=CN
|
| 183 |
+
=Cc
|
| 184 |
+
=cc
|
| 185 |
+
=NC
|
| 186 |
+
=Nc
|
| 187 |
+
=nC
|
| 188 |
+
=nc
|
| 189 |
+
#C
|
| 190 |
+
#CC
|
| 191 |
+
#CN
|
| 192 |
+
#N
|
| 193 |
+
#NC
|
| 194 |
+
#NN
|
| 195 |
+
(C
|
| 196 |
+
C)
|
| 197 |
+
(O
|
| 198 |
+
O)
|
| 199 |
+
(N
|
| 200 |
+
N)
|
| 201 |
+
NP
|
| 202 |
+
PN
|
| 203 |
+
CP
|
| 204 |
+
PC
|
| 205 |
+
NS
|
| 206 |
+
SN
|
| 207 |
+
SP
|
| 208 |
+
PS
|
| 209 |
+
C(=O)
|
| 210 |
+
(/Br)
|
| 211 |
+
(/C#N)
|
| 212 |
+
(/C)
|
| 213 |
+
(/C=N)
|
| 214 |
+
(/C=O)
|
| 215 |
+
(/CBr)
|
| 216 |
+
(/CC)
|
| 217 |
+
(/CCC)
|
| 218 |
+
(/CCF)
|
| 219 |
+
(/CCN)
|
| 220 |
+
(/CCO)
|
| 221 |
+
(/CCl)
|
| 222 |
+
(/CI)
|
| 223 |
+
(/CN)
|
| 224 |
+
(/CO)
|
| 225 |
+
(/CS)
|
| 226 |
+
(/Cl)
|
| 227 |
+
(/F)
|
| 228 |
+
(/I)
|
| 229 |
+
(/N)
|
| 230 |
+
(/NC)
|
| 231 |
+
(/NCC)
|
| 232 |
+
(/NO)
|
| 233 |
+
(/O)
|
| 234 |
+
(/OC)
|
| 235 |
+
(/OCC)
|
| 236 |
+
(/S)
|
| 237 |
+
(/SC)
|
| 238 |
+
(=C)
|
| 239 |
+
(=C/C)
|
| 240 |
+
(=C/F)
|
| 241 |
+
(=C/I)
|
| 242 |
+
(=C/N)
|
| 243 |
+
(=C/O)
|
| 244 |
+
(=CBr)
|
| 245 |
+
(=CC)
|
| 246 |
+
(=CCF)
|
| 247 |
+
(=CCN)
|
| 248 |
+
(=CCO)
|
| 249 |
+
(=CCl)
|
| 250 |
+
(=CF)
|
| 251 |
+
(=CI)
|
| 252 |
+
(=CN)
|
| 253 |
+
(=CO)
|
| 254 |
+
(=C\\C)
|
| 255 |
+
(=C\\F)
|
| 256 |
+
(=C\\I)
|
| 257 |
+
(=C\\N)
|
| 258 |
+
(=C\\O)
|
| 259 |
+
(=N)
|
| 260 |
+
(=N/C)
|
| 261 |
+
(=N/N)
|
| 262 |
+
(=N/O)
|
| 263 |
+
(=NBr)
|
| 264 |
+
(=NC)
|
| 265 |
+
(=NCC)
|
| 266 |
+
(=NCl)
|
| 267 |
+
(=NN)
|
| 268 |
+
(=NO)
|
| 269 |
+
(=NOC)
|
| 270 |
+
(=N\\C)
|
| 271 |
+
(=N\\N)
|
| 272 |
+
(=N\\O)
|
| 273 |
+
(=O)
|
| 274 |
+
(=S)
|
| 275 |
+
(B)
|
| 276 |
+
(Br)
|
| 277 |
+
(C#C)
|
| 278 |
+
(C#CC)
|
| 279 |
+
(C#CI)
|
| 280 |
+
(C#CO)
|
| 281 |
+
(C#N)
|
| 282 |
+
(C#SN)
|
| 283 |
+
(C)
|
| 284 |
+
(C=C)
|
| 285 |
+
(C=CF)
|
| 286 |
+
(C=CI)
|
| 287 |
+
(C=N)
|
| 288 |
+
(C=NN)
|
| 289 |
+
(C=NO)
|
| 290 |
+
(C=O)
|
| 291 |
+
(C=S)
|
| 292 |
+
(CBr)
|
| 293 |
+
(CC#C)
|
| 294 |
+
(CC#N)
|
| 295 |
+
(CC)
|
| 296 |
+
(CC=C)
|
| 297 |
+
(CC=O)
|
| 298 |
+
(CCBr)
|
| 299 |
+
(CCC)
|
| 300 |
+
(CCCC)
|
| 301 |
+
(CCCF)
|
| 302 |
+
(CCCI)
|
| 303 |
+
(CCCN)
|
| 304 |
+
(CCCO)
|
| 305 |
+
(CCCS)
|
| 306 |
+
(CCCl)
|
| 307 |
+
(CCF)
|
| 308 |
+
(CCI)
|
| 309 |
+
(CCN)
|
| 310 |
+
(CCNC)
|
| 311 |
+
(CCNN)
|
| 312 |
+
(CCNO)
|
| 313 |
+
(CCO)
|
| 314 |
+
(CCOC)
|
| 315 |
+
(CCON)
|
| 316 |
+
(CCS)
|
| 317 |
+
(CCSC)
|
| 318 |
+
(CCl)
|
| 319 |
+
(CF)
|
| 320 |
+
(CI)
|
| 321 |
+
(CN)
|
| 322 |
+
(CN=O)
|
| 323 |
+
(CNC)
|
| 324 |
+
(CNCC)
|
| 325 |
+
(CNCO)
|
| 326 |
+
(CNN)
|
| 327 |
+
(CNNC)
|
| 328 |
+
(CNO)
|
| 329 |
+
(CNOC)
|
| 330 |
+
(CO)
|
| 331 |
+
(COC)
|
| 332 |
+
(COCC)
|
| 333 |
+
(COCI)
|
| 334 |
+
(COCN)
|
| 335 |
+
(COCO)
|
| 336 |
+
(COF)
|
| 337 |
+
(CON)
|
| 338 |
+
(COO)
|
| 339 |
+
(CS)
|
| 340 |
+
(CSC)
|
| 341 |
+
(CSCC)
|
| 342 |
+
(CSCF)
|
| 343 |
+
(CSO)
|
| 344 |
+
(Cl)
|
| 345 |
+
(F)
|
| 346 |
+
(I)
|
| 347 |
+
(N)
|
| 348 |
+
(N=N)
|
| 349 |
+
(N=NO)
|
| 350 |
+
(N=O)
|
| 351 |
+
(N=S)
|
| 352 |
+
(NBr)
|
| 353 |
+
(NC#N)
|
| 354 |
+
(NC)
|
| 355 |
+
(NC=N)
|
| 356 |
+
(NC=O)
|
| 357 |
+
(NC=S)
|
| 358 |
+
(NCBr)
|
| 359 |
+
(NCC)
|
| 360 |
+
(NCCC)
|
| 361 |
+
(NCCF)
|
| 362 |
+
(NCCN)
|
| 363 |
+
(NCCO)
|
| 364 |
+
(NCCS)
|
| 365 |
+
(NCCl)
|
| 366 |
+
(NCNC)
|
| 367 |
+
(NCO)
|
| 368 |
+
(NCS)
|
| 369 |
+
(NCl)
|
| 370 |
+
(NN)
|
| 371 |
+
(NN=O)
|
| 372 |
+
(NNC)
|
| 373 |
+
(NO)
|
| 374 |
+
(NOC)
|
| 375 |
+
(O)
|
| 376 |
+
(OC#N)
|
| 377 |
+
(OC)
|
| 378 |
+
(OC=C)
|
| 379 |
+
(OC=O)
|
| 380 |
+
(OC=S)
|
| 381 |
+
(OCBr)
|
| 382 |
+
(OCC)
|
| 383 |
+
(OCCC)
|
| 384 |
+
(OCCF)
|
| 385 |
+
(OCCI)
|
| 386 |
+
(OCCN)
|
| 387 |
+
(OCCO)
|
| 388 |
+
(OCCS)
|
| 389 |
+
(OCCl)
|
| 390 |
+
(OCF)
|
| 391 |
+
(OCI)
|
| 392 |
+
(OCO)
|
| 393 |
+
(OCOC)
|
| 394 |
+
(OCON)
|
| 395 |
+
(OCSC)
|
| 396 |
+
(OCl)
|
| 397 |
+
(OI)
|
| 398 |
+
(ON)
|
| 399 |
+
(OO)
|
| 400 |
+
(OOC)
|
| 401 |
+
(OOCC)
|
| 402 |
+
(OOSN)
|
| 403 |
+
(OSC)
|
| 404 |
+
(P)
|
| 405 |
+
(S)
|
| 406 |
+
(SC#N)
|
| 407 |
+
(SC)
|
| 408 |
+
(SCC)
|
| 409 |
+
(SCCC)
|
| 410 |
+
(SCCF)
|
| 411 |
+
(SCCN)
|
| 412 |
+
(SCCO)
|
| 413 |
+
(SCCS)
|
| 414 |
+
(SCCl)
|
| 415 |
+
(SCF)
|
| 416 |
+
(SCN)
|
| 417 |
+
(SCOC)
|
| 418 |
+
(SCSC)
|
| 419 |
+
(SCl)
|
| 420 |
+
(SI)
|
| 421 |
+
(SN)
|
| 422 |
+
(SN=O)
|
| 423 |
+
(SO)
|
| 424 |
+
(SOC)
|
| 425 |
+
(SOOO)
|
| 426 |
+
(SS)
|
| 427 |
+
(SSC)
|
| 428 |
+
(SSCC)
|
| 429 |
+
([At])
|
| 430 |
+
([O-])
|
| 431 |
+
([O])
|
| 432 |
+
([S-])
|
| 433 |
+
(\\Br)
|
| 434 |
+
(\\C#N)
|
| 435 |
+
(\\C)
|
| 436 |
+
(\\C=N)
|
| 437 |
+
(\\C=O)
|
| 438 |
+
(\\CBr)
|
| 439 |
+
(\\CC)
|
| 440 |
+
(\\CCC)
|
| 441 |
+
(\\CCO)
|
| 442 |
+
(\\CCl)
|
| 443 |
+
(\\CF)
|
| 444 |
+
(\\CN)
|
| 445 |
+
(\\CNC)
|
| 446 |
+
(\\CO)
|
| 447 |
+
(\\COC)
|
| 448 |
+
(\\Cl)
|
| 449 |
+
(\\F)
|
| 450 |
+
(\\I)
|
| 451 |
+
(\\N)
|
| 452 |
+
(\\NC)
|
| 453 |
+
(\\NCC)
|
| 454 |
+
(\\NN)
|
| 455 |
+
(\\NO)
|
| 456 |
+
(\\NOC)
|
| 457 |
+
(\\O)
|
| 458 |
+
(\\OC)
|
| 459 |
+
(\\OCC)
|
| 460 |
+
(\\ON)
|
| 461 |
+
(\\S)
|
| 462 |
+
(\\SC)
|
| 463 |
+
(\\SCC)
|
| 464 |
+
[Ag+]
|
| 465 |
+
[Ag-4]
|
| 466 |
+
[Ag]
|
| 467 |
+
[Al-3]
|
| 468 |
+
[Al]
|
| 469 |
+
[As+]
|
| 470 |
+
[AsH3]
|
| 471 |
+
[AsH]
|
| 472 |
+
[As]
|
| 473 |
+
[At]
|
| 474 |
+
[B-]
|
| 475 |
+
[B@-]
|
| 476 |
+
[B@@-]
|
| 477 |
+
[BH-]
|
| 478 |
+
[BH2-]
|
| 479 |
+
[BH3-]
|
| 480 |
+
[B]
|
| 481 |
+
[Ba]
|
| 482 |
+
[Br+2]
|
| 483 |
+
[BrH]
|
| 484 |
+
[Br]
|
| 485 |
+
[C+]
|
| 486 |
+
[C-]
|
| 487 |
+
[C@@H]
|
| 488 |
+
[C@@]
|
| 489 |
+
[C@H]
|
| 490 |
+
[C@]
|
| 491 |
+
[CH-]
|
| 492 |
+
[CH2]
|
| 493 |
+
[CH3]
|
| 494 |
+
[CH]
|
| 495 |
+
[C]
|
| 496 |
+
[CaH2]
|
| 497 |
+
[Ca]
|
| 498 |
+
[Cl+2]
|
| 499 |
+
[Cl+3]
|
| 500 |
+
[Cl+]
|
| 501 |
+
[Cs]
|
| 502 |
+
[FH]
|
| 503 |
+
[F]
|
| 504 |
+
[H]
|
| 505 |
+
[He]
|
| 506 |
+
[I+2]
|
| 507 |
+
[I+3]
|
| 508 |
+
[I+]
|
| 509 |
+
[IH]
|
| 510 |
+
[I]
|
| 511 |
+
[K]
|
| 512 |
+
[Kr]
|
| 513 |
+
[Li+]
|
| 514 |
+
[LiH]
|
| 515 |
+
[MgH2]
|
| 516 |
+
[Mg]
|
| 517 |
+
[N+]
|
| 518 |
+
[N-]
|
| 519 |
+
[N@+]
|
| 520 |
+
[N@@+]
|
| 521 |
+
[N@@]
|
| 522 |
+
[N@]
|
| 523 |
+
[NH+]
|
| 524 |
+
[NH-]
|
| 525 |
+
[NH2+]
|
| 526 |
+
[NH3]
|
| 527 |
+
[NH]
|
| 528 |
+
[N]
|
| 529 |
+
[Na]
|
| 530 |
+
[O+]
|
| 531 |
+
[O-]
|
| 532 |
+
[OH+]
|
| 533 |
+
[OH2]
|
| 534 |
+
[OH]
|
| 535 |
+
[O]
|
| 536 |
+
[P+]
|
| 537 |
+
[P@+]
|
| 538 |
+
[P@@+]
|
| 539 |
+
[P@@]
|
| 540 |
+
[P@]
|
| 541 |
+
[PH2]
|
| 542 |
+
[PH]
|
| 543 |
+
[P]
|
| 544 |
+
[Ra]
|
| 545 |
+
[Rb]
|
| 546 |
+
[S+]
|
| 547 |
+
[S-]
|
| 548 |
+
[S@+]
|
| 549 |
+
[S@@+]
|
| 550 |
+
[S@@]
|
| 551 |
+
[S@]
|
| 552 |
+
[SH+]
|
| 553 |
+
[SH2]
|
| 554 |
+
[SH]
|
| 555 |
+
[S]
|
| 556 |
+
[Se+]
|
| 557 |
+
[Se-2]
|
| 558 |
+
[SeH2]
|
| 559 |
+
[SeH]
|
| 560 |
+
[Se]
|
| 561 |
+
[Si@]
|
| 562 |
+
[SiH2]
|
| 563 |
+
[SiH]
|
| 564 |
+
[Si]
|
| 565 |
+
[SrH2]
|
| 566 |
+
[TeH]
|
| 567 |
+
[Te]
|
| 568 |
+
[Xe]
|
| 569 |
+
[Zn+2]
|
| 570 |
+
[Zn-2]
|
| 571 |
+
[Zn]
|
| 572 |
+
[b-]
|
| 573 |
+
[c+]
|
| 574 |
+
[c-]
|
| 575 |
+
[cH-]
|
| 576 |
+
[cH]
|
| 577 |
+
[c]
|
| 578 |
+
[n+]
|
| 579 |
+
[n-]
|
| 580 |
+
[nH]
|
| 581 |
+
[n]
|
| 582 |
+
[o+]
|
| 583 |
+
[s+]
|
| 584 |
+
[se+]
|
| 585 |
+
[se]
|
| 586 |
+
[te+]
|
| 587 |
+
[te]
|
src/train.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# direct reward backpropagation
|
| 2 |
+
from diffusion import Diffusion
|
| 3 |
+
from hydra import initialize, compose
|
| 4 |
+
from hydra.core.global_hydra import GlobalHydra
|
| 5 |
+
import numpy as np
|
| 6 |
+
from scipy.stats import pearsonr
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
import argparse
|
| 10 |
+
import wandb
|
| 11 |
+
import os
|
| 12 |
+
import datetime
|
| 13 |
+
from finetune_peptides import finetune
|
| 14 |
+
from peptide_mcts import MCTS
|
| 15 |
+
from utils.utils import str2bool, set_seed
|
| 16 |
+
from scoring.scoring_functions import ScoringFunctions
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
argparser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
| 20 |
+
argparser.add_argument('--base_path', type=str, default='')
|
| 21 |
+
argparser.add_argument('--learning_rate', type=float, default=1e-4)
|
| 22 |
+
argparser.add_argument('--num_epochs', type=int, default=100)
|
| 23 |
+
argparser.add_argument('--num_accum_steps', type=int, default=4)
|
| 24 |
+
argparser.add_argument('--truncate_steps', type=int, default=50)
|
| 25 |
+
argparser.add_argument("--truncate_kl", type=str2bool, default=False)
|
| 26 |
+
argparser.add_argument('--gumbel_temp', type=float, default=1.0)
|
| 27 |
+
argparser.add_argument('--gradnorm_clip', type=float, default=1.0)
|
| 28 |
+
argparser.add_argument('--batch_size', type=int, default=32)
|
| 29 |
+
argparser.add_argument('--name', type=str, default='debug')
|
| 30 |
+
argparser.add_argument('--total_num_steps', type=int, default=128)
|
| 31 |
+
argparser.add_argument('--copy_flag_temp', type=float, default=None)
|
| 32 |
+
argparser.add_argument('--save_every_n_epochs', type=int, default=10)
|
| 33 |
+
argparser.add_argument('--alpha_schedule_warmup', type=int, default=0)
|
| 34 |
+
argparser.add_argument("--seed", type=int, default=0)
|
| 35 |
+
# new
|
| 36 |
+
argparser.add_argument('--run_name', type=str, default='peptides')
|
| 37 |
+
argparser.add_argument("--device", default="cuda:0", type=str)
|
| 38 |
+
argparser.add_argument("--save_path_dir", default="/path/to/your/home/PepTune/checkpoints/", type=str)
|
| 39 |
+
# mcts
|
| 40 |
+
argparser.add_argument('--num_sequences', type=int, default=10)
|
| 41 |
+
argparser.add_argument('--num_children', type=int, default=50)
|
| 42 |
+
argparser.add_argument('--num_iter', type=int, default=30) # iterations of mcts
|
| 43 |
+
argparser.add_argument('--seq_length', type=int, default=200)
|
| 44 |
+
argparser.add_argument('--time_conditioning', action='store_true', default=False)
|
| 45 |
+
argparser.add_argument('--mcts_sampling', type=int, default=0) # for batched categorical sampling: '0' means gumbel noise
|
| 46 |
+
argparser.add_argument('--buffer_size', type=int, default=100)
|
| 47 |
+
argparser.add_argument('--wdce_num_replicates', type=int, default=16)
|
| 48 |
+
argparser.add_argument('--noise_removal', action='store_true', default=False)
|
| 49 |
+
argparser.add_argument('--grad_clip', action='store_true', default=False)
|
| 50 |
+
argparser.add_argument('--resample_every_n_step', type=int, default=10)
|
| 51 |
+
argparser.add_argument('--exploration', type=float, default=0.1)
|
| 52 |
+
argparser.add_argument('--reset_every_n_step', type=int, default=100)
|
| 53 |
+
argparser.add_argument('--alpha', type=float, default=0.01)
|
| 54 |
+
argparser.add_argument('--scalarization', type=str, default='sum')
|
| 55 |
+
argparser.add_argument('--no_mcts', action='store_true', default=False)
|
| 56 |
+
argparser.add_argument("--centering", action='store_true', default=False)
|
| 57 |
+
|
| 58 |
+
# objectives
|
| 59 |
+
argparser.add_argument('--num_obj', type=int, default=5)
|
| 60 |
+
argparser.add_argument('--prot_seq', type=str, default=None)
|
| 61 |
+
argparser.add_argument('--prot_name', type=str, default=None)
|
| 62 |
+
|
| 63 |
+
args = argparser.parse_args()
|
| 64 |
+
print(args)
|
| 65 |
+
|
| 66 |
+
# pretrained model path
|
| 67 |
+
ckpt_path = f'{args.base_path}/checkpoints/peptune-pretrained.ckpt'
|
| 68 |
+
|
| 69 |
+
# reinitialize Hydra
|
| 70 |
+
GlobalHydra.instance().clear()
|
| 71 |
+
|
| 72 |
+
# Initialize Hydra and compose the configuration
|
| 73 |
+
initialize(config_path="configs", job_name="load_model")
|
| 74 |
+
cfg = compose(config_name="peptune_config.yaml")
|
| 75 |
+
curr_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 76 |
+
|
| 77 |
+
# proteins
|
| 78 |
+
amhr = 'MLGSLGLWALLPTAVEAPPNRRTCVFFEAPGVRGSTKTLGELLDTGTELPRAIRCLYSRCCFGIWNLTQDRAQVEMQGCRDSDEPGCESLHCDPSPRAHPSPGSTLFTCSCGTDFCNANYSHLPPPGSPGTPGSQGPQAAPGESIWMALVLLGLFLLLLLLLGSIILALLQRKNYRVRGEPVPEPRPDSGRDWSVELQELPELCFSQVIREGGHAVVWAGQLQGKLVAIKAFPPRSVAQFQAERALYELPGLQHDHIVRFITASRGGPGRLLSGPLLVLELHPKGSLCHYLTQYTSDWGSSLRMALSLAQGLAFLHEERWQNGQYKPGIAHRDLSSQNVLIREDGSCAIGDLGLALVLPGLTQPPAWTPTQPQGPAAIMEAGTQRYMAPELLDKTLDLQDWGMALRRADIYSLALLLWEILSRCPDLRPDSSPPPFQLAYEAELGNTPTSDELWALAVQERRRPYIPSTWRCFATDPDGLRELLEDCWDADPEARLTAECVQQRLAALAHPQESHPFPESCPRGCPPLCPEDCTSIPAPTILPCRPQRSACHFSVQQGPCSRNPQPACTLSPV'
|
| 79 |
+
tfr = 'MMDQARSAFSNLFGGEPLSYTRFSLARQVDGDNSHVEMKLAVDEEENADNNTKANVTKPKRCSGSICYGTIAVIVFFLIGFMIGYLGYCKGVEPKTECERLAGTESPVREEPGEDFPAARRLYWDDLKRKLSEKLDSTDFTGTIKLLNENSYVPREAGSQKDENLALYVENQFREFKLSKVWRDQHFVKIQVKDSAQNSVIIVDKNGRLVYLVENPGGYVAYSKAATVTGKLVHANFGTKKDFEDLYTPVNGSIVIVRAGKITFAEKVANAESLNAIGVLIYMDQTKFPIVNAELSFFGHAHLGTGDPYTPGFPSFNHTQFPPSRSSGLPNIPVQTISRAAAEKLFGNMEGDCPSDWKTDSTCRMVTSESKNVKLTVSNVLKEIKILNIFGVIKGFVEPDHYVVVGAQRDAWGPGAAKSGVGTALLLKLAQMFSDMVLKDGFQPSRSIIFASWSAGDFGSVGATEWLEGYLSSLHLKAFTYINLDKAVLGTSNFKVSASPLLYTLIEKTMQNVKHPVTGQFLYQDSNWASKVEKLTLDNAAFPFLAYSGIPAVSFCFCEDTDYPYLGTTMDTYKELIERIPELNKVARAAAEVAGQFVIKLTHDVELNLDYERYNSQLLSFVRDLNQYRADIKEMGLSLQWLYSARGDFFRATSRLTTDFGNAEKTDRFVMKKLNDRVMRVEYHFLSPYVSPKESPFRHVFWGSGSHTLPALLENLKLRKQNNGAFNETLFRNQLALATWTIQGAANALSGDVWDIDNEF'
|
| 80 |
+
gfap = 'MERRRITSAARRSYVSSGEMMVGGLAPGRRLGPGTRLSLARMPPPLPTRVDFSLAGALNAGFKETRASERAEMMELNDRFASYIEKVRFLEQQNKALAAELNQLRAKEPTKLADVYQAELRELRLRLDQLTANSARLEVERDNLAQDLATVRQKLQDETNLRLEAENNLAAYRQEADEATLARLDLERKIESLEEEIRFLRKIHEEEVRELQEQLARQQVHVELDVAKPDLTAALKEIRTQYEAMASSNMHEAEEWYRSKFADLTDAAARNAELLRQAKHEANDYRRQLQSLTCDLESLRGTNESLERQMREQEERHVREAASYQEALARLEEEGQSLKDEMARHLQEYQDLLNVKLALDIEIATYRKLLEGEENRITIPVQTFSNLQIRETSLDTKSVSEGHLKRNIVVKTVEMRDGEVIKESKQEHKDVM'
|
| 81 |
+
glp1 = 'MAGAPGPLRLALLLLGMVGRAGPRPQGATVSLWETVQKWREYRRQCQRSLTEDPPPATDLFCNRTFDEYACWPDGEPGSFVNVSCPWYLPWASSVPQGHVYRFCTAEGLWLQKDNSSLPWRDLSECEESKRGERSSPEEQLLFLYIIYTVGYALSFSALVIASAILLGFRHLHCTRNYIHLNLFASFILRALSVFIKDAALKWMYSTAAQQHQWDGLLSYQDSLSCRLVFLLMQYCVAANYYWLLVEGVYLYTLLAFSVLSEQWIFRLYVSIGWGVPLLFVVPWGIVKYLYEDEGCWTRNSNMNYWLIIRLPILFAIGVNFLIFVRVICIVVSKLKANLMCKTDIKCRLAKSTLTLIPLLGTHEVIFAFVMDEHARGTLRFIKLFTELSFTSFQGLMVAILYCFVNNEVQLEFRKSWERWRLEHLHIQRDSSMKPLKCPTSSLSSGATAGSSMYTATCQASCS'
|
| 82 |
+
glast = 'MTKSNGEEPKMGGRMERFQQGVRKRTLLAKKKVQNITKEDVKSYLFRNAFVLLTVTAVIVGTILGFTLRPYRMSYREVKYFSFPGELLMRMLQMLVLPLIISSLVTGMAALDSKASGKMGMRAVVYYMTTTIIAVVIGIIIVIIIHPGKGTKENMHREGKIVRVTAADAFLDLIRNMFPPNLVEACFKQFKTNYEKRSFKVPIQANETLVGAVINNVSEAMETLTRITEELVPVPGSVNGVNALGLVVFSMCFGFVIGNMKEQGQALREFFDSLNEAIMRLVAVIMWYAPVGILFLIAGKIVEMEDMGVIGGQLAMYTVTVIVGLLIHAVIVLPLLYFLVTRKNPWVFIGGLLQALITALGTSSSSATLPITFKCLEENNGVDKRVTRFVLPVGATINMDGTALYEALAAIFIAQVNNFELNFGQIITISITATAASIGAAGIPQAGLVTMVIVLTSVGLPTDDITLIIAVDWFLDRLRTTTNVLGDSLGAGIVEHLSRHELKNRDVEMGNSVIEENEMKKPYQLIAQDNETEKPIDSETKM'
|
| 83 |
+
ncam = 'LQTKDLIWTLFFLGTAVSLQVDIVPSQGEISVGESKFFLCQVAGDAKDKDISWFSPNGEKLTPNQQRISVVWNDDSSSTLTIYNANIDDAGIYKCVVTGEDGSESEATVNVKIFQKLMFKNAPTPQEFREGEDAVIVCDVVSSLPPTIIWKHKGRDVILKKDVRFIVLSNNYLQIRGIKKTDEGTYRCEGRILARGEINFKDIQVIVNVPPTIQARQNIVNATANLGQSVTLVCDAEGFPEPTMSWTKDGEQIEQEEDDEKYIFSDDSSQLTIKKVDKNDEAEYICIAENKAGEQDATIHLKVFAKPKITYVENQTAMELEEQVTLTCEASGDPIPSITWRTSTRNISSEEKASWTRPEKQETLDGHMVVRSHARVSSLTLKSIQYTDAGEYICTASNTIGQDSQSMYLEVQYAPKLQGPVAVYTWEGNQVNITCEVFAYPSATISWFRDGQLLPSSNYSNIKIYNTPSASYLEVTPDSENDFGNYNCTAVNRIGQESLEFILVQADTPSSPSIDQVEPYSSTAQVQFDEPEATGGVPILKYKAEWRAVGEEVWHSKWYDAKEASMEGIVTIVGLKPETTYAVRLAALNGKGLGEISAASEF'
|
| 84 |
+
cereblon = 'MAGEGDQQDAAHNMGNHLPLLPAESEEEDEMEVEDQDSKEAKKPNIINFDTSLPTSHTYLGADMEEFHGRTLHDDDSCQVIPVLPQVMMILIPGQTLPLQLFHPQEVSMVRNLIQKDRTFAVLAYSNVQEREAQFGTTAEIYAYREEQDFGIEIVKVKAIGRQRFKVLELRTQSDGIQQAKVQILPECVLPSTMSAVQLESLNKCQIFPSKPVSREDQCSYKWWQKYQKRKFHCANLTSWPRWLYSLYDAETLMDRIKKQLREWDENLKDDSLPSNPIDFSYRVAACLPIDDVLRIQLLKIGSAIQRLRCELDIMNKCTSLCCKQCQETEITTKNEIFSLSLCGPMAAYVNPHGYVHETLTVYKACNLNLIGRPSTEHSWFPGYAWTVAQCKICASHIGWKFTATKKDMSPQKFWGLTRSALLPTIPDTEDEISPDKVILCL'
|
| 85 |
+
ligase = 'MASQPPEDTAESQASDELECKICYNRYNLKQRKPKVLECCHRVCAKCLYKIIDFGDSPQGVIVCPFCRFETCLPDDEVSSLPDDNNILVNLTCGGKGKKCLPENPTELLLTPKRLASLVSPSHTSSNCLVITIMEVQRESSPSLSSTPVVEFYRPASFDSVTTVSHNWTVWNCTSLLFQTSIRVLVWLLGLLYFSSLPLGIYLLVSKKVTLGVVFVSLVPSSLVILMVYGFCQCVCHEFLDCMAPPS'
|
| 86 |
+
skp2 = 'MHRKHLQEIPDLSSNVATSFTWGWDSSKTSELLSGMGVSALEKEEPDSENIPQELLSNLGHPESPPRKRLKSKGSDKDFVIVRRPKLNRENFPGVSWDSLPDELLLGIFSCLCLPELLKVSGVCKRWYRLASDESLWQTLDLTGKNLHPDVTGRLLSQGVIAFRCPRSFMDQPLAEHFSPFRVQHMDLSNSVIEVSTLHGILSQCSKLQNLSLEGLRLSDPIVNTLAKNSNLVRLNLSGCSGFSEFALQTLLSSCSRLDELNLSWCFDFTEKHVQVAVAHVSETITQLNLSGYRKNLQKSDLSTLVRRCPNLVHLDLSDSVMLKNDCFQEFFQLNYLQHLSLSRCYDIIPETLLELGEIPTLKTLQVFGIVPDGTLQLLKEALPHLQINCSHFTTIARPTIGNKKNQEIWGIKCRLTLQKPSCL'
|
| 87 |
+
|
| 88 |
+
if args.prot_seq is not None:
|
| 89 |
+
prot = args.prot_seq
|
| 90 |
+
prot_name = args.prot_name
|
| 91 |
+
filename = args.prot_name
|
| 92 |
+
else:
|
| 93 |
+
prot = tfr
|
| 94 |
+
prot_name = "tfr"
|
| 95 |
+
filename = "tfr"
|
| 96 |
+
|
| 97 |
+
if args.no_mcts:
|
| 98 |
+
args.run_name = f'{prot_name}_resample{args.resample_every_n_step}_no-mcts'
|
| 99 |
+
else:
|
| 100 |
+
args.run_name = f'{prot_name}_resample{args.resample_every_n_step}_buffer{args.buffer_size}_numiter{args.num_iter}_children{args.num_children}_{curr_time}'
|
| 101 |
+
|
| 102 |
+
args.save_path = os.path.join(args.save_path_dir, args.run_name)
|
| 103 |
+
os.makedirs(args.save_path, exist_ok=True)
|
| 104 |
+
# wandb init
|
| 105 |
+
wandb.init(project='tree-multi', name=args.run_name, config=args, dir=args.save_path)
|
| 106 |
+
|
| 107 |
+
log_path = os.path.join(args.save_path, 'log.txt')
|
| 108 |
+
|
| 109 |
+
set_seed(args.seed, use_cuda=True)
|
| 110 |
+
|
| 111 |
+
# Initialize the model
|
| 112 |
+
policy_model = Diffusion.load_from_checkpoint(ckpt_path,
|
| 113 |
+
config=cfg,
|
| 114 |
+
mode="train",
|
| 115 |
+
device=args.device,
|
| 116 |
+
map_location=args.device)
|
| 117 |
+
pretrained = Diffusion.load_from_checkpoint(ckpt_path,
|
| 118 |
+
config=cfg,
|
| 119 |
+
mode="eval",
|
| 120 |
+
device=args.device,
|
| 121 |
+
map_location=args.device)
|
| 122 |
+
|
| 123 |
+
# define mcts
|
| 124 |
+
score_func_names = ['binding_affinity1', 'solubility', 'hemolysis', 'nonfouling', 'permeability']
|
| 125 |
+
|
| 126 |
+
mcts = MCTS(args, cfg, policy_model, pretrained, score_func_names, prot_seqs=[prot])
|
| 127 |
+
|
| 128 |
+
if args.no_mcts:
|
| 129 |
+
reward_model = ScoringFunctions(score_func_names, prot_seqs=[prot], device=args.device)
|
| 130 |
+
finetune(args, cfg, policy_model, reward_model=reward_model, mcts=None, pretrained=pretrained, filename=filename, prot_name=prot_name)
|
| 131 |
+
else:
|
| 132 |
+
mcts = MCTS(args, cfg, policy_model, pretrained, score_func_names, prot_seqs=[prot])
|
| 133 |
+
finetune(args, cfg, policy_model, reward_model=mcts.rewardFunc, mcts=mcts, pretrained=None, filename=filename, prot_name=prot_name)
|
src/train_peptune.py
ADDED
|
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import uuid
|
| 3 |
+
import sys
|
| 4 |
+
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
| 5 |
+
|
| 6 |
+
import wandb
|
| 7 |
+
import fsspec
|
| 8 |
+
import hydra
|
| 9 |
+
import lightning as L
|
| 10 |
+
from lightning.pytorch import Trainer
|
| 11 |
+
from lightning.pytorch.callbacks import ModelCheckpoint, GradientAccumulationScheduler
|
| 12 |
+
import omegaconf
|
| 13 |
+
import rich.syntax
|
| 14 |
+
import rich.tree
|
| 15 |
+
import torch
|
| 16 |
+
import sys
|
| 17 |
+
import torch.distributed as dist
|
| 18 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 19 |
+
import dataloading_for_dynamic_batching as dynamic_dataloader
|
| 20 |
+
from diffusion import Diffusion
|
| 21 |
+
import utils.utils as utils
|
| 22 |
+
|
| 23 |
+
from lightning.pytorch.strategies import DDPStrategy
|
| 24 |
+
from datasets import load_dataset
|
| 25 |
+
from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
|
| 26 |
+
|
| 27 |
+
omegaconf.OmegaConf.register_new_resolver('cwd', os.getcwd)
|
| 28 |
+
omegaconf.OmegaConf.register_new_resolver('device_count', torch.cuda.device_count)
|
| 29 |
+
omegaconf.OmegaConf.register_new_resolver('eval', eval)
|
| 30 |
+
omegaconf.OmegaConf.register_new_resolver('div_up', lambda x, y: (x + y - 1) // y)
|
| 31 |
+
|
| 32 |
+
def _load_from_checkpoint(config, tokenizer):
|
| 33 |
+
if 'hf' in config.backbone:
|
| 34 |
+
return Diffusion(
|
| 35 |
+
config, tokenizer=tokenizer).to('cuda')
|
| 36 |
+
else:
|
| 37 |
+
model = Diffusion.load_from_checkpoint(
|
| 38 |
+
config.eval.checkpoint_path,
|
| 39 |
+
tokenizer=tokenizer,
|
| 40 |
+
config=config)
|
| 41 |
+
|
| 42 |
+
return model
|
| 43 |
+
|
| 44 |
+
@L.pytorch.utilities.rank_zero_only
|
| 45 |
+
def print_config(
|
| 46 |
+
config: omegaconf.DictConfig,
|
| 47 |
+
resolve: bool = True,
|
| 48 |
+
save_cfg: bool = True) -> None:
|
| 49 |
+
"""
|
| 50 |
+
Prints content of DictConfig using Rich library and its tree structure.
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
config (DictConfig): Configuration composed by Hydra.
|
| 54 |
+
resolve (bool): Whether to resolve reference fields of DictConfig.
|
| 55 |
+
save_cfg (bool): Whether to save the configuration tree to a file.
|
| 56 |
+
"""
|
| 57 |
+
|
| 58 |
+
style = 'dim'
|
| 59 |
+
tree = rich.tree.Tree('CONFIG', style=style, guide_style=style)
|
| 60 |
+
|
| 61 |
+
fields = config.keys()
|
| 62 |
+
for field in fields:
|
| 63 |
+
branch = tree.add(field, style=style, guide_style=style)
|
| 64 |
+
|
| 65 |
+
config_section = config.get(field)
|
| 66 |
+
branch_content = str(config_section)
|
| 67 |
+
if isinstance(config_section, omegaconf.DictConfig):
|
| 68 |
+
branch_content = omegaconf.OmegaConf.to_yaml(
|
| 69 |
+
config_section, resolve=resolve)
|
| 70 |
+
|
| 71 |
+
branch.add(rich.syntax.Syntax(branch_content, 'yaml'))
|
| 72 |
+
rich.print(tree)
|
| 73 |
+
if save_cfg:
|
| 74 |
+
with fsspec.open(
|
| 75 |
+
'{}/config_tree.txt'.format(
|
| 76 |
+
config.checkpointing.save_dir), 'w') as fp:
|
| 77 |
+
rich.print(tree, file=fp)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
@L.pytorch.utilities.rank_zero_only
|
| 81 |
+
def print_batch(train_ds, valid_ds, tokenizer, k=64):
|
| 82 |
+
#for dl_type, dl in [
|
| 83 |
+
#('train', train_ds), ('valid', valid_ds)]:
|
| 84 |
+
|
| 85 |
+
for dl_type, dl in [
|
| 86 |
+
('train', train_ds)]:
|
| 87 |
+
print(f'Printing {dl_type} dataloader batch.')
|
| 88 |
+
batch = next(iter(dl))
|
| 89 |
+
print('Batch input_ids.shape', batch['input_ids'].shape)
|
| 90 |
+
first = batch['input_ids'][0, :k]
|
| 91 |
+
last = batch['input_ids'][0, -k:]
|
| 92 |
+
print(f'First {k} tokens:', tokenizer.decode(first))
|
| 93 |
+
print('ids:', first)
|
| 94 |
+
print(f'Last {k} tokens:', tokenizer.decode(last))
|
| 95 |
+
print('ids:', last)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def generate_samples(config, logger, tokenizer):
|
| 99 |
+
logger.info('Generating samples.')
|
| 100 |
+
model = _load_from_checkpoint(config=config, tokenizer=tokenizer)
|
| 101 |
+
# model.gen_ppl_metric.reset()
|
| 102 |
+
|
| 103 |
+
#stride_length = config.sampling.stride_length
|
| 104 |
+
#num_strides = config.sampling.num_strides
|
| 105 |
+
|
| 106 |
+
for _ in range(config.sampling.num_sample_batches):
|
| 107 |
+
samples = model.restore_model_and_sample(num_steps=config.sampling.steps)
|
| 108 |
+
peptide_sequences = model.tokenizer.batch_decode(samples)
|
| 109 |
+
model.compute_generative_perplexity(peptide_sequences)
|
| 110 |
+
|
| 111 |
+
print('Peptide samples:', peptide_sequences)
|
| 112 |
+
|
| 113 |
+
print('Generative perplexity:', model.compute_masked_perplexity())
|
| 114 |
+
|
| 115 |
+
return peptide_sequences
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def ppl_eval(config, logger, tokenizer, data_module):
|
| 119 |
+
logger.info('Starting Zero Shot Eval.')
|
| 120 |
+
|
| 121 |
+
model = _load_from_checkpoint(config=config, tokenizer=tokenizer)
|
| 122 |
+
|
| 123 |
+
wandb_logger = None
|
| 124 |
+
if config.get('wandb', None) is not None:
|
| 125 |
+
wandb_logger = L.pytorch.loggers.WandbLogger(
|
| 126 |
+
config=omegaconf.OmegaConf.to_object(config),
|
| 127 |
+
** config.wandb)
|
| 128 |
+
|
| 129 |
+
callbacks = []
|
| 130 |
+
|
| 131 |
+
if 'callbacks' in config:
|
| 132 |
+
for _, callback in config.callbacks.items():
|
| 133 |
+
callbacks.append(hydra.utils.instantiate(callback))
|
| 134 |
+
|
| 135 |
+
trainer = hydra.utils.instantiate(
|
| 136 |
+
config.trainer,
|
| 137 |
+
default_root_dir=os.getcwd(),
|
| 138 |
+
callbacks=callbacks,
|
| 139 |
+
strategy=DDPStrategy(find_unused_parameters = True),
|
| 140 |
+
logger=wandb_logger)
|
| 141 |
+
|
| 142 |
+
#_, valid_ds = dataloader.get_dataloaders(config, tokenizer, skiptrain=True, valid_seed=config.seed)
|
| 143 |
+
trainer.test(model, data_module)
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def _train(config, logger, tokenizer, data_module):
|
| 147 |
+
logger.info('Starting Training.')
|
| 148 |
+
wandb_logger = None
|
| 149 |
+
|
| 150 |
+
if config.get('wandb', None) is not None:
|
| 151 |
+
unique_id = str(uuid.uuid4())
|
| 152 |
+
|
| 153 |
+
config.wandb.id = f"{config.wandb.id}_{unique_id}"
|
| 154 |
+
|
| 155 |
+
wandb_logger = L.pytorch.loggers.WandbLogger(
|
| 156 |
+
config=omegaconf.OmegaConf.to_object(config),
|
| 157 |
+
** config.wandb)
|
| 158 |
+
|
| 159 |
+
if (config.checkpointing.resume_from_ckpt
|
| 160 |
+
and config.checkpointing.resume_ckpt_path is not None
|
| 161 |
+
and utils.fsspec_exists(
|
| 162 |
+
config.checkpointing.resume_ckpt_path)):
|
| 163 |
+
ckpt_path = config.checkpointing.resume_ckpt_path
|
| 164 |
+
else:
|
| 165 |
+
ckpt_path = None
|
| 166 |
+
|
| 167 |
+
# Lightning callbacks
|
| 168 |
+
callbacks = []
|
| 169 |
+
if 'callbacks' in config:
|
| 170 |
+
for callback_name, callback_config in config.callbacks.items():
|
| 171 |
+
if callback_name == 'model_checkpoint':
|
| 172 |
+
model_checkpoint_config = {k: v for k, v in callback_config.items() if k != '_target_'}
|
| 173 |
+
callbacks.append(ModelCheckpoint(**model_checkpoint_config))
|
| 174 |
+
else:
|
| 175 |
+
callbacks.append(hydra.utils.instantiate(callback_config))
|
| 176 |
+
|
| 177 |
+
if config.training.accumulator:
|
| 178 |
+
accumulator = GradientAccumulationScheduler(scheduling = {1: 5, 2: 4, 3: 3, 4: 1})
|
| 179 |
+
callbacks.append(accumulator)
|
| 180 |
+
|
| 181 |
+
trainer = hydra.utils.instantiate(
|
| 182 |
+
config.trainer,
|
| 183 |
+
default_root_dir=os.getcwd(),
|
| 184 |
+
callbacks=callbacks,
|
| 185 |
+
accelerator='cuda',
|
| 186 |
+
strategy=DDPStrategy(find_unused_parameters = True),
|
| 187 |
+
devices=[2,3,4,5,6,7],
|
| 188 |
+
logger=wandb_logger)
|
| 189 |
+
|
| 190 |
+
model = Diffusion(config, tokenizer=tokenizer)
|
| 191 |
+
|
| 192 |
+
if config.backbone == 'finetune_roformer':
|
| 193 |
+
checkpoint = torch.load(ckpt_path, map_location='cpu')
|
| 194 |
+
model.load_state_dict(checkpoint['state_dict'])
|
| 195 |
+
|
| 196 |
+
trainer.fit(model, datamodule=data_module, ckpt_path=ckpt_path)
|
| 197 |
+
|
| 198 |
+
@hydra.main(version_base=None, config_path=f'{os.getcwd()}/src', config_name='config')
|
| 199 |
+
def main(config):
|
| 200 |
+
"""
|
| 201 |
+
Main entry point for training
|
| 202 |
+
"""
|
| 203 |
+
wandb.init(project="peptune")
|
| 204 |
+
L.seed_everything(config.seed)
|
| 205 |
+
|
| 206 |
+
# print_config(config, resolve=True, save_cfg=True)
|
| 207 |
+
|
| 208 |
+
logger = utils.get_logger(__name__)
|
| 209 |
+
# load PeptideCLM tokenizer
|
| 210 |
+
|
| 211 |
+
tokenizer = SMILES_SPE_Tokenizer(f'{config.base_path}/src/tokenizer/new_vocab.txt',
|
| 212 |
+
f'{config.base_path}/src/tokenizer/new_splits.txt')
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
data_module = dynamic_dataloader.CustomDataModule(f'{config.base_path}/data/peptide_data', tokenizer)
|
| 216 |
+
|
| 217 |
+
if config.mode == 'sample_eval':
|
| 218 |
+
generate_samples(config, logger, tokenizer)
|
| 219 |
+
elif config.mode == 'ppl_eval':
|
| 220 |
+
ppl_eval(config, logger, tokenizer, data_module)
|
| 221 |
+
else:
|
| 222 |
+
_train(config, logger, tokenizer, data_module)
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
if __name__ == '__main__':
|
| 226 |
+
main()
|
src/utils/app.py
ADDED
|
@@ -0,0 +1,1255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import re
|
| 3 |
+
import pandas as pd
|
| 4 |
+
from io import StringIO
|
| 5 |
+
import rdkit
|
| 6 |
+
from rdkit import Chem
|
| 7 |
+
from rdkit.Chem import AllChem, Draw
|
| 8 |
+
import numpy as np
|
| 9 |
+
from PIL import Image, ImageDraw, ImageFont
|
| 10 |
+
import matplotlib.pyplot as plt
|
| 11 |
+
import matplotlib.patches as patches
|
| 12 |
+
from io import BytesIO
|
| 13 |
+
import tempfile
|
| 14 |
+
from rdkit import Chem
|
| 15 |
+
|
| 16 |
+
class PeptideAnalyzer:
|
| 17 |
+
def __init__(self):
|
| 18 |
+
self.bond_patterns = [
|
| 19 |
+
(r'OC\(=O\)', 'ester'), # Ester bond
|
| 20 |
+
(r'N\(C\)C\(=O\)', 'n_methyl'), # N-methylated peptide bond
|
| 21 |
+
(r'N[0-9]C\(=O\)', 'proline'), # Proline peptide bond
|
| 22 |
+
(r'NC\(=O\)', 'peptide'), # Standard peptide bond
|
| 23 |
+
(r'C\(=O\)N\(C\)', 'n_methyl_reverse'), # Reverse N-methylated
|
| 24 |
+
(r'C\(=O\)N[12]?', 'peptide_reverse') # Reverse peptide bond
|
| 25 |
+
]
|
| 26 |
+
# Three to one letter code mapping
|
| 27 |
+
self.three_to_one = {
|
| 28 |
+
'Ala': 'A', 'Cys': 'C', 'Asp': 'D', 'Glu': 'E',
|
| 29 |
+
'Phe': 'F', 'Gly': 'G', 'His': 'H', 'Ile': 'I',
|
| 30 |
+
'Lys': 'K', 'Leu': 'L', 'Met': 'M', 'Asn': 'N',
|
| 31 |
+
'Pro': 'P', 'Gln': 'Q', 'Arg': 'R', 'Ser': 'S',
|
| 32 |
+
'Thr': 'T', 'Val': 'V', 'Trp': 'W', 'Tyr': 'Y'
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
def is_peptide(self, smiles):
|
| 36 |
+
"""Check if the SMILES represents a peptide structure"""
|
| 37 |
+
mol = Chem.MolFromSmiles(smiles)
|
| 38 |
+
if mol is None:
|
| 39 |
+
return False
|
| 40 |
+
|
| 41 |
+
# Look for peptide bonds: NC(=O) pattern
|
| 42 |
+
peptide_bond_pattern = Chem.MolFromSmarts('[NH][C](=O)')
|
| 43 |
+
if mol.HasSubstructMatch(peptide_bond_pattern):
|
| 44 |
+
return True
|
| 45 |
+
|
| 46 |
+
# Look for N-methylated peptide bonds: N(C)C(=O) pattern
|
| 47 |
+
n_methyl_pattern = Chem.MolFromSmarts('[N;H0;$(NC)](C)[C](=O)')
|
| 48 |
+
if mol.HasSubstructMatch(n_methyl_pattern):
|
| 49 |
+
return True
|
| 50 |
+
|
| 51 |
+
return False
|
| 52 |
+
|
| 53 |
+
def is_cyclic(self, smiles):
|
| 54 |
+
"""Improved cyclic peptide detection"""
|
| 55 |
+
# Check for C-terminal carboxyl
|
| 56 |
+
if smiles.endswith('C(=O)O'):
|
| 57 |
+
return False, [], []
|
| 58 |
+
|
| 59 |
+
# Find all numbers used in ring closures
|
| 60 |
+
ring_numbers = re.findall(r'(?:^|[^c])[0-9](?=[A-Z@\(\)])', smiles)
|
| 61 |
+
|
| 62 |
+
# Find aromatic ring numbers
|
| 63 |
+
aromatic_matches = re.findall(r'c[0-9](?:ccccc|c\[nH\]c)[0-9]', smiles)
|
| 64 |
+
aromatic_cycles = []
|
| 65 |
+
for match in aromatic_matches:
|
| 66 |
+
numbers = re.findall(r'[0-9]', match)
|
| 67 |
+
aromatic_cycles.extend(numbers)
|
| 68 |
+
|
| 69 |
+
# Numbers that aren't part of aromatic rings are peptide cycles
|
| 70 |
+
peptide_cycles = [n for n in ring_numbers if n not in aromatic_cycles]
|
| 71 |
+
|
| 72 |
+
is_cyclic = len(peptide_cycles) > 0 and not smiles.endswith('C(=O)O')
|
| 73 |
+
return is_cyclic, peptide_cycles, aromatic_cycles
|
| 74 |
+
|
| 75 |
+
def split_on_bonds(self, smiles):
|
| 76 |
+
"""Split SMILES into segments with simplified Pro handling"""
|
| 77 |
+
positions = []
|
| 78 |
+
used = set()
|
| 79 |
+
|
| 80 |
+
# Find Gly pattern first
|
| 81 |
+
gly_pattern = r'NCC\(=O\)'
|
| 82 |
+
for match in re.finditer(gly_pattern, smiles):
|
| 83 |
+
if not any(p in range(match.start(), match.end()) for p in used):
|
| 84 |
+
positions.append({
|
| 85 |
+
'start': match.start(),
|
| 86 |
+
'end': match.end(),
|
| 87 |
+
'type': 'gly',
|
| 88 |
+
'pattern': match.group()
|
| 89 |
+
})
|
| 90 |
+
used.update(range(match.start(), match.end()))
|
| 91 |
+
|
| 92 |
+
for pattern, bond_type in self.bond_patterns:
|
| 93 |
+
for match in re.finditer(pattern, smiles):
|
| 94 |
+
if not any(p in range(match.start(), match.end()) for p in used):
|
| 95 |
+
positions.append({
|
| 96 |
+
'start': match.start(),
|
| 97 |
+
'end': match.end(),
|
| 98 |
+
'type': bond_type,
|
| 99 |
+
'pattern': match.group()
|
| 100 |
+
})
|
| 101 |
+
used.update(range(match.start(), match.end()))
|
| 102 |
+
|
| 103 |
+
# Sort by position
|
| 104 |
+
positions.sort(key=lambda x: x['start'])
|
| 105 |
+
|
| 106 |
+
# Create segments
|
| 107 |
+
segments = []
|
| 108 |
+
|
| 109 |
+
if positions:
|
| 110 |
+
# First segment
|
| 111 |
+
if positions[0]['start'] > 0:
|
| 112 |
+
segments.append({
|
| 113 |
+
'content': smiles[0:positions[0]['start']],
|
| 114 |
+
'bond_after': positions[0]['pattern']
|
| 115 |
+
})
|
| 116 |
+
|
| 117 |
+
# Process segments
|
| 118 |
+
for i in range(len(positions)-1):
|
| 119 |
+
current = positions[i]
|
| 120 |
+
next_pos = positions[i+1]
|
| 121 |
+
|
| 122 |
+
if current['type'] == 'gly':
|
| 123 |
+
segments.append({
|
| 124 |
+
'content': 'NCC(=O)',
|
| 125 |
+
'bond_before': positions[i-1]['pattern'] if i > 0 else None,
|
| 126 |
+
'bond_after': next_pos['pattern']
|
| 127 |
+
})
|
| 128 |
+
else:
|
| 129 |
+
content = smiles[current['end']:next_pos['start']]
|
| 130 |
+
if content:
|
| 131 |
+
segments.append({
|
| 132 |
+
'content': content,
|
| 133 |
+
'bond_before': current['pattern'],
|
| 134 |
+
'bond_after': next_pos['pattern']
|
| 135 |
+
})
|
| 136 |
+
|
| 137 |
+
# Last segment
|
| 138 |
+
if positions[-1]['end'] < len(smiles):
|
| 139 |
+
segments.append({
|
| 140 |
+
'content': smiles[positions[-1]['end']:],
|
| 141 |
+
'bond_before': positions[-1]['pattern']
|
| 142 |
+
})
|
| 143 |
+
|
| 144 |
+
return segments
|
| 145 |
+
|
| 146 |
+
def clean_terminal_carboxyl(self, segment):
|
| 147 |
+
"""Remove C-terminal carboxyl only if it's the true terminus"""
|
| 148 |
+
content = segment['content']
|
| 149 |
+
|
| 150 |
+
# Only clean if:
|
| 151 |
+
# 1. Contains C(=O)O
|
| 152 |
+
# 2. No bond_after exists (meaning it's the last segment)
|
| 153 |
+
# 3. C(=O)O is at the end of the content
|
| 154 |
+
if 'C(=O)O' in content and not segment.get('bond_after'):
|
| 155 |
+
print('recognized?')
|
| 156 |
+
# Remove C(=O)O pattern regardless of position
|
| 157 |
+
cleaned = re.sub(r'\(C\(=O\)O\)', '', content)
|
| 158 |
+
# Remove any leftover empty parentheses
|
| 159 |
+
cleaned = re.sub(r'\(\)', '', cleaned)
|
| 160 |
+
print(cleaned)
|
| 161 |
+
return cleaned
|
| 162 |
+
return content
|
| 163 |
+
|
| 164 |
+
def identify_residue(self, segment):
|
| 165 |
+
"""Identify residue with Pro reconstruction"""
|
| 166 |
+
# Only clean terminal carboxyl if this is the last segment
|
| 167 |
+
content = self.clean_terminal_carboxyl(segment)
|
| 168 |
+
mods = self.get_modifications(segment)
|
| 169 |
+
|
| 170 |
+
# UAA pattern matching section - before regular residues
|
| 171 |
+
# Phenylglycine and derivatives
|
| 172 |
+
if 'c1ccccc1' in content:
|
| 173 |
+
if '[C@@H](c1ccccc1)' in content or '[C@H](c1ccccc1)' in content:
|
| 174 |
+
return '4', mods # Base phenylglycine
|
| 175 |
+
|
| 176 |
+
# 4-substituted phenylalanines
|
| 177 |
+
if 'Cc1ccc' in content:
|
| 178 |
+
if 'OMe' in content or 'OCc1ccc' in content:
|
| 179 |
+
return '0A1', mods # 4-methoxy-Phenylalanine
|
| 180 |
+
elif 'Clc1ccc' in content:
|
| 181 |
+
return '200', mods # 4-chloro-Phenylalanine
|
| 182 |
+
elif 'Brc1ccc' in content:
|
| 183 |
+
return '4BF', mods # 4-Bromo-phenylalanine
|
| 184 |
+
elif 'C#Nc1ccc' in content:
|
| 185 |
+
return '4CF', mods # 4-cyano-phenylalanine
|
| 186 |
+
elif 'Ic1ccc' in content:
|
| 187 |
+
return 'PHI', mods # 4-Iodo-phenylalanine
|
| 188 |
+
elif 'Fc1ccc' in content:
|
| 189 |
+
return 'PFF', mods # 4-Fluoro-phenylalanine
|
| 190 |
+
|
| 191 |
+
# Modified tryptophans
|
| 192 |
+
if 'c[nH]c2' in content:
|
| 193 |
+
if 'Oc2cccc2' in content:
|
| 194 |
+
return '0AF', mods # 7-hydroxy-tryptophan
|
| 195 |
+
elif 'Fc2cccc2' in content:
|
| 196 |
+
return '4FW', mods # 4-fluoro-tryptophan
|
| 197 |
+
elif 'Clc2cccc2' in content:
|
| 198 |
+
return '6CW', mods # 6-chloro-tryptophan
|
| 199 |
+
elif 'Brc2cccc2' in content:
|
| 200 |
+
return 'BTR', mods # 6-bromo-tryptophan
|
| 201 |
+
elif 'COc2cccc2' in content:
|
| 202 |
+
return 'MOT5', mods # 5-Methoxy-tryptophan
|
| 203 |
+
elif 'Cc2cccc2' in content:
|
| 204 |
+
return 'MTR5', mods # 5-Methyl-tryptophan
|
| 205 |
+
|
| 206 |
+
# Special amino acids
|
| 207 |
+
if 'CC(C)(C)[C@@H]' in content or 'CC(C)(C)[C@H]' in content:
|
| 208 |
+
return 'BUG', mods # Tertleucine
|
| 209 |
+
|
| 210 |
+
if 'CCCNC(=N)N' in content:
|
| 211 |
+
return 'CIR', mods # Citrulline
|
| 212 |
+
|
| 213 |
+
if '[SeH]' in content:
|
| 214 |
+
return 'CSE', mods # Selenocysteine
|
| 215 |
+
|
| 216 |
+
if '[NH3]CC[C@@H]' in content or '[NH3]CC[C@H]' in content:
|
| 217 |
+
return 'DAB', mods # Diaminobutyric acid
|
| 218 |
+
|
| 219 |
+
if 'C1CCCCC1' in content:
|
| 220 |
+
if 'C1CCCCC1[C@@H]' in content or 'C1CCCCC1[C@H]' in content:
|
| 221 |
+
return 'CHG', mods # Cyclohexylglycine
|
| 222 |
+
elif 'C1CCCCC1C[C@@H]' in content or 'C1CCCCC1C[C@H]' in content:
|
| 223 |
+
return 'ALC', mods # 3-cyclohexyl-alanine
|
| 224 |
+
|
| 225 |
+
# Naphthalene derivatives
|
| 226 |
+
if 'c1cccc2c1cccc2' in content:
|
| 227 |
+
if 'c1cccc2c1cccc2[C@@H]' in content or 'c1cccc2c1cccc2[C@H]' in content:
|
| 228 |
+
return 'NAL', mods # 2-Naphthyl-alanine
|
| 229 |
+
|
| 230 |
+
# Heteroaromatic derivatives
|
| 231 |
+
if 'c1cncc' in content:
|
| 232 |
+
return 'PYR4', mods # 3-(4-Pyridyl)-alanine
|
| 233 |
+
if 'c1cscc' in content:
|
| 234 |
+
return 'THA3', mods # 3-(3-thienyl)-alanine
|
| 235 |
+
if 'c1nnc' in content:
|
| 236 |
+
return 'TRZ4', mods # 3-(1,2,4-Triazol-1-yl)-alanine
|
| 237 |
+
|
| 238 |
+
# Modified serines and threonines
|
| 239 |
+
if 'OP(O)(O)O' in content:
|
| 240 |
+
if '[C@@H](COP' in content or '[C@H](COP' in content:
|
| 241 |
+
return 'SEP', mods # phosphoserine
|
| 242 |
+
elif '[C@@H](OP' in content or '[C@H](OP' in content:
|
| 243 |
+
return 'TPO', mods # phosphothreonine
|
| 244 |
+
|
| 245 |
+
# Specialized ring systems
|
| 246 |
+
if 'c1c2ccccc2cc2c1cccc2' in content:
|
| 247 |
+
return 'ANTH', mods # 3-(9-anthryl)-alanine
|
| 248 |
+
if 'c1csc2c1cccc2' in content:
|
| 249 |
+
return 'BTH3', mods # 3-(3-benzothienyl)-alanine
|
| 250 |
+
if '[C@]12C[C@H]3C[C@@H](C2)C[C@@H](C1)C3' in content:
|
| 251 |
+
return 'ADAM', mods # Adamanthane
|
| 252 |
+
|
| 253 |
+
# Fluorinated derivatives
|
| 254 |
+
if 'FC(F)(F)' in content:
|
| 255 |
+
if 'CC(F)(F)F' in content:
|
| 256 |
+
return 'FLA', mods # Trifluoro-alanine
|
| 257 |
+
if 'C(F)(F)F)c1' in content:
|
| 258 |
+
if 'c1ccccc1C(F)(F)F' in content:
|
| 259 |
+
return 'TFG2', mods # 2-(Trifluoromethyl)-phenylglycine
|
| 260 |
+
if 'c1cccc(c1)C(F)(F)F' in content:
|
| 261 |
+
return 'TFG3', mods # 3-(Trifluoromethyl)-phenylglycine
|
| 262 |
+
if 'c1ccc(cc1)C(F)(F)F' in content:
|
| 263 |
+
return 'TFG4', mods # 4-(Trifluoromethyl)-phenylglycine
|
| 264 |
+
|
| 265 |
+
# Multiple halogen patterns
|
| 266 |
+
if 'F' in content and 'c1' in content:
|
| 267 |
+
if 'c1ccc(c(c1)F)F' in content:
|
| 268 |
+
return 'F2F', mods # 3,4-Difluoro-phenylalanine
|
| 269 |
+
if 'cc(F)cc(c1)F' in content:
|
| 270 |
+
return 'WFP', mods # 3,5-Difluoro-phenylalanine
|
| 271 |
+
if 'Cl' in content and 'c1' in content:
|
| 272 |
+
if 'c1ccc(cc1Cl)Cl' in content:
|
| 273 |
+
return 'CP24', mods # 2,4-dichloro-phenylalanine
|
| 274 |
+
if 'c1ccc(c(c1)Cl)Cl' in content:
|
| 275 |
+
return 'CP34', mods # 3,4-dichloro-phenylalanine
|
| 276 |
+
|
| 277 |
+
# Hydroxy and amino derivatives
|
| 278 |
+
if 'O' in content and 'c1' in content:
|
| 279 |
+
if 'c1cc(O)cc(c1)O' in content:
|
| 280 |
+
return '3FG', mods # (2s)-amino(3,5-dihydroxyphenyl)-ethanoic acid
|
| 281 |
+
if 'c1ccc(c(c1)O)O' in content:
|
| 282 |
+
return 'DAH', mods # 3,4-Dihydroxy-phenylalanine
|
| 283 |
+
|
| 284 |
+
# Cyclic amino acids
|
| 285 |
+
if 'C1CCCC1' in content:
|
| 286 |
+
return 'CPA3', mods # 3-Cyclopentyl-alanine
|
| 287 |
+
if 'C1CCCCC1' in content:
|
| 288 |
+
if 'CC1CCCCC1' in content:
|
| 289 |
+
return 'ALC', mods # 3-cyclohexyl-alanine
|
| 290 |
+
else:
|
| 291 |
+
return 'CHG', mods # Cyclohexylglycine
|
| 292 |
+
|
| 293 |
+
# Chain-length variants
|
| 294 |
+
if 'CCC[C@@H]' in content or 'CCC[C@H]' in content:
|
| 295 |
+
return 'NLE', mods # Norleucine
|
| 296 |
+
if 'CC[C@@H]' in content or 'CC[C@H]' in content:
|
| 297 |
+
if not any(x in content for x in ['CC(C)', 'COC', 'CN(']):
|
| 298 |
+
return 'ABA', mods # 2-Aminobutyric acid
|
| 299 |
+
|
| 300 |
+
# Modified histidines
|
| 301 |
+
if 'c1cnc' in content:
|
| 302 |
+
if '[C@@H]1CN[C@@H](N1)F' in content:
|
| 303 |
+
return '2HF', mods # 2-fluoro-l-histidine
|
| 304 |
+
if 'c1cnc([nH]1)F' in content:
|
| 305 |
+
return '2HF1', mods # 2-fluoro-l-histidine variant
|
| 306 |
+
if 'c1c[nH]c(n1)F' in content:
|
| 307 |
+
return '2HF2', mods # 2-fluoro-l-histidine variant
|
| 308 |
+
|
| 309 |
+
# Sulfur and selenium containing
|
| 310 |
+
if '[SeH]' in content:
|
| 311 |
+
return 'CSE', mods # Selenocysteine
|
| 312 |
+
if 'S' in content:
|
| 313 |
+
if 'CSCc1ccccc1' in content:
|
| 314 |
+
return 'BCS', mods # benzylcysteine
|
| 315 |
+
if 'CCSC' in content:
|
| 316 |
+
return 'ESC', mods # Ethionine
|
| 317 |
+
if 'CCS' in content:
|
| 318 |
+
return 'HCS', mods # homocysteine
|
| 319 |
+
|
| 320 |
+
# Additional modifications
|
| 321 |
+
if 'CN=[N]=N' in content:
|
| 322 |
+
return 'AZDA', mods # azido-alanine
|
| 323 |
+
if '[NH]=[C](=[NH2])=[NH2]' in content:
|
| 324 |
+
if 'CCC[NH]=' in content:
|
| 325 |
+
return 'AGM', mods # 5-methyl-arginine
|
| 326 |
+
if 'CC[NH]=' in content:
|
| 327 |
+
return 'GDPR', mods # 2-Amino-3-guanidinopropionic acid
|
| 328 |
+
|
| 329 |
+
if 'CCON' in content:
|
| 330 |
+
return 'CAN', mods # canaline
|
| 331 |
+
if '[C@@H]1C=C[C@@H](C=C1)' in content:
|
| 332 |
+
return 'ACZ', mods # cis-amiclenomycin
|
| 333 |
+
if 'CCC(=O)[NH3]' in content:
|
| 334 |
+
return 'ONL', mods # 5-oxo-l-norleucine
|
| 335 |
+
if 'c1ccncc1' in content:
|
| 336 |
+
return 'PYR4', mods # 3-(4-Pyridyl)-alanine
|
| 337 |
+
if 'c1ccco1' in content:
|
| 338 |
+
return 'FUA2', mods # (2-furyl)-alanine
|
| 339 |
+
|
| 340 |
+
if 'c1ccc' in content:
|
| 341 |
+
if 'c1ccc(cc1)c1ccccc1' in content:
|
| 342 |
+
return 'BIF', mods # 4,4-biphenylalanine
|
| 343 |
+
if 'c1ccc(cc1)C(=O)c1ccccc1' in content:
|
| 344 |
+
return 'PBF', mods # 4-benzoyl-phenylalanine
|
| 345 |
+
if 'c1ccc(cc1)C(C)(C)C' in content:
|
| 346 |
+
return 'TBP4', mods # 4-tert-butyl-phenylalanine
|
| 347 |
+
if 'c1ccc(cc1)[C](=[NH2])=[NH2]' in content:
|
| 348 |
+
return '0BN', mods # 4-carbamimidoyl-l-phenylalanine
|
| 349 |
+
if 'c1cccc(c1)[C](=[NH2])=[NH2]' in content:
|
| 350 |
+
return 'APM', mods # m-amidinophenyl-3-alanine
|
| 351 |
+
|
| 352 |
+
# Multiple hydroxy patterns
|
| 353 |
+
if 'O' in content:
|
| 354 |
+
if '[C@H]([C@H](C)O)O' in content:
|
| 355 |
+
return 'ILX', mods # 4,5-dihydroxy-isoleucine
|
| 356 |
+
if '[C@H]([C@@H](C)O)O' in content:
|
| 357 |
+
return 'ALO', mods # Allo-threonine
|
| 358 |
+
if '[C@H](COP(O)(O)O)' in content:
|
| 359 |
+
return 'SEP', mods # phosphoserine
|
| 360 |
+
if '[C@H]([C@@H](C)OP(O)(O)O)' in content:
|
| 361 |
+
return 'TPO', mods # phosphothreonine
|
| 362 |
+
if '[C@H](c1ccc(O)cc1)O' in content:
|
| 363 |
+
return 'OMX', mods # (betar)-beta-hydroxy-l-tyrosine
|
| 364 |
+
if '[C@H](c1ccc(c(Cl)c1)O)O' in content:
|
| 365 |
+
return 'OMY', mods # (betar)-3-chloro-beta-hydroxy-l-tyrosine
|
| 366 |
+
|
| 367 |
+
# Heterocyclic patterns
|
| 368 |
+
if 'n1' in content:
|
| 369 |
+
if 'n1cccn1' in content:
|
| 370 |
+
return 'PYZ1', mods # 3-(1-Pyrazolyl)-alanine
|
| 371 |
+
if 'n1nncn1' in content:
|
| 372 |
+
return 'TEZA', mods # 3-(2-Tetrazolyl)-alanine
|
| 373 |
+
if 'c2c(n1)cccc2' in content:
|
| 374 |
+
return 'QU32', mods # 3-(2-Quinolyl)-alanine
|
| 375 |
+
if 'c1cnc2c(c1)cccc2' in content:
|
| 376 |
+
return 'QU33', mods # 3-(3-quinolyl)-alanine
|
| 377 |
+
if 'c1ccnc2c1cccc2' in content:
|
| 378 |
+
return 'QU34', mods # 3-(4-quinolyl)-alanine
|
| 379 |
+
if 'c1ccc2c(c1)nccc2' in content:
|
| 380 |
+
return 'QU35', mods # 3-(5-Quinolyl)-alanine
|
| 381 |
+
if 'c1ccc2c(c1)cncc2' in content:
|
| 382 |
+
return 'QU36', mods # 3-(6-Quinolyl)-alanine
|
| 383 |
+
if 'c1cnc2c(n1)cccc2' in content:
|
| 384 |
+
return 'QX32', mods # 3-(2-quinoxalyl)-alanine
|
| 385 |
+
|
| 386 |
+
# Multiple nitrogen patterns
|
| 387 |
+
if 'N' in content:
|
| 388 |
+
if '[NH3]CC[C@@H]' in content:
|
| 389 |
+
return 'DAB', mods # Diaminobutyric acid
|
| 390 |
+
if '[NH3]C[C@@H]' in content:
|
| 391 |
+
return 'DPP', mods # 2,3-Diaminopropanoic acid
|
| 392 |
+
if '[NH3]CCCCCC[C@@H]' in content:
|
| 393 |
+
return 'HHK', mods # (2s)-2,8-diaminooctanoic acid
|
| 394 |
+
if 'CCC[NH]=[C](=[NH2])=[NH2]' in content:
|
| 395 |
+
return 'GBUT', mods # 2-Amino-4-guanidinobutryric acid
|
| 396 |
+
if '[NH]=[C](=S)=[NH2]' in content:
|
| 397 |
+
return 'THIC', mods # Thio-citrulline
|
| 398 |
+
|
| 399 |
+
# Chain modified amino acids
|
| 400 |
+
if 'CC' in content:
|
| 401 |
+
if 'CCCC[C@@H]' in content:
|
| 402 |
+
return 'AHP', mods # 2-Aminoheptanoic acid
|
| 403 |
+
if 'CCC([C@@H])(C)C' in content:
|
| 404 |
+
return 'I2M', mods # 3-methyl-l-alloisoleucine
|
| 405 |
+
if 'CC[C@H]([C@@H])C' in content:
|
| 406 |
+
return 'IIL', mods # Allo-Isoleucine
|
| 407 |
+
if '[C@H](CCC(C)C)' in content:
|
| 408 |
+
return 'HLEU', mods # Homoleucine
|
| 409 |
+
if '[C@@H]([C@@H](C)O)C' in content:
|
| 410 |
+
return 'HLU', mods # beta-hydroxyleucine
|
| 411 |
+
|
| 412 |
+
# Modified glutamate/aspartate patterns
|
| 413 |
+
if '[C@@H]' in content:
|
| 414 |
+
if '[C@@H](C[C@@H](F))' in content:
|
| 415 |
+
return 'FGA4', mods # 4-Fluoro-glutamic acid
|
| 416 |
+
if '[C@@H](C[C@@H](O))' in content:
|
| 417 |
+
return '3GL', mods # 4-hydroxy-glutamic-acid
|
| 418 |
+
if '[C@@H](C[C@H](C))' in content:
|
| 419 |
+
return 'LME', mods # (3r)-3-methyl-l-glutamic acid
|
| 420 |
+
if '[C@@H](CC[C@H](C))' in content:
|
| 421 |
+
return 'MEG', mods # (3s)-3-methyl-l-glutamic acid
|
| 422 |
+
|
| 423 |
+
# Sulfur and selenium modifications
|
| 424 |
+
if 'S' in content:
|
| 425 |
+
if 'SCC[C@@H]' in content:
|
| 426 |
+
return 'HSER', mods # homoserine
|
| 427 |
+
if 'SCCN' in content:
|
| 428 |
+
return 'SLZ', mods # thialysine
|
| 429 |
+
if 'SC(=O)' in content:
|
| 430 |
+
return 'CSA', mods # s-acetonylcysteine
|
| 431 |
+
if '[S@@](=O)' in content:
|
| 432 |
+
return 'SME', mods # Methionine sulfoxide
|
| 433 |
+
if 'S(=O)(=O)' in content:
|
| 434 |
+
return 'OMT', mods # Methionine sulfone
|
| 435 |
+
|
| 436 |
+
# Double bond containing
|
| 437 |
+
if 'C=' in content:
|
| 438 |
+
if 'C=C[C@@H]' in content:
|
| 439 |
+
return '2AG', mods # 2-Allyl-glycine
|
| 440 |
+
if 'C=C[C@@H]' in content:
|
| 441 |
+
return 'LVG', mods # vinylglycine
|
| 442 |
+
if 'C=Cc1ccccc1' in content:
|
| 443 |
+
return 'STYA', mods # Styrylalanine
|
| 444 |
+
|
| 445 |
+
# Special cases
|
| 446 |
+
if '[C@@H]1Cc2c(C1)cccc2' in content:
|
| 447 |
+
return 'IGL', mods # alpha-amino-2-indanacetic acid
|
| 448 |
+
if '[C](=[C](=O)=O)=O' in content:
|
| 449 |
+
return '26P', mods # 2-amino-6-oxopimelic acid
|
| 450 |
+
if '[C](=[C](=O)=O)=C' in content:
|
| 451 |
+
return '2NP', mods # l-2-amino-6-methylene-pimelic acid
|
| 452 |
+
if 'c2cnc[nH]2' in content:
|
| 453 |
+
return 'HIS', mods # histidine core
|
| 454 |
+
if 'c1cccc2c1cc(O)cc2' in content:
|
| 455 |
+
return 'NAO1', mods # 5-hydroxy-1-naphthalene
|
| 456 |
+
if 'c1ccc2c(c1)cc(O)cc2' in content:
|
| 457 |
+
return 'NAO2', mods # 6-hydroxy-2-naphthalene
|
| 458 |
+
|
| 459 |
+
# Proline (P) - flexible ring numbers
|
| 460 |
+
if any([
|
| 461 |
+
# Check for any ring number in bond patterns
|
| 462 |
+
(segment.get('bond_after', '').startswith(f'N{n}C(=O)') and 'CCC' in content and
|
| 463 |
+
any(f'[C@@H]{n}' in content or f'[C@H]{n}' in content for n in '123456789'))
|
| 464 |
+
for n in '123456789'
|
| 465 |
+
]) or any([
|
| 466 |
+
# Check ending patterns with any ring number
|
| 467 |
+
(f'CCCN{n}' in content and content.endswith('=O') and
|
| 468 |
+
any(f'[C@@H]{n}' in content or f'[C@H]{n}' in content for n in '123456789'))
|
| 469 |
+
for n in '123456789'
|
| 470 |
+
]) or any([
|
| 471 |
+
# Handle CCC[C@H]n patterns
|
| 472 |
+
(content == f'CCC[C@H]{n}' and segment.get('bond_before', '').startswith(f'C(=O)N{n}')) or
|
| 473 |
+
(content == f'CCC[C@@H]{n}' and segment.get('bond_before', '').startswith(f'C(=O)N{n}')) or
|
| 474 |
+
# N-terminal Pro with any ring number
|
| 475 |
+
(f'N{n}CCC[C@H]{n}' in content) or
|
| 476 |
+
(f'N{n}CCC[C@@H]{n}' in content)
|
| 477 |
+
for n in '123456789'
|
| 478 |
+
]):
|
| 479 |
+
return 'Pro', mods
|
| 480 |
+
|
| 481 |
+
# Tryptophan (W) - more specific indole pattern
|
| 482 |
+
if re.search(r'c[0-9]c\[nH\]c[0-9]ccccc[0-9][0-9]', content) and \
|
| 483 |
+
'c[nH]c' in content.replace(' ', ''):
|
| 484 |
+
return 'Trp', mods
|
| 485 |
+
|
| 486 |
+
# Lysine (K) - both patterns
|
| 487 |
+
if '[C@@H](CCCCN)' in content or '[C@H](CCCCN)' in content:
|
| 488 |
+
return 'Lys', mods
|
| 489 |
+
|
| 490 |
+
# Arginine (R) - both patterns
|
| 491 |
+
if '[C@@H](CCCNC(=N)N)' in content or '[C@H](CCCNC(=N)N)' in content:
|
| 492 |
+
return 'Arg', mods
|
| 493 |
+
|
| 494 |
+
if ('C[C@H](CCCC)' in content or 'C[C@@H](CCCC)' in content) and 'CC(C)' not in content:
|
| 495 |
+
return 'Nle', mods
|
| 496 |
+
|
| 497 |
+
# Ornithine (Orn) - 3-carbon chain with NH2
|
| 498 |
+
if ('C[C@H](CCCN)' in content or 'C[C@@H](CCCN)' in content) and 'CC(C)' not in content:
|
| 499 |
+
return 'Orn', mods
|
| 500 |
+
|
| 501 |
+
# 2-Naphthylalanine (2Nal) - distinct from Phe pattern
|
| 502 |
+
if ('Cc3cc2ccccc2c3' in content) and ('C[C@H]' in content or 'C[C@@H]' in content):
|
| 503 |
+
return '2Nal', mods
|
| 504 |
+
|
| 505 |
+
# Cyclohexylalanine (Cha) - already in your code but moved here for clarity
|
| 506 |
+
if 'N2CCCCC2' in content or 'CCCCC2' in content:
|
| 507 |
+
return 'Cha', mods
|
| 508 |
+
|
| 509 |
+
# Aminobutyric acid (Abu) - 2-carbon chain
|
| 510 |
+
if ('C[C@H](CC)' in content or 'C[C@@H](CC)' in content) and not any(p in content for p in ['CC(C)', 'CCCC', 'CCC(C)']):
|
| 511 |
+
return 'Abu', mods
|
| 512 |
+
|
| 513 |
+
# Pipecolic acid (Pip) - 6-membered ring like Pro
|
| 514 |
+
if ('N3CCCCC3' in content or 'CCCCC3' in content) and ('C[C@H]' in content or 'C[C@@H]' in content):
|
| 515 |
+
return 'Pip', mods
|
| 516 |
+
|
| 517 |
+
# Cyclohexylglycine (Chg) - direct cyclohexyl without CH2
|
| 518 |
+
if ('C[C@H](C1CCCCC1)' in content or 'C[C@@H](C1CCCCC1)' in content):
|
| 519 |
+
return 'Chg', mods
|
| 520 |
+
|
| 521 |
+
# 4-Fluorophenylalanine (4F-Phe)
|
| 522 |
+
if ('Cc2ccc(F)cc2' in content) and ('C[C@H]' in content or 'C[C@@H]' in content):
|
| 523 |
+
return '4F-Phe', mods
|
| 524 |
+
|
| 525 |
+
# Regular residue identification
|
| 526 |
+
if ('NCC(=O)' in content) or (content == 'C'):
|
| 527 |
+
# Middle case - between bonds
|
| 528 |
+
if segment.get('bond_before') and segment.get('bond_after'):
|
| 529 |
+
if ('C(=O)N' in segment['bond_before'] or 'C(=O)N(C)' in segment['bond_before']):
|
| 530 |
+
return 'Gly', mods
|
| 531 |
+
# Terminal case - at the end
|
| 532 |
+
elif segment.get('bond_before') and segment.get('bond_before').startswith('C(=O)N'):
|
| 533 |
+
return 'Gly', mods
|
| 534 |
+
|
| 535 |
+
if 'CC(C)C[C@H]' in content or 'CC(C)C[C@@H]' in content:
|
| 536 |
+
return 'Leu', mods
|
| 537 |
+
if '[C@@H](CC(C)C)' in content or '[C@H](CC(C)C)' in content:
|
| 538 |
+
return 'Leu', mods
|
| 539 |
+
|
| 540 |
+
if '[C@@H]([C@@H](C)O)' in content or '[C@H]([C@H](C)O)' in content:
|
| 541 |
+
return 'Thr', mods
|
| 542 |
+
|
| 543 |
+
if '[C@H](Cc2ccccc2)' in content or '[C@@H](Cc2ccccc2)' in content:
|
| 544 |
+
return 'Phe', mods
|
| 545 |
+
|
| 546 |
+
if ('[C@H](C(C)C)' in content or # With outer parentheses
|
| 547 |
+
'[C@@H](C(C)C)' in content or # With outer parentheses
|
| 548 |
+
'[C@H]C(C)C' in content or # Without outer parentheses
|
| 549 |
+
'[C@@H]C(C)C' in content): # Without outer parentheses
|
| 550 |
+
if not any(p in content for p in ['CC(C)C[C@H]', 'CC(C)C[C@@H]']): # Still check not Leu
|
| 551 |
+
return 'Val', mods
|
| 552 |
+
|
| 553 |
+
if '[C@H](COC(C)(C)C)' in content or '[C@@H](COC(C)(C)C)' in content:
|
| 554 |
+
return 'O-tBu', mods
|
| 555 |
+
|
| 556 |
+
if any([
|
| 557 |
+
'CC[C@H](C)' in content,
|
| 558 |
+
'CC[C@@H](C)' in content,
|
| 559 |
+
'C(C)C[C@H]' in content and 'CC(C)C' not in content,
|
| 560 |
+
'C(C)C[C@@H]' in content and 'CC(C)C' not in content
|
| 561 |
+
]):
|
| 562 |
+
return 'Ile', mods
|
| 563 |
+
|
| 564 |
+
if ('[C@H](C)' in content or '[C@@H](C)' in content):
|
| 565 |
+
if not any(p in content for p in ['C(C)C', 'COC', 'CN(', 'C(C)O', 'CC[C@H]', 'CC[C@@H]']):
|
| 566 |
+
return 'Ala', mods
|
| 567 |
+
|
| 568 |
+
# Tyrosine (Tyr) - 4-hydroxybenzyl side chain
|
| 569 |
+
if re.search(r'Cc[0-9]ccc\(O\)cc[0-9]', content):
|
| 570 |
+
return 'Tyr', mods
|
| 571 |
+
|
| 572 |
+
|
| 573 |
+
# Serine (Ser) - Hydroxymethyl side chain
|
| 574 |
+
if '[C@H](CO)' in content or '[C@@H](CO)' in content:
|
| 575 |
+
if not ('C(C)O' in content or 'COC' in content):
|
| 576 |
+
return 'Ser', mods
|
| 577 |
+
|
| 578 |
+
# Threonine (Thr) - 1-hydroxyethyl side chain
|
| 579 |
+
if '[C@@H]([C@@H](C)O)' in content or '[C@H]([C@H](C)O)' in content or '[C@@H](C)O' in content or '[C@H](C)O' in content:
|
| 580 |
+
return 'Thr', mods
|
| 581 |
+
|
| 582 |
+
# Cysteine (Cys) - Thiol side chain
|
| 583 |
+
if '[C@H](CS)' in content or '[C@@H](CS)' in content:
|
| 584 |
+
return 'Cys', mods
|
| 585 |
+
|
| 586 |
+
# Methionine (Met) - Methylthioethyl side chain
|
| 587 |
+
if ('C[C@H](CCSC)' in content or 'C[C@@H](CCSC)' in content):
|
| 588 |
+
return 'Met', mods
|
| 589 |
+
|
| 590 |
+
# Asparagine (Asn) - Carbamoylmethyl side chain
|
| 591 |
+
if ('CC(=O)N' in content) and ('C[C@H]' in content or 'C[C@@H]' in content):
|
| 592 |
+
return 'Asn', mods
|
| 593 |
+
|
| 594 |
+
# Glutamine (Gln) - Carbamoylethyl side chain
|
| 595 |
+
if ('CCC(=O)N' in content) and ('C[C@H]' in content or 'C[C@@H]' in content):
|
| 596 |
+
return 'Gln', mods
|
| 597 |
+
|
| 598 |
+
# Aspartic acid (Asp) - Carboxymethyl side chain
|
| 599 |
+
if ('CC(=O)O' in content) and ('C[C@H]' in content or 'C[C@@H]' in content):
|
| 600 |
+
return 'Asp', mods
|
| 601 |
+
|
| 602 |
+
# Glutamic acid (Glu) - Carboxyethyl side chain
|
| 603 |
+
if ('CCC(=O)O' in content) and ('C[C@H]' in content or 'C[C@@H]' in content):
|
| 604 |
+
return 'Glu', mods
|
| 605 |
+
|
| 606 |
+
# Arginine (Arg) - 3-guanidinopropyl side chain
|
| 607 |
+
if ('CCCNC(=N)N' in content) and ('C[C@H]' in content or 'C[C@@H]' in content):
|
| 608 |
+
return 'Arg', mods
|
| 609 |
+
|
| 610 |
+
# Histidine (His) - Imidazole side chain
|
| 611 |
+
if ('Cc2cnc[nH]2' in content) and ('C[C@H]' in content or 'C[C@@H]' in content):
|
| 612 |
+
return 'His', mods
|
| 613 |
+
|
| 614 |
+
return None, mods
|
| 615 |
+
|
| 616 |
+
def get_modifications(self, segment):
|
| 617 |
+
"""Get modifications based on bond types"""
|
| 618 |
+
mods = []
|
| 619 |
+
if segment.get('bond_after'):
|
| 620 |
+
if 'N(C)' in segment['bond_after'] or segment['bond_after'].startswith('C(=O)N(C)'):
|
| 621 |
+
mods.append('N-Me')
|
| 622 |
+
if 'OC(=O)' in segment['bond_after']:
|
| 623 |
+
mods.append('O-linked')
|
| 624 |
+
return mods
|
| 625 |
+
|
| 626 |
+
def analyze_structure(self, smiles):
|
| 627 |
+
"""Main analysis function with debug output"""
|
| 628 |
+
print("\nAnalyzing structure:", smiles)
|
| 629 |
+
|
| 630 |
+
# Split into segments
|
| 631 |
+
segments = self.split_on_bonds(smiles)
|
| 632 |
+
|
| 633 |
+
print("\nSegment Analysis:")
|
| 634 |
+
sequence = []
|
| 635 |
+
for i, segment in enumerate(segments):
|
| 636 |
+
print(f"\nSegment {i}:")
|
| 637 |
+
print(f"Content: {segment['content']}")
|
| 638 |
+
print(f"Bond before: {segment.get('bond_before', 'None')}")
|
| 639 |
+
print(f"Bond after: {segment.get('bond_after', 'None')}")
|
| 640 |
+
|
| 641 |
+
residue, mods = self.identify_residue(segment)
|
| 642 |
+
if residue:
|
| 643 |
+
if mods:
|
| 644 |
+
sequence.append(f"{residue}({','.join(mods)})")
|
| 645 |
+
else:
|
| 646 |
+
sequence.append(residue)
|
| 647 |
+
print(f"Identified as: {residue}")
|
| 648 |
+
print(f"Modifications: {mods}")
|
| 649 |
+
else:
|
| 650 |
+
print(f"Warning: Could not identify residue in segment: {segment['content']}")
|
| 651 |
+
|
| 652 |
+
# Check if cyclic
|
| 653 |
+
is_cyclic, peptide_cycles, aromatic_cycles = self.is_cyclic(smiles)
|
| 654 |
+
three_letter = '-'.join(sequence)
|
| 655 |
+
one_letter = ''.join(self.three_to_one.get(aa.split('(')[0], 'X') for aa in sequence)
|
| 656 |
+
|
| 657 |
+
if is_cyclic:
|
| 658 |
+
three_letter = f"cyclo({three_letter})"
|
| 659 |
+
one_letter = f"cyclo({one_letter})"
|
| 660 |
+
|
| 661 |
+
print(f"\nFinal sequence: {three_letter}")
|
| 662 |
+
print(f"One-letter code: {one_letter}")
|
| 663 |
+
print(f"Is cyclic: {is_cyclic}")
|
| 664 |
+
#print(f"Peptide cycles: {peptide_cycles}")
|
| 665 |
+
#print(f"Aromatic cycles: {aromatic_cycles}")
|
| 666 |
+
|
| 667 |
+
return three_letter, len(segments)
|
| 668 |
+
"""return {
|
| 669 |
+
'three_letter': three_letter,
|
| 670 |
+
#'one_letter': one_letter,
|
| 671 |
+
'is_cyclic': is_cyclic
|
| 672 |
+
}"""
|
| 673 |
+
|
| 674 |
+
def return_sequence(self, smiles):
|
| 675 |
+
"""Main analysis function with debug output"""
|
| 676 |
+
print("\nAnalyzing structure:", smiles)
|
| 677 |
+
|
| 678 |
+
# Split into segments
|
| 679 |
+
segments = self.split_on_bonds(smiles)
|
| 680 |
+
|
| 681 |
+
print("\nSegment Analysis:")
|
| 682 |
+
sequence = []
|
| 683 |
+
for i, segment in enumerate(segments):
|
| 684 |
+
print(f"\nSegment {i}:")
|
| 685 |
+
print(f"Content: {segment['content']}")
|
| 686 |
+
print(f"Bond before: {segment.get('bond_before', 'None')}")
|
| 687 |
+
print(f"Bond after: {segment.get('bond_after', 'None')}")
|
| 688 |
+
|
| 689 |
+
residue, mods = self.identify_residue(segment)
|
| 690 |
+
if residue:
|
| 691 |
+
if mods:
|
| 692 |
+
sequence.append(f"{residue}({','.join(mods)})")
|
| 693 |
+
else:
|
| 694 |
+
sequence.append(residue)
|
| 695 |
+
print(f"Identified as: {residue}")
|
| 696 |
+
print(f"Modifications: {mods}")
|
| 697 |
+
else:
|
| 698 |
+
print(f"Warning: Could not identify residue in segment: {segment['content']}")
|
| 699 |
+
|
| 700 |
+
return sequence
|
| 701 |
+
|
| 702 |
+
"""
|
| 703 |
+
def annotate_cyclic_structure(mol, sequence):
|
| 704 |
+
'''Create annotated 2D structure with clear, non-overlapping residue labels'''
|
| 705 |
+
# Generate 2D coordinates
|
| 706 |
+
# Generate 2D coordinates
|
| 707 |
+
AllChem.Compute2DCoords(mol)
|
| 708 |
+
|
| 709 |
+
# Create drawer with larger size for annotations
|
| 710 |
+
drawer = Draw.rdMolDraw2D.MolDraw2DCairo(2000, 2000) # Even larger size
|
| 711 |
+
|
| 712 |
+
# Get residue list and reverse it to match structural representation
|
| 713 |
+
if sequence.startswith('cyclo('):
|
| 714 |
+
residues = sequence[6:-1].split('-')
|
| 715 |
+
else:
|
| 716 |
+
residues = sequence.split('-')
|
| 717 |
+
residues = list(reversed(residues)) # Reverse the sequence
|
| 718 |
+
|
| 719 |
+
# Draw molecule first to get its bounds
|
| 720 |
+
drawer.drawOptions().addAtomIndices = False
|
| 721 |
+
drawer.DrawMolecule(mol)
|
| 722 |
+
drawer.FinishDrawing()
|
| 723 |
+
|
| 724 |
+
# Convert to PIL Image
|
| 725 |
+
img = Image.open(BytesIO(drawer.GetDrawingText()))
|
| 726 |
+
draw = ImageDraw.Draw(img)
|
| 727 |
+
|
| 728 |
+
try:
|
| 729 |
+
# Try to use DejaVuSans as it's commonly available on Linux systems
|
| 730 |
+
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 60)
|
| 731 |
+
small_font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 60)
|
| 732 |
+
except OSError:
|
| 733 |
+
try:
|
| 734 |
+
# Fallback to Arial if available (common on Windows)
|
| 735 |
+
font = ImageFont.truetype("arial.ttf", 60)
|
| 736 |
+
small_font = ImageFont.truetype("arial.ttf", 60)
|
| 737 |
+
except OSError:
|
| 738 |
+
# If no TrueType fonts are available, fall back to default
|
| 739 |
+
print("Warning: TrueType fonts not available, using default font")
|
| 740 |
+
font = ImageFont.load_default()
|
| 741 |
+
small_font = ImageFont.load_default()
|
| 742 |
+
# Get molecule bounds
|
| 743 |
+
conf = mol.GetConformer()
|
| 744 |
+
positions = []
|
| 745 |
+
for i in range(mol.GetNumAtoms()):
|
| 746 |
+
pos = conf.GetAtomPosition(i)
|
| 747 |
+
positions.append((pos.x, pos.y))
|
| 748 |
+
|
| 749 |
+
x_coords = [p[0] for p in positions]
|
| 750 |
+
y_coords = [p[1] for p in positions]
|
| 751 |
+
min_x, max_x = min(x_coords), max(x_coords)
|
| 752 |
+
min_y, max_y = min(y_coords), max(y_coords)
|
| 753 |
+
|
| 754 |
+
# Calculate scaling factors
|
| 755 |
+
scale = 150 # Increased scale factor
|
| 756 |
+
center_x = 1000 # Image center
|
| 757 |
+
center_y = 1000
|
| 758 |
+
|
| 759 |
+
# Add residue labels in a circular arrangement around the structure
|
| 760 |
+
n_residues = len(residues)
|
| 761 |
+
radius = 700 # Distance of labels from center
|
| 762 |
+
|
| 763 |
+
# Start from the rightmost point (3 o'clock position) and go counterclockwise
|
| 764 |
+
# Offset by -3 positions to align with structure
|
| 765 |
+
offset = 0 # Adjust this value to match the structure alignment
|
| 766 |
+
for i, residue in enumerate(residues):
|
| 767 |
+
# Calculate position in a circle around the structure
|
| 768 |
+
# Start from 0 (3 o'clock) and go counterclockwise
|
| 769 |
+
angle = -(2 * np.pi * ((i + offset) % n_residues) / n_residues)
|
| 770 |
+
|
| 771 |
+
# Calculate label position
|
| 772 |
+
label_x = center_x + radius * np.cos(angle)
|
| 773 |
+
label_y = center_y + radius * np.sin(angle)
|
| 774 |
+
|
| 775 |
+
# Draw residue label
|
| 776 |
+
text = f"{i+1}. {residue}"
|
| 777 |
+
bbox = draw.textbbox((label_x, label_y), text, font=font)
|
| 778 |
+
padding = 10
|
| 779 |
+
draw.rectangle([bbox[0]-padding, bbox[1]-padding,
|
| 780 |
+
bbox[2]+padding, bbox[3]+padding],
|
| 781 |
+
fill='white', outline='white')
|
| 782 |
+
draw.text((label_x, label_y), text,
|
| 783 |
+
font=font, fill='black', anchor="mm")
|
| 784 |
+
|
| 785 |
+
# Add sequence at the top with white background
|
| 786 |
+
seq_text = f"Sequence: {sequence}"
|
| 787 |
+
bbox = draw.textbbox((center_x, 100), seq_text, font=small_font)
|
| 788 |
+
padding = 10
|
| 789 |
+
draw.rectangle([bbox[0]-padding, bbox[1]-padding,
|
| 790 |
+
bbox[2]+padding, bbox[3]+padding],
|
| 791 |
+
fill='white', outline='white')
|
| 792 |
+
draw.text((center_x, 100), seq_text,
|
| 793 |
+
font=small_font, fill='black', anchor="mm")
|
| 794 |
+
|
| 795 |
+
return img
|
| 796 |
+
|
| 797 |
+
"""
|
| 798 |
+
def annotate_cyclic_structure(mol, sequence):
|
| 799 |
+
"""Create structure visualization with just the sequence header"""
|
| 800 |
+
# Generate 2D coordinates
|
| 801 |
+
AllChem.Compute2DCoords(mol)
|
| 802 |
+
|
| 803 |
+
# Create drawer with larger size for annotations
|
| 804 |
+
drawer = Draw.rdMolDraw2D.MolDraw2DCairo(2000, 2000)
|
| 805 |
+
|
| 806 |
+
# Draw molecule first
|
| 807 |
+
drawer.drawOptions().addAtomIndices = False
|
| 808 |
+
drawer.DrawMolecule(mol)
|
| 809 |
+
drawer.FinishDrawing()
|
| 810 |
+
|
| 811 |
+
# Convert to PIL Image
|
| 812 |
+
img = Image.open(BytesIO(drawer.GetDrawingText()))
|
| 813 |
+
draw = ImageDraw.Draw(img)
|
| 814 |
+
try:
|
| 815 |
+
small_font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 60)
|
| 816 |
+
except OSError:
|
| 817 |
+
try:
|
| 818 |
+
small_font = ImageFont.truetype("arial.ttf", 60)
|
| 819 |
+
except OSError:
|
| 820 |
+
print("Warning: TrueType fonts not available, using default font")
|
| 821 |
+
small_font = ImageFont.load_default()
|
| 822 |
+
|
| 823 |
+
# Add just the sequence header at the top
|
| 824 |
+
seq_text = f"Sequence: {sequence}"
|
| 825 |
+
bbox = draw.textbbox((1000, 100), seq_text, font=small_font)
|
| 826 |
+
padding = 10
|
| 827 |
+
draw.rectangle([bbox[0]-padding, bbox[1]-padding,
|
| 828 |
+
bbox[2]+padding, bbox[3]+padding],
|
| 829 |
+
fill='white', outline='white')
|
| 830 |
+
draw.text((1000, 100), seq_text,
|
| 831 |
+
font=small_font, fill='black', anchor="mm")
|
| 832 |
+
|
| 833 |
+
return img
|
| 834 |
+
|
| 835 |
+
def create_enhanced_linear_viz(sequence, smiles):
|
| 836 |
+
"""Create an enhanced linear representation using PeptideAnalyzer"""
|
| 837 |
+
analyzer = PeptideAnalyzer() # Create analyzer instance
|
| 838 |
+
|
| 839 |
+
# Create figure with two subplots
|
| 840 |
+
fig = plt.figure(figsize=(15, 10))
|
| 841 |
+
gs = fig.add_gridspec(2, 1, height_ratios=[1, 2])
|
| 842 |
+
ax_struct = fig.add_subplot(gs[0])
|
| 843 |
+
ax_detail = fig.add_subplot(gs[1])
|
| 844 |
+
|
| 845 |
+
# Parse sequence and get residues
|
| 846 |
+
if sequence.startswith('cyclo('):
|
| 847 |
+
residues = sequence[6:-1].split('-')
|
| 848 |
+
else:
|
| 849 |
+
residues = sequence.split('-')
|
| 850 |
+
|
| 851 |
+
# Get segments using analyzer
|
| 852 |
+
segments = analyzer.split_on_bonds(smiles)
|
| 853 |
+
|
| 854 |
+
# Debug print
|
| 855 |
+
print(f"Number of residues: {len(residues)}")
|
| 856 |
+
print(f"Number of segments: {len(segments)}")
|
| 857 |
+
|
| 858 |
+
# Top subplot - Basic structure
|
| 859 |
+
ax_struct.set_xlim(0, 10)
|
| 860 |
+
ax_struct.set_ylim(0, 2)
|
| 861 |
+
|
| 862 |
+
num_residues = len(residues)
|
| 863 |
+
spacing = 9.0 / (num_residues - 1) if num_residues > 1 else 9.0
|
| 864 |
+
|
| 865 |
+
# Draw basic structure
|
| 866 |
+
y_pos = 1.5
|
| 867 |
+
for i in range(num_residues):
|
| 868 |
+
x_pos = 0.5 + i * spacing
|
| 869 |
+
|
| 870 |
+
# Draw amino acid box
|
| 871 |
+
rect = patches.Rectangle((x_pos-0.3, y_pos-0.2), 0.6, 0.4,
|
| 872 |
+
facecolor='lightblue', edgecolor='black')
|
| 873 |
+
ax_struct.add_patch(rect)
|
| 874 |
+
|
| 875 |
+
# Draw connecting bonds if not the last residue
|
| 876 |
+
if i < num_residues - 1:
|
| 877 |
+
segment = segments[i] if i < len(segments) else None
|
| 878 |
+
if segment:
|
| 879 |
+
# Determine bond type from segment info
|
| 880 |
+
bond_type = 'ester' if 'O-linked' in segment.get('bond_after', '') else 'peptide'
|
| 881 |
+
is_n_methylated = 'N-Me' in segment.get('bond_after', '')
|
| 882 |
+
|
| 883 |
+
bond_color = 'red' if bond_type == 'ester' else 'black'
|
| 884 |
+
linestyle = '--' if bond_type == 'ester' else '-'
|
| 885 |
+
|
| 886 |
+
# Draw bond line
|
| 887 |
+
ax_struct.plot([x_pos+0.3, x_pos+spacing-0.3], [y_pos, y_pos],
|
| 888 |
+
color=bond_color, linestyle=linestyle, linewidth=2)
|
| 889 |
+
|
| 890 |
+
# Add bond type label
|
| 891 |
+
mid_x = x_pos + spacing/2
|
| 892 |
+
bond_label = f"{bond_type}"
|
| 893 |
+
if is_n_methylated:
|
| 894 |
+
bond_label += "\n(N-Me)"
|
| 895 |
+
ax_struct.text(mid_x, y_pos+0.1, bond_label,
|
| 896 |
+
ha='center', va='bottom', fontsize=10,
|
| 897 |
+
color=bond_color)
|
| 898 |
+
|
| 899 |
+
# Add residue label
|
| 900 |
+
ax_struct.text(x_pos, y_pos-0.5, residues[i],
|
| 901 |
+
ha='center', va='top', fontsize=14)
|
| 902 |
+
|
| 903 |
+
# Bottom subplot - Detailed breakdown
|
| 904 |
+
ax_detail.set_ylim(0, len(segments)+1)
|
| 905 |
+
ax_detail.set_xlim(0, 1)
|
| 906 |
+
|
| 907 |
+
# Create detailed breakdown
|
| 908 |
+
segment_y = len(segments) # Start from top
|
| 909 |
+
for i, segment in enumerate(segments):
|
| 910 |
+
y = segment_y - i
|
| 911 |
+
|
| 912 |
+
# Check if this is a bond or residue
|
| 913 |
+
residue, mods = analyzer.identify_residue(segment)
|
| 914 |
+
if residue:
|
| 915 |
+
text = f"Residue {i+1}: {residue}"
|
| 916 |
+
if mods:
|
| 917 |
+
text += f" ({', '.join(mods)})"
|
| 918 |
+
color = 'blue'
|
| 919 |
+
else:
|
| 920 |
+
# Must be a bond
|
| 921 |
+
text = f"Bond {i}: "
|
| 922 |
+
if 'O-linked' in segment.get('bond_after', ''):
|
| 923 |
+
text += "ester"
|
| 924 |
+
elif 'N-Me' in segment.get('bond_after', ''):
|
| 925 |
+
text += "peptide (N-methylated)"
|
| 926 |
+
else:
|
| 927 |
+
text += "peptide"
|
| 928 |
+
color = 'red'
|
| 929 |
+
|
| 930 |
+
# Add segment analysis
|
| 931 |
+
ax_detail.text(0.05, y, text, fontsize=12, color=color)
|
| 932 |
+
ax_detail.text(0.5, y, f"SMILES: {segment.get('content', '')}", fontsize=10, color='gray')
|
| 933 |
+
|
| 934 |
+
# If cyclic, add connection indicator
|
| 935 |
+
if sequence.startswith('cyclo('):
|
| 936 |
+
ax_struct.annotate('', xy=(9.5, y_pos), xytext=(0.5, y_pos),
|
| 937 |
+
arrowprops=dict(arrowstyle='<->', color='red', lw=2))
|
| 938 |
+
ax_struct.text(5, y_pos+0.3, 'Cyclic Connection',
|
| 939 |
+
ha='center', color='red', fontsize=14)
|
| 940 |
+
|
| 941 |
+
# Add titles and adjust layout
|
| 942 |
+
ax_struct.set_title("Peptide Structure Overview", pad=20)
|
| 943 |
+
ax_detail.set_title("Segment Analysis Breakdown", pad=20)
|
| 944 |
+
|
| 945 |
+
# Remove axes
|
| 946 |
+
for ax in [ax_struct, ax_detail]:
|
| 947 |
+
ax.set_xticks([])
|
| 948 |
+
ax.set_yticks([])
|
| 949 |
+
ax.axis('off')
|
| 950 |
+
|
| 951 |
+
plt.tight_layout()
|
| 952 |
+
return fig
|
| 953 |
+
|
| 954 |
+
class PeptideStructureGenerator:
|
| 955 |
+
"""A class to generate 3D structures of peptides using different embedding methods"""
|
| 956 |
+
|
| 957 |
+
@staticmethod
|
| 958 |
+
def prepare_molecule(smiles):
|
| 959 |
+
"""Prepare molecule with proper hydrogen handling"""
|
| 960 |
+
mol = Chem.MolFromSmiles(smiles, sanitize=False)
|
| 961 |
+
if mol is None:
|
| 962 |
+
raise ValueError("Failed to create molecule from SMILES")
|
| 963 |
+
|
| 964 |
+
# Calculate valence for each atom
|
| 965 |
+
for atom in mol.GetAtoms():
|
| 966 |
+
atom.UpdatePropertyCache(strict=False)
|
| 967 |
+
|
| 968 |
+
# Sanitize with reduced requirements
|
| 969 |
+
Chem.SanitizeMol(mol,
|
| 970 |
+
sanitizeOps=Chem.SANITIZE_FINDRADICALS|
|
| 971 |
+
Chem.SANITIZE_KEKULIZE|
|
| 972 |
+
Chem.SANITIZE_SETAROMATICITY|
|
| 973 |
+
Chem.SANITIZE_SETCONJUGATION|
|
| 974 |
+
Chem.SANITIZE_SETHYBRIDIZATION|
|
| 975 |
+
Chem.SANITIZE_CLEANUPCHIRALITY)
|
| 976 |
+
|
| 977 |
+
mol = Chem.AddHs(mol)
|
| 978 |
+
return mol
|
| 979 |
+
|
| 980 |
+
@staticmethod
|
| 981 |
+
def get_etkdg_params(attempt=0):
|
| 982 |
+
"""Get ETKDG parameters with optional modifications based on attempt number"""
|
| 983 |
+
params = AllChem.ETKDGv3()
|
| 984 |
+
params.randomSeed = -1
|
| 985 |
+
params.maxIterations = 200
|
| 986 |
+
params.numThreads = 4 # Reduced for web interface
|
| 987 |
+
params.useBasicKnowledge = True
|
| 988 |
+
params.enforceChirality = True
|
| 989 |
+
params.useExpTorsionAnglePrefs = True
|
| 990 |
+
params.useSmallRingTorsions = True
|
| 991 |
+
params.useMacrocycleTorsions = True
|
| 992 |
+
params.ETversion = 2
|
| 993 |
+
params.pruneRmsThresh = -1
|
| 994 |
+
params.embedRmsThresh = 0.5
|
| 995 |
+
|
| 996 |
+
if attempt > 10:
|
| 997 |
+
params.bondLength = 1.5 + (attempt - 10) * 0.02
|
| 998 |
+
params.useExpTorsionAnglePrefs = False
|
| 999 |
+
|
| 1000 |
+
return params
|
| 1001 |
+
|
| 1002 |
+
def generate_structure_etkdg(self, smiles, max_attempts=20):
|
| 1003 |
+
"""Generate 3D structure using ETKDG without UFF optimization"""
|
| 1004 |
+
success = False
|
| 1005 |
+
mol = None
|
| 1006 |
+
|
| 1007 |
+
for attempt in range(max_attempts):
|
| 1008 |
+
try:
|
| 1009 |
+
mol = self.prepare_molecule(smiles)
|
| 1010 |
+
params = self.get_etkdg_params(attempt)
|
| 1011 |
+
|
| 1012 |
+
if AllChem.EmbedMolecule(mol, params) == 0:
|
| 1013 |
+
success = True
|
| 1014 |
+
break
|
| 1015 |
+
except Exception as e:
|
| 1016 |
+
continue
|
| 1017 |
+
|
| 1018 |
+
if not success:
|
| 1019 |
+
raise ValueError("Failed to generate structure with ETKDG")
|
| 1020 |
+
|
| 1021 |
+
return mol
|
| 1022 |
+
|
| 1023 |
+
def generate_structure_uff(self, smiles, max_attempts=20):
|
| 1024 |
+
"""Generate 3D structure using ETKDG followed by UFF optimization"""
|
| 1025 |
+
best_mol = None
|
| 1026 |
+
lowest_energy = float('inf')
|
| 1027 |
+
|
| 1028 |
+
for attempt in range(max_attempts):
|
| 1029 |
+
try:
|
| 1030 |
+
test_mol = self.prepare_molecule(smiles)
|
| 1031 |
+
params = self.get_etkdg_params(attempt)
|
| 1032 |
+
|
| 1033 |
+
if AllChem.EmbedMolecule(test_mol, params) == 0:
|
| 1034 |
+
res = AllChem.UFFOptimizeMolecule(test_mol, maxIters=2000,
|
| 1035 |
+
vdwThresh=10.0, confId=0,
|
| 1036 |
+
ignoreInterfragInteractions=True)
|
| 1037 |
+
|
| 1038 |
+
if res == 0:
|
| 1039 |
+
ff = AllChem.UFFGetMoleculeForceField(test_mol)
|
| 1040 |
+
if ff:
|
| 1041 |
+
current_energy = ff.CalcEnergy()
|
| 1042 |
+
if current_energy < lowest_energy:
|
| 1043 |
+
lowest_energy = current_energy
|
| 1044 |
+
best_mol = Chem.Mol(test_mol)
|
| 1045 |
+
except Exception:
|
| 1046 |
+
continue
|
| 1047 |
+
|
| 1048 |
+
if best_mol is None:
|
| 1049 |
+
raise ValueError("Failed to generate optimized structure")
|
| 1050 |
+
|
| 1051 |
+
return best_mol
|
| 1052 |
+
|
| 1053 |
+
@staticmethod
|
| 1054 |
+
def mol_to_sdf_bytes(mol):
|
| 1055 |
+
"""Convert RDKit molecule to SDF file bytes"""
|
| 1056 |
+
# First write to StringIO in text mode
|
| 1057 |
+
sio = StringIO()
|
| 1058 |
+
writer = Chem.SDWriter(sio)
|
| 1059 |
+
writer.write(mol)
|
| 1060 |
+
writer.close()
|
| 1061 |
+
|
| 1062 |
+
# Convert the string to bytes
|
| 1063 |
+
return sio.getvalue().encode('utf-8')
|
| 1064 |
+
|
| 1065 |
+
def process_input(smiles_input=None, file_obj=None, show_linear=False,
|
| 1066 |
+
show_segment_details=False, generate_3d=False, use_uff=False):
|
| 1067 |
+
"""Process input and create visualizations using PeptideAnalyzer"""
|
| 1068 |
+
analyzer = PeptideAnalyzer()
|
| 1069 |
+
temp_dir = tempfile.mkdtemp() if generate_3d else None
|
| 1070 |
+
structure_files = []
|
| 1071 |
+
|
| 1072 |
+
# Handle direct SMILES input
|
| 1073 |
+
if smiles_input:
|
| 1074 |
+
smiles = smiles_input.strip()
|
| 1075 |
+
|
| 1076 |
+
# First check if it's a peptide using analyzer's method
|
| 1077 |
+
if not analyzer.is_peptide(smiles):
|
| 1078 |
+
return "Error: Input SMILES does not appear to be a peptide structure.", None, None
|
| 1079 |
+
|
| 1080 |
+
try:
|
| 1081 |
+
# Create molecule
|
| 1082 |
+
mol = Chem.MolFromSmiles(smiles)
|
| 1083 |
+
if mol is None:
|
| 1084 |
+
return "Error: Invalid SMILES notation.", None, None
|
| 1085 |
+
|
| 1086 |
+
# Generate 3D structures if requested
|
| 1087 |
+
if generate_3d:
|
| 1088 |
+
generator = PeptideStructureGenerator()
|
| 1089 |
+
|
| 1090 |
+
try:
|
| 1091 |
+
# Generate ETKDG structure
|
| 1092 |
+
mol_etkdg = generator.generate_structure_etkdg(smiles)
|
| 1093 |
+
etkdg_path = os.path.join(temp_dir, "structure_etkdg.sdf")
|
| 1094 |
+
writer = Chem.SDWriter(etkdg_path)
|
| 1095 |
+
writer.write(mol_etkdg)
|
| 1096 |
+
writer.close()
|
| 1097 |
+
structure_files.append(etkdg_path)
|
| 1098 |
+
|
| 1099 |
+
# Generate UFF structure if requested
|
| 1100 |
+
if use_uff:
|
| 1101 |
+
mol_uff = generator.generate_structure_uff(smiles)
|
| 1102 |
+
uff_path = os.path.join(temp_dir, "structure_uff.sdf")
|
| 1103 |
+
writer = Chem.SDWriter(uff_path)
|
| 1104 |
+
writer.write(mol_uff)
|
| 1105 |
+
writer.close()
|
| 1106 |
+
structure_files.append(uff_path)
|
| 1107 |
+
|
| 1108 |
+
except Exception as e:
|
| 1109 |
+
return f"Error generating 3D structures: {str(e)}", None, None, None
|
| 1110 |
+
|
| 1111 |
+
# Use analyzer to get sequence
|
| 1112 |
+
segments = analyzer.split_on_bonds(smiles)
|
| 1113 |
+
|
| 1114 |
+
# Process segments and build sequence
|
| 1115 |
+
sequence_parts = []
|
| 1116 |
+
output_text = ""
|
| 1117 |
+
|
| 1118 |
+
# Only include segment analysis in output if requested
|
| 1119 |
+
if show_segment_details:
|
| 1120 |
+
output_text += "Segment Analysis:\n"
|
| 1121 |
+
for i, segment in enumerate(segments):
|
| 1122 |
+
output_text += f"\nSegment {i}:\n"
|
| 1123 |
+
output_text += f"Content: {segment['content']}\n"
|
| 1124 |
+
output_text += f"Bond before: {segment.get('bond_before', 'None')}\n"
|
| 1125 |
+
output_text += f"Bond after: {segment.get('bond_after', 'None')}\n"
|
| 1126 |
+
|
| 1127 |
+
residue, mods = analyzer.identify_residue(segment)
|
| 1128 |
+
if residue:
|
| 1129 |
+
if mods:
|
| 1130 |
+
sequence_parts.append(f"{residue}({','.join(mods)})")
|
| 1131 |
+
else:
|
| 1132 |
+
sequence_parts.append(residue)
|
| 1133 |
+
output_text += f"Identified as: {residue}\n"
|
| 1134 |
+
output_text += f"Modifications: {mods}\n"
|
| 1135 |
+
else:
|
| 1136 |
+
output_text += f"Warning: Could not identify residue in segment: {segment['content']}\n"
|
| 1137 |
+
output_text += "\n"
|
| 1138 |
+
else:
|
| 1139 |
+
# Just build sequence without detailed analysis in output
|
| 1140 |
+
for segment in segments:
|
| 1141 |
+
residue, mods = analyzer.identify_residue(segment)
|
| 1142 |
+
if residue:
|
| 1143 |
+
if mods:
|
| 1144 |
+
sequence_parts.append(f"{residue}({','.join(mods)})")
|
| 1145 |
+
else:
|
| 1146 |
+
sequence_parts.append(residue)
|
| 1147 |
+
|
| 1148 |
+
# Check if cyclic using analyzer's method
|
| 1149 |
+
is_cyclic, peptide_cycles, aromatic_cycles = analyzer.is_cyclic(smiles)
|
| 1150 |
+
three_letter = '-'.join(sequence_parts)
|
| 1151 |
+
one_letter = ''.join(analyzer.three_to_one.get(aa.split('(')[0], 'X') for aa in sequence_parts)
|
| 1152 |
+
|
| 1153 |
+
if is_cyclic:
|
| 1154 |
+
three_letter = f"cyclo({three_letter})"
|
| 1155 |
+
one_letter = f"cyclo({one_letter})"
|
| 1156 |
+
|
| 1157 |
+
# Create cyclic structure visualization
|
| 1158 |
+
img_cyclic = annotate_cyclic_structure(mol, three_letter)
|
| 1159 |
+
|
| 1160 |
+
# Create linear representation if requested
|
| 1161 |
+
img_linear = None
|
| 1162 |
+
if show_linear:
|
| 1163 |
+
fig_linear = create_enhanced_linear_viz(three_letter, smiles)
|
| 1164 |
+
buf = BytesIO()
|
| 1165 |
+
fig_linear.savefig(buf, format='png', bbox_inches='tight', dpi=300)
|
| 1166 |
+
buf.seek(0)
|
| 1167 |
+
img_linear = Image.open(buf)
|
| 1168 |
+
plt.close(fig_linear)
|
| 1169 |
+
|
| 1170 |
+
# Add summary to output
|
| 1171 |
+
summary = "Summary:\n"
|
| 1172 |
+
summary += f"Sequence: {three_letter}\n"
|
| 1173 |
+
summary += f"One-letter code: {one_letter}\n"
|
| 1174 |
+
summary += f"Is Cyclic: {'Yes' if is_cyclic else 'No'}\n"
|
| 1175 |
+
#if is_cyclic:
|
| 1176 |
+
#summary += f"Peptide Cycles: {', '.join(peptide_cycles)}\n"
|
| 1177 |
+
#summary += f"Aromatic Cycles: {', '.join(aromatic_cycles)}\n"
|
| 1178 |
+
|
| 1179 |
+
if structure_files:
|
| 1180 |
+
summary += "\n3D Structures Generated:\n"
|
| 1181 |
+
for filepath in structure_files:
|
| 1182 |
+
summary += f"- {os.path.basename(filepath)}\n"
|
| 1183 |
+
|
| 1184 |
+
return summary + output_text, img_cyclic, img_linear, structure_files if structure_files else None
|
| 1185 |
+
|
| 1186 |
+
except Exception as e:
|
| 1187 |
+
return f"Error processing SMILES: {str(e)}", None, None, None
|
| 1188 |
+
|
| 1189 |
+
# Handle file input
|
| 1190 |
+
if file_obj is not None:
|
| 1191 |
+
try:
|
| 1192 |
+
# Handle file content
|
| 1193 |
+
if hasattr(file_obj, 'name'):
|
| 1194 |
+
with open(file_obj.name, 'r') as f:
|
| 1195 |
+
content = f.read()
|
| 1196 |
+
else:
|
| 1197 |
+
content = file_obj.decode('utf-8') if isinstance(file_obj, bytes) else str(file_obj)
|
| 1198 |
+
|
| 1199 |
+
output_text = ""
|
| 1200 |
+
for line in content.splitlines():
|
| 1201 |
+
smiles = line.strip()
|
| 1202 |
+
if smiles:
|
| 1203 |
+
# Check if it's a peptide
|
| 1204 |
+
if not analyzer.is_peptide(smiles):
|
| 1205 |
+
output_text += f"Skipping non-peptide SMILES: {smiles}\n"
|
| 1206 |
+
continue
|
| 1207 |
+
|
| 1208 |
+
# Process this SMILES
|
| 1209 |
+
segments = analyzer.split_on_bonds(smiles)
|
| 1210 |
+
sequence_parts = []
|
| 1211 |
+
|
| 1212 |
+
# Add segment details if requested
|
| 1213 |
+
if show_segment_details:
|
| 1214 |
+
output_text += f"\nSegment Analysis for SMILES: {smiles}\n"
|
| 1215 |
+
for i, segment in enumerate(segments):
|
| 1216 |
+
output_text += f"\nSegment {i}:\n"
|
| 1217 |
+
output_text += f"Content: {segment['content']}\n"
|
| 1218 |
+
output_text += f"Bond before: {segment.get('bond_before', 'None')}\n"
|
| 1219 |
+
output_text += f"Bond after: {segment.get('bond_after', 'None')}\n"
|
| 1220 |
+
residue, mods = analyzer.identify_residue(segment)
|
| 1221 |
+
if residue:
|
| 1222 |
+
if mods:
|
| 1223 |
+
sequence_parts.append(f"{residue}({','.join(mods)})")
|
| 1224 |
+
else:
|
| 1225 |
+
sequence_parts.append(residue)
|
| 1226 |
+
output_text += f"Identified as: {residue}\n"
|
| 1227 |
+
output_text += f"Modifications: {mods}\n"
|
| 1228 |
+
else:
|
| 1229 |
+
for segment in segments:
|
| 1230 |
+
residue, mods = analyzer.identify_residue(segment)
|
| 1231 |
+
if residue:
|
| 1232 |
+
if mods:
|
| 1233 |
+
sequence_parts.append(f"{residue}({','.join(mods)})")
|
| 1234 |
+
else:
|
| 1235 |
+
sequence_parts.append(residue)
|
| 1236 |
+
|
| 1237 |
+
# Get cyclicity and create sequence
|
| 1238 |
+
is_cyclic, peptide_cycles, aromatic_cycles = analyzer.is_cyclic(smiles)
|
| 1239 |
+
sequence = f"cyclo({'-'.join(sequence_parts)})" if is_cyclic else '-'.join(sequence_parts)
|
| 1240 |
+
|
| 1241 |
+
output_text += f"\nSummary for SMILES: {smiles}\n"
|
| 1242 |
+
output_text += f"Sequence: {sequence}\n"
|
| 1243 |
+
output_text += f"Is Cyclic: {'Yes' if is_cyclic else 'No'}\n"
|
| 1244 |
+
if is_cyclic:
|
| 1245 |
+
output_text += f"Peptide Cycles: {', '.join(peptide_cycles)}\n"
|
| 1246 |
+
#output_text += f"Aromatic Cycles: {', '.join(aromatic_cycles)}\n"
|
| 1247 |
+
output_text += "-" * 50 + "\n"
|
| 1248 |
+
|
| 1249 |
+
return output_text, None, None
|
| 1250 |
+
|
| 1251 |
+
except Exception as e:
|
| 1252 |
+
return f"Error processing file: {str(e)}", None, None
|
| 1253 |
+
|
| 1254 |
+
return "No input provided.", None, None
|
| 1255 |
+
|
src/utils/generate_utils.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import math
|
| 3 |
+
import sys
|
| 4 |
+
import pandas as pd
|
| 5 |
+
from transformers import AutoModelForMaskedLM, AutoModel, AutoTokenizer
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def mask_for_de_novo(config, sequence_length):
|
| 9 |
+
if config.vocab == 'helm':
|
| 10 |
+
return "[MASK]" * sequence_length
|
| 11 |
+
elif config.vocab == 'new_smiles' or config.vocab == 'selfies':
|
| 12 |
+
return ["<mask>"] * sequence_length
|
| 13 |
+
else:
|
| 14 |
+
return ["[MASK]"] * sequence_length
|
| 15 |
+
|
| 16 |
+
def generate_de_novo(sequence_length, tokenizer, model):
|
| 17 |
+
masked_sequence = mask_for_de_novo(sequence_length)
|
| 18 |
+
inputs = tokenizer(masked_sequence, return_tensors='pt').to(model.device)
|
| 19 |
+
|
| 20 |
+
with torch.no_grad():
|
| 21 |
+
logits = model(**inputs).logits
|
| 22 |
+
mask_token_indices = (inputs["input_ids"] == tokenizer.mask_token_id).nonzero(as_tuple=True)[1]
|
| 23 |
+
logits_at_masks = logits[0, mask_token_indices]
|
| 24 |
+
|
| 25 |
+
pred_tokens = []
|
| 26 |
+
for i in mask_token_indices:
|
| 27 |
+
topk_logits, topk_indices = logits_at_masks[i].topk(k=3, dim=-1)
|
| 28 |
+
probabilities = torch.nn.functional.softmax(topk_logits, dim=-1)
|
| 29 |
+
predicted_index = torch.distributions.categorical.Categorical(probabilities).sample()
|
| 30 |
+
predicted_token_id = topk_indices[predicted_index].item()
|
| 31 |
+
predicted_token = tokenizer.decode([predicted_token_id], skip_special_tokens=True)
|
| 32 |
+
pred_tokens.append(predicted_token)
|
| 33 |
+
|
| 34 |
+
generated_sequence = ''.join(pred_tokens)
|
| 35 |
+
perplexity = calculate_perplexity(model, tokenizer, generated_sequence)
|
| 36 |
+
|
| 37 |
+
return (generated_sequence, perplexity)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def calculate_perplexity(model, tokenizer, generated_sequence, mask_token_indices):
|
| 41 |
+
total_loss = 0.0
|
| 42 |
+
tensor_input = tokenizer.encode(generated_sequence, return_tensors='pt').to(model.device)
|
| 43 |
+
|
| 44 |
+
for i in mask_token_indices:
|
| 45 |
+
masked_input = tensor_input.clone()
|
| 46 |
+
masked_input[0, i] = tokenizer.mask_token_id
|
| 47 |
+
|
| 48 |
+
labels = torch.full(tensor_input.shape, -100).to(model.device)
|
| 49 |
+
labels[0, i] = tensor_input[0, i]
|
| 50 |
+
|
| 51 |
+
with torch.no_grad():
|
| 52 |
+
outputs = model(masked_input, labels=labels)
|
| 53 |
+
total_loss += outputs.loss.item()
|
| 54 |
+
|
| 55 |
+
num_mask_tokens = len(mask_token_indices)
|
| 56 |
+
if num_mask_tokens == 0:
|
| 57 |
+
perplexity = 10000
|
| 58 |
+
else:
|
| 59 |
+
avg_loss = total_loss / num_mask_tokens
|
| 60 |
+
perplexity = math.exp(avg_loss)
|
| 61 |
+
|
| 62 |
+
return perplexity
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def calculate_cosine_sim(original_sequence, generated_sequence, tokenizer, pepclm_model, device):
|
| 66 |
+
og_embeddings = pepclm_model.roformer.encoder(original_sequence)
|
| 67 |
+
new_embeddings = pepclm_model.roformer.encoder(generated_sequence)
|
| 68 |
+
|
| 69 |
+
sequence_similarity = torch.nn.functional.cosine_similarity(og_embeddings, new_embeddings, dim=-1)
|
| 70 |
+
cosine_similarity = torch.mean(sequence_similarity).item()
|
| 71 |
+
return cosine_similarity
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def calculate_hamming_dist(original_sequence, generated_sequence):
|
| 75 |
+
generated_sequence = generated_sequence
|
| 76 |
+
original_sequence = original_sequence
|
| 77 |
+
return sum(1 if original_sequence[i] != generated_sequence[i] else 0 for i in range(len(original_sequence)))
|
src/utils/utils.py
ADDED
|
@@ -0,0 +1,256 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Console logger utilities.
|
| 2 |
+
|
| 3 |
+
Copied from https://github.com/HazyResearch/transformers/blob/master/src/utils/utils.py
|
| 4 |
+
Copied from https://docs.python.org/3/howto/logging-cookbook.html#using-a-context-manager-for-selective-logging
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import logging
|
| 8 |
+
import math
|
| 9 |
+
|
| 10 |
+
import fsspec
|
| 11 |
+
import lightning
|
| 12 |
+
import torch
|
| 13 |
+
from timm.scheduler import CosineLRScheduler
|
| 14 |
+
from multiprocessing import Pool
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def fsspec_exists(filename):
|
| 18 |
+
"""Check if a file exists using fsspec."""
|
| 19 |
+
fs, _ = fsspec.core.url_to_fs(filename)
|
| 20 |
+
return fs.exists(filename)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def fsspec_listdir(dirname):
|
| 24 |
+
"""Listdir in manner compatible with fsspec."""
|
| 25 |
+
fs, _ = fsspec.core.url_to_fs(dirname)
|
| 26 |
+
return fs.ls(dirname)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def fsspec_mkdirs(dirname, exist_ok=True):
|
| 30 |
+
"""Mkdirs in manner compatible with fsspec."""
|
| 31 |
+
fs, _ = fsspec.core.url_to_fs(dirname)
|
| 32 |
+
fs.makedirs(dirname, exist_ok=exist_ok)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def print_nans(tensor, name):
|
| 36 |
+
if torch.isnan(tensor).any():
|
| 37 |
+
print(name, tensor)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class CosineDecayWarmupLRScheduler(
|
| 41 |
+
CosineLRScheduler,
|
| 42 |
+
torch.optim.lr_scheduler._LRScheduler):
|
| 43 |
+
"""Wrap timm.scheduler.CosineLRScheduler
|
| 44 |
+
Enables calling scheduler.step() without passing in epoch.
|
| 45 |
+
Supports resuming as well.
|
| 46 |
+
Adapted from:
|
| 47 |
+
https://github.com/HazyResearch/hyena-dna/blob/main/src/utils/optim/schedulers.py
|
| 48 |
+
"""
|
| 49 |
+
|
| 50 |
+
def __init__(self, *args, **kwargs):
|
| 51 |
+
super().__init__(*args, **kwargs)
|
| 52 |
+
self._last_epoch = -1
|
| 53 |
+
self.step(epoch=0)
|
| 54 |
+
|
| 55 |
+
def step(self, epoch=None):
|
| 56 |
+
if epoch is None:
|
| 57 |
+
self._last_epoch += 1
|
| 58 |
+
else:
|
| 59 |
+
self._last_epoch = epoch
|
| 60 |
+
# We call either step or step_update, depending on
|
| 61 |
+
# whether we're using the scheduler every epoch or every
|
| 62 |
+
# step.
|
| 63 |
+
# Otherwise, lightning will always call step (i.e.,
|
| 64 |
+
# meant for each epoch), and if we set scheduler
|
| 65 |
+
# interval to "step", then the learning rate update will
|
| 66 |
+
# be wrong.
|
| 67 |
+
if self.t_in_epochs:
|
| 68 |
+
super().step(epoch=self._last_epoch)
|
| 69 |
+
else:
|
| 70 |
+
super().step_update(num_updates=self._last_epoch)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class LoggingContext:
|
| 74 |
+
"""Context manager for selective logging."""
|
| 75 |
+
def __init__(self, logger, level=None, handler=None, close=True):
|
| 76 |
+
self.logger = logger
|
| 77 |
+
self.level = level
|
| 78 |
+
self.handler = handler
|
| 79 |
+
self.close = close
|
| 80 |
+
|
| 81 |
+
def __enter__(self):
|
| 82 |
+
if self.level is not None:
|
| 83 |
+
self.old_level = self.logger.level
|
| 84 |
+
self.logger.setLevel(self.level)
|
| 85 |
+
if self.handler:
|
| 86 |
+
self.logger.addHandler(self.handler)
|
| 87 |
+
|
| 88 |
+
def __exit__(self, et, ev, tb):
|
| 89 |
+
if self.level is not None:
|
| 90 |
+
self.logger.setLevel(self.old_level)
|
| 91 |
+
if self.handler:
|
| 92 |
+
self.logger.removeHandler(self.handler)
|
| 93 |
+
if self.handler and self.close:
|
| 94 |
+
self.handler.close()
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def get_logger(name=__name__, level=logging.INFO) -> logging.Logger:
|
| 98 |
+
"""Initializes multi-GPU-friendly python logger."""
|
| 99 |
+
|
| 100 |
+
logger = logging.getLogger(name)
|
| 101 |
+
logger.setLevel(level)
|
| 102 |
+
|
| 103 |
+
# this ensures all logging levels get marked with the rank zero decorator
|
| 104 |
+
# otherwise logs would get multiplied for each GPU process in multi-GPU setup
|
| 105 |
+
for level in ('debug', 'info', 'warning', 'error',
|
| 106 |
+
'exception', 'fatal', 'critical'):
|
| 107 |
+
setattr(logger,
|
| 108 |
+
level,
|
| 109 |
+
lightning.pytorch.utilities.rank_zero_only(
|
| 110 |
+
getattr(logger, level)))
|
| 111 |
+
|
| 112 |
+
return logger
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
class Sampler:
|
| 116 |
+
def __init__(self, shape):
|
| 117 |
+
self.shape = shape
|
| 118 |
+
|
| 119 |
+
def _sampling_noise(self):
|
| 120 |
+
pass
|
| 121 |
+
|
| 122 |
+
def _hard_sample(self, logits):
|
| 123 |
+
pass
|
| 124 |
+
|
| 125 |
+
def _soft_sample(self, logits):
|
| 126 |
+
return 0
|
| 127 |
+
|
| 128 |
+
def sample(self, logits):
|
| 129 |
+
noise = self._sampling_noise()
|
| 130 |
+
noise = noise[: logits.shape[0], :]
|
| 131 |
+
logits = logits + noise.to(
|
| 132 |
+
dtype=logits.dtype, device=logits.device)
|
| 133 |
+
hard_sample = self._hard_sample(logits)
|
| 134 |
+
soft_sample = self._soft_sample(logits)
|
| 135 |
+
return soft_sample + (hard_sample - soft_sample).detach()
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
class TopKSampler(Sampler):
|
| 139 |
+
def __init__(self, k, shape, gamma_tau=1.0):
|
| 140 |
+
super().__init__(shape)
|
| 141 |
+
self.k = k
|
| 142 |
+
self.gamma_tau = gamma_tau
|
| 143 |
+
self.num_betas = 10
|
| 144 |
+
self.sampler = torch.distributions.gamma.Gamma(
|
| 145 |
+
1 / k * torch.ones(self.num_betas, * self.shape), 1.0)
|
| 146 |
+
|
| 147 |
+
def _sampling_noise(self):
|
| 148 |
+
noise = self.sampler.sample()
|
| 149 |
+
beta = self.k / torch.arange(1, self.num_betas + 1, 1,
|
| 150 |
+
dtype=torch.float32)
|
| 151 |
+
beta = beta[:, None, None]
|
| 152 |
+
assert beta.ndim == noise.ndim
|
| 153 |
+
s = noise / beta
|
| 154 |
+
s = torch.sum(s, axis=0)
|
| 155 |
+
s = s - math.log(10.0)
|
| 156 |
+
s = self.gamma_tau * (s / self.k)
|
| 157 |
+
return s
|
| 158 |
+
|
| 159 |
+
def _hard_sample(self, logits):
|
| 160 |
+
assert logits.ndim == 2
|
| 161 |
+
thresholds, _ = torch.sort(logits, dim=-1)
|
| 162 |
+
thresholds = thresholds[:, - self.k][:, None]
|
| 163 |
+
return (logits >= thresholds).type(logits.dtype)
|
| 164 |
+
|
| 165 |
+
def _soft_sample(self, logits):
|
| 166 |
+
soft_top_k = logits - torch.mean(logits, dim=-1,
|
| 167 |
+
keepdim=True)
|
| 168 |
+
return soft_top_k / torch.norm(soft_top_k, dim=-1,
|
| 169 |
+
keepdim=True)
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
class DeterministicTopK(TopKSampler):
|
| 173 |
+
def __init__(self, k):
|
| 174 |
+
super().__init__(k, shape=(1, 1))
|
| 175 |
+
|
| 176 |
+
def _sampling_noise(self):
|
| 177 |
+
return 0
|
| 178 |
+
|
| 179 |
+
def discreize(self, x):
|
| 180 |
+
hard_sample = self._hard_sample(x)
|
| 181 |
+
soft_sample = self._soft_sample(x)
|
| 182 |
+
return soft_sample + (hard_sample - soft_sample).detach()
|
| 183 |
+
|
| 184 |
+
class GumbelSampler(Sampler):
|
| 185 |
+
|
| 186 |
+
def __init__(self, shape, temperature=1.0):
|
| 187 |
+
super().__init__(shape)
|
| 188 |
+
self.temperature = temperature
|
| 189 |
+
|
| 190 |
+
def _sampling_noise(self):
|
| 191 |
+
return - (1e-10 - (
|
| 192 |
+
torch.rand(* self.shape) + 1e-10).log()).log()
|
| 193 |
+
|
| 194 |
+
def _hard_sample(self, logits):
|
| 195 |
+
assert logits.ndim == 2
|
| 196 |
+
indices = torch.argmax(logits, dim=-1)
|
| 197 |
+
zeros = logits * 0
|
| 198 |
+
ones = torch.ones_like(logits[:, :, :1])
|
| 199 |
+
return torch.scatter(zeros, -1, indices[:, :, None],
|
| 200 |
+
ones)
|
| 201 |
+
|
| 202 |
+
def _soft_sample(self, logits):
|
| 203 |
+
return torch.nn.functional.softmax(
|
| 204 |
+
logits / self.temperature, dim=-1)
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
class BinarySampler(GumbelSampler):
|
| 208 |
+
|
| 209 |
+
def sample(self, probs):
|
| 210 |
+
# TODO(subhamsahoo): use the temperature parameter.
|
| 211 |
+
pos_noise = self._sampling_noise().to(
|
| 212 |
+
dtype=probs.dtype, device=probs.device)
|
| 213 |
+
neg_noise = self._sampling_noise().to(
|
| 214 |
+
dtype=probs.dtype, device=probs.device)
|
| 215 |
+
del_noise_exp = (neg_noise - pos_noise).exp()
|
| 216 |
+
hard_sample = (probs * (1 + del_noise_exp)
|
| 217 |
+
> 1).to(probs.dtype)
|
| 218 |
+
soft_sample = probs / (probs + (1 - probs) * del_noise_exp)
|
| 219 |
+
return soft_sample + (hard_sample - soft_sample).detach()
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
class GaussianSampler:
|
| 223 |
+
def __init__(self):
|
| 224 |
+
self.softplus = torch.nn.Softplus()
|
| 225 |
+
|
| 226 |
+
def sample(self, x):
|
| 227 |
+
assert x.ndim == 2
|
| 228 |
+
n = x.shape[-1] // 2
|
| 229 |
+
mu = x[:, :n]
|
| 230 |
+
sigma = self.softplus(x[:, n:]).sqrt()
|
| 231 |
+
return mu + sigma * torch.randn_like(mu)
|
| 232 |
+
|
| 233 |
+
def mapper(n_jobs):
|
| 234 |
+
'''
|
| 235 |
+
Returns function for map call.
|
| 236 |
+
If n_jobs == 1, will use standard map
|
| 237 |
+
If n_jobs > 1, will use multiprocessing pool
|
| 238 |
+
If n_jobs is a pool object, will return its map function
|
| 239 |
+
'''
|
| 240 |
+
if n_jobs == 1:
|
| 241 |
+
def _mapper(*args, **kwargs):
|
| 242 |
+
return list(map(*args, **kwargs))
|
| 243 |
+
|
| 244 |
+
return _mapper
|
| 245 |
+
if isinstance(n_jobs, int):
|
| 246 |
+
pool = Pool(n_jobs)
|
| 247 |
+
|
| 248 |
+
def _mapper(*args, **kwargs):
|
| 249 |
+
try:
|
| 250 |
+
result = pool.map(*args, **kwargs)
|
| 251 |
+
finally:
|
| 252 |
+
pool.terminate()
|
| 253 |
+
return result
|
| 254 |
+
|
| 255 |
+
return _mapper
|
| 256 |
+
return n_jobs.map
|