Irwiny123 commited on
Commit
52007f8
·
1 Parent(s): d0c4cc5

添加PepGLAD初始代码

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .github/workflows/stale.yml +27 -0
  2. .gitignore +35 -0
  3. .idea/.gitignore +3 -0
  4. .idea/PepGLAD.iml +8 -0
  5. .idea/inspectionProfiles/profiles_settings.xml +6 -0
  6. .idea/misc.xml +4 -0
  7. .idea/modules.xml +8 -0
  8. .idea/vcs.xml +6 -0
  9. LICENSE +21 -0
  10. README.md +214 -3
  11. api/detect_pocket.py +72 -0
  12. api/run.py +274 -0
  13. assets/1ssc_A_pocket.json +1 -0
  14. cal_metrics.py +228 -0
  15. configs/pepbdb/autoencoder/train_codesign.yaml +66 -0
  16. configs/pepbdb/autoencoder/train_fixseq.yaml +63 -0
  17. configs/pepbdb/ldm/setup_latent_guidance.yaml +12 -0
  18. configs/pepbdb/ldm/train_codesign.yaml +61 -0
  19. configs/pepbdb/ldm/train_fixseq.yaml +63 -0
  20. configs/pepbdb/test_codesign.yaml +18 -0
  21. configs/pepbdb/test_fixseq.yaml +19 -0
  22. configs/pepbench/autoencoder/train_codesign.yaml +66 -0
  23. configs/pepbench/autoencoder/train_fixseq.yaml +62 -0
  24. configs/pepbench/ldm/setup_latent_guidance.yaml +12 -0
  25. configs/pepbench/ldm/train_codesign.yaml +60 -0
  26. configs/pepbench/ldm/train_fixseq.yaml +61 -0
  27. configs/pepbench/test_codesign.yaml +17 -0
  28. configs/pepbench/test_fixseq.yaml +18 -0
  29. data/__init__.py +53 -0
  30. data/codesign.py +208 -0
  31. data/converter/blocks_interface.py +89 -0
  32. data/converter/blocks_to_data.py +110 -0
  33. data/converter/list_blocks_to_pdb.py +61 -0
  34. data/converter/pdb_to_list_blocks.py +99 -0
  35. data/dataset_wrapper.py +115 -0
  36. data/format.py +220 -0
  37. data/mmap_dataset.py +112 -0
  38. data/resample.py +19 -0
  39. env.yaml +32 -0
  40. evaluation/__init__.py +3 -0
  41. evaluation/dG/RosettaFastRelaxUtil.xml +190 -0
  42. evaluation/dG/base.py +148 -0
  43. evaluation/dG/energy.py +236 -0
  44. evaluation/dG/openmm_relaxer.py +107 -0
  45. evaluation/dG/run.py +92 -0
  46. evaluation/diversity.py +68 -0
  47. evaluation/dockq.py +15 -0
  48. evaluation/rmsd.py +11 -0
  49. evaluation/seq_metric.py +71 -0
  50. generate.py +235 -0
.github/workflows/stale.yml ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This workflow warns and then closes issues and PRs that have had no activity for a specified amount of time.
2
+ #
3
+ # You can adjust the behavior by modifying this file.
4
+ # For more information, see:
5
+ # https://github.com/actions/stale
6
+ name: Close inactive issues
7
+ on:
8
+ schedule:
9
+ - cron: "30 1 * * *"
10
+
11
+ jobs:
12
+ close-issues:
13
+ runs-on: ubuntu-latest
14
+ permissions:
15
+ issues: write
16
+ pull-requests: write
17
+ steps:
18
+ - uses: actions/stale@v9
19
+ with:
20
+ days-before-issue-stale: 30
21
+ days-before-issue-close: 14
22
+ stale-issue-label: "stale"
23
+ stale-issue-message: "This issue is stale because it has been open for 30 days with no activity."
24
+ close-issue-message: "This issue was closed because it has been inactive for 14 days since being marked as stale."
25
+ days-before-pr-stale: -1
26
+ days-before-pr-close: -1
27
+ repo-token: ${{ secrets.GITHUB_TOKEN }}
.gitignore ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__
2
+
3
+ __cache__
4
+
5
+ __tmcache__
6
+
7
+ ckpts
8
+
9
+ checkpoints
10
+
11
+ *_results*
12
+
13
+ datasets
14
+
15
+ exps
16
+
17
+ DockQ
18
+
19
+ TMscore
20
+
21
+ *.txt
22
+
23
+ *.pt
24
+
25
+ *.png
26
+
27
+ *.pkl
28
+
29
+ *.svg
30
+
31
+ *.log
32
+
33
+ *.pdb
34
+
35
+ *.jsonl
.idea/.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # 默认忽略的文件
2
+ /shelf/
3
+ /workspace.xml
.idea/PepGLAD.iml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <module type="PYTHON_MODULE" version="4">
3
+ <component name="NewModuleRootManager">
4
+ <content url="file://$MODULE_DIR$" />
5
+ <orderEntry type="jdk" jdkName="D:\Miniconda3" jdkType="Python SDK" />
6
+ <orderEntry type="sourceFolder" forTests="false" />
7
+ </component>
8
+ </module>
.idea/inspectionProfiles/profiles_settings.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <component name="InspectionProjectProfileManager">
2
+ <settings>
3
+ <option name="USE_PROJECT_PROFILE" value="false" />
4
+ <version value="1.0" />
5
+ </settings>
6
+ </component>
.idea/misc.xml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="ProjectRootManager" version="2" project-jdk-name="D:\Miniconda3" project-jdk-type="Python SDK" />
4
+ </project>
.idea/modules.xml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="ProjectModuleManager">
4
+ <modules>
5
+ <module fileurl="file://$PROJECT_DIR$/.idea/PepGLAD.iml" filepath="$PROJECT_DIR$/.idea/PepGLAD.iml" />
6
+ </modules>
7
+ </component>
8
+ </project>
.idea/vcs.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="VcsDirectoryMappings">
4
+ <mapping directory="" vcs="Git" />
5
+ </component>
6
+ </project>
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 THUNLP
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,3 +1,214 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # PepGLAD: Full-Atom Peptide Design with Geometric Latent Diffusion
2
+
3
+ ![cover](./assets/cover.png)
4
+
5
+ ## Quick Links
6
+
7
+ - [Setup](#setup)
8
+ - [Environment](#environment)
9
+ - [Datasets](#optional-datasets)
10
+ - [Trained Weights](#trained-weights)
11
+ - [Usage](#usage)
12
+ - [Peptide Sequence-Structure Co-Design](#peptide-sequence-structure-co-design)
13
+ - [Peptide Binding Conformation Generation](#peptide-binding-conformation-generation)
14
+ - [Reproduction of Paper Experiments](#reproduction-of-paper-experiments)
15
+ - [Codesign](#codesign)
16
+ - [Binding Conformation Generation](#binding-conformation-generation)
17
+ - [Contact](#contact)
18
+ - [Reference](#reference)
19
+
20
+ ## Updates
21
+
22
+ Changes for compatibilities and extended functionalities are saved in [beta](https://github.com/THUNLP-MT/PepGLAD/tree/beta) branch. Thank [@Barry0121](https://github.com/Barry0121) for the help.
23
+
24
+ - pyTorch 2.6.0 and openmm 8.2.0 are supported, with new environment configure at [2025_env.yaml](https://github.com/THUNLP-MT/PepGLAD/blob/beta/2025_env.yml).
25
+ - Support non-canonical amino acids in `detect_pocket.py`.
26
+
27
+
28
+ ## Setup
29
+
30
+ ### Environment
31
+
32
+ The conda environment can be constructed with the configuration `env.yaml`:
33
+
34
+ ```bash
35
+ conda env create -f env.yaml
36
+ ```
37
+
38
+ The codes are tested with cuda version `11.7` and pytorch version `1.13.1`.
39
+
40
+ Don't forget to activate the environment before running the codes:
41
+
42
+ ```bash
43
+ conda activate PepGLAD
44
+ ```
45
+
46
+ #### (Optional) pyRosetta
47
+
48
+ PyRosetta is used to calculate interface energy of generated peptides. If you are interested in it, please follow the instruction [here](https://www.pyrosetta.org/downloads) to install.
49
+
50
+ ### (Optional) Datasets
51
+
52
+ These datasets are only used for benchmarking models. If you just want to use the trained weights for inferencing on your cases, there is no need to download these datasets.
53
+
54
+ #### PepBench
55
+
56
+ 1. Download
57
+
58
+ The datasets, which are originally introduced in this paper, are uploaded to Zenodo at [this url](https://zenodo.org/records/13373108). You can download them as follows:
59
+
60
+ ```bash
61
+ mkdir datasets # all datasets will be put into this directory
62
+ wget https://zenodo.org/records/13373108/files/train_valid.tar.gz?download=1 -O ./datasets/train_valid.tar.gz # training/validation
63
+ wget https://zenodo.org/records/13373108/files/LNR.tar.gz?download=1 -O ./datasets/LNR.tar.gz # test set
64
+ wget https://zenodo.org/records/13373108/files/ProtFrag.tar.gz?download=1 -O ./datasets/ProtFrag.tar.gz # augmentation dataset
65
+ ```
66
+
67
+ 2. Decompresss
68
+
69
+ ```bash
70
+ tar zxvf ./datasets/train_valid.tar.gz -C ./datasets
71
+ tar zxvf ./datasets/LNR.tar.gz -C ./datasets
72
+ tar zxvf ./datasets/ProtFrag.tar.gz -C ./datasets
73
+ ```
74
+
75
+ 3. Process
76
+
77
+ ```bash
78
+ python -m scripts.data_process.process --index ./datasets/train_valid/all.txt --out_dir ./datasets/train_valid/processed # train/validation set
79
+ python -m scripts.data_process.process --index ./datasets/LNR/test.txt --out_dir ./datasets/LNR/processed # test set
80
+ python -m scripts.data_process.process --index ./datasets/ProtFrag/all.txt --out_dir ./datasets/ProtFrag/processed # augmentation dataset
81
+ ```
82
+
83
+ The index of processed data for train/validation splits need to be generated as follows, which will result in `datasets/train_valid/processed/train_index.txt` and `datasets/train_valid/processed/valid_index.txt`:
84
+
85
+ ```bash
86
+ python -m scripts.data_process.split --train_index datasets/train_valid/train.txt --valid_index datasets/train_valid/valid.txt --processed_dir datasets/train_valid/processed/
87
+ ```
88
+
89
+ #### PepBDB
90
+
91
+ 1. Download
92
+
93
+ ```bash
94
+ wget http://huanglab.phys.hust.edu.cn/pepbdb/db/download/pepbdb-20200318.tgz -O ./datasets/pepbdb.tgz
95
+ ```
96
+
97
+ 2. Decompress
98
+
99
+ ```bash
100
+ tar zxvf ./datasets/pepbdb.tgz -C ./datasets/pepbdb
101
+ ```
102
+
103
+
104
+ 3. Process
105
+
106
+ ```bash
107
+ python -m scripts.data_process.pepbdb --index ./datasets/pepbdb/peptidelist.txt --out_dir ./datasets/pepbdb/processed
108
+ python -m scripts.data_process.split --train_index ./datasets/pepbdb/train.txt --valid_index ./datasets/pepbdb/valid.txt --test_index ./datasets/pepbdb/test.txt --processed_dir datasets/pepbdb/processed/
109
+ mv ./datasets/pepbdb/processed/pdbs ./dataset/pepbdb # re-locate
110
+ ```
111
+
112
+
113
+ ### Trained Weights
114
+
115
+ - codesign: `./checkpoint/codesign.ckpt`
116
+ - conformation generation: `./checkpoints/fixseq.ckpt`
117
+
118
+ Both can be downloaded at the [release page](https://github.com/THUNLP-MT/PepGLAD/releases/tag/v1.0). These checkpoints were trained on PepBench.
119
+
120
+ ## Usage
121
+
122
+ :warning: Before using the following codes, please first download the trained weights mentioned above.
123
+
124
+ ### Peptide Sequence-Structure Co-Design
125
+
126
+ Take `./assets/1ssc_A_B.pdb` as an example, where chain A is the target protein:
127
+
128
+ ```bash
129
+ # obtain the binding site, which might also be manually crafted or from other ligands (e.g. small molecule, antibodies)
130
+ python -m api.detect_pocket --pdb assets/1ssc_A_B.pdb --target_chains A --ligand_chains B --out assets/1ssc_A_pocket.json
131
+ # sequence-structure codesign with length in [8, 15)
132
+ CUDA_VISIBLE_DEVICES=0 python -m api.run \
133
+ --mode codesign \
134
+ --pdb assets/1ssc_A_B.pdb \
135
+ --pocket assets/1ssc_A_pocket.json \
136
+ --out_dir ./output/codesign \
137
+ --length_min 8 \
138
+ --length_max 15 \
139
+ --n_samples 10
140
+ ```
141
+ Then 10 generations will be outputed under the folder `./output/codesign`.
142
+
143
+ ### Peptide Binding Conformation Generation
144
+
145
+ Take `./assets/1ssc_A_B.pdb` as an example, where chain A is the target protein:
146
+
147
+ ```bash
148
+ # obtain the binding site, which might also be manually crafted or from other ligands (e.g. small molecule, antibodies)
149
+ python -m api.detect_pocket --pdb assets/1ssc_A_B.pdb --target_chains A --ligand_chains B --out assets/1ssc_A_pocket.json
150
+ # generate binding conformation
151
+ CUDA_VISIBLE_DEVICES=0 python -m api.run \
152
+ --mode struct_pred \
153
+ --pdb assets/1ssc_A_B.pdb \
154
+ --pocket assets/1ssc_A_pocket.json \
155
+ --out_dir ./output/struct_pred \
156
+ --peptide_seq PYVPVHFDASV \
157
+ --n_samples 10
158
+ ```
159
+ Then 10 conformations will be outputed under the folder `./output/struct_pred`.
160
+
161
+
162
+ ## Reproduction of Paper Experiments
163
+
164
+ Each task requires the following steps, which we have integrated into the script `./scripts/run_exp_pipe.sh`:
165
+
166
+ 1. Train autoencoder
167
+ 2. Train latent diffusion model
168
+ 3. Calculate distribution of latent distances between consecutive residues
169
+ 4. Generation & Evaluation
170
+
171
+ On the other hand, if you want to evaluate existing checkpoints, please follow the instructions below (e.g. conformation generation):
172
+
173
+ ```bash
174
+ # generate results on the test set and save to ./results/fixseq
175
+ python generate.py --config configs/pepbench/test_fixseq.yaml --ckpt checkpoints/fixseq.ckpt --gpu 0 --save_dir ./results/fixseq
176
+ # calculate metrics
177
+ python cal_metrics.py --results ./results/fixseq/results.jsonl
178
+ ```
179
+
180
+ ### Codesign
181
+
182
+ Codesign experiments on PepBench:
183
+
184
+ ```bash
185
+ GPU=0 bash scripts/run_exp_pipe.sh pepbench_codesign configs/pepbench/autoencoder/train_codesign.yaml configs/pepbench/ldm/train_codesign.yaml configs/pepbench/ldm/setup_latent_guidance.yaml configs/pepbench/test_codesign.yaml
186
+ ```
187
+
188
+
189
+ ### Binding Conformation Generation
190
+
191
+ Conformation generation experiments on PepBench:
192
+
193
+ ```bash
194
+ GPU=0 bash scripts/run_exp_pipe.sh pepbench_fixseq configs/pepbench/autoencoder/train_fixseq.yaml configs/pepbench/ldm/train_fixseq.yaml configs/pepbench/ldm/setup_latent_guidance.yaml configs/pepbench/test_fixseq.yaml
195
+ ```
196
+
197
+ ## Contact
198
+
199
+ Thank you for your interest in our work!
200
+
201
+ Please feel free to ask about any questions about the algorithms, codes, as well as problems encountered in running them so that we can make it clearer and better. You can either create an issue in the github repo or contact us at jackie_kxz@outlook.com.
202
+
203
+ ## Reference
204
+
205
+ ```bibtex
206
+ @article{kong2025full,
207
+ title={Full-atom peptide design with geometric latent diffusion},
208
+ author={Kong, Xiangzhe and Jia, Yinjun and Huang, Wenbing and Liu, Yang},
209
+ journal={Advances in Neural Information Processing Systems},
210
+ volume={37},
211
+ pages={74808--74839},
212
+ year={2025}
213
+ }
214
+ ```
api/detect_pocket.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python
2
+ # -*- coding:utf-8 -*-
3
+ import argparse
4
+ import numpy as np
5
+
6
+ from data.converter.pdb_to_list_blocks import pdb_to_list_blocks
7
+ from data.converter.blocks_interface import blocks_cb_interface, dist_matrix_from_blocks
8
+
9
+
10
+ def get_interface(pdb, receptor_chains, ligand_chains, pocket_th=10.0): # CB distance
11
+ list_blocks, chain_ids = pdb_to_list_blocks(pdb, receptor_chains + ligand_chains, return_chain_ids=True)
12
+ chain2blocks = {chain: block for chain, block in zip(chain_ids, list_blocks)}
13
+ for c in receptor_chains:
14
+ assert c in chain2blocks, f'Chain {c} not found for receptor'
15
+ for c in ligand_chains:
16
+ assert c in chain2blocks, f'Chain {c} not found for ligand'
17
+
18
+ rec_blocks, rec_block_chains, lig_blocks, lig_block_chains = [], [], [], []
19
+ for c in receptor_chains:
20
+ for block in chain2blocks[c]:
21
+ rec_blocks.append(block)
22
+ rec_block_chains.append(c)
23
+ for c in ligand_chains:
24
+ for block in chain2blocks[c]:
25
+ lig_blocks.append(block)
26
+ lig_block_chains.append(c)
27
+
28
+ _, (pocket_idx, lig_if_idx) = blocks_cb_interface(rec_blocks, lig_blocks, pocket_th) # 10A for pocket size based on CB
29
+ epitope = []
30
+ for i in pocket_idx:
31
+ epitope.append((rec_blocks[i], rec_block_chains[i], i))
32
+
33
+ dist_mat = dist_matrix_from_blocks([rec_blocks[i] for i in pocket_idx], [lig_blocks[i] for i in lig_if_idx])
34
+ min_dists = np.min(dist_mat, axis=-1) # [Nrec]
35
+ lig_idxs = np.argmin(dist_mat, axis=-1) # [Nrec]
36
+ dists = []
37
+ for i, d in zip(lig_idxs, min_dists):
38
+ i = lig_if_idx[i]
39
+ dists.append((lig_blocks[i], lig_block_chains[i], i, d))
40
+
41
+ return epitope, dists
42
+
43
+
44
+ if __name__ == '__main__':
45
+ import json
46
+ parser = argparse.ArgumentParser(description='get interface')
47
+ parser.add_argument('--pdb', type=str, required=True, help='Path to the complex pdb')
48
+ parser.add_argument('--target_chains', type=str, nargs='+', required=True, help='Specify target chain ids')
49
+ parser.add_argument('--ligand_chains', type=str, nargs='+', required=True, help='Specify ligand chain ids')
50
+ parser.add_argument('--pocket_th', type=int, default=10.0, help='CB distance threshold for defining the binding site')
51
+ parser.add_argument('--out', type=str, default=None, help='Save epitope information to json file if specified')
52
+ args = parser.parse_args()
53
+ epitope, dists = get_interface(args.pdb, args.target_chains, args.ligand_chains, args.pocket_th)
54
+ para_res = {}
55
+ for _, chain_name, i, d in dists:
56
+ key = f'{chain_name}-{i}'
57
+ para_res[key] = 1
58
+ print(f'REMARK: {len(epitope)} residues in the binding site on the target protein, with {len(para_res)} residues in ligand:')
59
+ print(f' \tchain\tresidue id\ttype\tchain\tresidue id\ttype\tdistance')
60
+ for i, (e, p) in enumerate(zip(epitope, dists)):
61
+ e_res, e_chain_name, _ = e
62
+ p_res, p_chain_name, _, d = p
63
+ print(f'{i+1}\t{e_chain_name}\t{e_res.id}\t{e_res.abrv}\t' + \
64
+ f'{p_chain_name}\t{p_res.id}\t{p_res.abrv}\t{round(d, 3)}')
65
+
66
+ if args.out:
67
+ data = []
68
+ for e in epitope:
69
+ res, chain_name, _ = e
70
+ data.append((chain_name, res.id))
71
+ with open(args.out, 'w') as fout:
72
+ json.dump(data, fout)
api/run.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python
2
+ # -*- coding:utf-8 -*-
3
+ import os
4
+ import sys
5
+ import json
6
+ import argparse
7
+ from tqdm import tqdm
8
+ from os.path import splitext, basename
9
+
10
+ import ray
11
+ import numpy as np
12
+ import torch
13
+ from torch.utils.data import DataLoader
14
+
15
+ from data.format import Atom, Block, VOCAB
16
+ from data.converter.pdb_to_list_blocks import pdb_to_list_blocks
17
+ from data.converter.list_blocks_to_pdb import list_blocks_to_pdb
18
+ from data.codesign import calculate_covariance_matrix
19
+ from utils.const import sidechain_atoms
20
+ from utils.logger import print_log
21
+ from evaluation.dG.openmm_relaxer import ForceFieldMinimizer
22
+
23
+
24
+ class DesignDataset(torch.utils.data.Dataset):
25
+
26
+ MAX_N_ATOM = 14
27
+
28
+ def __init__(self, pdbs, epitopes, lengths_range=None, seqs=None) -> None:
29
+ super().__init__()
30
+ self.pdbs = pdbs
31
+ self.epitopes = epitopes
32
+ self.lengths_range = lengths_range
33
+ self.seqs = seqs
34
+ # structure prediction or codesign
35
+ assert (self.seqs is not None and self.lengths_range is None) | \
36
+ (self.seqs is None and self.lengths_range is not None)
37
+
38
+ def get_epitope(self, idx):
39
+ pdb, epitope_def = self.pdbs[idx], self.epitopes[idx]
40
+
41
+ with open(epitope_def, 'r') as fin:
42
+ epitope = json.load(fin)
43
+ to_str = lambda pos: f'{pos[0]}-{pos[1]}'
44
+ epi_map = {}
45
+ for chain_name, pos in epitope:
46
+ if chain_name not in epi_map:
47
+ epi_map[chain_name] = {}
48
+ epi_map[chain_name][to_str(pos)] = True
49
+ residues, position_ids = [], []
50
+ chain2blocks = pdb_to_list_blocks(pdb, list(epi_map.keys()), dict_form=True)
51
+ if len(chain2blocks) != len(epi_map):
52
+ print_log(f'Some chains in the epitope are missing. Parsed {list(chain2blocks.keys())}, given {list(epi_map.keys())}.', level='WARN')
53
+ for chain_name in chain2blocks:
54
+ chain = chain2blocks[chain_name]
55
+ for i, block in enumerate(chain): # residue
56
+ if to_str(block.id) in epi_map[chain_name]:
57
+ residues.append(block)
58
+ position_ids.append(i + 1) # position ids start from 1
59
+ return residues, position_ids, chain2blocks
60
+
61
+ def generate_pep_chain(self, idx):
62
+ if self.lengths_range is not None: # codesign
63
+ lmin, lmax = self.lengths_range[idx]
64
+ length = np.random.randint(lmin, lmax)
65
+ unk_block = Block(VOCAB.symbol_to_abrv(VOCAB.UNK), [Atom('CA', [0, 0, 0], 'C')])
66
+ return [unk_block] * length
67
+ else:
68
+ seq = self.seqs[idx]
69
+ blocks = []
70
+ for s in seq:
71
+ atoms = []
72
+ for atom_name in VOCAB.backbone_atoms + sidechain_atoms.get(s, []):
73
+ atoms.append(Atom(atom_name, [0, 0, 0], atom_name[0]))
74
+ blocks.append(Block(VOCAB.symbol_to_abrv(s), atoms))
75
+ return blocks
76
+
77
+ def __len__(self):
78
+ return len(self.pdbs)
79
+
80
+ def __getitem__(self, idx: int):
81
+ rec_blocks, rec_position_ids, rec_chain2blocks = self.get_epitope(idx)
82
+ lig_blocks = self.generate_pep_chain(idx)
83
+
84
+ mask = [0 for _ in rec_blocks] + [1 for _ in lig_blocks]
85
+ position_ids = rec_position_ids + [i + 1 for i, _ in enumerate(lig_blocks)]
86
+ X, S, atom_mask = [], [], []
87
+ for block in rec_blocks + lig_blocks:
88
+ symbol = VOCAB.abrv_to_symbol(block.abrv)
89
+ atom2coord = { unit.name: unit.get_coord() for unit in block.units }
90
+ bb_pos = np.mean(list(atom2coord.values()), axis=0).tolist()
91
+ coords, coord_mask = [], []
92
+ for atom_name in VOCAB.backbone_atoms + sidechain_atoms.get(symbol, []):
93
+ if atom_name in atom2coord:
94
+ coords.append(atom2coord[atom_name])
95
+ coord_mask.append(1)
96
+ else:
97
+ coords.append(bb_pos)
98
+ coord_mask.append(0)
99
+ n_pad = self.MAX_N_ATOM - len(coords)
100
+ for _ in range(n_pad):
101
+ coords.append(bb_pos)
102
+ coord_mask.append(0)
103
+
104
+ X.append(coords)
105
+ S.append(VOCAB.symbol_to_idx(symbol))
106
+ atom_mask.append(coord_mask)
107
+
108
+ X, atom_mask = torch.tensor(X, dtype=torch.float), torch.tensor(atom_mask, dtype=torch.bool)
109
+ mask = torch.tensor(mask, dtype=torch.bool)
110
+ cov = calculate_covariance_matrix(X[~mask][:, 1][atom_mask[~mask][:, 1]].numpy()) # only use the receptor to derive the affine transformation
111
+ eps = 1e-4
112
+ cov = cov + eps * np.identity(cov.shape[0])
113
+ L = torch.from_numpy(np.linalg.cholesky(cov)).float().unsqueeze(0)
114
+
115
+ return {
116
+ 'X': X, # [N, 14] or [N, 4] if backbone_only == True
117
+ 'S': torch.tensor(S, dtype=torch.long), # [N]
118
+ 'position_ids': torch.tensor(position_ids, dtype=torch.long), # [N]
119
+ 'mask': mask, # [N], 1 for generation
120
+ 'atom_mask': atom_mask, # [N, 14] or [N, 4], 1 for having records in the PDB
121
+ 'lengths': len(S),
122
+ 'rec_chain2blocks': rec_chain2blocks,
123
+ 'L': L
124
+ }
125
+
126
+ def collate_fn(self, batch):
127
+ results = {}
128
+ for key in batch[0]:
129
+ values = [item[key] for item in batch]
130
+ if key == 'lengths':
131
+ results[key] = torch.tensor(values, dtype=torch.long)
132
+ elif key == 'rec_chain2blocks':
133
+ results[key] = values
134
+ else:
135
+ results[key] = torch.cat(values, dim=0)
136
+ return results
137
+
138
+
139
+ @ray.remote(num_cpus=1, num_gpus=1/16)
140
+ def openmm_relax(pdb_path):
141
+ force_field = ForceFieldMinimizer()
142
+ force_field(pdb_path, pdb_path)
143
+ return pdb_path
144
+
145
+
146
+ def design(mode, ckpt, gpu, pdbs, epitope_defs, n_samples, out_dir,
147
+ lengths_range=None, seqs=None, identifiers=None, batch_size=8, num_workers=4):
148
+
149
+ # create out dir
150
+ if not os.path.exists(out_dir):
151
+ os.makedirs(out_dir)
152
+ result_summary = open(os.path.join(out_dir, 'summary.jsonl'), 'w')
153
+ if identifiers is None:
154
+ identifiers = [splitext(basename(pdb))[0] for pdb in pdbs]
155
+ # load model
156
+ device = torch.device('cpu' if gpu == -1 else f'cuda:{gpu}')
157
+ model = torch.load(ckpt, map_location='cpu')
158
+ model.to(device)
159
+ model.eval()
160
+
161
+ # generate dataset
162
+ # expand data
163
+ if lengths_range is None: lengths_range = [None for _ in pdbs]
164
+ if seqs is None: seqs = [None for _ in pdbs]
165
+ expand_pdbs, expand_epitopes, expand_lens, expand_ids, expand_seqs = [], [], [], [], []
166
+ for _id, pdb, epitope, l, s, n in zip(identifiers, pdbs, epitope_defs, lengths_range, seqs, n_samples):
167
+ expand_ids.extend([f'{_id}_{i}' for i in range(n)])
168
+ expand_pdbs.extend([pdb for _ in range(n)])
169
+ expand_epitopes.extend([epitope for _ in range(n)])
170
+ expand_lens.extend([l for _ in range(n)])
171
+ expand_seqs.extend([s for _ in range(n)])
172
+ # create dataset
173
+ if expand_lens[0] is None: expand_lens = None
174
+ if expand_seqs[0] is None: expand_seqs = None
175
+ dataset = DesignDataset(expand_pdbs, expand_epitopes, expand_lens, expand_seqs)
176
+ dataloader = DataLoader(dataset, batch_size=batch_size,
177
+ num_workers=num_workers,
178
+ collate_fn=dataset.collate_fn,
179
+ shuffle=False
180
+ )
181
+
182
+ # generate peptides
183
+ cnt = 0
184
+ all_pdbs = []
185
+ for batch in tqdm(dataloader):
186
+ with torch.no_grad():
187
+ # move data
188
+ for k in batch:
189
+ if hasattr(batch[k], 'to'):
190
+ batch[k] = batch[k].to(device)
191
+ # generate
192
+ batch_X, batch_S, batch_pmetric = model.sample(
193
+ batch['X'], batch['S'],
194
+ batch['mask'], batch['position_ids'],
195
+ batch['lengths'], batch['atom_mask'],
196
+ L=batch['L'], sample_opt={
197
+ 'energy_func': 'default',
198
+ 'energy_lambda': 0.5 if mode == 'struct_pred' else 0.8
199
+ }
200
+ )
201
+ # save data
202
+ for X, S, pmetric, rec_chain2blocks in zip(batch_X, batch_S, batch_pmetric, batch['rec_chain2blocks']):
203
+ if S is None: S = expand_seqs[cnt] # structure prediction
204
+ lig_blocks = []
205
+ for x, s in zip(X, S):
206
+ abrv = VOCAB.symbol_to_abrv(s)
207
+ atoms = VOCAB.backbone_atoms + sidechain_atoms[VOCAB.abrv_to_symbol(abrv)]
208
+ units = [
209
+ Atom(atom_name, coord, atom_name[0]) for atom_name, coord in zip(atoms, x)
210
+ ]
211
+ lig_blocks.append(Block(abrv, units))
212
+ list_blocks, chain_names = [], []
213
+ for chain in rec_chain2blocks:
214
+ list_blocks.append(rec_chain2blocks[chain])
215
+ chain_names.append(chain)
216
+ pep_chain_id = chr(max([ord(c) for c in chain_names]) + 1)
217
+ list_blocks.append(lig_blocks)
218
+ chain_names.append(pep_chain_id)
219
+ out_pdb = os.path.join(out_dir, expand_ids[cnt] + '.pdb')
220
+ list_blocks_to_pdb(list_blocks, chain_names, out_pdb)
221
+ all_pdbs.append(out_pdb)
222
+ result_summary.write(json.dumps({
223
+ 'id': expand_ids[cnt],
224
+ 'rec_chains': list(rec_chain2blocks.keys()),
225
+ 'pep_chain': pep_chain_id,
226
+ 'pep_seq': ''.join([VOCAB.abrv_to_symbol(block.abrv) for block in lig_blocks])
227
+ }) + '\n')
228
+ result_summary.flush()
229
+ cnt += 1
230
+ result_summary.close()
231
+
232
+ print_log(f'Running openmm relaxation...')
233
+ ray.init(num_cpus=8)
234
+ futures = [openmm_relax.remote(path) for path in all_pdbs]
235
+ pbar = tqdm(total=len(futures))
236
+ while len(futures) > 0:
237
+ done_ids, futures = ray.wait(futures, num_returns=1)
238
+ for done_id in done_ids:
239
+ done_path = ray.get(done_id)
240
+ pbar.update(1)
241
+ print_log(f'Done')
242
+
243
+
244
+ def parse():
245
+ parser = argparse.ArgumentParser(description='run pepglad for codesign or structure prediction')
246
+ parser.add_argument('--mode', type=str, required=True, choices=['codesign', 'struct_pred'], help='Running mode')
247
+ parser.add_argument('--pdb', type=str, required=True, help='Path to the PDB file of the target protein')
248
+ parser.add_argument('--pocket', type=str, required=True, help='Path to the pocket definition (*.json generated by detect_pocket)')
249
+ parser.add_argument('--n_samples', type=int, default=10, help='Number of samples')
250
+ parser.add_argument('--out_dir', type=str, required=True, help='Output directory')
251
+ parser.add_argument('--peptide_seq', type=str, required='struct_pred' in sys.argv, help='Peptide sequence for structure prediction')
252
+ parser.add_argument('--length_min', type=int, required='codesign' in sys.argv, help='Minimum peptide length for codesign (inclusive)')
253
+ parser.add_argument('--length_max', type=int, required='codesign' in sys.argv, help='Maximum peptide length for codesign (exclusive)')
254
+ parser.add_argument('--gpu', type=int, default=0, help='GPU to use')
255
+ return parser.parse_args()
256
+
257
+
258
+ if __name__ == '__main__':
259
+ args = parse()
260
+ proj_dir = os.path.join(os.path.dirname(__file__), '..')
261
+ ckpt = os.path.join(proj_dir, 'checkpoints', 'fixseq.ckpt' if args.mode == 'struct_pred' else 'codesign.ckpt')
262
+ print_log(f'Loading checkpoint: {ckpt}')
263
+ design(
264
+ mode=args.mode,
265
+ ckpt=ckpt, # path to the checkpoint of the trained model
266
+ gpu=args.gpu, # the ID of the GPU to use
267
+ pdbs=[args.pdb], # paths to the PDB file of each antigen
268
+ epitope_defs=[args.pocket], # paths to the epitope (pocket) definitions
269
+ n_samples=[args.n_samples], # number of samples for each epitope
270
+ out_dir=args.out_dir, # output directory
271
+ identifiers=[os.path.basename(os.path.splitext(args.pdb)[0])], # file name (name of each output candidate)
272
+ lengths_range=[(args.length_min, args.length_max)] if args.mode == 'codesign' else None, # range of acceptable peptide lengths, left inclusive, right exclusive
273
+ seqs=[args.peptide_seq] if args.mode == 'struct_pred' else None # peptide sequences for structure prediction
274
+ )
assets/1ssc_A_pocket.json ADDED
@@ -0,0 +1 @@
 
 
1
+ [["A", [3, " "]], ["A", [4, " "]], ["A", [5, " "]], ["A", [6, " "]], ["A", [7, " "]], ["A", [8, " "]], ["A", [9, " "]], ["A", [11, " "]], ["A", [12, " "]], ["A", [13, " "]], ["A", [43, " "]], ["A", [44, " "]], ["A", [45, " "]], ["A", [46, " "]], ["A", [47, " "]], ["A", [51, " "]], ["A", [54, " "]], ["A", [55, " "]], ["A", [56, " "]], ["A", [57, " "]], ["A", [58, " "]], ["A", [59, " "]], ["A", [63, " "]], ["A", [64, " "]], ["A", [65, " "]], ["A", [66, " "]], ["A", [67, " "]], ["A", [69, " "]], ["A", [71, " "]], ["A", [72, " "]], ["A", [73, " "]], ["A", [74, " "]], ["A", [75, " "]], ["A", [78, " "]], ["A", [79, " "]], ["A", [81, " "]], ["A", [83, " "]], ["A", [102, " "]], ["A", [103, " "]], ["A", [104, " "]], ["A", [105, " "]], ["A", [106, " "]], ["A", [107, " "]], ["A", [108, " "]], ["A", [109, " "]], ["A", [110, " "]], ["A", [111, " "]], ["A", [112, " "]]]
cal_metrics.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python
2
+ # -*- coding:utf-8 -*-
3
+ import argparse
4
+ import json
5
+ import os
6
+ import random
7
+ from copy import deepcopy
8
+ from collections import defaultdict
9
+ from tqdm import tqdm
10
+ from tqdm.contrib.concurrent import process_map
11
+ import statistics
12
+ import warnings
13
+ warnings.filterwarnings("ignore")
14
+
15
+ import numpy as np
16
+ from scipy.stats import spearmanr
17
+
18
+ from data.converter.pdb_to_list_blocks import pdb_to_list_blocks
19
+ from evaluation import diversity
20
+ from evaluation.dockq import dockq
21
+ from evaluation.rmsd import compute_rmsd
22
+ from utils.random_seed import setup_seed
23
+ from evaluation.seq_metric import aar, slide_aar
24
+
25
+
26
+ def _get_ref_pdb(_id, root_dir):
27
+ return os.path.join(root_dir, 'references', f'{_id}_ref.pdb')
28
+
29
+
30
+ def _get_gen_pdb(_id, number, root_dir, use_rosetta):
31
+ suffix = '_rosetta' if use_rosetta else ''
32
+ return os.path.join(root_dir, 'candidates', _id, f'{_id}_gen_{number}{suffix}.pdb')
33
+
34
+
35
+ def cal_metrics(items):
36
+ # all of the items are conditioned on the same binding pocket
37
+ root_dir = items[0]['root_dir']
38
+ ref_pdb, rec_chain, lig_chain = items[0]['ref_pdb'], items[0]['rec_chain'], items[0]['lig_chain']
39
+ ref_pdb = _get_ref_pdb(items[0]['id'], root_dir)
40
+ seq_only, struct_only, backbone_only = items[0]['seq_only'], items[0]['struct_only'], items[0]['backbone_only']
41
+
42
+ # prepare
43
+ results = defaultdict(list)
44
+ cand_seqs, cand_ca_xs = [], []
45
+ rec_blocks, ref_pep_blocks = pdb_to_list_blocks(ref_pdb, [rec_chain, lig_chain])
46
+ ref_ca_x, ca_mask = [], []
47
+ for ref_block in ref_pep_blocks:
48
+ if ref_block.has_unit('CA'):
49
+ ca_mask.append(1)
50
+ ref_ca_x.append(ref_block.get_unit_by_name('CA').get_coord())
51
+ else:
52
+ ca_mask.append(0)
53
+ ref_ca_x.append([0, 0, 0])
54
+ ref_ca_x, ca_mask = np.array(ref_ca_x), np.array(ca_mask).astype(bool)
55
+
56
+ for item in items:
57
+ if not struct_only:
58
+ cand_seqs.append(item['gen_seq'])
59
+ results['Slide AAR'].append(slide_aar(item['gen_seq'], item['ref_seq'], aar))
60
+
61
+ # structure metrics
62
+ gen_pdb = _get_gen_pdb(item['id'], item['number'], root_dir, item['rosetta'])
63
+ _, gen_pep_blocks = pdb_to_list_blocks(gen_pdb, [rec_chain, lig_chain])
64
+ assert len(gen_pep_blocks) == len(ref_pep_blocks), f'{item}\t{len(ref_pep_blocks)}\t{len(gen_pep_blocks)}'
65
+
66
+ # CA RMSD
67
+ gen_ca_x = np.array([block.get_unit_by_name('CA').get_coord() for block in gen_pep_blocks])
68
+ cand_ca_xs.append(gen_ca_x)
69
+ rmsd = compute_rmsd(ref_ca_x[ca_mask], gen_ca_x[ca_mask], aligned=True)
70
+ results['RMSD(CA)'].append(rmsd)
71
+ if struct_only:
72
+ results['RMSD<=2.0'].append(1 if rmsd <= 2.0 else 0)
73
+ results['RMSD<=5.0'].append(1 if rmsd <= 5.0 else 0)
74
+ results['RMSD<=10.0'].append(1 if rmsd <= 10.0 else 0)
75
+
76
+
77
+ if backbone_only:
78
+ continue
79
+
80
+ # 5. DockQ
81
+ dockq_score = dockq(gen_pdb, ref_pdb, lig_chain)
82
+ results['DockQ'].append(dockq_score)
83
+ if struct_only:
84
+ results['DockQ>=0.23'].append(1 if dockq_score >= 0.23 else 0)
85
+ results['DockQ>=0.49'].append(1 if dockq_score >= 0.49 else 0)
86
+ results['DockQ>=0.80'].append(1 if dockq_score >= 0.80 else 0)
87
+
88
+ # Full atom RMSD
89
+ if struct_only:
90
+ gen_all_x, ref_all_x = [], []
91
+ for gen_block, ref_block in zip(gen_pep_blocks, ref_pep_blocks):
92
+ for ref_atom in ref_block:
93
+ if gen_block.has_unit(ref_atom.name):
94
+ ref_all_x.append(ref_atom.get_coord())
95
+ gen_all_x.append(gen_block.get_unit_by_name(ref_atom.name).get_coord())
96
+ results['RMSD(full-atom)'].append(compute_rmsd(
97
+ np.array(gen_all_x), np.array(ref_all_x), aligned=True
98
+ ))
99
+
100
+ pmets = [item['pmetric'] for item in items]
101
+ indexes = list(range(len(items)))
102
+ # aggregation
103
+ for name in results:
104
+ vals = results[name]
105
+ corr = spearmanr(vals, pmets, nan_policy='omit').statistic
106
+ if np.isnan(corr):
107
+ corr = 0
108
+ aggr_res = {
109
+ 'max': max(vals),
110
+ 'min': min(vals),
111
+ 'mean': sum(vals) / len(vals),
112
+ 'random': vals[0],
113
+ 'max*': vals[(max if corr > 0 else min)(indexes, key=lambda i: pmets[i])],
114
+ 'min*': vals[(min if corr > 0 else max)(indexes, key=lambda i: pmets[i])],
115
+ 'pmet_corr': corr,
116
+ 'individual': vals,
117
+ 'individual_pmet': pmets
118
+ }
119
+ results[name] = aggr_res
120
+
121
+ if len(cand_seqs) > 1 and not seq_only:
122
+ seq_div, struct_div, co_div, consistency = diversity.diversity(cand_seqs, np.array(cand_ca_xs))
123
+ results['Sequence Diversity'] = seq_div
124
+ results['Struct Diversity'] = struct_div
125
+ results['Codesign Diversity'] = co_div
126
+ results['Consistency'] = consistency
127
+
128
+ return results
129
+
130
+
131
+ def cnt_aa_dist(seqs):
132
+ cnts = {}
133
+ for seq in seqs:
134
+ for aa in seq:
135
+ if aa not in cnts:
136
+ cnts[aa] = 0
137
+ cnts[aa] += 1
138
+ aas = sorted(list(cnts.keys()), key=lambda aa: cnts[aa])
139
+ total = sum(cnts.values())
140
+ for aa in aas:
141
+ print(f'\t{aa}: {cnts[aa] / total}')
142
+
143
+
144
+ def main(args):
145
+ root_dir = os.path.dirname(args.results)
146
+ # load dG filter
147
+ if args.filter_dG is None:
148
+ filter_func = lambda _id, n: True
149
+ else:
150
+ dG_results = json.load(open(args.filter_dG, 'r'))
151
+ filter_func = lambda _id, n: dG_results[_id]['all'][str(n)] < 0
152
+ # load results
153
+ with open(args.results, 'r') as fin:
154
+ lines = fin.read().strip().split('\n')
155
+ id2items = {}
156
+ for line in lines:
157
+ item = json.loads(line)
158
+ _id = item['id']
159
+ if not filter_func(_id, item['number']):
160
+ continue
161
+ if _id not in id2items:
162
+ id2items[_id] = []
163
+ item['root_dir'] = root_dir
164
+ item['rosetta'] = args.rosetta
165
+ id2items[_id].append(item)
166
+ ids = list(id2items.keys())
167
+
168
+ if args.filter_dG is not None:
169
+ # delete results with only one sample since it cannot calculate diversity
170
+ del_ids = [_id for _id in ids if len(id2items[_id]) < 2]
171
+ for _id in del_ids:
172
+ print(f'Deleting {_id} since it only has one sample passed the filter')
173
+ del id2items[_id]
174
+
175
+ if args.num_workers > 1:
176
+ metrics = process_map(cal_metrics, id2items.values(), max_workers=args.num_workers, chunksize=1)
177
+ else:
178
+ metrics = [cal_metrics(inputs) for inputs in tqdm(id2items.values())]
179
+
180
+ eval_results_path = os.path.join(os.path.dirname(args.results), 'eval_report.json')
181
+ with open(eval_results_path, 'w') as fout:
182
+ for i, _id in enumerate(id2items):
183
+ metric = deepcopy(metrics[i])
184
+ metric['id'] = _id
185
+ fout.write(json.dumps(metric) + '\n')
186
+
187
+ # individual level results
188
+ print('Point-wise evaluation results:')
189
+ for name in metrics[0]:
190
+ vals = [item[name] for item in metrics]
191
+ if isinstance(vals[0], dict):
192
+ if 'RMSD' in name and '<=' not in name:
193
+ aggr = 'min'
194
+ else:
195
+ aggr = 'max'
196
+ aggr_vals = [val[aggr] for val in vals]
197
+ if '>=' in name or '<=' in name: # percentage
198
+ print(f'{name}: {sum(aggr_vals) / len(aggr_vals)}')
199
+ else:
200
+ if 'RMSD' in name:
201
+ print(f'{name}(median): {statistics.median(aggr_vals)}') # unbounded, some extreme values will affect the mean but not the median
202
+ else:
203
+ print(f'{name}(mean): {sum(aggr_vals) / len(aggr_vals)}')
204
+ lowest_i = min([i for i in range(len(aggr_vals))], key=lambda i: aggr_vals[i])
205
+ highest_i = max([i for i in range(len(aggr_vals))], key=lambda i: aggr_vals[i])
206
+ print(f'\tlowest: {aggr_vals[lowest_i]}, id: {ids[lowest_i]}', end='')
207
+ print(f'\thighest: {aggr_vals[highest_i]}, id: {ids[highest_i]}')
208
+ else:
209
+ print(f'{name} (mean): {sum(vals) / len(vals)}')
210
+ lowest_i = min([i for i in range(len(vals))], key=lambda i: vals[i])
211
+ highest_i = max([i for i in range(len(vals))], key=lambda i: vals[i])
212
+ print(f'\tlowest: {vals[lowest_i]}, id: {ids[lowest_i]}')
213
+ print(f'\thighest: {vals[highest_i]}, id: {ids[highest_i]}')
214
+
215
+
216
+ def parse():
217
+ parser = argparse.ArgumentParser(description='calculate metrics')
218
+ parser.add_argument('--results', type=str, required=True, help='Path to test set')
219
+ parser.add_argument('--num_workers', type=int, default=8, help='Number of workers to use')
220
+ parser.add_argument('--rosetta', action='store_true', help='Use the rosetta-refined structure')
221
+ parser.add_argument('--filter_dG', type=str, default=None, help='Only calculate results on samples with dG<0')
222
+
223
+ return parser.parse_args()
224
+
225
+
226
+ if __name__ == '__main__':
227
+ setup_seed(0)
228
+ main(parse())
configs/pepbdb/autoencoder/train_codesign.yaml ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset:
2
+ train:
3
+ - class: CoDesignDataset
4
+ mmap_dir: ./datasets/pepbdb/processed
5
+ specify_index: ./datasets/pepbdb/processed/train_index.txt
6
+ backbone_only: false
7
+ cluster: ./datasets/pepbdb/train.cluster
8
+ - class: CoDesignDataset
9
+ mmap_dir: ./datasets/ProtFrag/processed
10
+ backbone_only: false
11
+ valid:
12
+ class: CoDesignDataset
13
+ mmap_dir: ./datasets/pepbdb/processed
14
+ specify_index: ./datasets/pepbdb/processed/valid_index.txt
15
+ backbone_only: false
16
+
17
+ dataloader:
18
+ shuffle: true
19
+ num_workers: 4
20
+ wrapper:
21
+ class: DynamicBatchWrapper
22
+ complexity: n**2
23
+ ubound_per_batch: 60000 # batch size ~24
24
+
25
+ trainer:
26
+ class: AutoEncoderTrainer
27
+ config:
28
+ max_epoch: 100
29
+ save_topk: 10
30
+ save_dir: ./ckpts/autoencoder_codesign_pepbdb
31
+ patience: 10
32
+ metric_min_better: true
33
+
34
+ optimizer:
35
+ class: AdamW
36
+ lr: 1.0e-4
37
+
38
+ scheduler:
39
+ class: ReduceLROnPlateau
40
+ factor: 0.8
41
+ patience: 5
42
+ mode: min
43
+ frequency: val_epoch
44
+ min_lr: 5.0e-6
45
+
46
+ model:
47
+ class: AutoEncoder
48
+ embed_size: 128
49
+ hidden_size: 128
50
+ latent_size: 8
51
+ latent_n_channel: 1
52
+ n_layers: 3
53
+ n_channel: 14 # all atom
54
+ h_kl_weight: 0.3
55
+ z_kl_weight: 0.5
56
+ coord_loss_ratio: 0.5
57
+ coord_loss_weights:
58
+ Xloss: 1.0
59
+ ca_Xloss: 1.0
60
+ bb_bond_lengths_loss: 1.0
61
+ sc_bond_lengths_loss: 1.0
62
+ bb_dihedral_angles_loss: 0.0
63
+ sc_chi_angles_loss: 0.5
64
+ relative_position: false
65
+ anchor_at_ca: true
66
+ mask_ratio: 0.25
configs/pepbdb/autoencoder/train_fixseq.yaml ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset:
2
+ train:
3
+ class: CoDesignDataset
4
+ mmap_dir: ./datasets/pepbdb/processed
5
+ specify_index: ./datasets/pepbdb/processed/train_index.txt
6
+ backbone_only: false
7
+ cluster: ./datasets/pepbdb/train.cluster
8
+ valid:
9
+ class: CoDesignDataset
10
+ mmap_dir: ./datasets/pepbdb/processed
11
+ specify_index: ./datasets/pepbdb/processed/valid_index.txt
12
+ backbone_only: false
13
+
14
+ dataloader:
15
+ shuffle: true
16
+ num_workers: 4
17
+ wrapper:
18
+ class: DynamicBatchWrapper
19
+ complexity: n**2
20
+ ubound_per_batch: 60000 # batch size ~24
21
+
22
+ trainer:
23
+ class: AutoEncoderTrainer
24
+ config:
25
+ max_epoch: 150 # the best checkpoint should be obatained at about epoch 457
26
+ save_topk: 10
27
+ save_dir: ./ckpts/autoencoder_fixseq
28
+ patience: 10
29
+ metric_min_better: true
30
+
31
+ optimizer:
32
+ class: AdamW
33
+ lr: 1.0e-4
34
+
35
+ scheduler:
36
+ class: ReduceLROnPlateau
37
+ factor: 0.8
38
+ patience: 15
39
+ mode: min
40
+ frequency: val_epoch
41
+ min_lr: 5.0e-6
42
+
43
+ model:
44
+ class: AutoEncoder
45
+ embed_size: 128
46
+ hidden_size: 128
47
+ latent_size: 0
48
+ latent_n_channel: 1
49
+ n_layers: 3
50
+ n_channel: 14 # all atom
51
+ h_kl_weight: 0.0
52
+ z_kl_weight: 0.6
53
+ coord_loss_ratio: 1.0
54
+ coord_loss_weights:
55
+ Xloss: 1.0
56
+ ca_Xloss: 1.0
57
+ bb_bond_lengths_loss: 1.0
58
+ sc_bond_lengths_loss: 1.0
59
+ bb_dihedral_angles_loss: 0.0
60
+ sc_chi_angles_loss: 0.5
61
+ anchor_at_ca: true
62
+ mode: fixseq
63
+ additional_noise_scale: 1.0
configs/pepbdb/ldm/setup_latent_guidance.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset:
2
+ test:
3
+ class: CoDesignDataset
4
+ mmap_dir: ./datasets/pepbdb/processed
5
+ specify_index: ./datasets/pepbdb/processed/train_index.txt
6
+ backbone_only: false
7
+
8
+ dataloader:
9
+ num_workers: 2
10
+ batch_size: 32
11
+
12
+ backbone_only: false
configs/pepbdb/ldm/train_codesign.yaml ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset:
2
+ train:
3
+ - class: CoDesignDataset
4
+ mmap_dir: ./datasets/pepbdb/processed
5
+ specify_index: ./datasets/pepbdb/processed/train_index.txt
6
+ backbone_only: false
7
+ cluster: ./datasets/pepbdb/train.cluster
8
+ use_covariance_matrix: true
9
+ valid:
10
+ class: CoDesignDataset
11
+ mmap_dir: ./datasets/pepbdb/processed
12
+ specify_index: ./datasets/pepbdb/processed/valid_index.txt
13
+ backbone_only: false
14
+ use_covariance_matrix: true
15
+
16
+ dataloader:
17
+ shuffle: true
18
+ num_workers: 4
19
+ wrapper:
20
+ class: DynamicBatchWrapper
21
+ complexity: n**2
22
+ ubound_per_batch: 60000 # batch size ~32
23
+
24
+ trainer:
25
+ class: LDMTrainer
26
+ criterion: Loss
27
+ config:
28
+ max_epoch: 500 # the best checkpoint should be obtained at around epoch 380
29
+ save_topk: 10
30
+ val_freq: 10
31
+ save_dir: ./ckpts/LDM_codesign
32
+ patience: 10
33
+ metric_min_better: true
34
+
35
+ optimizer:
36
+ class: AdamW
37
+ lr: 1.0e-4
38
+
39
+ scheduler:
40
+ class: ReduceLROnPlateau
41
+ factor: 0.6
42
+ patience: 3
43
+ mode: min
44
+ frequency: val_epoch
45
+ min_lr: 5.0e-6
46
+
47
+ model:
48
+ class: LDMPepDesign
49
+ autoencoder_ckpt: ""
50
+ autoencoder_no_randomness: true
51
+ hidden_size: 128
52
+ num_steps: 100
53
+ n_layers: 3
54
+ n_rbf: 32
55
+ cutoff: 3.0 # the coordinates are in standard space
56
+ dist_rbf: 32
57
+ dist_rbf_cutoff: 7.0
58
+ diffusion_opt:
59
+ trans_seq_type: Diffusion
60
+ trans_pos_type: Diffusion
61
+ max_gen_position: 60
configs/pepbdb/ldm/train_fixseq.yaml ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset:
2
+ train:
3
+ - class: CoDesignDataset
4
+ mmap_dir: ./datasets/pepbdb/processed
5
+ specify_index: ./datasets/pepbdb/processed/train_index.txt
6
+ backbone_only: false
7
+ cluster: ./datasets/pepbdb/train.cluster
8
+ use_covariance_matrix: true
9
+ valid:
10
+ class: CoDesignDataset
11
+ mmap_dir: ./datasets/pepbdb/processed
12
+ specify_index: ./datasets/pepbdb/processed/valid_index.txt
13
+ backbone_only: false
14
+ use_covariance_matrix: true
15
+
16
+ dataloader:
17
+ shuffle: true
18
+ num_workers: 4
19
+ wrapper:
20
+ class: DynamicBatchWrapper
21
+ complexity: n**2
22
+ ubound_per_batch: 60000 # batch size ~32
23
+
24
+ trainer:
25
+ class: LDMTrainer
26
+ criterion: RMSD
27
+ config:
28
+ max_epoch: 1000 # the best checkpoint will be obtained at about 900 epoch
29
+ save_topk: 10
30
+ val_freq: 10
31
+ save_dir: ./ckpts/LDM_fixseq
32
+ patience: 10
33
+ metric_min_better: true
34
+
35
+ optimizer:
36
+ class: AdamW
37
+ lr: 1.0e-4
38
+
39
+ scheduler:
40
+ class: ReduceLROnPlateau
41
+ factor: 0.6
42
+ patience: 3
43
+ mode: min
44
+ frequency: val_epoch
45
+ min_lr: 5.0e-6
46
+
47
+ model:
48
+ class: LDMPepDesign
49
+ autoencoder_ckpt: ""
50
+ autoencoder_no_randomness: true
51
+ hidden_size: 128
52
+ num_steps: 100
53
+ n_layers: 6
54
+ n_rbf: 32
55
+ cutoff: 3.0 # the coordinates are in standard space
56
+ dist_rbf: 0
57
+ dist_rbf_cutoff: 0.0
58
+ diffusion_opt:
59
+ trans_seq_type: Diffusion
60
+ trans_pos_type: Diffusion
61
+ std: 20.0
62
+ mode: fixseq
63
+ max_gen_position: 60
configs/pepbdb/test_codesign.yaml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset:
2
+ test:
3
+ class: CoDesignDataset
4
+ mmap_dir: ./datasets/pepbdb/processed
5
+ specify_index: ./datasets/pepbdb/processed/test_index.txt
6
+ backbone_only: false
7
+ use_covariance_matrix: true
8
+
9
+ dataloader:
10
+ num_workers: 4
11
+ batch_size: 64
12
+
13
+ backbone_only: false
14
+ n_samples: 40
15
+
16
+ sample_opt:
17
+ energy_func: default
18
+ energy_lambda: 0.8
configs/pepbdb/test_fixseq.yaml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset:
2
+ test:
3
+ class: CoDesignDataset
4
+ mmap_dir: ./datasets/pepbdb/processed
5
+ specify_index: ./datasets/pepbdb/processed/test_index.txt
6
+ backbone_only: false
7
+ use_covariance_matrix: true
8
+
9
+ dataloader:
10
+ num_workers: 4
11
+ batch_size: 64
12
+
13
+ backbone_only: false
14
+ struct_only: true
15
+ n_samples: 10
16
+
17
+ sample_opt:
18
+ energy_func: default
19
+ energy_lambda: 0.8
configs/pepbench/autoencoder/train_codesign.yaml ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset:
2
+ train:
3
+ - class: CoDesignDataset
4
+ mmap_dir: ./datasets/train_valid/processed
5
+ specify_index: ./datasets/train_valid/processed/train_index.txt
6
+ backbone_only: false
7
+ cluster: ./datasets/train_valid/train.cluster
8
+ - class: CoDesignDataset
9
+ mmap_dir: ./datasets/ProtFrag/processed
10
+ backbone_only: false
11
+ valid:
12
+ class: CoDesignDataset
13
+ mmap_dir: ./datasets/train_valid/processed
14
+ specify_index: ./datasets/train_valid/processed/valid_index.txt
15
+ backbone_only: false
16
+
17
+ dataloader:
18
+ shuffle: true
19
+ num_workers: 4
20
+ wrapper:
21
+ class: DynamicBatchWrapper
22
+ complexity: n**2
23
+ ubound_per_batch: 60000 # batch size ~24
24
+
25
+ trainer:
26
+ class: AutoEncoderTrainer
27
+ config:
28
+ max_epoch: 100
29
+ save_topk: 10
30
+ save_dir: ./ckpts/autoencoder_codesign
31
+ patience: 10
32
+ metric_min_better: true
33
+
34
+ optimizer:
35
+ class: AdamW
36
+ lr: 1.0e-4
37
+
38
+ scheduler:
39
+ class: ReduceLROnPlateau
40
+ factor: 0.8
41
+ patience: 5
42
+ mode: min
43
+ frequency: val_epoch
44
+ min_lr: 5.0e-6
45
+
46
+ model:
47
+ class: AutoEncoder
48
+ embed_size: 128
49
+ hidden_size: 128
50
+ latent_size: 8
51
+ latent_n_channel: 1
52
+ n_layers: 3
53
+ n_channel: 14 # all atom
54
+ h_kl_weight: 0.3
55
+ z_kl_weight: 0.5
56
+ coord_loss_ratio: 0.5
57
+ coord_loss_weights:
58
+ Xloss: 1.0
59
+ ca_Xloss: 1.0
60
+ bb_bond_lengths_loss: 1.0
61
+ sc_bond_lengths_loss: 1.0
62
+ bb_dihedral_angles_loss: 0.0
63
+ sc_chi_angles_loss: 0.5
64
+ relative_position: false
65
+ anchor_at_ca: true
66
+ mask_ratio: 0.25
configs/pepbench/autoencoder/train_fixseq.yaml ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset:
2
+ train:
3
+ class: CoDesignDataset
4
+ mmap_dir: ./datasets/train_valid/processed
5
+ specify_index: ./datasets/train_valid/processed/train_index.txt
6
+ backbone_only: false
7
+ cluster: ./datasets/train_valid/train.cluster
8
+ valid:
9
+ class: CoDesignDataset
10
+ mmap_dir: ./datasets/train_valid/processed
11
+ specify_index: ./datasets/train_valid/processed/valid_index.txt
12
+ backbone_only: false
13
+
14
+ dataloader:
15
+ shuffle: true
16
+ num_workers: 4
17
+ wrapper:
18
+ class: DynamicBatchWrapper
19
+ complexity: n**2
20
+ ubound_per_batch: 60000 # batch size ~24
21
+
22
+ trainer:
23
+ class: AutoEncoderTrainer
24
+ config:
25
+ max_epoch: 500 # the best checkpoint should be obatained at about epoch 457
26
+ save_topk: 10
27
+ save_dir: ./ckpts/autoencoder_fixseq
28
+ patience: 10
29
+ metric_min_better: true
30
+
31
+ optimizer:
32
+ class: AdamW
33
+ lr: 1.0e-4
34
+
35
+ scheduler:
36
+ class: ReduceLROnPlateau
37
+ factor: 0.8
38
+ patience: 15
39
+ mode: min
40
+ frequency: val_epoch
41
+ min_lr: 5.0e-6
42
+
43
+ model:
44
+ class: AutoEncoder
45
+ embed_size: 128
46
+ hidden_size: 128
47
+ latent_size: 0
48
+ latent_n_channel: 1
49
+ n_layers: 3
50
+ n_channel: 14 # all atom
51
+ h_kl_weight: 0.0
52
+ z_kl_weight: 1.0
53
+ coord_loss_ratio: 1.0
54
+ coord_loss_weights:
55
+ Xloss: 1.0
56
+ ca_Xloss: 1.0
57
+ bb_bond_lengths_loss: 1.0
58
+ sc_bond_lengths_loss: 1.0
59
+ bb_dihedral_angles_loss: 0.0
60
+ sc_chi_angles_loss: 0.5
61
+ anchor_at_ca: true
62
+ mode: fixseq
configs/pepbench/ldm/setup_latent_guidance.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset:
2
+ test:
3
+ class: CoDesignDataset
4
+ mmap_dir: ./datasets/train_valid/processed
5
+ specify_index: ./datasets/train_valid/processed/train_index.txt
6
+ backbone_only: false
7
+
8
+ dataloader:
9
+ num_workers: 2
10
+ batch_size: 32
11
+
12
+ backbone_only: false
configs/pepbench/ldm/train_codesign.yaml ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset:
2
+ train:
3
+ class: CoDesignDataset
4
+ mmap_dir: ./datasets/train_valid/processed
5
+ specify_index: ./datasets/train_valid/processed/train_index.txt
6
+ backbone_only: false
7
+ cluster: ./datasets/train_valid/train.cluster
8
+ use_covariance_matrix: true
9
+ valid:
10
+ class: CoDesignDataset
11
+ mmap_dir: ./datasets/train_valid/processed
12
+ specify_index: ./datasets/train_valid/processed/valid_index.txt
13
+ backbone_only: false
14
+ use_covariance_matrix: true
15
+
16
+ dataloader:
17
+ shuffle: true
18
+ num_workers: 4
19
+ wrapper:
20
+ class: DynamicBatchWrapper
21
+ complexity: n**2
22
+ ubound_per_batch: 60000 # batch size ~32
23
+
24
+ trainer:
25
+ class: LDMTrainer
26
+ criterion: Loss
27
+ config:
28
+ max_epoch: 500 # the best checkpoint should be obtained at around epoch 380
29
+ save_topk: 10
30
+ val_freq: 10
31
+ save_dir: ./ckpts/LDM_codesign
32
+ patience: 10
33
+ metric_min_better: true
34
+
35
+ optimizer:
36
+ class: AdamW
37
+ lr: 1.0e-4
38
+
39
+ scheduler:
40
+ class: ReduceLROnPlateau
41
+ factor: 0.6
42
+ patience: 3
43
+ mode: min
44
+ frequency: val_epoch
45
+ min_lr: 5.0e-6
46
+
47
+ model:
48
+ class: LDMPepDesign
49
+ autoencoder_ckpt: ""
50
+ autoencoder_no_randomness: true
51
+ hidden_size: 128
52
+ num_steps: 100
53
+ n_layers: 3
54
+ n_rbf: 32
55
+ cutoff: 3.0 # the coordinates are in standard space
56
+ dist_rbf: 32
57
+ dist_rbf_cutoff: 7.0
58
+ diffusion_opt:
59
+ trans_seq_type: Diffusion
60
+ trans_pos_type: Diffusion
configs/pepbench/ldm/train_fixseq.yaml ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset:
2
+ train:
3
+ class: CoDesignDataset
4
+ mmap_dir: ./datasets/train_valid/processed
5
+ specify_index: ./datasets/train_valid/processed/train_index.txt
6
+ backbone_only: false
7
+ cluster: ./datasets/train_valid/train.cluster
8
+ use_covariance_matrix: true
9
+ valid:
10
+ class: CoDesignDataset
11
+ mmap_dir: ./datasets/train_valid/processed
12
+ specify_index: ./datasets/train_valid/processed/valid_index.txt
13
+ backbone_only: false
14
+ use_covariance_matrix: true
15
+
16
+ dataloader:
17
+ shuffle: true
18
+ num_workers: 4
19
+ wrapper:
20
+ class: DynamicBatchWrapper
21
+ complexity: n**2
22
+ ubound_per_batch: 60000 # batch size ~32
23
+
24
+ trainer:
25
+ class: LDMTrainer
26
+ criterion: RMSD
27
+ config:
28
+ max_epoch: 1000 # the best checkpoint will be obtained at about 720 epoch
29
+ save_topk: 10
30
+ val_freq: 10
31
+ save_dir: ./ckpts/LDM_fixseq
32
+ patience: 10
33
+ metric_min_better: true
34
+
35
+ optimizer:
36
+ class: AdamW
37
+ lr: 1.0e-4
38
+
39
+ scheduler:
40
+ class: ReduceLROnPlateau
41
+ factor: 0.6
42
+ patience: 3
43
+ mode: min
44
+ frequency: val_epoch
45
+ min_lr: 5.0e-6
46
+
47
+ model:
48
+ class: LDMPepDesign
49
+ autoencoder_ckpt: ""
50
+ autoencoder_no_randomness: true
51
+ hidden_size: 128
52
+ num_steps: 100
53
+ n_layers: 3
54
+ n_rbf: 32
55
+ cutoff: 3.0 # the coordinates are in standard space
56
+ dist_rbf: 0
57
+ dist_rbf_cutoff: 0.0
58
+ diffusion_opt:
59
+ trans_seq_type: Diffusion
60
+ trans_pos_type: Diffusion
61
+ mode: fixseq
configs/pepbench/test_codesign.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset:
2
+ test:
3
+ class: CoDesignDataset
4
+ mmap_dir: ./datasets/LNR/processed
5
+ backbone_only: false
6
+ use_covariance_matrix: true
7
+
8
+ dataloader:
9
+ num_workers: 4
10
+ batch_size: 64
11
+
12
+ backbone_only: false
13
+ n_samples: 40
14
+
15
+ sample_opt:
16
+ energy_func: default
17
+ energy_lambda: 0.8
configs/pepbench/test_fixseq.yaml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset:
2
+ test:
3
+ class: CoDesignDataset
4
+ mmap_dir: ./datasets/LNR/processed
5
+ backbone_only: false
6
+ use_covariance_matrix: true
7
+
8
+ dataloader:
9
+ num_workers: 4
10
+ batch_size: 64
11
+
12
+ backbone_only: false
13
+ struct_only: true
14
+ n_samples: 10
15
+
16
+ sample_opt:
17
+ energy_func: default
18
+ energy_lambda: 0.5
data/__init__.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python
2
+ # -*- coding:utf-8 -*-
3
+ from .dataset_wrapper import MixDatasetWrapper
4
+ from .codesign import CoDesignDataset
5
+ from .resample import ClusterResampler
6
+
7
+
8
+ import torch
9
+ from torch.utils.data import DataLoader
10
+
11
+ import utils.register as R
12
+ from utils.logger import print_log
13
+
14
+ def create_dataset(config: dict):
15
+ splits = []
16
+ for split_name in ['train', 'valid', 'test']:
17
+ split_config = config.get(split_name, None)
18
+ if split_config is None:
19
+ splits.append(None)
20
+ continue
21
+ if isinstance(split_config, list):
22
+ dataset = MixDatasetWrapper(
23
+ *[R.construct(cfg) for cfg in split_config]
24
+ )
25
+ else:
26
+ dataset = R.construct(split_config)
27
+ splits.append(dataset)
28
+ return splits # train/valid/test
29
+
30
+
31
+ def create_dataloader(dataset, config: dict, n_gpu: int=1, validation: bool=False):
32
+ if 'wrapper' in config:
33
+ dataset = R.construct(config['wrapper'], dataset=dataset)
34
+ batch_size = config.get('batch_size', n_gpu) # default 1 on each gpu
35
+ if validation:
36
+ batch_size = config.get('val_batch_size', batch_size)
37
+ shuffle = config.get('shuffle', False)
38
+ num_workers = config.get('num_workers', 4)
39
+ collate_fn = dataset.collate_fn if hasattr(dataset, 'collate_fn') else None
40
+ if n_gpu > 1:
41
+ sampler = torch.utils.data.distributed.DistributedSampler(dataset, shuffle=shuffle)
42
+ batch_size = int(batch_size / n_gpu)
43
+ print_log(f'Batch size on a single GPU: {batch_size}')
44
+ else:
45
+ sampler = None
46
+ return DataLoader(
47
+ dataset=dataset,
48
+ batch_size=batch_size,
49
+ num_workers=num_workers,
50
+ shuffle=(shuffle and sampler is None),
51
+ collate_fn=collate_fn,
52
+ sampler=sampler
53
+ )
data/codesign.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ from typing import Optional, Any
4
+
5
+ import numpy as np
6
+ import torch
7
+ from torch.nn.utils.rnn import pad_sequence
8
+
9
+ from utils import register as R
10
+ from utils.const import sidechain_atoms
11
+
12
+ from data.converter.list_blocks_to_pdb import list_blocks_to_pdb
13
+
14
+ from .format import VOCAB, Block, Atom
15
+ from .mmap_dataset import MMAPDataset
16
+ from .resample import ClusterResampler
17
+
18
+
19
+
20
+ def calculate_covariance_matrix(point_cloud):
21
+ # Calculate the covariance matrix of the point cloud
22
+ covariance_matrix = np.cov(point_cloud, rowvar=False)
23
+ return covariance_matrix
24
+
25
+
26
+ @R.register('CoDesignDataset')
27
+ class CoDesignDataset(MMAPDataset):
28
+
29
+ MAX_N_ATOM = 14
30
+
31
+ def __init__(
32
+ self,
33
+ mmap_dir: str,
34
+ backbone_only: bool, # only backbone (N, CA, C, O) or full-atom
35
+ specify_data: Optional[str] = None,
36
+ specify_index: Optional[str] = None,
37
+ padding_collate: bool = False,
38
+ cluster: Optional[str] = None,
39
+ use_covariance_matrix: bool = False
40
+ ) -> None:
41
+ super().__init__(mmap_dir, specify_data, specify_index)
42
+ self.mmap_dir = mmap_dir
43
+ self.backbone_only = backbone_only
44
+ self._lengths = [len(prop[-1].split(',')) + int(prop[1]) for prop in self._properties]
45
+ self.padding_collate = padding_collate
46
+ self.resampler = ClusterResampler(cluster) if cluster else None # should only be used in training!
47
+ self.use_covariance_matrix = use_covariance_matrix
48
+
49
+ self.dynamic_idxs = [i for i in range(len(self))]
50
+ self.update_epoch() # should be called every epoch
51
+
52
+ def update_epoch(self):
53
+ if self.resampler is not None:
54
+ self.dynamic_idxs = self.resampler(len(self))
55
+
56
+ def get_len(self, idx):
57
+ return self._lengths[self.dynamic_idxs[idx]]
58
+
59
+ def get_summary(self, idx: int):
60
+ props = self._properties[idx]
61
+ _id = self._indexes[idx][0].split('.')[0]
62
+ ref_pdb = os.path.join(self.mmap_dir, '..', 'pdbs', _id + '.pdb')
63
+ rec_chain, lig_chain = props[4], props[5]
64
+ return _id, ref_pdb, rec_chain, lig_chain
65
+
66
+ def __getitem__(self, idx: int):
67
+ idx = self.dynamic_idxs[idx]
68
+ rec_blocks, lig_blocks = super().__getitem__(idx)
69
+ # receptor, (lig_chain_id, lig_blocks) = super().__getitem__(idx)
70
+ # pocket = {}
71
+ # for i in self._properties[idx][-1].split(','):
72
+ # chain, i = i.split(':')
73
+ # if chain not in pocket:
74
+ # pocket[chain] = []
75
+ # pocket[chain].append(int(i))
76
+ # rec_blocks = []
77
+ # for chain_id, blocks in receptor:
78
+ # for i in pocket[chain_id]:
79
+ # rec_blocks.append(blocks[i])
80
+ pocket_idx = [int(i) for i in self._properties[idx][-1].split(',')]
81
+ rec_position_ids = [i + 1 for i, _ in enumerate(rec_blocks)]
82
+ rec_blocks = [rec_blocks[i] for i in pocket_idx]
83
+ rec_position_ids = [rec_position_ids[i] for i in pocket_idx]
84
+ rec_blocks = [Block.from_tuple(tup) for tup in rec_blocks]
85
+ lig_blocks = [Block.from_tuple(tup) for tup in lig_blocks]
86
+
87
+ # for block in lig_blocks:
88
+ # block.units = [Atom('CA', [0, 0, 0], 'C')]
89
+ # if idx == 0:
90
+ # print(self._properties[idx])
91
+ # print(''.join(VOCAB.abrv_to_symbol(block.abrv) for block in lig_blocks))
92
+ # list_blocks_to_pdb([
93
+ # rec_blocks, lig_blocks
94
+ # ], ['B', 'A'], 'pocket.pdb')
95
+
96
+ mask = [0 for _ in rec_blocks] + [1 for _ in lig_blocks]
97
+ position_ids = rec_position_ids + [i + 1 for i, _ in enumerate(lig_blocks)]
98
+ X, S, atom_mask = [], [], []
99
+ for block in rec_blocks + lig_blocks:
100
+ symbol = VOCAB.abrv_to_symbol(block.abrv)
101
+ atom2coord = { unit.name: unit.get_coord() for unit in block.units }
102
+ bb_pos = np.mean(list(atom2coord.values()), axis=0).tolist()
103
+ coords, coord_mask = [], []
104
+ for atom_name in VOCAB.backbone_atoms + sidechain_atoms.get(symbol, []):
105
+ if atom_name in atom2coord:
106
+ coords.append(atom2coord[atom_name])
107
+ coord_mask.append(1)
108
+ else:
109
+ coords.append(bb_pos)
110
+ coord_mask.append(0)
111
+ n_pad = self.MAX_N_ATOM - len(coords)
112
+ for _ in range(n_pad):
113
+ coords.append(bb_pos)
114
+ coord_mask.append(0)
115
+
116
+ X.append(coords)
117
+ S.append(VOCAB.symbol_to_idx(symbol))
118
+ atom_mask.append(coord_mask)
119
+
120
+ X, atom_mask = torch.tensor(X, dtype=torch.float), torch.tensor(atom_mask, dtype=torch.bool)
121
+ mask = torch.tensor(mask, dtype=torch.bool)
122
+ if self.backbone_only:
123
+ X, atom_mask = X[:, :4], atom_mask[:, :4]
124
+
125
+ if self.use_covariance_matrix:
126
+ cov = calculate_covariance_matrix(X[~mask][:, 1][atom_mask[~mask][:, 1]].numpy()) # only use the receptor to derive the affine transformation
127
+ eps = 1e-4
128
+ cov = cov + eps * np.identity(cov.shape[0])
129
+ L = torch.from_numpy(np.linalg.cholesky(cov)).float().unsqueeze(0)
130
+ else:
131
+ L = None
132
+
133
+ item = {
134
+ 'X': X, # [N, 14] or [N, 4] if backbone_only == True
135
+ 'S': torch.tensor(S, dtype=torch.long), # [N]
136
+ 'position_ids': torch.tensor(position_ids, dtype=torch.long), # [N]
137
+ 'mask': mask, # [N], 1 for generation
138
+ 'atom_mask': atom_mask, # [N, 14] or [N, 4], 1 for having records in the PDB
139
+ 'lengths': len(S),
140
+ }
141
+ if L is not None:
142
+ item['L'] = L
143
+ return item
144
+
145
+ def collate_fn(self, batch):
146
+ if self.padding_collate:
147
+ results = {}
148
+ pad_idx = VOCAB.symbol_to_idx(VOCAB.PAD)
149
+ for key in batch[0]:
150
+ values = [item[key] for item in batch]
151
+ if values[0] is None:
152
+ results[key] = None
153
+ continue
154
+ if key == 'lengths':
155
+ results[key] = torch.tensor(values, dtype=torch.long)
156
+ elif key == 'S':
157
+ results[key] = pad_sequence(values, batch_first=True, padding_value=pad_idx)
158
+ else:
159
+ results[key] = pad_sequence(values, batch_first=True, padding_value=0)
160
+ return results
161
+ else:
162
+ results = {}
163
+ for key in batch[0]:
164
+ values = [item[key] for item in batch]
165
+ if values[0] is None:
166
+ results[key] = None
167
+ continue
168
+ if key == 'lengths':
169
+ results[key] = torch.tensor(values, dtype=torch.long)
170
+ else:
171
+ results[key] = torch.cat(values, dim=0)
172
+ return results
173
+
174
+
175
+ @R.register('ShapeDataset')
176
+ class ShapeDataset(CoDesignDataset):
177
+ def __init__(
178
+ self,
179
+ mmap_dir: str,
180
+ specify_data: Optional[str] = None,
181
+ specify_index: Optional[str] = None,
182
+ padding_collate: bool = False,
183
+ cluster: Optional[str] = None
184
+ ) -> None:
185
+ super().__init__(mmap_dir, False, specify_data, specify_index, padding_collate, cluster)
186
+ self.ca_idx = VOCAB.backbone_atoms.index('CA')
187
+
188
+ def __getitem__(self, idx: int):
189
+ item = super().__getitem__(idx)
190
+
191
+ # refine coordinates to CA and the atom furthest from CA
192
+ X = item['X'] # [N, 14, 3]
193
+ atom_mask = item['atom_mask']
194
+ ca_x = X[:, self.ca_idx].unsqueeze(1) # [N, 1, 3]
195
+ sc_x = X[:, 4:] # [N, 10, 3], sidechain atom indexes
196
+ dist = torch.norm(sc_x - ca_x, dim=-1) # [N, 10]
197
+ dist = dist.masked_fill(~atom_mask[:, 4:], 1e10)
198
+ furthest_atom_x = sc_x[torch.arange(sc_x.shape[0]), torch.argmax(dist, dim=-1)] # [N, 3]
199
+ X = torch.cat([ca_x, furthest_atom_x.unsqueeze(1)], dim=1)
200
+
201
+ item['X'] = X
202
+ return item
203
+
204
+
205
+ if __name__ == '__main__':
206
+ import sys
207
+ dataset = CoDesignDataset(sys.argv[1], backbone_only=True)
208
+ print(dataset[0])
data/converter/blocks_interface.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python
2
+ # -*- coding:utf-8 -*-
3
+ import numpy as np
4
+
5
+
6
+ def blocks_to_coords(blocks):
7
+ max_n_unit = 0
8
+ coords, masks = [], []
9
+ for block in blocks:
10
+ coords.append([unit.get_coord() for unit in block.units])
11
+ max_n_unit = max(max_n_unit, len(coords[-1]))
12
+ masks.append([1 for _ in coords[-1]])
13
+
14
+ for i in range(len(coords)):
15
+ num_pad = max_n_unit - len(coords[i])
16
+ coords[i] = coords[i] + [[0, 0, 0] for _ in range(num_pad)]
17
+ masks[i] = masks[i] + [0 for _ in range(num_pad)]
18
+
19
+ return np.array(coords), np.array(masks).astype('bool') # [N, M, 3], [N, M], M == max_n_unit, in mask 0 is for padding
20
+
21
+
22
+ def dist_matrix_from_coords(coords1, masks1, coords2, masks2):
23
+ dist = np.linalg.norm(coords1[:, None] - coords2[None, :], axis=-1) # [N1, N2, M]
24
+ dist = dist + np.logical_not(masks1[:, None] * masks2[None, :]) * 1e6 # [N1, N2, M]
25
+ dist = np.min(dist, axis=-1) # [N1, N2]
26
+ return dist
27
+
28
+
29
+ def dist_matrix_from_blocks(blocks1, blocks2):
30
+ blocks_coord, blocks_mask = blocks_to_coords(blocks1 + blocks2)
31
+ blocks1_coord, blocks1_mask = blocks_coord[:len(blocks1)], blocks_mask[:len(blocks1)]
32
+ blocks2_coord, blocks2_mask = blocks_coord[len(blocks1):], blocks_mask[len(blocks1):]
33
+ dist = dist_matrix_from_coords(blocks1_coord, blocks1_mask, blocks2_coord, blocks2_mask)
34
+ return dist
35
+
36
+
37
+ def blocks_interface(blocks1, blocks2, dist_th):
38
+ dist = dist_matrix_from_blocks(blocks1, blocks2)
39
+
40
+ on_interface = dist < dist_th
41
+ indexes1 = np.nonzero(on_interface.sum(axis=1) > 0)[0]
42
+ indexes2 = np.nonzero(on_interface.sum(axis=0) > 0)[0]
43
+
44
+ blocks1 = [blocks1[i] for i in indexes1]
45
+ blocks2 = [blocks2[i] for i in indexes2]
46
+
47
+ return (blocks1, blocks2), (indexes1, indexes2)
48
+
49
+
50
+ def add_cb(input_array):
51
+ #from protein mpnn
52
+ #The virtual Cβ coordinates were calculated using ideal angle and bond length definitions: b = Cα - N, c = C - Cα, a = cross(b, c), Cβ = -0.58273431*a + 0.56802827*b - 0.54067466*c + Cα.
53
+ N,CA,C,O = input_array
54
+ b = CA - N
55
+ c = C - CA
56
+ a = np.cross(b,c)
57
+ CB = np.around(-0.58273431*a + 0.56802827*b - 0.54067466*c + CA,3)
58
+ return CB #np.array([N,CA,C,CB,O])
59
+
60
+
61
+ def blocks_to_cb_coords(blocks):
62
+ cb_coords = []
63
+ for block in blocks:
64
+ try:
65
+ cb_coords.append(block.get_unit_by_name('CB').get_coord())
66
+ except KeyError:
67
+ tmp_coord = np.array([
68
+ block.get_unit_by_name('N').get_coord(),
69
+ block.get_unit_by_name('CA').get_coord(),
70
+ block.get_unit_by_name('C').get_coord(),
71
+ block.get_unit_by_name('O').get_coord()
72
+ ])
73
+ cb_coords.append(add_cb(tmp_coord))
74
+ return np.array(cb_coords)
75
+
76
+
77
+ def blocks_cb_interface(blocks1, blocks2, dist_th=8.0):
78
+ cb_coords1 = blocks_to_cb_coords(blocks1)
79
+ cb_coords2 = blocks_to_cb_coords(blocks2)
80
+ dist = np.linalg.norm(cb_coords1[:, None] - cb_coords2[None, :], axis=-1) # [N1, N2]
81
+
82
+ on_interface = dist < dist_th
83
+ indexes1 = np.nonzero(on_interface.sum(axis=1) > 0)[0]
84
+ indexes2 = np.nonzero(on_interface.sum(axis=0) > 0)[0]
85
+
86
+ blocks1 = [blocks1[i] for i in indexes1]
87
+ blocks2 = [blocks2[i] for i in indexes2]
88
+
89
+ return (blocks1, blocks2), (indexes1, indexes2)
data/converter/blocks_to_data.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python
2
+ # -*- coding:utf-8 -*-
3
+ from typing import List
4
+
5
+ import numpy as np
6
+
7
+ from data.format import VOCAB, Block
8
+ from utils import const
9
+
10
+
11
+ def blocks_to_data(*blocks_list: List[List[Block]]):
12
+ B, A, X, atom_positions, block_lengths, segment_ids = [], [], [], [], [], []
13
+ atom_mask, is_ca = [], []
14
+ topo_edge_index, topo_edge_attr, atom_names = [], [], []
15
+ last_c_node_id = None
16
+ for i, blocks in enumerate(blocks_list):
17
+ if len(blocks) == 0:
18
+ continue
19
+ cur_B, cur_A, cur_X, cur_atom_positions, cur_block_lengths = [], [], [], [], []
20
+ cur_atom_mask, cur_is_ca = [], []
21
+ # other nodes
22
+ for block in blocks:
23
+ b, symbol = VOCAB.abrv_to_idx(block.abrv), VOCAB.abrv_to_symbol(block.abrv)
24
+ x, a, positions, m, ca = [], [], [], [], []
25
+ atom2node_id = {}
26
+ if symbol == '?':
27
+ atom_missing = {}
28
+ else:
29
+ atom_missing = { atom_name: True for atom_name in const.backbone_atoms + const.sidechain_atoms[symbol] }
30
+ for atom in block:
31
+ atom2node_id[atom.name] = len(A) + len(cur_A) + len(a)
32
+ a.append(VOCAB.atom_to_idx(atom.get_element()))
33
+ x.append(atom.get_coord())
34
+ pos_code = ''.join((c for c in atom.get_pos_code() if not c.isdigit()))
35
+ positions.append(VOCAB.atom_pos_to_idx(pos_code))
36
+ if atom.name in atom_missing:
37
+ atom_missing[atom.name] = False
38
+ m.append(1)
39
+ ca.append(atom.name == 'CA')
40
+ atom_names.append(atom.name)
41
+ for atom_name in atom_missing:
42
+ if atom_missing[atom_name]:
43
+ atom2node_id[atom_name] = len(A) + len(cur_A) + len(a)
44
+ a.append(VOCAB.atom_to_idx(atom_name[0])) # only C, N, O, S in proteins
45
+ x.append([0, 0, 0])
46
+ pos_code = ''.join((c for c in atom_name[1:] if not c.isdigit()))
47
+ positions.append(VOCAB.atom_pos_to_idx(pos_code))
48
+ m.append(0)
49
+ ca.append(atom_name == 'CA')
50
+ atom_names.append(atom_name)
51
+ block_len = len(a)
52
+ cur_B.append(b)
53
+ cur_A.extend(a)
54
+ cur_X.extend(x)
55
+ cur_atom_positions.extend(positions)
56
+ cur_block_lengths.append(block_len)
57
+ cur_atom_mask.extend(m)
58
+ cur_is_ca.extend(ca)
59
+
60
+ # topology edges
61
+ for src, dst, bond_type in const.sidechain_bonds.get(VOCAB.abrv_to_symbol(block.abrv), []):
62
+ src, dst = atom2node_id[src], atom2node_id[dst]
63
+ topo_edge_index.append((src, dst)) # no direction
64
+ topo_edge_index.append((dst, src))
65
+ topo_edge_attr.append(bond_type)
66
+ topo_edge_attr.append(bond_type)
67
+ if last_c_node_id is not None and ('CA' in atom2node_id):
68
+ src, dst = last_c_node_id, atom2node_id['N']
69
+ topo_edge_index.append((src, dst)) # no direction
70
+ topo_edge_index.append((dst, src))
71
+ topo_edge_attr.append(4)
72
+ topo_edge_attr.append(4)
73
+ if 'CA' not in atom2node_id:
74
+ last_c_node_id = None
75
+ else:
76
+ last_c_node_id = atom2node_id['C']
77
+
78
+ # update coordinates of the global node to the center
79
+ # cur_X[0] = np.mean(cur_X[1:], axis=0)
80
+ cur_segment_ids = [i for _ in cur_B]
81
+
82
+ # finish these blocks
83
+ B.extend(cur_B)
84
+ A.extend(cur_A)
85
+ X.extend(cur_X)
86
+ atom_positions.extend(cur_atom_positions)
87
+ block_lengths.extend(cur_block_lengths)
88
+ segment_ids.extend(cur_segment_ids)
89
+ atom_mask.extend(cur_atom_mask)
90
+ is_ca.extend(cur_is_ca)
91
+
92
+ X = np.array(X).tolist()
93
+ topo_edge_index = np.array(topo_edge_index).T.tolist()
94
+ topo_edge_attr = (np.array(topo_edge_attr) - 1).tolist() # type starts from 0 but bond type starts from 1
95
+
96
+ data = {
97
+ 'X': X, # [Natom, 2, 3]
98
+ 'B': B, # [Nb], block (residue) type
99
+ 'A': A, # [Natom]
100
+ 'atom_positions': atom_positions, # [Natom]
101
+ 'block_lengths': block_lengths, # [Nresidue]
102
+ 'segment_ids': segment_ids, # [Nresidue]
103
+ 'atom_mask': atom_mask, # [Natom]
104
+ 'is_ca': is_ca, # [Natom]
105
+ 'atom_names': atom_names, # [Natom]
106
+ 'topo_edge_index': topo_edge_index, # atom level
107
+ 'topo_edge_attr': topo_edge_attr
108
+ }
109
+
110
+ return data
data/converter/list_blocks_to_pdb.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python
2
+ # -*- coding:utf-8 -*-
3
+ import os
4
+ from typing import List
5
+
6
+ import numpy as np
7
+
8
+ from Bio.PDB import PDBParser, PDBIO
9
+ from Bio.PDB.Structure import Structure as BStructure
10
+ from Bio.PDB.Model import Model as BModel
11
+ from Bio.PDB.Chain import Chain as BChain
12
+ from Bio.PDB.Residue import Residue as BResidue
13
+ from Bio.PDB.Atom import Atom as BAtom
14
+
15
+ from data.format import Block, Atom, VOCAB
16
+
17
+
18
+ def list_blocks_to_pdb(list_blocks: List[List[Block]], chain_names: List[str], out_path: str) -> None:
19
+ '''
20
+ Convert pdb file to a list of lists of blocks using Biopython.
21
+ Each chain will be a list of blocks.
22
+
23
+ Parameters:
24
+ list_blocks: A list of lists of blocks. Each list of blocks will be parsed into one chain in the pdb
25
+ chain_names: name of chains
26
+ out_path: Path to the pdb file
27
+
28
+ '''
29
+ pdb_id = os.path.basename(os.path.splitext(out_path)[0])
30
+ structure = BStructure(id=pdb_id)
31
+ model = BModel(id=0)
32
+ for blocks, chain_name in zip(list_blocks, chain_names):
33
+ chain = BChain(id=chain_name)
34
+ for i, block in enumerate(blocks):
35
+ chain.add(_block_to_biopython(block, i))
36
+ model.add(chain)
37
+ structure.add(model)
38
+ io = PDBIO()
39
+ io.set_structure(structure)
40
+ io.save(out_path)
41
+
42
+
43
+ def _block_to_biopython(block: Block, pos_code: int) -> BResidue:
44
+ _id = (' ', pos_code, ' ')
45
+ residue = BResidue(_id, block.abrv, ' ')
46
+ for i, atom in enumerate(block):
47
+ fullname = ' ' + atom.name
48
+ while len(fullname) < 4:
49
+ fullname += ' '
50
+ bio_atom = BAtom(
51
+ name=atom,
52
+ coord=np.array(atom.coordinate, dtype=np.float32),
53
+ bfactor=0,
54
+ occupancy=1.0,
55
+ altloc=' ',
56
+ fullname=fullname,
57
+ serial_number=i,
58
+ element=atom.element
59
+ )
60
+ residue.add(bio_atom)
61
+ return residue
data/converter/pdb_to_list_blocks.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python
2
+ # -*- coding:utf-8 -*-
3
+ from typing import Dict, List, Optional, Union
4
+
5
+ from Bio.PDB import PDBParser
6
+
7
+ from data.format import Block, Atom
8
+
9
+
10
+ def pdb_to_list_blocks(pdb: str, selected_chains: Optional[List[str]]=None, return_chain_ids: bool=False, dict_form: bool=False) -> Union[List[List[Block]], Dict[str, List[Block]]]:
11
+ '''
12
+ Convert pdb file to a list of lists of blocks using Biopython.
13
+ Each chain will be a list of blocks.
14
+
15
+ Parameters:
16
+ pdb: Path to the pdb file
17
+ selected_chains: List of selected chain ids. The returned list will be ordered
18
+ according to the ordering of chain ids in this parameter. If not specified,
19
+ all chains will be returned. e.g. ['A', 'B']
20
+ return_chain_ids: Whether to return the ids of each chain
21
+ dict_form: Whether to return chains in dict form (chain id as the key and blocks
22
+ as the value)
23
+
24
+ Returns:
25
+ A list of lists of blocks. Each chain in the pdb file will be parsed into
26
+ one list of blocks.
27
+ example:
28
+ [
29
+ [residueA1, residueA2, ...], # chain A
30
+ [residueB1, residueB2, ...] # chain B
31
+ ],
32
+ where each residue is instantiated by Block data class.
33
+ '''
34
+
35
+ parser = PDBParser(QUIET=True)
36
+ structure = parser.get_structure('anonym', pdb)
37
+
38
+ list_blocks, chain_ids, chains = [], {}, []
39
+
40
+ for model in structure.get_models(): # use model 1 only
41
+ structure = model
42
+ break
43
+
44
+ for chain in structure.get_chains():
45
+
46
+ _id = chain.get_id()
47
+ if (selected_chains is not None) and (_id not in selected_chains):
48
+ continue
49
+
50
+ residues, res_ids = [], {}
51
+
52
+ for residue in chain:
53
+ abrv = residue.get_resname()
54
+ hetero_flag, res_number, insert_code = residue.get_id()
55
+ res_id = f'{res_number}-{insert_code}'
56
+ if hetero_flag == 'W':
57
+ continue # residue from glucose (WAT) or water (HOH)
58
+ if hetero_flag.strip() != '' and res_id in res_ids:
59
+ continue # the solvent (e.g. H_EDO (EDO))
60
+ if abrv in ['EDO', 'HOH', 'BME']: # solvent or other molecules
61
+ continue
62
+ if abrv == 'MSE':
63
+ abrv = 'MET' # MET is usually transformed to MSE for structural analysis
64
+
65
+ # filter Hs because not all data include them
66
+ atoms = [ Atom(atom.get_id(), atom.get_coord().tolist(), atom.element) for atom in residue if atom.element != 'H' ]
67
+ block = Block(abrv, atoms, id=(res_number, insert_code))
68
+ if block.is_residue():
69
+ residues.append(block)
70
+ res_ids[res_id] = True
71
+
72
+ if len(residues) == 0: # not a chain
73
+ continue
74
+
75
+ chain_ids[_id] = len(list_blocks)
76
+ list_blocks.append(residues)
77
+ chains.append(_id)
78
+
79
+ # reorder
80
+ if selected_chains is not None:
81
+ list_blocks = [list_blocks[chain_ids[chain_id]] for chain_id in selected_chains]
82
+ chains = selected_chains
83
+
84
+ if dict_form:
85
+ return { chain: blocks for chain, blocks in zip(chains, list_blocks)}
86
+
87
+ if return_chain_ids:
88
+ return list_blocks, chains
89
+
90
+ return list_blocks
91
+
92
+
93
+ if __name__ == '__main__':
94
+ import sys
95
+ list_blocks = pdb_to_list_blocks(sys.argv[1])
96
+ print(f'{sys.argv[1]} parsed')
97
+ print(f'number of chains: {len(list_blocks)}')
98
+ for i, chain in enumerate(list_blocks):
99
+ print(f'chain {i} lengths: {len(chain)}')
data/dataset_wrapper.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable
2
+ from tqdm import tqdm
3
+ from math import log
4
+
5
+ import numpy as np
6
+ import torch
7
+ import sympy
8
+
9
+ from utils import register as R
10
+
11
+
12
+ class MixDatasetWrapper(torch.utils.data.Dataset):
13
+ def __init__(self, *datasets, collate_fn: Callable=None) -> None:
14
+ super().__init__()
15
+ self.datasets = datasets
16
+ self.cum_len = []
17
+ self.total_len = 0
18
+ for dataset in datasets:
19
+ self.total_len += len(dataset)
20
+ self.cum_len.append(self.total_len)
21
+ self.collate_fn = self.datasets[0].collate_fn if collate_fn is None else collate_fn
22
+ if hasattr(datasets[0], '_lengths'):
23
+ self._lengths = []
24
+ for dataset in datasets:
25
+ self._lengths.extend(dataset._lengths)
26
+
27
+ def update_epoch(self):
28
+ for dataset in self.datasets:
29
+ if hasattr(dataset, 'update_epoch'):
30
+ dataset.update_epoch()
31
+
32
+ def get_len(self, idx):
33
+ return self._lengths[idx]
34
+
35
+ def __len__(self):
36
+ return self.total_len
37
+
38
+ def __getitem__(self, idx):
39
+ last_cum_len = 0
40
+ for i, cum_len in enumerate(self.cum_len):
41
+ if idx < cum_len:
42
+ return self.datasets[i].__getitem__(idx - last_cum_len)
43
+ last_cum_len = cum_len
44
+ return None # this is not possible
45
+
46
+
47
+ @R.register('DynamicBatchWrapper')
48
+ class DynamicBatchWrapper(torch.utils.data.Dataset):
49
+ def __init__(self, dataset, complexity, ubound_per_batch) -> None:
50
+ super().__init__()
51
+ self.dataset = dataset
52
+ self.indexes = [i for i in range(len(dataset))]
53
+ self.complexity = complexity
54
+ self.eval_func = sympy.lambdify('n', sympy.simplify(complexity))
55
+ self.ubound_per_batch = ubound_per_batch
56
+ self.total_size = None
57
+ self.batch_indexes = []
58
+ self._form_batch()
59
+
60
+ def __getattr__(self, attr):
61
+ if attr in self.__dict__:
62
+ return self.__dict__[attr]
63
+ elif hasattr(self.dataset, attr):
64
+ return getattr(self.dataset, attr)
65
+ else:
66
+ raise AttributeError(f"'DynamicBatchWrapper'(or '{type(self.dataset)}') object has no attribute '{attr}'")
67
+
68
+ def update_epoch(self):
69
+ if hasattr(self.dataset, 'update_epoch'):
70
+ self.dataset.update_epoch()
71
+ self._form_batch()
72
+
73
+ ########## overload with your criterion ##########
74
+ def _form_batch(self):
75
+
76
+ np.random.shuffle(self.indexes)
77
+ last_batch_indexes = self.batch_indexes
78
+ self.batch_indexes = []
79
+
80
+ cur_complexity = 0
81
+ batch = []
82
+
83
+ for i in tqdm(self.indexes):
84
+ item_len = self.eval_func(self.dataset.get_len(i))
85
+ if item_len > self.ubound_per_batch:
86
+ continue
87
+ cur_complexity += item_len
88
+ if cur_complexity > self.ubound_per_batch:
89
+ self.batch_indexes.append(batch)
90
+ batch = []
91
+ cur_complexity = item_len
92
+ batch.append(i)
93
+ self.batch_indexes.append(batch)
94
+
95
+ if self.total_size is None:
96
+ self.total_size = len(self.batch_indexes)
97
+ else:
98
+ # control the lengths of the dataset, otherwise the dataloader will raise error
99
+ if len(self.batch_indexes) < self.total_size:
100
+ num_add = self.total_size - len(self.batch_indexes)
101
+ self.batch_indexes = self.batch_indexes + last_batch_indexes[:num_add]
102
+ else:
103
+ self.batch_indexes = self.batch_indexes[:self.total_size]
104
+
105
+ def __len__(self):
106
+ return len(self.batch_indexes)
107
+
108
+ def __getitem__(self, idx):
109
+ return [self.dataset[i] for i in self.batch_indexes[idx]]
110
+
111
+ def collate_fn(self, batched_batch):
112
+ batch = []
113
+ for minibatch in batched_batch:
114
+ batch.extend(minibatch)
115
+ return self.dataset.collate_fn(batch)
data/format.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python
2
+ # -*- coding:utf-8 -*-
3
+ from copy import copy
4
+ from typing import List, Tuple, Iterator, Optional
5
+
6
+ from utils import const
7
+
8
+
9
+ class MoleculeVocab:
10
+
11
+ MAX_ATOM_NUMBER = 14
12
+
13
+ def __init__(self):
14
+ self.backbone_atoms = ['N', 'CA', 'C', 'O']
15
+ self.PAD, self.MASK, self.UNK, self.LAT = '#', '*', '?', '&' # pad / mask / unk / latent node
16
+ specials = [# special added
17
+ (self.PAD, 'PAD'), (self.MASK, 'MASK'), (self.UNK, 'UNK'), # pad / mask / unk
18
+ (self.LAT, '<L>') # latent node in latent space
19
+ ]
20
+
21
+ aas = const.aas
22
+
23
+ # sms = [(e.lower(), e) for e in const.periodic_table]
24
+ sms = [] # disable small molecule vocabulary
25
+
26
+ self.atom_pad, self.atom_mask, self.atom_latent = 'pad', 'msk', 'lat' # Avoid conflict with atom P
27
+ self.atom_pos_pad, self.atom_pos_mask, self.atom_pos_latent = 'pad', 'msk', 'lat'
28
+ self.atom_pos_sm = 'sml' # small molecule
29
+
30
+ # block level vocab
31
+ self.idx2block = specials + aas + sms
32
+ self.symbol2idx, self.abrv2idx = {}, {}
33
+ for i, (symbol, abrv) in enumerate(self.idx2block):
34
+ self.symbol2idx[symbol] = i
35
+ self.abrv2idx[abrv] = i
36
+ self.special_mask = [1 for _ in specials] + [0 for _ in aas] + [0 for _ in sms]
37
+
38
+ # atom level vocab
39
+ self.idx2atom = [self.atom_pad, self.atom_mask, self.atom_latent] + const.periodic_table
40
+ self.idx2atom_pos = [self.atom_pos_pad, self.atom_pos_mask, self.atom_pos_latent, '', 'A', 'B', 'G', 'D', 'E', 'Z', 'H', 'XT', 'P', self.atom_pos_sm] # SM is for atoms in small molecule, 'P' for O1P, O2P, O3P
41
+ self.atom2idx, self.atom_pos2idx = {}, {}
42
+ self.atom2idx = {}
43
+ for i, atom in enumerate(self.idx2atom):
44
+ self.atom2idx[atom] = i
45
+ for i, atom_pos in enumerate(self.idx2atom_pos):
46
+ self.atom_pos2idx[atom_pos] = i
47
+
48
+ # block level APIs
49
+
50
+ def abrv_to_symbol(self, abrv):
51
+ idx = self.abrv_to_idx(abrv)
52
+ return None if idx is None else self.idx2block[idx][0]
53
+
54
+ def symbol_to_abrv(self, symbol):
55
+ idx = self.symbol_to_idx(symbol)
56
+ return None if idx is None else self.idx2block[idx][1]
57
+
58
+ def abrv_to_idx(self, abrv):
59
+ abrv = abrv.upper()
60
+ return self.abrv2idx.get(abrv, self.abrv2idx['UNK'])
61
+
62
+ def symbol_to_idx(self, symbol):
63
+ # symbol = symbol.upper()
64
+ return self.symbol2idx.get(symbol, self.abrv2idx['UNK'])
65
+
66
+ def idx_to_symbol(self, idx):
67
+ return self.idx2block[idx][0]
68
+
69
+ def idx_to_abrv(self, idx):
70
+ return self.idx2block[idx][1]
71
+
72
+ def get_pad_idx(self):
73
+ return self.symbol_to_idx(self.PAD)
74
+
75
+ def get_mask_idx(self):
76
+ return self.symbol_to_idx(self.MASK)
77
+
78
+ def get_special_mask(self):
79
+ return copy(self.special_mask)
80
+
81
+ # atom level APIs
82
+
83
+ def get_atom_pad_idx(self):
84
+ return self.atom2idx[self.atom_pad]
85
+
86
+ def get_atom_mask_idx(self):
87
+ return self.atom2idx[self.atom_mask]
88
+
89
+ def get_atom_latent_idx(self):
90
+ return self.atom2idx[self.atom_latent]
91
+
92
+ def get_atom_pos_pad_idx(self):
93
+ return self.atom_pos2idx[self.atom_pos_pad]
94
+
95
+ def get_atom_pos_mask_idx(self):
96
+ return self.atom_pos2idx[self.atom_pos_mask]
97
+
98
+ def get_atom_pos_latent_idx(self):
99
+ return self.atom_pos2idx[self.atom_pos_latent]
100
+
101
+ def idx_to_atom(self, idx):
102
+ return self.idx2atom[idx]
103
+
104
+ def atom_to_idx(self, atom):
105
+ atom = atom.upper()
106
+ return self.atom2idx.get(atom, self.atom2idx[self.atom_mask])
107
+
108
+ def idx_to_atom_pos(self, idx):
109
+ return self.idx2atom_pos[idx]
110
+
111
+ def atom_pos_to_idx(self, atom_pos):
112
+ return self.atom_pos2idx.get(atom_pos, self.atom_pos2idx[self.atom_pos_mask])
113
+
114
+ # sizes
115
+
116
+ def get_num_atom_type(self):
117
+ return len(self.idx2atom)
118
+
119
+ def get_num_atom_pos(self):
120
+ return len(self.idx2atom_pos)
121
+
122
+ def get_num_block_type(self):
123
+ return len(self.special_mask) - sum(self.special_mask)
124
+
125
+ def __len__(self):
126
+ return len(self.symbol2idx)
127
+
128
+ # others
129
+ @property
130
+ def ca_channel_idx(self):
131
+ return self.backbone_atoms.index('CA')
132
+
133
+
134
+ VOCAB = MoleculeVocab()
135
+
136
+
137
+ class Atom:
138
+ def __init__(self, atom_name: str, coordinate: List[float], element: str, pos_code: str=None):
139
+ self.name = atom_name
140
+ self.coordinate = coordinate
141
+ self.element = element
142
+ if pos_code is None:
143
+ pos_code = atom_name.lstrip(element)
144
+ self.pos_code = pos_code
145
+ else:
146
+ self.pos_code = pos_code
147
+
148
+ def get_element(self):
149
+ return self.element
150
+
151
+ def get_coord(self):
152
+ return copy(self.coordinate)
153
+
154
+ def get_pos_code(self):
155
+ return self.pos_code
156
+
157
+ def __str__(self) -> str:
158
+ return self.name
159
+
160
+ def __repr__(self) -> str:
161
+ return f"Atom ({self.name}): {self.element}({self.pos_code}) [{','.join(['{:.4f}'.format(num) for num in self.coordinate])}]"
162
+
163
+ def to_tuple(self):
164
+ return (
165
+ self.name,
166
+ self.coordinate,
167
+ self.element,
168
+ self.pos_code
169
+ )
170
+
171
+ @classmethod
172
+ def from_tuple(self, data):
173
+ return Atom(
174
+ atom_name=data[0],
175
+ coordinate=data[1],
176
+ element=data[2],
177
+ pos_code=data[3]
178
+ )
179
+
180
+
181
+ class Block:
182
+ def __init__(self, abrv: str, units: List[Atom], id: Optional[any]=None) -> None:
183
+ self.abrv: str = abrv
184
+ self.units: List[Atom] = units
185
+ self._uname2idx = { unit.name: i for i, unit in enumerate(self.units) }
186
+ self.id = id
187
+
188
+ def __len__(self) -> int:
189
+ return len(self.units)
190
+
191
+ def __iter__(self) -> Iterator[Atom]:
192
+ return iter(self.units)
193
+
194
+ def get_unit_by_name(self, name: str) -> Atom:
195
+ idx = self._uname2idx[name]
196
+ return self.units[idx]
197
+
198
+ def has_unit(self, name: str) -> bool:
199
+ return name in self._uname2idx
200
+
201
+ def to_tuple(self):
202
+ return (
203
+ self.abrv,
204
+ [unit.to_tuple() for unit in self.units],
205
+ self.id
206
+ )
207
+
208
+ def is_residue(self):
209
+ return self.has_unit('CA') and self.has_unit('N') and self.has_unit('C') and self.has_unit('O')
210
+
211
+ @classmethod
212
+ def from_tuple(self, data):
213
+ return Block(
214
+ abrv=data[0],
215
+ units=[Atom.from_tuple(unit_data) for unit_data in data[1]],
216
+ id=data[2]
217
+ )
218
+
219
+ def __repr__(self) -> str:
220
+ return f"Block ({self.abrv}):\n\t" + '\n\t'.join([repr(at) for at in self.units]) + '\n'
data/mmap_dataset.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python
2
+ # -*- coding:utf-8 -*-
3
+ import os
4
+ import io
5
+ import gzip
6
+ import json
7
+ import mmap
8
+ from typing import Optional
9
+ from tqdm import tqdm
10
+
11
+ import torch
12
+
13
+
14
+ def compress(x):
15
+ serialized_x = json.dumps(x).encode()
16
+ buf = io.BytesIO()
17
+ with gzip.GzipFile(fileobj=buf, mode='wb', compresslevel=6) as f:
18
+ f.write(serialized_x)
19
+ compressed = buf.getvalue()
20
+ return compressed
21
+
22
+
23
+ def decompress(compressed_x):
24
+ buf = io.BytesIO(compressed_x)
25
+ with gzip.GzipFile(fileobj=buf, mode="rb") as f:
26
+ serialized_x = f.read().decode()
27
+ x = json.loads(serialized_x)
28
+ return x
29
+
30
+
31
+ def _find_measure_unit(num_bytes):
32
+ size, measure_unit = num_bytes, 'Bytes'
33
+ for unit in ['KB', 'MB', 'GB']:
34
+ if size > 1000:
35
+ size /= 1024
36
+ measure_unit = unit
37
+ else:
38
+ break
39
+ return size, measure_unit
40
+
41
+
42
+ def create_mmap(iterator, out_dir, total_len=None, commit_batch=10000):
43
+
44
+ if not os.path.exists(out_dir):
45
+ os.makedirs(out_dir)
46
+
47
+ data_file_path = os.path.join(out_dir, 'data.bin')
48
+ data_file = open(data_file_path, 'wb')
49
+ index_file = open(os.path.join(out_dir, 'index.txt'), 'w')
50
+
51
+ i, offset, n_finished = 0, 0, 0
52
+ progress_bar = tqdm(iterator, total=total_len)
53
+ for _id, x, properties, entry_idx in iterator:
54
+ progress_bar.set_description(f'Processing {_id}')
55
+ compressed_x = compress(x)
56
+ bin_length = data_file.write(compressed_x)
57
+ properties = '\t'.join([str(prop) for prop in properties])
58
+ index_file.write(f'{_id}\t{offset}\t{offset + bin_length}\t{properties}\n') # tuple of (_id, start, end), data slice is [start, end)
59
+ offset += bin_length
60
+ i += 1
61
+
62
+ if entry_idx > n_finished:
63
+ progress_bar.update(entry_idx - n_finished)
64
+ n_finished = entry_idx
65
+ if total_len is not None:
66
+ expected_size = os.fstat(data_file.fileno()).st_size / n_finished * total_len
67
+ expected_size, measure_unit = _find_measure_unit(expected_size)
68
+ progress_bar.set_postfix({f'{i} saved; Estimated total size ({measure_unit})': expected_size})
69
+
70
+ if i % commit_batch == 0:
71
+ data_file.flush() # save from memory to disk
72
+ index_file.flush()
73
+
74
+
75
+ data_file.close()
76
+ index_file.close()
77
+
78
+
79
+ class MMAPDataset(torch.utils.data.Dataset):
80
+
81
+ def __init__(self, mmap_dir: str, specify_data: Optional[str]=None, specify_index: Optional[str]=None) -> None:
82
+ super().__init__()
83
+
84
+ self._indexes = []
85
+ self._properties = []
86
+ _index_path = os.path.join(mmap_dir, 'index.txt') if specify_index is None else specify_index
87
+ with open(_index_path, 'r') as f:
88
+ for line in f.readlines():
89
+ messages = line.strip().split('\t')
90
+ _id, start, end = messages[:3]
91
+ _property = messages[3:]
92
+ self._indexes.append((_id, int(start), int(end)))
93
+ self._properties.append(_property)
94
+ _data_path = os.path.join(mmap_dir, 'data.bin') if specify_data is None else specify_data
95
+ self._data_file = open(_data_path, 'rb')
96
+ self._mmap = mmap.mmap(self._data_file.fileno(), 0, access=mmap.ACCESS_READ)
97
+
98
+ def __del__(self):
99
+ self._mmap.close()
100
+ self._data_file.close()
101
+
102
+ def __len__(self):
103
+ return len(self._indexes)
104
+
105
+ def __getitem__(self, idx: int):
106
+ if idx < 0 or idx >= len(self):
107
+ raise IndexError(idx)
108
+
109
+ _, start, end = self._indexes[idx]
110
+ data = decompress(self._mmap[start:end])
111
+
112
+ return data
data/resample.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python
2
+ # -*- coding:utf-8 -*-
3
+ import numpy as np
4
+
5
+
6
+ class ClusterResampler:
7
+ def __init__(self, cluster_path: str) -> None:
8
+ idx2prob = []
9
+ with open(cluster_path, 'r') as fin:
10
+ for line in fin:
11
+ cluster_n_member = int(line.strip().split('\t')[-1])
12
+ idx2prob.append(1 / cluster_n_member)
13
+ total = sum(idx2prob)
14
+ idx2prob = [p / total for p in idx2prob]
15
+ self.idx2prob = np.array(idx2prob)
16
+
17
+ def __call__(self, n_sample:int, replace: bool=False):
18
+ idxs = np.random.choice(len(self.idx2prob), size=n_sample, replace=replace, p=self.idx2prob)
19
+ return idxs
env.yaml ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: PepGLAD
2
+ channels:
3
+ - pytorch
4
+ - nvidia
5
+ - bioconda
6
+ - pyg
7
+ - salilab
8
+ - conda-forge
9
+ - defaults
10
+ dependencies:
11
+ - python=3.9
12
+ - pytorch::pytorch=1.13.1
13
+ - pytorch::pytorch-cuda=11.7
14
+ - nvidia::cudatoolkit=11.7.0
15
+ - pyg::pytorch-scatter
16
+ - mkl=2024.0.0
17
+ - salilab::dssp
18
+ - anaconda::libboost=1.73.0
19
+ - mmseqs2
20
+ - openmm=8.0.0
21
+ - pdbfixer
22
+ - pip
23
+ - pip:
24
+ - biopython==1.80
25
+ - rdkit-pypi==2022.3.5
26
+ - ray
27
+ - sympy
28
+ - scipy
29
+ - freesasa
30
+ - tensorboard
31
+ - pyyaml
32
+ - tqdm
evaluation/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ #!/usr/bin/python
2
+ # -*- coding:utf-8 -*-
3
+
evaluation/dG/RosettaFastRelaxUtil.xml ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <ROSETTASCRIPTS>
2
+ <SCOREFXNS>
3
+ <ScoreFunction name="sfxn_soft" weights="beta_nov16_soft" />
4
+ <ScoreFunction name="sfxn" weights="beta_nov16" />
5
+ <ScoreFunction name="sfxn_relax" weights="beta_nov16" >
6
+ <Reweight scoretype="arg_cation_pi" weight="3" />
7
+ <Reweight scoretype="approximate_buried_unsat_penalty" weight="5" />
8
+ <Set approximate_buried_unsat_penalty_burial_atomic_depth="3.5" />
9
+ <Set approximate_buried_unsat_penalty_hbond_energy_threshold="-0.5" />
10
+ </ScoreFunction>
11
+ <ScoreFunction name="sfxn_softish" weights="beta_nov16" >
12
+ <Reweight scoretype="fa_rep" weight="0.15" />
13
+ </ScoreFunction>
14
+ <ScoreFunction name="sfxn_fa_atr" weights="empty" >
15
+ <Reweight scoretype="fa_atr" weight="1" />
16
+ </ScoreFunction>
17
+ <ScoreFunction name="vdw_sol" weights="empty" >
18
+ <Reweight scoretype="fa_atr" weight="1.0" />
19
+ <Reweight scoretype="fa_rep" weight="0.55" />
20
+ <Reweight scoretype="fa_sol" weight="1.0" />
21
+ </ScoreFunction>
22
+
23
+ </SCOREFXNS>
24
+ <RESIDUE_SELECTORS>
25
+ <Chain name="chainA" chains="A"/>
26
+ <Chain name="chainB" chains="B"/>
27
+ <Neighborhood name="interface_chA" selector="chainB" distance="8.0" />
28
+ <Neighborhood name="interface_chB" selector="chainA" distance="8.0" />
29
+ <And name="AB_interface" selectors="interface_chA,interface_chB" />
30
+ <Not name="Not_interface" selector="AB_interface" />
31
+ <And name="actual_interface_chB" selectors="AB_interface,chainB" />
32
+ <And name="not_interface_chB" selectors="Not_interface,chainB" />
33
+
34
+ <ResidueName name="apolar" residue_name3="ALA,CYS,PHE,ILE,LEU,MET,THR,PRO,VAL,TRP,TYR" />
35
+ <Not name="polar" selector="apolar" />
36
+
37
+ <True name="all" />
38
+
39
+ <ResidueName name="pro_and_gly_positions" residue_name3="PRO,GLY" />
40
+
41
+ <ResiduePDBInfoHasLabel name="HOTSPOT_res" property="HOTSPOT" />
42
+ </RESIDUE_SELECTORS>
43
+
44
+
45
+ <RESIDUE_SELECTORS>
46
+ <!-- Layer Design -->
47
+ <Layer name="surface" select_core="false" select_boundary="false" select_surface="true" use_sidechain_neighbors="true"/>
48
+ <Layer name="boundary" select_core="false" select_boundary="true" select_surface="false" use_sidechain_neighbors="true"/>
49
+ <Layer name="core" select_core="true" select_boundary="false" select_surface="false" use_sidechain_neighbors="true"/>
50
+ <SecondaryStructure name="sheet" overlap="0" minH="3" minE="2" include_terminal_loops="false" use_dssp="true" ss="E"/>
51
+ <SecondaryStructure name="entire_loop" overlap="0" minH="3" minE="2" include_terminal_loops="true" use_dssp="true" ss="L"/>
52
+ <SecondaryStructure name="entire_helix" overlap="0" minH="3" minE="2" include_terminal_loops="false" use_dssp="true" ss="H"/>
53
+ <And name="helix_cap" selectors="entire_loop">
54
+ <PrimarySequenceNeighborhood lower="1" upper="0" selector="entire_helix"/>
55
+ </And>
56
+ <And name="helix_start" selectors="entire_helix">
57
+ <PrimarySequenceNeighborhood lower="0" upper="1" selector="helix_cap"/>
58
+ </And>
59
+ <And name="helix" selectors="entire_helix">
60
+ <Not selector="helix_start"/>
61
+ </And>
62
+ <And name="loop" selectors="entire_loop">
63
+ <Not selector="helix_cap"/>
64
+ </And>
65
+
66
+ </RESIDUE_SELECTORS>
67
+
68
+ <TASKOPERATIONS>
69
+ <DesignRestrictions name="layer_design_no_core_polars">
70
+ <Action selector_logic="surface AND helix_start" aas="DEHKPQR"/>
71
+ <Action selector_logic="surface AND helix" aas="EHKQR"/>
72
+ <Action selector_logic="surface AND sheet" aas="EHKNQRST"/>
73
+ <Action selector_logic="surface AND loop" aas="DEGHKNPQRST"/>
74
+ <Action selector_logic="boundary AND helix_start" aas="ADEHIKLNPQRSTVWY"/>
75
+ <Action selector_logic="boundary AND helix" aas="ADEHIKLNQRSTVWY"/>
76
+ <Action selector_logic="boundary AND sheet" aas="DEFHIKLNQRSTVWY"/>
77
+ <Action selector_logic="boundary AND loop" aas="ADEFGHIKLNPQRSTVWY"/>
78
+ <Action selector_logic="core AND helix_start" aas="AFILMPVWY"/>
79
+ <Action selector_logic="core AND helix" aas="AFILVWY"/>
80
+ <Action selector_logic="core AND sheet" aas="FILVWY"/>
81
+ <Action selector_logic="core AND loop" aas="AFGILPVWY"/>
82
+ <Action selector_logic="helix_cap" aas="DNST"/>
83
+ </DesignRestrictions>
84
+ </TASKOPERATIONS>
85
+
86
+
87
+ <TASKOPERATIONS>
88
+ <ProteinProteinInterfaceUpweighter name="upweight_interface" interface_weight="3" />
89
+ <ProteinInterfaceDesign name="pack_long" design_chain1="0" design_chain2="0" jump="1" interface_distance_cutoff="15"/>
90
+ <InitializeFromCommandline name="init" />
91
+ <IncludeCurrent name="current" />
92
+ <LimitAromaChi2 name="limitchi2" chi2max="110" chi2min="70" include_trp="True" />
93
+ <ExtraRotamersGeneric name="ex1_ex2" ex1="1" ex2="1" />
94
+
95
+
96
+ <OperateOnResidueSubset name="restrict_target_not_interface" selector="not_interface_chB">
97
+ <PreventRepackingRLT/>
98
+ </OperateOnResidueSubset>
99
+ <OperateOnResidueSubset name="restrict2repacking" selector="all">
100
+ <RestrictToRepackingRLT/>
101
+ </OperateOnResidueSubset>
102
+ <OperateOnResidueSubset name="restrict_to_interface" selector="Not_interface">
103
+ <PreventRepackingRLT/>
104
+ </OperateOnResidueSubset>
105
+ <OperateOnResidueSubset name="restrict_target2repacking" selector="chainB">
106
+ <RestrictToRepackingRLT/>
107
+ </OperateOnResidueSubset>
108
+ <OperateOnResidueSubset name="restrict_hotspots2repacking" selector="HOTSPOT_res">
109
+ <RestrictToRepackingRLT/>
110
+ </OperateOnResidueSubset>
111
+
112
+ <DisallowIfNonnative name="disallow_GLY" resnum="0" disallow_aas="G" />
113
+ <DisallowIfNonnative name="disallow_PRO" resnum="0" disallow_aas="P" />
114
+ <SelectBySASA name="PR_monomer_core" mode="sc" state="monomer" probe_radius="2.2" core_asa="10" surface_asa="10" core="0" boundary="1" surface="1" verbose="0" />
115
+
116
+ <OperateOnResidueSubset name="restrict_PRO_GLY" selector="pro_and_gly_positions">
117
+ <PreventRepackingRLT/>
118
+ </OperateOnResidueSubset>
119
+
120
+ PruneBadRotamers name="prune_bad_rotamers" probability_cut="0.01" />
121
+
122
+ </TASKOPERATIONS>
123
+ <MOVERS>
124
+
125
+
126
+ <SwitchChainOrder name="chain1onlypre" chain_order="1" />
127
+ <ScoreMover name="scorepose" scorefxn="sfxn" verbose="false" />
128
+ <ParsedProtocol name="chain1only">
129
+ <Add mover="chain1onlypre" />
130
+ <Add mover="scorepose" />
131
+ </ParsedProtocol>
132
+ <TaskAwareMinMover name="min" scorefxn="sfxn" bb="0" chi="1" task_operations="pack_long" />
133
+
134
+ <DeleteRegionMover name="delete_polar" residue_selector="polar" rechain="false" />
135
+
136
+
137
+ </MOVERS>
138
+ <FILTERS>
139
+
140
+ <Time name="timed"/>
141
+
142
+ <Sasa name="interface_buried_sasa" confidence="0" />
143
+ <Ddg name="ddg" threshold="0" jump="1" repeats="1" repack="1" relax_mover="min" confidence="0" scorefxn="sfxn" />
144
+ <Ddg name="ddg_norepack" threshold="0" jump="1" repeats="1" repack="0" relax_mover="min" confidence="0" scorefxn="sfxn" />
145
+ <ShapeComplementarity name="interface_sc" verbose="0" min_sc="0.55" write_int_area="1" write_median_dist="1" jump="1" confidence="0"/>
146
+
147
+
148
+ ### score function monomer terms
149
+ <ScoreType name="total_score_MBF" scorefxn="sfxn" score_type="total_score" threshold="0" confidence="0" />
150
+ <MoveBeforeFilter name="total_score_monomer" mover="chain1only" filter="total_score_MBF" confidence="0" />
151
+ <ResidueCount name="res_count_MBF" max_residue_count="9999" confidence="0"/>
152
+ <MoveBeforeFilter name="res_count_monomer" mover="chain1only" filter="res_count_MBF" confidence="0" />
153
+
154
+
155
+ <CalculatorFilter name="score_per_res" equation="total_score_monomer / res" threshold="-3.5" confidence="0">
156
+ <Var name="total_score_monomer" filter="total_score_monomer"/>
157
+ <Var name="res" filter="res_count_monomer"/>
158
+ </CalculatorFilter>
159
+
160
+
161
+ <InterfaceHydrophobicResidueContacts name="hydrophobic_residue_contacts" target_selector="chainB" binder_selector="chainA" scorefxn="sfxn_soft" confidence="0"/>
162
+
163
+
164
+ <Ddg name="ddg_hydrophobic_pre" threshold="-10" jump="1" repeats="1" repack="0" confidence="0" scorefxn="vdw_sol" />
165
+ <MoveBeforeFilter name="ddg_hydrophobic" mover="delete_polar" filter="ddg_hydrophobic_pre" confidence="0"/>
166
+
167
+ <ContactMolecularSurface name="contact_molecular_surface" distance_weight="0.5" target_selector="chainA" binder_selector="chainB" confidence="0" />
168
+
169
+ </FILTERS>
170
+
171
+
172
+ <MOVERS>
173
+
174
+ <FastRelax name="FastRelax" scorefxn="sfxn_relax" repeats="1" batch="false" ramp_down_constraints="false" cartesian="false" bondangle="false" bondlength="false" min_type="dfpmin_armijo_nonmonotone" task_operations="current,ex1_ex2,restrict_target_not_interface,limitchi2" >
175
+ <MoveMap name="MM" >
176
+ <Chain number="1" chi="true" bb="true" />
177
+ <Chain number="2" chi="true" bb="false" />
178
+ <Jump number="1" setting="true" />
179
+ </MoveMap>
180
+ </FastRelax>
181
+
182
+ </MOVERS>
183
+ <APPLY_TO_POSE>
184
+ </APPLY_TO_POSE>
185
+
186
+ <PROTOCOLS>
187
+ </PROTOCOLS>
188
+
189
+ <OUTPUT/>
190
+ </ROSETTASCRIPTS>
evaluation/dG/base.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python
2
+ # -*- coding:utf-8 -*-
3
+ import os
4
+ import re
5
+ import json
6
+ from typing import Optional, Tuple, List
7
+ from dataclasses import dataclass
8
+
9
+ from .energy import pyrosetta_fastrelax, pyrosetta_interface_energy, rfdiff_refine
10
+
11
+
12
+ @dataclass
13
+ class RelaxTask:
14
+ in_path: str
15
+ current_path: str
16
+ info: dict
17
+ status: str
18
+ rec_chain: str
19
+ pep_chain: str
20
+ rfdiff_relax: bool = False
21
+ dG: Optional[float] = None
22
+
23
+ def set_dG(self, dG):
24
+ self.dG = dG
25
+
26
+ def get_in_path_with_tag(self, tag):
27
+ name, ext = os.path.splitext(self.in_path)
28
+ new_path = f'{name}_{tag}{ext}'
29
+ return new_path
30
+
31
+ def set_current_path_tag(self, tag):
32
+ new_path = self.get_in_path_with_tag(tag)
33
+ self.current_path = new_path
34
+ return new_path
35
+
36
+ def check_current_path_exists(self):
37
+ ok = os.path.exists(self.current_path)
38
+ if not ok:
39
+ self.mark_failure()
40
+ if os.path.getsize(self.current_path) == 0:
41
+ ok = False
42
+ self.mark_failure()
43
+ os.unlink(self.current_path)
44
+ return ok
45
+
46
+ def update_if_finished(self, tag):
47
+ out_path = self.get_in_path_with_tag(tag)
48
+ if os.path.exists(out_path) and os.path.getsize(out_path) > 0:
49
+ # print('Already finished', out_path)
50
+ self.set_current_path_tag(tag)
51
+ self.mark_success()
52
+ return True
53
+ return False
54
+
55
+ def can_proceed(self):
56
+ self.check_current_path_exists()
57
+ return self.status != 'failed'
58
+
59
+ def mark_success(self):
60
+ self.status = 'success'
61
+
62
+ def mark_failure(self):
63
+ self.status = 'failed'
64
+
65
+
66
+ class TaskScanner:
67
+
68
+ def __init__(self, results, n_sample, rfdiff_relax):
69
+ super().__init__()
70
+ self.results = results
71
+ self.n_sample = n_sample
72
+ self.rfdiff_relax = rfdiff_relax
73
+ self.visited = set()
74
+
75
+ def scan(self) -> List[RelaxTask]:
76
+ tasks = []
77
+ root_dir = os.path.dirname(self.results)
78
+ with open(self.results, 'r') as fin:
79
+ lines = fin.readlines()
80
+ for line in lines:
81
+ item = json.loads(line)
82
+ if item['number'] >= self.n_sample:
83
+ continue
84
+ _id = f"{item['id']}_{item['number']}"
85
+ if _id in self.visited:
86
+ continue
87
+ gen_pdb = os.path.split(item['gen_pdb'])[-1]
88
+ # subdir = gen_pdb.split('_')[0]
89
+ subdir = '_'.join(gen_pdb.split('_')[:-2])
90
+ gen_pdb = os.path.join(root_dir, 'candidates', subdir, gen_pdb)
91
+ tasks.append(RelaxTask(
92
+ in_path=gen_pdb,
93
+ current_path=gen_pdb,
94
+ info=item,
95
+ status='created',
96
+ rec_chain=item['rec_chain'],
97
+ pep_chain=item['lig_chain'],
98
+ rfdiff_relax=self.rfdiff_relax
99
+ ))
100
+ self.visited.add(_id)
101
+ return tasks
102
+
103
+ def scan_dataset(self) -> List[RelaxTask]:
104
+ tasks = []
105
+ root_dir = os.path.dirname(self.results)
106
+ with open(self.results, 'r') as fin: # index file of datasets
107
+ lines = fin.readlines()
108
+ for line in lines:
109
+ line = line.strip('\n').split('\t')
110
+ _id = line[0]
111
+ item = {
112
+ 'id': _id,
113
+ 'number': 0
114
+ }
115
+ pdb_path = os.path.join(root_dir, 'pdbs', _id + '.pdb')
116
+ tasks.append(RelaxTask(
117
+ in_path=pdb_path,
118
+ current_path=pdb_path,
119
+ info=item,
120
+ status='created',
121
+ rec_chain=line[7],
122
+ pep_chain=line[8],
123
+ rfdiff_relax=self.rfdiff_relax
124
+ ))
125
+ self.visited.add(_id)
126
+ return tasks
127
+
128
+
129
+ def run_pyrosetta(task: RelaxTask):
130
+ if not task.can_proceed() :
131
+ return task
132
+ # if task.update_if_finished('rosetta'):
133
+ # return task
134
+
135
+ out_path = task.set_current_path_tag('rosetta')
136
+ try:
137
+ if task.rfdiff_relax:
138
+ rfdiff_refine(task.in_path, out_path, task.pep_chain)
139
+ else:
140
+ pyrosetta_fastrelax(task.in_path, out_path, task.pep_chain, rfdiff_config=task.rfdiff_relax)
141
+ dG = pyrosetta_interface_energy(out_path, [task.rec_chain], [task.pep_chain])
142
+ task.mark_success()
143
+ except Exception as e:
144
+ print(e)
145
+ dG = 1e10
146
+ task.mark_failure()
147
+ task.set_dG(dG)
148
+ return task
evaluation/dG/energy.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python
2
+ # -*- coding:utf-8 -*-
3
+ '''
4
+ From https://github.com/luost26/diffab/blob/main/diffab/tools/relax/pyrosetta_relaxer.py
5
+ '''
6
+ import os
7
+ import time
8
+ import pyrosetta
9
+ from pyrosetta.rosetta.protocols.analysis import InterfaceAnalyzerMover
10
+ # for fast relax
11
+ from pyrosetta.rosetta import protocols
12
+ from pyrosetta.rosetta.protocols.relax import FastRelax
13
+ from pyrosetta.rosetta.core.pack.task import TaskFactory
14
+ from pyrosetta.rosetta.core.pack.task import operation
15
+ from pyrosetta.rosetta.core.select import residue_selector as selections
16
+ from pyrosetta.rosetta.core.select.movemap import MoveMapFactory, move_map_action
17
+ from pyrosetta.rosetta.core.scoring import ScoreType
18
+
19
+ from Bio.PDB import PDBIO, PDBParser
20
+ from Bio.PDB.Structure import Structure as BStructure
21
+ from Bio.PDB.Model import Model as BModel
22
+ from Bio.PDB.Chain import Chain as BChain
23
+
24
+
25
+ pyrosetta.init(' '.join([
26
+ '-mute', 'all',
27
+ '-use_input_sc',
28
+ '-ignore_unrecognized_res',
29
+ '-ignore_zero_occupancy', 'false',
30
+ '-load_PDB_components', 'false',
31
+ '-relax:default_repeats', '2',
32
+ '-no_fconfig',
33
+ # below are from https://github.com/nrbennet/dl_binder_design/blob/main/mpnn_fr/dl_interface_design.py
34
+ # '-beta_nov16',
35
+ '-use_terminal_residues', 'true',
36
+ '-in:file:silent_struct_type', 'binary'
37
+ ]))
38
+
39
+
40
+ def current_milli_time():
41
+ return round(time.time() * 1000)
42
+
43
+
44
+ def get_scorefxn(scorefxn_name:str):
45
+ """
46
+ Gets the scorefxn with appropriate corrections.
47
+ Taken from: https://gist.github.com/matteoferla/b33585f3aeab58b8424581279e032550
48
+ """
49
+ import pyrosetta
50
+
51
+ corrections = {
52
+ 'beta_july15': False,
53
+ 'beta_nov16': False,
54
+ 'gen_potential': False,
55
+ 'restore_talaris_behavior': False,
56
+ }
57
+ if 'beta_july15' in scorefxn_name or 'beta_nov15' in scorefxn_name:
58
+ # beta_july15 is ref2015
59
+ corrections['beta_july15'] = True
60
+ elif 'beta_nov16' in scorefxn_name:
61
+ corrections['beta_nov16'] = True
62
+ elif 'genpot' in scorefxn_name:
63
+ corrections['gen_potential'] = True
64
+ pyrosetta.rosetta.basic.options.set_boolean_option('corrections:beta_july15', True)
65
+ elif 'talaris' in scorefxn_name: #2013 and 2014
66
+ corrections['restore_talaris_behavior'] = True
67
+ else:
68
+ pass
69
+ for corr, value in corrections.items():
70
+ pyrosetta.rosetta.basic.options.set_boolean_option(f'corrections:{corr}', value)
71
+ return pyrosetta.create_score_function(scorefxn_name)
72
+
73
+
74
+ class RelaxRegion(object):
75
+
76
+ def __init__(self, scorefxn='ref2015', max_iter=1000, subset='nbrs', move_bb=True, rfdiff_config=False):
77
+ super().__init__()
78
+
79
+ if rfdiff_config:
80
+ self.scorefxn = get_scorefxn('beta_nov16')
81
+ xml = os.path.join(os.path.dirname(__file__), 'RosettaFastRelaxUtil.xml')
82
+ objs = protocols.rosetta_scripts.XmlObjects.create_from_file(xml)
83
+ self.fast_relax = objs.get_mover('FastRelax')
84
+ self.fast_relax.max_iter(max_iter)
85
+ else:
86
+ self.scorefxn = get_scorefxn(scorefxn)
87
+ self.fast_relax = FastRelax()
88
+ self.fast_relax.set_scorefxn(self.scorefxn)
89
+ self.fast_relax.max_iter(max_iter)
90
+
91
+ assert subset in ('all', 'target', 'nbrs')
92
+ self.subset = subset
93
+ self.move_bb = move_bb
94
+
95
+ def __call__(self, pdb_path, ligand_chains): # flexible_residue_first, flexible_residue_last):
96
+ pose = pyrosetta.pose_from_pdb(pdb_path)
97
+ start_t = current_milli_time()
98
+ original_pose = pose.clone()
99
+
100
+ tf = TaskFactory()
101
+ tf.push_back(operation.InitializeFromCommandline())
102
+ tf.push_back(operation.RestrictToRepacking()) # Only allow residues to repack. No design at any position.
103
+
104
+ # Create selector for the region to be relaxed
105
+ # Turn off design and repacking on irrelevant positions
106
+ # if flexible_residue_first[-1] == ' ':
107
+ # flexible_residue_first = flexible_residue_first[:-1]
108
+ # if flexible_residue_last[-1] == ' ':
109
+ # flexible_residue_last = flexible_residue_last[:-1]
110
+ if self.subset != 'all':
111
+ chain_selectors = [selections.ChainSelector(chain) for chain in ligand_chains]
112
+ if len(chain_selectors) == 1:
113
+ gen_selector = chain_selectors[0]
114
+ else:
115
+ gen_selector = selections.OrResidueSelector(chain_selectors[0], chain_selectors[1])
116
+ for selector in chain_selectors[2:]:
117
+ gen_selector = selections.OrResidueSelector(gen_selector, selector)
118
+ # gen_selector = selections.ChainSelector(pep_chain)
119
+ # gen_selector = selections.ResidueIndexSelector()
120
+ # gen_selector.set_index_range(
121
+ # pose.pdb_info().pdb2pose(*flexible_residue_first),
122
+ # pose.pdb_info().pdb2pose(*flexible_residue_last),
123
+ # )
124
+ nbr_selector = selections.NeighborhoodResidueSelector()
125
+ nbr_selector.set_focus_selector(gen_selector)
126
+ nbr_selector.set_include_focus_in_subset(True)
127
+
128
+ if self.subset == 'nbrs':
129
+ subset_selector = nbr_selector
130
+ elif self.subset == 'target':
131
+ subset_selector = gen_selector
132
+
133
+ prevent_repacking_rlt = operation.PreventRepackingRLT()
134
+ prevent_subset_repacking = operation.OperateOnResidueSubset(
135
+ prevent_repacking_rlt,
136
+ subset_selector,
137
+ flip_subset=True,
138
+ )
139
+ tf.push_back(prevent_subset_repacking)
140
+
141
+ scorefxn = self.scorefxn
142
+ fr = self.fast_relax
143
+
144
+ pose = original_pose.clone()
145
+ # pos_list = pyrosetta.rosetta.utility.vector1_unsigned_long()
146
+ # for pos in range(pose.pdb_info().pdb2pose(*flexible_residue_first), pose.pdb_info().pdb2pose(*flexible_residue_last)+1):
147
+ # pos_list.append(pos)
148
+ # basic_idealize(pose, pos_list, scorefxn, fast=True)
149
+
150
+ mmf = MoveMapFactory()
151
+ if self.move_bb:
152
+ mmf.add_bb_action(move_map_action.mm_enable, gen_selector)
153
+ mmf.add_chi_action(move_map_action.mm_enable, subset_selector)
154
+ mm = mmf.create_movemap_from_pose(pose)
155
+
156
+ fr.set_movemap(mm)
157
+ fr.set_task_factory(tf)
158
+ fr.apply(pose)
159
+
160
+ e_before = scorefxn(original_pose)
161
+ e_relax = scorefxn(pose)
162
+ # print('\n\n[Finished in %.2f secs]' % ((current_milli_time() - start_t) / 1000))
163
+ # print(' > Energy (before): %.4f' % scorefxn(original_pose))
164
+ # print(' > Energy (optimized): %.4f' % scorefxn(pose))
165
+ return pose, e_before, e_relax
166
+
167
+
168
+ def pyrosetta_fastrelax(pdb_path, out_path, pep_chain, rfdiff_config=False):
169
+ minimizer = RelaxRegion(rfdiff_config=rfdiff_config)
170
+ pose_min, _, _ = minimizer(
171
+ pdb_path=pdb_path,
172
+ ligand_chains=[pep_chain]
173
+ )
174
+ pose_min.dump_pdb(out_path)
175
+
176
+
177
+ def _rename_chain(pdb_path, out_path, src_pep_chain, tgt_pep_chain, tgt_rec_chain):
178
+
179
+ io = PDBIO()
180
+ parser = PDBParser()
181
+
182
+ structure = parser.get_structure('anonymous', pdb_path)
183
+
184
+ new_mapping = {}
185
+ pep_chain, rec_chain = BChain(id=tgt_pep_chain), BChain(id=tgt_rec_chain)
186
+
187
+ for model in structure:
188
+ for chain in model:
189
+ if chain.get_id() == src_pep_chain:
190
+ new_mapping[src_pep_chain] = tgt_pep_chain
191
+ for res in chain:
192
+ pep_chain.add(res.copy())
193
+ else:
194
+ new_mapping[chain.get_id()] = tgt_rec_chain
195
+ for res in chain:
196
+ rec_chain.add(res.copy())
197
+
198
+ structure = BStructure(id=structure.get_id())
199
+ model = BModel(id=0)
200
+ model.add(pep_chain)
201
+ model.add(rec_chain)
202
+ structure.add(model)
203
+
204
+ io.set_structure(structure)
205
+ io.save(out_path)
206
+
207
+ return new_mapping
208
+
209
+
210
+ def rfdiff_refine(pdb_path, out_path, pep_chain):
211
+ # rename peptide chain to A and receptor to B
212
+ new_mapping = _rename_chain(pdb_path, out_path, pep_chain, 'A', 'B')
213
+
214
+ # force fields from RFDiffusion
215
+ get_scorefxn('beta_nov16')
216
+ xml = os.path.join(os.path.dirname(__file__), 'RosettaFastRelaxUtil.xml')
217
+ objs = protocols.rosetta_scripts.XmlObjects.create_from_file(xml)
218
+ fastrelax = objs.get_mover('FastRelax')
219
+ pose = pyrosetta.pose_from_pdb(out_path)
220
+ fastrelax.apply(pose)
221
+ pose.dump_pdb(out_path)
222
+
223
+ # get back to original chain ids
224
+ reverse_mapping = { new_mapping[key]: key for key in new_mapping }
225
+ _rename_chain(out_path, out_path, 'A', reverse_mapping['A'], reverse_mapping['B'])
226
+
227
+
228
+ def pyrosetta_interface_energy(pdb_path, receptor_chains, ligand_chains, return_dict=False):
229
+ pose = pyrosetta.pose_from_pdb(pdb_path)
230
+ interface = ''.join(ligand_chains) + '_' + ''.join(receptor_chains)
231
+ mover = InterfaceAnalyzerMover(interface)
232
+ mover.set_pack_separated(True)
233
+ mover.apply(pose)
234
+ if return_dict:
235
+ return pose.scores
236
+ return pose.scores['dG_separated']
evaluation/dG/openmm_relaxer.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import io
4
+ import logging
5
+ import pdbfixer
6
+ import openmm
7
+ from openmm import app as openmm_app
8
+ from openmm import unit
9
+ ENERGY = unit.kilocalories_per_mole
10
+ LENGTH = unit.angstroms
11
+
12
+
13
+ class ForceFieldMinimizer(object):
14
+
15
+ def __init__(self, stiffness=10.0, max_iterations=0, tolerance=2.39*unit.kilocalories_per_mole, platform='CUDA'):
16
+ super().__init__()
17
+ self.stiffness = stiffness
18
+ self.max_iterations = max_iterations
19
+ self.tolerance = tolerance
20
+ assert platform in ('CUDA', 'CPU')
21
+ self.platform = platform
22
+
23
+ def _fix(self, pdb_str):
24
+ fixer = pdbfixer.PDBFixer(pdbfile=io.StringIO(pdb_str))
25
+ fixer.findNonstandardResidues()
26
+ fixer.replaceNonstandardResidues()
27
+
28
+ fixer.findMissingResidues()
29
+ fixer.findMissingAtoms()
30
+ fixer.addMissingAtoms(seed=0)
31
+ fixer.addMissingHydrogens()
32
+
33
+ out_handle = io.StringIO()
34
+ openmm_app.PDBFile.writeFile(fixer.topology, fixer.positions, out_handle, keepIds=True)
35
+ return out_handle.getvalue()
36
+
37
+ def _get_pdb_string(self, topology, positions):
38
+ with io.StringIO() as f:
39
+ openmm_app.PDBFile.writeFile(topology, positions, f, keepIds=True)
40
+ return f.getvalue()
41
+
42
+ def _minimize(self, pdb_str):
43
+ pdb = openmm_app.PDBFile(io.StringIO(pdb_str))
44
+
45
+ force_field = openmm_app.ForceField("charmm36.xml") # referring to http://docs.openmm.org/latest/userguide/application/02_running_sims.html
46
+ constraints = openmm_app.HBonds
47
+ system = force_field.createSystem(pdb.topology, constraints=constraints)
48
+
49
+ # Add constraints to non-generated regions
50
+ force = openmm.CustomExternalForce("0.5 * k * ((x-x0)^2 + (y-y0)^2 + (z-z0)^2)")
51
+ force.addGlobalParameter("k", self.stiffness)
52
+ for p in ["x0", "y0", "z0"]:
53
+ force.addPerParticleParameter(p)
54
+
55
+ for i, a in enumerate(pdb.topology.atoms()):
56
+ if a.element.name != 'hydrogen':
57
+ force.addParticle(i, pdb.positions[i])
58
+
59
+ system.addForce(force)
60
+
61
+ # Set up the integrator and simulation
62
+ integrator = openmm.LangevinIntegrator(0, 0.01, 0.0)
63
+ platform = openmm.Platform.getPlatformByName("CUDA")
64
+ simulation = openmm_app.Simulation(pdb.topology, system, integrator, platform)
65
+ simulation.context.setPositions(pdb.positions)
66
+
67
+ # Perform minimization
68
+ ret = {}
69
+ state = simulation.context.getState(getEnergy=True, getPositions=True)
70
+ ret["einit"] = state.getPotentialEnergy().value_in_unit(ENERGY)
71
+ ret["posinit"] = state.getPositions(asNumpy=True).value_in_unit(LENGTH)
72
+
73
+ simulation.minimizeEnergy(maxIterations=self.max_iterations, tolerance=self.tolerance)
74
+
75
+ state = simulation.context.getState(getEnergy=True, getPositions=True)
76
+ ret["efinal"] = state.getPotentialEnergy().value_in_unit(ENERGY)
77
+ ret["pos"] = state.getPositions(asNumpy=True).value_in_unit(LENGTH)
78
+ ret["min_pdb"] = self._get_pdb_string(simulation.topology, state.getPositions())
79
+
80
+ return ret['min_pdb'], ret
81
+
82
+ def _add_energy_remarks(self, pdb_str, ret):
83
+ pdb_lines = pdb_str.splitlines()
84
+ pdb_lines.insert(1, "REMARK 1 FINAL ENERGY: {:.3f} KCAL/MOL".format(ret['efinal']))
85
+ pdb_lines.insert(1, "REMARK 1 INITIAL ENERGY: {:.3f} KCAL/MOL".format(ret['einit']))
86
+ return "\n".join(pdb_lines)
87
+
88
+ def __call__(self, pdb_str, out_path, return_info=True):
89
+ if '\n' not in pdb_str and pdb_str.lower().endswith(".pdb"):
90
+ with open(pdb_str) as f:
91
+ pdb_str = f.read()
92
+
93
+ pdb_fixed = self._fix(pdb_str)
94
+ pdb_min, ret = self._minimize(pdb_fixed)
95
+ pdb_min = self._add_energy_remarks(pdb_min, ret)
96
+ with open(out_path, 'w') as f:
97
+ f.write(pdb_min)
98
+ if return_info:
99
+ return pdb_min, ret
100
+ else:
101
+ return pdb_min
102
+
103
+
104
+ if __name__ == '__main__':
105
+ import sys
106
+ force_field = ForceFieldMinimizer()
107
+ force_field(sys.argv[1], sys.argv[2])
evaluation/dG/run.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python
2
+ # -*- coding:utf-8 -*-
3
+ import os
4
+ import json
5
+ import argparse
6
+ import statistics
7
+
8
+ import ray
9
+
10
+ from utils.logger import print_log
11
+
12
+ from .base import TaskScanner, run_pyrosetta
13
+
14
+ # @ray.remote(num_gpus=1/8, num_cpus=1)
15
+ # def run_openmm_remote(task):
16
+ # return run_openmm(task)
17
+
18
+
19
+ @ray.remote(num_cpus=1)
20
+ def run_pyrosetta_remote(task):
21
+ return run_pyrosetta(task)
22
+
23
+
24
+ @ray.remote
25
+ def pipeline_pyrosetta(task):
26
+ funcs = [
27
+ run_pyrosetta_remote,
28
+ ]
29
+ for fn in funcs:
30
+ task = fn.remote(task)
31
+ return ray.get(task)
32
+
33
+
34
+ def parse():
35
+ parser = argparse.ArgumentParser(description='calculating dG using pyrosetta')
36
+ parser.add_argument('--results', type=str, required=True, help='Path to the summary of the results (.jsonl)')
37
+ parser.add_argument('--n_sample', type=int, default=float('inf'), help='Maximum number of samples for calculation')
38
+ parser.add_argument('--rfdiff_relax', action='store_true', help='Use rfdiff fastrelax')
39
+ parser.add_argument('--out_path', type=str, default=None, help='Output path, default dG_report.jsonl under the same directory as results')
40
+ return parser.parse_args()
41
+
42
+
43
+ def main(args):
44
+ # output summary
45
+ if args.out_path is None:
46
+ args.out_path = os.path.join(os.path.dirname(args.results), 'dG_report.jsonl')
47
+ results = {}
48
+
49
+ # parallel
50
+ ray.init()
51
+ scanner = TaskScanner(args.results, args.n_sample, args.rfdiff_relax)
52
+ if args.results.endswith('txt'):
53
+ tasks = scanner.scan_dataset()
54
+ else:
55
+ tasks = scanner.scan()
56
+ futures = [pipeline_pyrosetta.remote(t) for t in tasks]
57
+ if len(futures) > 0:
58
+ print_log(f'Submitted {len(futures)} tasks.')
59
+ while len(futures) > 0:
60
+ done_ids, futures = ray.wait(futures, num_returns=1)
61
+ for done_id in done_ids:
62
+ done_task = ray.get(done_id)
63
+ print_log(f'Remaining {len(futures)}. Finished {done_task.current_path}, dG {done_task.dG}')
64
+ _id, number = done_task.info['id'], done_task.info['number']
65
+ if _id not in results:
66
+ results[_id] = {
67
+ 'min': float('inf'),
68
+ 'all': {}
69
+ }
70
+ results[_id]['all'][number] = done_task.dG
71
+ results[_id]['min'] = min(results[_id]['min'], done_task.dG)
72
+
73
+ # write results
74
+ for _id in results:
75
+ success = 0
76
+ for n in results[_id]['all']:
77
+ if results[_id]['all'][n] < 0:
78
+ success += 1
79
+ results[_id]['success rate'] = success / len(results[_id]['all'])
80
+ json.dump(results, open(args.out_path, 'w'), indent=2)
81
+
82
+ # show results
83
+ vals = [results[_id]['min'] for _id in results]
84
+ print(f'median: {statistics.median(vals)}, mean: {sum(vals) / len(vals)}')
85
+ success = [results[_id]['success rate'] for _id in results]
86
+ print(f'mean success rate: {sum(success) / len(success)}')
87
+
88
+
89
+ if __name__ == '__main__':
90
+ import random
91
+ random.seed(12)
92
+ main(parse())
evaluation/diversity.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python
2
+ # -*- coding:utf-8 -*-
3
+ from typing import List
4
+
5
+ import numpy as np
6
+ from scipy.cluster.hierarchy import linkage, fcluster
7
+ from scipy.spatial.distance import squareform
8
+ from scipy.stats.contingency import association
9
+
10
+ from evaluation.seq_metric import align_sequences
11
+
12
+
13
+ def seq_diversity(seqs: List[str], th: float=0.4) -> float:
14
+ '''
15
+ th: sequence distance
16
+ '''
17
+ dists = []
18
+ for i, seq1 in enumerate(seqs):
19
+ dists.append([])
20
+ for j, seq2 in enumerate(seqs):
21
+ _, sim = align_sequences(seq1, seq2)
22
+ dists[i].append(1 - sim)
23
+ dists = np.array(dists)
24
+ Z = linkage(squareform(dists), 'single')
25
+ cluster = fcluster(Z, t=th, criterion='distance')
26
+ return len(np.unique(cluster)) / len(seqs), cluster
27
+
28
+
29
+ def struct_diversity(structs: np.ndarray, th: float=4.0) -> float:
30
+ '''
31
+ structs: N*L*3, alpha carbon coordinates
32
+ th: threshold for clustering (distance < th)
33
+ '''
34
+ ca_dists = np.sum((structs[:, None] - structs[None, :]) ** 2, axis=-1) # [N, N, L]
35
+ rmsd = np.sqrt(np.mean(ca_dists, axis=-1))
36
+ Z = linkage(squareform(rmsd), 'single') # since the distances might not be euclidean distances (e.g. rmsd)
37
+ cluster = fcluster(Z, t=th, criterion='distance')
38
+ return len(np.unique(cluster)) / structs.shape[0], cluster
39
+
40
+
41
+ def diversity(seqs: List[str], structs: np.ndarray):
42
+ seq_div, seq_clu = seq_diversity(seqs)
43
+ if structs is None:
44
+ return seq_div, None, seq_div, None
45
+ struct_div, struct_clu = struct_diversity(structs)
46
+ co_div = np.sqrt(seq_div * struct_div)
47
+
48
+ n_seq_clu, n_struct_clu = np.max(seq_clu), np.max(struct_clu) # clusters start from 1
49
+ if n_seq_clu == 1 or n_struct_clu == 1:
50
+ consistency = 1.0 if n_seq_clu == n_struct_clu else 0.0
51
+ else:
52
+ table = [[0 for _ in range(n_struct_clu)] for _ in range(n_seq_clu)]
53
+ for seq_c, struct_c in zip(seq_clu, struct_clu):
54
+ table[seq_c - 1][struct_c - 1] += 1
55
+ consistency = association(np.array(table), method='cramer')
56
+
57
+ return seq_div, struct_div, co_div, consistency
58
+
59
+
60
+ if __name__ == '__main__':
61
+ N, L = 100, 10
62
+ a = np.random.randn(N, L, 3)
63
+ print(struct_diversity(a))
64
+ from utils.const import aas
65
+ aas = [tup[0] for tup in aas]
66
+ seqs = np.random.randint(0, len(aas), (N, L))
67
+ seqs = [''.join([aas[i] for i in idx]) for idx in seqs]
68
+ print(seq_diversity(seqs, 0.4))
evaluation/dockq.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python
2
+ # -*- coding:utf-8 -*-
3
+ import os
4
+ import re
5
+
6
+ from globals import DOCKQ_DIR
7
+
8
+
9
+ def dockq(mod_pdb: str, native_pdb: str, pep_chain: str):
10
+ p = os.popen(f'{os.path.join(DOCKQ_DIR, "DockQ.py")} {mod_pdb} {native_pdb} -model_chain1 {pep_chain} -native_chain1 {pep_chain} -no_needle')
11
+ text = p.read()
12
+ p.close()
13
+ res = re.search(r'DockQ\s+([0-1]\.[0-9]+)', text)
14
+ score = float(res.group(1))
15
+ return score
evaluation/rmsd.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python
2
+ # -*- coding:utf-8 -*-
3
+ import torch
4
+ import numpy as np
5
+
6
+
7
+ # a: [N, 3], b: [N, 3]
8
+ def compute_rmsd(a, b, aligned=False): # amino acids level rmsd
9
+ dist = np.sum((a - b) ** 2, axis=-1)
10
+ rmsd = np.sqrt(dist.sum() / a.shape[0])
11
+ return float(rmsd)
evaluation/seq_metric.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python
2
+ # -*- coding:utf-8 -*-
3
+ from math import sqrt
4
+
5
+ from Bio.Align import substitution_matrices, PairwiseAligner
6
+
7
+
8
+ def aar(candidate, reference):
9
+ hit = 0
10
+ for a, b in zip(candidate, reference):
11
+ if a == b:
12
+ hit += 1
13
+ return hit / len(reference)
14
+
15
+
16
+ def align_sequences(sequence_A, sequence_B, **kwargs):
17
+ """
18
+ Performs a global pairwise alignment between two sequences
19
+ using the BLOSUM62 matrix and the Needleman-Wunsch algorithm
20
+ as implemented in Biopython. Returns the alignment, the sequence
21
+ identity and the residue mapping between both original sequences.
22
+ """
23
+
24
+ sub_matrice = substitution_matrices.load('BLOSUM62')
25
+ aligner = PairwiseAligner()
26
+ aligner.substitution_matrix = sub_matrice
27
+ alns = aligner.align(sequence_A, sequence_B)
28
+
29
+ best_aln = alns[0]
30
+ aligned_A, aligned_B = best_aln
31
+
32
+ base = sqrt(aligner.score(sequence_A, sequence_A) * aligner.score(sequence_B, sequence_B))
33
+ seq_id = aligner.score(sequence_A, sequence_B) / base
34
+ return (aligned_A, aligned_B), seq_id
35
+
36
+
37
+ def slide_aar(candidate, reference, aar_func):
38
+ '''
39
+ e.g.
40
+ candidate: AILPV
41
+ reference: ILPVH
42
+
43
+ should be matched as
44
+ AILPV
45
+ ILPVH
46
+
47
+ To do this, we slide the candidate and calculate the maximum aar:
48
+ A
49
+ AI
50
+ AIL
51
+ AILP
52
+ AILPV
53
+ ILPV
54
+ LPV
55
+ PV
56
+ V
57
+ '''
58
+ special_token = ' '
59
+ ref_len = len(reference)
60
+ padded_candidate = special_token * (ref_len - 1) + candidate + special_token * (ref_len - 1)
61
+ value = 0
62
+ for start in range(len(padded_candidate) - ref_len + 1):
63
+ value = max(value, aar_func(padded_candidate[start:start + ref_len], reference))
64
+ return value
65
+
66
+
67
+ if __name__ == '__main__':
68
+ print(align_sequences('PKGYAAPSA', 'KPAVYKFTL'))
69
+ print(align_sequences('KPAVYKFTL', 'PKGYAAPSA'))
70
+ print(align_sequences('PKGYAAPSA', 'PKGYAAPSA'))
71
+ print(align_sequences('KPAVYKFTL', 'KPAVYKFTL'))
generate.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python
2
+ # -*- coding:utf-8 -*-
3
+ import argparse
4
+ import json
5
+ import os
6
+ import pickle as pkl
7
+ from tqdm import tqdm
8
+ from copy import deepcopy
9
+ from multiprocessing import Pool
10
+
11
+ import yaml
12
+ import torch
13
+ from torch.utils.data import DataLoader
14
+
15
+ import models
16
+ from utils.config_utils import overwrite_values
17
+ from data.converter.pdb_to_list_blocks import pdb_to_list_blocks
18
+ from data.converter.list_blocks_to_pdb import list_blocks_to_pdb
19
+ from data.format import VOCAB, Atom
20
+ from data import create_dataloader, create_dataset
21
+ from utils.logger import print_log
22
+ from utils.random_seed import setup_seed
23
+ from utils.const import sidechain_atoms
24
+
25
+
26
+ def get_best_ckpt(ckpt_dir):
27
+ with open(os.path.join(ckpt_dir, 'checkpoint', 'topk_map.txt'), 'r') as f:
28
+ ls = f.readlines()
29
+ ckpts = []
30
+ for l in ls:
31
+ k,v = l.strip().split(':')
32
+ k = float(k)
33
+ v = v.split('/')[-1]
34
+ ckpts.append((k,v))
35
+
36
+ # ckpts = sorted(ckpts, key=lambda x:x[0])
37
+ best_ckpt = ckpts[0][1]
38
+ return os.path.join(ckpt_dir, 'checkpoint', best_ckpt)
39
+
40
+
41
+ def to_device(data, device):
42
+ if isinstance(data, dict):
43
+ for key in data:
44
+ data[key] = to_device(data[key], device)
45
+ elif isinstance(data, list) or isinstance(data, tuple):
46
+ res = [to_device(item, device) for item in data]
47
+ data = type(data)(res)
48
+ elif hasattr(data, 'to'):
49
+ data = data.to(device)
50
+ return data
51
+
52
+
53
+ def clamp_coord(coord):
54
+ # some models (e.g. diffab) will output very large coordinates (absolute value >1000) which will corrupt the pdb file
55
+ new_coord = []
56
+ for val in coord:
57
+ if abs(val) >= 1000:
58
+ val = 0
59
+ new_coord.append(val)
60
+ return new_coord
61
+
62
+
63
+ def overwrite_blocks(blocks, seq=None, X=None):
64
+ if seq is not None:
65
+ assert len(blocks) == len(seq), f'{len(blocks)} {len(seq)}'
66
+ new_blocks = []
67
+ for i, block in enumerate(blocks):
68
+ block = deepcopy(block)
69
+ if seq is None:
70
+ abrv = block.abrv
71
+ else:
72
+ abrv = VOCAB.symbol_to_abrv(seq[i])
73
+ if block.abrv != abrv:
74
+ if X is None:
75
+ block.units = [atom for atom in block.units if atom.name in VOCAB.backbone_atoms]
76
+ if X is not None:
77
+ coords = X[i]
78
+ atoms = VOCAB.backbone_atoms + sidechain_atoms[VOCAB.abrv_to_symbol(abrv)]
79
+ block.units = [
80
+ Atom(atom_name, clamp_coord(coord), atom_name[0]) for atom_name, coord in zip(atoms, coords)
81
+ ]
82
+ block.abrv = abrv
83
+ new_blocks.append(block)
84
+ return new_blocks
85
+
86
+
87
+ def generate_wrapper(model, sample_opt={}):
88
+ if isinstance(model, models.AutoEncoder):
89
+ def wrapper(batch):
90
+ X, S, ppls = model.test(batch['X'], batch['S'], batch['mask'], batch['position_ids'], batch['lengths'], batch['atom_mask'])
91
+ return X, S, ppls
92
+ elif isinstance(model, models.LDMPepDesign):
93
+ def wrapper(batch):
94
+ X, S, ppls = model.sample(batch['X'], batch['S'], batch['mask'], batch['position_ids'], batch['lengths'], batch['atom_mask'],
95
+ L=batch['L'] if 'L' in batch else None, sample_opt=sample_opt)
96
+ return X, S, ppls
97
+ else:
98
+ raise NotImplementedError(f'Wrapper for {type(model)} not implemented')
99
+ return wrapper
100
+
101
+
102
+ def save_data(
103
+ _id, n,
104
+ x_pkl_file, s_pkl_file, pmetric_pkl_file,
105
+ ref_pdb, rec_chain, lig_chain, ref_save_dir, cand_save_dir,
106
+ seq_only, struct_only, backbone_only
107
+ ):
108
+
109
+ X, S, pmetric = pkl.load(open(x_pkl_file, 'rb')), pkl.load(open(s_pkl_file, 'rb')), pkl.load(open(pmetric_pkl_file, 'rb'))
110
+ os.remove(x_pkl_file), os.remove(s_pkl_file), os.remove(pmetric_pkl_file)
111
+ if seq_only:
112
+ X = None
113
+ elif struct_only:
114
+ S = None
115
+ rec_blocks, lig_blocks = pdb_to_list_blocks(ref_pdb, selected_chains=[rec_chain, lig_chain])
116
+ ref_pdb = os.path.join(ref_save_dir, _id + "_ref.pdb")
117
+ list_blocks_to_pdb([rec_blocks, lig_blocks], [rec_chain, lig_chain], ref_pdb)
118
+ # os.system(f'cp {ref_pdb} {os.path.join(ref_save_dir, _id + "_ref.pdb")}')
119
+ ref_seq = ''.join([VOCAB.abrv_to_symbol(block.abrv) for block in lig_blocks])
120
+ lig_blocks = overwrite_blocks(lig_blocks, S, X)
121
+ gen_seq = ''.join([VOCAB.abrv_to_symbol(block.abrv) for block in lig_blocks])
122
+ save_dir = os.path.join(cand_save_dir, _id)
123
+ if not os.path.exists(save_dir):
124
+ os.makedirs(save_dir)
125
+ gen_pdb = os.path.join(save_dir, _id + f'_gen_{n}.pdb')
126
+ list_blocks_to_pdb([rec_blocks, lig_blocks], [rec_chain, lig_chain], gen_pdb)
127
+
128
+ return {
129
+ 'id': _id,
130
+ 'number': n,
131
+ 'gen_pdb': gen_pdb,
132
+ 'ref_pdb': ref_pdb,
133
+ 'pmetric': pmetric,
134
+ 'rec_chain': rec_chain,
135
+ 'lig_chain': lig_chain,
136
+ 'ref_seq': ref_seq,
137
+ 'gen_seq': gen_seq,
138
+ 'seq_only': seq_only,
139
+ 'struct_only': struct_only,
140
+ 'backbone_only': backbone_only
141
+ }
142
+
143
+
144
+ def main(args, opt_args):
145
+ config = yaml.safe_load(open(args.config, 'r'))
146
+ config = overwrite_values(config, opt_args)
147
+ struct_only = config.get('struct_only', False)
148
+ seq_only = config.get('seq_only', False)
149
+ assert not (seq_only and struct_only)
150
+ backbone_only = config.get('backbone_only', False)
151
+ # load model
152
+ b_ckpt = args.ckpt if args.ckpt.endswith('.ckpt') else get_best_ckpt(args.ckpt)
153
+ ckpt_dir = os.path.split(os.path.split(b_ckpt)[0])[0]
154
+ print(f'Using checkpoint {b_ckpt}')
155
+ model = torch.load(b_ckpt, map_location='cpu')
156
+ device = torch.device('cpu' if args.gpu == -1 else f'cuda:{args.gpu}')
157
+ model.to(device)
158
+ model.eval()
159
+
160
+ # load data
161
+ _, _, test_set = create_dataset(config['dataset'])
162
+ test_loader = create_dataloader(test_set, config['dataloader'])
163
+
164
+ # save path
165
+ if args.save_dir is None:
166
+ save_dir = os.path.join(ckpt_dir, 'results')
167
+ else:
168
+ save_dir = args.save_dir
169
+ ref_save_dir = os.path.join(save_dir, 'references')
170
+ cand_save_dir = os.path.join(save_dir, 'candidates')
171
+ for directory in [ref_save_dir, cand_save_dir]:
172
+ if not os.path.exists(directory):
173
+ os.makedirs(directory)
174
+
175
+
176
+ fout = open(os.path.join(save_dir, 'results.jsonl'), 'w')
177
+ item_idx = 0
178
+
179
+ # multiprocessing
180
+ pool = Pool(args.n_cpu)
181
+
182
+ n_samples = config.get('n_samples', 1)
183
+
184
+ pbar = tqdm(total=n_samples * len(test_loader))
185
+ for n in range(n_samples):
186
+ item_idx = 0
187
+ with torch.no_grad():
188
+ for batch in test_loader:
189
+ batch = to_device(batch, device)
190
+ batch_X, batch_S, batch_pmetric = generate_wrapper(model, deepcopy(config.get('sample_opt', {})))(batch)
191
+
192
+ # parallel
193
+ inputs = []
194
+ for X, S, pmetric in zip(batch_X, batch_S, batch_pmetric):
195
+ _id, ref_pdb, rec_chain, lig_chain = test_set.get_summary(item_idx)
196
+ # save temporary pickle file
197
+ x_pkl_file = os.path.join(save_dir, _id + f'_gen_{n}_X.pkl')
198
+ pkl.dump(X, open(x_pkl_file, 'wb'))
199
+ s_pkl_file = os.path.join(save_dir, _id + f'_gen_{n}_S.pkl')
200
+ pkl.dump(S, open(s_pkl_file, 'wb'))
201
+ pmetric_pkl_file = os.path.join(save_dir, _id + f'_gen_{n}_pmetric.pkl')
202
+ pkl.dump(pmetric, open(pmetric_pkl_file, 'wb'))
203
+ inputs.append((
204
+ _id, n,
205
+ x_pkl_file, s_pkl_file, pmetric_pkl_file,
206
+ ref_pdb, rec_chain, lig_chain, ref_save_dir, cand_save_dir,
207
+ seq_only, struct_only, backbone_only
208
+ ))
209
+ item_idx += 1
210
+
211
+ results = pool.starmap(save_data, inputs)
212
+ for result in results:
213
+ fout.write(json.dumps(result) + '\n')
214
+
215
+ pbar.update(1)
216
+
217
+ fout.close()
218
+
219
+
220
+ def parse():
221
+ parser = argparse.ArgumentParser(description='Generate peptides given epitopes')
222
+ parser.add_argument('--config', type=str, required=True, help='Path to the test configuration')
223
+ parser.add_argument('--ckpt', type=str, required=True, help='Path to checkpoint')
224
+ parser.add_argument('--save_dir', type=str, default=None, help='Directory to save generated peptides')
225
+
226
+ parser.add_argument('--gpu', type=int, default=0, help='GPU to use, -1 for cpu')
227
+ parser.add_argument('--n_cpu', type=int, default=4, help='Number of CPU to use (for parallelly saving the generated results)')
228
+ return parser.parse_known_args()
229
+
230
+
231
+ if __name__ == '__main__':
232
+ args, opt_args = parse()
233
+ print_log(f'Overwritting args: {opt_args}')
234
+ setup_seed(12)
235
+ main(args, opt_args)