Sophia Tang commited on
Commit
40e7e76
·
0 Parent(s):

initial commit

Browse files
Files changed (46) hide show
  1. .gitattributes +2 -0
  2. .gitignore +1 -0
  3. README.md +193 -0
  4. assets/mcts.png +3 -0
  5. assets/mdlm.png +3 -0
  6. assets/peptune.png +3 -0
  7. assets/poster.png +3 -0
  8. data/dataloading_for_dynamic_batching.py +156 -0
  9. data/dataset.py +207 -0
  10. scripts/generate_mcts.sh +57 -0
  11. scripts/generate_unconditional.sh +16 -0
  12. scripts/train.sh +18 -0
  13. src/config.py +319 -0
  14. src/config.yaml +164 -0
  15. src/diffusion.py +1015 -0
  16. src/environment.yml +40 -0
  17. src/generate_mcts.py +365 -0
  18. src/generate_unconditional.py +111 -0
  19. src/metrics.py +72 -0
  20. src/noise_schedule.py +152 -0
  21. src/pareto_mcts.py +492 -0
  22. src/roformer.py +74 -0
  23. src/scoring/functions/binding.py +178 -0
  24. src/scoring/functions/binding_utils.py +290 -0
  25. src/scoring/functions/classifiers/hemolysis-xgboost.json +0 -0
  26. src/scoring/functions/classifiers/nonfouling-xgboost.json +0 -0
  27. src/scoring/functions/classifiers/permeability-xgboost.json +3 -0
  28. src/scoring/functions/classifiers/solubility-xgboost.json +0 -0
  29. src/scoring/functions/hemolysis.py +63 -0
  30. src/scoring/functions/nonfouling.py +66 -0
  31. src/scoring/functions/permeability.py +171 -0
  32. src/scoring/functions/scoring_utils.py +94 -0
  33. src/scoring/functions/solubility.py +63 -0
  34. src/scoring/scoring_functions.py +75 -0
  35. src/scoring/tokenizer/my_tokenizers.py +424 -0
  36. src/scoring/tokenizer/new_splits.txt +159 -0
  37. src/scoring/tokenizer/new_vocab.txt +587 -0
  38. src/tokenizer/__init__.py +0 -0
  39. src/tokenizer/my_tokenizers.py +441 -0
  40. src/tokenizer/new_splits.txt +159 -0
  41. src/tokenizer/new_vocab.txt +587 -0
  42. src/train.py +133 -0
  43. src/train_peptune.py +226 -0
  44. src/utils/app.py +1255 -0
  45. src/utils/generate_utils.py +77 -0
  46. src/utils/utils.py +256 -0
.gitattributes ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ src/scoring/functions/classifiers/permeability-xgboost.json filter=lfs diff=lfs merge=lfs -text
2
+ assets/*.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ logs/
README.md ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # [PepTune: De Novo Generation of Therapeutic Peptides with Multi-Objective-Guided Discrete Diffusion](https://arxiv.org/abs/2412.17780) 🧬🔮 (ICML 2025)
2
+
3
+ [**Sophia Tang**](https://sophtang.github.io/)\*, [**Yinuo Zhang**](https://www.linkedin.com/in/yinuozhang98/)\* and [**Pranam Chatterjee**](https://www.chatterjeelab.com/)
4
+
5
+ ![PepTune](assets/poster.png)
6
+
7
+ This is the repository for **[PepTune: De Novo Generation of Therapeutic Peptides with Multi-Objective-Guided Discrete Diffusion](https://arxiv.org/abs/2412.17780)** 🧬🔮 published at **ICML 2025**. It is partially built on the **[MDLM repo](https://github.com/kuleshov-group/mdlm)** ([Sahoo et al. 2024](https://arxiv.org/abs/2406.07524)).
8
+
9
+ PepTune leverages **Monte-Carlo Tree Search (MCTS)** to guide a generative masked discrete diffusion model which iteratively refines a set of Pareto non-dominated sequences optimized across a set of therapeutic properties, including binding affinity, cell membrane permeability, solubility, non-fouling, and non-hemolysis.
10
+
11
+ ## Environment Installation
12
+
13
+ ```bash
14
+ conda env create -f src/environment.yml
15
+
16
+ conda activate peptune
17
+ ```
18
+
19
+ ## Model Pretrained Weights Download
20
+
21
+ Follow the steps below to download the model weights required for this experiment.
22
+
23
+ 1. Download the PepTune pre-trained MDLM checkpoint and place in `checkpoints/`: https://drive.google.com/file/d/1oXGDpKLNF0KX0ZdOcl1NZj5Czk2lSFUn/view?usp=sharing
24
+ 2. Download the pre-trained binding affinity Transformer model and place in `src/scoring/functions/classifiers/`: https://drive.google.com/file/d/128shlEP_-rYAxPgZRCk_n0HBWVbOYSva/view?usp=sharing
25
+
26
+ ## Training Data Download
27
+
28
+ Download the peptide training dataset from https://drive.google.com/file/d/1yCDr641WVjCtECg3nbG0nsMNu8j7d7gp/view?usp=drive_link and unzip it into the `data/` directory:
29
+
30
+ ```bash
31
+ # Download peptide_data.zip into the data/ directory
32
+ cd data/
33
+
34
+ # Unzip the training data
35
+ unzip peptide_data.zip
36
+
37
+ cd ..
38
+ ```
39
+
40
+ After unzipping, the data should be located at `data/peptide_data/`.
41
+
42
+ ## Repository Structure
43
+
44
+ ```
45
+ PepTune/
46
+ ├── src/
47
+ │ ├── train_peptune.py # Main training script
48
+ │ ├── generate_mcts.py # MCTS-guided peptide generation
49
+ │ ├── generate_unconditional.py # Unconditional generation
50
+ │ ├── diffusion.py # Core masked discrete diffusion model
51
+ │ ├── pareto_mcts.py # Pareto-front MCTS implementation
52
+ │ ├── roformer.py # RoFormer backbone
53
+ │ ├── noise_schedule.py # Noise scheduling (loglinear, logpoly)
54
+ │ ├── config.yaml # Hydra configuration
55
+ │ ├── config.py # Argparse configuration
56
+ │ ├── environment.yml # Conda environment
57
+ │ ├── scoring/ # Therapeutic property scoring
58
+ │ │ ├── scoring_functions.py # Unified scoring interface
59
+ │ │ └── functions/ # Individual property predictors
60
+ │ │ ├── binding.py
61
+ │ │ ├── hemolysis.py
62
+ │ │ ├── nonfouling.py
63
+ │ │ ├── permeability.py
64
+ │ │ ├── solubility.py
65
+ │ │ └── classifiers/ # Pre-trained scoring model weights
66
+ │ ├── tokenizer/ # SMILES SPE tokenizer
67
+ │ │ ├── my_tokenizers.py
68
+ │ │ ├── new_vocab.txt
69
+ │ │ └── new_splits.txt
70
+ │ └── utils/ # Utilities & PeptideAnalyzer
71
+ │ ├── app.py
72
+ │ ├── generate_utils.py
73
+ │ └── utils.py
74
+ ├── scripts/ # Shell scripts for running experiments
75
+ │ ├── train.sh # Pre-training
76
+ │ ├── generate_mcts.sh # MCTS-guided generation
77
+ │ └── generate_unconditional.sh # Unconditional generation
78
+ ├── data/ # Training data
79
+ │ ├── dataloading_for_dynamic_batching.py
80
+ │ └── dataset.py
81
+ ├── checkpoints/ # Model checkpoints
82
+ └── assets/ # Figures
83
+ ```
84
+
85
+ ## Pre-training
86
+
87
+ Before running, fill in `HOME_LOC` and `ENV_LOC` in `scripts/train.sh` and `base_path` in `src/config.yaml` to match your paths.
88
+
89
+ ```bash
90
+ chmod +x scripts/train.sh
91
+
92
+ nohup ./scripts/train.sh > train.log 2>&1 &
93
+ ```
94
+
95
+ Training uses Hydra configuration from `src/config.yaml`. Key settings:
96
+ - **Backbone:** RoFormer (768 hidden, 8 layers, 12 heads)
97
+ - **Optimizer:** AdamW (lr=3e-4, weight_decay=0.075)
98
+ - **Data:** 11M SMILES peptide dataset with dynamic batching by length
99
+ - **Precision:** fp64
100
+ - Checkpoints saved to `checkpoints/` (monitors `val/nll`, saves top 10)
101
+
102
+ ## MCTS-Guided Peptide Generation
103
+
104
+ Generate therapeutic peptides optimized across multiple objectives using Monte-Carlo Tree Search.
105
+
106
+ 1. Fill in `base_path` in `src/config.yaml` and `src/scoring/scoring_functions.py`.
107
+ 2. Fill in `HOME_LOC` in `scripts/generate_mcts.sh`.
108
+ 3. Create output directories: `mkdir -p results logs`
109
+
110
+ ```bash
111
+ chmod +x scripts/generate_mcts.sh
112
+
113
+ # Usage: ./scripts/generate_mcts.sh [PROT_NAME] [PROT_NAME2] [MODE] [MODEL] [LENGTH] [EPOCH]
114
+ # Example: Generate peptides targeting GFAP with length 100
115
+ nohup ./scripts/generate_mcts.sh gfap "" 2 mcts 100 7 > generate.log 2>&1 &
116
+ ```
117
+
118
+ ### Available Target Proteins
119
+
120
+ | Name | Target |
121
+ |------|--------|
122
+ | `amhr` | AMH Receptor |
123
+ | `tfr` | Transferrin Receptor |
124
+ | `gfap` | Glial Fibrillary Acidic Protein |
125
+ | `glp1` | GLP-1 Receptor |
126
+ | `glast` | Excitatory Amino Acid Transporter |
127
+ | `ncam` | Neural Cell Adhesion Molecule |
128
+ | `cereblon` | Cereblon (CRBN) |
129
+ | `ligase` | E3 Ubiquitin Ligase |
130
+ | `skp2` | S-Phase Kinase-Associated Protein 2 |
131
+ | `p53` | Tumor Suppressor p53 |
132
+ | `egfp` | Enhanced Green Fluorescent Protein |
133
+
134
+ To specify a custom target protein, override `+prot_seq=<amino acid sequence>` and `+prot_name=<name>` as Hydra arguments in the generation script.
135
+
136
+ ### Scoring Objectives
137
+
138
+ PepTune jointly optimizes across five therapeutic properties via the integrated scoring suite:
139
+
140
+ | Objective | Property | Model |
141
+ |-----------|----------|-------|
142
+ | `binding_affinity1` | Binding affinity to target protein | Cross-attention Transformer |
143
+ | `solubility` | Aqueous solubility | XGBoost on SMILES CNN embeddings |
144
+ | `hemolysis` | Non-hemolytic | SMILES binary classifier |
145
+ | `nonfouling` | Non-fouling | SMILES binary classifier |
146
+ | `permeability` | Cell membrane permeability | PAMPA CNN |
147
+
148
+ ### Default MCTS Hyperparameters
149
+
150
+ These can be overridden via Hydra config overrides:
151
+
152
+ | Parameter | Default | Description |
153
+ |-----------|---------|-------------|
154
+ | `mcts.num_children` | 50 | Branching factor per MCTS node |
155
+ | `mcts.num_iter` | 128 | Number of MCTS iterations |
156
+ | `mcts.num_objectives` | 5 | Number of optimization objectives |
157
+ | `sampling.steps` | 128 | Diffusion denoising steps |
158
+ | `sampling.seq_length` | 200 | Generated peptide length |
159
+
160
+ ## Unconditional Generation
161
+
162
+ Generate peptides without property guidance:
163
+
164
+ ```bash
165
+ chmod +x scripts/generate_unconditional.sh
166
+
167
+ nohup ./scripts/generate_unconditional.sh > generate_unconditional.log 2>&1 &
168
+ ```
169
+
170
+ ## Evaluation
171
+
172
+ To summarize metrics after generation, fill in `path` and `prot_name` in `src/metrics.py` and run:
173
+
174
+ ```bash
175
+ python src/metrics.py
176
+ ```
177
+
178
+ ## Citation
179
+
180
+ If you find this repository helpful for your publications, please consider citing our paper:
181
+
182
+ ```bibtex
183
+ @article{tang2025peptune,
184
+ title={Peptune: De novo generation of therapeutic peptides with multi-objective-guided discrete diffusion},
185
+ author={Tang, Sophia and Zhang, Yinuo and Chatterjee, Pranam},
186
+ journal={42nd International Conference on Machine Learning},
187
+ year={2025}
188
+ }
189
+ ```
190
+
191
+ ## License
192
+
193
+ To use this repository, you agree to abide by the MIT License.
assets/mcts.png ADDED

Git LFS Details

  • SHA256: e63bdc835269660e4b7bda69973bd60611b61045f25c5c07a9baa277e31d2acd
  • Pointer size: 132 Bytes
  • Size of remote file: 1.67 MB
assets/mdlm.png ADDED

Git LFS Details

  • SHA256: 2944b0a2fde891d883a765f29dd235c877cea5bf3c5117bd7423cab7f3102fa3
  • Pointer size: 131 Bytes
  • Size of remote file: 432 kB
assets/peptune.png ADDED

Git LFS Details

  • SHA256: f6e3bbdab7e5e9c435248796b9cf9d7eca6a41354d80556bb37cf5d01920830c
  • Pointer size: 131 Bytes
  • Size of remote file: 210 kB
assets/poster.png ADDED

Git LFS Details

  • SHA256: 6c35b4c6a3c7e55f5ac821ba36b1a78fadbeb9fb6927e324984031c31428acee
  • Pointer size: 132 Bytes
  • Size of remote file: 1.05 MB
data/dataloading_for_dynamic_batching.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env
2
+ import torch
3
+ from torch.utils.data import Dataset, DataLoader
4
+ from datasets import Dataset,load_from_disk
5
+ import sys
6
+ import lightning.pytorch as pl
7
+ from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
8
+ from functools import partial
9
+ import re
10
+
11
+
12
+ class DynamicBatchingDataset(Dataset):
13
+ def __init__(self, dataset_dict, tokenizer):
14
+ print('Initializing dataset...')
15
+ self.dataset_dict = {
16
+ 'attention_mask': [torch.tensor(item) for item in dataset_dict['attention_mask']],
17
+ 'input_ids': [torch.tensor(item) for item in dataset_dict['input_ids']],
18
+ 'labels': dataset_dict['labels']
19
+ }
20
+ self.tokenizer = tokenizer
21
+
22
+ def __len__(self):
23
+ return len(self.dataset_dict['attention_mask'])
24
+
25
+ def __getitem__(self, idx):
26
+ if isinstance(idx, int):
27
+ return {
28
+ 'input_ids': self.dataset_dict['input_ids'][idx],
29
+ 'attention_mask': self.dataset_dict['attention_mask'][idx],
30
+ 'labels': self.dataset_dict['labels'][idx]
31
+ }
32
+ elif isinstance(idx, list):
33
+ return {
34
+ 'input_ids': [self.dataset_dict['input_ids'][i] for i in idx],
35
+ 'attention_mask': [self.dataset_dict['attention_mask'][i] for i in idx],
36
+ 'labels': [self.dataset_dict['labels'][i] for i in idx]
37
+ }
38
+ else:
39
+ raise ValueError(f"Expected idx to be int or list, but got {type(idx)}")
40
+
41
+ class CustomDataModule(pl.LightningDataModule):
42
+ def __init__(self, dataset_path, tokenizer):
43
+ super().__init__()
44
+ self.dataset = load_from_disk(dataset_path)
45
+ self.tokenizer = tokenizer
46
+
47
+ def peptide_bond_mask(self, smiles_list):
48
+ """
49
+ Returns a mask with shape (batch_size, seq_length) that has 1 at the locations
50
+ of recognized bonds in the positions dictionary and 0 elsewhere.
51
+
52
+ Args:
53
+ smiles_list: List of peptide SMILES strings (batch of SMILES strings).
54
+
55
+ Returns:
56
+ np.ndarray: A mask of shape (batch_size, seq_length) with 1s at bond positions.
57
+ """
58
+ # Initialize the batch mask
59
+ batch_size = len(smiles_list)
60
+ max_seq_length = 1035 #max(len(smiles) for smiles in smiles_list) # Find the longest SMILES
61
+ mask = torch.zeros((batch_size, max_seq_length), dtype=torch.int) # Mask filled with zeros
62
+
63
+ bond_patterns = [
64
+ (r'OC\(=O\)', 'ester'),
65
+ (r'N\(C\)C\(=O\)', 'n_methyl'),
66
+ (r'N[12]C\(=O\)', 'peptide'), # Pro peptide bonds
67
+ (r'NC\(=O\)', 'peptide'), # Regular peptide bonds
68
+ (r'C\(=O\)N\(C\)', 'n_methyl'),
69
+ (r'C\(=O\)N[12]?', 'peptide')
70
+ ]
71
+
72
+ for batch_idx, smiles in enumerate(smiles_list):
73
+ positions = []
74
+ used = set()
75
+
76
+ # Identify bonds
77
+ for pattern, bond_type in bond_patterns:
78
+ for match in re.finditer(pattern, smiles):
79
+ if not any(p in range(match.start(), match.end()) for p in used):
80
+ positions.append({
81
+ 'start': match.start(),
82
+ 'end': match.end(),
83
+ 'type': bond_type,
84
+ 'pattern': match.group()
85
+ })
86
+ used.update(range(match.start(), match.end()))
87
+
88
+ # Update the mask for the current SMILES
89
+ for pos in positions:
90
+ mask[batch_idx, pos['start']:pos['end']] = 1
91
+
92
+ return mask
93
+
94
+ def peptide_token_mask(self, smiles_list, token_lists):
95
+ """
96
+ Returns a mask with shape (batch_size, num_tokens) that has 1 for tokens
97
+ where any part of the token overlaps with a peptide bond, and 0 elsewhere.
98
+
99
+ Args:
100
+ smiles_list: List of peptide SMILES strings (batch of SMILES strings).
101
+ token_lists: List of tokenized SMILES strings (split into tokens).
102
+
103
+ Returns:
104
+ np.ndarray: A mask of shape (batch_size, num_tokens) with 1s for peptide bond tokens.
105
+ """
106
+ # Initialize the batch mask
107
+ batch_size = len(smiles_list)
108
+ token_seq_length = max(len(tokens) for tokens in token_lists) # Find the longest tokenized sequence
109
+ tokenized_masks = torch.zeros((batch_size, token_seq_length), dtype=torch.int) # Mask filled with zeros
110
+ atomwise_masks = self.peptide_bond_mask(smiles_list)
111
+
112
+
113
+ for batch_idx, atomwise_mask in enumerate(atomwise_masks):
114
+ token_seq = token_lists[batch_idx]
115
+ atom_idx = 0
116
+
117
+ for token_idx, token in enumerate(token_seq):
118
+ if token_idx != 0 and token_idx != len(token_seq) - 1:
119
+ if torch.sum(atomwise_mask[atom_idx:atom_idx+len(token)]) >= 1:
120
+ tokenized_masks[batch_idx][token_idx] = 1
121
+ atom_idx += len(token)
122
+
123
+ return tokenized_masks
124
+
125
+ def collate_fn(self, batch):
126
+ item = batch[0]
127
+
128
+ token_array = self.tokenizer.get_token_split(item['input_ids'])
129
+ bond_mask = self.peptide_token_mask(item['labels'], token_array)
130
+
131
+ return {
132
+ 'input_ids': item['input_ids'],
133
+ 'attention_mask': item['attention_mask'],
134
+ 'bond_mask': bond_mask
135
+ }
136
+
137
+ def train_dataloader(self):
138
+ train_dataset = DynamicBatchingDataset(self.dataset['train'], tokenizer=self.tokenizer)
139
+ return DataLoader(
140
+ train_dataset,
141
+ batch_size=1,
142
+ collate_fn=self.collate_fn, # Use the instance method
143
+ shuffle=True,
144
+ num_workers=12,
145
+ pin_memory=True
146
+ )
147
+
148
+ def val_dataloader(self):
149
+ val_dataset = DynamicBatchingDataset(self.dataset['val'], tokenizer=self.tokenizer)
150
+ return DataLoader(
151
+ val_dataset,
152
+ batch_size=1,
153
+ collate_fn=self.collate_fn, # Use the instance method
154
+ num_workers=8,
155
+ pin_memory=True
156
+ )
data/dataset.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import re
3
+ import torch
4
+
5
+ import utils
6
+
7
+ from torch.utils.data import Dataset, DataLoader
8
+ import lightning.pytorch as pl
9
+ from functools import partial
10
+ import sys
11
+
12
+ class CustomDataset(Dataset):
13
+ def __init__(self, dataset, indices):
14
+ self.dataset = dataset
15
+ self.indices = indices
16
+
17
+ def __len__(self):
18
+ return len(self.indices)
19
+
20
+ def __getitem__(self, idx):
21
+ actual_idx = int(self.indices[idx])
22
+ item = self.dataset[actual_idx]
23
+ return item
24
+
25
+
26
+ # for weighting losses of peptide bonds
27
+ def peptide_bond_mask(smiles_list):
28
+ """
29
+ Returns a mask with shape (batch_size, seq_length) that has 1 at the locations
30
+ of recognized bonds in the positions dictionary and 0 elsewhere.
31
+
32
+ Args:
33
+ smiles_list: List of peptide SMILES strings (batch of SMILES strings).
34
+
35
+ Returns:
36
+ np.ndarray: A mask of shape (batch_size, seq_length) with 1s at bond positions.
37
+ """
38
+ # Initialize the batch mask
39
+ batch_size = len(smiles_list)
40
+ max_seq_length = max(len(smiles) for smiles in smiles_list) # Find the longest SMILES
41
+ mask = torch.zeros((batch_size, max_seq_length), dtype=torch.int) # Mask filled with zeros
42
+
43
+ bond_patterns = [
44
+ (r'OC\(=O\)', 'ester'),
45
+ (r'N\(C\)C\(=O\)', 'n_methyl'),
46
+ (r'N[12]C\(=O\)', 'peptide'), # Pro peptide bonds
47
+ (r'NC\(=O\)', 'peptide'), # Regular peptide bonds
48
+ (r'C\(=O\)N\(C\)', 'n_methyl'),
49
+ (r'C\(=O\)N[12]?', 'peptide')
50
+ ]
51
+
52
+ for batch_idx, smiles in enumerate(smiles_list):
53
+ positions = []
54
+ used = set()
55
+
56
+ # Identify bonds
57
+ for pattern, bond_type in bond_patterns:
58
+ for match in re.finditer(pattern, smiles):
59
+ if not any(p in range(match.start(), match.end()) for p in used):
60
+ positions.append({
61
+ 'start': match.start(),
62
+ 'end': match.end(),
63
+ 'type': bond_type,
64
+ 'pattern': match.group()
65
+ })
66
+ used.update(range(match.start(), match.end()))
67
+
68
+ # Update the mask for the current SMILES
69
+ for pos in positions:
70
+ mask[batch_idx, pos['start']:pos['end']] = 1
71
+
72
+ return mask
73
+
74
+ def peptide_token_mask(smiles_list, token_lists):
75
+ """
76
+ Returns a mask with shape (batch_size, num_tokens) that has 1 for tokens
77
+ where any part of the token overlaps with a peptide bond, and 0 elsewhere.
78
+
79
+ Args:
80
+ smiles_list: List of peptide SMILES strings (batch of SMILES strings).
81
+ token_lists: List of tokenized SMILES strings (split into tokens).
82
+
83
+ Returns:
84
+ np.ndarray: A mask of shape (batch_size, num_tokens) with 1s for peptide bond tokens.
85
+ """
86
+ # Initialize the batch mask
87
+ batch_size = len(smiles_list)
88
+ token_seq_length = max(len(tokens) for tokens in token_lists) # Find the longest tokenized sequence
89
+ tokenized_masks = torch.zeros((batch_size, token_seq_length), dtype=torch.int) # Mask filled with zeros
90
+ atomwise_masks = peptide_bond_mask(smiles_list)
91
+
92
+
93
+ for batch_idx, atomwise_mask in enumerate(atomwise_masks):
94
+ token_seq = token_lists[batch_idx]
95
+ atom_idx = 0
96
+
97
+ for token_idx, token in enumerate(token_seq):
98
+ if token_idx != 0 and token_idx != len(token_seq) - 1:
99
+ if torch.sum(atomwise_mask[atom_idx:atom_idx+len(token)]) >= 1:
100
+ tokenized_masks[batch_idx][token_idx] = 1
101
+ atom_idx += len(token)
102
+
103
+ return tokenized_masks
104
+
105
+ def extract_amino_acid_sequence(helm_string):
106
+ """
107
+ Extracts the amino acid sequence from a HELM peptide notation and outputs it as an array,
108
+ removing any brackets around each amino acid.
109
+
110
+ Args:
111
+ helm_string (str): The HELM notation string for a peptide.
112
+
113
+ Returns:
114
+ list: A list containing each amino acid in sequence without brackets.
115
+ """
116
+ # Use regex to find the pattern within `{}` brackets following "PEPTIDE" followed by a number
117
+ matches = re.findall(r'PEPTIDE\d+\{([^}]+)\}', helm_string)
118
+
119
+ if matches:
120
+ # Join all matched sequences and split by dots to get individual amino acids
121
+ amino_acid_sequence = []
122
+ for match in matches:
123
+ sequence = match.replace('[', '').replace(']', '').split('.')
124
+ amino_acid_sequence.extend(sequence)
125
+ return amino_acid_sequence
126
+ else:
127
+ return "Invalid HELM notation or no peptide sequence found."
128
+
129
+ def helm_collate_fn(batch, tokenizer):
130
+ sequences = [item['HELM'] for item in batch]
131
+
132
+ max_len = 0
133
+ for sequence in sequences:
134
+ seq_len = len(extract_amino_acid_sequence(sequence))
135
+ if seq_len > max_len:
136
+ max_len = seq_len
137
+
138
+ tokens = tokenizer(sequences, return_tensors='pt', padding=True, truncation=True, max_length=1024)
139
+
140
+ return {
141
+ 'input_ids': tokens['input_ids'],
142
+ 'attention_mask': tokens['attention_mask']
143
+ }
144
+
145
+
146
+ def collate_fn(batch, tokenizer):
147
+ """Standard data collator that truncates/pad sequences based on max_length"""
148
+ valid_sequences = []
149
+ valid_items = []
150
+
151
+ for item in batch:
152
+ try:
153
+ test_tokens = tokenizer([item['SMILES']], return_tensors='pt', padding=False, truncation=True, max_length=1035)
154
+ valid_sequences.append(item['SMILES'])
155
+ valid_items.append(item)
156
+ except Exception as e:
157
+ print(f"Skipping sequence due to: {str(e)}")
158
+ continue
159
+
160
+ #sequences = [item['SMILES'] for item in batch]
161
+ #max_len = max([len(seq) for seq in sequences])
162
+ #labels = torch.tensor([item['labels'] for item in batch], dtype=torch.float32)
163
+
164
+ tokens = tokenizer(valid_sequences, return_tensors='pt', padding=True, truncation=True, max_length=1035)
165
+
166
+ token_array = tokenizer.get_token_split(tokens['input_ids'])
167
+ bond_mask = peptide_token_mask(valid_sequences, token_array)
168
+ #attention_masks = torch.ones(tokens.size()[:2], dtype=torch.bool)
169
+
170
+ return {
171
+ 'input_ids': tokens['input_ids'],
172
+ 'attention_mask': tokens['attention_mask'],
173
+ 'bond_mask': bond_mask
174
+ }
175
+
176
+
177
+ class CustomDataModule(pl.LightningDataModule):
178
+ def __init__(self, train_dataset, val_dataset, test_dataset, tokenizer, batch_size, collate_fn=collate_fn):
179
+ super().__init__()
180
+ self.train_dataset = train_dataset
181
+ self.val_dataset = val_dataset
182
+ #self.test_dataset = test_dataset
183
+ self.batch_size = batch_size
184
+ self.tokenizer = tokenizer
185
+ self.collate_fn = collate_fn
186
+
187
+ def train_dataloader(self):
188
+ return DataLoader(self.train_dataset,
189
+ batch_size=self.batch_size,
190
+ collate_fn=partial(self.collate_fn, tokenizer=self.tokenizer),
191
+ num_workers=8,
192
+ pin_memory=True
193
+ )
194
+
195
+
196
+ def val_dataloader(self):
197
+ return DataLoader(self.val_dataset,
198
+ batch_size=self.batch_size,
199
+ collate_fn=partial(self.collate_fn, tokenizer=self.tokenizer),
200
+ num_workers=8,
201
+ pin_memory=True
202
+ )
203
+
204
+ """def test_dataloader(self):
205
+ return DataLoader(self.test_dataset, batch_size=self.batch_size,
206
+ collate_fn=partial(self.collate_fn, tokenizer=self.tokenizer),
207
+ num_workers=8, pin_memory=True)"""
scripts/generate_mcts.sh ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ HOME_LOC=/path/to/your/home/PepTune
4
+ SCRIPT_LOC=$HOME_LOC/src
5
+ LOG_LOC=$HOME_LOC/logs
6
+ DATE=$(date +%m_%d)
7
+ SPECIAL_PREFIX='mcts'
8
+ PYTHON_EXECUTABLE=python
9
+
10
+ # ===================================================================
11
+ # Default parameters (can be overridden by command line arguments)
12
+ # Available proteins: amhr, tfr, gfap, glp1, glast, ncam, cereblon, ligase, skp2, p53, egfp
13
+ PROT_NAME1=${1:-"gfap"}
14
+ PROT_NAME2=${2:-""}
15
+ MODE=${3:-"2"}
16
+ MODEL=${4:-"mcts"}
17
+ LENGTH=${5:-"100"}
18
+ EPOCH=${6:-"7"}
19
+ CKPT=$HOME_LOC/checkpoints/epoch13-new-tokenizer.ckpt
20
+
21
+ # ===================================================================
22
+ echo "Activating conda environment..."
23
+ source "$(conda info --base)/etc/profile.d/conda.sh"
24
+ conda activate peptune
25
+
26
+ mkdir -p "${HOME_LOC}/${PROT_NAME1}"
27
+ mkdir -p "${LOG_LOC}"
28
+
29
+ echo "Running MCTS generation with parameters:"
30
+ echo " Protein Name 1: $PROT_NAME1"
31
+ echo " Protein Name 2: $PROT_NAME2"
32
+ echo " Mode: $MODE"
33
+ echo " Model: $MODEL"
34
+ echo " Length: $LENGTH"
35
+ echo " Epoch: $EPOCH"
36
+
37
+ # Build Hydra override arguments
38
+ mkdir -p "${LOG_LOC}"
39
+
40
+ HYDRA_ARGS="+prot_name1=$PROT_NAME1 ++mode=$MODE +model_type=$MODEL +length=$LENGTH +epoch=$EPOCH"
41
+ if [ -n "$PROT_NAME2" ]; then
42
+ HYDRA_ARGS="$HYDRA_ARGS +prot_name2=$PROT_NAME2"
43
+ fi
44
+
45
+ cd "$SCRIPT_LOC"
46
+
47
+ # Run the MCTS generation script with Hydra overrides
48
+ $PYTHON_EXECUTABLE $SCRIPT_LOC/generate_mcts.py \
49
+ --config-path "$SCRIPT_LOC" \
50
+ --config-name config \
51
+ base_path="$HOME_LOC" \
52
+ eval.checkpoint_path="$CKPT" \
53
+ $HYDRA_ARGS >> ${LOG_LOC}/${DATE}_${SPECIAL_PREFIX}_generate.log 2>&1
54
+
55
+ echo "Generation complete. Check logs at: ${LOG_LOC}/${DATE}_${SPECIAL_PREFIX}_generate.log"
56
+
57
+ conda deactivate
scripts/generate_unconditional.sh ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ HOME_LOC=/path/to/your/home/PepTune
4
+ SCRIPT_LOC=$HOME_LOC/src
5
+ LOG_LOC=$HOME_LOC/logs
6
+ DATE=$(date +%m_%d)
7
+ SPECIAL_PREFIX='unconditional'
8
+ PYTHON_EXECUTABLE=python
9
+
10
+ # ===================================================================
11
+ source "$(conda info --base)/etc/profile.d/conda.sh"
12
+ conda activate peptune
13
+
14
+ $PYTHON_EXECUTABLE $SCRIPT_LOC/generate_unconditional.py >> ${DATE}_${SPECIAL_PREFIX}_generate.log 2>&1
15
+
16
+ conda deactivate
scripts/train.sh ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ HOME_LOC=/path/to/your/home/PepTune
4
+ ENV_LOC=/path/to/your/envs/peptune
5
+ SCRIPT_LOC=$HOME_LOC/src
6
+ LOG_LOC=$HOME_LOC/logs
7
+ DATE=$(date +%m_%d)
8
+ SPECIAL_PREFIX='11M-ablation-all-losses'
9
+ # set 3 have skip connection
10
+ PYTHON_EXECUTABLE=$ENV_LOC/bin/python
11
+
12
+ # ===================================================================
13
+ source "$(conda info --base)/etc/profile.d/conda.sh"
14
+ conda activate $ENV_LOC
15
+
16
+ $PYTHON_EXECUTABLE $SCRIPT_LOC/train_peptune.py >> ${DATE}_${SPECIAL_PREFIX}_train.log 2>&1
17
+
18
+ conda deactivate
src/config.py ADDED
@@ -0,0 +1,319 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+
4
+
5
+ def get_parser():
6
+ parser = argparse.ArgumentParser(description='PepTune Training and Evaluation')
7
+
8
+ # Noise parameters
9
+ noise_group = parser.add_argument_group('noise')
10
+ noise_group.add_argument('--noise-type', type=str, default='loglinear',
11
+ help='Type of noise schedule')
12
+ noise_group.add_argument('--sigma-min', type=float, default=1e-4,
13
+ help='Minimum sigma value')
14
+ noise_group.add_argument('--sigma-max', type=float, default=20,
15
+ help='Maximum sigma value')
16
+ noise_group.add_argument('--state-dependent', action='store_true', default=True,
17
+ help='Use state-dependent noise')
18
+
19
+ # Base parameters
20
+ parser.add_argument('--base-path', type=str, default='/path/to/PepTune',
21
+ help='Base path to PepTune')
22
+ parser.add_argument('--mode', type=str, default='ppl_eval',
23
+ choices=['train', 'ppl_eval', 'sample_eval'],
24
+ help='Running mode')
25
+ parser.add_argument('--diffusion', type=str, default='absorbing_state',
26
+ help='Diffusion type')
27
+ parser.add_argument('--vocab', type=str, default='old_smiles',
28
+ choices=['old_smiles', 'new_smiles', 'selfies', 'helm'],
29
+ help='Vocabulary type')
30
+ parser.add_argument('--backbone', type=str, default='roformer',
31
+ choices=['peptideclm', 'helmgpt', 'dit', 'roformer', 'finetune_roformer'],
32
+ help='Model backbone')
33
+ parser.add_argument('--parameterization', type=str, default='subs',
34
+ help='Parameterization type')
35
+ parser.add_argument('--time-conditioning', action='store_true', default=False,
36
+ help='Use time conditioning')
37
+ parser.add_argument('--T', type=int, default=0,
38
+ help='Number of diffusion steps (0 for continuous time, 1000 for discrete)')
39
+ parser.add_argument('--subs-masking', action='store_true', default=False,
40
+ help='Use substitution masking')
41
+ parser.add_argument('--seed', type=int, default=42,
42
+ help='Random seed')
43
+
44
+ # MCTS parameters
45
+ mcts_group = parser.add_argument_group('mcts')
46
+ mcts_group.add_argument('--mcts-num-children', type=int, default=50,
47
+ help='Number of children in MCTS')
48
+ mcts_group.add_argument('--mcts-num-objectives', type=int, default=5,
49
+ help='Number of objectives in MCTS')
50
+ mcts_group.add_argument('--mcts-topk', type=int, default=100,
51
+ help='Top-k for MCTS')
52
+ mcts_group.add_argument('--mcts-mask-token', type=int, default=4,
53
+ help='Mask token ID')
54
+ mcts_group.add_argument('--mcts-num-iter', type=int, default=128,
55
+ help='Number of MCTS iterations')
56
+ mcts_group.add_argument('--mcts-sampling', type=int, default=0,
57
+ help='Sampling strategy (0 for gumbel, >0 for top-k)')
58
+ mcts_group.add_argument('--mcts-invalid-penalty', type=float, default=0.5,
59
+ help='Penalty for invalid sequences')
60
+ mcts_group.add_argument('--mcts-sample-prob', type=float, default=1.0,
61
+ help='Sampling probability')
62
+ mcts_group.add_argument('--mcts-perm', action='store_true', default=True,
63
+ help='Use permutation in MCTS')
64
+ mcts_group.add_argument('--mcts-dual', action='store_true', default=False,
65
+ help='Use dual mode')
66
+ mcts_group.add_argument('--mcts-single', action='store_true', default=False,
67
+ help='Use single mode')
68
+ mcts_group.add_argument('--mcts-time-dependent', action='store_true', default=True,
69
+ help='Use time-dependent MCTS')
70
+
71
+ # Data parameters
72
+ data_group = parser.add_argument_group('data')
73
+ data_group.add_argument('--train-data', type=str,
74
+ default='/path/to/your/home/PepTune/data/peptide_data',
75
+ help='Path to training data')
76
+ data_group.add_argument('--valid-data', type=str,
77
+ default='/path/to/your/home/PepTune/data/peptide_data',
78
+ help='Path to validation data')
79
+ data_group.add_argument('--data-batching', type=str, default='wrapping',
80
+ choices=['padding', 'wrapping'],
81
+ help='Batching strategy')
82
+
83
+ # Loader parameters
84
+ loader_group = parser.add_argument_group('loader')
85
+ loader_group.add_argument('--global-batch-size', type=int, default=64,
86
+ help='Global batch size')
87
+ loader_group.add_argument('--eval-global-batch-size', type=int, default=None,
88
+ help='Evaluation global batch size (defaults to global-batch-size)')
89
+ loader_group.add_argument('--num-workers', type=int, default=None,
90
+ help='Number of dataloader workers (defaults to available CPUs)')
91
+ loader_group.add_argument('--pin-memory', action='store_true', default=True,
92
+ help='Pin memory for dataloaders')
93
+
94
+ # Sampling parameters
95
+ sampling_group = parser.add_argument_group('sampling')
96
+ sampling_group.add_argument('--predictor', type=str, default='ddpm_cache',
97
+ choices=['analytic', 'ddpm', 'ddpm_cache'],
98
+ help='Predictor type for sampling')
99
+ sampling_group.add_argument('--num-sequences', type=int, default=100,
100
+ help='Number of sequences to generate')
101
+ sampling_group.add_argument('--sampling-eps', type=float, default=1e-3,
102
+ help='Sampling epsilon')
103
+ sampling_group.add_argument('--steps', type=int, default=128,
104
+ help='Number of sampling steps')
105
+ sampling_group.add_argument('--seq-length', type=int, default=100,
106
+ help='Sequence length')
107
+ sampling_group.add_argument('--noise-removal', action='store_true', default=True,
108
+ help='Use noise removal')
109
+ sampling_group.add_argument('--num-sample-batches', type=int, default=2,
110
+ help='Number of sample batches')
111
+ sampling_group.add_argument('--num-sample-log', type=int, default=2,
112
+ help='Number of samples to log')
113
+ sampling_group.add_argument('--stride-length', type=int, default=1,
114
+ help='Stride length for sampling')
115
+ sampling_group.add_argument('--num-strides', type=int, default=1,
116
+ help='Number of strides')
117
+
118
+ # Training parameters
119
+ training_group = parser.add_argument_group('training')
120
+ training_group.add_argument('--antithetic-sampling', action='store_true', default=True,
121
+ help='Use antithetic sampling')
122
+ training_group.add_argument('--training-sampling-eps', type=float, default=1e-3,
123
+ help='Training sampling epsilon')
124
+ training_group.add_argument('--focus-mask', action='store_true', default=False,
125
+ help='Use focus mask')
126
+ training_group.add_argument('--accumulator', action='store_true', default=False,
127
+ help='Use accumulator')
128
+
129
+ # Evaluation parameters
130
+ eval_group = parser.add_argument_group('eval')
131
+ eval_group.add_argument('--checkpoint-path', type=str, default=None,
132
+ help='Path to checkpoint for evaluation')
133
+ eval_group.add_argument('--disable-ema', action='store_true', default=False,
134
+ help='Disable EMA')
135
+ eval_group.add_argument('--compute-generative-perplexity', action='store_true', default=False,
136
+ help='Compute generative perplexity')
137
+ eval_group.add_argument('--perplexity-batch-size', type=int, default=8,
138
+ help='Batch size for perplexity computation')
139
+ eval_group.add_argument('--compute-perplexity-on-sanity', action='store_true', default=False,
140
+ help='Compute perplexity on sanity check')
141
+ eval_group.add_argument('--gen-ppl-eval-model', type=str, default='gpt2-large',
142
+ help='Model for generative perplexity evaluation')
143
+ eval_group.add_argument('--generate-samples', action='store_true', default=True,
144
+ help='Generate samples during evaluation')
145
+ eval_group.add_argument('--generation-model', type=str, default=None,
146
+ help='Model for generation')
147
+
148
+ # Optimizer parameters
149
+ optim_group = parser.add_argument_group('optim')
150
+ optim_group.add_argument('--weight-decay', type=float, default=0.075,
151
+ help='Weight decay')
152
+ optim_group.add_argument('--lr', type=float, default=3e-4,
153
+ help='Learning rate')
154
+ optim_group.add_argument('--beta1', type=float, default=0.9,
155
+ help='Adam beta1')
156
+ optim_group.add_argument('--beta2', type=float, default=0.999,
157
+ help='Adam beta2')
158
+ optim_group.add_argument('--eps', type=float, default=1e-8,
159
+ help='Adam epsilon')
160
+
161
+ # PepCLM model parameters
162
+ pepclm_group = parser.add_argument_group('pepclm')
163
+ pepclm_group.add_argument('--pepclm-hidden-size', type=int, default=768,
164
+ help='PepCLM hidden size')
165
+ pepclm_group.add_argument('--pepclm-cond-dim', type=int, default=256,
166
+ help='PepCLM conditioning dimension')
167
+ pepclm_group.add_argument('--pepclm-n-heads', type=int, default=20,
168
+ help='PepCLM number of attention heads')
169
+ pepclm_group.add_argument('--pepclm-n-blocks', type=int, default=4,
170
+ help='PepCLM number of blocks')
171
+ pepclm_group.add_argument('--pepclm-dropout', type=float, default=0.5,
172
+ help='PepCLM dropout rate')
173
+ pepclm_group.add_argument('--pepclm-length', type=int, default=512,
174
+ help='PepCLM sequence length')
175
+
176
+ # General model parameters
177
+ model_group = parser.add_argument_group('model')
178
+ model_group.add_argument('--model-type', type=str, default='ddit',
179
+ help='Model type')
180
+ model_group.add_argument('--hidden-size', type=int, default=768,
181
+ help='Model hidden size')
182
+ model_group.add_argument('--cond-dim', type=int, default=128,
183
+ help='Conditioning dimension')
184
+ model_group.add_argument('--length', type=int, default=512,
185
+ help='Sequence length')
186
+ model_group.add_argument('--n-blocks', type=int, default=12,
187
+ help='Number of blocks')
188
+ model_group.add_argument('--n-heads', type=int, default=12,
189
+ help='Number of attention heads')
190
+ model_group.add_argument('--scale-by-sigma', action='store_true', default=True,
191
+ help='Scale by sigma')
192
+ model_group.add_argument('--dropout', type=float, default=0.1,
193
+ help='Dropout rate')
194
+
195
+ # RoFormer parameters
196
+ roformer_group = parser.add_argument_group('roformer')
197
+ roformer_group.add_argument('--roformer-hidden-size', type=int, default=768,
198
+ help='RoFormer hidden size')
199
+ roformer_group.add_argument('--roformer-n-layers', type=int, default=8,
200
+ help='RoFormer number of layers')
201
+ roformer_group.add_argument('--roformer-n-heads', type=int, default=8,
202
+ help='RoFormer number of attention heads')
203
+ roformer_group.add_argument('--roformer-max-position-embeddings', type=int, default=1035,
204
+ help='RoFormer max position embeddings')
205
+
206
+ # HelmGPT parameters
207
+ helmgpt_group = parser.add_argument_group('helmgpt')
208
+ helmgpt_group.add_argument('--helmgpt-hidden-size', type=int, default=256,
209
+ help='HelmGPT hidden size')
210
+ helmgpt_group.add_argument('--helmgpt-embd-pdrop', type=float, default=0.1,
211
+ help='HelmGPT embedding dropout')
212
+ helmgpt_group.add_argument('--helmgpt-resid-pdrop', type=float, default=0.1,
213
+ help='HelmGPT residual dropout')
214
+ helmgpt_group.add_argument('--helmgpt-attn-pdrop', type=float, default=0.1,
215
+ help='HelmGPT attention dropout')
216
+ helmgpt_group.add_argument('--helmgpt-ff-dropout', type=float, default=0.0,
217
+ help='HelmGPT feedforward dropout')
218
+ helmgpt_group.add_argument('--helmgpt-block-size', type=int, default=140,
219
+ help='HelmGPT block size')
220
+ helmgpt_group.add_argument('--helmgpt-n-layer', type=int, default=8,
221
+ help='HelmGPT number of layers')
222
+ helmgpt_group.add_argument('--helmgpt-n-heads', type=int, default=8,
223
+ help='HelmGPT number of attention heads')
224
+
225
+ # Trainer parameters
226
+ trainer_group = parser.add_argument_group('trainer')
227
+ trainer_group.add_argument('--accelerator', type=str, default='cuda',
228
+ help='Accelerator type')
229
+ trainer_group.add_argument('--num-nodes', type=int, default=1,
230
+ help='Number of nodes')
231
+ trainer_group.add_argument('--devices', type=int, default=1,
232
+ help='Number of devices')
233
+ trainer_group.add_argument('--gradient-clip-val', type=float, default=1.0,
234
+ help='Gradient clipping value')
235
+ trainer_group.add_argument('--precision', type=str, default='64-true',
236
+ help='Training precision')
237
+ trainer_group.add_argument('--num-sanity-val-steps', type=int, default=2,
238
+ help='Number of sanity validation steps')
239
+ trainer_group.add_argument('--max-epochs', type=int, default=100,
240
+ help='Maximum number of epochs')
241
+ trainer_group.add_argument('--max-steps', type=int, default=1_000_000,
242
+ help='Maximum number of steps')
243
+ trainer_group.add_argument('--log-every-n-steps', type=int, default=10,
244
+ help='Log every n steps')
245
+ trainer_group.add_argument('--limit-train-batches', type=float, default=1.0,
246
+ help='Limit training batches')
247
+ trainer_group.add_argument('--limit-val-batches', type=float, default=1.0,
248
+ help='Limit validation batches')
249
+ trainer_group.add_argument('--check-val-every-n-epoch', type=int, default=1,
250
+ help='Check validation every n epochs')
251
+
252
+ # WandB parameters
253
+ wandb_group = parser.add_argument_group('wandb')
254
+ wandb_group.add_argument('--wandb-project', type=str, default='peptune',
255
+ help='WandB project name')
256
+ wandb_group.add_argument('--wandb-notes', type=str, default=None,
257
+ help='WandB notes')
258
+ wandb_group.add_argument('--wandb-group', type=str, default=None,
259
+ help='WandB group')
260
+ wandb_group.add_argument('--wandb-job-type', type=str, default=None,
261
+ help='WandB job type')
262
+ wandb_group.add_argument('--wandb-name', type=str, default='sophia-tang',
263
+ help='WandB run name')
264
+ wandb_group.add_argument('--wandb-id', type=str, default=None,
265
+ help='WandB run ID')
266
+
267
+ # Checkpointing parameters
268
+ checkpoint_group = parser.add_argument_group('checkpointing')
269
+ checkpoint_group.add_argument('--save-dir', type=str, default=None,
270
+ help='Directory to save checkpoints')
271
+ checkpoint_group.add_argument('--resume-from-ckpt', action='store_true', default=True,
272
+ help='Resume from checkpoint')
273
+ checkpoint_group.add_argument('--resume-ckpt-path', type=str, default=None,
274
+ help='Path to checkpoint to resume from')
275
+ checkpoint_group.add_argument('--checkpoint-every-n-epochs', type=int, default=1,
276
+ help='Save checkpoint every n epochs')
277
+ checkpoint_group.add_argument('--checkpoint-monitor', type=str, default='val/nll',
278
+ help='Metric to monitor for checkpointing')
279
+ checkpoint_group.add_argument('--checkpoint-save-top-k', type=int, default=10,
280
+ help='Save top k checkpoints')
281
+ checkpoint_group.add_argument('--checkpoint-mode', type=str, default='min',
282
+ choices=['min', 'max'],
283
+ help='Mode for checkpoint monitoring')
284
+ checkpoint_group.add_argument('--checkpoint-dirpath', type=str,
285
+ default='./checkpoints/11M-old-tokenizer',
286
+ help='Directory path for checkpoints')
287
+
288
+ # LR Scheduler parameters
289
+ scheduler_group = parser.add_argument_group('lr_scheduler')
290
+ scheduler_group.add_argument('--lr-warmup-steps', type=int, default=2500,
291
+ help='Number of warmup steps for learning rate')
292
+
293
+ return parser
294
+
295
+
296
+ def get_args():
297
+ """Parse and return arguments."""
298
+ parser = get_parser()
299
+ args = parser.parse_args()
300
+
301
+ # Post-process arguments
302
+ if args.eval_global_batch_size is None:
303
+ args.eval_global_batch_size = args.global_batch_size
304
+
305
+ if args.num_workers is None:
306
+ args.num_workers = len(os.sched_getaffinity(0))
307
+
308
+ if args.wandb_id is None:
309
+ args.wandb_id = f"{args.wandb_name}_nov12_set2"
310
+
311
+ if args.save_dir is None:
312
+ args.save_dir = os.getcwd()
313
+
314
+ return args
315
+
316
+
317
+ if __name__ == '__main__':
318
+ args = get_args()
319
+ print(args)
src/config.yaml ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ noise:
2
+ type: loglinear
3
+ sigma_min: 1e-4
4
+ sigma_max: 20
5
+ state_dependent: True
6
+
7
+ base_path: /path/to/your/home/PepTune
8
+ mode: train # train / ppl_eval / sample_eval
9
+ diffusion: absorbing_state
10
+ vocab: old_smiles # old_smiles / new_smiles / selfies / helm
11
+ backbone: roformer # peptideclm / helmgpt / dit / roformer / finetune_roformer
12
+ parameterization: subs # subs
13
+ time_conditioning: False
14
+ T: 0 # 0 (continuous time) / 1000
15
+ subs_masking: False
16
+
17
+ seed: 42
18
+
19
+ mcts:
20
+ num_children: 50
21
+ num_objectives: 5
22
+ topk: 100
23
+ mask_token: 4
24
+ num_iter: 128
25
+ sampling: 0 # 0 is gumbel sampling / > 0 samples children from top k probs
26
+ invalid_penalty: 0.5
27
+ sample_prob: 1.0
28
+ perm: True
29
+ dual: False
30
+ single: False
31
+ time_dependent: True
32
+
33
+ lr_scheduler:
34
+ _target_: transformers.get_constant_schedule_with_warmup
35
+ num_warmup_steps: 2500
36
+
37
+ loader:
38
+ global_batch_size: 64
39
+ eval_global_batch_size: ${.global_batch_size}
40
+ # Note: batch_size and eval_batch_size are **per machine**
41
+ batch_size: ${div_up:${.global_batch_size}, ${eval:${trainer.devices} * ${trainer.num_nodes}}}
42
+ eval_batch_size: ${div_up:${.eval_global_batch_size}, ${eval:${trainer.devices} * ${trainer.num_nodes}}}
43
+ num_workers: ${eval:"len(__import__('os').sched_getaffinity(0))"}
44
+ pin_memory: True
45
+
46
+ sampling:
47
+ predictor: ddpm_cache # analytic, ddpm, ddpm_cache
48
+ num_sequences: 100
49
+ sampling_eps: 1e-3
50
+ steps: 128
51
+ seq_length: 200
52
+ noise_removal: True
53
+ num_sample_batches: 2 # Total samples: `num_gpus` * `loader.eval_batch_size` * num_sample_batches
54
+ num_sample_log: 2
55
+ stride_length: 1
56
+ num_strides: 1
57
+
58
+ training:
59
+ antithetic_sampling: True
60
+ sampling_eps: 1e-3
61
+ focus_mask: False
62
+ #dynamic_batching: True
63
+ accumulator: False
64
+
65
+ eval:
66
+ checkpoint_path: None
67
+ disable_ema: False
68
+ compute_generative_perplexity: False
69
+ perplexity_batch_size: 8
70
+ compute_perplexity_on_sanity: False
71
+ gen_ppl_eval_model_name_or_path: gpt2-large # gpt2-large, meta-llama/Llama-2-7b-hf
72
+ generate_samples: True
73
+ generation_model: None
74
+
75
+ optim:
76
+ weight_decay: 0.075
77
+ lr: 3e-4
78
+ beta1: 0.9
79
+ beta2: 0.999
80
+ eps: 1e-8
81
+
82
+ pepclm:
83
+ hidden_size: 768
84
+ cond_dim: 256
85
+ n_heads: 20
86
+ n_blocks: 4
87
+ dropout: 0.5
88
+ length: 512
89
+ #scale_by_sigma: True
90
+
91
+ model:
92
+ type: ddit
93
+ hidden_size: 768
94
+ cond_dim: 128
95
+ length: 512
96
+ n_blocks: 12
97
+ n_heads: 12
98
+ scale_by_sigma: True
99
+ dropout: 0.1
100
+
101
+ roformer:
102
+ hidden_size: 768
103
+ n_layers: 8
104
+ n_heads: 8
105
+ max_position_embeddings: 1035
106
+
107
+ helmgpt:
108
+ hidden_size: 256
109
+ embd_pdrop: 0.1
110
+ resid_pdrop: 0.1
111
+ attn_pdrop: 0.1
112
+ ff_dropout: 0.
113
+ block_size: 140
114
+ n_layer: 8
115
+ n_heads: 8
116
+
117
+
118
+ trainer:
119
+ _target_: lightning.Trainer
120
+ accelerator: cuda
121
+ num_nodes: 1
122
+ devices: ${device_count:}
123
+ accumulate_grad_batches: ${div_up:${loader.global_batch_size}, ${eval:${trainer.devices} * ${loader.batch_size} * ${trainer.num_nodes}}}
124
+ gradient_clip_val: 1.0
125
+ precision: 64-true
126
+ num_sanity_val_steps: 2
127
+ max_epochs: 100
128
+ max_steps: 1_000_000
129
+ log_every_n_steps: 10
130
+ limit_train_batches: 1.0 # train on full dataset, can be used to toggle quick run
131
+ limit_val_batches: 1.0 # validate on full dataset, can be used to toggle quick run
132
+ #val_check_interval: 40 #954
133
+ check_val_every_n_epoch: 1
134
+
135
+
136
+ wandb:
137
+ project: peptune
138
+ notes: null
139
+ group: null
140
+ job_type: null
141
+ name: sophia-tang
142
+ id: ${.name}_nov12_set2
143
+
144
+ hydra:
145
+ run:
146
+ dir: ./${now:%Y.%m.%d}/
147
+ job:
148
+ chdir: True
149
+
150
+ checkpointing:
151
+ # Use custom `save_dir` if, e.g., saving to S3 bucket, otherwise leave this parameter as is
152
+ save_dir: ${cwd:}
153
+ # Note: `checkpoints` path should correspond to `checkpoint_every_n_steps.dirpath`
154
+ resume_from_ckpt: True
155
+ resume_ckpt_path: None
156
+
157
+ callbacks:
158
+ model_checkpoint:
159
+ _target_: pytorch_lightning.callbacks.ModelCheckpoint
160
+ every_n_epochs: 1
161
+ monitor: "val/nll"
162
+ save_top_k: 10
163
+ mode: "min"
164
+ dirpath: './checkpoints/'
src/diffusion.py ADDED
@@ -0,0 +1,1015 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Adapted from MDLM: https://github.com/kuleshov-group/mdlm
2
+
3
+ import numpy as np
4
+ import sys
5
+ import itertools
6
+ import time
7
+ import torch
8
+ from torch import Tensor
9
+ import math
10
+ import torch.nn.functional as F
11
+ import numpy as np
12
+ import random as rd
13
+ import lightning as L
14
+ import torchmetrics
15
+ from dataclasses import dataclass
16
+ import gc
17
+ import pickle
18
+ import utils.utils as utils
19
+ from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
20
+ import noise_schedule
21
+ from torch.optim.lr_scheduler import _LRScheduler
22
+ import roformer as roformer
23
+ from utils.app import PeptideAnalyzer
24
+
25
+ @dataclass
26
+ class Loss:
27
+ loss: torch.FloatTensor
28
+ nlls: torch.FloatTensor
29
+ attn_mask: torch.FloatTensor
30
+
31
+
32
+ class NLL(torchmetrics.aggregation.MeanMetric):
33
+ pass
34
+
35
+
36
+ class BPD(NLL):
37
+ def compute(self) -> Tensor:
38
+ """Computes the bits per dimension.
39
+
40
+ Returns:
41
+ bpd
42
+ """
43
+ return self.mean_value / self.weight / math.log(2)
44
+
45
+
46
+ class Perplexity(NLL):
47
+ def compute(self) -> Tensor:
48
+ """Computes the Perplexity.
49
+
50
+ Returns:
51
+ Perplexity
52
+ """
53
+ return torch.exp(self.mean_value / self.weight)
54
+
55
+
56
+ class Diffusion(L.LightningModule):
57
+ def __init__(self, config, tokenizer):
58
+
59
+ super().__init__()
60
+ self.config = config
61
+ #self.save_hyperparameters()
62
+
63
+ # PeptideCLM tokenizer
64
+ self.tokenizer = tokenizer
65
+ self.vocab_size = self.tokenizer.vocab_size
66
+ self.mask_token_id = self.tokenizer.mask_token_id
67
+ self.sampler = self.config.sampling.predictor
68
+ self.analyzer = PeptideAnalyzer()
69
+
70
+ # backbone LM PeptideCLM model
71
+ if self.config.backbone == 'roformer':
72
+ self.backbone = roformer.Roformer(self.config, self.tokenizer)
73
+ self.backbone.unfreeze_all_layers()
74
+ elif self.config.backbone == 'finetune_roformer':
75
+ self.backbone = roformer.Roformer(self.config, self.tokenizer)
76
+ self.backbone.freeze_model()
77
+ self.backbone.unfreeze_n_layers(n=8)
78
+ else:
79
+ Exception('invalid backbone config')
80
+
81
+ self.neg_infinity = -1000000.0
82
+ self.T = config.T
83
+ # noise schedule for non-peptide bond tokens (default to log-linear)
84
+ self.noise = noise_schedule.get_noise(config)
85
+ # noise schedule for peptide bonds (log-polynomial)
86
+ self.bond_noise = noise_schedule.LogPolyNoise()
87
+ self.time_conditioning = self.config.time_conditioning
88
+ self.fast_forward_epochs = None
89
+ self.fast_forward_batches = None
90
+
91
+ self.gen_ppl_eval_model_name_or_path = self.config.eval.gen_ppl_eval_model_name_or_path
92
+ self.gen_ppl_metric = Perplexity()
93
+
94
+ self.lr = self.config.optim.lr
95
+ self.sampling_eps = self.config.training.sampling_eps
96
+
97
+ metrics = torchmetrics.MetricCollection({
98
+ 'nll': NLL(),
99
+ 'bpd': BPD(),
100
+ 'ppl': Perplexity(),
101
+ })
102
+ metrics.set_dtype(torch.float64)
103
+ self.train_metrics = metrics.clone(prefix='trainer/')
104
+ self.valid_metrics = metrics.clone(prefix='val/')
105
+ self.test_metrics = metrics.clone(prefix='test/')
106
+
107
+
108
+ """LOSS FOR INVALID PEPTIDES"""
109
+
110
+ @torch.no_grad()
111
+ def conditional_gumbel(self, logits, D, k):
112
+ """
113
+ Outputs k samples of Q = StandardGumbel(), such that argmax(logits
114
+ + Q) is given by D (one-hot vector).
115
+
116
+ Input:
117
+ - logits: Tensor of shape (batch_size, seq_len, vocab_size)
118
+ - D: One-hot tensor of shape (batch_size, seq_len, vocab_size)
119
+ - k: Number of Gumbel samples
120
+
121
+ Output:
122
+ - Adjusted logits with shape (k, batch_size, seq_len, vocab_size)
123
+ """
124
+
125
+ # iid. exponential samples of shape (k, batch_size, seq_len, vocab_size)
126
+ E = torch.distributions.exponential.Exponential(rate=torch.ones_like(logits)).sample([k])
127
+
128
+ # E of the chosen class, shape (k, batch_size, seq_len, 1)
129
+ Ei = (D * E).sum(dim=-1, keepdim=True)
130
+
131
+ # Partition function (normalization constant), shape (batch_size, seq_len, 1)
132
+ Z = logits.exp().sum(dim=-1, keepdim=True)
133
+
134
+ # Adjusted logits for Gumbel distribution
135
+ adjusted = (
136
+ D * (-torch.log(Ei) + torch.log(Z)) +
137
+ (1 - D) * -torch.log(E / logits.exp() + Ei / Z)
138
+ )
139
+
140
+ # Adjusted logits shape: (k, batch_size, seq_len, vocab_size)
141
+ return adjusted - logits
142
+
143
+ def replace_gradient(self, value, surrogate):
144
+ """
145
+ Returns `value` but backpropagates gradients through `surrogate`.
146
+ """
147
+ return surrogate + (value - surrogate).detach()
148
+
149
+ def gumbel_rao(self, logits, k, temp=1.0, I=None):
150
+ """
151
+ Returns a categorical sample from logits (over axis=-1) as a
152
+ one-hot vector, with gumbel-rao gradient.
153
+
154
+ Input:
155
+ - logits: Tensor of shape (batch_size, seq_len, vocab_size)
156
+ - k: Number of Gumbel samples for Rao-Blackwellization
157
+ - temp: Temperature for softmax
158
+ - I: Optional, precomputed categorical sample tensor of shape (batch_size, seq_len)
159
+
160
+ Output:
161
+ - One-hot tensor of shape (batch_size, seq_len, vocab_size)
162
+ with Gumbel-Rao gradient.
163
+ """
164
+ assert logits.shape[-1] == self.tokenizer.vocab_size
165
+ vocab_size = logits.shape[-1]
166
+
167
+ if I is None:
168
+ # Sample indices for each token in the batch
169
+ I = torch.distributions.categorical.Categorical(logits=logits).sample() # (batch_size, seq_len)
170
+
171
+ # Convert indices to one-hot encodings, shape (batch_size, seq_len, vocab_size)
172
+ D = torch.nn.functional.one_hot(I, num_classes=vocab_size).float()
173
+
174
+ # Generate k different adjusted logits that all evaluate to the same sequence
175
+ adjusted = logits + self.conditional_gumbel(logits, D, k=k) # (k, batch_size, seq_len, vocab_size)
176
+
177
+ # Compute the surrogate by averaging softmax across k samples
178
+ surrogate = torch.nn.functional.softmax(adjusted / temp, dim=-1).mean(dim=0) # (batch_size, seq_len, vocab_size)
179
+
180
+ # Return one-hot representation with surrogate gradient
181
+ return self.replace_gradient(D, surrogate)
182
+
183
+ def compute_invalid_loss(self, logits, k=None, temp=None):
184
+ """
185
+ Penalizes logits that produce invalid sequences using the `is_peptide` function,
186
+ scaling penalties inversely with token probabilities.
187
+
188
+ Args:
189
+ logits: Tensor of shape [batch_size, seq_len, vocab_size].
190
+ k: Number of samples for Gumbel-Rao.
191
+ temp: Temperature for softmax.
192
+
193
+ Returns:
194
+ loss: A scalar tensor representing the total loss for invalid sequences.
195
+ """
196
+
197
+ #samples = self.gumbel_rao(logits, k=k, temp=temp) # (batch_size, seq_len, vocab_size)
198
+
199
+ # Convert logits to sequences using the tokenizer
200
+ batch_token_ids = logits.argmax(dim=-1).to(self.device) # (batch_size, seq_len)
201
+ sampled_sequences = self.tokenizer.batch_decode(batch_token_ids)
202
+
203
+ # Check validity of each sampled sequence (not differentiable)
204
+ penalties = torch.tensor(
205
+ [1 if not self.analyzer.is_peptide(seq) else 0 for seq in sampled_sequences],
206
+ dtype=torch.float32,
207
+ device=self.device
208
+ )
209
+ #print(penalties)
210
+
211
+ # Compute probabilities for each token (batch_size, seq_length)
212
+ sampled_probs = torch.softmax(logits, dim=-1).gather(dim=-1, index=batch_token_ids.unsqueeze(-1)).squeeze(-1).to(self.device)
213
+
214
+ # scale penalties by softmax probability of sampled tokens
215
+ scaled_penalty = penalties[:, None] * sampled_probs # (batch_size, seq_length)
216
+
217
+ return scaled_penalty.to(self.device)
218
+
219
+ """DIFFUSION LOSS"""
220
+
221
+ def sample_t(self, n, device):
222
+ """
223
+ Sample random time steps for batch training
224
+ """
225
+ # sample values uniformly at random from [0, 1)
226
+ eps_t = torch.rand(n, device=device)
227
+ # antithetic sampling: reduce variance by pairing each sample with complementary sample
228
+ if self.config.training.antithetic_sampling:
229
+ # compute interval between sampled time steps
230
+ offset = torch.arange(n, device=device) / n
231
+ # ensure that each eps value is evenly spaced between [0, 1)
232
+ eps_t = ((eps_t / n) + offset) % 1
233
+
234
+ # ensures values are not exactly 0 or 1
235
+ t = (1 - self.config.training.sampling_eps) * eps_t + self.config.training.sampling_eps
236
+
237
+ return t
238
+
239
+ def q_xt(self, x, mask_prob):
240
+ """Computes the noisy sample xt.
241
+
242
+ Args:
243
+ x: int torch.Tensor with shape (batch_size,
244
+ diffusion_model_input_length), input.
245
+ move_chance: float torch.Tensor with shape (batch_size, 1).
246
+ """
247
+
248
+ actual_seq_length = (x != 0).sum(dim=-1, keepdim=True)
249
+ #print(actual_seq_length)
250
+
251
+ max_mask_length = (actual_seq_length * 0.75).long()
252
+
253
+ mask_indices = torch.rand(*x.shape, device=x.device) < mask_prob
254
+
255
+ restricted_move_indices = torch.zeros_like(mask_indices, dtype=torch.bool)
256
+
257
+ for i in range(x.shape[0]):
258
+ true_positions = torch.where(mask_indices[i])[0]
259
+ if len(true_positions) > max_mask_length[i]:
260
+ selected_positions = true_positions[:max_mask_length[i].item()]
261
+ restricted_move_indices[i, selected_positions] = True
262
+ else:
263
+ restricted_move_indices[i] = mask_indices[i]
264
+
265
+ xt = torch.where(restricted_move_indices, self.tokenizer.mask_token_id, x)
266
+
267
+ return xt
268
+
269
+
270
+ def sample_prior(self, *batch_dims):
271
+ """
272
+ Returns array of fully masked sequences with same shape as input
273
+ """
274
+ return self.mask_token_id * torch.ones(* batch_dims, dtype=torch.int64)
275
+
276
+
277
+ """COMPUTING LOSS"""
278
+
279
+ def compute_diffusion_loss(self, model_output, xt, x0, t):
280
+ """
281
+ Computes diffusion loss term in ELBO
282
+ (evaluates how accurately the model predicts the token probabilities at each time step)
283
+
284
+ Inputs:
285
+ - model_output: [sequence length, vocab size, vocab size] array of logits for each token at each sequence position
286
+ - zt: corrupted version of original input x0 at timestep t
287
+ - x0: original input sequence
288
+ - t: timestep
289
+ """
290
+ # compute interval between each timestep
291
+ dt = 1 / self.T
292
+
293
+ # compute vectorized alpha scaling terms for the logits at timestep s and t
294
+ alpha_t = 1 - t + torch.zeros_like(x0)
295
+ # s = t - dt
296
+ alpha_s = 1 - (t - dt) + torch.zeros_like(x0)
297
+
298
+ # gather vector of log-probabilities for each token in x0
299
+ # log<x_theta, x>
300
+ log_x_theta_at_x0 = torch.gather(model_output, -1, x0[:, :, None]) # shape (B, L, vocab_size)
301
+ # gather log-probabillities for assigning a masked token at each position in the sequence at time t
302
+ # log<x_theta, m>
303
+ log_x_theta_at_m = model_output[:, :, self.mask_token_id]
304
+ # obtain non-log probability of assigning a masked token
305
+ # <xt, m>
306
+ x_theta_at_m = log_x_theta_at_m.exp()
307
+
308
+ # first term of diffusion loss
309
+ term_1_coef = dt / t
310
+ term_1_log_numerator = torch.log((alpha_t * x_theta_at_m) / t + 1)
311
+ term_1_log_denom = log_x_theta_at_x0
312
+
313
+ # second term of diffusion loss
314
+ term_2_coef = 1 - (dt / t)
315
+ term_2_log_numerator = term_1_log_numerator
316
+ term_2_log_denom = torch.log((alpha_s * x_theta_at_m) / (t - dt) + 1)
317
+
318
+ L_vb_masked = (term_1_coef * (term_1_log_numerator - term_1_log_denom) +
319
+ term_2_coef * (term_2_log_numerator - term_2_log_denom))
320
+
321
+ # multiply by <zt, m> term
322
+ L_vb = L_vb_masked * (xt == self.mask_token_id)
323
+
324
+ # scale by T and return
325
+ return self.T * L_vb
326
+
327
+ def _forward_pass_diffusion(self, x0, attn_mask, bond_mask=None, mask=None):
328
+ """
329
+ Training reverse diffusion model x_theta to reconstruct samples x0
330
+
331
+ bond_mask: (batch, seq_length)
332
+ """
333
+ # randomly sample time steps to start the denoising process for each x0 in batch
334
+ t = self.sample_t(x0.shape[0], self.device)
335
+
336
+ # if we are training the intermediate transition blocks
337
+ if self.T > 0:
338
+ # scale by total timesteps T and cast to integer
339
+ t = (t * self.T).to(torch.int)
340
+ # scale down by T to get a multiple of 1/T
341
+ t = t / self.T
342
+ # add 1/T to ensure no 0 values
343
+ t += (1 / self.T)
344
+
345
+ # get noise and rate of noise at timestep t
346
+ # sigma = -log(1-t); dsigma = 1 / (1-t)
347
+ sigma, dsigma = self.noise(t)
348
+ time_conditioning = sigma[:, None]
349
+
350
+ # Get masking probabilities for all tokens for each batch
351
+ # log-linear: 1 - alpha = t
352
+ base_mask_prob = 1 - torch.exp(-sigma[:, None]) # (batch_size, L)
353
+
354
+ if self.config.noise.state_dependent and (bond_mask is not None):
355
+ # log-polynomial masking schedule: alpha = 1 - t^w
356
+ # bond_sigma = -log(1-t^w) for w = 3 (default)
357
+ # bond_dsigma = -wt^(w-1) / (1-t^w)
358
+ bond_sigma, bond_dsigma = self.bond_noise(t) # scalar
359
+ # expand dimensions for broadcasting to (B, L)
360
+ bond_sigma = bond_sigma[:, None]
361
+ bond_dsigma = bond_dsigma[:, None]
362
+ sigma = sigma[:, None]
363
+ dsigma = dsigma[:, None]
364
+
365
+ # compute masking probability for peptide bonds 1 - bond_alpha = t^w
366
+ bond_mask_prob = 1 - torch.exp(-bond_sigma).to(self.device)
367
+ # piece together (B, L) tensor with modified masking prob at peptide-bond locations
368
+ mask_prob = torch.where(bond_mask == 1, bond_mask_prob, base_mask_prob).to(self.device)
369
+ #print(mask_prob)
370
+ dsigma = torch.where(bond_mask == 1, bond_dsigma, dsigma).to(self.device)
371
+ sigma = torch.where(bond_mask == 1, bond_sigma, sigma).to(self.device)
372
+ else:
373
+ mask_prob = base_mask_prob.to(self.device)
374
+
375
+ # get masked samples at different timesteps
376
+ if mask is None:
377
+ zt = self.q_xt(x0, mask_prob).to(self.device)
378
+ else:
379
+ zt = x0.where(mask==1, torch.full_like(x0, self.mask_token_id)).to(self.device)
380
+
381
+ model_output = self.forward(zt, attn_mask=attn_mask.to(self.device), sigma=time_conditioning).to(self.device)
382
+
383
+ # debugging
384
+ assert not torch.isnan(model_output).any()
385
+ assert model_output.is_cuda
386
+ utils.print_nans(model_output, 'model_output')
387
+
388
+ # compute invalid loss
389
+ invalid_loss = self.compute_invalid_loss(logits=model_output).to(self.device) # (B, L)
390
+ #print(invalid_loss)
391
+
392
+ if self.T > 0:
393
+ # compute diffusion loss
394
+ diffusion_loss = self.compute_diffusion_loss(model_output, zt, x0, t)
395
+ return diffusion_loss
396
+
397
+ # compute loss for the final that converts from z0 to x0
398
+ # -log(p_theta)
399
+ # get (batch_size, L) array of log-probabilities
400
+ log_p_theta = torch.gather(input=model_output, dim=-1, index=x0[:, :, None]).squeeze(-1).to(self.device) # (B, L)
401
+
402
+ if self.config.noise.state_dependent and (bond_mask is not None):
403
+ return (-log_p_theta * (dsigma / torch.expm1(sigma)) + invalid_loss).to(self.device)
404
+ else:
405
+ return ((-log_p_theta * (dsigma / torch.expm1(sigma))[:, None]) + invalid_loss).to(self.device)
406
+
407
+ def _loss(self, x0, attn_mask, bond_mask=None, mask=None):
408
+ loss = self._forward_pass_diffusion(x0, attn_mask, bond_mask, mask)
409
+
410
+ # negative log loss
411
+ nlls = loss * attn_mask
412
+
413
+ # count number of tokens
414
+ num_tokens = attn_mask.sum()
415
+
416
+ # compute batch loss
417
+ batch_nll = nlls.sum()
418
+ # compute per token loss
419
+ token_nll = batch_nll / num_tokens
420
+ # return losses
421
+ return Loss(loss = token_nll.to(self.device), nlls = nlls.to(self.device), attn_mask = attn_mask.to(self.device))
422
+
423
+ def _compute_loss(self, batch, prefix, bond_mask=None):
424
+
425
+ attn_mask = batch['attention_mask'].to(self.device)
426
+
427
+ if 'mask' in batch:
428
+ mask = batch['mask'].to(self.device)
429
+ else:
430
+ mask = None
431
+
432
+ if 'bond_mask' in batch:
433
+ bond_mask = batch['bond_mask'].to(self.device)
434
+ else:
435
+ bond_mask = None
436
+
437
+ losses = self._loss(batch['input_ids'].to(self.device), attn_mask, bond_mask, mask)
438
+ loss = losses.loss
439
+
440
+ if prefix == 'train':
441
+ self.train_metrics.update(
442
+ losses.nlls.to(self.device),
443
+ losses.attn_mask.to(self.device)
444
+ )
445
+ metrics = self.train_metrics
446
+ elif prefix == 'val':
447
+ self.valid_metrics.update(
448
+ losses.nlls.to(self.device),
449
+ losses.attn_mask.to(self.device)
450
+ )
451
+ metrics = self.valid_metrics
452
+ elif prefix == 'test':
453
+ self.test_metrics.update(losses.nlls, losses.attn_mask)
454
+ metrics = self.test_metrics
455
+ else:
456
+ raise ValueError(f'Invalid prefix: {prefix}')
457
+
458
+ self.log_dict(metrics,
459
+ on_step=False,
460
+ on_epoch=True,
461
+ sync_dist=True)
462
+
463
+ return loss
464
+
465
+
466
+ """SAMPLING"""
467
+
468
+ def generate_from_masked(self, num_samples=None, seq_length=None, sample_steps=128, eps=1e-5):
469
+ # get number of timesteps
470
+ if sample_steps is None:
471
+ sample_steps = self.config.sampling.steps
472
+
473
+ if seq_length is None:
474
+ seq_length = self.config.sampling.seq_length
475
+
476
+ # sample fully masked sequences
477
+ z = self.sample_prior(num_samples, seq_length).to(self.device)
478
+
479
+ # create vector of sample_steps timesteps
480
+ timesteps = torch.linspace(1, eps, sample_steps + 1, device=self.device)
481
+
482
+ # compute interval between timesteps
483
+ dt = (1 - eps) / sample_steps
484
+
485
+ for i in range(sample_steps):
486
+ t = timesteps[i] * torch.ones(z.shape[0], 1, device=self.device)
487
+
488
+ z = self.single_reverse_step(z, t, dt)
489
+
490
+ return z
491
+
492
+
493
+ """SAMPLING STEP"""
494
+
495
+ def single_reverse_step(self, zt, t, dt, attn_mask=None):
496
+ """
497
+ Take a single reverse diffusion step for the expansion step of the MCTS algorithm
498
+ """
499
+ # get sigma values that determine masking prob
500
+ sigma_t, _ = self.noise(t)
501
+ sigma_s, _ = self.noise(t - dt)
502
+
503
+ # reshape sigmas
504
+ if sigma_t.ndim > 1:
505
+ sigma_t = sigma_t.squeeze(-1)
506
+ if sigma_s.ndim > 1:
507
+ sigma_s = sigma_s.squeeze(-1)
508
+ assert sigma_t.ndim == 1, sigma_t.shape
509
+ assert sigma_s.ndim == 1, sigma_s.shape
510
+
511
+ # compute masking probabilities for each timestep
512
+ change_prob_t = 1 - torch.exp(-sigma_t)
513
+ change_prob_s = 1 - torch.exp(-sigma_s)
514
+
515
+ # expand dimensions
516
+ change_prob_t = change_prob_t[:, None, None]
517
+ change_prob_s = change_prob_s[:, None, None]
518
+
519
+ # get prodiction model that outputs token probabilities
520
+ log_p_x0 = self.forward(zt, attn_mask=attn_mask, sigma=sigma_t)
521
+
522
+ # check dimensions match
523
+ assert change_prob_t.ndim == log_p_x0.ndim
524
+
525
+ # compute reverse diffusion probability of being unmasked at timestep s
526
+ # (sigma_s - sigma_t)*x_theta
527
+ q_zs = log_p_x0.exp() * (change_prob_t - change_prob_s)
528
+
529
+ # compute reverse diffusion probability of remaining masked at timestep s
530
+ # (1 - sigma_s)*m
531
+ q_zs[:, :, self.mask_token_id] = change_prob_s[:, :, 0]
532
+
533
+ # sample sequence at timestep s from categorical distribution of q_zs
534
+ z_changed = sample_categorical(q_zs)
535
+
536
+ copy_flag = (zt != self.mask_token_id).to(zt.dtype)
537
+ return (copy_flag * zt) + ((1 - copy_flag) * z_changed)
538
+
539
+ def cached_reverse_step(self, x, t, dt, p_x0=None, attn_mask=None):
540
+ assert self.config.noise.type == 'loglinear'
541
+ sigma_t, _ = self.noise(t)
542
+
543
+ if t.ndim > 1:
544
+ t = t.squeeze(-1)
545
+ assert t.ndim == 1
546
+
547
+ change_prob_t = t[:, None, None]
548
+ change_prob_s = (t - dt)[:, None, None]
549
+
550
+ assert change_prob_t.ndim == 3, change_prob_t.shape
551
+
552
+ if p_x0 is None:
553
+ p_x0 = self.forward(x, attn_mask=attn_mask, sigma=sigma_t).exp()
554
+
555
+ assert change_prob_t.ndim == p_x0.ndim
556
+
557
+ q_xs = p_x0 * (change_prob_t - change_prob_s)
558
+
559
+ q_xs[:, :, self.mask_token_id] = change_prob_s[:, :, 0]
560
+
561
+ x_changed = sample_categorical(q_xs)
562
+
563
+ copy_flag = (x != self.mask_token_id).to(x.dtype)
564
+
565
+ return p_x0, copy_flag * x + (1 - copy_flag) * x_changed
566
+
567
+ # first step in expansion
568
+ def batch_cached_reverse_step(self, token_array, t, dt, batch_size, p_x0=None, attn_mask=None):
569
+
570
+ assert self.config.noise.type == 'loglinear'
571
+ sigma_t, _ = self.noise(t)
572
+
573
+ if t.ndim > 1:
574
+ t = t.squeeze(-1)
575
+ assert t.ndim == 1
576
+
577
+ change_prob_t = t[:, None, None]
578
+ change_prob_s = (t - dt)[:, None, None]
579
+
580
+ assert change_prob_t.ndim == 3, change_prob_t.shape
581
+
582
+ if token_array.dim() == 1:
583
+ token_array = token_array.unsqueeze(0)
584
+ #token_array = token_array.repeat(batch_size, 1)
585
+
586
+ attn_mask = torch.ones_like(token_array)
587
+
588
+ if p_x0 is None:
589
+ p_x0 = self.forward(token_array, attn_mask=attn_mask, sigma=sigma_t).exp()
590
+
591
+ assert change_prob_t.ndim == p_x0.ndim
592
+
593
+ q_xs = p_x0 * (change_prob_t - change_prob_s)
594
+
595
+ # zero-masking probability
596
+ q_xs[:, :, self.mask_token_id] = change_prob_s[:, :, 0]
597
+
598
+ # repeat the parent token along the first dimension which will be unmasked into distinct sequences
599
+ token_array = token_array.repeat(batch_size, 1)
600
+
601
+ if self.config.mcts.sampling == 0:
602
+ x_changed = sample_batched_categorical(q_xs.to(self.device), batch_size)
603
+ else:
604
+ x_changed = sample_batched_top_k(q_xs.to(self.device), batch_size, self.config.mcts.sampling)
605
+
606
+ copy_flag = (token_array != self.mask_token_id).to(token_array.dtype)
607
+
608
+ return p_x0, copy_flag * token_array + (1 - copy_flag) * x_changed
609
+
610
+ def _process_sigma(self, sigma):
611
+ if sigma.ndim > 1:
612
+ sigma = sigma.squeeze(-1)
613
+ if not self.time_conditioning:
614
+ sigma = torch.zeros_like(sigma)
615
+ assert sigma.ndim == 1, sigma.shape
616
+ return sigma
617
+
618
+ def forward(self, zt, attn_mask, sigma):
619
+ """
620
+ Predicts the token log-probabilities from zt at time t with noise schedule sigma
621
+ """
622
+ sigma = self._process_sigma(sigma)
623
+
624
+ with torch.amp.autocast("cuda", enabled=True, dtype=torch.float32, cache_enabled=True):
625
+ logits = self.backbone(zt, attn_mask).to(self.device)
626
+
627
+ return self.subs_parameterization(logits, zt)
628
+
629
+ def subs_parameterization(self, logits, zt):
630
+ """
631
+ Updates reverse diffusion logits based on SUBS parameterization:
632
+ - zero masking probabilities: -infinity probability of being masked during reverse diffusion
633
+ - carry-over unmasking: unmasked input tokens remain unchanged during reverse diffusion
634
+
635
+ Args:
636
+ logits: vector of token probabilities for unmasking masked tokens
637
+ zt: partially unmasked sequence at current timestep
638
+ """
639
+ logits[:, :, self.mask_token_id] += self.neg_infinity # [sequence index, current token, next token]
640
+
641
+
642
+ logits = (logits - torch.logsumexp(logits, dim=-1, keepdim=True)).to(self.device)
643
+
644
+
645
+ unmasked_indices = (zt != self.mask_token_id).to(self.device) # shape: [200, seq_length]
646
+ batch_idx, seq_idx = torch.where(unmasked_indices) # Get explicit indices
647
+ batch_idx = batch_idx.to(self.device)
648
+ seq_idx = seq_idx.to(self.device)
649
+ tokens = zt[batch_idx, seq_idx].to(self.device) # Get the tokens at those positions
650
+
651
+ assert logits.is_contiguous(), "logits tensor is not contiguous"
652
+ assert unmasked_indices.shape == zt.shape, "same shape"
653
+ assert not torch.isnan(logits).any(), "NaN values found in logits"
654
+ assert tokens.max() < logits.shape[-1], "token indices out of bounds"
655
+ assert batch_idx.max() < logits.shape[0], "batch index out of bounds"
656
+ assert seq_idx.max() < logits.shape[1], "seq index out of bounds"
657
+ assert batch_idx.device == seq_idx.device == logits.device == tokens.device, "device inconsistent"
658
+
659
+ logits[batch_idx, seq_idx] = self.neg_infinity # Set everything to -inf first
660
+ logits[batch_idx, seq_idx, tokens] = 0 # Set only the specific token positions to 0
661
+ # return logits with SUBS parameterization
662
+ return logits.to(self.device)
663
+
664
+ """SAMPLING"""
665
+ @torch.no_grad()
666
+ def _sample(self, num_steps=None, eps=1e-5, x_input=None):
667
+ """
668
+ Generate samples
669
+ """
670
+ batch_size_per_gpu = self.config.eval.perplexity_batch_size
671
+
672
+ if num_steps is None:
673
+ num_steps = self.config.sampling.steps
674
+
675
+ if x_input is not None:
676
+ x = x_input['input_ids'].to(self.device)
677
+ attn_mask = x_input['attention_mask'].to(self.device)
678
+ else:
679
+ x = self.sample_prior(batch_size_per_gpu, self.config.model.length).to(self.device)
680
+ attn_mask = torch.ones_like(x).to(self.device)
681
+
682
+
683
+ timesteps = torch.linspace(1, eps, num_steps+1, device=self.device)
684
+ dt = (1 - eps) / num_steps
685
+ p_x0_cache = None
686
+ generation_history = [] # used to track which tokens are unmasked
687
+
688
+ for i in range(num_steps):
689
+ t = timesteps[i] * torch.ones(x.shape[0], 1, device = self.device)
690
+ if self.sampler == 'ddpm':
691
+ x = self.single_reverse_step(x, t, dt).to(self.device)
692
+
693
+ elif self.sampler == 'ddpm_cache':
694
+ p_x0_cache, x_next = self.cached_reverse_step(x, t, dt, p_x0=p_x0_cache, attn_mask=attn_mask)
695
+ if (not torch.allclose(x_next, x) or self.time_conditioning):
696
+ # Disable caching
697
+ p_x0_cache = None
698
+ x = x_next.to(self.device)
699
+ #print(self.tokenizer.decode(x.squeeze()))
700
+ else:
701
+ x = self._analytic_update(x, t, dt, attn_mask).to(self.device)
702
+
703
+ if self.config.sampling.noise_removal:
704
+ t = timesteps[-1] * torch.ones(x.shape[0], 1, device=self.device)
705
+ if self.sampler == 'analytic':
706
+ x = self._denoiser_update(x, t).to(self.device)
707
+ else:
708
+ time_conditioning = self.noise(t)[0].to(self.device)
709
+ x = self.forward(x, attn_mask=attn_mask, sigma=time_conditioning).argmax(dim=-1).to(self.device)
710
+ #print(self.tokenizer.decode(x.squeeze()))
711
+ return x.to(self.device)
712
+
713
+
714
+ def restore_model_and_sample(self, num_steps, eps=1e-5):
715
+ """Generate samples from the model."""
716
+ self.backbone.eval()
717
+ self.noise.eval()
718
+ samples = self._sample(num_steps=num_steps, eps=eps)
719
+ self.backbone.train()
720
+ self.noise.train()
721
+ return samples
722
+
723
+ def get_score(self, zt, sigma, attn_mask=None):
724
+
725
+ # score(x, t) = p_t(y) / p_t(x)
726
+ # => log score(x, t) = log p_t(y) - log p_t(x)
727
+
728
+ # case 1: x = masked
729
+ # (i) y = unmasked
730
+ # log score(x, t) = log p_\theta(x)|_y + log k
731
+ # where k = exp(- sigma) / (1 - exp(- sigma))
732
+ # (ii) y = masked
733
+ # log score(x, t) = 0
734
+
735
+ # case 2: x = unmasked
736
+ # (i) y != masked, y != x
737
+ # log score(x_i, t) = - inf
738
+ # (ii) y = x
739
+ # log score(x_i, t) = 0
740
+ # (iii) y = masked token
741
+ # log score(x_i, t) = - log k
742
+ # where k = exp(- sigma) / (1 - exp(- sigma))
743
+
744
+ model_output = self.forward(zt, attn_mask=attn_mask, sigma=sigma)
745
+
746
+ log_k = -torch.log(torch.expm1(sigma)).squeeze(-1)
747
+ assert log_k.ndim == 1
748
+
749
+ masked_score = model_output + log_k[:, None, None]
750
+ masked_score[:, :, self.mask_token_id] = 0
751
+
752
+ unmasked_score = self.neg_infinity * torch.ones_like(model_output)
753
+ unmasked_score = torch.scatter(
754
+ unmasked_score, -1,
755
+ zt[..., None],
756
+ torch.zeros_like(unmasked_score[..., :1]))
757
+
758
+ unmasked_score[:, :, self.mask_token_id] = - (log_k[:, None] * torch.ones_like(zt))
759
+
760
+ masked_indices = (zt == self.mask_token_id).to(model_output.dtype)[:, :, None]
761
+
762
+ model_output = (masked_score * masked_indices + unmasked_score * (1 - masked_indices))
763
+
764
+ return model_output.exp()
765
+
766
+ def _staggered_score(self, score, dsigma):
767
+ score = score.clone()
768
+ extra_const = (1 - dsigma.exp()) * score.sum(dim=-1)
769
+ score *= dsigma.exp()[:, None]
770
+ score[..., self.mask_token_id] += extra_const
771
+ return score
772
+
773
+ def _analytic_update(self, x, t, step_size, attn_mask=None):
774
+ curr_sigma, _ = self.noise(t)
775
+ next_sigma, _ = self.noise(t - step_size)
776
+ dsigma = curr_sigma - next_sigma
777
+ score = self.get_score(x, attn_mask, curr_sigma)
778
+ stag_score = self._staggered_score(score, dsigma)
779
+ probs = stag_score * self._transp_transition(x, dsigma)
780
+ return sample_categorical(probs)
781
+
782
+ def _denoiser_update(self, x, t):
783
+ sigma, _ = self.noise(t)
784
+ score = self.get_score(x, sigma)
785
+ stag_score = self._staggered_score(score, sigma)
786
+ probs = stag_score * self._transp_transition(x, sigma)
787
+ probs[..., self.mask_token_id] = 0
788
+ samples = sample_categorical(probs)
789
+ return samples
790
+
791
+ def _transp_transition(self, i, sigma):
792
+ sigma = unsqueeze(sigma, reference=i[..., None])
793
+ edge = torch.exp(-sigma) * F.one_hot(
794
+ i, num_classes=self.vocab_size)
795
+ edge += torch.where(i == self.mask_token_id,
796
+ 1 - torch.exp(-sigma).squeeze(-1),
797
+ 0)[..., None]
798
+ return edge
799
+
800
+
801
+ def on_train_epoch_start(self):
802
+ torch.cuda.empty_cache()
803
+ self.backbone.train()
804
+ self.noise.train()
805
+
806
+
807
+ def training_step(self, batch, batch_idx):
808
+ # Initialize throughput calculation
809
+ start_time = time.time()
810
+
811
+ if self.config.vocab == 'old_smiles' or self.config.vocab == 'new_smiles':
812
+ loss = self._compute_loss(batch, prefix='train', bond_mask=batch['bond_mask'])
813
+ else:
814
+ loss = self._compute_loss(batch, prefix='train')
815
+
816
+ self.log(name='trainer/loss',
817
+ value=loss.item(),
818
+ on_step=True,
819
+ on_epoch=False,
820
+ sync_dist=True)
821
+
822
+ # Calculate throughput
823
+ elapsed_time = time.time() - start_time
824
+ total_tokens = batch['input_ids'].numel()
825
+ throughput = total_tokens / elapsed_time
826
+
827
+ self.log(name='trainer/throughput',
828
+ value=throughput,
829
+ on_step=True,
830
+ on_epoch=False,
831
+ sync_dist=True)
832
+
833
+ return loss
834
+
835
+
836
+ def on_load_checkpoint(self, checkpoint):
837
+ self.fast_forward_epochs = checkpoint['loops']['fit_loop']['epoch_progress']['current']['completed']
838
+ self.fast_forward_batches = checkpoint['loops']['fit_loop']['epoch_loop.batch_progress']['current']['completed']
839
+
840
+ """VALIDATION"""
841
+ def on_validation_epoch_start(self):
842
+ gc.collect()
843
+ torch.cuda.empty_cache()
844
+ self.backbone.eval()
845
+ self.noise.eval()
846
+ assert self.valid_metrics.nll.mean_value == 0
847
+ assert self.valid_metrics.nll.weight == 0
848
+
849
+ def validation_step(self, batch, batch_idx):
850
+ if self.config.vocab == 'old_smiles' or self.config.vocab == 'new_smiles':
851
+ loss = self._compute_loss(batch, prefix='val', bond_mask=batch['bond_mask'])
852
+ else:
853
+ loss = self._compute_loss(batch, prefix='val')
854
+
855
+ self.log(name='trainer/val_loss',
856
+ value=loss.item(),
857
+ on_step=True,
858
+ on_epoch=False,
859
+ prog_bar=True,
860
+ sync_dist=True)
861
+ return loss
862
+
863
+ def on_validation_epoch_end(self):
864
+ gc.collect()
865
+ torch.cuda.empty_cache()
866
+
867
+ """OPTIMIZATION"""
868
+
869
+ def optimizer_step(self, *args, **kwargs):
870
+ super().optimizer_step(*args, **kwargs)
871
+
872
+ gc.collect()
873
+ torch.cuda.empty_cache()
874
+
875
+ def configure_optimizers(self):
876
+ optimizer = torch.optim.AdamW(
877
+ itertools.chain(self.backbone.parameters(),self.noise.parameters()),
878
+ lr=self.config.optim.lr,
879
+ betas=(self.config.optim.beta1, self.config.optim.beta2),
880
+ eps=self.config.optim.eps,
881
+ weight_decay=self.config.optim.weight_decay
882
+ )
883
+
884
+ self.total_steps = self.config.trainer.max_steps
885
+ scheduler = CosineWarmup(optimizer,
886
+ warmup_steps=self.config.lr_scheduler.num_warmup_steps,
887
+ total_steps=self.total_steps)
888
+
889
+ scheduler_dict = {
890
+ 'scheduler': scheduler,
891
+ 'interval': 'step',
892
+ 'frequency': 1,
893
+ 'monitor': 'val/loss',
894
+ 'name': 'trainer/lr'
895
+ }
896
+
897
+ return [optimizer], [scheduler_dict]
898
+
899
+ @torch.no_grad()
900
+ def compute_masked_perplexity(self, generated_ids, input_ids):
901
+ """
902
+ Computes masked perplexity between array of generated token ids and masked ids that are converted to logits
903
+ """
904
+
905
+ total_nll = 0
906
+ total_tokens = 0
907
+
908
+ input_ids = torch.tensor(input_ids).to(self.device)
909
+ #print(input_ids)
910
+
911
+ for sequence in generated_ids:
912
+ # tokenize the sequence
913
+
914
+ gt_ids = torch.tensor(sequence).to(self.device)
915
+ #print(gt_ids)
916
+
917
+ sys.stdout.flush()
918
+
919
+ # forward pass thorugh backbone peptideclm model
920
+ attn_mask = torch.ones_like(input_ids).to(self.device)
921
+
922
+ # compute logits using backbone
923
+
924
+ outputs = self.backbone.forward(input_ids=input_ids, attn_mask=attn_mask)
925
+
926
+
927
+ # get logits for each position in sequence across all tokens in vocab
928
+ #logits = outputs[-1] # (batch_size, seq_length, vocab_size)
929
+
930
+ logits = outputs.view(-1, outputs.size(-1))
931
+ gt_ids = gt_ids.view(-1)
932
+
933
+ #print(logits.shape)
934
+ #print(gt_ids.shape)
935
+
936
+ # compute loss
937
+ # shift_logits = logits[:, :-1, :].contiguous() # remove eos
938
+ # shift_labels = input_ids[:, 1:].contiguous()
939
+ # print(masked)
940
+
941
+ loss = F.cross_entropy(logits,
942
+ gt_ids.where(input_ids==self.mask_token_id, torch.full_like(gt_ids, -100)).view(-1),
943
+ reduction='sum')
944
+
945
+ total_nll += loss.item()
946
+ # count all non-padding tokens
947
+ total_tokens += input_ids.ne(self.tokenizer.pad_token_id).sum().item() # count in bos and eos
948
+
949
+ # compute pseudo-perplexity
950
+ # print(total_nll, ",;,", total_tokens)
951
+ pseudo_perplexity = torch.exp(torch.tensor(total_nll / total_tokens))
952
+ self.gen_ppl_metric.update(pseudo_perplexity)
953
+
954
+ return pseudo_perplexity.item()
955
+
956
+
957
+ def sample_categorical(categorical_probs):
958
+ gumbel_norm = (
959
+ 1e-10
960
+ - (torch.rand_like(categorical_probs) + 1e-10).log())
961
+ return (categorical_probs / gumbel_norm).argmax(dim=-1)
962
+
963
+ def sample_batched_categorical(categorical_probs, batch_size):
964
+ _, sequence_length, vocab_size = categorical_probs.shape
965
+
966
+ # add Gumbel noise and sample m sequences
967
+ gumbel_noise = (-torch.log(-torch.log(torch.rand(batch_size, sequence_length, vocab_size) + 1e-10) + 1e-10)).to(categorical_probs.device)
968
+ noisy_scores = torch.log(categorical_probs) + gumbel_noise # add Gumbel noise to log probabilities
969
+
970
+ # select the highest score (most likely category after Gumbel noise)
971
+ sampled_sequences = noisy_scores.argmax(dim=-1) # shape: (m, sequence_length)
972
+
973
+ return sampled_sequences
974
+
975
+ def sample_batched_top_k(categorical_probs, batch_size, k):
976
+ _, sequence_length, vocab_length = categorical_probs.shape
977
+
978
+ # Add Gumbel noise to the log probabilities
979
+ gumbel_noise = -torch.log(-torch.log(torch.rand(batch_size, sequence_length, vocab_length) + 1e-10) + 1e-10).to(categorical_probs.device)
980
+ noisy_scores = torch.log(categorical_probs[None, :, :]) + gumbel_noise # Shape: (m, sequence_length, vocab_length)
981
+
982
+ # Get the top-k categories based on noisy scores
983
+ top_k_scores, top_k_indices = torch.topk(noisy_scores, k, dim=-1) # Shape: (m, sequence_length, k)
984
+
985
+ # Convert top-k scores back to probabilities and normalize
986
+ top_k_probs = torch.softmax(top_k_scores, dim=-1).to(categorical_probs.device) # Shape: (m, sequence_length, k)
987
+
988
+ # Sample randomly from the top-k probabilities
989
+ sampled_indices_in_top_k = torch.multinomial(top_k_probs.reshape(-1, k), num_samples=1).squeeze(-1).to(categorical_probs.device)
990
+ sampled_indices_in_top_k = sampled_indices_in_top_k.view(batch_size, sequence_length).to(categorical_probs.device) # Shape: (batch_size, sequence_length)
991
+
992
+ # Map sampled indices back to the original vocabulary indices
993
+ sampled_sequences = torch.gather(top_k_indices, -1, sampled_indices_in_top_k.unsqueeze(-1)).squeeze(-1).to(categorical_probs.device)
994
+
995
+ return sampled_sequences
996
+
997
+ def unsqueeze(x, reference):
998
+ return x.view(* x.shape, * ((1,) * (len(reference.shape) - len(x.shape))))
999
+
1000
+ class CosineWarmup(_LRScheduler):
1001
+ def __init__(self, optimizer, warmup_steps, total_steps, eta_ratio=0.1, last_epoch=-1):
1002
+ self.warmup_steps = warmup_steps
1003
+ self.total_steps = total_steps
1004
+ self.eta_ratio = eta_ratio # The ratio of minimum to maximum learning rate
1005
+ super(CosineWarmup, self).__init__(optimizer, last_epoch)
1006
+
1007
+ def get_lr(self):
1008
+ if self.last_epoch < self.warmup_steps:
1009
+ return [base_lr * self.last_epoch / self.warmup_steps for base_lr in self.base_lrs]
1010
+
1011
+ progress = (self.last_epoch - self.warmup_steps) / (self.total_steps - self.warmup_steps)
1012
+ cosine_decay = 0.5 * (1 + np.cos(np.pi * progress))
1013
+ decayed_lr = (1 - self.eta_ratio) * cosine_decay + self.eta_ratio
1014
+
1015
+ return [decayed_lr * base_lr for base_lr in self.base_lrs]
src/environment.yml ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: peptune
2
+ channels:
3
+ - pytorch
4
+ - nvidia
5
+ - conda-forge
6
+ dependencies:
7
+ - python=3.10
8
+ - pip
9
+ - pytorch
10
+ - torchvision
11
+ - pytorch-cuda=12.1
12
+ - rdkit
13
+ - numpy
14
+ - pandas
15
+ - scikit-learn
16
+ - jupyterlab
17
+ - matplotlib-base
18
+ - seaborn
19
+ - tqdm
20
+ - pyyaml
21
+ - pip:
22
+ - pytorch-lightning==2.5.5
23
+ - lightning==2.5.5
24
+ - fair-esm==2.0.0
25
+ - transformers==4.56.2
26
+ - SmilesPE==0.0.3
27
+ - scipy==1.13.1
28
+ - wandb==0.22.0
29
+ - hydra-core==1.3.2
30
+ - hydra-submitit-launcher==1.2.0
31
+ - pathos==0.3.4
32
+ - matplotlib==3.10.1
33
+ - pandas==2.2.2
34
+ - seaborn==0.13.2
35
+ - timm==1.0.20
36
+ - xgboost==3.0.5
37
+ - loguru==0.7.3
38
+ - peft==0.17.1
39
+ - accelerate==1.11.0
40
+ - datasets
src/generate_mcts.py ADDED
@@ -0,0 +1,365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env
2
+ import time
3
+ import torch
4
+ import torch.nn.functional as F
5
+ import math
6
+ import random
7
+ import sys
8
+ import pandas as pd
9
+ from utils.generate_utils import mask_for_de_novo
10
+ from diffusion import Diffusion
11
+ from pareto_mcts import Node, MCTS
12
+ import hydra
13
+ from tqdm import tqdm
14
+ from transformers import AutoTokenizer, AutoModel, pipeline
15
+ from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
16
+ from utils.app import PeptideAnalyzer
17
+ import matplotlib.pyplot as plt
18
+ import os
19
+ import seaborn as sns
20
+ import pandas as pd
21
+ import numpy as np
22
+
23
+ # Protein sequence dictionary
24
+ PROTEIN_SEQUENCES = {
25
+ 'amhr': 'MLGSLGLWALLPTAVEAPPNRRTCVFFEAPGVRGSTKTLGELLDTGTELPRAIRCLYSRCCFGIWNLTQDRAQVEMQGCRDSDEPGCESLHCDPSPRAHPSPGSTLFTCSCGTDFCNANYSHLPPPGSPGTPGSQGPQAAPGESIWMALVLLGLFLLLLLLLGSIILALLQRKNYRVRGEPVPEPRPDSGRDWSVELQELPELCFSQVIREGGHAVVWAGQLQGKLVAIKAFPPRSVAQFQAERALYELPGLQHDHIVRFITASRGGPGRLLSGPLLVLELHPKGSLCHYLTQYTSDWGSSLRMALSLAQGLAFLHEERWQNGQYKPGIAHRDLSSQNVLIREDGSCAIGDLGLALVLPGLTQPPAWTPTQPQGPAAIMEAGTQRYMAPELLDKTLDLQDWGMALRRADIYSLALLLWEILSRCPDLRPDSSPPPFQLAYEAELGNTPTSDELWALAVQERRRPYIPSTWRCFATDPDGLRELLEDCWDADPEARLTAECVQQRLAALAHPQESHPFPESCPRGCPPLCPEDCTSIPAPTILPCRPQRSACHFSVQQGPCSRNPQPACTLSPV',
26
+ 'tfr': 'MMDQARSAFSNLFGGEPLSYTRFSLARQVDGDNSHVEMKLAVDEEENADNNTKANVTKPKRCSGSICYGTIAVIVFFLIGFMIGYLGYCKGVEPKTECERLAGTESPVREEPGEDFPAARRLYWDDLKRKLSEKLDSTDFTGTIKLLNENSYVPREAGSQKDENLALYVENQFREFKLSKVWRDQHFVKIQVKDSAQNSVIIVDKNGRLVYLVENPGGYVAYSKAATVTGKLVHANFGTKKDFEDLYTPVNGSIVIVRAGKITFAEKVANAESLNAIGVLIYMDQTKFPIVNAELSFFGHAHLGTGDPYTPGFPSFNHTQFPPSRSSGLPNIPVQTISRAAAEKLFGNMEGDCPSDWKTDSTCRMVTSESKNVKLTVSNVLKEIKILNIFGVIKGFVEPDHYVVVGAQRDAWGPGAAKSGVGTALLLKLAQMFSDMVLKDGFQPSRSIIFASWSAGDFGSVGATEWLEGYLSSLHLKAFTYINLDKAVLGTSNFKVSASPLLYTLIEKTMQNVKHPVTGQFLYQDSNWASKVEKLTLDNAAFPFLAYSGIPAVSFCFCEDTDYPYLGTTMDTYKELIERIPELNKVARAAAEVAGQFVIKLTHDVELNLDYERYNSQLLSFVRDLNQYRADIKEMGLSLQWLYSARGDFFRATSRLTTDFGNAEKTDRFVMKKLNDRVMRVEYHFLSPYVSPKESPFRHVFWGSGSHTLPALLENLKLRKQNNGAFNETLFRNQLALATWTIQGAANALSGDVWDIDNEF',
27
+ 'gfap': 'MERRRITSAARRSYVSSGEMMVGGLAPGRRLGPGTRLSLARMPPPLPTRVDFSLAGALNAGFKETRASERAEMMELNDRFASYIEKVRFLEQQNKALAAELNQLRAKEPTKLADVYQAELRELRLRLDQLTANSARLEVERDNLAQDLATVRQKLQDETNLRLEAENNLAAYRQEADEATLARLDLERKIESLEEEIRFLRKIHEEEVRELQEQLARQQVHVELDVAKPDLTAALKEIRTQYEAMASSNMHEAEEWYRSKFADLTDAAARNAELLRQAKHEANDYRRQLQSLTCDLESLRGTNESLERQMREQEERHVREAASYQEALARLEEEGQSLKDEMARHLQEYQDLLNVKLALDIEIATYRKLLEGEENRITIPVQTFSNLQIRETSLDTKSVSEGHLKRNIVVKTVEMRDGEVIKESKQEHKDVM',
28
+ 'glp1': 'MAGAPGPLRLALLLLGMVGRAGPRPQGATVSLWETVQKWREYRRQCQRSLTEDPPPATDLFCNRTFDEYACWPDGEPGSFVNVSCPWYLPWASSVPQGHVYRFCTAEGLWLQKDNSSLPWRDLSECEESKRGERSSPEEQLLFLYIIYTVGYALSFSALVIASAILLGFRHLHCTRNYIHLNLFASFILRALSVFIKDAALKWMYSTAAQQHQWDGLLSYQDSLSCRLVFLLMQYCVAANYYWLLVEGVYLYTLLAFSVLSEQWIFRLYVSIGWGVPLLFVVPWGIVKYLYEDEGCWTRNSNMNYWLIIRLPILFAIGVNFLIFVRVICIVVSKLKANLMCKTDIKCRLAKSTLTLIPLLGTHEVIFAFVMDEHARGTLRFIKLFTELSFTSFQGLMVAILYCFVNNEVQLEFRKSWERWRLEHLHIQRDSSMKPLKCPTSSLSSGATAGSSMYTATCQASCS',
29
+ 'glast': 'MTKSNGEEPKMGGRMERFQQGVRKRTLLAKKKVQNITKEDVKSYLFRNAFVLLTVTAVIVGTILGFTLRPYRMSYREVKYFSFPGELLMRMLQMLVLPLIISSLVTGMAALDSKASGKMGMRAVVYYMTTTIIAVVIGIIIVIIIHPGKGTKENMHREGKIVRVTAADAFLDLIRNMFPPNLVEACFKQFKTNYEKRSFKVPIQANETLVGAVINNVSEAMETLTRITEELVPVPGSVNGVNALGLVVFSMCFGFVIGNMKEQGQALREFFDSLNEAIMRLVAVIMWYAPVGILFLIAGKIVEMEDMGVIGGQLAMYTVTVIVGLLIHAVIVLPLLYFLVTRKNPWVFIGGLLQALITALGTSSSSATLPITFKCLEENNGVDKRVTRFVLPVGATINMDGTALYEALAAIFIAQVNNFELNFGQIITISITATAASIGAAGIPQAGLVTMVIVLTSVGLPTDDITLIIAVDWFLDRLRTTTNVLGDSLGAGIVEHLSRHELKNRDVEMGNSVIEENEMKKPYQLIAQDNETEKPIDSETKM',
30
+ 'ncam': 'LQTKDLIWTLFFLGTAVSLQVDIVPSQGEISVGESKFFLCQVAGDAKDKDISWFSPNGEKLTPNQQRISVVWNDDSSSTLTIYNANIDDAGIYKCVVTGEDGSESEATVNVKIFQKLMFKNAPTPQEFREGEDAVIVCDVVSSLPPTIIWKHKGRDVILKKDVRFIVLSNNYLQIRGIKKTDEGTYRCEGRILARGEINFKDIQVIVNVPPTIQARQNIVNATANLGQSVTLVCDAEGFPEPTMSWTKDGEQIEQEEDDEKYIFSDDSSQLTIKKVDKNDEAEYICIAENKAGEQDATIHLKVFAKPKITYVENQTAMELEEQVTLTCEASGDPIPSITWRTSTRNISSEEKASWTRPEKQETLDGHMVVRSHARVSSLTLKSIQYTDAGEYICTASNTIGQDSQSMYLEVQYAPKLQGPVAVYTWEGNQVNITCEVFAYPSATISWFRDGQLLPSSNYSNIKIYNTPSASYLEVTPDSENDFGNYNCTAVNRIGQESLEFILVQADTPSSPSIDQVEPYSSTAQVQFDEPEATGGVPILKYKAEWRAVGEEVWHSKWYDAKEASMEGIVTIVGLKPETTYAVRLAALNGKGLGEISAASEF',
31
+ 'cereblon': 'MAGEGDQQDAAHNMGNHLPLLPAESEEEDEMEVEDQDSKEAKKPNIINFDTSLPTSHTYLGADMEEFHGRTLHDDDSCQVIPVLPQVMMILIPGQTLPLQLFHPQEVSMVRNLIQKDRTFAVLAYSNVQEREAQFGTTAEIYAYREEQDFGIEIVKVKAIGRQRFKVLELRTQSDGIQQAKVQILPECVLPSTMSAVQLESLNKCQIFPSKPVSREDQCSYKWWQKYQKRKFHCANLTSWPRWLYSLYDAETLMDRIKKQLREWDENLKDDSLPSNPIDFSYRVAACLPIDDVLRIQLLKIGSAIQRLRCELDIMNKCTSLCCKQCQETEITTKNEIFSLSLCGPMAAYVNPHGYVHETLTVYKACNLNLIGRPSTEHSWFPGYAWTVAQCKICASHIGWKFTATKKDMSPQKFWGLTRSALLPTIPDTEDEISPDKVILCL',
32
+ 'ligase': 'MASQPPEDTAESQASDELECKICYNRYNLKQRKPKVLECCHRVCAKCLYKIIDFGDSPQGVIVCPFCRFETCLPDDEVSSLPDDNNILVNLTCGGKGKKCLPENPTELLLTPKRLASLVSPSHTSSNCLVITIMEVQRESSPSLSSTPVVEFYRPASFDSVTTVSHNWTVWNCTSLLFQTSIRVLVWLLGLLYFSSLPLGIYLLVSKKVTLGVVFVSLVPSSLVILMVYGFCQCVCHEFLDCMAPPS',
33
+ 'skp2': 'MHRKHLQEIPDLSSNVATSFTWGWDSSKTSELLSGMGVSALEKEEPDSENIPQELLSNLGHPESPPRKRLKSKGSDKDFVIVRRPKLNRENFPGVSWDSLPDELLLGIFSCLCLPELLKVSGVCKRWYRLASDESLWQTLDLTGKNLHPDVTGRLLSQGVIAFRCPRSFMDQPLAEHFSPFRVQHMDLSNSVIEVSTLHGILSQCSKLQNLSLEGLRLSDPIVNTLAKNSNLVRLNLSGCSGFSEFALQTLLSSCSRLDELNLSWCFDFTEKHVQVAVAHVSETITQLNLSGYRKNLQKSDLSTLVRRCPNLVHLDLSDSVMLKNDCFQEFFQLNYLQHLSLSRCYDIIPETLLELGEIPTLKTLQVFGIVPDGTLQLLKEALPHLQINCSHFTTIARPTIGNKKNQEIWGIKCRLTLQKPSCL',
34
+ 'p53': 'MEEPQSDPSVEPPLSQETFSDLWKLLPENNVLSPLPSQAMDDLMLSPDDIEQWFTEDPGPDEAPRMPEAAPPVAPAPAAPTPAAPAPAPSWPLSSSVPSQKTYQGSYGFRLGFLHSGTAKSVTCTYSPALNKMFCQLAKTCPVQLWVDSTPPPGTRVRAMAIYKQSQHMTEVVRRCPHHERCSDSDGLAPPQHLIRVEGNLRVEYLDDRNTFRHSVVVPYEPPEVGSDCTTIHYNYMCNSSCMGGMNRRPILTIITLEDSSGNLLGRNSFEVRVCACPGRDRRTEEENLRKKGEPHHELPPGSTKRALPNNTSSSPQPKKKPLDGEYFTLQIRGRERFEMFRELNEALELKDAQAGKEPGGSRAHSSHLKSKKGQSTSRHKKLMFKTEGPDSD',
35
+ 'egfp': 'VSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTLTYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITLGMDELYK'
36
+ }
37
+
38
+ def save_logs_to_file(config, valid_fraction_log, score_logs, output_path):
39
+ """
40
+ Saves the logs to a CSV file.
41
+
42
+ Parameters:
43
+ valid_fraction_log (list): Log of valid fractions over iterations.
44
+ score_logs (dict): Dict mapping score func names to lists of scores.
45
+ output_path (str): Path to save the log CSV file.
46
+ """
47
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
48
+
49
+ log_data = {
50
+ "Iteration": list(range(1, len(valid_fraction_log) + 1)),
51
+ "Valid Fraction": valid_fraction_log,
52
+ }
53
+ for name, log in score_logs.items():
54
+ log_data[name] = log
55
+
56
+ df = pd.DataFrame(log_data)
57
+
58
+ # Save to CSV
59
+ df.to_csv(output_path, index=False)
60
+
61
+ def plot_data(log1, log2=None,
62
+ save_path=None,
63
+ label1="Log 1",
64
+ label2=None,
65
+ title="Fraction of Valid Peptides Over Iterations",
66
+ palette=None):
67
+ """
68
+ Plots one or two datasets with their mean values over iterations.
69
+
70
+ Parameters:
71
+ log1 (list): The first list of mean values for each iteration.
72
+ log2 (list, optional): The second list of mean values for each iteration. Defaults to None.
73
+ save_path (str): Path to save the plot. Defaults to None.
74
+ label1 (str): Label for the first dataset. Defaults to "Log 1".
75
+ label2 (str, optional): Label for the second dataset. Defaults to None.
76
+ title (str): Title of the plot. Defaults to "Mean Values Over Iterations".
77
+ palette (dict, optional): A dictionary defining custom colors for datasets. Defaults to None.
78
+ """
79
+ # Prepare data for log1
80
+ data1 = pd.DataFrame({
81
+ "Iteration": range(1, len(log1) + 1),
82
+ "Fraction of Valid Peptides": log1,
83
+ "Dataset": label1
84
+ })
85
+
86
+ # Prepare data for log2 if provided
87
+ if log2 is not None:
88
+ data2 = pd.DataFrame({
89
+ "Iteration": range(1, len(log2) + 1),
90
+ "Fraction of Valid Peptides": log2,
91
+ "Dataset": label2
92
+ })
93
+ data = pd.concat([data1, data2], ignore_index=True)
94
+ else:
95
+ data = data1
96
+
97
+ palette = {
98
+ label1: "#8181ED", # Default color for log1
99
+ label2: "#D577FF" # Default color for log2 (if provided)
100
+ }
101
+
102
+ # Set Seaborn theme
103
+ sns.set_theme()
104
+ sns.set_context("paper")
105
+
106
+ # Create the plot
107
+ sns.lineplot(
108
+ data=data,
109
+ x="Iteration",
110
+ y="Fraction of Valid Peptides",
111
+ hue="Dataset",
112
+ style="Dataset",
113
+ markers=True,
114
+ dashes=False,
115
+ palette=palette
116
+ )
117
+
118
+ # Titles and labels
119
+ plt.title(title)
120
+ plt.xlabel("Iteration")
121
+ plt.ylabel("Fraction of Valid Peptides")
122
+
123
+ if save_path:
124
+ plt.savefig(save_path, dpi=300, bbox_inches='tight')
125
+ print(f"Plot saved to {save_path}")
126
+ plt.show()
127
+
128
+ def plot_data_with_distribution_seaborn(log1, log2=None,
129
+ save_path=None,
130
+ label1=None,
131
+ label2=None,
132
+ title=None):
133
+ """
134
+ Plots one or two datasets with the average values and distributions over iterations using Seaborn.
135
+
136
+ Parameters:
137
+ log1 (list of lists): The first list of scores (each element is a list of scores for an iteration).
138
+ log2 (list of lists, optional): The second list of scores (each element is a list of scores for an iteration). Defaults to None.
139
+ save_path (str): Path to save the plot. Defaults to None.
140
+ label1 (str): Label for the first dataset. Defaults to "Fraction of Valid Peptide SMILES".
141
+ label2 (str, optional): Label for the second dataset. Defaults to None.
142
+ title (str): Title of the plot. Defaults to "Fraction of Valid Peptides Over Iterations".
143
+ """
144
+ # Prepare data for log1
145
+ data1 = pd.DataFrame({
146
+ "Iteration": np.repeat(range(1, len(log1) + 1), [len(scores) for scores in log1]),
147
+ "Fraction of Valid Peptides": [float(score) for scores in log1 for score in scores],
148
+ "Dataset": label1,
149
+ "Style": "Log1"
150
+ })
151
+
152
+ # Prepare data for log2 if provided
153
+ if log2 is not None:
154
+ data2 = pd.DataFrame({
155
+ "Iteration": np.repeat(range(1, len(log2) + 1), [len(scores) for scores in log2]),
156
+ "Fraction of Valid Peptides": [float(score) for scores in log2 for score in scores],
157
+ "Dataset": label2,
158
+ "Style": "Log2"
159
+ })
160
+ data = pd.concat([data1, data2], ignore_index=True)
161
+ else:
162
+ data = data1
163
+
164
+ palette = {
165
+ label1: "#8181ED", # Default color for log1
166
+ label2: "#D577FF" # Default color for log2 (if provided)
167
+ }
168
+
169
+ # Set Seaborn theme
170
+ sns.set_theme()
171
+ sns.set_context("paper")
172
+
173
+ # Create the plot
174
+ sns.relplot(
175
+ data=data,
176
+ kind="line",
177
+ x="Iteration",
178
+ y="Fraction of Valid Peptides",
179
+ hue="Dataset",
180
+ style="Style",
181
+ markers=True,
182
+ dashes=True,
183
+ ci="sd", # Show standard deviation
184
+ height=5,
185
+ aspect=1.5,
186
+ palette=palette
187
+ )
188
+
189
+ # Titles and labels
190
+ plt.title(title)
191
+ plt.xlabel("Iteration")
192
+ plt.ylabel("Fraction of Valid Peptides")
193
+
194
+ if save_path:
195
+ plt.savefig(save_path, dpi=300, bbox_inches='tight')
196
+ print(f"Plot saved to {save_path}")
197
+ plt.show()
198
+
199
+ @torch.no_grad()
200
+ def generate_valid_mcts(config, mdlm, prot1=None, prot2=None, filename=None, prot_name1=None, prot_name2 = None):
201
+ tokenizer = mdlm.tokenizer
202
+ max_sequence_length = config.sampling.seq_length
203
+
204
+ # generate array of [MASK] tokens
205
+ masked_array = mask_for_de_novo(config, max_sequence_length)
206
+
207
+ inputs = tokenizer.encode(masked_array)
208
+
209
+ inputs = {key: value.to(mdlm.device) for key, value in inputs.items()}
210
+
211
+ # initialize root node
212
+ rootNode = Node(config=config, tokens=inputs, timestep=0)
213
+ # initalize tree search algorithm
214
+
215
+ if config.mcts.perm:
216
+ score_func_names = ['permeability', 'binding_affinity1', 'solubility', 'hemolysis', 'nonfouling']
217
+ num_func = [0, 0, 0, 0, 0]
218
+ elif config.mcts.dual:
219
+ score_func_names = ['binding_affinity1', 'solubility', 'hemolysis', 'nonfouling', 'binding_affinity2']
220
+ num_func = [0, 0, 0, 0, 0]
221
+ elif config.mcts.single:
222
+ if config.mode == 'binding':
223
+ score_func_names = ['binding_affinity1']
224
+ else:
225
+ score_func_names = ['permeability']
226
+ num_func = [0]
227
+ else:
228
+ score_func_names = ['binding_affinity1', 'solubility', 'hemolysis', 'nonfouling']
229
+ num_func = [0, 0, 0, 0]
230
+
231
+ if not config.mcts.time_dependent:
232
+ num_func = [0] * len(score_func_names)
233
+
234
+ if prot1 and prot2 is not None:
235
+ mcts = MCTS(config=config, max_sequence_length=max_sequence_length, mdlm=mdlm, score_func_names=score_func_names, prot_seqs=[prot1, prot2], num_func=num_func)
236
+ elif prot1 is not None:
237
+ mcts = MCTS(config=config, max_sequence_length=max_sequence_length, mdlm=mdlm, score_func_names=score_func_names, prot_seqs=[prot1], num_func=num_func)
238
+ elif config.mcts.single:
239
+ mcts = MCTS(config=config, max_sequence_length=max_sequence_length, mdlm=mdlm, score_func_names=score_func_names, num_func=num_func)
240
+ else:
241
+ mcts = MCTS(config=config, max_sequence_length=max_sequence_length, mdlm=mdlm, score_func_names=score_func_names, num_func=num_func)
242
+
243
+ paretoFront = mcts.forward(rootNode)
244
+
245
+ output_log_path = f'{config.base_path}/{prot_name1}/log_{filename}.csv'
246
+ save_logs_to_file(config, mcts.valid_fraction_log, mcts.score_logs, output_log_path)
247
+
248
+ plot_data(mcts.valid_fraction_log,
249
+ save_path=f'{config.base_path}/{prot_name1}/valid_{filename}.png')
250
+
251
+ for name in mcts.score_func_names:
252
+ plot_data_with_distribution_seaborn(log1=mcts.score_logs[name],
253
+ save_path=f'{config.base_path}/{prot_name1}/{name}_{filename}.png',
254
+ label1=f"Average {name}",
255
+ title=f"Average {name} Over Iterations")
256
+
257
+ return paretoFront, inputs
258
+
259
+
260
+ @hydra.main(version_base=None, config_path='.', config_name='config')
261
+ def main(config):
262
+ # Get parameters from config with defaults
263
+ prot_name1 = config.get('prot_name1', 'gfap')
264
+ prot_name2 = config.get('prot_name2', None)
265
+ mode = config.get('mode', '2')
266
+ model = config.get('model_type', 'mcts')
267
+ length = config.get('length', '100')
268
+ epoch = config.get('epoch', '7')
269
+
270
+ filename = f'{mode}_{model}_length_{length}_epoch_{epoch}'
271
+
272
+ tokenizer = SMILES_SPE_Tokenizer(f'{config.base_path}/src/tokenizer/new_vocab.txt',
273
+ f'{config.base_path}/src/tokenizer/new_splits.txt')
274
+
275
+ mdlm = Diffusion.load_from_checkpoint(config.eval.checkpoint_path, config=config, tokenizer=tokenizer, strict=False)
276
+
277
+ mdlm.eval()
278
+ device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
279
+ mdlm.to(device)
280
+
281
+
282
+ print("loaded models...")
283
+ analyzer = PeptideAnalyzer()
284
+
285
+ # Look up protein sequences from names
286
+ prot_seq1 = PROTEIN_SEQUENCES.get(prot_name1.lower())
287
+ prot_seq2 = PROTEIN_SEQUENCES.get(prot_name2.lower()) if prot_name2 else None
288
+
289
+ if prot_seq1 is None:
290
+ raise ValueError(f"Protein '{prot_name1}' not found in PROTEIN_SEQUENCES dictionary. Available proteins: {list(PROTEIN_SEQUENCES.keys())}")
291
+
292
+ if prot_name2 and prot_seq2 is None:
293
+ raise ValueError(f"Protein '{prot_name2}' not found in PROTEIN_SEQUENCES dictionary. Available proteins: {list(PROTEIN_SEQUENCES.keys())}")
294
+
295
+ print(f"Using protein 1: {prot_name1}")
296
+ if prot_name2:
297
+ print(f"Using protein 2: {prot_name2}")
298
+
299
+ t_start = time.time()
300
+ paretoFront, input_array = generate_valid_mcts(config, mdlm, prot_seq1, prot_seq2, filename, prot_name1, prot_name2)
301
+ generation_results = []
302
+
303
+ for sequence, v in paretoFront.items():
304
+ generated_array = v['token_ids'].to(mdlm.device)
305
+
306
+ # compute perplexity
307
+ perplexity = mdlm.compute_masked_perplexity(generated_array, input_array['input_ids'])
308
+ perplexity = round(perplexity, 4)
309
+
310
+ aa_seq, seq_length = analyzer.analyze_structure(sequence)
311
+ scores = v['scores']
312
+
313
+ if config.mcts.single == False:
314
+ binding1 = scores[0]
315
+ solubility = scores[1]
316
+ hemo = scores[2]
317
+ nonfouling = scores[3]
318
+
319
+ if config.mcts.perm:
320
+ permeability = scores[4]
321
+ generation_results.append([sequence, perplexity, aa_seq, binding1, solubility, hemo, nonfouling, permeability])
322
+ print(f"perplexity: {perplexity} | length: {seq_length} | smiles sequence: {sequence} | amino acid sequence: {aa_seq} | Binding Affinity: {binding1} | Solubility: {solubility} | Hemolysis: {hemo} | Nonfouling: {nonfouling} | Permeability: {permeability}")
323
+ elif config.mcts.dual:
324
+ binding2 = scores[4]
325
+ generation_results.append([sequence, perplexity, aa_seq, binding1, binding2, solubility, hemo, nonfouling])
326
+ print(f"perplexity: {perplexity} | length: {seq_length} | smiles sequence: {sequence} | amino acid sequence: {aa_seq} | Binding Affinity 1: {binding1} | Binding Affinity 2: {binding2} | Solubility: {solubility} | Hemolysis: {hemo} | Nonfouling: {nonfouling}")
327
+ elif config.mcts.single:
328
+ permeability = scores[0]
329
+ else:
330
+ generation_results.append([sequence, perplexity, aa_seq, binding1, solubility, hemo, nonfouling])
331
+ print(f"perplexity: {perplexity} | length: {seq_length} | smiles sequence: {sequence} | amino acid sequence: {aa_seq} | Binding Affinity: {binding1} | Solubility: {solubility} | Hemolysis: {hemo} | Nonfouling: {nonfouling}")
332
+
333
+ sys.stdout.flush()
334
+
335
+ if config.mcts.perm:
336
+ df = pd.DataFrame(generation_results, columns=['Generated SMILES', 'Perplexity', 'Peptide Sequence', 'Binding Affinity', 'Solubility', 'Hemolysis', 'Nonfouling', 'Permeability'])
337
+ elif config.mcts.dual:
338
+ df = pd.DataFrame(generation_results, columns=['Generated SMILES', 'Perplexity', 'Peptide Sequence', 'Binding Affinity 1', 'Binding Affinity 2', 'Solubility', 'Hemolysis', 'Nonfouling'])
339
+ elif config.mcts.single:
340
+ df = pd.DataFrame(generation_results, columns=['Generated SMILES', 'Perplexity', 'Peptide Sequence', 'Permeability'])
341
+ else:
342
+ df = pd.DataFrame(generation_results, columns=['Generated SMILES', 'Perplexity', 'Peptide Sequence', 'Binding Affinity', 'Solubility', 'Hemolysis', 'Nonfouling'])
343
+
344
+ df.to_csv(f'{config.base_path}/{prot_name1}/{filename}.csv', index=False)
345
+
346
+ # ── timing ──
347
+ elapsed = time.time() - t_start
348
+ print(f"\n{'='*60}")
349
+ print(f"Generation complete in {elapsed:.1f}s ({elapsed/60:.1f} min)")
350
+ print(f"Pareto front size: {len(df)}")
351
+
352
+ # ── score statistics ──
353
+ score_cols = [c for c in df.columns if c not in ('Generated SMILES', 'Peptide Sequence')]
354
+ print(f"\n{'Score':<22} {'Mean':>8} {'Std':>8} {'Min':>8} {'Max':>8}")
355
+ print('-' * 58)
356
+ for col in score_cols:
357
+ vals = pd.to_numeric(df[col], errors='coerce').dropna()
358
+ if len(vals) == 0:
359
+ continue
360
+ print(f"{col:<22} {vals.mean():8.4f} {vals.std():8.4f} {vals.min():8.4f} {vals.max():8.4f}")
361
+ print('=' * 60)
362
+
363
+
364
+ if __name__ == "__main__":
365
+ main()
src/generate_unconditional.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import sys
5
+ import pandas as pd
6
+ import omegaconf
7
+ from utils.generate_utils import mask_for_de_novo, calculate_cosine_sim, calculate_hamming_dist
8
+ from diffusion import Diffusion
9
+ import hydra
10
+ from tqdm import tqdm
11
+ from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
12
+ from utils.app import PeptideAnalyzer
13
+ from scoring.scoring_functions import ScoringFunctions
14
+
15
+ # Register custom OmegaConf resolvers required by config.yaml
16
+ omegaconf.OmegaConf.register_new_resolver('cwd', os.getcwd, replace=True)
17
+ omegaconf.OmegaConf.register_new_resolver('device_count', torch.cuda.device_count, replace=True)
18
+ omegaconf.OmegaConf.register_new_resolver('eval', eval, replace=True)
19
+ omegaconf.OmegaConf.register_new_resolver('div_up', lambda x, y: (x + y - 1) // y, replace=True)
20
+
21
+ base_path = '/path/to/your/home/PepTune'
22
+ ckpt_path = base_path + '/checkpoints/peptune-pretrained.ckpt'
23
+
24
+ @torch.no_grad()
25
+ def generate_sequence_unconditional(config, sequence_length: int, mdlm: Diffusion):
26
+ tokenizer = mdlm.tokenizer
27
+ # generate array of [MASK] tokens
28
+ masked_array = mask_for_de_novo(config, sequence_length)
29
+
30
+ inputs = tokenizer.encode(masked_array)
31
+
32
+ # tokenized masked array
33
+ inputs = {key: value.to(mdlm.device) for key, value in inputs.items()}
34
+ # sample unconditional array of tokens
35
+ logits = mdlm._sample(x_input=inputs) # using sample, change config.sampling.steps to determine robustness
36
+
37
+ return logits, inputs
38
+
39
+
40
+ @hydra.main(version_base=None, config_path='.', config_name='config')
41
+ def main(config):
42
+
43
+ tokenizer = SMILES_SPE_Tokenizer(f'{base_path}/src/tokenizer/new_vocab.txt',
44
+ f'{base_path}/src/tokenizer/new_splits.txt')
45
+
46
+ # Build model with current config, then load weights manually
47
+ # (load_from_checkpoint overrides config with saved hparams)
48
+ mdlm_model = Diffusion(config=config, tokenizer=tokenizer)
49
+ ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
50
+ mdlm_model.load_state_dict(ckpt["state_dict"], strict=False)
51
+
52
+ mdlm_model.eval()
53
+ device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
54
+ mdlm_model.to(device)
55
+
56
+ print("loaded models...")
57
+ analyzer = PeptideAnalyzer()
58
+
59
+ gfap = 'MERRRITSAARRSYVSSGEMMVGGLAPGRRLGPGTRLSLARMPPPLPTRVDFSLAGALNAGFKETRASERAEMMELNDRFASYIEKVRFLEQQNKALAAELNQLRAKEPTKLADVYQAELRELRLRLDQLTANSARLEVERDNLAQDLATVRQKLQDETNLRLEAENNLAAYRQEADEATLARLDLERKIESLEEEIRFLRKIHEEEVRELQEQLARQQVHVELDVAKPDLTAALKEIRTQYEAMASSNMHEAEEWYRSKFADLTDAAARNAELLRQAKHEANDYRRQLQSLTCDLESLRGTNESLERQMREQEERHVREAASYQEALARLEEEGQSLKDEMARHLQEYQDLLNVKLALDIEIATYRKLLEGEENRITIPVQTFSNLQIRETSLDTKSVSEGHLKRNIVVKTVEMRDGEVIKESKQEHKDVM'
60
+
61
+ # scoring functions
62
+ score_func_names = ['binding_affinity1', 'solubility', 'hemolysis', 'nonfouling', 'permeability']
63
+ score_functions = ScoringFunctions(score_func_names, [gfap])
64
+
65
+
66
+ max_seq_length = config.sampling.seq_length
67
+ num_sequences = config.sampling.num_sequences
68
+ generation_results = []
69
+ num_valid = 0.
70
+ num_total = 0.
71
+ while num_total < num_sequences:
72
+ num_total += 1
73
+ generated_array, input_array = generate_sequence_unconditional(config, max_seq_length, mdlm_model)
74
+
75
+ # store in device
76
+ generated_array = generated_array.to(mdlm_model.device)
77
+ print(generated_array)
78
+
79
+ # compute masked perplexity
80
+ perplexity = mdlm_model.compute_masked_perplexity(generated_array, input_array['input_ids'])
81
+ perplexity = round(perplexity, 4)
82
+
83
+ smiles_seq = tokenizer.decode(generated_array)
84
+ if analyzer.is_peptide(smiles_seq):
85
+ aa_seq, seq_length = analyzer.analyze_structure(smiles_seq)
86
+ num_valid += 1
87
+ scores = score_functions(input_seqs=[smiles_seq])
88
+
89
+ binding = scores[0][0]
90
+ sol = scores[0][1]
91
+ hemo = scores[0][2]
92
+ nf = scores[0][3]
93
+ perm = scores[0][4]
94
+
95
+ generation_results.append([smiles_seq, perplexity, aa_seq, binding, sol, hemo, nf, perm])
96
+ else:
97
+ aa_seq = "not valid peptide"
98
+ seq_length = '-'
99
+ scores = "not valid peptide"
100
+
101
+
102
+ print(f"perplexity: {perplexity} | length: {seq_length} | smiles sequence: {smiles_seq} | amino acid sequence: {aa_seq} | scores: {scores}")
103
+ sys.stdout.flush()
104
+
105
+ valid_frac = num_valid / num_total
106
+ print(f"fraction of synthesizable peptides: {valid_frac}")
107
+ df = pd.DataFrame(generation_results, columns=['Generated SMILES', 'Perplexity', 'Peptide Sequence', 'Binding Affinity', 'Solubility', 'Hemolysis', 'Nonfouling', 'Permeability'])
108
+ df.to_csv(base_path + f'/results/test_generate.csv', index=False)
109
+
110
+ if __name__ == "__main__":
111
+ main()
src/metrics.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ from math import sqrt
3
+
4
+ def summarize_metrics(skip, csv_path: str, save_path: str | None = None) -> pd.DataFrame:
5
+ """
6
+ Compute mean and standard deviation for all columns except the first
7
+ (assumed non-numeric identifier like 'Peptide Sequence').
8
+
9
+ Returns a DataFrame with rows = column names and columns = ['mean','std','count'].
10
+ Uses sample std (ddof=1). Non-numeric cells are coerced to NaN.
11
+ """
12
+ df = pd.read_csv(csv_path)
13
+ vals = df.iloc[:, skip:].apply(pd.to_numeric, errors='coerce') # columns 2..end
14
+ stats = vals.agg(['mean', 'std', 'count']).T # shape: (num_metrics, 3)
15
+ if save_path:
16
+ stats.to_csv(save_path, index=True)
17
+ return stats
18
+
19
+ def summarize_list(xs, ddof = 1):
20
+ # Clean & coerce to float
21
+ vals = []
22
+ for x in xs:
23
+ if x is None or x == "":
24
+ continue
25
+ try:
26
+ vals.append(float(x))
27
+ except (TypeError, ValueError):
28
+ continue
29
+
30
+ n = len(vals)
31
+ if n == 0:
32
+ raise ValueError("No numeric values found.")
33
+ if n <= ddof:
34
+ raise ValueError(f"Need at least {ddof + 1} numeric values; got {n}.")
35
+
36
+ # Welford’s algorithm (one pass, stable)
37
+ mean = 0.0
38
+ M2 = 0.0
39
+ count = 0
40
+ for v in vals:
41
+ count += 1
42
+ delta = v - mean
43
+ mean += delta / count
44
+ M2 += delta * (v - mean)
45
+
46
+ var = M2 / (count - ddof)
47
+ std = sqrt(var)
48
+
49
+ result = {"mean": mean, "std": std, "count": count}
50
+
51
+ return result
52
+
53
+ def csv_column_to_list(path: str, column: str, *, dropna: bool = True):
54
+ df = pd.read_csv(path)
55
+ if column not in df.columns:
56
+ raise KeyError(f"Column '{column}' not found. Available: {list(df.columns)}")
57
+ s = df[column]
58
+ if dropna:
59
+ s = s.dropna()
60
+ return s.tolist()
61
+
62
+ def main():
63
+ csv_path = "/scratch/pranamlab/sophtang/home/tr2d2/peptides/plots/glast_resample20_no-mcts/"
64
+ path = "/scratch/pranamlab/sophtang/home/TR2-D2/tr2d2-pep/results/tfr_resample10_buffer20_numiter10_children50_20260326_183626"
65
+ prot_name = "tfr"
66
+ stats = summarize_metrics(skip=1, csv_path=f"{path}/{prot_name}_generation_results.csv",
67
+ save_path=f"{path}/results_summary.csv")
68
+
69
+ print(stats)
70
+
71
+ if __name__ == '__main__':
72
+ main()
src/noise_schedule.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Adapted from MDLM: https://github.com/kuleshov-group/mdlm
2
+
3
+ import abc
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ torch._C._jit_set_profiling_mode(False)
9
+ torch._C._jit_set_profiling_executor(False)
10
+ torch._C._jit_override_can_fuse_on_cpu(True)
11
+ torch._C._jit_override_can_fuse_on_gpu(True)
12
+
13
+ def get_noise(config, dtype=torch.float32):
14
+ if config.noise.type == 'geometric':
15
+ return GeometricNoise(config.noise.sigma_min, config.noise.sigma_max)
16
+ elif config.noise.type == 'loglinear':
17
+ return LogLinearNoise()
18
+ elif config.noise.type == 'cosine':
19
+ return CosineNoise()
20
+ elif config.noise.type == 'cosinesqr':
21
+ return CosineSqrNoise()
22
+ elif config.noise.type == 'linear':
23
+ return Linear(config.noise.sigma_min, config.noise.sigma_max, dtype)
24
+ else:
25
+ raise ValueError(f'{config.noise.type} is not a valid noise')
26
+
27
+
28
+ def binary_discretization(z):
29
+ z_hard = torch.sign(z)
30
+ z_soft = z / torch.norm(z, dim=-1, keepdim=True)
31
+ return z_soft + (z_hard - z_soft).detach()
32
+
33
+
34
+ class Noise(abc.ABC, nn.Module):
35
+ """
36
+ Baseline forward method to get the total + rate of noise at a timestep
37
+ """
38
+ def forward(self, t):
39
+ # Assume time goes from 0 to 1
40
+ return self.total_noise(t), self.rate_noise(t)
41
+
42
+
43
+ class CosineNoise(Noise):
44
+ def __init__(self, eps=1e-3):
45
+ super().__init__()
46
+ self.eps = eps
47
+
48
+ def rate_noise(self, t):
49
+ cos = (1 - self.eps) * torch.cos(t * torch.pi / 2)
50
+ sin = (1 - self.eps) * torch.sin(t * torch.pi / 2)
51
+ scale = torch.pi / 2
52
+ return scale * sin / (cos + self.eps)
53
+
54
+ def total_noise(self, t):
55
+ cos = torch.cos(t * torch.pi / 2)
56
+ return - torch.log(self.eps + (1 - self.eps) * cos)
57
+
58
+
59
+ class CosineSqrNoise(Noise):
60
+ def __init__(self, eps=1e-3):
61
+ super().__init__()
62
+ self.eps = eps
63
+
64
+ def rate_noise(self, t):
65
+ cos = (1 - self.eps) * (
66
+ torch.cos(t * torch.pi / 2) ** 2)
67
+ sin = (1 - self.eps) * torch.sin(t * torch.pi)
68
+ scale = torch.pi / 2
69
+ return scale * sin / (cos + self.eps)
70
+
71
+ def total_noise(self, t):
72
+ cos = torch.cos(t * torch.pi / 2) ** 2
73
+ return - torch.log(self.eps + (1 - self.eps) * cos)
74
+
75
+
76
+ class Linear(Noise):
77
+ def __init__(self, sigma_min=0, sigma_max=10, dtype=torch.float32):
78
+ super().__init__()
79
+ self.sigma_min = torch.tensor(sigma_min, dtype=dtype)
80
+ self.sigma_max = torch.tensor(sigma_max, dtype=dtype)
81
+
82
+ def rate_noise(self):
83
+ return self.sigma_max - self.sigma_min
84
+
85
+ def total_noise(self, t):
86
+ return self.sigma_min + t * (self.sigma_max - self.sigma_min)
87
+
88
+ def importance_sampling_transformation(self, t):
89
+ f_T = torch.log1p(- torch.exp(- self.sigma_max))
90
+ f_0 = torch.log1p(- torch.exp(- self.sigma_min))
91
+ sigma_t = - torch.log1p(- torch.exp(t * f_T + (1 - t) * f_0))
92
+ return (sigma_t - self.sigma_min) / (
93
+ self.sigma_max - self.sigma_min)
94
+
95
+
96
+ class GeometricNoise(Noise):
97
+ def __init__(self, sigma_min=1e-3, sigma_max=1):
98
+ super().__init__()
99
+ self.sigmas = 1.0 * torch.tensor([sigma_min, sigma_max])
100
+
101
+ def rate_noise(self, t):
102
+ return self.sigmas[0] ** (1 - t) * self.sigmas[1] ** t * (
103
+ self.sigmas[1].log() - self.sigmas[0].log())
104
+
105
+ def total_noise(self, t):
106
+ return self.sigmas[0] ** (1 - t) * self.sigmas[1] ** t
107
+
108
+
109
+ class LogLinearNoise(Noise):
110
+ """Log Linear noise schedule.
111
+
112
+ Built such that 1 - 1/e^(n(t)) interpolates between 0 and
113
+ ~1 when t varies from 0 to 1. Total noise is
114
+ -log(1 - (1 - eps) * t), so the sigma will be
115
+ (1 - eps) * t.
116
+ """
117
+ def __init__(self, eps=1e-3):
118
+ super().__init__()
119
+ self.eps = eps
120
+ self.sigma_max = self.total_noise(torch.tensor(1.0))
121
+ self.sigma_min = self.eps + self.total_noise(torch.tensor(0.0))
122
+
123
+ def rate_noise(self, t):
124
+ return (1 - self.eps) / (1 - (1 - self.eps) * t)
125
+
126
+ def total_noise(self, t):
127
+ return -torch.log1p(-(1 - self.eps) * t)
128
+
129
+ def importance_sampling_transformation(self, t):
130
+ f_T = torch.log1p(- torch.exp(- self.sigma_max))
131
+ f_0 = torch.log1p(- torch.exp(- self.sigma_min))
132
+ sigma_t = - torch.log1p(- torch.exp(t * f_T + (1 - t) * f_0))
133
+ t = - torch.expm1(- sigma_t) / (1 - self.eps)
134
+ return t
135
+
136
+ class LogPolyNoise(Noise):
137
+ """
138
+ Log Polynomial noise schedule for slower masking of peptide bond tokens
139
+ """
140
+ def __init__(self, eps=1e-3):
141
+ super().__init__()
142
+ self.eps = eps
143
+ self.sigma_max = self.total_noise(torch.tensor(1.0))
144
+ self.sigma_min = self.eps + self.total_noise(torch.tensor(0.0))
145
+
146
+ def rate_noise(self, t):
147
+ # derivative of -log(1-t^w)
148
+ return ((3 * (t**2)) - self.eps) / (1 - (1 - self.eps) * (t**3))
149
+
150
+ def total_noise(self, t):
151
+ # -log(1-t^w)
152
+ return -torch.log1p(-(1 - self.eps) * (t**3))
src/pareto_mcts.py ADDED
@@ -0,0 +1,492 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ import random as rd
6
+
7
+ from diffusion import Diffusion
8
+ from scoring.scoring_functions import ScoringFunctions
9
+ from utils.app import PeptideAnalyzer
10
+ import noise_schedule
11
+
12
+ """"
13
+ Notes: store rolled out sequence?
14
+ path of node objects or strings?
15
+ should we only select valid expandable leaf nodes?
16
+ calculate similarity between sibling nodes?
17
+ should we evaluate generated sequences?
18
+ """
19
+ class Node:
20
+ """
21
+ Node class: partially unmasked SMILES string
22
+ - parentNode: Node object at previous time step
23
+ - childNodes: set of M Node objects generated from sampling M distinct unmasking schemes
24
+ - totalReward: vector of cumulative rewards for all K objectives
25
+ - visits: number of times the node has been visited by an interation
26
+ - path: array of partially unmasked SMILES strings leading to the node from the completely masked root node
27
+ - timestep: the time step where the sequence was sampled
28
+ - sampleProb: probability of sampling the sequence from the diffusion model
29
+ """
30
+ def __init__(self, config, tokens=None, parentNode=None, childNodes=[], scoreVector=None, totalReward=None, timestep=None, sampleProb=None):
31
+ self.config = config
32
+ self.parentNode = parentNode
33
+ self.childNodes = childNodes
34
+ self.scoreVector = scoreVector
35
+
36
+ # initialize total rewards to the reward of the roll out unmasked sequence
37
+ if totalReward is not None:
38
+ self.totalReward = totalReward
39
+ else:
40
+ self.totalReward = np.zeros(self.config.mcts.num_objectives)
41
+
42
+ # set initial visits to 1
43
+ self.visits = 1
44
+ # array of all sequences in path from the root -> node
45
+ #self.path = path
46
+ # set timestep (value between 0 and num_steps)
47
+ self.timestep = timestep
48
+ # set the sampling probabiltiy equal to the probability from the reverse posterior
49
+ self.sampleProb = sampleProb
50
+
51
+ # dict with 'input_ids' as token array and 'attention_mask'
52
+ self.tokens = tokens
53
+
54
+ #self.sequence = sequence
55
+
56
+ def selectNode(self, num_func):
57
+ """
58
+ Selects a node to move to among the children nodes
59
+ """
60
+ # extract the status of the current node
61
+ nodeStatus = self.getExpandStatus()
62
+
63
+ # if the node is a legal non-leaf node
64
+ if (nodeStatus == 3):
65
+ # initialize array that will store select score vectors of each child node
66
+ paretoFront = {}
67
+ for childNode in self.childNodes:
68
+ childStatus = childNode.getExpandStatus()
69
+ # only append child if it is legal leaf node (expandable) or legal non-leaf node
70
+ if childStatus == 2 or childStatus == 3:
71
+ selectScore = childNode.calcSelectScore()
72
+ paretoFront = updateParetoFront(paretoFront, childNode, selectScore, num_func)
73
+
74
+ # if no selectable children (all terminal), return self as a leaf
75
+ if len(paretoFront) == 0:
76
+ return self, 1
77
+
78
+ # randomly select a node on the Pareto front
79
+ #selected = rd.choice(paretoFront)
80
+ selected = rd.choice(list(paretoFront.keys()))
81
+ # return selected child node and status
82
+ return selected, selected.getExpandStatus()
83
+
84
+ # if node is not valid non-leaf node
85
+ return self, nodeStatus
86
+
87
+ def addChildNode(self, tokens, totalReward, prob=None):
88
+ """"
89
+ Adds a child node
90
+ """
91
+ child = Node(config=self.config,
92
+ tokens=tokens,
93
+ parentNode=self,
94
+ childNodes=[],
95
+ totalReward=totalReward,
96
+ timestep=self.timestep+1,
97
+ sampleProb=prob)
98
+
99
+ self.childNodes.append(child)
100
+ return child
101
+
102
+ def updateNode(self, rewards):
103
+ """
104
+ Updates the cumulative rewards vector with the reward vector at a descendent leaf node.
105
+ Increments the number of visits to the node.
106
+ """
107
+ self.visits += 1
108
+ self.totalReward += rewards
109
+
110
+ def calcSelectScore(self):
111
+ """
112
+ Calculates the select score for the node from the cumulative rewards vector and number of visits.
113
+ - c: determines the degree of exploration
114
+ - minSelectScore: determines the
115
+ """
116
+ """"
117
+ if not self.parentNode:
118
+ return 0.0
119
+ """
120
+ # K-dimensional vector of normalized rewards for each objective
121
+ normRewards = self.totalReward / self.visits
122
+ if self.sampleProb is not None:
123
+ print("Sample Prob")
124
+ print(self.sampleProb)
125
+ return normRewards + (self.config.mcts.sample_prob * self.sampleProb * np.sqrt(self.root.visits) / self.visits)
126
+ return normRewards
127
+
128
+ def getExpandStatus(self):
129
+ """
130
+ Returns an integer indicating whether the node is a:
131
+ 1. terminal node (sequence is fully unmasked)
132
+ 2. legal leaf node (partially unmasked sequence that can be expanded)
133
+ 3. legal non-leaf node (already expanded sequence with M child nodes)
134
+ """
135
+ if self.timestep == self.config.sampling.steps:
136
+ return 1
137
+ elif (self.timestep < self.config.sampling.steps) and (len(self.childNodes) == 0):
138
+ return 2
139
+ return 3
140
+
141
+ """END OF NODE CLASS"""
142
+
143
+ def updateParetoFront(paretoFront, node, scoreVector, num_func):
144
+ """
145
+ Removes sequences that are dominated by scoreVector
146
+ adds the SMILES sequence if it is non-dominated and its scoreVector
147
+ """
148
+ paretoSize = len(paretoFront)
149
+ if paretoSize == 0:
150
+ # if pareto front is empty, add sequence and scoreVector
151
+ paretoFront[node] = scoreVector
152
+ else:
153
+ # vector of boolean
154
+ # true: sequence is non-dominated by the pareto-optimal sequence
155
+ # false: sequence is completely dominated by the pareto-optimal sequence
156
+ nondominate = []
157
+ # sequences to be deleted
158
+ delete = []
159
+ for k, v in paretoFront.items():
160
+ nondominated = scoreVector >= np.asarray(v)
161
+ dominant = scoreVector > np.asarray(v)
162
+
163
+ if num_func <= len(nondominated):
164
+ attn_nondominated = nondominated[:num_func]
165
+ attn_dominant = dominant[:num_func]
166
+
167
+ # all scores are greater than or equal to v and at least one score is strictly greater than v
168
+ if attn_nondominated.all() and attn_dominant.any():
169
+ # add the dominated sequence to be deleted
170
+ delete.append(k)
171
+ # sequence is dominant
172
+ nondominate.append(True)
173
+ elif attn_nondominated.all():
174
+ # sequence is non-dominated
175
+ nondominate.append(True)
176
+ else:
177
+ # sequence is completely dominated
178
+ nondominate.append(False)
179
+
180
+ nondominate = np.asarray(nondominate)
181
+ # if sequence is either dominant or non-dominated by all sequences in pareto-front -> add to pareto front
182
+ if nondominate.all():
183
+ paretoFront[node] = scoreVector
184
+
185
+ # delete all dominated sequences
186
+ while (paretoSize > 0) and (len(delete) > 0):
187
+ #for k in delete:
188
+ del paretoFront[delete[0]]
189
+ del delete[0]
190
+ paretoSize -= 1
191
+ return paretoFront
192
+
193
+ ###BEGINNING OF MCTS CLASS###
194
+
195
+ class MCTS:
196
+ def __init__(self, config, max_sequence_length=None, mdlm=None, score_func_names=[], prot_seqs=None, num_func = []):
197
+ self.config = config
198
+ self.noise = noise_schedule.get_noise(config)
199
+ self.time_conditioning = self.config.time_conditioning
200
+ # dictionary of k (SMILES string) and v (score vector) of Pareto-optimal sequences
201
+ self.peptideParetoFront = {}
202
+ self.num_steps = config.sampling.steps
203
+ self.num_sequences = config.sampling.num_sequences
204
+
205
+ # mdlm model
206
+ self.mdlm = mdlm
207
+ self.tokenizer = mdlm.tokenizer
208
+ self.device = mdlm.device
209
+
210
+ if max_sequence_length is None:
211
+ self.sequence_length = self.config.sampling.seq_length
212
+ else:
213
+ self.sequence_length = max_sequence_length
214
+
215
+ self.num_iter = config.mcts.num_iter
216
+
217
+ self.num_child = config.mcts.num_children
218
+
219
+ # score functions
220
+ self.score_functions = ScoringFunctions(score_func_names, prot_seqs)
221
+ self.score_func_names = score_func_names
222
+ self.num_func = num_func # K-dimensional vector with the iteration number to start conditioning on each of the objectives in increasng order
223
+ self.iter_num = 0
224
+ self.curr_num_func = 1
225
+ self.analyzer = PeptideAnalyzer()
226
+
227
+ # track fraction of valid peptides
228
+ self.valid_fraction_log = []
229
+ self.score_logs = {name: [] for name in score_func_names}
230
+
231
+ def reset(self):
232
+ self.iter_num = 0
233
+ self.valid_fraction_log = []
234
+ self.score_logs = {name: [] for name in self.score_func_names}
235
+ self.peptideParetoFront = {}
236
+
237
+ def forward(self, rootNode):
238
+ self.reset()
239
+
240
+ while (self.iter_num < self.num_iter):
241
+ self.iter_num += 1
242
+
243
+ # traverse the tree form the root node until a leaf node
244
+ leafNode, _ = self.select(rootNode)
245
+ #print(leafNode.tokens['input_ids'])
246
+
247
+ # expand leaf node into num_children partially unmasked sequences at the next timestep
248
+ self.expand(leafNode)
249
+
250
+ # return dictionary of pareto front peptides and their score vectors
251
+ return self.peptideParetoFront
252
+
253
+ # change to include more even if dominated? since there is error in the scores
254
+ def updateParetoFront(self, sequence, scoreVector, tokens):
255
+ """
256
+ Removes sequences that are dominated by scoreVector
257
+ adds the SMILES sequence if it is non-dominated and its scoreVector
258
+
259
+ num_func: index of the last objective to consider when updating the pareto front from 0 to K
260
+ """
261
+ paretoSize = len(self.peptideParetoFront)
262
+
263
+ self.curr_num_func = 1
264
+
265
+ for i in range(len(self.num_func)):
266
+ if self.iter_num >= self.num_func[i]:
267
+ self.curr_num_func = i+1
268
+
269
+ if paretoSize == 0:
270
+ # if pareto front is empty, add sequence and scoreVector
271
+ self.peptideParetoFront[sequence] = {'scores': scoreVector, 'token_ids': tokens}
272
+ # if pareto front is empty, set reward vector to 1s
273
+ rewardVector = np.ones(len(scoreVector))
274
+ else:
275
+ # vector of boolean
276
+ # true: sequence is non-dominated by the pareto-optimal sequence
277
+ # false: sequence is completely dominated by the pareto-optimal sequence
278
+ nondominate = []
279
+ # sequences to be deleted
280
+ delete = []
281
+ # initialize reward vector with zeros
282
+ rewardVector = np.zeros(len(scoreVector))
283
+ for k, v in self.peptideParetoFront.items():
284
+ # boolean vector
285
+ # true: if all metrics are equal or larger
286
+ # false: if the pareto front sequence dominates scoreVector
287
+ nondominated = scoreVector >= np.asarray(v['scores']) # [num_objectives]
288
+ dominant = scoreVector > np.asarray(v['scores'])
289
+ # add to reward vector
290
+ rewardVector += nondominated # [num_objectives]
291
+
292
+ if self.curr_num_func <= len(nondominated):
293
+ attn_nondominated = nondominated[:self.curr_num_func]
294
+ attn_dominant = dominant[:self.curr_num_func]
295
+
296
+ # only delete pareto-optimal sequence if
297
+ # all scores are greater than or equal to v and at least one score is strictly greater than v
298
+ if attn_nondominated.all() and attn_dominant.any():
299
+ # add the dominated sequence to be deleted
300
+ delete.append(k)
301
+ # sequence is dominant
302
+ nondominate.append(True)
303
+ elif attn_nondominated.all():
304
+ # sequence is non-dominated
305
+ nondominate.append(True)
306
+ else:
307
+ # sequence is completely dominated
308
+ nondominate.append(False)
309
+
310
+ assert len(nondominate) == paretoSize
311
+ nondominate = np.asarray(nondominate)
312
+ # if sequence is either dominant or non-dominated by all sequences in pareto-front -> add to pareto front
313
+ # or if the pareto front does not have enough sequences
314
+ if nondominate.all() or paretoSize < self.num_sequences:
315
+ self.peptideParetoFront[sequence] = {'scores': scoreVector, 'token_ids': tokens}
316
+
317
+ rewardVector = rewardVector / paretoSize
318
+
319
+ # delete all dominated sequences if pareto front is larger than num_sequences
320
+ while (paretoSize > self.num_sequences) and (len(delete) > 0):
321
+ #for k in delete:
322
+ del self.peptideParetoFront[delete[0]]
323
+ del delete[0]
324
+ paretoSize -= 1
325
+
326
+ return rewardVector
327
+
328
+ def isPathEnd(self, path, maxDepth):
329
+ """
330
+ Checks if the node is completely unmasked (ie. end of path)
331
+ or if the path is at the max depth
332
+ """
333
+ if (path[-1] != self.config.mcts.mask_token).all():
334
+ return True
335
+ elif len(path) >= maxDepth:
336
+ return True
337
+ return False
338
+
339
+ def select(self, currNode):
340
+ """
341
+ Traverse the tree from the root node until reaching a legal leaf node
342
+ """
343
+ while True:
344
+ currNode, nodeStatus = currNode.selectNode(self.curr_num_func)
345
+ if nodeStatus != 3:
346
+ return currNode, nodeStatus
347
+
348
+ def expand(self, parentNode, eps=1e-5, checkSimilarity = True):
349
+ """
350
+ Sample unmasking steps from the pre-trained MDLM
351
+ adds num_children partially unmasked sequences to the children of the parentNode
352
+ """
353
+
354
+ num_children = self.config.mcts.num_children
355
+ # initialize child rewards that will be added to total rewards
356
+ allChildReward = np.zeros_like(parentNode.totalReward) # (n_objectives)
357
+
358
+
359
+ # compute number of rollout steps
360
+ # if parentNode.timestep = self.num_steps then num_rollout_steps = 1
361
+ num_rollout_steps = self.num_steps - parentNode.timestep
362
+ # array of rollout timesteps from the timestep of parent node to 0
363
+ rollout_t = torch.linspace(1, eps, num_rollout_steps, device=self.device)
364
+ dt = (1 - eps) / self.num_steps
365
+ p_x0_cache = None
366
+
367
+ # initialize x and attn_mask
368
+ x = parentNode.tokens['input_ids'].to(self.device)
369
+ attn_mask = parentNode.tokens['attention_mask'].to(self.device)
370
+
371
+ t = rollout_t[0] * torch.ones(num_children, 1, device = self.device)
372
+ # generate (n_children, seq_length) array of sampled children nodes
373
+ print("token array:")
374
+ print(x)
375
+ p_x0_cache, x_children = self.mdlm.batch_cached_reverse_step(token_array=x,
376
+ t=t, dt=dt,
377
+ batch_size=num_children,
378
+ attn_mask=attn_mask)
379
+ x_rollout = x_children
380
+
381
+ for i in range(1, num_rollout_steps):
382
+ t = rollout_t[i] * torch.ones(num_children, 1, device = self.device)
383
+
384
+ p_x0_cache, x_next = self.mdlm.cached_reverse_step(x=x_rollout,
385
+ t=t, dt=dt, p_x0=p_x0_cache,
386
+ attn_mask=attn_mask)
387
+
388
+ if (not torch.allclose(x_next, x) or self.time_conditioning):
389
+ # Disable caching
390
+ p_x0_cache = None
391
+
392
+ x_rollout = x_next
393
+
394
+ if self.config.sampling.noise_removal:
395
+ t = rollout_t[-1] * torch.ones(x.shape[0], 1, device=self.device)
396
+
397
+ time_cond = self.noise(t)[0]
398
+ x_rollout = self.mdlm.forward(x_rollout, attn_mask, time_cond).argmax(dim=-1) # (n_children, seq_length)
399
+
400
+ childSequences = self.tokenizer.batch_decode(x_rollout)
401
+
402
+ validSequences = []
403
+ maskedTokens = []
404
+ unmaskedTokens = []
405
+ for i in range(num_children):
406
+ childSeq = childSequences[i]
407
+ #scoreVector = scoreVectors[i]
408
+ rewardVector = np.zeros(self.config.mcts.num_objectives)
409
+
410
+ # check if the peptide is valid
411
+ if self.analyzer.is_peptide(childSeq):
412
+ validSequences.append(childSeq)
413
+ maskedTokens.append(x_children[i])
414
+ unmaskedTokens.append(x_rollout[i])
415
+ else:
416
+ childTokens = {'input_ids': x_children[i], 'attention_mask': attn_mask}
417
+ parentNode.addChildNode(tokens=childTokens,
418
+ totalReward=rewardVector)
419
+
420
+ if (len(validSequences) != 0):
421
+ scoreVectors = self.score_functions(input_seqs=validSequences)
422
+ average_scores = scoreVectors.T
423
+ for i, name in enumerate(self.score_func_names):
424
+ self.score_logs[name].append(average_scores[i])
425
+ else:
426
+ for name in self.score_func_names:
427
+ self.score_logs[name].append(np.zeros(0))
428
+
429
+ for i, validSeq in enumerate(validSequences):
430
+ #tokens = validTokens[i]
431
+ scoreVector = scoreVectors[i]
432
+
433
+ # update pareto front
434
+ rewardVector = self.updateParetoFront(validSeq, scoreVector, unmaskedTokens[i])
435
+ print(scoreVector)
436
+ print(rewardVector)
437
+
438
+ # add to all child reward vector for backprop
439
+ allChildReward += rewardVector
440
+
441
+ # create node for sequence and add to the children node of parent
442
+ childTokens = {'input_ids': maskedTokens[i], 'attention_mask': attn_mask}
443
+ parentNode.addChildNode(tokens=childTokens,
444
+ totalReward=rewardVector)
445
+
446
+ # compute fraction of invalid child sequences
447
+ invalid = (num_children - len(validSequences)) / num_children
448
+
449
+ valid_fraction = len(validSequences) / num_children
450
+ print(f"Valid fraction: {valid_fraction}")
451
+ self.valid_fraction_log.append(valid_fraction)
452
+
453
+ print(self.config.mcts.invalid_penalty)
454
+ # subtract score using fraction of invalid sequences from reward
455
+ allChildReward = allChildReward - (self.config.mcts.invalid_penalty * invalid)
456
+ # backpropogate all child rewards
457
+ self.backprop(parentNode, allChildReward)
458
+
459
+
460
+ def backprop(self, node, reward_vector):
461
+ # backpropogate rewards through the path leading to the leaf node from the root
462
+ while node:
463
+ node.updateNode(reward_vector)
464
+ node = node.parentNode
465
+
466
+
467
+ def getSequenceForObjective(self, objective_index, k):
468
+ """
469
+ Returns the top-k sequences in the pareto front that has the best score for
470
+ a given objective and their score vectors for all objectives
471
+ """
472
+
473
+ # dictionary of top-k peptides for the objective
474
+ topk = {}
475
+
476
+ peptides = []
477
+ objectiveScores = []
478
+ for k, v in self.peptideParetoFront.items():
479
+ # store peptides in list
480
+ peptides.append(k)
481
+ # store score for objective
482
+ objectiveScores.append(v['token_ids'][objective_index])
483
+
484
+ objectiveScores = torch.tensor(objectiveScores)
485
+ topKScores = torch.topk(objectiveScores, k)
486
+ for (_, index) in topKScores.items():
487
+ seq = peptides[index]
488
+
489
+ topk[seq] = self.peptideParetoFront.get(seq)
490
+
491
+ return topk
492
+
src/roformer.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import RoFormerConfig, RoFormerForMaskedLM
2
+ import torch.nn as nn
3
+ from torch.nn.parallel import DistributedDataParallel as DDP
4
+ import torch
5
+
6
+ class Roformer(nn.Module):
7
+ def __init__(self, config, tokenizer):
8
+ super(Roformer, self).__init__()
9
+
10
+ self.tokenizer = tokenizer
11
+ self.vocab_size = self.tokenizer.vocab_size
12
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+ self.device = device
14
+
15
+
16
+ roformer_config = RoFormerConfig(
17
+ vocab_size=self.tokenizer.vocab_size,
18
+ embedding_size=config.roformer.hidden_size,
19
+ hidden_size=config.roformer.hidden_size,
20
+ num_hidden_layers=config.roformer.n_layers,
21
+ num_attention_heads=config.roformer.n_heads,
22
+ intermediate_size=config.roformer.hidden_size * 4,
23
+ max_position_embeddings=config.roformer.max_position_embeddings,
24
+ hidden_dropout_prob=0.1,
25
+ attention_probs_dropout_prob=0.1,
26
+ pad_token_id=0,
27
+ rotary_value=False
28
+ )
29
+
30
+ self.model = RoFormerForMaskedLM(roformer_config).to(self.device)
31
+
32
+ def freeze_model(self):
33
+ for param in self.model.parameters():
34
+ param.requires_grad = False
35
+
36
+ def unfreeze_all_layers(self):
37
+ for param in self.model.parameters():
38
+ param.requires_grad = True
39
+
40
+ def unfreeze_n_layers(self, n):
41
+ num_layers = 8
42
+
43
+ for i, layer in enumerate(self.model.roformer.encoder.layer):
44
+ # finetune final n layers
45
+ if i >= num_layers - n:
46
+ # unfreeze query weights
47
+ for module in layer.attention.self.query.modules():
48
+ for param in module.parameters():
49
+ param.requires_grad = True
50
+ # unfreeze key weights
51
+ for module in layer.attention.self.key.modules():
52
+ for param in module.parameters():
53
+ param.requires_grad = True
54
+
55
+ def forward(self, input_ids, attn_mask):
56
+
57
+ input_ids = input_ids.to(self.device)
58
+ attn_mask = attn_mask.to(self.device)
59
+
60
+ # get logits embeddings
61
+ logits = self.model(input_ids=input_ids, attention_mask=attn_mask)
62
+ # return logits
63
+ #print(logits.logits)
64
+ return logits.logits
65
+
66
+ def save_model(self, save_dir):
67
+ self.model.save_pretrained(save_dir)
68
+ self.tokenizer.save_pretrained(save_dir)
69
+
70
+ @classmethod
71
+ def load_model(cls, save_dir, config, tokenizer):
72
+ roformer = cls(config, tokenizer)
73
+ roformer.model = RoFormerForMaskedLM.from_pretrained(save_dir)
74
+ return roformer
src/scoring/functions/binding.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os, torch
3
+ import numpy as np
4
+ import torch
5
+ import pandas as pd
6
+ import torch.nn as nn
7
+ import esm
8
+ from transformers import AutoModelForMaskedLM
9
+
10
+ class ImprovedBindingPredictor(nn.Module):
11
+ def __init__(self,
12
+ esm_dim=1280,
13
+ smiles_dim=768,
14
+ hidden_dim=512,
15
+ n_heads=8,
16
+ n_layers=3,
17
+ dropout=0.1):
18
+ super().__init__()
19
+
20
+ # Define binding thresholds
21
+ self.tight_threshold = 7.5 # Kd/Ki/IC50 ≤ ~30nM
22
+ self.weak_threshold = 6.0 # Kd/Ki/IC50 > 1μM
23
+
24
+ # Project to same dimension
25
+ self.smiles_projection = nn.Linear(smiles_dim, hidden_dim)
26
+ self.protein_projection = nn.Linear(esm_dim, hidden_dim)
27
+ self.protein_norm = nn.LayerNorm(hidden_dim)
28
+ self.smiles_norm = nn.LayerNorm(hidden_dim)
29
+
30
+ # Cross attention blocks with layer norm
31
+ self.cross_attention_layers = nn.ModuleList([
32
+ nn.ModuleDict({
33
+ 'attention': nn.MultiheadAttention(hidden_dim, n_heads, dropout=dropout),
34
+ 'norm1': nn.LayerNorm(hidden_dim),
35
+ 'ffn': nn.Sequential(
36
+ nn.Linear(hidden_dim, hidden_dim * 4),
37
+ nn.ReLU(),
38
+ nn.Dropout(dropout),
39
+ nn.Linear(hidden_dim * 4, hidden_dim)
40
+ ),
41
+ 'norm2': nn.LayerNorm(hidden_dim)
42
+ }) for _ in range(n_layers)
43
+ ])
44
+
45
+ # Prediction heads
46
+ self.shared_head = nn.Sequential(
47
+ nn.Linear(hidden_dim * 2, hidden_dim),
48
+ nn.ReLU(),
49
+ nn.Dropout(dropout),
50
+ )
51
+
52
+ # Regression head
53
+ self.regression_head = nn.Linear(hidden_dim, 1)
54
+
55
+ # Classification head (3 classes: tight, medium, loose binding)
56
+ self.classification_head = nn.Linear(hidden_dim, 3)
57
+
58
+ def get_binding_class(self, affinity):
59
+ """Convert affinity values to class indices
60
+ 0: tight binding (>= 7.5)
61
+ 1: medium binding (6.0-7.5)
62
+ 2: weak binding (< 6.0)
63
+ """
64
+ if isinstance(affinity, torch.Tensor):
65
+ tight_mask = affinity >= self.tight_threshold
66
+ weak_mask = affinity < self.weak_threshold
67
+ medium_mask = ~(tight_mask | weak_mask)
68
+
69
+ classes = torch.zeros_like(affinity, dtype=torch.long)
70
+ classes[medium_mask] = 1
71
+ classes[weak_mask] = 2
72
+ return classes
73
+ else:
74
+ if affinity >= self.tight_threshold:
75
+ return 0 # tight binding
76
+ elif affinity < self.weak_threshold:
77
+ return 2 # weak binding
78
+ else:
79
+ return 1 # medium binding
80
+
81
+ def forward(self, protein_emb, smiles_emb):
82
+ protein = self.protein_norm(self.protein_projection(protein_emb))
83
+ smiles = self.smiles_norm(self.smiles_projection(smiles_emb))
84
+
85
+ #protein = protein.transpose(0, 1)
86
+ #smiles = smiles.transpose(0, 1)
87
+
88
+ # Cross attention layers
89
+ for layer in self.cross_attention_layers:
90
+ # Protein attending to SMILES
91
+ attended_protein = layer['attention'](
92
+ protein, smiles, smiles
93
+ )[0]
94
+ protein = layer['norm1'](protein + attended_protein)
95
+ protein = layer['norm2'](protein + layer['ffn'](protein))
96
+
97
+ # SMILES attending to protein
98
+ attended_smiles = layer['attention'](
99
+ smiles, protein, protein
100
+ )[0]
101
+ smiles = layer['norm1'](smiles + attended_smiles)
102
+ smiles = layer['norm2'](smiles + layer['ffn'](smiles))
103
+
104
+ # Get sequence-level representations
105
+ protein_pool = torch.mean(protein, dim=0)
106
+ smiles_pool = torch.mean(smiles, dim=0)
107
+
108
+ # Concatenate both representations
109
+ combined = torch.cat([protein_pool, smiles_pool], dim=-1)
110
+
111
+ # Shared features
112
+ shared_features = self.shared_head(combined)
113
+
114
+ regression_output = self.regression_head(shared_features)
115
+ classification_logits = self.classification_head(shared_features)
116
+
117
+ return regression_output, classification_logits
118
+
119
+ class BindingAffinity:
120
+ def __init__(self, prot_seq, tokenizer, base_path, device=None, emb_model=None):
121
+ super().__init__()
122
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
123
+
124
+ # peptide embeddings
125
+ if emb_model is not None:
126
+ self.pep_model = emb_model.to(self.device).eval()
127
+ else:
128
+ self.pep_model = AutoModelForMaskedLM.from_pretrained('aaronfeller/PeptideCLM-23M-all').roformer.to(self.device).eval()
129
+
130
+ self.pep_tokenizer = tokenizer
131
+
132
+ self.model = ImprovedBindingPredictor().to(self.device)
133
+ checkpoint = torch.load(f'{base_path}/src/scoring/functions/classifiers/binding-affinity.pt',
134
+ map_location=self.device,
135
+ weights_only=False)
136
+ self.model.load_state_dict(checkpoint['model_state_dict'])
137
+
138
+ self.model.eval()
139
+
140
+ self.esm_model, alphabet = esm.pretrained.esm2_t33_650M_UR50D() # load ESM-2 model
141
+ self.esm_model = self.esm_model.to(self.device).eval()
142
+ self.prot_tokenizer = alphabet.get_batch_converter() # load esm tokenizer
143
+
144
+ data = [("target", prot_seq)]
145
+ # get tokenized protein
146
+ _, _, prot_tokens = self.prot_tokenizer(data)
147
+ prot_tokens = prot_tokens.to(self.device)
148
+ with torch.no_grad():
149
+ results = self.esm_model.forward(prot_tokens, repr_layers=[33]) # Example with ESM-2
150
+ prot_emb = results["representations"][33]
151
+
152
+ self.prot_emb = prot_emb[0].to(self.device)
153
+ self.prot_emb = torch.mean(self.prot_emb, dim=0, keepdim=True)
154
+
155
+
156
+ def forward(self, input_seqs):
157
+ with torch.no_grad():
158
+ scores = []
159
+ for seq in input_seqs:
160
+ pep_tokens = self.pep_tokenizer(seq, return_tensors='pt', padding=True)
161
+
162
+ pep_tokens = {k: v.to(self.device) for k, v in pep_tokens.items()}
163
+
164
+ with torch.no_grad():
165
+ emb = self.pep_model(input_ids=pep_tokens['input_ids'],
166
+ attention_mask=pep_tokens['attention_mask'],
167
+ output_hidden_states=True)
168
+
169
+ #emb = self.pep_model(input_ids=pep_tokens['input_ids'], attention_mask=pep_tokens['attention_mask'])
170
+ pep_emb = emb.last_hidden_state.squeeze(0)
171
+ pep_emb = torch.mean(pep_emb, dim=0, keepdim=True)
172
+
173
+ score, logits = self.model.forward(self.prot_emb, pep_emb)
174
+ scores.append(score.item())
175
+ return scores
176
+
177
+ def __call__(self, input_seqs: list):
178
+ return self.forward(input_seqs)
src/scoring/functions/binding_utils.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ import torch
3
+ import numpy as np
4
+
5
+ def to_var(x):
6
+ if torch.cuda.is_available():
7
+ x = x.cuda()
8
+ return x
9
+
10
+ class MultiHeadAttentionSequence(nn.Module):
11
+
12
+ def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
13
+
14
+ super().__init__()
15
+
16
+ self.n_head = n_head
17
+ self.d_model = d_model
18
+ self.d_k = d_k
19
+ self.d_v = d_v
20
+
21
+ self.W_Q = nn.Linear(d_model, n_head*d_k)
22
+ self.W_K = nn.Linear(d_model, n_head*d_k)
23
+ self.W_V = nn.Linear(d_model, n_head*d_v)
24
+ self.W_O = nn.Linear(n_head*d_v, d_model)
25
+
26
+ self.layer_norm = nn.LayerNorm(d_model)
27
+
28
+ self.dropout = nn.Dropout(dropout)
29
+
30
+ def forward(self, q, k, v):
31
+
32
+ batch, len_q, _ = q.size()
33
+ batch, len_k, _ = k.size()
34
+ batch, len_v, _ = v.size()
35
+
36
+ Q = self.W_Q(q).view([batch, len_q, self.n_head, self.d_k])
37
+ K = self.W_K(k).view([batch, len_k, self.n_head, self.d_k])
38
+ V = self.W_V(v).view([batch, len_v, self.n_head, self.d_v])
39
+
40
+ Q = Q.transpose(1, 2)
41
+ K = K.transpose(1, 2).transpose(2, 3)
42
+ V = V.transpose(1, 2)
43
+
44
+ attention = torch.matmul(Q, K)
45
+
46
+ attention = attention / np.sqrt(self.d_k)
47
+
48
+ attention = F.softmax(attention, dim=-1)
49
+
50
+ output = torch.matmul(attention, V)
51
+
52
+ output = output.transpose(1, 2).reshape([batch, len_q, self.d_v*self.n_head])
53
+
54
+ output = self.W_O(output)
55
+
56
+ output = self.dropout(output)
57
+
58
+ output = self.layer_norm(output + q)
59
+
60
+ return output, attention
61
+
62
+ class MultiHeadAttentionReciprocal(nn.Module):
63
+
64
+
65
+ def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
66
+
67
+ super().__init__()
68
+
69
+ self.n_head = n_head
70
+ self.d_model = d_model
71
+ self.d_k = d_k
72
+ self.d_v = d_v
73
+
74
+ self.W_Q = nn.Linear(d_model, n_head*d_k)
75
+ self.W_K = nn.Linear(d_model, n_head*d_k)
76
+ self.W_V = nn.Linear(d_model, n_head*d_v)
77
+ self.W_O = nn.Linear(n_head*d_v, d_model)
78
+ self.W_V_2 = nn.Linear(d_model, n_head*d_v)
79
+ self.W_O_2 = nn.Linear(n_head*d_v, d_model)
80
+
81
+ self.layer_norm = nn.LayerNorm(d_model)
82
+
83
+ self.dropout = nn.Dropout(dropout)
84
+
85
+ self.layer_norm_2 = nn.LayerNorm(d_model)
86
+
87
+ self.dropout_2 = nn.Dropout(dropout)
88
+
89
+ def forward(self, q, k, v, v_2):
90
+
91
+ batch, len_q, _ = q.size()
92
+ batch, len_k, _ = k.size()
93
+ batch, len_v, _ = v.size()
94
+ batch, len_v_2, _ = v_2.size()
95
+
96
+ Q = self.W_Q(q).view([batch, len_q, self.n_head, self.d_k])
97
+ K = self.W_K(k).view([batch, len_k, self.n_head, self.d_k])
98
+ V = self.W_V(v).view([batch, len_v, self.n_head, self.d_v])
99
+ V_2 = self.W_V_2(v_2).view([batch, len_v_2, self.n_head, self.d_v])
100
+
101
+ Q = Q.transpose(1, 2)
102
+ K = K.transpose(1, 2).transpose(2, 3)
103
+ V = V.transpose(1, 2)
104
+ V_2 = V_2.transpose(1,2)
105
+
106
+ attention = torch.matmul(Q, K)
107
+
108
+
109
+ attention = attention /np.sqrt(self.d_k)
110
+
111
+ attention_2 = attention.transpose(-2, -1)
112
+
113
+
114
+
115
+ attention = F.softmax(attention, dim=-1)
116
+
117
+ attention_2 = F.softmax(attention_2, dim=-1)
118
+
119
+
120
+ output = torch.matmul(attention, V)
121
+
122
+ output_2 = torch.matmul(attention_2, V_2)
123
+
124
+ output = output.transpose(1, 2).reshape([batch, len_q, self.d_v*self.n_head])
125
+
126
+ output_2 = output_2.transpose(1, 2).reshape([batch, len_k, self.d_v*self.n_head])
127
+
128
+ output = self.W_O(output)
129
+
130
+ output_2 = self.W_O_2(output_2)
131
+
132
+ output = self.dropout(output)
133
+
134
+ output = self.layer_norm(output + q)
135
+
136
+ output_2 = self.dropout(output_2)
137
+
138
+ output_2 = self.layer_norm(output_2 + k)
139
+
140
+
141
+ return output, output_2, attention, attention_2
142
+
143
+
144
+ class FFN(nn.Module):
145
+
146
+ def __init__(self, d_in, d_hid, dropout=0.1):
147
+ super().__init__()
148
+
149
+ self.layer_1 = nn.Conv1d(d_in, d_hid,1)
150
+ self.layer_2 = nn.Conv1d(d_hid, d_in,1)
151
+ self.relu = nn.ReLU()
152
+ self.layer_norm = nn.LayerNorm(d_in)
153
+
154
+ self.dropout = nn.Dropout(dropout)
155
+
156
+ def forward(self, x):
157
+
158
+ residual = x
159
+ output = self.layer_1(x.transpose(1, 2))
160
+
161
+ output = self.relu(output)
162
+
163
+ output = self.layer_2(output)
164
+
165
+ output = self.dropout(output)
166
+
167
+ output = self.layer_norm(output.transpose(1, 2)+residual)
168
+
169
+ return output
170
+
171
+ class ConvLayer(nn.Module):
172
+ def __init__(self, in_channels, out_channels, kernel_size, padding, dilation):
173
+ super(ConvLayer, self).__init__()
174
+ self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, padding=padding, dilation=dilation)
175
+ self.relu = nn.ReLU()
176
+
177
+ def forward(self, x):
178
+ out = self.conv(x)
179
+ out = self.relu(out)
180
+ return out
181
+
182
+
183
+ class DilatedCNN(nn.Module):
184
+ def __init__(self, d_model, d_hidden):
185
+ super(DilatedCNN, self).__init__()
186
+ self.first_ = nn.ModuleList()
187
+ self.second_ = nn.ModuleList()
188
+ self.third_ = nn.ModuleList()
189
+
190
+ dilation_tuple = (1, 2, 3)
191
+ dim_in_tuple = (d_model, d_hidden, d_hidden)
192
+ dim_out_tuple = (d_hidden, d_hidden, d_hidden)
193
+
194
+ for i, dilation_rate in enumerate(dilation_tuple):
195
+ self.first_.append(ConvLayer(dim_in_tuple[i], dim_out_tuple[i], kernel_size=3, padding=dilation_rate,
196
+ dilation=dilation_rate))
197
+
198
+ for i, dilation_rate in enumerate(dilation_tuple):
199
+ self.second_.append(ConvLayer(dim_in_tuple[i], dim_out_tuple[i], kernel_size=5, padding=2*dilation_rate,
200
+ dilation=dilation_rate))
201
+
202
+ for i, dilation_rate in enumerate(dilation_tuple):
203
+ self.third_.append(ConvLayer(dim_in_tuple[i], dim_out_tuple[i], kernel_size=7, padding=3*dilation_rate,
204
+ dilation=dilation_rate))
205
+
206
+ def forward(self, protein_seq_enc):
207
+ # pdb.set_trace()
208
+ protein_seq_enc = protein_seq_enc.transpose(1, 2) # protein_seq_enc's shape: B*L*d_model -> B*d_model*L
209
+
210
+ first_embedding = protein_seq_enc
211
+ second_embedding = protein_seq_enc
212
+ third_embedding = protein_seq_enc
213
+
214
+ for i in range(len(self.first_)):
215
+ first_embedding = self.first_[i](first_embedding)
216
+
217
+ for i in range(len(self.second_)):
218
+ second_embedding = self.second_[i](second_embedding)
219
+
220
+ for i in range(len(self.third_)):
221
+ third_embedding = self.third_[i](third_embedding)
222
+
223
+ # pdb.set_trace()
224
+
225
+ protein_seq_enc = first_embedding + second_embedding + third_embedding
226
+
227
+ return protein_seq_enc.transpose(1, 2)
228
+
229
+
230
+ class ReciprocalLayerwithCNN(nn.Module):
231
+
232
+ def __init__(self, d_model, d_inner, d_hidden, n_head, d_k, d_v):
233
+ super().__init__()
234
+
235
+ self.cnn = DilatedCNN(d_model, d_hidden)
236
+
237
+ self.sequence_attention_layer = MultiHeadAttentionSequence(n_head, d_hidden, d_k, d_v)
238
+
239
+ self.protein_attention_layer = MultiHeadAttentionSequence(n_head, d_hidden, d_k, d_v)
240
+
241
+ self.reciprocal_attention_layer = MultiHeadAttentionReciprocal(n_head, d_hidden, d_k, d_v)
242
+
243
+ self.ffn_seq = FFN(d_hidden, d_inner)
244
+
245
+ self.ffn_protein = FFN(d_hidden, d_inner)
246
+
247
+ def forward(self, sequence_enc, protein_seq_enc):
248
+ # pdb.set_trace() # protein_seq_enc.shape = B * L * d_model
249
+ protein_seq_enc = self.cnn(protein_seq_enc)
250
+ prot_enc, prot_attention = self.protein_attention_layer(protein_seq_enc, protein_seq_enc, protein_seq_enc)
251
+
252
+ seq_enc, sequence_attention = self.sequence_attention_layer(sequence_enc, sequence_enc, sequence_enc)
253
+
254
+ prot_enc, seq_enc, prot_seq_attention, seq_prot_attention = self.reciprocal_attention_layer(prot_enc, seq_enc, seq_enc, prot_enc)
255
+
256
+ prot_enc = self.ffn_protein(prot_enc)
257
+
258
+ seq_enc = self.ffn_seq(seq_enc)
259
+
260
+ return prot_enc, seq_enc, prot_attention, sequence_attention, prot_seq_attention, seq_prot_attention
261
+
262
+
263
+ class ReciprocalLayer(nn.Module):
264
+
265
+ def __init__(self, d_model, d_inner, n_head, d_k, d_v):
266
+
267
+ super().__init__()
268
+
269
+ self.sequence_attention_layer = MultiHeadAttentionSequence(n_head, d_model, d_k, d_v)
270
+
271
+ self.protein_attention_layer = MultiHeadAttentionSequence(n_head, d_model, d_k, d_v)
272
+
273
+ self.reciprocal_attention_layer = MultiHeadAttentionReciprocal(n_head, d_model, d_k, d_v)
274
+
275
+ self.ffn_seq = FFN(d_model, d_inner)
276
+
277
+ self.ffn_protein = FFN(d_model, d_inner)
278
+
279
+ def forward(self, sequence_enc, protein_seq_enc):
280
+ prot_enc, prot_attention = self.protein_attention_layer(protein_seq_enc, protein_seq_enc, protein_seq_enc)
281
+
282
+ seq_enc, sequence_attention = self.sequence_attention_layer(sequence_enc, sequence_enc, sequence_enc)
283
+
284
+
285
+ prot_enc, seq_enc, prot_seq_attention, seq_prot_attention = self.reciprocal_attention_layer(prot_enc, seq_enc, seq_enc, prot_enc)
286
+ prot_enc = self.ffn_protein(prot_enc)
287
+
288
+ seq_enc = self.ffn_seq(seq_enc)
289
+
290
+ return prot_enc, seq_enc, prot_attention, sequence_attention, prot_seq_attention, seq_prot_attention
src/scoring/functions/classifiers/hemolysis-xgboost.json ADDED
The diff for this file is too large to render. See raw diff
 
src/scoring/functions/classifiers/nonfouling-xgboost.json ADDED
The diff for this file is too large to render. See raw diff
 
src/scoring/functions/classifiers/permeability-xgboost.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7e5d8c84bdad75f7091b5b3963133d4b0ebd180ae45654618ca6c090eee0bc06
3
+ size 45249160
src/scoring/functions/classifiers/solubility-xgboost.json ADDED
The diff for this file is too large to render. See raw diff
 
src/scoring/functions/hemolysis.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import xgboost as xgb
2
+ import torch
3
+ import numpy as np
4
+ from transformers import AutoModelForMaskedLM
5
+ import warnings
6
+ import numpy as np
7
+ from rdkit import rdBase
8
+
9
+ rdBase.DisableLog('rdApp.error')
10
+ warnings.filterwarnings("ignore", category=DeprecationWarning)
11
+ warnings.filterwarnings("ignore", category=UserWarning)
12
+ warnings.filterwarnings("ignore", category=FutureWarning)
13
+
14
+ class Hemolysis:
15
+
16
+ def __init__(self, tokenizer, base_path, device=None, emb_model=None):
17
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
18
+ self.predictor = xgb.Booster(model_file=f'{base_path}/src/scoring/functions/classifiers/hemolysis-xgboost.json')
19
+ self.emb_model = emb_model if emb_model is not None else AutoModelForMaskedLM.from_pretrained('aaronfeller/PeptideCLM-23M-all').roformer.to(device).eval()
20
+ self.tokenizer = tokenizer
21
+
22
+ def generate_embeddings(self, sequences):
23
+ embeddings = []
24
+ for sequence in sequences:
25
+ tokenized = self.tokenizer(sequence, return_tensors='pt')
26
+ tokenized = {k: v.to(self.device) for k, v in tokenized.items()}
27
+ with torch.no_grad():
28
+ output = self.emb_model(**tokenized)
29
+ # Mean pooling across sequence length
30
+ embedding = output.last_hidden_state.mean(dim=1).squeeze(0).cpu().numpy()
31
+ embeddings.append(embedding)
32
+ return np.array(embeddings)
33
+
34
+ def get_scores(self, input_seqs: list):
35
+ scores = np.ones(len(input_seqs))
36
+ features = self.generate_embeddings(input_seqs)
37
+
38
+ if len(features) == 0:
39
+ return scores
40
+
41
+ features = np.nan_to_num(features, nan=0.)
42
+ features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max)
43
+
44
+ features = xgb.DMatrix(features)
45
+
46
+ probs = self.predictor.predict(features)
47
+ # return the probability of it being not hemolytic
48
+ return scores - probs
49
+
50
+ def __call__(self, input_seqs: list):
51
+ scores = self.get_scores(input_seqs)
52
+ return scores
53
+
54
+ def unittest():
55
+ hemo = Hemolysis()
56
+ seq = ["[te]NCC(=O)N[C@H](CS)C(=O)N[C@@H](CO)C(=O)NCC(=O)N[C@@H](CC1=CN=C-N1)C(=O)N[C@@H](CC(=O)N)C(=O)N[C@@H](CC(=CN2)C1=C2C=CC=C1)C(=O)N[C@@H](c1ccc(cc1)F)C(=O)N[C@@H]([C@H](CC)C)C(=O)N[C@@H](CCCO)C(=O)N[C@@H](CC1=CN=C-N1)C(=O)N[C@@H](CCC(=O)O)C(=O)N[C@@H](CO)C(=O)O"]
57
+ print(hemo.tokenizer.vocab_size)
58
+ scores = hemo(input_seqs=seq)
59
+ print(scores)
60
+
61
+
62
+ if __name__ == '__main__':
63
+ unittest()
src/scoring/functions/nonfouling.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ import xgboost as xgb
4
+ import torch
5
+ import numpy as np
6
+ from transformers import AutoModelForMaskedLM
7
+ import warnings
8
+ import numpy as np
9
+ from rdkit import Chem, rdBase, DataStructs
10
+
11
+
12
+ rdBase.DisableLog('rdApp.error')
13
+ warnings.filterwarnings("ignore", category=DeprecationWarning)
14
+ warnings.filterwarnings("ignore", category=UserWarning)
15
+ warnings.filterwarnings("ignore", category=FutureWarning)
16
+
17
+ class Nonfouling:
18
+
19
+ def __init__(self, tokenizer, base_path, device=None, emb_model=None):
20
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
21
+ self.predictor = xgb.Booster(model_file=f'{base_path}/src/scoring/functions/classifiers/nonfouling-xgboost.json')
22
+ self.emb_model = emb_model if emb_model is not None else AutoModelForMaskedLM.from_pretrained('aaronfeller/PeptideCLM-23M-all').roformer.to(device).eval()
23
+ self.tokenizer = tokenizer
24
+
25
+ def generate_embeddings(self, sequences):
26
+ embeddings = []
27
+ for sequence in sequences:
28
+ tokenized = self.tokenizer(sequence, return_tensors='pt')
29
+ tokenized = {k: v.to(self.device) for k, v in tokenized.items()}
30
+ with torch.no_grad():
31
+ output = self.emb_model(**tokenized)
32
+ # Mean pooling across sequence length
33
+ embedding = output.last_hidden_state.mean(dim=1).squeeze(0).cpu().numpy()
34
+ embeddings.append(embedding)
35
+ return np.array(embeddings)
36
+
37
+ def get_scores(self, input_seqs: list):
38
+ scores = np.zeros(len(input_seqs))
39
+ features = self.generate_embeddings(input_seqs)
40
+
41
+ if len(features) == 0:
42
+ return scores
43
+
44
+ features = np.nan_to_num(features, nan=0.)
45
+ features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max)
46
+
47
+ features = xgb.DMatrix(features)
48
+
49
+ scores = self.predictor.predict(features)
50
+ # return the probability of it being not hemolytic
51
+ return scores
52
+
53
+ def __call__(self, input_seqs: list):
54
+ scores = self.get_scores(input_seqs)
55
+ return scores
56
+
57
+ def unittest():
58
+ nf = Nonfouling()
59
+ seq = ["NCC(=O)N[C@H](CS)C(=O)N[C@@H](CO)C(=O)NCC(=O)N[C@@H](CC1=CN=C-N1)C(=O)N[C@@H](CC(=O)N)C(=O)N[C@@H](CC(=CN2)C1=C2C=CC=C1)C(=O)N[C@@H](c1ccc(cc1)F)C(=O)N[C@@H]([C@H](CC)C)C(=O)N[C@@H](CCCO)C(=O)N[C@@H](CC1=CN=C-N1)C(=O)N[C@@H](CCC(=O)O)C(=O)N[C@@H](CO)C(=O)O"]
60
+
61
+ scores = nf(input_seqs=seq)
62
+ print(scores)
63
+
64
+
65
+ if __name__ == '__main__':
66
+ unittest()
src/scoring/functions/permeability.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ import xgboost as xgb
4
+ import torch
5
+ import numpy as np
6
+ from transformers import AutoModelForMaskedLM
7
+ import warnings
8
+ import numpy as np
9
+ from rdkit.Chem import Descriptors, rdMolDescriptors
10
+ from rdkit import Chem, rdBase, DataStructs
11
+ from rdkit.Chem import AllChem
12
+ from typing import List
13
+ from transformers import AutoModelForMaskedLM
14
+
15
+
16
+ rdBase.DisableLog('rdApp.error')
17
+ warnings.filterwarnings("ignore", category=DeprecationWarning)
18
+ warnings.filterwarnings("ignore", category=UserWarning)
19
+ warnings.filterwarnings("ignore", category=FutureWarning)
20
+
21
+ def fingerprints_from_smiles(smiles: List, size=2048):
22
+ """ Create ECFP fingerprints of smiles, with validity check """
23
+ fps = []
24
+ valid_mask = []
25
+ for i, smile in enumerate(smiles):
26
+ mol = Chem.MolFromSmiles(smile)
27
+ valid_mask.append(int(mol is not None))
28
+ fp = fingerprints_from_mol(mol, size=size) if mol else np.zeros((1, size))
29
+ fps.append(fp)
30
+
31
+ fps = np.concatenate(fps, axis=0)
32
+ return fps, valid_mask
33
+
34
+
35
+ def fingerprints_from_mol(molecule, radius=3, size=2048, hashed=False):
36
+ """ Create ECFP fingerprint of a molecule """
37
+ if hashed:
38
+ fp_bits = AllChem.GetHashedMorganFingerprint(molecule, radius, nBits=size)
39
+ else:
40
+ fp_bits = AllChem.GetMorganFingerprintAsBitVect(molecule, radius, nBits=size)
41
+ fp_np = np.zeros((1,))
42
+ DataStructs.ConvertToNumpyArray(fp_bits, fp_np)
43
+ return fp_np.reshape(1, -1)
44
+
45
+ def getMolDescriptors(mol, missingVal=0):
46
+ """ calculate the full list of descriptors for a molecule """
47
+
48
+ values, names = [], []
49
+ for nm, fn in Descriptors._descList:
50
+ try:
51
+ val = fn(mol)
52
+ except:
53
+ val = missingVal
54
+ values.append(val)
55
+ names.append(nm)
56
+
57
+ custom_descriptors = {'hydrogen-bond donors': rdMolDescriptors.CalcNumLipinskiHBD,
58
+ 'hydrogen-bond acceptors': rdMolDescriptors.CalcNumLipinskiHBA,
59
+ 'rotatable bonds': rdMolDescriptors.CalcNumRotatableBonds,}
60
+
61
+ for nm, fn in custom_descriptors.items():
62
+ try:
63
+ val = fn(mol)
64
+ except:
65
+ val = missingVal
66
+ values.append(val)
67
+ names.append(nm)
68
+ return values, names
69
+
70
+ def get_pep_dps_from_smi(smi):
71
+ try:
72
+ mol = Chem.MolFromSmiles(smi)
73
+ except:
74
+ print(f"convert smi {smi} to molecule failed!")
75
+ mol = None
76
+
77
+ dps, _ = getMolDescriptors(mol)
78
+ return np.array(dps)
79
+
80
+
81
+ def get_pep_dps(smi_list):
82
+ if len(smi_list) == 0:
83
+ return np.zeros((0, 213))
84
+ return np.array([get_pep_dps_from_smi(smi) for smi in smi_list])
85
+
86
+ def check_smi_validity(smiles: list):
87
+ valid_smi, valid_idx = [], []
88
+ for idx, smi in enumerate(smiles):
89
+ try:
90
+ mol = Chem.MolFromSmiles(smi) if smi else None
91
+ if mol:
92
+ valid_smi.append(smi)
93
+ valid_idx.append(idx)
94
+ except Exception as e:
95
+ # logger.debug(f'Error: {e} in smiles {smi}')
96
+ pass
97
+ return valid_smi, valid_idx
98
+
99
+ class Permeability:
100
+
101
+ def __init__(self, tokenizer, base_path, device=None, emb_model=None):
102
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
103
+ self.predictor = xgb.Booster(model_file=f'{base_path}/src/scoring/functions/classifiers/permeability-xgboost.json')
104
+ if emb_model is not None:
105
+ self.emb_model = emb_model.to(self.device).eval()
106
+ else:
107
+ self.emb_model = AutoModelForMaskedLM.from_pretrained('aaronfeller/PeptideCLM-23M-all').roformer.to(device).eval()
108
+
109
+ self.tokenizer = tokenizer
110
+
111
+ def generate_embeddings(self, sequences):
112
+ embeddings = []
113
+ for sequence in sequences:
114
+ tokenized = self.tokenizer(sequence, return_tensors='pt')
115
+ tokenized = {k: v.to(self.device) for k, v in tokenized.items()}
116
+ with torch.no_grad():
117
+ output = self.emb_model(**tokenized)
118
+ # Mean pooling across sequence length
119
+ embedding = output.last_hidden_state.mean(dim=1).squeeze(0).cpu().numpy()
120
+ embeddings.append(embedding)
121
+ return np.array(embeddings)
122
+
123
+ def get_features(self, input_seqs: list, dps=False, fps=False):
124
+ #valid_smiles, valid_idxes = check_smi_validity(input_seqs)
125
+
126
+
127
+ if fps:
128
+ fingerprints = fingerprints_from_smiles(input_seqs)[0]
129
+ else:
130
+ fingerprints = torch.empty((len(input_seqs), 0))
131
+
132
+ if dps:
133
+ descriptors = get_pep_dps(input_seqs)
134
+ else:
135
+ descriptors = torch.empty((len(input_seqs), 0))
136
+
137
+ embeddings = self.generate_embeddings(input_seqs)
138
+ # logger.debug(f'X_fps.shape: {X_fps.shape}, X_dps.shape: {X_dps.shape}')
139
+
140
+ features = np.concatenate([fingerprints, descriptors, embeddings], axis=1)
141
+
142
+ return features
143
+
144
+ def get_scores(self, input_seqs: list):
145
+ scores = -10 * np.ones(len(input_seqs))
146
+ features = self.get_features(input_seqs)
147
+
148
+ if len(features) == 0:
149
+ return scores
150
+
151
+ features = np.nan_to_num(features, nan=0.)
152
+ features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max)
153
+
154
+ features = xgb.DMatrix(features)
155
+
156
+ scores = self.predictor.predict(features)
157
+ return scores
158
+
159
+ def __call__(self, input_seqs: list):
160
+ scores = self.get_scores(input_seqs)
161
+ return scores
162
+
163
+ def unittest():
164
+ permeability = Permeability()
165
+ seq = ['N[C@@H](CCCNC(=N)N)C(=O)N[C@@H](Cc1cNc2c1cc(O)cc2)C(=O)N[C@@H](CC1=CN=C-N1)C(=O)N[C@@H](CCCNC(=N)N)C(=O)N[C@@H](Cc1ccccc1)C(=O)N[C@@H](CCC(=O)O)C(=O)N[C@@H]([C@@H](O)C(C)C)C(=O)N[C@@H](Cc1ccc(O)cc1)C(=O)N[C@H](CC(=CN2)C1=C2C=CC=C1)C(=O)O']
166
+ scores = permeability(input_seqs=seq)
167
+ print(scores)
168
+
169
+
170
+ if __name__ == '__main__':
171
+ unittest()
src/scoring/functions/scoring_utils.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ import numpy as np
3
+ from loguru import logger
4
+ from sklearn.ensemble import RandomForestRegressor
5
+ from rdkit.Chem import Descriptors, rdMolDescriptors
6
+ import joblib
7
+ from rdkit import Chem, rdBase, DataStructs
8
+ from rdkit.Chem import AllChem
9
+ from typing import List
10
+
11
+
12
+ def fingerprints_from_mol(molecule, radius=3, size=2048, hashed=False):
13
+ """
14
+ Create ECFP fingerprint of a molecule
15
+ """
16
+ if hashed:
17
+ fp_bits = AllChem.GetHashedMorganFingerprint(molecule, radius, nBits=size)
18
+ else:
19
+ fp_bits = AllChem.GetMorganFingerprintAsBitVect(molecule, radius, nBits=size)
20
+ fp_np = np.zeros((1,))
21
+ DataStructs.ConvertToNumpyArray(fp_bits, fp_np)
22
+ return fp_np.reshape(1, -1)
23
+
24
+
25
+ def fingerprints_from_smiles(smiles: List, size=2048):
26
+ """ Create ECFP fingerprints of smiles, with validity check """
27
+ fps = []
28
+ valid_mask = []
29
+ for i, smile in enumerate(smiles):
30
+ mol = Chem.MolFromSmiles(smile)
31
+ valid_mask.append(int(mol is not None))
32
+ fp = fingerprints_from_mol(mol, size=size) if mol else np.zeros((1, size))
33
+ fps.append(fp)
34
+
35
+ fps = np.concatenate(fps, axis=0) if len(fps) > 0 else np.zeros((0, size))
36
+ return fps, valid_mask
37
+
38
+
39
+ def getMolDescriptors(mol, missingVal=0):
40
+ """ calculate the full list of descriptors for a molecule """
41
+
42
+ values, names = [], []
43
+ for nm, fn in Descriptors._descList:
44
+ try:
45
+ val = fn(mol)
46
+ except:
47
+ val = missingVal
48
+ values.append(val)
49
+ names.append(nm)
50
+
51
+ custom_descriptors = {'hydrogen-bond donors': rdMolDescriptors.CalcNumLipinskiHBD,
52
+ 'hydrogen-bond acceptors': rdMolDescriptors.CalcNumLipinskiHBA,
53
+ 'rotatable bonds': rdMolDescriptors.CalcNumRotatableBonds,}
54
+
55
+ for nm, fn in custom_descriptors.items():
56
+ try:
57
+ val = fn(mol)
58
+ except:
59
+ val = missingVal
60
+ values.append(val)
61
+ names.append(nm)
62
+ return values, names
63
+
64
+
65
+ def get_pep_dps_from_smi(smi):
66
+ try:
67
+ mol = Chem.MolFromSmiles(smi)
68
+ except:
69
+ print(f"convert smi {smi} to molecule failed!")
70
+ mol = None
71
+
72
+ dps, _ = getMolDescriptors(mol)
73
+ return np.array(dps)
74
+
75
+
76
+ def get_pep_dps(smi_list):
77
+ if len(smi_list) == 0:
78
+ return np.zeros((0, 211))
79
+ return np.array([get_pep_dps_from_smi(smi) for smi in smi_list])
80
+
81
+
82
+
83
+ def check_smi_validity(smiles: list):
84
+ valid_smi, valid_idx = [], []
85
+ for idx, smi in enumerate(smiles):
86
+ try:
87
+ mol = Chem.MolFromSmiles(smi) if smi else None
88
+ if mol:
89
+ valid_smi.append(smi)
90
+ valid_idx.append(idx)
91
+ except Exception as e:
92
+ # logger.debug(f'Error: {e} in smiles {smi}')
93
+ pass
94
+ return valid_smi, valid_idx
src/scoring/functions/solubility.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import xgboost as xgb
2
+ import torch
3
+ import numpy as np
4
+ from transformers import AutoModelForMaskedLM
5
+ import warnings
6
+ import numpy as np
7
+ from rdkit import rdBase
8
+
9
+ rdBase.DisableLog('rdApp.error')
10
+ warnings.filterwarnings("ignore", category=DeprecationWarning)
11
+ warnings.filterwarnings("ignore", category=UserWarning)
12
+ warnings.filterwarnings("ignore", category=FutureWarning)
13
+
14
+ class Solubility:
15
+ def __init__(self, tokenizer, base_path, device=None, emb_model=None):
16
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
17
+ self.predictor = xgb.Booster(model_file=f'{base_path}/src/scoring/functions/classifiers/solubility-xgboost.json')
18
+ if emb_model is not None:
19
+ self.emb_model = emb_model.to(self.device).eval()
20
+ else:
21
+ self.emb_model = AutoModelForMaskedLM.from_pretrained('aaronfeller/PeptideCLM-23M-all').roformer.to(self.device).eval()
22
+
23
+ self.tokenizer = tokenizer
24
+
25
+ def generate_embeddings(self, sequences):
26
+ embeddings = []
27
+ for sequence in sequences:
28
+ tokenized = self.tokenizer(sequence, return_tensors='pt')
29
+ tokenized = {k: v.to(self.device) for k, v in tokenized.items()}
30
+ with torch.no_grad():
31
+ output = self.emb_model(**tokenized)
32
+ # Mean pooling across sequence length
33
+ embedding = output.last_hidden_state.mean(dim=1).squeeze(0).cpu().numpy()
34
+ embeddings.append(embedding)
35
+ return np.array(embeddings)
36
+
37
+ def get_scores(self, input_seqs: list):
38
+ scores = np.zeros(len(input_seqs))
39
+ features = self.generate_embeddings(input_seqs)
40
+
41
+ if len(features) == 0:
42
+ return scores
43
+
44
+ features = np.nan_to_num(features, nan=0.)
45
+ features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max)
46
+
47
+ features = xgb.DMatrix(features)
48
+
49
+ scores = self.predictor.predict(features)
50
+ return scores
51
+
52
+ def __call__(self, input_seqs: list):
53
+ scores = self.get_scores(input_seqs)
54
+ return scores
55
+
56
+ def unittest():
57
+ solubility = Solubility()
58
+ seq = ["NCC(=O)N[C@H](CS)C(=O)N[C@@H](CO)C(=O)NCC(=O)N[C@@H](CC1=CN=C-N1)C(=O)N[C@@H](CC(=O)N)C(=O)N[C@@H](CC(=CN2)C1=C2C=CC=C1)C(=O)N[C@@H](c1ccc(cc1)F)C(=O)N[C@@H]([C@H](CC)C)C(=O)N[C@@H](CCCO)C(=O)N[C@@H](CC1=CN=C-N1)C(=O)N[C@@H](CCC(=O)O)C(=O)N[C@@H](CO)C(=O)O"]
59
+ scores = solubility(input_seqs=seq)
60
+ print(scores)
61
+
62
+ if __name__ == '__main__':
63
+ unittest()
src/scoring/scoring_functions.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
2
+ from transformers import AutoModelForMaskedLM
3
+ import numpy as np
4
+ from scoring.functions.binding import BindingAffinity
5
+ from scoring.functions.permeability import Permeability
6
+ from scoring.functions.solubility import Solubility
7
+ from scoring.functions.hemolysis import Hemolysis
8
+ from scoring.functions.nonfouling import Nonfouling
9
+
10
+ base_path = '/path/to/your/home'
11
+
12
+ class ScoringFunctions:
13
+ def __init__(self, score_func_names=None, prot_seqs=None, device=None):
14
+ """
15
+ Class for generating score vectors given generated sequence
16
+
17
+ Args:
18
+ score_func_names: list of scoring function names to be evaluated
19
+ score_weights: weights to scale scores (default: 1)
20
+ target_protein: sequence of target protein binder
21
+ """
22
+ emb_model = AutoModelForMaskedLM.from_pretrained('aaronfeller/PeptideCLM-23M-all').roformer.to(device).eval()
23
+ tokenizer = SMILES_SPE_Tokenizer(f'{base_path}/src/scoring/functions/tokenizer/new_vocab.txt',
24
+ f'{base_path}/src/scoring/functions/tokenizer/new_splits.txt')
25
+ prot_seqs = prot_seqs if prot_seqs is not None else []
26
+
27
+ if score_func_names is None:
28
+ # just do unmasking based on validity of peptide bonds
29
+ self.score_func_names = []
30
+ else:
31
+ self.score_func_names = score_func_names
32
+
33
+ # binding affinities
34
+ self.target_protein = prot_seqs
35
+ print(len(prot_seqs))
36
+
37
+ if ('binding_affinity1' in score_func_names) and (len(prot_seqs) == 1):
38
+ binding_affinity1 = BindingAffinity(prot_seqs[0], tokenizer=tokenizer, base_path=base_path, device=device)
39
+ binding_affinity2 = None
40
+ elif ('binding_affinity1' in score_func_names) and ('binding_affinity2' in score_func_names) and (len(prot_seqs) == 2):
41
+ binding_affinity1 = BindingAffinity(prot_seqs[0], tokenizer=tokenizer, base_path=base_path, device=device)
42
+ binding_affinity2 = BindingAffinity(prot_seqs[1], tokenizer=tokenizer, base_path=base_path, device=device)
43
+ else:
44
+ print("here")
45
+ binding_affinity1 = None
46
+ binding_affinity2 = None
47
+
48
+ permeability = Permeability(tokenizer=tokenizer, base_path=base_path, device=device, emb_model=emb_model)
49
+ sol = Solubility(tokenizer=tokenizer, base_path=base_path, device=device, emb_model=emb_model)
50
+ nonfouling = Nonfouling(tokenizer=tokenizer, base_path=base_path, device=device, emb_model=emb_model)
51
+ hemo = Hemolysis(tokenizer=tokenizer, base_path=base_path, device=device, emb_model=emb_model)
52
+
53
+ self.all_funcs = {'binding_affinity1': binding_affinity1,
54
+ 'binding_affinity2': binding_affinity2,
55
+ 'permeability': permeability,
56
+ 'nonfouling': nonfouling,
57
+ 'solubility': sol,
58
+ 'hemolysis': hemo
59
+ }
60
+
61
+ def forward(self, input_seqs):
62
+ scores = []
63
+
64
+ for i, score_func in enumerate(self.score_func_names):
65
+ score = self.all_funcs[score_func](input_seqs = input_seqs)
66
+
67
+ scores.append(score)
68
+
69
+ # convert to numpy arrays with shape (num_sequences, num_functions)
70
+ scores = np.float32(scores).T
71
+
72
+ return scores
73
+
74
+ def __call__(self, input_seqs: list):
75
+ return self.forward(input_seqs)
src/scoring/tokenizer/my_tokenizers.py ADDED
@@ -0,0 +1,424 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ import os
3
+ import re
4
+ from typing import List, Optional
5
+ from transformers import PreTrainedTokenizer
6
+ from SmilesPE.tokenizer import SPE_Tokenizer
7
+ import torch
8
+
9
+ def load_vocab(vocab_file):
10
+ """Loads a vocabulary file into a dictionary."""
11
+ vocab = collections.OrderedDict()
12
+ with open(vocab_file, "r", encoding="utf-8") as reader:
13
+ tokens = reader.readlines()
14
+ for index, token in enumerate(tokens):
15
+ token = token.rstrip("\n")
16
+ vocab[token] = index
17
+ return vocab
18
+
19
+ class Atomwise_Tokenizer(object):
20
+ """Run atom-level SMILES tokenization"""
21
+
22
+ def __init__(self):
23
+ """ Constructs a atom-level Tokenizer.
24
+ """
25
+ # self.regex_pattern = r"(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\|\/|:|~|@|\?|>>?|\*|\$|\%[0-9]{2}|[0-9])"
26
+ self.regex_pattern = r"(\([^\(\)]{0,4}\)|\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\|\/\/?|:|~|@|\?|>>?|\*|\$|\%[0-9]{2}|[0-9])"
27
+
28
+ self.regex = re.compile(self.regex_pattern)
29
+
30
+ def tokenize(self, text):
31
+ """ Basic Tokenization of a SMILES.
32
+ """
33
+ tokens = [token for token in self.regex.findall(text)]
34
+ return tokens
35
+
36
+ class SMILES_SPE_Tokenizer(PreTrainedTokenizer):
37
+ r"""
38
+ Constructs a SMILES tokenizer. Based on SMILES Pair Encoding (https://github.com/XinhaoLi74/SmilesPE).
39
+ This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the methods. Users
40
+ should refer to the superclass for more information regarding methods.
41
+ Args:
42
+ vocab_file (:obj:`string`):
43
+ File containing the vocabulary.
44
+ spe_file (:obj:`string`):
45
+ File containing the trained SMILES Pair Encoding vocabulary.
46
+ unk_token (:obj:`string`, `optional`, defaults to "[UNK]"):
47
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
48
+ token instead.
49
+ sep_token (:obj:`string`, `optional`, defaults to "[SEP]"):
50
+ The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences
51
+ for sequence classification or for a text and a question for question answering.
52
+ It is also used as the last token of a sequence built with special tokens.
53
+ pad_token (:obj:`string`, `optional`, defaults to "[PAD]"):
54
+ The token used for padding, for example when batching sequences of different lengths.
55
+ cls_token (:obj:`string`, `optional`, defaults to "[CLS]"):
56
+ The classifier token which is used when doing sequence classification (classification of the whole
57
+ sequence instead of per-token classification). It is the first token of the sequence when built with
58
+ special tokens.
59
+ mask_token (:obj:`string`, `optional`, defaults to "[MASK]"):
60
+ The token used for masking values. This is the token used when training this model with masked language
61
+ modeling. This is the token which the model will try to predict.
62
+ """
63
+
64
+ def __init__(self, vocab_file, spe_file,
65
+ unk_token="[UNK]",
66
+ sep_token="[SEP]",
67
+ pad_token="[PAD]",
68
+ cls_token="[CLS]",
69
+ mask_token="[MASK]",
70
+ **kwargs):
71
+ if not os.path.isfile(vocab_file):
72
+ raise ValueError("Can't find a vocabulary file at path '{}'.".format(vocab_file))
73
+ if not os.path.isfile(spe_file):
74
+ raise ValueError("Can't find a SPE vocabulary file at path '{}'.".format(spe_file))
75
+
76
+ self.vocab = load_vocab(vocab_file)
77
+ self.spe_vocab = open(spe_file, 'r', encoding='utf-8')
78
+ self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
79
+ self.spe_tokenizer = SPE_Tokenizer(self.spe_vocab)
80
+
81
+ super().__init__(
82
+ unk_token=unk_token,
83
+ sep_token=sep_token,
84
+ pad_token=pad_token,
85
+ cls_token=cls_token,
86
+ mask_token=mask_token,
87
+ **kwargs)
88
+
89
+ @property
90
+ def vocab_size(self):
91
+ return len(self.vocab)
92
+
93
+ def get_vocab(self):
94
+ return dict(self.vocab, **self.added_tokens_encoder)
95
+
96
+ def _tokenize(self, text):
97
+ return self.spe_tokenizer.tokenize(text).split(' ')
98
+
99
+ def _convert_token_to_id(self, token):
100
+ """ Converts a token (str) in an id using the vocab. """
101
+ return self.vocab.get(token, self.vocab.get(self.unk_token))
102
+
103
+ # changed encode and decode functions
104
+ def encode(self, token_array):
105
+ token_ids = []
106
+ token_ids.append(2)
107
+ for token in token_array:
108
+ id = self._convert_token_to_id(token)
109
+ token_ids.append(id)
110
+ token_ids.append(3)
111
+ token_ids = torch.tensor([token_ids])
112
+ attn_mask = torch.ones_like(token_ids)
113
+ return {'input_ids': token_ids, 'attention_mask': attn_mask}
114
+
115
+ def decode(self, token_ids, skip_special_tokens=True):
116
+ token_ids = token_ids.squeeze(0).cpu().tolist()
117
+ token_array = []
118
+ for idx in token_ids:
119
+ if idx == 3: # Stop decoding when token ID 3 is encountered
120
+ break
121
+ if skip_special_tokens and idx in self.all_special_ids:
122
+ continue
123
+ token = self._convert_id_to_token(idx)
124
+ token_array.append(token)
125
+ sequence = "".join(token_array)
126
+ return sequence
127
+
128
+ def batch_decode(self, batch_token_ids, skip_special_tokens=True):
129
+ sequences = []
130
+ for token_ids in batch_token_ids:
131
+ sequences.append(self.decode(token_ids))
132
+ return sequences
133
+
134
+ def get_token_split(self, token_ids):
135
+ if isinstance(token_ids, torch.Tensor):
136
+ token_ids = token_ids.cpu().tolist()
137
+
138
+ token_array = []
139
+ for seq_ids in token_ids:
140
+ seq_array = []
141
+ for id in seq_ids:
142
+ token = self._convert_id_to_token(id)
143
+ seq_array.append(token)
144
+ token_array.append(seq_array)
145
+
146
+ return token_array
147
+
148
+ def _convert_id_to_token(self, index):
149
+ """Converts an index (integer) in a token (str) using the vocab."""
150
+ return self.ids_to_tokens.get(index, self.unk_token)
151
+
152
+ def convert_tokens_to_string(self, tokens):
153
+ """ Converts a sequence of tokens (string) in a single string. """
154
+ out_string = " ".join(tokens).replace(" ##", "").strip()
155
+ return out_string
156
+
157
+ def build_inputs_with_special_tokens(
158
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
159
+ ) -> List[int]:
160
+ """
161
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks
162
+ by concatenating and adding special tokens.
163
+ A BERT sequence has the following format:
164
+ - single sequence: ``[CLS] X [SEP]``
165
+ - pair of sequences: ``[CLS] A [SEP] B [SEP]``
166
+ Args:
167
+ token_ids_0 (:obj:`List[int]`):
168
+ List of IDs to which the special tokens will be added
169
+ token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
170
+ Optional second list of IDs for sequence pairs.
171
+ Returns:
172
+ :obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
173
+ """
174
+ if token_ids_1 is None:
175
+ return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
176
+ cls = [self.cls_token_id]
177
+ sep = [self.sep_token_id]
178
+ return cls + token_ids_0 + sep + token_ids_1 + sep
179
+
180
+ def get_special_tokens_mask(
181
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
182
+ ) -> List[int]:
183
+ """
184
+ Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
185
+ special tokens using the tokenizer ``prepare_for_model`` method.
186
+ Args:
187
+ token_ids_0 (:obj:`List[int]`):
188
+ List of ids.
189
+ token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
190
+ Optional second list of IDs for sequence pairs.
191
+ already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
192
+ Set to True if the token list is already formatted with special tokens for the model
193
+ Returns:
194
+ :obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
195
+ """
196
+
197
+ if already_has_special_tokens:
198
+ if token_ids_1 is not None:
199
+ raise ValueError(
200
+ "You should not supply a second sequence if the provided sequence of "
201
+ "ids is already formated with special tokens for the model."
202
+ )
203
+ return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
204
+
205
+ if token_ids_1 is not None:
206
+ return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
207
+ return [1] + ([0] * len(token_ids_0)) + [1]
208
+
209
+ def create_token_type_ids_from_sequences(
210
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
211
+ ) -> List[int]:
212
+ """
213
+ Creates a mask from the two sequences passed to be used in a sequence-pair classification task.
214
+ A BERT sequence pair mask has the following format:
215
+ ::
216
+ 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
217
+ | first sequence | second sequence |
218
+ if token_ids_1 is None, only returns the first portion of the mask (0's).
219
+ Args:
220
+ token_ids_0 (:obj:`List[int]`):
221
+ List of ids.
222
+ token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
223
+ Optional second list of IDs for sequence pairs.
224
+ Returns:
225
+ :obj:`List[int]`: List of `token type IDs <../glossary.html#token-type-ids>`_ according to the given
226
+ sequence(s).
227
+ """
228
+ sep = [self.sep_token_id]
229
+ cls = [self.cls_token_id]
230
+ if token_ids_1 is None:
231
+ return len(cls + token_ids_0 + sep) * [0]
232
+ return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
233
+
234
+ def save_vocabulary(self, vocab_path):
235
+ """
236
+ Save the sentencepiece vocabulary (copy original file) and special tokens file to a directory.
237
+ Args:
238
+ vocab_path (:obj:`str`):
239
+ The directory in which to save the vocabulary.
240
+ Returns:
241
+ :obj:`Tuple(str)`: Paths to the files saved.
242
+ """
243
+ index = 0
244
+ vocab_file = vocab_path
245
+ with open(vocab_file, "w", encoding="utf-8") as writer:
246
+ for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
247
+ if index != token_index:
248
+ index = token_index
249
+ writer.write(token + "\n")
250
+ index += 1
251
+ return (vocab_file,)
252
+
253
+ class SMILES_Atomwise_Tokenizer(PreTrainedTokenizer):
254
+ r"""
255
+ Constructs a SMILES tokenizer. Based on SMILES Pair Encoding (https://github.com/XinhaoLi74/SmilesPE).
256
+ This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the methods. Users
257
+ should refer to the superclass for more information regarding methods.
258
+ Args:
259
+ vocab_file (:obj:`string`):
260
+ File containing the vocabulary.
261
+ unk_token (:obj:`string`, `optional`, defaults to "[UNK]"):
262
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
263
+ token instead.
264
+ sep_token (:obj:`string`, `optional`, defaults to "[SEP]"):
265
+ The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences
266
+ for sequence classification or for a text and a question for question answering.
267
+ It is also used as the last token of a sequence built with special tokens.
268
+ pad_token (:obj:`string`, `optional`, defaults to "[PAD]"):
269
+ The token used for padding, for example when batching sequences of different lengths.
270
+ cls_token (:obj:`string`, `optional`, defaults to "[CLS]"):
271
+ The classifier token which is used when doing sequence classification (classification of the whole
272
+ sequence instead of per-token classification). It is the first token of the sequence when built with
273
+ special tokens.
274
+ mask_token (:obj:`string`, `optional`, defaults to "[MASK]"):
275
+ The token used for masking values. This is the token used when training this model with masked language
276
+ modeling. This is the token which the model will try to predict.
277
+ """
278
+
279
+ def __init__(
280
+ self,
281
+ vocab_file,
282
+ unk_token="[UNK]",
283
+ sep_token="[SEP]",
284
+ pad_token="[PAD]",
285
+ cls_token="[CLS]",
286
+ mask_token="[MASK]",
287
+ **kwargs
288
+ ):
289
+ super().__init__(
290
+ unk_token=unk_token,
291
+ sep_token=sep_token,
292
+ pad_token=pad_token,
293
+ cls_token=cls_token,
294
+ mask_token=mask_token,
295
+ **kwargs,
296
+ )
297
+
298
+ if not os.path.isfile(vocab_file):
299
+ raise ValueError(
300
+ "Can't find a vocabulary file at path '{}'.".format(vocab_file)
301
+ )
302
+ self.vocab = load_vocab(vocab_file)
303
+ self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
304
+ self.tokenizer = Atomwise_Tokenizer()
305
+
306
+ @property
307
+ def vocab_size(self):
308
+ return len(self.vocab)
309
+
310
+ def get_vocab(self):
311
+ return dict(self.vocab, **self.added_tokens_encoder)
312
+
313
+
314
+ def _tokenize(self, text):
315
+ return self.tokenizer.tokenize(text)
316
+
317
+ def _convert_token_to_id(self, token):
318
+ """ Converts a token (str) in an id using the vocab. """
319
+ return self.vocab.get(token, self.vocab.get(self.unk_token))
320
+
321
+ def _convert_id_to_token(self, index):
322
+ """Converts an index (integer) in a token (str) using the vocab."""
323
+ return self.ids_to_tokens.get(index, self.unk_token)
324
+
325
+ def convert_tokens_to_string(self, tokens):
326
+ """ Converts a sequence of tokens (string) in a single string. """
327
+ out_string = " ".join(tokens).replace(" ##", "").strip()
328
+ return out_string
329
+
330
+ def build_inputs_with_special_tokens(
331
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
332
+ ) -> List[int]:
333
+ """
334
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks
335
+ by concatenating and adding special tokens.
336
+ A BERT sequence has the following format:
337
+ - single sequence: ``[CLS] X [SEP]``
338
+ - pair of sequences: ``[CLS] A [SEP] B [SEP]``
339
+ Args:
340
+ token_ids_0 (:obj:`List[int]`):
341
+ List of IDs to which the special tokens will be added
342
+ token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
343
+ Optional second list of IDs for sequence pairs.
344
+ Returns:
345
+ :obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
346
+ """
347
+ if token_ids_1 is None:
348
+ return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
349
+ cls = [self.cls_token_id]
350
+ sep = [self.sep_token_id]
351
+ return cls + token_ids_0 + sep + token_ids_1 + sep
352
+
353
+ def get_special_tokens_mask(
354
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
355
+ ) -> List[int]:
356
+ """
357
+ Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
358
+ special tokens using the tokenizer ``prepare_for_model`` method.
359
+ Args:
360
+ token_ids_0 (:obj:`List[int]`):
361
+ List of ids.
362
+ token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
363
+ Optional second list of IDs for sequence pairs.
364
+ already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
365
+ Set to True if the token list is already formatted with special tokens for the model
366
+ Returns:
367
+ :obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
368
+ """
369
+
370
+ if already_has_special_tokens:
371
+ if token_ids_1 is not None:
372
+ raise ValueError(
373
+ "You should not supply a second sequence if the provided sequence of "
374
+ "ids is already formated with special tokens for the model."
375
+ )
376
+ return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
377
+
378
+ if token_ids_1 is not None:
379
+ return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
380
+ return [1] + ([0] * len(token_ids_0)) + [1]
381
+
382
+ def create_token_type_ids_from_sequences(
383
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
384
+ ) -> List[int]:
385
+ """
386
+ Creates a mask from the two sequences passed to be used in a sequence-pair classification task.
387
+ A BERT sequence pair mask has the following format:
388
+ ::
389
+ 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
390
+ | first sequence | second sequence |
391
+ if token_ids_1 is None, only returns the first portion of the mask (0's).
392
+ Args:
393
+ token_ids_0 (:obj:`List[int]`):
394
+ List of ids.
395
+ token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
396
+ Optional second list of IDs for sequence pairs.
397
+ Returns:
398
+ :obj:`List[int]`: List of `token type IDs <../glossary.html#token-type-ids>`_ according to the given
399
+ sequence(s).
400
+ """
401
+ sep = [self.sep_token_id]
402
+ cls = [self.cls_token_id]
403
+ if token_ids_1 is None:
404
+ return len(cls + token_ids_0 + sep) * [0]
405
+ return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
406
+
407
+ def save_vocabulary(self, vocab_path):
408
+ """
409
+ Save the sentencepiece vocabulary (copy original file) and special tokens file to a directory.
410
+ Args:
411
+ vocab_path (:obj:`str`):
412
+ The directory in which to save the vocabulary.
413
+ Returns:
414
+ :obj:`Tuple(str)`: Paths to the files saved.
415
+ """
416
+ index = 0
417
+ vocab_file = vocab_path
418
+ with open(vocab_file, "w", encoding="utf-8") as writer:
419
+ for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
420
+ if index != token_index:
421
+ index = token_index
422
+ writer.write(token + "\n")
423
+ index += 1
424
+ return (vocab_file,)
src/scoring/tokenizer/new_splits.txt ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ c 1
2
+ c 2
3
+ c 3
4
+ c 4
5
+ c 5
6
+ c 6
7
+ c 7
8
+ c 8
9
+ c 9
10
+ ( c1
11
+ ( c2
12
+ c1 )
13
+ c2 )
14
+ n 1
15
+ n 2
16
+ n 3
17
+ n 4
18
+ n 5
19
+ n 6
20
+ n 7
21
+ n 8
22
+ n 9
23
+ ( n1
24
+ ( n2
25
+ n1 )
26
+ n2 )
27
+ O 1
28
+ O 2
29
+ O 3
30
+ O 4
31
+ O 5
32
+ O 6
33
+ O 7
34
+ O 8
35
+ O 9
36
+ ( O1
37
+ ( O2
38
+ O2 )
39
+ O2 )
40
+ = O
41
+ = C
42
+ = c
43
+ = N
44
+ = n
45
+ =C C
46
+ =C N
47
+ =C c
48
+ =c c
49
+ =N C
50
+ =N c
51
+ =n C
52
+ =n c
53
+ # N
54
+ # C
55
+ #N C
56
+ #C C
57
+ #C N
58
+ #N N
59
+ ( C
60
+ C )
61
+ ( O
62
+ O )
63
+ ( N
64
+ N )
65
+ Br c
66
+ ( =O
67
+ (=O )
68
+ C (=O)
69
+ C =O
70
+ C =N
71
+ C #N
72
+ C #C
73
+ C C
74
+ CC C
75
+ CC N
76
+ CC O
77
+ CC S
78
+ CC c
79
+ CC n
80
+ C N
81
+ CN C
82
+ CN c
83
+ C O
84
+ CO C
85
+ CO N
86
+ CO c
87
+ C S
88
+ CS C
89
+ CS S
90
+ CS c
91
+ C c
92
+ Cl c
93
+ C n
94
+ F c
95
+ N C
96
+ NC C
97
+ NC c
98
+ N N
99
+ N O
100
+ N c
101
+ N n
102
+ O C
103
+ OC C
104
+ OC O
105
+ OC c
106
+ O N
107
+ O O
108
+ O c
109
+ S C
110
+ SC C
111
+ SC c
112
+ S S
113
+ S c
114
+ c c
115
+ cc c
116
+ cc n
117
+ cc o
118
+ cc s
119
+ cc cc
120
+ c n
121
+ cn c
122
+ cn n
123
+ c o
124
+ co c
125
+ c s
126
+ cs c
127
+ cs n
128
+ n c
129
+ nc c
130
+ nc n
131
+ nc o
132
+ nc s
133
+ n n
134
+ nn c
135
+ nn n
136
+ n o
137
+ no c
138
+ no n
139
+ n s
140
+ ns c
141
+ ns n
142
+ o c
143
+ oc c
144
+ o n
145
+ s c
146
+ sc c
147
+ sc n
148
+ s n
149
+ N P
150
+ P N
151
+ C P
152
+ P C
153
+ N S
154
+ S N
155
+ C S
156
+ S C
157
+ S P
158
+ P S
159
+ C I
src/scoring/tokenizer/new_vocab.txt ADDED
@@ -0,0 +1,587 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [PAD]
2
+ [UNK]
3
+ [CLS]
4
+ [SEP]
5
+ [MASK]
6
+ #
7
+ %
8
+ (
9
+ )
10
+ +
11
+ -
12
+ /
13
+ 0
14
+ 1
15
+ 2
16
+ 3
17
+ 4
18
+ 5
19
+ 6
20
+ 7
21
+ 8
22
+ 9
23
+ =
24
+ @
25
+ A
26
+ B
27
+ Br
28
+ Brc
29
+ C
30
+ CC
31
+ CCC
32
+ CCN
33
+ CCO
34
+ CCS
35
+ CCc
36
+ CCn
37
+ CN
38
+ CNC
39
+ CNc
40
+ CO
41
+ COC
42
+ CON
43
+ COc
44
+ CS
45
+ CSC
46
+ CSS
47
+ CSc
48
+ Cc
49
+ Cl
50
+ Clc
51
+ Cn
52
+ F
53
+ Fc
54
+ H
55
+ I
56
+ K
57
+ L
58
+ M
59
+ N
60
+ NC
61
+ NCC
62
+ NCc
63
+ NN
64
+ NO
65
+ Nc
66
+ Nn
67
+ O
68
+ OC
69
+ OCC
70
+ OCO
71
+ OCc
72
+ ON
73
+ OO
74
+ Oc
75
+ P
76
+ R
77
+ S
78
+ SC
79
+ SCC
80
+ SCc
81
+ SS
82
+ Sc
83
+ T
84
+ X
85
+ Z
86
+ [
87
+ \\
88
+ (/
89
+ ]
90
+ a
91
+ b
92
+ c
93
+ cc
94
+ ccc
95
+ cccc
96
+ ccn
97
+ cco
98
+ ccs
99
+ cn
100
+ cnc
101
+ cnn
102
+ co
103
+ coc
104
+ cs
105
+ csc
106
+ csn
107
+ e
108
+ g
109
+ i
110
+ l
111
+ n
112
+ nc
113
+ ncc
114
+ ncn
115
+ nco
116
+ ncs
117
+ nn
118
+ nnc
119
+ nnn
120
+ no
121
+ noc
122
+ non
123
+ ns
124
+ nsc
125
+ nsn
126
+ o
127
+ oc
128
+ occ
129
+ on
130
+ p
131
+ r
132
+ s
133
+ sc
134
+ scc
135
+ scn
136
+ sn
137
+ t
138
+ c1
139
+ c2
140
+ c3
141
+ c4
142
+ c5
143
+ c6
144
+ c7
145
+ c8
146
+ c9
147
+ n1
148
+ n2
149
+ n3
150
+ n4
151
+ n5
152
+ n6
153
+ n7
154
+ n8
155
+ n9
156
+ O1
157
+ O2
158
+ O3
159
+ O4
160
+ O5
161
+ O6
162
+ O7
163
+ O8
164
+ O9
165
+ (c1
166
+ (c2
167
+ c1)
168
+ c2)
169
+ (n1
170
+ (n2
171
+ n1)
172
+ n2)
173
+ (O1
174
+ (O2
175
+ O2)
176
+ =O
177
+ =C
178
+ =c
179
+ =N
180
+ =n
181
+ =CC
182
+ =CN
183
+ =Cc
184
+ =cc
185
+ =NC
186
+ =Nc
187
+ =nC
188
+ =nc
189
+ #C
190
+ #CC
191
+ #CN
192
+ #N
193
+ #NC
194
+ #NN
195
+ (C
196
+ C)
197
+ (O
198
+ O)
199
+ (N
200
+ N)
201
+ NP
202
+ PN
203
+ CP
204
+ PC
205
+ NS
206
+ SN
207
+ SP
208
+ PS
209
+ C(=O)
210
+ (/Br)
211
+ (/C#N)
212
+ (/C)
213
+ (/C=N)
214
+ (/C=O)
215
+ (/CBr)
216
+ (/CC)
217
+ (/CCC)
218
+ (/CCF)
219
+ (/CCN)
220
+ (/CCO)
221
+ (/CCl)
222
+ (/CI)
223
+ (/CN)
224
+ (/CO)
225
+ (/CS)
226
+ (/Cl)
227
+ (/F)
228
+ (/I)
229
+ (/N)
230
+ (/NC)
231
+ (/NCC)
232
+ (/NO)
233
+ (/O)
234
+ (/OC)
235
+ (/OCC)
236
+ (/S)
237
+ (/SC)
238
+ (=C)
239
+ (=C/C)
240
+ (=C/F)
241
+ (=C/I)
242
+ (=C/N)
243
+ (=C/O)
244
+ (=CBr)
245
+ (=CC)
246
+ (=CCF)
247
+ (=CCN)
248
+ (=CCO)
249
+ (=CCl)
250
+ (=CF)
251
+ (=CI)
252
+ (=CN)
253
+ (=CO)
254
+ (=C\\C)
255
+ (=C\\F)
256
+ (=C\\I)
257
+ (=C\\N)
258
+ (=C\\O)
259
+ (=N)
260
+ (=N/C)
261
+ (=N/N)
262
+ (=N/O)
263
+ (=NBr)
264
+ (=NC)
265
+ (=NCC)
266
+ (=NCl)
267
+ (=NN)
268
+ (=NO)
269
+ (=NOC)
270
+ (=N\\C)
271
+ (=N\\N)
272
+ (=N\\O)
273
+ (=O)
274
+ (=S)
275
+ (B)
276
+ (Br)
277
+ (C#C)
278
+ (C#CC)
279
+ (C#CI)
280
+ (C#CO)
281
+ (C#N)
282
+ (C#SN)
283
+ (C)
284
+ (C=C)
285
+ (C=CF)
286
+ (C=CI)
287
+ (C=N)
288
+ (C=NN)
289
+ (C=NO)
290
+ (C=O)
291
+ (C=S)
292
+ (CBr)
293
+ (CC#C)
294
+ (CC#N)
295
+ (CC)
296
+ (CC=C)
297
+ (CC=O)
298
+ (CCBr)
299
+ (CCC)
300
+ (CCCC)
301
+ (CCCF)
302
+ (CCCI)
303
+ (CCCN)
304
+ (CCCO)
305
+ (CCCS)
306
+ (CCCl)
307
+ (CCF)
308
+ (CCI)
309
+ (CCN)
310
+ (CCNC)
311
+ (CCNN)
312
+ (CCNO)
313
+ (CCO)
314
+ (CCOC)
315
+ (CCON)
316
+ (CCS)
317
+ (CCSC)
318
+ (CCl)
319
+ (CF)
320
+ (CI)
321
+ (CN)
322
+ (CN=O)
323
+ (CNC)
324
+ (CNCC)
325
+ (CNCO)
326
+ (CNN)
327
+ (CNNC)
328
+ (CNO)
329
+ (CNOC)
330
+ (CO)
331
+ (COC)
332
+ (COCC)
333
+ (COCI)
334
+ (COCN)
335
+ (COCO)
336
+ (COF)
337
+ (CON)
338
+ (COO)
339
+ (CS)
340
+ (CSC)
341
+ (CSCC)
342
+ (CSCF)
343
+ (CSO)
344
+ (Cl)
345
+ (F)
346
+ (I)
347
+ (N)
348
+ (N=N)
349
+ (N=NO)
350
+ (N=O)
351
+ (N=S)
352
+ (NBr)
353
+ (NC#N)
354
+ (NC)
355
+ (NC=N)
356
+ (NC=O)
357
+ (NC=S)
358
+ (NCBr)
359
+ (NCC)
360
+ (NCCC)
361
+ (NCCF)
362
+ (NCCN)
363
+ (NCCO)
364
+ (NCCS)
365
+ (NCCl)
366
+ (NCNC)
367
+ (NCO)
368
+ (NCS)
369
+ (NCl)
370
+ (NN)
371
+ (NN=O)
372
+ (NNC)
373
+ (NO)
374
+ (NOC)
375
+ (O)
376
+ (OC#N)
377
+ (OC)
378
+ (OC=C)
379
+ (OC=O)
380
+ (OC=S)
381
+ (OCBr)
382
+ (OCC)
383
+ (OCCC)
384
+ (OCCF)
385
+ (OCCI)
386
+ (OCCN)
387
+ (OCCO)
388
+ (OCCS)
389
+ (OCCl)
390
+ (OCF)
391
+ (OCI)
392
+ (OCO)
393
+ (OCOC)
394
+ (OCON)
395
+ (OCSC)
396
+ (OCl)
397
+ (OI)
398
+ (ON)
399
+ (OO)
400
+ (OOC)
401
+ (OOCC)
402
+ (OOSN)
403
+ (OSC)
404
+ (P)
405
+ (S)
406
+ (SC#N)
407
+ (SC)
408
+ (SCC)
409
+ (SCCC)
410
+ (SCCF)
411
+ (SCCN)
412
+ (SCCO)
413
+ (SCCS)
414
+ (SCCl)
415
+ (SCF)
416
+ (SCN)
417
+ (SCOC)
418
+ (SCSC)
419
+ (SCl)
420
+ (SI)
421
+ (SN)
422
+ (SN=O)
423
+ (SO)
424
+ (SOC)
425
+ (SOOO)
426
+ (SS)
427
+ (SSC)
428
+ (SSCC)
429
+ ([At])
430
+ ([O-])
431
+ ([O])
432
+ ([S-])
433
+ (\\Br)
434
+ (\\C#N)
435
+ (\\C)
436
+ (\\C=N)
437
+ (\\C=O)
438
+ (\\CBr)
439
+ (\\CC)
440
+ (\\CCC)
441
+ (\\CCO)
442
+ (\\CCl)
443
+ (\\CF)
444
+ (\\CN)
445
+ (\\CNC)
446
+ (\\CO)
447
+ (\\COC)
448
+ (\\Cl)
449
+ (\\F)
450
+ (\\I)
451
+ (\\N)
452
+ (\\NC)
453
+ (\\NCC)
454
+ (\\NN)
455
+ (\\NO)
456
+ (\\NOC)
457
+ (\\O)
458
+ (\\OC)
459
+ (\\OCC)
460
+ (\\ON)
461
+ (\\S)
462
+ (\\SC)
463
+ (\\SCC)
464
+ [Ag+]
465
+ [Ag-4]
466
+ [Ag]
467
+ [Al-3]
468
+ [Al]
469
+ [As+]
470
+ [AsH3]
471
+ [AsH]
472
+ [As]
473
+ [At]
474
+ [B-]
475
+ [B@-]
476
+ [B@@-]
477
+ [BH-]
478
+ [BH2-]
479
+ [BH3-]
480
+ [B]
481
+ [Ba]
482
+ [Br+2]
483
+ [BrH]
484
+ [Br]
485
+ [C+]
486
+ [C-]
487
+ [C@@H]
488
+ [C@@]
489
+ [C@H]
490
+ [C@]
491
+ [CH-]
492
+ [CH2]
493
+ [CH3]
494
+ [CH]
495
+ [C]
496
+ [CaH2]
497
+ [Ca]
498
+ [Cl+2]
499
+ [Cl+3]
500
+ [Cl+]
501
+ [Cs]
502
+ [FH]
503
+ [F]
504
+ [H]
505
+ [He]
506
+ [I+2]
507
+ [I+3]
508
+ [I+]
509
+ [IH]
510
+ [I]
511
+ [K]
512
+ [Kr]
513
+ [Li+]
514
+ [LiH]
515
+ [MgH2]
516
+ [Mg]
517
+ [N+]
518
+ [N-]
519
+ [N@+]
520
+ [N@@+]
521
+ [N@@]
522
+ [N@]
523
+ [NH+]
524
+ [NH-]
525
+ [NH2+]
526
+ [NH3]
527
+ [NH]
528
+ [N]
529
+ [Na]
530
+ [O+]
531
+ [O-]
532
+ [OH+]
533
+ [OH2]
534
+ [OH]
535
+ [O]
536
+ [P+]
537
+ [P@+]
538
+ [P@@+]
539
+ [P@@]
540
+ [P@]
541
+ [PH2]
542
+ [PH]
543
+ [P]
544
+ [Ra]
545
+ [Rb]
546
+ [S+]
547
+ [S-]
548
+ [S@+]
549
+ [S@@+]
550
+ [S@@]
551
+ [S@]
552
+ [SH+]
553
+ [SH2]
554
+ [SH]
555
+ [S]
556
+ [Se+]
557
+ [Se-2]
558
+ [SeH2]
559
+ [SeH]
560
+ [Se]
561
+ [Si@]
562
+ [SiH2]
563
+ [SiH]
564
+ [Si]
565
+ [SrH2]
566
+ [TeH]
567
+ [Te]
568
+ [Xe]
569
+ [Zn+2]
570
+ [Zn-2]
571
+ [Zn]
572
+ [b-]
573
+ [c+]
574
+ [c-]
575
+ [cH-]
576
+ [cH]
577
+ [c]
578
+ [n+]
579
+ [n-]
580
+ [nH]
581
+ [n]
582
+ [o+]
583
+ [s+]
584
+ [se+]
585
+ [se]
586
+ [te+]
587
+ [te]
src/tokenizer/__init__.py ADDED
File without changes
src/tokenizer/my_tokenizers.py ADDED
@@ -0,0 +1,441 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ import logging
3
+ import os
4
+ import re
5
+ import codecs
6
+ import unicodedata
7
+ from typing import List, Optional
8
+ from transformers import PreTrainedTokenizer
9
+ from SmilesPE.tokenizer import SPE_Tokenizer
10
+ import torch
11
+
12
+ def load_vocab(vocab_file):
13
+ """Loads a vocabulary file into a dictionary."""
14
+ vocab = collections.OrderedDict()
15
+ with open(vocab_file, "r", encoding="utf-8") as reader:
16
+ tokens = reader.readlines()
17
+ for index, token in enumerate(tokens):
18
+ token = token.rstrip("\n")
19
+ vocab[token] = index
20
+ return vocab
21
+
22
+ class Atomwise_Tokenizer(object):
23
+ """Run atom-level SMILES tokenization"""
24
+
25
+ def __init__(self):
26
+ """ Constructs a atom-level Tokenizer.
27
+ """
28
+ # self.regex_pattern = r"(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\|\/|:|~|@|\?|>>?|\*|\$|\%[0-9]{2}|[0-9])"
29
+ self.regex_pattern = r"(\([^\(\)]{0,4}\)|\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\|\/\/?|:|~|@|\?|>>?|\*|\$|\%[0-9]{2}|[0-9])"
30
+
31
+ self.regex = re.compile(self.regex_pattern)
32
+
33
+ def tokenize(self, text):
34
+ """ Basic Tokenization of a SMILES.
35
+ """
36
+ tokens = [token for token in self.regex.findall(text)]
37
+ return tokens
38
+
39
+ class SMILES_SPE_Tokenizer(PreTrainedTokenizer):
40
+ r"""
41
+ Constructs a SMILES tokenizer. Based on SMILES Pair Encoding (https://github.com/XinhaoLi74/SmilesPE).
42
+ This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the methods. Users
43
+ should refer to the superclass for more information regarding methods.
44
+ Args:
45
+ vocab_file (:obj:`string`):
46
+ File containing the vocabulary.
47
+ spe_file (:obj:`string`):
48
+ File containing the trained SMILES Pair Encoding vocabulary.
49
+ unk_token (:obj:`string`, `optional`, defaults to "[UNK]"):
50
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
51
+ token instead.
52
+ sep_token (:obj:`string`, `optional`, defaults to "[SEP]"):
53
+ The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences
54
+ for sequence classification or for a text and a question for question answering.
55
+ It is also used as the last token of a sequence built with special tokens.
56
+ pad_token (:obj:`string`, `optional`, defaults to "[PAD]"):
57
+ The token used for padding, for example when batching sequences of different lengths.
58
+ cls_token (:obj:`string`, `optional`, defaults to "[CLS]"):
59
+ The classifier token which is used when doing sequence classification (classification of the whole
60
+ sequence instead of per-token classification). It is the first token of the sequence when built with
61
+ special tokens.
62
+ mask_token (:obj:`string`, `optional`, defaults to "[MASK]"):
63
+ The token used for masking values. This is the token used when training this model with masked language
64
+ modeling. This is the token which the model will try to predict.
65
+ """
66
+
67
+ def __init__(self, vocab_file, spe_file,
68
+ unk_token="[UNK]",
69
+ sep_token="[SEP]",
70
+ pad_token="[PAD]",
71
+ cls_token="[CLS]",
72
+ mask_token="[MASK]",
73
+ **kwargs):
74
+ if not os.path.isfile(vocab_file):
75
+ raise ValueError("Can't find a vocabulary file at path '{}'.".format(vocab_file))
76
+ if not os.path.isfile(spe_file):
77
+ raise ValueError("Can't find a SPE vocabulary file at path '{}'.".format(spe_file))
78
+
79
+ self.vocab = load_vocab(vocab_file)
80
+ self.spe_vocab = open(spe_file, 'r', encoding='utf-8')
81
+ self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
82
+ self.spe_tokenizer = SPE_Tokenizer(self.spe_vocab)
83
+
84
+ super().__init__(
85
+ unk_token=unk_token,
86
+ sep_token=sep_token,
87
+ pad_token=pad_token,
88
+ cls_token=cls_token,
89
+ mask_token=mask_token,
90
+ **kwargs)
91
+
92
+ @property
93
+ def vocab_size(self):
94
+ return len(self.vocab)
95
+
96
+ def get_vocab(self):
97
+ return dict(self.vocab, **self.added_tokens_encoder)
98
+
99
+ def _tokenize(self, text):
100
+ return self.spe_tokenizer.tokenize(text).split(' ')
101
+
102
+ def _convert_token_to_id(self, token):
103
+ """ Converts a token (str) in an id using the vocab. """
104
+ return self.vocab.get(token, self.vocab.get(self.unk_token))
105
+
106
+ # changed encode and decode functions
107
+ def encode(self, token_array):
108
+ token_ids = []
109
+ token_ids.append(2)
110
+ for token in token_array:
111
+ id = self._convert_token_to_id(token)
112
+ token_ids.append(id)
113
+ token_ids.append(3)
114
+ token_ids = torch.tensor([token_ids])
115
+ attn_mask = torch.ones_like(token_ids)
116
+ return {'input_ids': token_ids, 'attention_mask': attn_mask}
117
+
118
+ def decode(self, token_ids, skip_special_tokens=True):
119
+ token_ids = token_ids.squeeze(0).cpu().tolist()
120
+ token_array = []
121
+ for idx in token_ids:
122
+ if idx == 3: # Stop decoding when token ID 3 is encountered
123
+ break
124
+ if skip_special_tokens and idx in self.all_special_ids:
125
+ continue
126
+ token = self._convert_id_to_token(idx)
127
+ token_array.append(token)
128
+ sequence = "".join(token_array)
129
+ return sequence
130
+
131
+ def batch_decode(self, batch_token_ids, skip_special_tokens=True):
132
+ sequences = []
133
+ for token_ids in batch_token_ids:
134
+ sequences.append(self.decode(token_ids))
135
+ return sequences
136
+
137
+ def get_token_split(self, token_ids):
138
+ if isinstance(token_ids, torch.Tensor):
139
+ token_ids = token_ids.cpu().tolist()
140
+
141
+ token_array = []
142
+ for seq_ids in token_ids:
143
+ seq_array = []
144
+ for id in seq_ids:
145
+ token = self._convert_id_to_token(id)
146
+ seq_array.append(token)
147
+ token_array.append(seq_array)
148
+
149
+ return token_array
150
+
151
+ def _convert_id_to_token(self, index):
152
+ """Converts an index (integer) in a token (str) using the vocab."""
153
+ return self.ids_to_tokens.get(index, self.unk_token)
154
+
155
+ def convert_tokens_to_string(self, tokens):
156
+ """ Converts a sequence of tokens (string) in a single string. """
157
+ out_string = " ".join(tokens).replace(" ##", "").strip()
158
+ return out_string
159
+
160
+ def build_inputs_with_special_tokens(
161
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
162
+ ) -> List[int]:
163
+ """
164
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks
165
+ by concatenating and adding special tokens.
166
+ A BERT sequence has the following format:
167
+ - single sequence: ``[CLS] X [SEP]``
168
+ - pair of sequences: ``[CLS] A [SEP] B [SEP]``
169
+ Args:
170
+ token_ids_0 (:obj:`List[int]`):
171
+ List of IDs to which the special tokens will be added
172
+ token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
173
+ Optional second list of IDs for sequence pairs.
174
+ Returns:
175
+ :obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
176
+ """
177
+ if token_ids_1 is None:
178
+ return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
179
+ cls = [self.cls_token_id]
180
+ sep = [self.sep_token_id]
181
+ return cls + token_ids_0 + sep + token_ids_1 + sep
182
+
183
+ def get_special_tokens_mask(
184
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
185
+ ) -> List[int]:
186
+ """
187
+ Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
188
+ special tokens using the tokenizer ``prepare_for_model`` method.
189
+ Args:
190
+ token_ids_0 (:obj:`List[int]`):
191
+ List of ids.
192
+ token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
193
+ Optional second list of IDs for sequence pairs.
194
+ already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
195
+ Set to True if the token list is already formatted with special tokens for the model
196
+ Returns:
197
+ :obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
198
+ """
199
+
200
+ if already_has_special_tokens:
201
+ if token_ids_1 is not None:
202
+ raise ValueError(
203
+ "You should not supply a second sequence if the provided sequence of "
204
+ "ids is already formated with special tokens for the model."
205
+ )
206
+ return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
207
+
208
+ if token_ids_1 is not None:
209
+ return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
210
+ return [1] + ([0] * len(token_ids_0)) + [1]
211
+
212
+ def create_token_type_ids_from_sequences(
213
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
214
+ ) -> List[int]:
215
+ """
216
+ Creates a mask from the two sequences passed to be used in a sequence-pair classification task.
217
+ A BERT sequence pair mask has the following format:
218
+ ::
219
+ 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
220
+ | first sequence | second sequence |
221
+ if token_ids_1 is None, only returns the first portion of the mask (0's).
222
+ Args:
223
+ token_ids_0 (:obj:`List[int]`):
224
+ List of ids.
225
+ token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
226
+ Optional second list of IDs for sequence pairs.
227
+ Returns:
228
+ :obj:`List[int]`: List of `token type IDs <../glossary.html#token-type-ids>`_ according to the given
229
+ sequence(s).
230
+ """
231
+ sep = [self.sep_token_id]
232
+ cls = [self.cls_token_id]
233
+ if token_ids_1 is None:
234
+ return len(cls + token_ids_0 + sep) * [0]
235
+ return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
236
+
237
+ def save_vocabulary(self, vocab_path):
238
+ """
239
+ Save the sentencepiece vocabulary (copy original file) and special tokens file to a directory.
240
+ Args:
241
+ vocab_path (:obj:`str`):
242
+ The directory in which to save the vocabulary.
243
+ Returns:
244
+ :obj:`Tuple(str)`: Paths to the files saved.
245
+ """
246
+ index = 0
247
+ if os.path.isdir(vocab_path):
248
+ vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES["vocab_file"])
249
+ else:
250
+ vocab_file = vocab_path
251
+ with open(vocab_file, "w", encoding="utf-8") as writer:
252
+ for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
253
+ if index != token_index:
254
+ logger.warning(
255
+ "Saving vocabulary to {}: vocabulary indices are not consecutive."
256
+ " Please check that the vocabulary is not corrupted!".format(vocab_file)
257
+ )
258
+ index = token_index
259
+ writer.write(token + "\n")
260
+ index += 1
261
+ return (vocab_file,)
262
+
263
+ class SMILES_Atomwise_Tokenizer(PreTrainedTokenizer):
264
+ r"""
265
+ Constructs a SMILES tokenizer. Based on SMILES Pair Encoding (https://github.com/XinhaoLi74/SmilesPE).
266
+ This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the methods. Users
267
+ should refer to the superclass for more information regarding methods.
268
+ Args:
269
+ vocab_file (:obj:`string`):
270
+ File containing the vocabulary.
271
+ unk_token (:obj:`string`, `optional`, defaults to "[UNK]"):
272
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
273
+ token instead.
274
+ sep_token (:obj:`string`, `optional`, defaults to "[SEP]"):
275
+ The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences
276
+ for sequence classification or for a text and a question for question answering.
277
+ It is also used as the last token of a sequence built with special tokens.
278
+ pad_token (:obj:`string`, `optional`, defaults to "[PAD]"):
279
+ The token used for padding, for example when batching sequences of different lengths.
280
+ cls_token (:obj:`string`, `optional`, defaults to "[CLS]"):
281
+ The classifier token which is used when doing sequence classification (classification of the whole
282
+ sequence instead of per-token classification). It is the first token of the sequence when built with
283
+ special tokens.
284
+ mask_token (:obj:`string`, `optional`, defaults to "[MASK]"):
285
+ The token used for masking values. This is the token used when training this model with masked language
286
+ modeling. This is the token which the model will try to predict.
287
+ """
288
+
289
+ def __init__(
290
+ self,
291
+ vocab_file,
292
+ unk_token="[UNK]",
293
+ sep_token="[SEP]",
294
+ pad_token="[PAD]",
295
+ cls_token="[CLS]",
296
+ mask_token="[MASK]",
297
+ **kwargs
298
+ ):
299
+ super().__init__(
300
+ unk_token=unk_token,
301
+ sep_token=sep_token,
302
+ pad_token=pad_token,
303
+ cls_token=cls_token,
304
+ mask_token=mask_token,
305
+ **kwargs,
306
+ )
307
+
308
+ if not os.path.isfile(vocab_file):
309
+ raise ValueError(
310
+ "Can't find a vocabulary file at path '{}'.".format(vocab_file)
311
+ )
312
+ self.vocab = load_vocab(vocab_file)
313
+ self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
314
+ self.tokenizer = Atomwise_Tokenizer()
315
+
316
+ @property
317
+ def vocab_size(self):
318
+ return len(self.vocab)
319
+
320
+ def get_vocab(self):
321
+ return dict(self.vocab, **self.added_tokens_encoder)
322
+
323
+
324
+ def _tokenize(self, text):
325
+ return self.tokenizer.tokenize(text)
326
+
327
+ def _convert_token_to_id(self, token):
328
+ """ Converts a token (str) in an id using the vocab. """
329
+ return self.vocab.get(token, self.vocab.get(self.unk_token))
330
+
331
+ def _convert_id_to_token(self, index):
332
+ """Converts an index (integer) in a token (str) using the vocab."""
333
+ return self.ids_to_tokens.get(index, self.unk_token)
334
+
335
+ def convert_tokens_to_string(self, tokens):
336
+ """ Converts a sequence of tokens (string) in a single string. """
337
+ out_string = " ".join(tokens).replace(" ##", "").strip()
338
+ return out_string
339
+
340
+ def build_inputs_with_special_tokens(
341
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
342
+ ) -> List[int]:
343
+ """
344
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks
345
+ by concatenating and adding special tokens.
346
+ A BERT sequence has the following format:
347
+ - single sequence: ``[CLS] X [SEP]``
348
+ - pair of sequences: ``[CLS] A [SEP] B [SEP]``
349
+ Args:
350
+ token_ids_0 (:obj:`List[int]`):
351
+ List of IDs to which the special tokens will be added
352
+ token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
353
+ Optional second list of IDs for sequence pairs.
354
+ Returns:
355
+ :obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
356
+ """
357
+ if token_ids_1 is None:
358
+ return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
359
+ cls = [self.cls_token_id]
360
+ sep = [self.sep_token_id]
361
+ return cls + token_ids_0 + sep + token_ids_1 + sep
362
+
363
+ def get_special_tokens_mask(
364
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
365
+ ) -> List[int]:
366
+ """
367
+ Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
368
+ special tokens using the tokenizer ``prepare_for_model`` method.
369
+ Args:
370
+ token_ids_0 (:obj:`List[int]`):
371
+ List of ids.
372
+ token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
373
+ Optional second list of IDs for sequence pairs.
374
+ already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
375
+ Set to True if the token list is already formatted with special tokens for the model
376
+ Returns:
377
+ :obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
378
+ """
379
+
380
+ if already_has_special_tokens:
381
+ if token_ids_1 is not None:
382
+ raise ValueError(
383
+ "You should not supply a second sequence if the provided sequence of "
384
+ "ids is already formated with special tokens for the model."
385
+ )
386
+ return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
387
+
388
+ if token_ids_1 is not None:
389
+ return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
390
+ return [1] + ([0] * len(token_ids_0)) + [1]
391
+
392
+ def create_token_type_ids_from_sequences(
393
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
394
+ ) -> List[int]:
395
+ """
396
+ Creates a mask from the two sequences passed to be used in a sequence-pair classification task.
397
+ A BERT sequence pair mask has the following format:
398
+ ::
399
+ 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
400
+ | first sequence | second sequence |
401
+ if token_ids_1 is None, only returns the first portion of the mask (0's).
402
+ Args:
403
+ token_ids_0 (:obj:`List[int]`):
404
+ List of ids.
405
+ token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
406
+ Optional second list of IDs for sequence pairs.
407
+ Returns:
408
+ :obj:`List[int]`: List of `token type IDs <../glossary.html#token-type-ids>`_ according to the given
409
+ sequence(s).
410
+ """
411
+ sep = [self.sep_token_id]
412
+ cls = [self.cls_token_id]
413
+ if token_ids_1 is None:
414
+ return len(cls + token_ids_0 + sep) * [0]
415
+ return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
416
+
417
+ def save_vocabulary(self, vocab_path):
418
+ """
419
+ Save the sentencepiece vocabulary (copy original file) and special tokens file to a directory.
420
+ Args:
421
+ vocab_path (:obj:`str`):
422
+ The directory in which to save the vocabulary.
423
+ Returns:
424
+ :obj:`Tuple(str)`: Paths to the files saved.
425
+ """
426
+ index = 0
427
+ if os.path.isdir(vocab_path):
428
+ vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES["vocab_file"])
429
+ else:
430
+ vocab_file = vocab_path
431
+ with open(vocab_file, "w", encoding="utf-8") as writer:
432
+ for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
433
+ if index != token_index:
434
+ logger.warning(
435
+ "Saving vocabulary to {}: vocabulary indices are not consecutive."
436
+ " Please check that the vocabulary is not corrupted!".format(vocab_file)
437
+ )
438
+ index = token_index
439
+ writer.write(token + "\n")
440
+ index += 1
441
+ return (vocab_file,)
src/tokenizer/new_splits.txt ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ c 1
2
+ c 2
3
+ c 3
4
+ c 4
5
+ c 5
6
+ c 6
7
+ c 7
8
+ c 8
9
+ c 9
10
+ ( c1
11
+ ( c2
12
+ c1 )
13
+ c2 )
14
+ n 1
15
+ n 2
16
+ n 3
17
+ n 4
18
+ n 5
19
+ n 6
20
+ n 7
21
+ n 8
22
+ n 9
23
+ ( n1
24
+ ( n2
25
+ n1 )
26
+ n2 )
27
+ O 1
28
+ O 2
29
+ O 3
30
+ O 4
31
+ O 5
32
+ O 6
33
+ O 7
34
+ O 8
35
+ O 9
36
+ ( O1
37
+ ( O2
38
+ O2 )
39
+ O2 )
40
+ = O
41
+ = C
42
+ = c
43
+ = N
44
+ = n
45
+ =C C
46
+ =C N
47
+ =C c
48
+ =c c
49
+ =N C
50
+ =N c
51
+ =n C
52
+ =n c
53
+ # N
54
+ # C
55
+ #N C
56
+ #C C
57
+ #C N
58
+ #N N
59
+ ( C
60
+ C )
61
+ ( O
62
+ O )
63
+ ( N
64
+ N )
65
+ Br c
66
+ ( =O
67
+ (=O )
68
+ C (=O)
69
+ C =O
70
+ C =N
71
+ C #N
72
+ C #C
73
+ C C
74
+ CC C
75
+ CC N
76
+ CC O
77
+ CC S
78
+ CC c
79
+ CC n
80
+ C N
81
+ CN C
82
+ CN c
83
+ C O
84
+ CO C
85
+ CO N
86
+ CO c
87
+ C S
88
+ CS C
89
+ CS S
90
+ CS c
91
+ C c
92
+ Cl c
93
+ C n
94
+ F c
95
+ N C
96
+ NC C
97
+ NC c
98
+ N N
99
+ N O
100
+ N c
101
+ N n
102
+ O C
103
+ OC C
104
+ OC O
105
+ OC c
106
+ O N
107
+ O O
108
+ O c
109
+ S C
110
+ SC C
111
+ SC c
112
+ S S
113
+ S c
114
+ c c
115
+ cc c
116
+ cc n
117
+ cc o
118
+ cc s
119
+ cc cc
120
+ c n
121
+ cn c
122
+ cn n
123
+ c o
124
+ co c
125
+ c s
126
+ cs c
127
+ cs n
128
+ n c
129
+ nc c
130
+ nc n
131
+ nc o
132
+ nc s
133
+ n n
134
+ nn c
135
+ nn n
136
+ n o
137
+ no c
138
+ no n
139
+ n s
140
+ ns c
141
+ ns n
142
+ o c
143
+ oc c
144
+ o n
145
+ s c
146
+ sc c
147
+ sc n
148
+ s n
149
+ N P
150
+ P N
151
+ C P
152
+ P C
153
+ N S
154
+ S N
155
+ C S
156
+ S C
157
+ S P
158
+ P S
159
+ C I
src/tokenizer/new_vocab.txt ADDED
@@ -0,0 +1,587 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [PAD]
2
+ [UNK]
3
+ [CLS]
4
+ [SEP]
5
+ [MASK]
6
+ #
7
+ %
8
+ (
9
+ )
10
+ +
11
+ -
12
+ /
13
+ 0
14
+ 1
15
+ 2
16
+ 3
17
+ 4
18
+ 5
19
+ 6
20
+ 7
21
+ 8
22
+ 9
23
+ =
24
+ @
25
+ A
26
+ B
27
+ Br
28
+ Brc
29
+ C
30
+ CC
31
+ CCC
32
+ CCN
33
+ CCO
34
+ CCS
35
+ CCc
36
+ CCn
37
+ CN
38
+ CNC
39
+ CNc
40
+ CO
41
+ COC
42
+ CON
43
+ COc
44
+ CS
45
+ CSC
46
+ CSS
47
+ CSc
48
+ Cc
49
+ Cl
50
+ Clc
51
+ Cn
52
+ F
53
+ Fc
54
+ H
55
+ I
56
+ K
57
+ L
58
+ M
59
+ N
60
+ NC
61
+ NCC
62
+ NCc
63
+ NN
64
+ NO
65
+ Nc
66
+ Nn
67
+ O
68
+ OC
69
+ OCC
70
+ OCO
71
+ OCc
72
+ ON
73
+ OO
74
+ Oc
75
+ P
76
+ R
77
+ S
78
+ SC
79
+ SCC
80
+ SCc
81
+ SS
82
+ Sc
83
+ T
84
+ X
85
+ Z
86
+ [
87
+ \\
88
+ (/
89
+ ]
90
+ a
91
+ b
92
+ c
93
+ cc
94
+ ccc
95
+ cccc
96
+ ccn
97
+ cco
98
+ ccs
99
+ cn
100
+ cnc
101
+ cnn
102
+ co
103
+ coc
104
+ cs
105
+ csc
106
+ csn
107
+ e
108
+ g
109
+ i
110
+ l
111
+ n
112
+ nc
113
+ ncc
114
+ ncn
115
+ nco
116
+ ncs
117
+ nn
118
+ nnc
119
+ nnn
120
+ no
121
+ noc
122
+ non
123
+ ns
124
+ nsc
125
+ nsn
126
+ o
127
+ oc
128
+ occ
129
+ on
130
+ p
131
+ r
132
+ s
133
+ sc
134
+ scc
135
+ scn
136
+ sn
137
+ t
138
+ c1
139
+ c2
140
+ c3
141
+ c4
142
+ c5
143
+ c6
144
+ c7
145
+ c8
146
+ c9
147
+ n1
148
+ n2
149
+ n3
150
+ n4
151
+ n5
152
+ n6
153
+ n7
154
+ n8
155
+ n9
156
+ O1
157
+ O2
158
+ O3
159
+ O4
160
+ O5
161
+ O6
162
+ O7
163
+ O8
164
+ O9
165
+ (c1
166
+ (c2
167
+ c1)
168
+ c2)
169
+ (n1
170
+ (n2
171
+ n1)
172
+ n2)
173
+ (O1
174
+ (O2
175
+ O2)
176
+ =O
177
+ =C
178
+ =c
179
+ =N
180
+ =n
181
+ =CC
182
+ =CN
183
+ =Cc
184
+ =cc
185
+ =NC
186
+ =Nc
187
+ =nC
188
+ =nc
189
+ #C
190
+ #CC
191
+ #CN
192
+ #N
193
+ #NC
194
+ #NN
195
+ (C
196
+ C)
197
+ (O
198
+ O)
199
+ (N
200
+ N)
201
+ NP
202
+ PN
203
+ CP
204
+ PC
205
+ NS
206
+ SN
207
+ SP
208
+ PS
209
+ C(=O)
210
+ (/Br)
211
+ (/C#N)
212
+ (/C)
213
+ (/C=N)
214
+ (/C=O)
215
+ (/CBr)
216
+ (/CC)
217
+ (/CCC)
218
+ (/CCF)
219
+ (/CCN)
220
+ (/CCO)
221
+ (/CCl)
222
+ (/CI)
223
+ (/CN)
224
+ (/CO)
225
+ (/CS)
226
+ (/Cl)
227
+ (/F)
228
+ (/I)
229
+ (/N)
230
+ (/NC)
231
+ (/NCC)
232
+ (/NO)
233
+ (/O)
234
+ (/OC)
235
+ (/OCC)
236
+ (/S)
237
+ (/SC)
238
+ (=C)
239
+ (=C/C)
240
+ (=C/F)
241
+ (=C/I)
242
+ (=C/N)
243
+ (=C/O)
244
+ (=CBr)
245
+ (=CC)
246
+ (=CCF)
247
+ (=CCN)
248
+ (=CCO)
249
+ (=CCl)
250
+ (=CF)
251
+ (=CI)
252
+ (=CN)
253
+ (=CO)
254
+ (=C\\C)
255
+ (=C\\F)
256
+ (=C\\I)
257
+ (=C\\N)
258
+ (=C\\O)
259
+ (=N)
260
+ (=N/C)
261
+ (=N/N)
262
+ (=N/O)
263
+ (=NBr)
264
+ (=NC)
265
+ (=NCC)
266
+ (=NCl)
267
+ (=NN)
268
+ (=NO)
269
+ (=NOC)
270
+ (=N\\C)
271
+ (=N\\N)
272
+ (=N\\O)
273
+ (=O)
274
+ (=S)
275
+ (B)
276
+ (Br)
277
+ (C#C)
278
+ (C#CC)
279
+ (C#CI)
280
+ (C#CO)
281
+ (C#N)
282
+ (C#SN)
283
+ (C)
284
+ (C=C)
285
+ (C=CF)
286
+ (C=CI)
287
+ (C=N)
288
+ (C=NN)
289
+ (C=NO)
290
+ (C=O)
291
+ (C=S)
292
+ (CBr)
293
+ (CC#C)
294
+ (CC#N)
295
+ (CC)
296
+ (CC=C)
297
+ (CC=O)
298
+ (CCBr)
299
+ (CCC)
300
+ (CCCC)
301
+ (CCCF)
302
+ (CCCI)
303
+ (CCCN)
304
+ (CCCO)
305
+ (CCCS)
306
+ (CCCl)
307
+ (CCF)
308
+ (CCI)
309
+ (CCN)
310
+ (CCNC)
311
+ (CCNN)
312
+ (CCNO)
313
+ (CCO)
314
+ (CCOC)
315
+ (CCON)
316
+ (CCS)
317
+ (CCSC)
318
+ (CCl)
319
+ (CF)
320
+ (CI)
321
+ (CN)
322
+ (CN=O)
323
+ (CNC)
324
+ (CNCC)
325
+ (CNCO)
326
+ (CNN)
327
+ (CNNC)
328
+ (CNO)
329
+ (CNOC)
330
+ (CO)
331
+ (COC)
332
+ (COCC)
333
+ (COCI)
334
+ (COCN)
335
+ (COCO)
336
+ (COF)
337
+ (CON)
338
+ (COO)
339
+ (CS)
340
+ (CSC)
341
+ (CSCC)
342
+ (CSCF)
343
+ (CSO)
344
+ (Cl)
345
+ (F)
346
+ (I)
347
+ (N)
348
+ (N=N)
349
+ (N=NO)
350
+ (N=O)
351
+ (N=S)
352
+ (NBr)
353
+ (NC#N)
354
+ (NC)
355
+ (NC=N)
356
+ (NC=O)
357
+ (NC=S)
358
+ (NCBr)
359
+ (NCC)
360
+ (NCCC)
361
+ (NCCF)
362
+ (NCCN)
363
+ (NCCO)
364
+ (NCCS)
365
+ (NCCl)
366
+ (NCNC)
367
+ (NCO)
368
+ (NCS)
369
+ (NCl)
370
+ (NN)
371
+ (NN=O)
372
+ (NNC)
373
+ (NO)
374
+ (NOC)
375
+ (O)
376
+ (OC#N)
377
+ (OC)
378
+ (OC=C)
379
+ (OC=O)
380
+ (OC=S)
381
+ (OCBr)
382
+ (OCC)
383
+ (OCCC)
384
+ (OCCF)
385
+ (OCCI)
386
+ (OCCN)
387
+ (OCCO)
388
+ (OCCS)
389
+ (OCCl)
390
+ (OCF)
391
+ (OCI)
392
+ (OCO)
393
+ (OCOC)
394
+ (OCON)
395
+ (OCSC)
396
+ (OCl)
397
+ (OI)
398
+ (ON)
399
+ (OO)
400
+ (OOC)
401
+ (OOCC)
402
+ (OOSN)
403
+ (OSC)
404
+ (P)
405
+ (S)
406
+ (SC#N)
407
+ (SC)
408
+ (SCC)
409
+ (SCCC)
410
+ (SCCF)
411
+ (SCCN)
412
+ (SCCO)
413
+ (SCCS)
414
+ (SCCl)
415
+ (SCF)
416
+ (SCN)
417
+ (SCOC)
418
+ (SCSC)
419
+ (SCl)
420
+ (SI)
421
+ (SN)
422
+ (SN=O)
423
+ (SO)
424
+ (SOC)
425
+ (SOOO)
426
+ (SS)
427
+ (SSC)
428
+ (SSCC)
429
+ ([At])
430
+ ([O-])
431
+ ([O])
432
+ ([S-])
433
+ (\\Br)
434
+ (\\C#N)
435
+ (\\C)
436
+ (\\C=N)
437
+ (\\C=O)
438
+ (\\CBr)
439
+ (\\CC)
440
+ (\\CCC)
441
+ (\\CCO)
442
+ (\\CCl)
443
+ (\\CF)
444
+ (\\CN)
445
+ (\\CNC)
446
+ (\\CO)
447
+ (\\COC)
448
+ (\\Cl)
449
+ (\\F)
450
+ (\\I)
451
+ (\\N)
452
+ (\\NC)
453
+ (\\NCC)
454
+ (\\NN)
455
+ (\\NO)
456
+ (\\NOC)
457
+ (\\O)
458
+ (\\OC)
459
+ (\\OCC)
460
+ (\\ON)
461
+ (\\S)
462
+ (\\SC)
463
+ (\\SCC)
464
+ [Ag+]
465
+ [Ag-4]
466
+ [Ag]
467
+ [Al-3]
468
+ [Al]
469
+ [As+]
470
+ [AsH3]
471
+ [AsH]
472
+ [As]
473
+ [At]
474
+ [B-]
475
+ [B@-]
476
+ [B@@-]
477
+ [BH-]
478
+ [BH2-]
479
+ [BH3-]
480
+ [B]
481
+ [Ba]
482
+ [Br+2]
483
+ [BrH]
484
+ [Br]
485
+ [C+]
486
+ [C-]
487
+ [C@@H]
488
+ [C@@]
489
+ [C@H]
490
+ [C@]
491
+ [CH-]
492
+ [CH2]
493
+ [CH3]
494
+ [CH]
495
+ [C]
496
+ [CaH2]
497
+ [Ca]
498
+ [Cl+2]
499
+ [Cl+3]
500
+ [Cl+]
501
+ [Cs]
502
+ [FH]
503
+ [F]
504
+ [H]
505
+ [He]
506
+ [I+2]
507
+ [I+3]
508
+ [I+]
509
+ [IH]
510
+ [I]
511
+ [K]
512
+ [Kr]
513
+ [Li+]
514
+ [LiH]
515
+ [MgH2]
516
+ [Mg]
517
+ [N+]
518
+ [N-]
519
+ [N@+]
520
+ [N@@+]
521
+ [N@@]
522
+ [N@]
523
+ [NH+]
524
+ [NH-]
525
+ [NH2+]
526
+ [NH3]
527
+ [NH]
528
+ [N]
529
+ [Na]
530
+ [O+]
531
+ [O-]
532
+ [OH+]
533
+ [OH2]
534
+ [OH]
535
+ [O]
536
+ [P+]
537
+ [P@+]
538
+ [P@@+]
539
+ [P@@]
540
+ [P@]
541
+ [PH2]
542
+ [PH]
543
+ [P]
544
+ [Ra]
545
+ [Rb]
546
+ [S+]
547
+ [S-]
548
+ [S@+]
549
+ [S@@+]
550
+ [S@@]
551
+ [S@]
552
+ [SH+]
553
+ [SH2]
554
+ [SH]
555
+ [S]
556
+ [Se+]
557
+ [Se-2]
558
+ [SeH2]
559
+ [SeH]
560
+ [Se]
561
+ [Si@]
562
+ [SiH2]
563
+ [SiH]
564
+ [Si]
565
+ [SrH2]
566
+ [TeH]
567
+ [Te]
568
+ [Xe]
569
+ [Zn+2]
570
+ [Zn-2]
571
+ [Zn]
572
+ [b-]
573
+ [c+]
574
+ [c-]
575
+ [cH-]
576
+ [cH]
577
+ [c]
578
+ [n+]
579
+ [n-]
580
+ [nH]
581
+ [n]
582
+ [o+]
583
+ [s+]
584
+ [se+]
585
+ [se]
586
+ [te+]
587
+ [te]
src/train.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # direct reward backpropagation
2
+ from diffusion import Diffusion
3
+ from hydra import initialize, compose
4
+ from hydra.core.global_hydra import GlobalHydra
5
+ import numpy as np
6
+ from scipy.stats import pearsonr
7
+ import torch
8
+ import torch.nn.functional as F
9
+ import argparse
10
+ import wandb
11
+ import os
12
+ import datetime
13
+ from finetune_peptides import finetune
14
+ from peptide_mcts import MCTS
15
+ from utils.utils import str2bool, set_seed
16
+ from scoring.scoring_functions import ScoringFunctions
17
+
18
+
19
+ argparser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
20
+ argparser.add_argument('--base_path', type=str, default='')
21
+ argparser.add_argument('--learning_rate', type=float, default=1e-4)
22
+ argparser.add_argument('--num_epochs', type=int, default=100)
23
+ argparser.add_argument('--num_accum_steps', type=int, default=4)
24
+ argparser.add_argument('--truncate_steps', type=int, default=50)
25
+ argparser.add_argument("--truncate_kl", type=str2bool, default=False)
26
+ argparser.add_argument('--gumbel_temp', type=float, default=1.0)
27
+ argparser.add_argument('--gradnorm_clip', type=float, default=1.0)
28
+ argparser.add_argument('--batch_size', type=int, default=32)
29
+ argparser.add_argument('--name', type=str, default='debug')
30
+ argparser.add_argument('--total_num_steps', type=int, default=128)
31
+ argparser.add_argument('--copy_flag_temp', type=float, default=None)
32
+ argparser.add_argument('--save_every_n_epochs', type=int, default=10)
33
+ argparser.add_argument('--alpha_schedule_warmup', type=int, default=0)
34
+ argparser.add_argument("--seed", type=int, default=0)
35
+ # new
36
+ argparser.add_argument('--run_name', type=str, default='peptides')
37
+ argparser.add_argument("--device", default="cuda:0", type=str)
38
+ argparser.add_argument("--save_path_dir", default="/path/to/your/home/PepTune/checkpoints/", type=str)
39
+ # mcts
40
+ argparser.add_argument('--num_sequences', type=int, default=10)
41
+ argparser.add_argument('--num_children', type=int, default=50)
42
+ argparser.add_argument('--num_iter', type=int, default=30) # iterations of mcts
43
+ argparser.add_argument('--seq_length', type=int, default=200)
44
+ argparser.add_argument('--time_conditioning', action='store_true', default=False)
45
+ argparser.add_argument('--mcts_sampling', type=int, default=0) # for batched categorical sampling: '0' means gumbel noise
46
+ argparser.add_argument('--buffer_size', type=int, default=100)
47
+ argparser.add_argument('--wdce_num_replicates', type=int, default=16)
48
+ argparser.add_argument('--noise_removal', action='store_true', default=False)
49
+ argparser.add_argument('--grad_clip', action='store_true', default=False)
50
+ argparser.add_argument('--resample_every_n_step', type=int, default=10)
51
+ argparser.add_argument('--exploration', type=float, default=0.1)
52
+ argparser.add_argument('--reset_every_n_step', type=int, default=100)
53
+ argparser.add_argument('--alpha', type=float, default=0.01)
54
+ argparser.add_argument('--scalarization', type=str, default='sum')
55
+ argparser.add_argument('--no_mcts', action='store_true', default=False)
56
+ argparser.add_argument("--centering", action='store_true', default=False)
57
+
58
+ # objectives
59
+ argparser.add_argument('--num_obj', type=int, default=5)
60
+ argparser.add_argument('--prot_seq', type=str, default=None)
61
+ argparser.add_argument('--prot_name', type=str, default=None)
62
+
63
+ args = argparser.parse_args()
64
+ print(args)
65
+
66
+ # pretrained model path
67
+ ckpt_path = f'{args.base_path}/checkpoints/peptune-pretrained.ckpt'
68
+
69
+ # reinitialize Hydra
70
+ GlobalHydra.instance().clear()
71
+
72
+ # Initialize Hydra and compose the configuration
73
+ initialize(config_path="configs", job_name="load_model")
74
+ cfg = compose(config_name="peptune_config.yaml")
75
+ curr_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
76
+
77
+ # proteins
78
+ amhr = 'MLGSLGLWALLPTAVEAPPNRRTCVFFEAPGVRGSTKTLGELLDTGTELPRAIRCLYSRCCFGIWNLTQDRAQVEMQGCRDSDEPGCESLHCDPSPRAHPSPGSTLFTCSCGTDFCNANYSHLPPPGSPGTPGSQGPQAAPGESIWMALVLLGLFLLLLLLLGSIILALLQRKNYRVRGEPVPEPRPDSGRDWSVELQELPELCFSQVIREGGHAVVWAGQLQGKLVAIKAFPPRSVAQFQAERALYELPGLQHDHIVRFITASRGGPGRLLSGPLLVLELHPKGSLCHYLTQYTSDWGSSLRMALSLAQGLAFLHEERWQNGQYKPGIAHRDLSSQNVLIREDGSCAIGDLGLALVLPGLTQPPAWTPTQPQGPAAIMEAGTQRYMAPELLDKTLDLQDWGMALRRADIYSLALLLWEILSRCPDLRPDSSPPPFQLAYEAELGNTPTSDELWALAVQERRRPYIPSTWRCFATDPDGLRELLEDCWDADPEARLTAECVQQRLAALAHPQESHPFPESCPRGCPPLCPEDCTSIPAPTILPCRPQRSACHFSVQQGPCSRNPQPACTLSPV'
79
+ tfr = 'MMDQARSAFSNLFGGEPLSYTRFSLARQVDGDNSHVEMKLAVDEEENADNNTKANVTKPKRCSGSICYGTIAVIVFFLIGFMIGYLGYCKGVEPKTECERLAGTESPVREEPGEDFPAARRLYWDDLKRKLSEKLDSTDFTGTIKLLNENSYVPREAGSQKDENLALYVENQFREFKLSKVWRDQHFVKIQVKDSAQNSVIIVDKNGRLVYLVENPGGYVAYSKAATVTGKLVHANFGTKKDFEDLYTPVNGSIVIVRAGKITFAEKVANAESLNAIGVLIYMDQTKFPIVNAELSFFGHAHLGTGDPYTPGFPSFNHTQFPPSRSSGLPNIPVQTISRAAAEKLFGNMEGDCPSDWKTDSTCRMVTSESKNVKLTVSNVLKEIKILNIFGVIKGFVEPDHYVVVGAQRDAWGPGAAKSGVGTALLLKLAQMFSDMVLKDGFQPSRSIIFASWSAGDFGSVGATEWLEGYLSSLHLKAFTYINLDKAVLGTSNFKVSASPLLYTLIEKTMQNVKHPVTGQFLYQDSNWASKVEKLTLDNAAFPFLAYSGIPAVSFCFCEDTDYPYLGTTMDTYKELIERIPELNKVARAAAEVAGQFVIKLTHDVELNLDYERYNSQLLSFVRDLNQYRADIKEMGLSLQWLYSARGDFFRATSRLTTDFGNAEKTDRFVMKKLNDRVMRVEYHFLSPYVSPKESPFRHVFWGSGSHTLPALLENLKLRKQNNGAFNETLFRNQLALATWTIQGAANALSGDVWDIDNEF'
80
+ gfap = 'MERRRITSAARRSYVSSGEMMVGGLAPGRRLGPGTRLSLARMPPPLPTRVDFSLAGALNAGFKETRASERAEMMELNDRFASYIEKVRFLEQQNKALAAELNQLRAKEPTKLADVYQAELRELRLRLDQLTANSARLEVERDNLAQDLATVRQKLQDETNLRLEAENNLAAYRQEADEATLARLDLERKIESLEEEIRFLRKIHEEEVRELQEQLARQQVHVELDVAKPDLTAALKEIRTQYEAMASSNMHEAEEWYRSKFADLTDAAARNAELLRQAKHEANDYRRQLQSLTCDLESLRGTNESLERQMREQEERHVREAASYQEALARLEEEGQSLKDEMARHLQEYQDLLNVKLALDIEIATYRKLLEGEENRITIPVQTFSNLQIRETSLDTKSVSEGHLKRNIVVKTVEMRDGEVIKESKQEHKDVM'
81
+ glp1 = 'MAGAPGPLRLALLLLGMVGRAGPRPQGATVSLWETVQKWREYRRQCQRSLTEDPPPATDLFCNRTFDEYACWPDGEPGSFVNVSCPWYLPWASSVPQGHVYRFCTAEGLWLQKDNSSLPWRDLSECEESKRGERSSPEEQLLFLYIIYTVGYALSFSALVIASAILLGFRHLHCTRNYIHLNLFASFILRALSVFIKDAALKWMYSTAAQQHQWDGLLSYQDSLSCRLVFLLMQYCVAANYYWLLVEGVYLYTLLAFSVLSEQWIFRLYVSIGWGVPLLFVVPWGIVKYLYEDEGCWTRNSNMNYWLIIRLPILFAIGVNFLIFVRVICIVVSKLKANLMCKTDIKCRLAKSTLTLIPLLGTHEVIFAFVMDEHARGTLRFIKLFTELSFTSFQGLMVAILYCFVNNEVQLEFRKSWERWRLEHLHIQRDSSMKPLKCPTSSLSSGATAGSSMYTATCQASCS'
82
+ glast = 'MTKSNGEEPKMGGRMERFQQGVRKRTLLAKKKVQNITKEDVKSYLFRNAFVLLTVTAVIVGTILGFTLRPYRMSYREVKYFSFPGELLMRMLQMLVLPLIISSLVTGMAALDSKASGKMGMRAVVYYMTTTIIAVVIGIIIVIIIHPGKGTKENMHREGKIVRVTAADAFLDLIRNMFPPNLVEACFKQFKTNYEKRSFKVPIQANETLVGAVINNVSEAMETLTRITEELVPVPGSVNGVNALGLVVFSMCFGFVIGNMKEQGQALREFFDSLNEAIMRLVAVIMWYAPVGILFLIAGKIVEMEDMGVIGGQLAMYTVTVIVGLLIHAVIVLPLLYFLVTRKNPWVFIGGLLQALITALGTSSSSATLPITFKCLEENNGVDKRVTRFVLPVGATINMDGTALYEALAAIFIAQVNNFELNFGQIITISITATAASIGAAGIPQAGLVTMVIVLTSVGLPTDDITLIIAVDWFLDRLRTTTNVLGDSLGAGIVEHLSRHELKNRDVEMGNSVIEENEMKKPYQLIAQDNETEKPIDSETKM'
83
+ ncam = 'LQTKDLIWTLFFLGTAVSLQVDIVPSQGEISVGESKFFLCQVAGDAKDKDISWFSPNGEKLTPNQQRISVVWNDDSSSTLTIYNANIDDAGIYKCVVTGEDGSESEATVNVKIFQKLMFKNAPTPQEFREGEDAVIVCDVVSSLPPTIIWKHKGRDVILKKDVRFIVLSNNYLQIRGIKKTDEGTYRCEGRILARGEINFKDIQVIVNVPPTIQARQNIVNATANLGQSVTLVCDAEGFPEPTMSWTKDGEQIEQEEDDEKYIFSDDSSQLTIKKVDKNDEAEYICIAENKAGEQDATIHLKVFAKPKITYVENQTAMELEEQVTLTCEASGDPIPSITWRTSTRNISSEEKASWTRPEKQETLDGHMVVRSHARVSSLTLKSIQYTDAGEYICTASNTIGQDSQSMYLEVQYAPKLQGPVAVYTWEGNQVNITCEVFAYPSATISWFRDGQLLPSSNYSNIKIYNTPSASYLEVTPDSENDFGNYNCTAVNRIGQESLEFILVQADTPSSPSIDQVEPYSSTAQVQFDEPEATGGVPILKYKAEWRAVGEEVWHSKWYDAKEASMEGIVTIVGLKPETTYAVRLAALNGKGLGEISAASEF'
84
+ cereblon = 'MAGEGDQQDAAHNMGNHLPLLPAESEEEDEMEVEDQDSKEAKKPNIINFDTSLPTSHTYLGADMEEFHGRTLHDDDSCQVIPVLPQVMMILIPGQTLPLQLFHPQEVSMVRNLIQKDRTFAVLAYSNVQEREAQFGTTAEIYAYREEQDFGIEIVKVKAIGRQRFKVLELRTQSDGIQQAKVQILPECVLPSTMSAVQLESLNKCQIFPSKPVSREDQCSYKWWQKYQKRKFHCANLTSWPRWLYSLYDAETLMDRIKKQLREWDENLKDDSLPSNPIDFSYRVAACLPIDDVLRIQLLKIGSAIQRLRCELDIMNKCTSLCCKQCQETEITTKNEIFSLSLCGPMAAYVNPHGYVHETLTVYKACNLNLIGRPSTEHSWFPGYAWTVAQCKICASHIGWKFTATKKDMSPQKFWGLTRSALLPTIPDTEDEISPDKVILCL'
85
+ ligase = 'MASQPPEDTAESQASDELECKICYNRYNLKQRKPKVLECCHRVCAKCLYKIIDFGDSPQGVIVCPFCRFETCLPDDEVSSLPDDNNILVNLTCGGKGKKCLPENPTELLLTPKRLASLVSPSHTSSNCLVITIMEVQRESSPSLSSTPVVEFYRPASFDSVTTVSHNWTVWNCTSLLFQTSIRVLVWLLGLLYFSSLPLGIYLLVSKKVTLGVVFVSLVPSSLVILMVYGFCQCVCHEFLDCMAPPS'
86
+ skp2 = 'MHRKHLQEIPDLSSNVATSFTWGWDSSKTSELLSGMGVSALEKEEPDSENIPQELLSNLGHPESPPRKRLKSKGSDKDFVIVRRPKLNRENFPGVSWDSLPDELLLGIFSCLCLPELLKVSGVCKRWYRLASDESLWQTLDLTGKNLHPDVTGRLLSQGVIAFRCPRSFMDQPLAEHFSPFRVQHMDLSNSVIEVSTLHGILSQCSKLQNLSLEGLRLSDPIVNTLAKNSNLVRLNLSGCSGFSEFALQTLLSSCSRLDELNLSWCFDFTEKHVQVAVAHVSETITQLNLSGYRKNLQKSDLSTLVRRCPNLVHLDLSDSVMLKNDCFQEFFQLNYLQHLSLSRCYDIIPETLLELGEIPTLKTLQVFGIVPDGTLQLLKEALPHLQINCSHFTTIARPTIGNKKNQEIWGIKCRLTLQKPSCL'
87
+
88
+ if args.prot_seq is not None:
89
+ prot = args.prot_seq
90
+ prot_name = args.prot_name
91
+ filename = args.prot_name
92
+ else:
93
+ prot = tfr
94
+ prot_name = "tfr"
95
+ filename = "tfr"
96
+
97
+ if args.no_mcts:
98
+ args.run_name = f'{prot_name}_resample{args.resample_every_n_step}_no-mcts'
99
+ else:
100
+ args.run_name = f'{prot_name}_resample{args.resample_every_n_step}_buffer{args.buffer_size}_numiter{args.num_iter}_children{args.num_children}_{curr_time}'
101
+
102
+ args.save_path = os.path.join(args.save_path_dir, args.run_name)
103
+ os.makedirs(args.save_path, exist_ok=True)
104
+ # wandb init
105
+ wandb.init(project='tree-multi', name=args.run_name, config=args, dir=args.save_path)
106
+
107
+ log_path = os.path.join(args.save_path, 'log.txt')
108
+
109
+ set_seed(args.seed, use_cuda=True)
110
+
111
+ # Initialize the model
112
+ policy_model = Diffusion.load_from_checkpoint(ckpt_path,
113
+ config=cfg,
114
+ mode="train",
115
+ device=args.device,
116
+ map_location=args.device)
117
+ pretrained = Diffusion.load_from_checkpoint(ckpt_path,
118
+ config=cfg,
119
+ mode="eval",
120
+ device=args.device,
121
+ map_location=args.device)
122
+
123
+ # define mcts
124
+ score_func_names = ['binding_affinity1', 'solubility', 'hemolysis', 'nonfouling', 'permeability']
125
+
126
+ mcts = MCTS(args, cfg, policy_model, pretrained, score_func_names, prot_seqs=[prot])
127
+
128
+ if args.no_mcts:
129
+ reward_model = ScoringFunctions(score_func_names, prot_seqs=[prot], device=args.device)
130
+ finetune(args, cfg, policy_model, reward_model=reward_model, mcts=None, pretrained=pretrained, filename=filename, prot_name=prot_name)
131
+ else:
132
+ mcts = MCTS(args, cfg, policy_model, pretrained, score_func_names, prot_seqs=[prot])
133
+ finetune(args, cfg, policy_model, reward_model=mcts.rewardFunc, mcts=mcts, pretrained=None, filename=filename, prot_name=prot_name)
src/train_peptune.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import uuid
3
+ import sys
4
+ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
5
+
6
+ import wandb
7
+ import fsspec
8
+ import hydra
9
+ import lightning as L
10
+ from lightning.pytorch import Trainer
11
+ from lightning.pytorch.callbacks import ModelCheckpoint, GradientAccumulationScheduler
12
+ import omegaconf
13
+ import rich.syntax
14
+ import rich.tree
15
+ import torch
16
+ import sys
17
+ import torch.distributed as dist
18
+ from torch.nn.parallel import DistributedDataParallel as DDP
19
+ import dataloading_for_dynamic_batching as dynamic_dataloader
20
+ from diffusion import Diffusion
21
+ import utils.utils as utils
22
+
23
+ from lightning.pytorch.strategies import DDPStrategy
24
+ from datasets import load_dataset
25
+ from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
26
+
27
+ omegaconf.OmegaConf.register_new_resolver('cwd', os.getcwd)
28
+ omegaconf.OmegaConf.register_new_resolver('device_count', torch.cuda.device_count)
29
+ omegaconf.OmegaConf.register_new_resolver('eval', eval)
30
+ omegaconf.OmegaConf.register_new_resolver('div_up', lambda x, y: (x + y - 1) // y)
31
+
32
+ def _load_from_checkpoint(config, tokenizer):
33
+ if 'hf' in config.backbone:
34
+ return Diffusion(
35
+ config, tokenizer=tokenizer).to('cuda')
36
+ else:
37
+ model = Diffusion.load_from_checkpoint(
38
+ config.eval.checkpoint_path,
39
+ tokenizer=tokenizer,
40
+ config=config)
41
+
42
+ return model
43
+
44
+ @L.pytorch.utilities.rank_zero_only
45
+ def print_config(
46
+ config: omegaconf.DictConfig,
47
+ resolve: bool = True,
48
+ save_cfg: bool = True) -> None:
49
+ """
50
+ Prints content of DictConfig using Rich library and its tree structure.
51
+
52
+ Args:
53
+ config (DictConfig): Configuration composed by Hydra.
54
+ resolve (bool): Whether to resolve reference fields of DictConfig.
55
+ save_cfg (bool): Whether to save the configuration tree to a file.
56
+ """
57
+
58
+ style = 'dim'
59
+ tree = rich.tree.Tree('CONFIG', style=style, guide_style=style)
60
+
61
+ fields = config.keys()
62
+ for field in fields:
63
+ branch = tree.add(field, style=style, guide_style=style)
64
+
65
+ config_section = config.get(field)
66
+ branch_content = str(config_section)
67
+ if isinstance(config_section, omegaconf.DictConfig):
68
+ branch_content = omegaconf.OmegaConf.to_yaml(
69
+ config_section, resolve=resolve)
70
+
71
+ branch.add(rich.syntax.Syntax(branch_content, 'yaml'))
72
+ rich.print(tree)
73
+ if save_cfg:
74
+ with fsspec.open(
75
+ '{}/config_tree.txt'.format(
76
+ config.checkpointing.save_dir), 'w') as fp:
77
+ rich.print(tree, file=fp)
78
+
79
+
80
+ @L.pytorch.utilities.rank_zero_only
81
+ def print_batch(train_ds, valid_ds, tokenizer, k=64):
82
+ #for dl_type, dl in [
83
+ #('train', train_ds), ('valid', valid_ds)]:
84
+
85
+ for dl_type, dl in [
86
+ ('train', train_ds)]:
87
+ print(f'Printing {dl_type} dataloader batch.')
88
+ batch = next(iter(dl))
89
+ print('Batch input_ids.shape', batch['input_ids'].shape)
90
+ first = batch['input_ids'][0, :k]
91
+ last = batch['input_ids'][0, -k:]
92
+ print(f'First {k} tokens:', tokenizer.decode(first))
93
+ print('ids:', first)
94
+ print(f'Last {k} tokens:', tokenizer.decode(last))
95
+ print('ids:', last)
96
+
97
+
98
+ def generate_samples(config, logger, tokenizer):
99
+ logger.info('Generating samples.')
100
+ model = _load_from_checkpoint(config=config, tokenizer=tokenizer)
101
+ # model.gen_ppl_metric.reset()
102
+
103
+ #stride_length = config.sampling.stride_length
104
+ #num_strides = config.sampling.num_strides
105
+
106
+ for _ in range(config.sampling.num_sample_batches):
107
+ samples = model.restore_model_and_sample(num_steps=config.sampling.steps)
108
+ peptide_sequences = model.tokenizer.batch_decode(samples)
109
+ model.compute_generative_perplexity(peptide_sequences)
110
+
111
+ print('Peptide samples:', peptide_sequences)
112
+
113
+ print('Generative perplexity:', model.compute_masked_perplexity())
114
+
115
+ return peptide_sequences
116
+
117
+
118
+ def ppl_eval(config, logger, tokenizer, data_module):
119
+ logger.info('Starting Zero Shot Eval.')
120
+
121
+ model = _load_from_checkpoint(config=config, tokenizer=tokenizer)
122
+
123
+ wandb_logger = None
124
+ if config.get('wandb', None) is not None:
125
+ wandb_logger = L.pytorch.loggers.WandbLogger(
126
+ config=omegaconf.OmegaConf.to_object(config),
127
+ ** config.wandb)
128
+
129
+ callbacks = []
130
+
131
+ if 'callbacks' in config:
132
+ for _, callback in config.callbacks.items():
133
+ callbacks.append(hydra.utils.instantiate(callback))
134
+
135
+ trainer = hydra.utils.instantiate(
136
+ config.trainer,
137
+ default_root_dir=os.getcwd(),
138
+ callbacks=callbacks,
139
+ strategy=DDPStrategy(find_unused_parameters = True),
140
+ logger=wandb_logger)
141
+
142
+ #_, valid_ds = dataloader.get_dataloaders(config, tokenizer, skiptrain=True, valid_seed=config.seed)
143
+ trainer.test(model, data_module)
144
+
145
+
146
+ def _train(config, logger, tokenizer, data_module):
147
+ logger.info('Starting Training.')
148
+ wandb_logger = None
149
+
150
+ if config.get('wandb', None) is not None:
151
+ unique_id = str(uuid.uuid4())
152
+
153
+ config.wandb.id = f"{config.wandb.id}_{unique_id}"
154
+
155
+ wandb_logger = L.pytorch.loggers.WandbLogger(
156
+ config=omegaconf.OmegaConf.to_object(config),
157
+ ** config.wandb)
158
+
159
+ if (config.checkpointing.resume_from_ckpt
160
+ and config.checkpointing.resume_ckpt_path is not None
161
+ and utils.fsspec_exists(
162
+ config.checkpointing.resume_ckpt_path)):
163
+ ckpt_path = config.checkpointing.resume_ckpt_path
164
+ else:
165
+ ckpt_path = None
166
+
167
+ # Lightning callbacks
168
+ callbacks = []
169
+ if 'callbacks' in config:
170
+ for callback_name, callback_config in config.callbacks.items():
171
+ if callback_name == 'model_checkpoint':
172
+ model_checkpoint_config = {k: v for k, v in callback_config.items() if k != '_target_'}
173
+ callbacks.append(ModelCheckpoint(**model_checkpoint_config))
174
+ else:
175
+ callbacks.append(hydra.utils.instantiate(callback_config))
176
+
177
+ if config.training.accumulator:
178
+ accumulator = GradientAccumulationScheduler(scheduling = {1: 5, 2: 4, 3: 3, 4: 1})
179
+ callbacks.append(accumulator)
180
+
181
+ trainer = hydra.utils.instantiate(
182
+ config.trainer,
183
+ default_root_dir=os.getcwd(),
184
+ callbacks=callbacks,
185
+ accelerator='cuda',
186
+ strategy=DDPStrategy(find_unused_parameters = True),
187
+ devices=[2,3,4,5,6,7],
188
+ logger=wandb_logger)
189
+
190
+ model = Diffusion(config, tokenizer=tokenizer)
191
+
192
+ if config.backbone == 'finetune_roformer':
193
+ checkpoint = torch.load(ckpt_path, map_location='cpu')
194
+ model.load_state_dict(checkpoint['state_dict'])
195
+
196
+ trainer.fit(model, datamodule=data_module, ckpt_path=ckpt_path)
197
+
198
+ @hydra.main(version_base=None, config_path=f'{os.getcwd()}/src', config_name='config')
199
+ def main(config):
200
+ """
201
+ Main entry point for training
202
+ """
203
+ wandb.init(project="peptune")
204
+ L.seed_everything(config.seed)
205
+
206
+ # print_config(config, resolve=True, save_cfg=True)
207
+
208
+ logger = utils.get_logger(__name__)
209
+ # load PeptideCLM tokenizer
210
+
211
+ tokenizer = SMILES_SPE_Tokenizer(f'{config.base_path}/src/tokenizer/new_vocab.txt',
212
+ f'{config.base_path}/src/tokenizer/new_splits.txt')
213
+
214
+
215
+ data_module = dynamic_dataloader.CustomDataModule(f'{config.base_path}/data/peptide_data', tokenizer)
216
+
217
+ if config.mode == 'sample_eval':
218
+ generate_samples(config, logger, tokenizer)
219
+ elif config.mode == 'ppl_eval':
220
+ ppl_eval(config, logger, tokenizer, data_module)
221
+ else:
222
+ _train(config, logger, tokenizer, data_module)
223
+
224
+
225
+ if __name__ == '__main__':
226
+ main()
src/utils/app.py ADDED
@@ -0,0 +1,1255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import pandas as pd
4
+ from io import StringIO
5
+ import rdkit
6
+ from rdkit import Chem
7
+ from rdkit.Chem import AllChem, Draw
8
+ import numpy as np
9
+ from PIL import Image, ImageDraw, ImageFont
10
+ import matplotlib.pyplot as plt
11
+ import matplotlib.patches as patches
12
+ from io import BytesIO
13
+ import tempfile
14
+ from rdkit import Chem
15
+
16
+ class PeptideAnalyzer:
17
+ def __init__(self):
18
+ self.bond_patterns = [
19
+ (r'OC\(=O\)', 'ester'), # Ester bond
20
+ (r'N\(C\)C\(=O\)', 'n_methyl'), # N-methylated peptide bond
21
+ (r'N[0-9]C\(=O\)', 'proline'), # Proline peptide bond
22
+ (r'NC\(=O\)', 'peptide'), # Standard peptide bond
23
+ (r'C\(=O\)N\(C\)', 'n_methyl_reverse'), # Reverse N-methylated
24
+ (r'C\(=O\)N[12]?', 'peptide_reverse') # Reverse peptide bond
25
+ ]
26
+ # Three to one letter code mapping
27
+ self.three_to_one = {
28
+ 'Ala': 'A', 'Cys': 'C', 'Asp': 'D', 'Glu': 'E',
29
+ 'Phe': 'F', 'Gly': 'G', 'His': 'H', 'Ile': 'I',
30
+ 'Lys': 'K', 'Leu': 'L', 'Met': 'M', 'Asn': 'N',
31
+ 'Pro': 'P', 'Gln': 'Q', 'Arg': 'R', 'Ser': 'S',
32
+ 'Thr': 'T', 'Val': 'V', 'Trp': 'W', 'Tyr': 'Y'
33
+ }
34
+
35
+ def is_peptide(self, smiles):
36
+ """Check if the SMILES represents a peptide structure"""
37
+ mol = Chem.MolFromSmiles(smiles)
38
+ if mol is None:
39
+ return False
40
+
41
+ # Look for peptide bonds: NC(=O) pattern
42
+ peptide_bond_pattern = Chem.MolFromSmarts('[NH][C](=O)')
43
+ if mol.HasSubstructMatch(peptide_bond_pattern):
44
+ return True
45
+
46
+ # Look for N-methylated peptide bonds: N(C)C(=O) pattern
47
+ n_methyl_pattern = Chem.MolFromSmarts('[N;H0;$(NC)](C)[C](=O)')
48
+ if mol.HasSubstructMatch(n_methyl_pattern):
49
+ return True
50
+
51
+ return False
52
+
53
+ def is_cyclic(self, smiles):
54
+ """Improved cyclic peptide detection"""
55
+ # Check for C-terminal carboxyl
56
+ if smiles.endswith('C(=O)O'):
57
+ return False, [], []
58
+
59
+ # Find all numbers used in ring closures
60
+ ring_numbers = re.findall(r'(?:^|[^c])[0-9](?=[A-Z@\(\)])', smiles)
61
+
62
+ # Find aromatic ring numbers
63
+ aromatic_matches = re.findall(r'c[0-9](?:ccccc|c\[nH\]c)[0-9]', smiles)
64
+ aromatic_cycles = []
65
+ for match in aromatic_matches:
66
+ numbers = re.findall(r'[0-9]', match)
67
+ aromatic_cycles.extend(numbers)
68
+
69
+ # Numbers that aren't part of aromatic rings are peptide cycles
70
+ peptide_cycles = [n for n in ring_numbers if n not in aromatic_cycles]
71
+
72
+ is_cyclic = len(peptide_cycles) > 0 and not smiles.endswith('C(=O)O')
73
+ return is_cyclic, peptide_cycles, aromatic_cycles
74
+
75
+ def split_on_bonds(self, smiles):
76
+ """Split SMILES into segments with simplified Pro handling"""
77
+ positions = []
78
+ used = set()
79
+
80
+ # Find Gly pattern first
81
+ gly_pattern = r'NCC\(=O\)'
82
+ for match in re.finditer(gly_pattern, smiles):
83
+ if not any(p in range(match.start(), match.end()) for p in used):
84
+ positions.append({
85
+ 'start': match.start(),
86
+ 'end': match.end(),
87
+ 'type': 'gly',
88
+ 'pattern': match.group()
89
+ })
90
+ used.update(range(match.start(), match.end()))
91
+
92
+ for pattern, bond_type in self.bond_patterns:
93
+ for match in re.finditer(pattern, smiles):
94
+ if not any(p in range(match.start(), match.end()) for p in used):
95
+ positions.append({
96
+ 'start': match.start(),
97
+ 'end': match.end(),
98
+ 'type': bond_type,
99
+ 'pattern': match.group()
100
+ })
101
+ used.update(range(match.start(), match.end()))
102
+
103
+ # Sort by position
104
+ positions.sort(key=lambda x: x['start'])
105
+
106
+ # Create segments
107
+ segments = []
108
+
109
+ if positions:
110
+ # First segment
111
+ if positions[0]['start'] > 0:
112
+ segments.append({
113
+ 'content': smiles[0:positions[0]['start']],
114
+ 'bond_after': positions[0]['pattern']
115
+ })
116
+
117
+ # Process segments
118
+ for i in range(len(positions)-1):
119
+ current = positions[i]
120
+ next_pos = positions[i+1]
121
+
122
+ if current['type'] == 'gly':
123
+ segments.append({
124
+ 'content': 'NCC(=O)',
125
+ 'bond_before': positions[i-1]['pattern'] if i > 0 else None,
126
+ 'bond_after': next_pos['pattern']
127
+ })
128
+ else:
129
+ content = smiles[current['end']:next_pos['start']]
130
+ if content:
131
+ segments.append({
132
+ 'content': content,
133
+ 'bond_before': current['pattern'],
134
+ 'bond_after': next_pos['pattern']
135
+ })
136
+
137
+ # Last segment
138
+ if positions[-1]['end'] < len(smiles):
139
+ segments.append({
140
+ 'content': smiles[positions[-1]['end']:],
141
+ 'bond_before': positions[-1]['pattern']
142
+ })
143
+
144
+ return segments
145
+
146
+ def clean_terminal_carboxyl(self, segment):
147
+ """Remove C-terminal carboxyl only if it's the true terminus"""
148
+ content = segment['content']
149
+
150
+ # Only clean if:
151
+ # 1. Contains C(=O)O
152
+ # 2. No bond_after exists (meaning it's the last segment)
153
+ # 3. C(=O)O is at the end of the content
154
+ if 'C(=O)O' in content and not segment.get('bond_after'):
155
+ print('recognized?')
156
+ # Remove C(=O)O pattern regardless of position
157
+ cleaned = re.sub(r'\(C\(=O\)O\)', '', content)
158
+ # Remove any leftover empty parentheses
159
+ cleaned = re.sub(r'\(\)', '', cleaned)
160
+ print(cleaned)
161
+ return cleaned
162
+ return content
163
+
164
+ def identify_residue(self, segment):
165
+ """Identify residue with Pro reconstruction"""
166
+ # Only clean terminal carboxyl if this is the last segment
167
+ content = self.clean_terminal_carboxyl(segment)
168
+ mods = self.get_modifications(segment)
169
+
170
+ # UAA pattern matching section - before regular residues
171
+ # Phenylglycine and derivatives
172
+ if 'c1ccccc1' in content:
173
+ if '[C@@H](c1ccccc1)' in content or '[C@H](c1ccccc1)' in content:
174
+ return '4', mods # Base phenylglycine
175
+
176
+ # 4-substituted phenylalanines
177
+ if 'Cc1ccc' in content:
178
+ if 'OMe' in content or 'OCc1ccc' in content:
179
+ return '0A1', mods # 4-methoxy-Phenylalanine
180
+ elif 'Clc1ccc' in content:
181
+ return '200', mods # 4-chloro-Phenylalanine
182
+ elif 'Brc1ccc' in content:
183
+ return '4BF', mods # 4-Bromo-phenylalanine
184
+ elif 'C#Nc1ccc' in content:
185
+ return '4CF', mods # 4-cyano-phenylalanine
186
+ elif 'Ic1ccc' in content:
187
+ return 'PHI', mods # 4-Iodo-phenylalanine
188
+ elif 'Fc1ccc' in content:
189
+ return 'PFF', mods # 4-Fluoro-phenylalanine
190
+
191
+ # Modified tryptophans
192
+ if 'c[nH]c2' in content:
193
+ if 'Oc2cccc2' in content:
194
+ return '0AF', mods # 7-hydroxy-tryptophan
195
+ elif 'Fc2cccc2' in content:
196
+ return '4FW', mods # 4-fluoro-tryptophan
197
+ elif 'Clc2cccc2' in content:
198
+ return '6CW', mods # 6-chloro-tryptophan
199
+ elif 'Brc2cccc2' in content:
200
+ return 'BTR', mods # 6-bromo-tryptophan
201
+ elif 'COc2cccc2' in content:
202
+ return 'MOT5', mods # 5-Methoxy-tryptophan
203
+ elif 'Cc2cccc2' in content:
204
+ return 'MTR5', mods # 5-Methyl-tryptophan
205
+
206
+ # Special amino acids
207
+ if 'CC(C)(C)[C@@H]' in content or 'CC(C)(C)[C@H]' in content:
208
+ return 'BUG', mods # Tertleucine
209
+
210
+ if 'CCCNC(=N)N' in content:
211
+ return 'CIR', mods # Citrulline
212
+
213
+ if '[SeH]' in content:
214
+ return 'CSE', mods # Selenocysteine
215
+
216
+ if '[NH3]CC[C@@H]' in content or '[NH3]CC[C@H]' in content:
217
+ return 'DAB', mods # Diaminobutyric acid
218
+
219
+ if 'C1CCCCC1' in content:
220
+ if 'C1CCCCC1[C@@H]' in content or 'C1CCCCC1[C@H]' in content:
221
+ return 'CHG', mods # Cyclohexylglycine
222
+ elif 'C1CCCCC1C[C@@H]' in content or 'C1CCCCC1C[C@H]' in content:
223
+ return 'ALC', mods # 3-cyclohexyl-alanine
224
+
225
+ # Naphthalene derivatives
226
+ if 'c1cccc2c1cccc2' in content:
227
+ if 'c1cccc2c1cccc2[C@@H]' in content or 'c1cccc2c1cccc2[C@H]' in content:
228
+ return 'NAL', mods # 2-Naphthyl-alanine
229
+
230
+ # Heteroaromatic derivatives
231
+ if 'c1cncc' in content:
232
+ return 'PYR4', mods # 3-(4-Pyridyl)-alanine
233
+ if 'c1cscc' in content:
234
+ return 'THA3', mods # 3-(3-thienyl)-alanine
235
+ if 'c1nnc' in content:
236
+ return 'TRZ4', mods # 3-(1,2,4-Triazol-1-yl)-alanine
237
+
238
+ # Modified serines and threonines
239
+ if 'OP(O)(O)O' in content:
240
+ if '[C@@H](COP' in content or '[C@H](COP' in content:
241
+ return 'SEP', mods # phosphoserine
242
+ elif '[C@@H](OP' in content or '[C@H](OP' in content:
243
+ return 'TPO', mods # phosphothreonine
244
+
245
+ # Specialized ring systems
246
+ if 'c1c2ccccc2cc2c1cccc2' in content:
247
+ return 'ANTH', mods # 3-(9-anthryl)-alanine
248
+ if 'c1csc2c1cccc2' in content:
249
+ return 'BTH3', mods # 3-(3-benzothienyl)-alanine
250
+ if '[C@]12C[C@H]3C[C@@H](C2)C[C@@H](C1)C3' in content:
251
+ return 'ADAM', mods # Adamanthane
252
+
253
+ # Fluorinated derivatives
254
+ if 'FC(F)(F)' in content:
255
+ if 'CC(F)(F)F' in content:
256
+ return 'FLA', mods # Trifluoro-alanine
257
+ if 'C(F)(F)F)c1' in content:
258
+ if 'c1ccccc1C(F)(F)F' in content:
259
+ return 'TFG2', mods # 2-(Trifluoromethyl)-phenylglycine
260
+ if 'c1cccc(c1)C(F)(F)F' in content:
261
+ return 'TFG3', mods # 3-(Trifluoromethyl)-phenylglycine
262
+ if 'c1ccc(cc1)C(F)(F)F' in content:
263
+ return 'TFG4', mods # 4-(Trifluoromethyl)-phenylglycine
264
+
265
+ # Multiple halogen patterns
266
+ if 'F' in content and 'c1' in content:
267
+ if 'c1ccc(c(c1)F)F' in content:
268
+ return 'F2F', mods # 3,4-Difluoro-phenylalanine
269
+ if 'cc(F)cc(c1)F' in content:
270
+ return 'WFP', mods # 3,5-Difluoro-phenylalanine
271
+ if 'Cl' in content and 'c1' in content:
272
+ if 'c1ccc(cc1Cl)Cl' in content:
273
+ return 'CP24', mods # 2,4-dichloro-phenylalanine
274
+ if 'c1ccc(c(c1)Cl)Cl' in content:
275
+ return 'CP34', mods # 3,4-dichloro-phenylalanine
276
+
277
+ # Hydroxy and amino derivatives
278
+ if 'O' in content and 'c1' in content:
279
+ if 'c1cc(O)cc(c1)O' in content:
280
+ return '3FG', mods # (2s)-amino(3,5-dihydroxyphenyl)-ethanoic acid
281
+ if 'c1ccc(c(c1)O)O' in content:
282
+ return 'DAH', mods # 3,4-Dihydroxy-phenylalanine
283
+
284
+ # Cyclic amino acids
285
+ if 'C1CCCC1' in content:
286
+ return 'CPA3', mods # 3-Cyclopentyl-alanine
287
+ if 'C1CCCCC1' in content:
288
+ if 'CC1CCCCC1' in content:
289
+ return 'ALC', mods # 3-cyclohexyl-alanine
290
+ else:
291
+ return 'CHG', mods # Cyclohexylglycine
292
+
293
+ # Chain-length variants
294
+ if 'CCC[C@@H]' in content or 'CCC[C@H]' in content:
295
+ return 'NLE', mods # Norleucine
296
+ if 'CC[C@@H]' in content or 'CC[C@H]' in content:
297
+ if not any(x in content for x in ['CC(C)', 'COC', 'CN(']):
298
+ return 'ABA', mods # 2-Aminobutyric acid
299
+
300
+ # Modified histidines
301
+ if 'c1cnc' in content:
302
+ if '[C@@H]1CN[C@@H](N1)F' in content:
303
+ return '2HF', mods # 2-fluoro-l-histidine
304
+ if 'c1cnc([nH]1)F' in content:
305
+ return '2HF1', mods # 2-fluoro-l-histidine variant
306
+ if 'c1c[nH]c(n1)F' in content:
307
+ return '2HF2', mods # 2-fluoro-l-histidine variant
308
+
309
+ # Sulfur and selenium containing
310
+ if '[SeH]' in content:
311
+ return 'CSE', mods # Selenocysteine
312
+ if 'S' in content:
313
+ if 'CSCc1ccccc1' in content:
314
+ return 'BCS', mods # benzylcysteine
315
+ if 'CCSC' in content:
316
+ return 'ESC', mods # Ethionine
317
+ if 'CCS' in content:
318
+ return 'HCS', mods # homocysteine
319
+
320
+ # Additional modifications
321
+ if 'CN=[N]=N' in content:
322
+ return 'AZDA', mods # azido-alanine
323
+ if '[NH]=[C](=[NH2])=[NH2]' in content:
324
+ if 'CCC[NH]=' in content:
325
+ return 'AGM', mods # 5-methyl-arginine
326
+ if 'CC[NH]=' in content:
327
+ return 'GDPR', mods # 2-Amino-3-guanidinopropionic acid
328
+
329
+ if 'CCON' in content:
330
+ return 'CAN', mods # canaline
331
+ if '[C@@H]1C=C[C@@H](C=C1)' in content:
332
+ return 'ACZ', mods # cis-amiclenomycin
333
+ if 'CCC(=O)[NH3]' in content:
334
+ return 'ONL', mods # 5-oxo-l-norleucine
335
+ if 'c1ccncc1' in content:
336
+ return 'PYR4', mods # 3-(4-Pyridyl)-alanine
337
+ if 'c1ccco1' in content:
338
+ return 'FUA2', mods # (2-furyl)-alanine
339
+
340
+ if 'c1ccc' in content:
341
+ if 'c1ccc(cc1)c1ccccc1' in content:
342
+ return 'BIF', mods # 4,4-biphenylalanine
343
+ if 'c1ccc(cc1)C(=O)c1ccccc1' in content:
344
+ return 'PBF', mods # 4-benzoyl-phenylalanine
345
+ if 'c1ccc(cc1)C(C)(C)C' in content:
346
+ return 'TBP4', mods # 4-tert-butyl-phenylalanine
347
+ if 'c1ccc(cc1)[C](=[NH2])=[NH2]' in content:
348
+ return '0BN', mods # 4-carbamimidoyl-l-phenylalanine
349
+ if 'c1cccc(c1)[C](=[NH2])=[NH2]' in content:
350
+ return 'APM', mods # m-amidinophenyl-3-alanine
351
+
352
+ # Multiple hydroxy patterns
353
+ if 'O' in content:
354
+ if '[C@H]([C@H](C)O)O' in content:
355
+ return 'ILX', mods # 4,5-dihydroxy-isoleucine
356
+ if '[C@H]([C@@H](C)O)O' in content:
357
+ return 'ALO', mods # Allo-threonine
358
+ if '[C@H](COP(O)(O)O)' in content:
359
+ return 'SEP', mods # phosphoserine
360
+ if '[C@H]([C@@H](C)OP(O)(O)O)' in content:
361
+ return 'TPO', mods # phosphothreonine
362
+ if '[C@H](c1ccc(O)cc1)O' in content:
363
+ return 'OMX', mods # (betar)-beta-hydroxy-l-tyrosine
364
+ if '[C@H](c1ccc(c(Cl)c1)O)O' in content:
365
+ return 'OMY', mods # (betar)-3-chloro-beta-hydroxy-l-tyrosine
366
+
367
+ # Heterocyclic patterns
368
+ if 'n1' in content:
369
+ if 'n1cccn1' in content:
370
+ return 'PYZ1', mods # 3-(1-Pyrazolyl)-alanine
371
+ if 'n1nncn1' in content:
372
+ return 'TEZA', mods # 3-(2-Tetrazolyl)-alanine
373
+ if 'c2c(n1)cccc2' in content:
374
+ return 'QU32', mods # 3-(2-Quinolyl)-alanine
375
+ if 'c1cnc2c(c1)cccc2' in content:
376
+ return 'QU33', mods # 3-(3-quinolyl)-alanine
377
+ if 'c1ccnc2c1cccc2' in content:
378
+ return 'QU34', mods # 3-(4-quinolyl)-alanine
379
+ if 'c1ccc2c(c1)nccc2' in content:
380
+ return 'QU35', mods # 3-(5-Quinolyl)-alanine
381
+ if 'c1ccc2c(c1)cncc2' in content:
382
+ return 'QU36', mods # 3-(6-Quinolyl)-alanine
383
+ if 'c1cnc2c(n1)cccc2' in content:
384
+ return 'QX32', mods # 3-(2-quinoxalyl)-alanine
385
+
386
+ # Multiple nitrogen patterns
387
+ if 'N' in content:
388
+ if '[NH3]CC[C@@H]' in content:
389
+ return 'DAB', mods # Diaminobutyric acid
390
+ if '[NH3]C[C@@H]' in content:
391
+ return 'DPP', mods # 2,3-Diaminopropanoic acid
392
+ if '[NH3]CCCCCC[C@@H]' in content:
393
+ return 'HHK', mods # (2s)-2,8-diaminooctanoic acid
394
+ if 'CCC[NH]=[C](=[NH2])=[NH2]' in content:
395
+ return 'GBUT', mods # 2-Amino-4-guanidinobutryric acid
396
+ if '[NH]=[C](=S)=[NH2]' in content:
397
+ return 'THIC', mods # Thio-citrulline
398
+
399
+ # Chain modified amino acids
400
+ if 'CC' in content:
401
+ if 'CCCC[C@@H]' in content:
402
+ return 'AHP', mods # 2-Aminoheptanoic acid
403
+ if 'CCC([C@@H])(C)C' in content:
404
+ return 'I2M', mods # 3-methyl-l-alloisoleucine
405
+ if 'CC[C@H]([C@@H])C' in content:
406
+ return 'IIL', mods # Allo-Isoleucine
407
+ if '[C@H](CCC(C)C)' in content:
408
+ return 'HLEU', mods # Homoleucine
409
+ if '[C@@H]([C@@H](C)O)C' in content:
410
+ return 'HLU', mods # beta-hydroxyleucine
411
+
412
+ # Modified glutamate/aspartate patterns
413
+ if '[C@@H]' in content:
414
+ if '[C@@H](C[C@@H](F))' in content:
415
+ return 'FGA4', mods # 4-Fluoro-glutamic acid
416
+ if '[C@@H](C[C@@H](O))' in content:
417
+ return '3GL', mods # 4-hydroxy-glutamic-acid
418
+ if '[C@@H](C[C@H](C))' in content:
419
+ return 'LME', mods # (3r)-3-methyl-l-glutamic acid
420
+ if '[C@@H](CC[C@H](C))' in content:
421
+ return 'MEG', mods # (3s)-3-methyl-l-glutamic acid
422
+
423
+ # Sulfur and selenium modifications
424
+ if 'S' in content:
425
+ if 'SCC[C@@H]' in content:
426
+ return 'HSER', mods # homoserine
427
+ if 'SCCN' in content:
428
+ return 'SLZ', mods # thialysine
429
+ if 'SC(=O)' in content:
430
+ return 'CSA', mods # s-acetonylcysteine
431
+ if '[S@@](=O)' in content:
432
+ return 'SME', mods # Methionine sulfoxide
433
+ if 'S(=O)(=O)' in content:
434
+ return 'OMT', mods # Methionine sulfone
435
+
436
+ # Double bond containing
437
+ if 'C=' in content:
438
+ if 'C=C[C@@H]' in content:
439
+ return '2AG', mods # 2-Allyl-glycine
440
+ if 'C=C[C@@H]' in content:
441
+ return 'LVG', mods # vinylglycine
442
+ if 'C=Cc1ccccc1' in content:
443
+ return 'STYA', mods # Styrylalanine
444
+
445
+ # Special cases
446
+ if '[C@@H]1Cc2c(C1)cccc2' in content:
447
+ return 'IGL', mods # alpha-amino-2-indanacetic acid
448
+ if '[C](=[C](=O)=O)=O' in content:
449
+ return '26P', mods # 2-amino-6-oxopimelic acid
450
+ if '[C](=[C](=O)=O)=C' in content:
451
+ return '2NP', mods # l-2-amino-6-methylene-pimelic acid
452
+ if 'c2cnc[nH]2' in content:
453
+ return 'HIS', mods # histidine core
454
+ if 'c1cccc2c1cc(O)cc2' in content:
455
+ return 'NAO1', mods # 5-hydroxy-1-naphthalene
456
+ if 'c1ccc2c(c1)cc(O)cc2' in content:
457
+ return 'NAO2', mods # 6-hydroxy-2-naphthalene
458
+
459
+ # Proline (P) - flexible ring numbers
460
+ if any([
461
+ # Check for any ring number in bond patterns
462
+ (segment.get('bond_after', '').startswith(f'N{n}C(=O)') and 'CCC' in content and
463
+ any(f'[C@@H]{n}' in content or f'[C@H]{n}' in content for n in '123456789'))
464
+ for n in '123456789'
465
+ ]) or any([
466
+ # Check ending patterns with any ring number
467
+ (f'CCCN{n}' in content and content.endswith('=O') and
468
+ any(f'[C@@H]{n}' in content or f'[C@H]{n}' in content for n in '123456789'))
469
+ for n in '123456789'
470
+ ]) or any([
471
+ # Handle CCC[C@H]n patterns
472
+ (content == f'CCC[C@H]{n}' and segment.get('bond_before', '').startswith(f'C(=O)N{n}')) or
473
+ (content == f'CCC[C@@H]{n}' and segment.get('bond_before', '').startswith(f'C(=O)N{n}')) or
474
+ # N-terminal Pro with any ring number
475
+ (f'N{n}CCC[C@H]{n}' in content) or
476
+ (f'N{n}CCC[C@@H]{n}' in content)
477
+ for n in '123456789'
478
+ ]):
479
+ return 'Pro', mods
480
+
481
+ # Tryptophan (W) - more specific indole pattern
482
+ if re.search(r'c[0-9]c\[nH\]c[0-9]ccccc[0-9][0-9]', content) and \
483
+ 'c[nH]c' in content.replace(' ', ''):
484
+ return 'Trp', mods
485
+
486
+ # Lysine (K) - both patterns
487
+ if '[C@@H](CCCCN)' in content or '[C@H](CCCCN)' in content:
488
+ return 'Lys', mods
489
+
490
+ # Arginine (R) - both patterns
491
+ if '[C@@H](CCCNC(=N)N)' in content or '[C@H](CCCNC(=N)N)' in content:
492
+ return 'Arg', mods
493
+
494
+ if ('C[C@H](CCCC)' in content or 'C[C@@H](CCCC)' in content) and 'CC(C)' not in content:
495
+ return 'Nle', mods
496
+
497
+ # Ornithine (Orn) - 3-carbon chain with NH2
498
+ if ('C[C@H](CCCN)' in content or 'C[C@@H](CCCN)' in content) and 'CC(C)' not in content:
499
+ return 'Orn', mods
500
+
501
+ # 2-Naphthylalanine (2Nal) - distinct from Phe pattern
502
+ if ('Cc3cc2ccccc2c3' in content) and ('C[C@H]' in content or 'C[C@@H]' in content):
503
+ return '2Nal', mods
504
+
505
+ # Cyclohexylalanine (Cha) - already in your code but moved here for clarity
506
+ if 'N2CCCCC2' in content or 'CCCCC2' in content:
507
+ return 'Cha', mods
508
+
509
+ # Aminobutyric acid (Abu) - 2-carbon chain
510
+ if ('C[C@H](CC)' in content or 'C[C@@H](CC)' in content) and not any(p in content for p in ['CC(C)', 'CCCC', 'CCC(C)']):
511
+ return 'Abu', mods
512
+
513
+ # Pipecolic acid (Pip) - 6-membered ring like Pro
514
+ if ('N3CCCCC3' in content or 'CCCCC3' in content) and ('C[C@H]' in content or 'C[C@@H]' in content):
515
+ return 'Pip', mods
516
+
517
+ # Cyclohexylglycine (Chg) - direct cyclohexyl without CH2
518
+ if ('C[C@H](C1CCCCC1)' in content or 'C[C@@H](C1CCCCC1)' in content):
519
+ return 'Chg', mods
520
+
521
+ # 4-Fluorophenylalanine (4F-Phe)
522
+ if ('Cc2ccc(F)cc2' in content) and ('C[C@H]' in content or 'C[C@@H]' in content):
523
+ return '4F-Phe', mods
524
+
525
+ # Regular residue identification
526
+ if ('NCC(=O)' in content) or (content == 'C'):
527
+ # Middle case - between bonds
528
+ if segment.get('bond_before') and segment.get('bond_after'):
529
+ if ('C(=O)N' in segment['bond_before'] or 'C(=O)N(C)' in segment['bond_before']):
530
+ return 'Gly', mods
531
+ # Terminal case - at the end
532
+ elif segment.get('bond_before') and segment.get('bond_before').startswith('C(=O)N'):
533
+ return 'Gly', mods
534
+
535
+ if 'CC(C)C[C@H]' in content or 'CC(C)C[C@@H]' in content:
536
+ return 'Leu', mods
537
+ if '[C@@H](CC(C)C)' in content or '[C@H](CC(C)C)' in content:
538
+ return 'Leu', mods
539
+
540
+ if '[C@@H]([C@@H](C)O)' in content or '[C@H]([C@H](C)O)' in content:
541
+ return 'Thr', mods
542
+
543
+ if '[C@H](Cc2ccccc2)' in content or '[C@@H](Cc2ccccc2)' in content:
544
+ return 'Phe', mods
545
+
546
+ if ('[C@H](C(C)C)' in content or # With outer parentheses
547
+ '[C@@H](C(C)C)' in content or # With outer parentheses
548
+ '[C@H]C(C)C' in content or # Without outer parentheses
549
+ '[C@@H]C(C)C' in content): # Without outer parentheses
550
+ if not any(p in content for p in ['CC(C)C[C@H]', 'CC(C)C[C@@H]']): # Still check not Leu
551
+ return 'Val', mods
552
+
553
+ if '[C@H](COC(C)(C)C)' in content or '[C@@H](COC(C)(C)C)' in content:
554
+ return 'O-tBu', mods
555
+
556
+ if any([
557
+ 'CC[C@H](C)' in content,
558
+ 'CC[C@@H](C)' in content,
559
+ 'C(C)C[C@H]' in content and 'CC(C)C' not in content,
560
+ 'C(C)C[C@@H]' in content and 'CC(C)C' not in content
561
+ ]):
562
+ return 'Ile', mods
563
+
564
+ if ('[C@H](C)' in content or '[C@@H](C)' in content):
565
+ if not any(p in content for p in ['C(C)C', 'COC', 'CN(', 'C(C)O', 'CC[C@H]', 'CC[C@@H]']):
566
+ return 'Ala', mods
567
+
568
+ # Tyrosine (Tyr) - 4-hydroxybenzyl side chain
569
+ if re.search(r'Cc[0-9]ccc\(O\)cc[0-9]', content):
570
+ return 'Tyr', mods
571
+
572
+
573
+ # Serine (Ser) - Hydroxymethyl side chain
574
+ if '[C@H](CO)' in content or '[C@@H](CO)' in content:
575
+ if not ('C(C)O' in content or 'COC' in content):
576
+ return 'Ser', mods
577
+
578
+ # Threonine (Thr) - 1-hydroxyethyl side chain
579
+ if '[C@@H]([C@@H](C)O)' in content or '[C@H]([C@H](C)O)' in content or '[C@@H](C)O' in content or '[C@H](C)O' in content:
580
+ return 'Thr', mods
581
+
582
+ # Cysteine (Cys) - Thiol side chain
583
+ if '[C@H](CS)' in content or '[C@@H](CS)' in content:
584
+ return 'Cys', mods
585
+
586
+ # Methionine (Met) - Methylthioethyl side chain
587
+ if ('C[C@H](CCSC)' in content or 'C[C@@H](CCSC)' in content):
588
+ return 'Met', mods
589
+
590
+ # Asparagine (Asn) - Carbamoylmethyl side chain
591
+ if ('CC(=O)N' in content) and ('C[C@H]' in content or 'C[C@@H]' in content):
592
+ return 'Asn', mods
593
+
594
+ # Glutamine (Gln) - Carbamoylethyl side chain
595
+ if ('CCC(=O)N' in content) and ('C[C@H]' in content or 'C[C@@H]' in content):
596
+ return 'Gln', mods
597
+
598
+ # Aspartic acid (Asp) - Carboxymethyl side chain
599
+ if ('CC(=O)O' in content) and ('C[C@H]' in content or 'C[C@@H]' in content):
600
+ return 'Asp', mods
601
+
602
+ # Glutamic acid (Glu) - Carboxyethyl side chain
603
+ if ('CCC(=O)O' in content) and ('C[C@H]' in content or 'C[C@@H]' in content):
604
+ return 'Glu', mods
605
+
606
+ # Arginine (Arg) - 3-guanidinopropyl side chain
607
+ if ('CCCNC(=N)N' in content) and ('C[C@H]' in content or 'C[C@@H]' in content):
608
+ return 'Arg', mods
609
+
610
+ # Histidine (His) - Imidazole side chain
611
+ if ('Cc2cnc[nH]2' in content) and ('C[C@H]' in content or 'C[C@@H]' in content):
612
+ return 'His', mods
613
+
614
+ return None, mods
615
+
616
+ def get_modifications(self, segment):
617
+ """Get modifications based on bond types"""
618
+ mods = []
619
+ if segment.get('bond_after'):
620
+ if 'N(C)' in segment['bond_after'] or segment['bond_after'].startswith('C(=O)N(C)'):
621
+ mods.append('N-Me')
622
+ if 'OC(=O)' in segment['bond_after']:
623
+ mods.append('O-linked')
624
+ return mods
625
+
626
+ def analyze_structure(self, smiles):
627
+ """Main analysis function with debug output"""
628
+ print("\nAnalyzing structure:", smiles)
629
+
630
+ # Split into segments
631
+ segments = self.split_on_bonds(smiles)
632
+
633
+ print("\nSegment Analysis:")
634
+ sequence = []
635
+ for i, segment in enumerate(segments):
636
+ print(f"\nSegment {i}:")
637
+ print(f"Content: {segment['content']}")
638
+ print(f"Bond before: {segment.get('bond_before', 'None')}")
639
+ print(f"Bond after: {segment.get('bond_after', 'None')}")
640
+
641
+ residue, mods = self.identify_residue(segment)
642
+ if residue:
643
+ if mods:
644
+ sequence.append(f"{residue}({','.join(mods)})")
645
+ else:
646
+ sequence.append(residue)
647
+ print(f"Identified as: {residue}")
648
+ print(f"Modifications: {mods}")
649
+ else:
650
+ print(f"Warning: Could not identify residue in segment: {segment['content']}")
651
+
652
+ # Check if cyclic
653
+ is_cyclic, peptide_cycles, aromatic_cycles = self.is_cyclic(smiles)
654
+ three_letter = '-'.join(sequence)
655
+ one_letter = ''.join(self.three_to_one.get(aa.split('(')[0], 'X') for aa in sequence)
656
+
657
+ if is_cyclic:
658
+ three_letter = f"cyclo({three_letter})"
659
+ one_letter = f"cyclo({one_letter})"
660
+
661
+ print(f"\nFinal sequence: {three_letter}")
662
+ print(f"One-letter code: {one_letter}")
663
+ print(f"Is cyclic: {is_cyclic}")
664
+ #print(f"Peptide cycles: {peptide_cycles}")
665
+ #print(f"Aromatic cycles: {aromatic_cycles}")
666
+
667
+ return three_letter, len(segments)
668
+ """return {
669
+ 'three_letter': three_letter,
670
+ #'one_letter': one_letter,
671
+ 'is_cyclic': is_cyclic
672
+ }"""
673
+
674
+ def return_sequence(self, smiles):
675
+ """Main analysis function with debug output"""
676
+ print("\nAnalyzing structure:", smiles)
677
+
678
+ # Split into segments
679
+ segments = self.split_on_bonds(smiles)
680
+
681
+ print("\nSegment Analysis:")
682
+ sequence = []
683
+ for i, segment in enumerate(segments):
684
+ print(f"\nSegment {i}:")
685
+ print(f"Content: {segment['content']}")
686
+ print(f"Bond before: {segment.get('bond_before', 'None')}")
687
+ print(f"Bond after: {segment.get('bond_after', 'None')}")
688
+
689
+ residue, mods = self.identify_residue(segment)
690
+ if residue:
691
+ if mods:
692
+ sequence.append(f"{residue}({','.join(mods)})")
693
+ else:
694
+ sequence.append(residue)
695
+ print(f"Identified as: {residue}")
696
+ print(f"Modifications: {mods}")
697
+ else:
698
+ print(f"Warning: Could not identify residue in segment: {segment['content']}")
699
+
700
+ return sequence
701
+
702
+ """
703
+ def annotate_cyclic_structure(mol, sequence):
704
+ '''Create annotated 2D structure with clear, non-overlapping residue labels'''
705
+ # Generate 2D coordinates
706
+ # Generate 2D coordinates
707
+ AllChem.Compute2DCoords(mol)
708
+
709
+ # Create drawer with larger size for annotations
710
+ drawer = Draw.rdMolDraw2D.MolDraw2DCairo(2000, 2000) # Even larger size
711
+
712
+ # Get residue list and reverse it to match structural representation
713
+ if sequence.startswith('cyclo('):
714
+ residues = sequence[6:-1].split('-')
715
+ else:
716
+ residues = sequence.split('-')
717
+ residues = list(reversed(residues)) # Reverse the sequence
718
+
719
+ # Draw molecule first to get its bounds
720
+ drawer.drawOptions().addAtomIndices = False
721
+ drawer.DrawMolecule(mol)
722
+ drawer.FinishDrawing()
723
+
724
+ # Convert to PIL Image
725
+ img = Image.open(BytesIO(drawer.GetDrawingText()))
726
+ draw = ImageDraw.Draw(img)
727
+
728
+ try:
729
+ # Try to use DejaVuSans as it's commonly available on Linux systems
730
+ font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 60)
731
+ small_font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 60)
732
+ except OSError:
733
+ try:
734
+ # Fallback to Arial if available (common on Windows)
735
+ font = ImageFont.truetype("arial.ttf", 60)
736
+ small_font = ImageFont.truetype("arial.ttf", 60)
737
+ except OSError:
738
+ # If no TrueType fonts are available, fall back to default
739
+ print("Warning: TrueType fonts not available, using default font")
740
+ font = ImageFont.load_default()
741
+ small_font = ImageFont.load_default()
742
+ # Get molecule bounds
743
+ conf = mol.GetConformer()
744
+ positions = []
745
+ for i in range(mol.GetNumAtoms()):
746
+ pos = conf.GetAtomPosition(i)
747
+ positions.append((pos.x, pos.y))
748
+
749
+ x_coords = [p[0] for p in positions]
750
+ y_coords = [p[1] for p in positions]
751
+ min_x, max_x = min(x_coords), max(x_coords)
752
+ min_y, max_y = min(y_coords), max(y_coords)
753
+
754
+ # Calculate scaling factors
755
+ scale = 150 # Increased scale factor
756
+ center_x = 1000 # Image center
757
+ center_y = 1000
758
+
759
+ # Add residue labels in a circular arrangement around the structure
760
+ n_residues = len(residues)
761
+ radius = 700 # Distance of labels from center
762
+
763
+ # Start from the rightmost point (3 o'clock position) and go counterclockwise
764
+ # Offset by -3 positions to align with structure
765
+ offset = 0 # Adjust this value to match the structure alignment
766
+ for i, residue in enumerate(residues):
767
+ # Calculate position in a circle around the structure
768
+ # Start from 0 (3 o'clock) and go counterclockwise
769
+ angle = -(2 * np.pi * ((i + offset) % n_residues) / n_residues)
770
+
771
+ # Calculate label position
772
+ label_x = center_x + radius * np.cos(angle)
773
+ label_y = center_y + radius * np.sin(angle)
774
+
775
+ # Draw residue label
776
+ text = f"{i+1}. {residue}"
777
+ bbox = draw.textbbox((label_x, label_y), text, font=font)
778
+ padding = 10
779
+ draw.rectangle([bbox[0]-padding, bbox[1]-padding,
780
+ bbox[2]+padding, bbox[3]+padding],
781
+ fill='white', outline='white')
782
+ draw.text((label_x, label_y), text,
783
+ font=font, fill='black', anchor="mm")
784
+
785
+ # Add sequence at the top with white background
786
+ seq_text = f"Sequence: {sequence}"
787
+ bbox = draw.textbbox((center_x, 100), seq_text, font=small_font)
788
+ padding = 10
789
+ draw.rectangle([bbox[0]-padding, bbox[1]-padding,
790
+ bbox[2]+padding, bbox[3]+padding],
791
+ fill='white', outline='white')
792
+ draw.text((center_x, 100), seq_text,
793
+ font=small_font, fill='black', anchor="mm")
794
+
795
+ return img
796
+
797
+ """
798
+ def annotate_cyclic_structure(mol, sequence):
799
+ """Create structure visualization with just the sequence header"""
800
+ # Generate 2D coordinates
801
+ AllChem.Compute2DCoords(mol)
802
+
803
+ # Create drawer with larger size for annotations
804
+ drawer = Draw.rdMolDraw2D.MolDraw2DCairo(2000, 2000)
805
+
806
+ # Draw molecule first
807
+ drawer.drawOptions().addAtomIndices = False
808
+ drawer.DrawMolecule(mol)
809
+ drawer.FinishDrawing()
810
+
811
+ # Convert to PIL Image
812
+ img = Image.open(BytesIO(drawer.GetDrawingText()))
813
+ draw = ImageDraw.Draw(img)
814
+ try:
815
+ small_font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 60)
816
+ except OSError:
817
+ try:
818
+ small_font = ImageFont.truetype("arial.ttf", 60)
819
+ except OSError:
820
+ print("Warning: TrueType fonts not available, using default font")
821
+ small_font = ImageFont.load_default()
822
+
823
+ # Add just the sequence header at the top
824
+ seq_text = f"Sequence: {sequence}"
825
+ bbox = draw.textbbox((1000, 100), seq_text, font=small_font)
826
+ padding = 10
827
+ draw.rectangle([bbox[0]-padding, bbox[1]-padding,
828
+ bbox[2]+padding, bbox[3]+padding],
829
+ fill='white', outline='white')
830
+ draw.text((1000, 100), seq_text,
831
+ font=small_font, fill='black', anchor="mm")
832
+
833
+ return img
834
+
835
+ def create_enhanced_linear_viz(sequence, smiles):
836
+ """Create an enhanced linear representation using PeptideAnalyzer"""
837
+ analyzer = PeptideAnalyzer() # Create analyzer instance
838
+
839
+ # Create figure with two subplots
840
+ fig = plt.figure(figsize=(15, 10))
841
+ gs = fig.add_gridspec(2, 1, height_ratios=[1, 2])
842
+ ax_struct = fig.add_subplot(gs[0])
843
+ ax_detail = fig.add_subplot(gs[1])
844
+
845
+ # Parse sequence and get residues
846
+ if sequence.startswith('cyclo('):
847
+ residues = sequence[6:-1].split('-')
848
+ else:
849
+ residues = sequence.split('-')
850
+
851
+ # Get segments using analyzer
852
+ segments = analyzer.split_on_bonds(smiles)
853
+
854
+ # Debug print
855
+ print(f"Number of residues: {len(residues)}")
856
+ print(f"Number of segments: {len(segments)}")
857
+
858
+ # Top subplot - Basic structure
859
+ ax_struct.set_xlim(0, 10)
860
+ ax_struct.set_ylim(0, 2)
861
+
862
+ num_residues = len(residues)
863
+ spacing = 9.0 / (num_residues - 1) if num_residues > 1 else 9.0
864
+
865
+ # Draw basic structure
866
+ y_pos = 1.5
867
+ for i in range(num_residues):
868
+ x_pos = 0.5 + i * spacing
869
+
870
+ # Draw amino acid box
871
+ rect = patches.Rectangle((x_pos-0.3, y_pos-0.2), 0.6, 0.4,
872
+ facecolor='lightblue', edgecolor='black')
873
+ ax_struct.add_patch(rect)
874
+
875
+ # Draw connecting bonds if not the last residue
876
+ if i < num_residues - 1:
877
+ segment = segments[i] if i < len(segments) else None
878
+ if segment:
879
+ # Determine bond type from segment info
880
+ bond_type = 'ester' if 'O-linked' in segment.get('bond_after', '') else 'peptide'
881
+ is_n_methylated = 'N-Me' in segment.get('bond_after', '')
882
+
883
+ bond_color = 'red' if bond_type == 'ester' else 'black'
884
+ linestyle = '--' if bond_type == 'ester' else '-'
885
+
886
+ # Draw bond line
887
+ ax_struct.plot([x_pos+0.3, x_pos+spacing-0.3], [y_pos, y_pos],
888
+ color=bond_color, linestyle=linestyle, linewidth=2)
889
+
890
+ # Add bond type label
891
+ mid_x = x_pos + spacing/2
892
+ bond_label = f"{bond_type}"
893
+ if is_n_methylated:
894
+ bond_label += "\n(N-Me)"
895
+ ax_struct.text(mid_x, y_pos+0.1, bond_label,
896
+ ha='center', va='bottom', fontsize=10,
897
+ color=bond_color)
898
+
899
+ # Add residue label
900
+ ax_struct.text(x_pos, y_pos-0.5, residues[i],
901
+ ha='center', va='top', fontsize=14)
902
+
903
+ # Bottom subplot - Detailed breakdown
904
+ ax_detail.set_ylim(0, len(segments)+1)
905
+ ax_detail.set_xlim(0, 1)
906
+
907
+ # Create detailed breakdown
908
+ segment_y = len(segments) # Start from top
909
+ for i, segment in enumerate(segments):
910
+ y = segment_y - i
911
+
912
+ # Check if this is a bond or residue
913
+ residue, mods = analyzer.identify_residue(segment)
914
+ if residue:
915
+ text = f"Residue {i+1}: {residue}"
916
+ if mods:
917
+ text += f" ({', '.join(mods)})"
918
+ color = 'blue'
919
+ else:
920
+ # Must be a bond
921
+ text = f"Bond {i}: "
922
+ if 'O-linked' in segment.get('bond_after', ''):
923
+ text += "ester"
924
+ elif 'N-Me' in segment.get('bond_after', ''):
925
+ text += "peptide (N-methylated)"
926
+ else:
927
+ text += "peptide"
928
+ color = 'red'
929
+
930
+ # Add segment analysis
931
+ ax_detail.text(0.05, y, text, fontsize=12, color=color)
932
+ ax_detail.text(0.5, y, f"SMILES: {segment.get('content', '')}", fontsize=10, color='gray')
933
+
934
+ # If cyclic, add connection indicator
935
+ if sequence.startswith('cyclo('):
936
+ ax_struct.annotate('', xy=(9.5, y_pos), xytext=(0.5, y_pos),
937
+ arrowprops=dict(arrowstyle='<->', color='red', lw=2))
938
+ ax_struct.text(5, y_pos+0.3, 'Cyclic Connection',
939
+ ha='center', color='red', fontsize=14)
940
+
941
+ # Add titles and adjust layout
942
+ ax_struct.set_title("Peptide Structure Overview", pad=20)
943
+ ax_detail.set_title("Segment Analysis Breakdown", pad=20)
944
+
945
+ # Remove axes
946
+ for ax in [ax_struct, ax_detail]:
947
+ ax.set_xticks([])
948
+ ax.set_yticks([])
949
+ ax.axis('off')
950
+
951
+ plt.tight_layout()
952
+ return fig
953
+
954
+ class PeptideStructureGenerator:
955
+ """A class to generate 3D structures of peptides using different embedding methods"""
956
+
957
+ @staticmethod
958
+ def prepare_molecule(smiles):
959
+ """Prepare molecule with proper hydrogen handling"""
960
+ mol = Chem.MolFromSmiles(smiles, sanitize=False)
961
+ if mol is None:
962
+ raise ValueError("Failed to create molecule from SMILES")
963
+
964
+ # Calculate valence for each atom
965
+ for atom in mol.GetAtoms():
966
+ atom.UpdatePropertyCache(strict=False)
967
+
968
+ # Sanitize with reduced requirements
969
+ Chem.SanitizeMol(mol,
970
+ sanitizeOps=Chem.SANITIZE_FINDRADICALS|
971
+ Chem.SANITIZE_KEKULIZE|
972
+ Chem.SANITIZE_SETAROMATICITY|
973
+ Chem.SANITIZE_SETCONJUGATION|
974
+ Chem.SANITIZE_SETHYBRIDIZATION|
975
+ Chem.SANITIZE_CLEANUPCHIRALITY)
976
+
977
+ mol = Chem.AddHs(mol)
978
+ return mol
979
+
980
+ @staticmethod
981
+ def get_etkdg_params(attempt=0):
982
+ """Get ETKDG parameters with optional modifications based on attempt number"""
983
+ params = AllChem.ETKDGv3()
984
+ params.randomSeed = -1
985
+ params.maxIterations = 200
986
+ params.numThreads = 4 # Reduced for web interface
987
+ params.useBasicKnowledge = True
988
+ params.enforceChirality = True
989
+ params.useExpTorsionAnglePrefs = True
990
+ params.useSmallRingTorsions = True
991
+ params.useMacrocycleTorsions = True
992
+ params.ETversion = 2
993
+ params.pruneRmsThresh = -1
994
+ params.embedRmsThresh = 0.5
995
+
996
+ if attempt > 10:
997
+ params.bondLength = 1.5 + (attempt - 10) * 0.02
998
+ params.useExpTorsionAnglePrefs = False
999
+
1000
+ return params
1001
+
1002
+ def generate_structure_etkdg(self, smiles, max_attempts=20):
1003
+ """Generate 3D structure using ETKDG without UFF optimization"""
1004
+ success = False
1005
+ mol = None
1006
+
1007
+ for attempt in range(max_attempts):
1008
+ try:
1009
+ mol = self.prepare_molecule(smiles)
1010
+ params = self.get_etkdg_params(attempt)
1011
+
1012
+ if AllChem.EmbedMolecule(mol, params) == 0:
1013
+ success = True
1014
+ break
1015
+ except Exception as e:
1016
+ continue
1017
+
1018
+ if not success:
1019
+ raise ValueError("Failed to generate structure with ETKDG")
1020
+
1021
+ return mol
1022
+
1023
+ def generate_structure_uff(self, smiles, max_attempts=20):
1024
+ """Generate 3D structure using ETKDG followed by UFF optimization"""
1025
+ best_mol = None
1026
+ lowest_energy = float('inf')
1027
+
1028
+ for attempt in range(max_attempts):
1029
+ try:
1030
+ test_mol = self.prepare_molecule(smiles)
1031
+ params = self.get_etkdg_params(attempt)
1032
+
1033
+ if AllChem.EmbedMolecule(test_mol, params) == 0:
1034
+ res = AllChem.UFFOptimizeMolecule(test_mol, maxIters=2000,
1035
+ vdwThresh=10.0, confId=0,
1036
+ ignoreInterfragInteractions=True)
1037
+
1038
+ if res == 0:
1039
+ ff = AllChem.UFFGetMoleculeForceField(test_mol)
1040
+ if ff:
1041
+ current_energy = ff.CalcEnergy()
1042
+ if current_energy < lowest_energy:
1043
+ lowest_energy = current_energy
1044
+ best_mol = Chem.Mol(test_mol)
1045
+ except Exception:
1046
+ continue
1047
+
1048
+ if best_mol is None:
1049
+ raise ValueError("Failed to generate optimized structure")
1050
+
1051
+ return best_mol
1052
+
1053
+ @staticmethod
1054
+ def mol_to_sdf_bytes(mol):
1055
+ """Convert RDKit molecule to SDF file bytes"""
1056
+ # First write to StringIO in text mode
1057
+ sio = StringIO()
1058
+ writer = Chem.SDWriter(sio)
1059
+ writer.write(mol)
1060
+ writer.close()
1061
+
1062
+ # Convert the string to bytes
1063
+ return sio.getvalue().encode('utf-8')
1064
+
1065
+ def process_input(smiles_input=None, file_obj=None, show_linear=False,
1066
+ show_segment_details=False, generate_3d=False, use_uff=False):
1067
+ """Process input and create visualizations using PeptideAnalyzer"""
1068
+ analyzer = PeptideAnalyzer()
1069
+ temp_dir = tempfile.mkdtemp() if generate_3d else None
1070
+ structure_files = []
1071
+
1072
+ # Handle direct SMILES input
1073
+ if smiles_input:
1074
+ smiles = smiles_input.strip()
1075
+
1076
+ # First check if it's a peptide using analyzer's method
1077
+ if not analyzer.is_peptide(smiles):
1078
+ return "Error: Input SMILES does not appear to be a peptide structure.", None, None
1079
+
1080
+ try:
1081
+ # Create molecule
1082
+ mol = Chem.MolFromSmiles(smiles)
1083
+ if mol is None:
1084
+ return "Error: Invalid SMILES notation.", None, None
1085
+
1086
+ # Generate 3D structures if requested
1087
+ if generate_3d:
1088
+ generator = PeptideStructureGenerator()
1089
+
1090
+ try:
1091
+ # Generate ETKDG structure
1092
+ mol_etkdg = generator.generate_structure_etkdg(smiles)
1093
+ etkdg_path = os.path.join(temp_dir, "structure_etkdg.sdf")
1094
+ writer = Chem.SDWriter(etkdg_path)
1095
+ writer.write(mol_etkdg)
1096
+ writer.close()
1097
+ structure_files.append(etkdg_path)
1098
+
1099
+ # Generate UFF structure if requested
1100
+ if use_uff:
1101
+ mol_uff = generator.generate_structure_uff(smiles)
1102
+ uff_path = os.path.join(temp_dir, "structure_uff.sdf")
1103
+ writer = Chem.SDWriter(uff_path)
1104
+ writer.write(mol_uff)
1105
+ writer.close()
1106
+ structure_files.append(uff_path)
1107
+
1108
+ except Exception as e:
1109
+ return f"Error generating 3D structures: {str(e)}", None, None, None
1110
+
1111
+ # Use analyzer to get sequence
1112
+ segments = analyzer.split_on_bonds(smiles)
1113
+
1114
+ # Process segments and build sequence
1115
+ sequence_parts = []
1116
+ output_text = ""
1117
+
1118
+ # Only include segment analysis in output if requested
1119
+ if show_segment_details:
1120
+ output_text += "Segment Analysis:\n"
1121
+ for i, segment in enumerate(segments):
1122
+ output_text += f"\nSegment {i}:\n"
1123
+ output_text += f"Content: {segment['content']}\n"
1124
+ output_text += f"Bond before: {segment.get('bond_before', 'None')}\n"
1125
+ output_text += f"Bond after: {segment.get('bond_after', 'None')}\n"
1126
+
1127
+ residue, mods = analyzer.identify_residue(segment)
1128
+ if residue:
1129
+ if mods:
1130
+ sequence_parts.append(f"{residue}({','.join(mods)})")
1131
+ else:
1132
+ sequence_parts.append(residue)
1133
+ output_text += f"Identified as: {residue}\n"
1134
+ output_text += f"Modifications: {mods}\n"
1135
+ else:
1136
+ output_text += f"Warning: Could not identify residue in segment: {segment['content']}\n"
1137
+ output_text += "\n"
1138
+ else:
1139
+ # Just build sequence without detailed analysis in output
1140
+ for segment in segments:
1141
+ residue, mods = analyzer.identify_residue(segment)
1142
+ if residue:
1143
+ if mods:
1144
+ sequence_parts.append(f"{residue}({','.join(mods)})")
1145
+ else:
1146
+ sequence_parts.append(residue)
1147
+
1148
+ # Check if cyclic using analyzer's method
1149
+ is_cyclic, peptide_cycles, aromatic_cycles = analyzer.is_cyclic(smiles)
1150
+ three_letter = '-'.join(sequence_parts)
1151
+ one_letter = ''.join(analyzer.three_to_one.get(aa.split('(')[0], 'X') for aa in sequence_parts)
1152
+
1153
+ if is_cyclic:
1154
+ three_letter = f"cyclo({three_letter})"
1155
+ one_letter = f"cyclo({one_letter})"
1156
+
1157
+ # Create cyclic structure visualization
1158
+ img_cyclic = annotate_cyclic_structure(mol, three_letter)
1159
+
1160
+ # Create linear representation if requested
1161
+ img_linear = None
1162
+ if show_linear:
1163
+ fig_linear = create_enhanced_linear_viz(three_letter, smiles)
1164
+ buf = BytesIO()
1165
+ fig_linear.savefig(buf, format='png', bbox_inches='tight', dpi=300)
1166
+ buf.seek(0)
1167
+ img_linear = Image.open(buf)
1168
+ plt.close(fig_linear)
1169
+
1170
+ # Add summary to output
1171
+ summary = "Summary:\n"
1172
+ summary += f"Sequence: {three_letter}\n"
1173
+ summary += f"One-letter code: {one_letter}\n"
1174
+ summary += f"Is Cyclic: {'Yes' if is_cyclic else 'No'}\n"
1175
+ #if is_cyclic:
1176
+ #summary += f"Peptide Cycles: {', '.join(peptide_cycles)}\n"
1177
+ #summary += f"Aromatic Cycles: {', '.join(aromatic_cycles)}\n"
1178
+
1179
+ if structure_files:
1180
+ summary += "\n3D Structures Generated:\n"
1181
+ for filepath in structure_files:
1182
+ summary += f"- {os.path.basename(filepath)}\n"
1183
+
1184
+ return summary + output_text, img_cyclic, img_linear, structure_files if structure_files else None
1185
+
1186
+ except Exception as e:
1187
+ return f"Error processing SMILES: {str(e)}", None, None, None
1188
+
1189
+ # Handle file input
1190
+ if file_obj is not None:
1191
+ try:
1192
+ # Handle file content
1193
+ if hasattr(file_obj, 'name'):
1194
+ with open(file_obj.name, 'r') as f:
1195
+ content = f.read()
1196
+ else:
1197
+ content = file_obj.decode('utf-8') if isinstance(file_obj, bytes) else str(file_obj)
1198
+
1199
+ output_text = ""
1200
+ for line in content.splitlines():
1201
+ smiles = line.strip()
1202
+ if smiles:
1203
+ # Check if it's a peptide
1204
+ if not analyzer.is_peptide(smiles):
1205
+ output_text += f"Skipping non-peptide SMILES: {smiles}\n"
1206
+ continue
1207
+
1208
+ # Process this SMILES
1209
+ segments = analyzer.split_on_bonds(smiles)
1210
+ sequence_parts = []
1211
+
1212
+ # Add segment details if requested
1213
+ if show_segment_details:
1214
+ output_text += f"\nSegment Analysis for SMILES: {smiles}\n"
1215
+ for i, segment in enumerate(segments):
1216
+ output_text += f"\nSegment {i}:\n"
1217
+ output_text += f"Content: {segment['content']}\n"
1218
+ output_text += f"Bond before: {segment.get('bond_before', 'None')}\n"
1219
+ output_text += f"Bond after: {segment.get('bond_after', 'None')}\n"
1220
+ residue, mods = analyzer.identify_residue(segment)
1221
+ if residue:
1222
+ if mods:
1223
+ sequence_parts.append(f"{residue}({','.join(mods)})")
1224
+ else:
1225
+ sequence_parts.append(residue)
1226
+ output_text += f"Identified as: {residue}\n"
1227
+ output_text += f"Modifications: {mods}\n"
1228
+ else:
1229
+ for segment in segments:
1230
+ residue, mods = analyzer.identify_residue(segment)
1231
+ if residue:
1232
+ if mods:
1233
+ sequence_parts.append(f"{residue}({','.join(mods)})")
1234
+ else:
1235
+ sequence_parts.append(residue)
1236
+
1237
+ # Get cyclicity and create sequence
1238
+ is_cyclic, peptide_cycles, aromatic_cycles = analyzer.is_cyclic(smiles)
1239
+ sequence = f"cyclo({'-'.join(sequence_parts)})" if is_cyclic else '-'.join(sequence_parts)
1240
+
1241
+ output_text += f"\nSummary for SMILES: {smiles}\n"
1242
+ output_text += f"Sequence: {sequence}\n"
1243
+ output_text += f"Is Cyclic: {'Yes' if is_cyclic else 'No'}\n"
1244
+ if is_cyclic:
1245
+ output_text += f"Peptide Cycles: {', '.join(peptide_cycles)}\n"
1246
+ #output_text += f"Aromatic Cycles: {', '.join(aromatic_cycles)}\n"
1247
+ output_text += "-" * 50 + "\n"
1248
+
1249
+ return output_text, None, None
1250
+
1251
+ except Exception as e:
1252
+ return f"Error processing file: {str(e)}", None, None
1253
+
1254
+ return "No input provided.", None, None
1255
+
src/utils/generate_utils.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import math
3
+ import sys
4
+ import pandas as pd
5
+ from transformers import AutoModelForMaskedLM, AutoModel, AutoTokenizer
6
+
7
+
8
+ def mask_for_de_novo(config, sequence_length):
9
+ if config.vocab == 'helm':
10
+ return "[MASK]" * sequence_length
11
+ elif config.vocab == 'new_smiles' or config.vocab == 'selfies':
12
+ return ["<mask>"] * sequence_length
13
+ else:
14
+ return ["[MASK]"] * sequence_length
15
+
16
+ def generate_de_novo(sequence_length, tokenizer, model):
17
+ masked_sequence = mask_for_de_novo(sequence_length)
18
+ inputs = tokenizer(masked_sequence, return_tensors='pt').to(model.device)
19
+
20
+ with torch.no_grad():
21
+ logits = model(**inputs).logits
22
+ mask_token_indices = (inputs["input_ids"] == tokenizer.mask_token_id).nonzero(as_tuple=True)[1]
23
+ logits_at_masks = logits[0, mask_token_indices]
24
+
25
+ pred_tokens = []
26
+ for i in mask_token_indices:
27
+ topk_logits, topk_indices = logits_at_masks[i].topk(k=3, dim=-1)
28
+ probabilities = torch.nn.functional.softmax(topk_logits, dim=-1)
29
+ predicted_index = torch.distributions.categorical.Categorical(probabilities).sample()
30
+ predicted_token_id = topk_indices[predicted_index].item()
31
+ predicted_token = tokenizer.decode([predicted_token_id], skip_special_tokens=True)
32
+ pred_tokens.append(predicted_token)
33
+
34
+ generated_sequence = ''.join(pred_tokens)
35
+ perplexity = calculate_perplexity(model, tokenizer, generated_sequence)
36
+
37
+ return (generated_sequence, perplexity)
38
+
39
+
40
+ def calculate_perplexity(model, tokenizer, generated_sequence, mask_token_indices):
41
+ total_loss = 0.0
42
+ tensor_input = tokenizer.encode(generated_sequence, return_tensors='pt').to(model.device)
43
+
44
+ for i in mask_token_indices:
45
+ masked_input = tensor_input.clone()
46
+ masked_input[0, i] = tokenizer.mask_token_id
47
+
48
+ labels = torch.full(tensor_input.shape, -100).to(model.device)
49
+ labels[0, i] = tensor_input[0, i]
50
+
51
+ with torch.no_grad():
52
+ outputs = model(masked_input, labels=labels)
53
+ total_loss += outputs.loss.item()
54
+
55
+ num_mask_tokens = len(mask_token_indices)
56
+ if num_mask_tokens == 0:
57
+ perplexity = 10000
58
+ else:
59
+ avg_loss = total_loss / num_mask_tokens
60
+ perplexity = math.exp(avg_loss)
61
+
62
+ return perplexity
63
+
64
+
65
+ def calculate_cosine_sim(original_sequence, generated_sequence, tokenizer, pepclm_model, device):
66
+ og_embeddings = pepclm_model.roformer.encoder(original_sequence)
67
+ new_embeddings = pepclm_model.roformer.encoder(generated_sequence)
68
+
69
+ sequence_similarity = torch.nn.functional.cosine_similarity(og_embeddings, new_embeddings, dim=-1)
70
+ cosine_similarity = torch.mean(sequence_similarity).item()
71
+ return cosine_similarity
72
+
73
+
74
+ def calculate_hamming_dist(original_sequence, generated_sequence):
75
+ generated_sequence = generated_sequence
76
+ original_sequence = original_sequence
77
+ return sum(1 if original_sequence[i] != generated_sequence[i] else 0 for i in range(len(original_sequence)))
src/utils/utils.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Console logger utilities.
2
+
3
+ Copied from https://github.com/HazyResearch/transformers/blob/master/src/utils/utils.py
4
+ Copied from https://docs.python.org/3/howto/logging-cookbook.html#using-a-context-manager-for-selective-logging
5
+ """
6
+
7
+ import logging
8
+ import math
9
+
10
+ import fsspec
11
+ import lightning
12
+ import torch
13
+ from timm.scheduler import CosineLRScheduler
14
+ from multiprocessing import Pool
15
+
16
+
17
+ def fsspec_exists(filename):
18
+ """Check if a file exists using fsspec."""
19
+ fs, _ = fsspec.core.url_to_fs(filename)
20
+ return fs.exists(filename)
21
+
22
+
23
+ def fsspec_listdir(dirname):
24
+ """Listdir in manner compatible with fsspec."""
25
+ fs, _ = fsspec.core.url_to_fs(dirname)
26
+ return fs.ls(dirname)
27
+
28
+
29
+ def fsspec_mkdirs(dirname, exist_ok=True):
30
+ """Mkdirs in manner compatible with fsspec."""
31
+ fs, _ = fsspec.core.url_to_fs(dirname)
32
+ fs.makedirs(dirname, exist_ok=exist_ok)
33
+
34
+
35
+ def print_nans(tensor, name):
36
+ if torch.isnan(tensor).any():
37
+ print(name, tensor)
38
+
39
+
40
+ class CosineDecayWarmupLRScheduler(
41
+ CosineLRScheduler,
42
+ torch.optim.lr_scheduler._LRScheduler):
43
+ """Wrap timm.scheduler.CosineLRScheduler
44
+ Enables calling scheduler.step() without passing in epoch.
45
+ Supports resuming as well.
46
+ Adapted from:
47
+ https://github.com/HazyResearch/hyena-dna/blob/main/src/utils/optim/schedulers.py
48
+ """
49
+
50
+ def __init__(self, *args, **kwargs):
51
+ super().__init__(*args, **kwargs)
52
+ self._last_epoch = -1
53
+ self.step(epoch=0)
54
+
55
+ def step(self, epoch=None):
56
+ if epoch is None:
57
+ self._last_epoch += 1
58
+ else:
59
+ self._last_epoch = epoch
60
+ # We call either step or step_update, depending on
61
+ # whether we're using the scheduler every epoch or every
62
+ # step.
63
+ # Otherwise, lightning will always call step (i.e.,
64
+ # meant for each epoch), and if we set scheduler
65
+ # interval to "step", then the learning rate update will
66
+ # be wrong.
67
+ if self.t_in_epochs:
68
+ super().step(epoch=self._last_epoch)
69
+ else:
70
+ super().step_update(num_updates=self._last_epoch)
71
+
72
+
73
+ class LoggingContext:
74
+ """Context manager for selective logging."""
75
+ def __init__(self, logger, level=None, handler=None, close=True):
76
+ self.logger = logger
77
+ self.level = level
78
+ self.handler = handler
79
+ self.close = close
80
+
81
+ def __enter__(self):
82
+ if self.level is not None:
83
+ self.old_level = self.logger.level
84
+ self.logger.setLevel(self.level)
85
+ if self.handler:
86
+ self.logger.addHandler(self.handler)
87
+
88
+ def __exit__(self, et, ev, tb):
89
+ if self.level is not None:
90
+ self.logger.setLevel(self.old_level)
91
+ if self.handler:
92
+ self.logger.removeHandler(self.handler)
93
+ if self.handler and self.close:
94
+ self.handler.close()
95
+
96
+
97
+ def get_logger(name=__name__, level=logging.INFO) -> logging.Logger:
98
+ """Initializes multi-GPU-friendly python logger."""
99
+
100
+ logger = logging.getLogger(name)
101
+ logger.setLevel(level)
102
+
103
+ # this ensures all logging levels get marked with the rank zero decorator
104
+ # otherwise logs would get multiplied for each GPU process in multi-GPU setup
105
+ for level in ('debug', 'info', 'warning', 'error',
106
+ 'exception', 'fatal', 'critical'):
107
+ setattr(logger,
108
+ level,
109
+ lightning.pytorch.utilities.rank_zero_only(
110
+ getattr(logger, level)))
111
+
112
+ return logger
113
+
114
+
115
+ class Sampler:
116
+ def __init__(self, shape):
117
+ self.shape = shape
118
+
119
+ def _sampling_noise(self):
120
+ pass
121
+
122
+ def _hard_sample(self, logits):
123
+ pass
124
+
125
+ def _soft_sample(self, logits):
126
+ return 0
127
+
128
+ def sample(self, logits):
129
+ noise = self._sampling_noise()
130
+ noise = noise[: logits.shape[0], :]
131
+ logits = logits + noise.to(
132
+ dtype=logits.dtype, device=logits.device)
133
+ hard_sample = self._hard_sample(logits)
134
+ soft_sample = self._soft_sample(logits)
135
+ return soft_sample + (hard_sample - soft_sample).detach()
136
+
137
+
138
+ class TopKSampler(Sampler):
139
+ def __init__(self, k, shape, gamma_tau=1.0):
140
+ super().__init__(shape)
141
+ self.k = k
142
+ self.gamma_tau = gamma_tau
143
+ self.num_betas = 10
144
+ self.sampler = torch.distributions.gamma.Gamma(
145
+ 1 / k * torch.ones(self.num_betas, * self.shape), 1.0)
146
+
147
+ def _sampling_noise(self):
148
+ noise = self.sampler.sample()
149
+ beta = self.k / torch.arange(1, self.num_betas + 1, 1,
150
+ dtype=torch.float32)
151
+ beta = beta[:, None, None]
152
+ assert beta.ndim == noise.ndim
153
+ s = noise / beta
154
+ s = torch.sum(s, axis=0)
155
+ s = s - math.log(10.0)
156
+ s = self.gamma_tau * (s / self.k)
157
+ return s
158
+
159
+ def _hard_sample(self, logits):
160
+ assert logits.ndim == 2
161
+ thresholds, _ = torch.sort(logits, dim=-1)
162
+ thresholds = thresholds[:, - self.k][:, None]
163
+ return (logits >= thresholds).type(logits.dtype)
164
+
165
+ def _soft_sample(self, logits):
166
+ soft_top_k = logits - torch.mean(logits, dim=-1,
167
+ keepdim=True)
168
+ return soft_top_k / torch.norm(soft_top_k, dim=-1,
169
+ keepdim=True)
170
+
171
+
172
+ class DeterministicTopK(TopKSampler):
173
+ def __init__(self, k):
174
+ super().__init__(k, shape=(1, 1))
175
+
176
+ def _sampling_noise(self):
177
+ return 0
178
+
179
+ def discreize(self, x):
180
+ hard_sample = self._hard_sample(x)
181
+ soft_sample = self._soft_sample(x)
182
+ return soft_sample + (hard_sample - soft_sample).detach()
183
+
184
+ class GumbelSampler(Sampler):
185
+
186
+ def __init__(self, shape, temperature=1.0):
187
+ super().__init__(shape)
188
+ self.temperature = temperature
189
+
190
+ def _sampling_noise(self):
191
+ return - (1e-10 - (
192
+ torch.rand(* self.shape) + 1e-10).log()).log()
193
+
194
+ def _hard_sample(self, logits):
195
+ assert logits.ndim == 2
196
+ indices = torch.argmax(logits, dim=-1)
197
+ zeros = logits * 0
198
+ ones = torch.ones_like(logits[:, :, :1])
199
+ return torch.scatter(zeros, -1, indices[:, :, None],
200
+ ones)
201
+
202
+ def _soft_sample(self, logits):
203
+ return torch.nn.functional.softmax(
204
+ logits / self.temperature, dim=-1)
205
+
206
+
207
+ class BinarySampler(GumbelSampler):
208
+
209
+ def sample(self, probs):
210
+ # TODO(subhamsahoo): use the temperature parameter.
211
+ pos_noise = self._sampling_noise().to(
212
+ dtype=probs.dtype, device=probs.device)
213
+ neg_noise = self._sampling_noise().to(
214
+ dtype=probs.dtype, device=probs.device)
215
+ del_noise_exp = (neg_noise - pos_noise).exp()
216
+ hard_sample = (probs * (1 + del_noise_exp)
217
+ > 1).to(probs.dtype)
218
+ soft_sample = probs / (probs + (1 - probs) * del_noise_exp)
219
+ return soft_sample + (hard_sample - soft_sample).detach()
220
+
221
+
222
+ class GaussianSampler:
223
+ def __init__(self):
224
+ self.softplus = torch.nn.Softplus()
225
+
226
+ def sample(self, x):
227
+ assert x.ndim == 2
228
+ n = x.shape[-1] // 2
229
+ mu = x[:, :n]
230
+ sigma = self.softplus(x[:, n:]).sqrt()
231
+ return mu + sigma * torch.randn_like(mu)
232
+
233
+ def mapper(n_jobs):
234
+ '''
235
+ Returns function for map call.
236
+ If n_jobs == 1, will use standard map
237
+ If n_jobs > 1, will use multiprocessing pool
238
+ If n_jobs is a pool object, will return its map function
239
+ '''
240
+ if n_jobs == 1:
241
+ def _mapper(*args, **kwargs):
242
+ return list(map(*args, **kwargs))
243
+
244
+ return _mapper
245
+ if isinstance(n_jobs, int):
246
+ pool = Pool(n_jobs)
247
+
248
+ def _mapper(*args, **kwargs):
249
+ try:
250
+ result = pool.map(*args, **kwargs)
251
+ finally:
252
+ pool.terminate()
253
+ return result
254
+
255
+ return _mapper
256
+ return n_jobs.map