添加PepGLAD初始代码
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .github/workflows/stale.yml +27 -0
- .gitignore +35 -0
- .idea/.gitignore +3 -0
- .idea/PepGLAD.iml +8 -0
- .idea/inspectionProfiles/profiles_settings.xml +6 -0
- .idea/misc.xml +4 -0
- .idea/modules.xml +8 -0
- .idea/vcs.xml +6 -0
- LICENSE +21 -0
- README.md +214 -3
- api/detect_pocket.py +72 -0
- api/run.py +274 -0
- assets/1ssc_A_pocket.json +1 -0
- cal_metrics.py +228 -0
- configs/pepbdb/autoencoder/train_codesign.yaml +66 -0
- configs/pepbdb/autoencoder/train_fixseq.yaml +63 -0
- configs/pepbdb/ldm/setup_latent_guidance.yaml +12 -0
- configs/pepbdb/ldm/train_codesign.yaml +61 -0
- configs/pepbdb/ldm/train_fixseq.yaml +63 -0
- configs/pepbdb/test_codesign.yaml +18 -0
- configs/pepbdb/test_fixseq.yaml +19 -0
- configs/pepbench/autoencoder/train_codesign.yaml +66 -0
- configs/pepbench/autoencoder/train_fixseq.yaml +62 -0
- configs/pepbench/ldm/setup_latent_guidance.yaml +12 -0
- configs/pepbench/ldm/train_codesign.yaml +60 -0
- configs/pepbench/ldm/train_fixseq.yaml +61 -0
- configs/pepbench/test_codesign.yaml +17 -0
- configs/pepbench/test_fixseq.yaml +18 -0
- data/__init__.py +53 -0
- data/codesign.py +208 -0
- data/converter/blocks_interface.py +89 -0
- data/converter/blocks_to_data.py +110 -0
- data/converter/list_blocks_to_pdb.py +61 -0
- data/converter/pdb_to_list_blocks.py +99 -0
- data/dataset_wrapper.py +115 -0
- data/format.py +220 -0
- data/mmap_dataset.py +112 -0
- data/resample.py +19 -0
- env.yaml +32 -0
- evaluation/__init__.py +3 -0
- evaluation/dG/RosettaFastRelaxUtil.xml +190 -0
- evaluation/dG/base.py +148 -0
- evaluation/dG/energy.py +236 -0
- evaluation/dG/openmm_relaxer.py +107 -0
- evaluation/dG/run.py +92 -0
- evaluation/diversity.py +68 -0
- evaluation/dockq.py +15 -0
- evaluation/rmsd.py +11 -0
- evaluation/seq_metric.py +71 -0
- generate.py +235 -0
.github/workflows/stale.yml
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This workflow warns and then closes issues and PRs that have had no activity for a specified amount of time.
|
| 2 |
+
#
|
| 3 |
+
# You can adjust the behavior by modifying this file.
|
| 4 |
+
# For more information, see:
|
| 5 |
+
# https://github.com/actions/stale
|
| 6 |
+
name: Close inactive issues
|
| 7 |
+
on:
|
| 8 |
+
schedule:
|
| 9 |
+
- cron: "30 1 * * *"
|
| 10 |
+
|
| 11 |
+
jobs:
|
| 12 |
+
close-issues:
|
| 13 |
+
runs-on: ubuntu-latest
|
| 14 |
+
permissions:
|
| 15 |
+
issues: write
|
| 16 |
+
pull-requests: write
|
| 17 |
+
steps:
|
| 18 |
+
- uses: actions/stale@v9
|
| 19 |
+
with:
|
| 20 |
+
days-before-issue-stale: 30
|
| 21 |
+
days-before-issue-close: 14
|
| 22 |
+
stale-issue-label: "stale"
|
| 23 |
+
stale-issue-message: "This issue is stale because it has been open for 30 days with no activity."
|
| 24 |
+
close-issue-message: "This issue was closed because it has been inactive for 14 days since being marked as stale."
|
| 25 |
+
days-before-pr-stale: -1
|
| 26 |
+
days-before-pr-close: -1
|
| 27 |
+
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
.gitignore
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__
|
| 2 |
+
|
| 3 |
+
__cache__
|
| 4 |
+
|
| 5 |
+
__tmcache__
|
| 6 |
+
|
| 7 |
+
ckpts
|
| 8 |
+
|
| 9 |
+
checkpoints
|
| 10 |
+
|
| 11 |
+
*_results*
|
| 12 |
+
|
| 13 |
+
datasets
|
| 14 |
+
|
| 15 |
+
exps
|
| 16 |
+
|
| 17 |
+
DockQ
|
| 18 |
+
|
| 19 |
+
TMscore
|
| 20 |
+
|
| 21 |
+
*.txt
|
| 22 |
+
|
| 23 |
+
*.pt
|
| 24 |
+
|
| 25 |
+
*.png
|
| 26 |
+
|
| 27 |
+
*.pkl
|
| 28 |
+
|
| 29 |
+
*.svg
|
| 30 |
+
|
| 31 |
+
*.log
|
| 32 |
+
|
| 33 |
+
*.pdb
|
| 34 |
+
|
| 35 |
+
*.jsonl
|
.idea/.gitignore
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 默认忽略的文件
|
| 2 |
+
/shelf/
|
| 3 |
+
/workspace.xml
|
.idea/PepGLAD.iml
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
| 2 |
+
<module type="PYTHON_MODULE" version="4">
|
| 3 |
+
<component name="NewModuleRootManager">
|
| 4 |
+
<content url="file://$MODULE_DIR$" />
|
| 5 |
+
<orderEntry type="jdk" jdkName="D:\Miniconda3" jdkType="Python SDK" />
|
| 6 |
+
<orderEntry type="sourceFolder" forTests="false" />
|
| 7 |
+
</component>
|
| 8 |
+
</module>
|
.idea/inspectionProfiles/profiles_settings.xml
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<component name="InspectionProjectProfileManager">
|
| 2 |
+
<settings>
|
| 3 |
+
<option name="USE_PROJECT_PROFILE" value="false" />
|
| 4 |
+
<version value="1.0" />
|
| 5 |
+
</settings>
|
| 6 |
+
</component>
|
.idea/misc.xml
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
| 2 |
+
<project version="4">
|
| 3 |
+
<component name="ProjectRootManager" version="2" project-jdk-name="D:\Miniconda3" project-jdk-type="Python SDK" />
|
| 4 |
+
</project>
|
.idea/modules.xml
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
| 2 |
+
<project version="4">
|
| 3 |
+
<component name="ProjectModuleManager">
|
| 4 |
+
<modules>
|
| 5 |
+
<module fileurl="file://$PROJECT_DIR$/.idea/PepGLAD.iml" filepath="$PROJECT_DIR$/.idea/PepGLAD.iml" />
|
| 6 |
+
</modules>
|
| 7 |
+
</component>
|
| 8 |
+
</project>
|
.idea/vcs.xml
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
| 2 |
+
<project version="4">
|
| 3 |
+
<component name="VcsDirectoryMappings">
|
| 4 |
+
<mapping directory="" vcs="Git" />
|
| 5 |
+
</component>
|
| 6 |
+
</project>
|
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2024 THUNLP
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
README.md
CHANGED
|
@@ -1,3 +1,214 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# PepGLAD: Full-Atom Peptide Design with Geometric Latent Diffusion
|
| 2 |
+
|
| 3 |
+

|
| 4 |
+
|
| 5 |
+
## Quick Links
|
| 6 |
+
|
| 7 |
+
- [Setup](#setup)
|
| 8 |
+
- [Environment](#environment)
|
| 9 |
+
- [Datasets](#optional-datasets)
|
| 10 |
+
- [Trained Weights](#trained-weights)
|
| 11 |
+
- [Usage](#usage)
|
| 12 |
+
- [Peptide Sequence-Structure Co-Design](#peptide-sequence-structure-co-design)
|
| 13 |
+
- [Peptide Binding Conformation Generation](#peptide-binding-conformation-generation)
|
| 14 |
+
- [Reproduction of Paper Experiments](#reproduction-of-paper-experiments)
|
| 15 |
+
- [Codesign](#codesign)
|
| 16 |
+
- [Binding Conformation Generation](#binding-conformation-generation)
|
| 17 |
+
- [Contact](#contact)
|
| 18 |
+
- [Reference](#reference)
|
| 19 |
+
|
| 20 |
+
## Updates
|
| 21 |
+
|
| 22 |
+
Changes for compatibilities and extended functionalities are saved in [beta](https://github.com/THUNLP-MT/PepGLAD/tree/beta) branch. Thank [@Barry0121](https://github.com/Barry0121) for the help.
|
| 23 |
+
|
| 24 |
+
- pyTorch 2.6.0 and openmm 8.2.0 are supported, with new environment configure at [2025_env.yaml](https://github.com/THUNLP-MT/PepGLAD/blob/beta/2025_env.yml).
|
| 25 |
+
- Support non-canonical amino acids in `detect_pocket.py`.
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
## Setup
|
| 29 |
+
|
| 30 |
+
### Environment
|
| 31 |
+
|
| 32 |
+
The conda environment can be constructed with the configuration `env.yaml`:
|
| 33 |
+
|
| 34 |
+
```bash
|
| 35 |
+
conda env create -f env.yaml
|
| 36 |
+
```
|
| 37 |
+
|
| 38 |
+
The codes are tested with cuda version `11.7` and pytorch version `1.13.1`.
|
| 39 |
+
|
| 40 |
+
Don't forget to activate the environment before running the codes:
|
| 41 |
+
|
| 42 |
+
```bash
|
| 43 |
+
conda activate PepGLAD
|
| 44 |
+
```
|
| 45 |
+
|
| 46 |
+
#### (Optional) pyRosetta
|
| 47 |
+
|
| 48 |
+
PyRosetta is used to calculate interface energy of generated peptides. If you are interested in it, please follow the instruction [here](https://www.pyrosetta.org/downloads) to install.
|
| 49 |
+
|
| 50 |
+
### (Optional) Datasets
|
| 51 |
+
|
| 52 |
+
These datasets are only used for benchmarking models. If you just want to use the trained weights for inferencing on your cases, there is no need to download these datasets.
|
| 53 |
+
|
| 54 |
+
#### PepBench
|
| 55 |
+
|
| 56 |
+
1. Download
|
| 57 |
+
|
| 58 |
+
The datasets, which are originally introduced in this paper, are uploaded to Zenodo at [this url](https://zenodo.org/records/13373108). You can download them as follows:
|
| 59 |
+
|
| 60 |
+
```bash
|
| 61 |
+
mkdir datasets # all datasets will be put into this directory
|
| 62 |
+
wget https://zenodo.org/records/13373108/files/train_valid.tar.gz?download=1 -O ./datasets/train_valid.tar.gz # training/validation
|
| 63 |
+
wget https://zenodo.org/records/13373108/files/LNR.tar.gz?download=1 -O ./datasets/LNR.tar.gz # test set
|
| 64 |
+
wget https://zenodo.org/records/13373108/files/ProtFrag.tar.gz?download=1 -O ./datasets/ProtFrag.tar.gz # augmentation dataset
|
| 65 |
+
```
|
| 66 |
+
|
| 67 |
+
2. Decompresss
|
| 68 |
+
|
| 69 |
+
```bash
|
| 70 |
+
tar zxvf ./datasets/train_valid.tar.gz -C ./datasets
|
| 71 |
+
tar zxvf ./datasets/LNR.tar.gz -C ./datasets
|
| 72 |
+
tar zxvf ./datasets/ProtFrag.tar.gz -C ./datasets
|
| 73 |
+
```
|
| 74 |
+
|
| 75 |
+
3. Process
|
| 76 |
+
|
| 77 |
+
```bash
|
| 78 |
+
python -m scripts.data_process.process --index ./datasets/train_valid/all.txt --out_dir ./datasets/train_valid/processed # train/validation set
|
| 79 |
+
python -m scripts.data_process.process --index ./datasets/LNR/test.txt --out_dir ./datasets/LNR/processed # test set
|
| 80 |
+
python -m scripts.data_process.process --index ./datasets/ProtFrag/all.txt --out_dir ./datasets/ProtFrag/processed # augmentation dataset
|
| 81 |
+
```
|
| 82 |
+
|
| 83 |
+
The index of processed data for train/validation splits need to be generated as follows, which will result in `datasets/train_valid/processed/train_index.txt` and `datasets/train_valid/processed/valid_index.txt`:
|
| 84 |
+
|
| 85 |
+
```bash
|
| 86 |
+
python -m scripts.data_process.split --train_index datasets/train_valid/train.txt --valid_index datasets/train_valid/valid.txt --processed_dir datasets/train_valid/processed/
|
| 87 |
+
```
|
| 88 |
+
|
| 89 |
+
#### PepBDB
|
| 90 |
+
|
| 91 |
+
1. Download
|
| 92 |
+
|
| 93 |
+
```bash
|
| 94 |
+
wget http://huanglab.phys.hust.edu.cn/pepbdb/db/download/pepbdb-20200318.tgz -O ./datasets/pepbdb.tgz
|
| 95 |
+
```
|
| 96 |
+
|
| 97 |
+
2. Decompress
|
| 98 |
+
|
| 99 |
+
```bash
|
| 100 |
+
tar zxvf ./datasets/pepbdb.tgz -C ./datasets/pepbdb
|
| 101 |
+
```
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
3. Process
|
| 105 |
+
|
| 106 |
+
```bash
|
| 107 |
+
python -m scripts.data_process.pepbdb --index ./datasets/pepbdb/peptidelist.txt --out_dir ./datasets/pepbdb/processed
|
| 108 |
+
python -m scripts.data_process.split --train_index ./datasets/pepbdb/train.txt --valid_index ./datasets/pepbdb/valid.txt --test_index ./datasets/pepbdb/test.txt --processed_dir datasets/pepbdb/processed/
|
| 109 |
+
mv ./datasets/pepbdb/processed/pdbs ./dataset/pepbdb # re-locate
|
| 110 |
+
```
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
### Trained Weights
|
| 114 |
+
|
| 115 |
+
- codesign: `./checkpoint/codesign.ckpt`
|
| 116 |
+
- conformation generation: `./checkpoints/fixseq.ckpt`
|
| 117 |
+
|
| 118 |
+
Both can be downloaded at the [release page](https://github.com/THUNLP-MT/PepGLAD/releases/tag/v1.0). These checkpoints were trained on PepBench.
|
| 119 |
+
|
| 120 |
+
## Usage
|
| 121 |
+
|
| 122 |
+
:warning: Before using the following codes, please first download the trained weights mentioned above.
|
| 123 |
+
|
| 124 |
+
### Peptide Sequence-Structure Co-Design
|
| 125 |
+
|
| 126 |
+
Take `./assets/1ssc_A_B.pdb` as an example, where chain A is the target protein:
|
| 127 |
+
|
| 128 |
+
```bash
|
| 129 |
+
# obtain the binding site, which might also be manually crafted or from other ligands (e.g. small molecule, antibodies)
|
| 130 |
+
python -m api.detect_pocket --pdb assets/1ssc_A_B.pdb --target_chains A --ligand_chains B --out assets/1ssc_A_pocket.json
|
| 131 |
+
# sequence-structure codesign with length in [8, 15)
|
| 132 |
+
CUDA_VISIBLE_DEVICES=0 python -m api.run \
|
| 133 |
+
--mode codesign \
|
| 134 |
+
--pdb assets/1ssc_A_B.pdb \
|
| 135 |
+
--pocket assets/1ssc_A_pocket.json \
|
| 136 |
+
--out_dir ./output/codesign \
|
| 137 |
+
--length_min 8 \
|
| 138 |
+
--length_max 15 \
|
| 139 |
+
--n_samples 10
|
| 140 |
+
```
|
| 141 |
+
Then 10 generations will be outputed under the folder `./output/codesign`.
|
| 142 |
+
|
| 143 |
+
### Peptide Binding Conformation Generation
|
| 144 |
+
|
| 145 |
+
Take `./assets/1ssc_A_B.pdb` as an example, where chain A is the target protein:
|
| 146 |
+
|
| 147 |
+
```bash
|
| 148 |
+
# obtain the binding site, which might also be manually crafted or from other ligands (e.g. small molecule, antibodies)
|
| 149 |
+
python -m api.detect_pocket --pdb assets/1ssc_A_B.pdb --target_chains A --ligand_chains B --out assets/1ssc_A_pocket.json
|
| 150 |
+
# generate binding conformation
|
| 151 |
+
CUDA_VISIBLE_DEVICES=0 python -m api.run \
|
| 152 |
+
--mode struct_pred \
|
| 153 |
+
--pdb assets/1ssc_A_B.pdb \
|
| 154 |
+
--pocket assets/1ssc_A_pocket.json \
|
| 155 |
+
--out_dir ./output/struct_pred \
|
| 156 |
+
--peptide_seq PYVPVHFDASV \
|
| 157 |
+
--n_samples 10
|
| 158 |
+
```
|
| 159 |
+
Then 10 conformations will be outputed under the folder `./output/struct_pred`.
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
## Reproduction of Paper Experiments
|
| 163 |
+
|
| 164 |
+
Each task requires the following steps, which we have integrated into the script `./scripts/run_exp_pipe.sh`:
|
| 165 |
+
|
| 166 |
+
1. Train autoencoder
|
| 167 |
+
2. Train latent diffusion model
|
| 168 |
+
3. Calculate distribution of latent distances between consecutive residues
|
| 169 |
+
4. Generation & Evaluation
|
| 170 |
+
|
| 171 |
+
On the other hand, if you want to evaluate existing checkpoints, please follow the instructions below (e.g. conformation generation):
|
| 172 |
+
|
| 173 |
+
```bash
|
| 174 |
+
# generate results on the test set and save to ./results/fixseq
|
| 175 |
+
python generate.py --config configs/pepbench/test_fixseq.yaml --ckpt checkpoints/fixseq.ckpt --gpu 0 --save_dir ./results/fixseq
|
| 176 |
+
# calculate metrics
|
| 177 |
+
python cal_metrics.py --results ./results/fixseq/results.jsonl
|
| 178 |
+
```
|
| 179 |
+
|
| 180 |
+
### Codesign
|
| 181 |
+
|
| 182 |
+
Codesign experiments on PepBench:
|
| 183 |
+
|
| 184 |
+
```bash
|
| 185 |
+
GPU=0 bash scripts/run_exp_pipe.sh pepbench_codesign configs/pepbench/autoencoder/train_codesign.yaml configs/pepbench/ldm/train_codesign.yaml configs/pepbench/ldm/setup_latent_guidance.yaml configs/pepbench/test_codesign.yaml
|
| 186 |
+
```
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
### Binding Conformation Generation
|
| 190 |
+
|
| 191 |
+
Conformation generation experiments on PepBench:
|
| 192 |
+
|
| 193 |
+
```bash
|
| 194 |
+
GPU=0 bash scripts/run_exp_pipe.sh pepbench_fixseq configs/pepbench/autoencoder/train_fixseq.yaml configs/pepbench/ldm/train_fixseq.yaml configs/pepbench/ldm/setup_latent_guidance.yaml configs/pepbench/test_fixseq.yaml
|
| 195 |
+
```
|
| 196 |
+
|
| 197 |
+
## Contact
|
| 198 |
+
|
| 199 |
+
Thank you for your interest in our work!
|
| 200 |
+
|
| 201 |
+
Please feel free to ask about any questions about the algorithms, codes, as well as problems encountered in running them so that we can make it clearer and better. You can either create an issue in the github repo or contact us at jackie_kxz@outlook.com.
|
| 202 |
+
|
| 203 |
+
## Reference
|
| 204 |
+
|
| 205 |
+
```bibtex
|
| 206 |
+
@article{kong2025full,
|
| 207 |
+
title={Full-atom peptide design with geometric latent diffusion},
|
| 208 |
+
author={Kong, Xiangzhe and Jia, Yinjun and Huang, Wenbing and Liu, Yang},
|
| 209 |
+
journal={Advances in Neural Information Processing Systems},
|
| 210 |
+
volume={37},
|
| 211 |
+
pages={74808--74839},
|
| 212 |
+
year={2025}
|
| 213 |
+
}
|
| 214 |
+
```
|
api/detect_pocket.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/python
|
| 2 |
+
# -*- coding:utf-8 -*-
|
| 3 |
+
import argparse
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
from data.converter.pdb_to_list_blocks import pdb_to_list_blocks
|
| 7 |
+
from data.converter.blocks_interface import blocks_cb_interface, dist_matrix_from_blocks
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def get_interface(pdb, receptor_chains, ligand_chains, pocket_th=10.0): # CB distance
|
| 11 |
+
list_blocks, chain_ids = pdb_to_list_blocks(pdb, receptor_chains + ligand_chains, return_chain_ids=True)
|
| 12 |
+
chain2blocks = {chain: block for chain, block in zip(chain_ids, list_blocks)}
|
| 13 |
+
for c in receptor_chains:
|
| 14 |
+
assert c in chain2blocks, f'Chain {c} not found for receptor'
|
| 15 |
+
for c in ligand_chains:
|
| 16 |
+
assert c in chain2blocks, f'Chain {c} not found for ligand'
|
| 17 |
+
|
| 18 |
+
rec_blocks, rec_block_chains, lig_blocks, lig_block_chains = [], [], [], []
|
| 19 |
+
for c in receptor_chains:
|
| 20 |
+
for block in chain2blocks[c]:
|
| 21 |
+
rec_blocks.append(block)
|
| 22 |
+
rec_block_chains.append(c)
|
| 23 |
+
for c in ligand_chains:
|
| 24 |
+
for block in chain2blocks[c]:
|
| 25 |
+
lig_blocks.append(block)
|
| 26 |
+
lig_block_chains.append(c)
|
| 27 |
+
|
| 28 |
+
_, (pocket_idx, lig_if_idx) = blocks_cb_interface(rec_blocks, lig_blocks, pocket_th) # 10A for pocket size based on CB
|
| 29 |
+
epitope = []
|
| 30 |
+
for i in pocket_idx:
|
| 31 |
+
epitope.append((rec_blocks[i], rec_block_chains[i], i))
|
| 32 |
+
|
| 33 |
+
dist_mat = dist_matrix_from_blocks([rec_blocks[i] for i in pocket_idx], [lig_blocks[i] for i in lig_if_idx])
|
| 34 |
+
min_dists = np.min(dist_mat, axis=-1) # [Nrec]
|
| 35 |
+
lig_idxs = np.argmin(dist_mat, axis=-1) # [Nrec]
|
| 36 |
+
dists = []
|
| 37 |
+
for i, d in zip(lig_idxs, min_dists):
|
| 38 |
+
i = lig_if_idx[i]
|
| 39 |
+
dists.append((lig_blocks[i], lig_block_chains[i], i, d))
|
| 40 |
+
|
| 41 |
+
return epitope, dists
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
if __name__ == '__main__':
|
| 45 |
+
import json
|
| 46 |
+
parser = argparse.ArgumentParser(description='get interface')
|
| 47 |
+
parser.add_argument('--pdb', type=str, required=True, help='Path to the complex pdb')
|
| 48 |
+
parser.add_argument('--target_chains', type=str, nargs='+', required=True, help='Specify target chain ids')
|
| 49 |
+
parser.add_argument('--ligand_chains', type=str, nargs='+', required=True, help='Specify ligand chain ids')
|
| 50 |
+
parser.add_argument('--pocket_th', type=int, default=10.0, help='CB distance threshold for defining the binding site')
|
| 51 |
+
parser.add_argument('--out', type=str, default=None, help='Save epitope information to json file if specified')
|
| 52 |
+
args = parser.parse_args()
|
| 53 |
+
epitope, dists = get_interface(args.pdb, args.target_chains, args.ligand_chains, args.pocket_th)
|
| 54 |
+
para_res = {}
|
| 55 |
+
for _, chain_name, i, d in dists:
|
| 56 |
+
key = f'{chain_name}-{i}'
|
| 57 |
+
para_res[key] = 1
|
| 58 |
+
print(f'REMARK: {len(epitope)} residues in the binding site on the target protein, with {len(para_res)} residues in ligand:')
|
| 59 |
+
print(f' \tchain\tresidue id\ttype\tchain\tresidue id\ttype\tdistance')
|
| 60 |
+
for i, (e, p) in enumerate(zip(epitope, dists)):
|
| 61 |
+
e_res, e_chain_name, _ = e
|
| 62 |
+
p_res, p_chain_name, _, d = p
|
| 63 |
+
print(f'{i+1}\t{e_chain_name}\t{e_res.id}\t{e_res.abrv}\t' + \
|
| 64 |
+
f'{p_chain_name}\t{p_res.id}\t{p_res.abrv}\t{round(d, 3)}')
|
| 65 |
+
|
| 66 |
+
if args.out:
|
| 67 |
+
data = []
|
| 68 |
+
for e in epitope:
|
| 69 |
+
res, chain_name, _ = e
|
| 70 |
+
data.append((chain_name, res.id))
|
| 71 |
+
with open(args.out, 'w') as fout:
|
| 72 |
+
json.dump(data, fout)
|
api/run.py
ADDED
|
@@ -0,0 +1,274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/python
|
| 2 |
+
# -*- coding:utf-8 -*-
|
| 3 |
+
import os
|
| 4 |
+
import sys
|
| 5 |
+
import json
|
| 6 |
+
import argparse
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
from os.path import splitext, basename
|
| 9 |
+
|
| 10 |
+
import ray
|
| 11 |
+
import numpy as np
|
| 12 |
+
import torch
|
| 13 |
+
from torch.utils.data import DataLoader
|
| 14 |
+
|
| 15 |
+
from data.format import Atom, Block, VOCAB
|
| 16 |
+
from data.converter.pdb_to_list_blocks import pdb_to_list_blocks
|
| 17 |
+
from data.converter.list_blocks_to_pdb import list_blocks_to_pdb
|
| 18 |
+
from data.codesign import calculate_covariance_matrix
|
| 19 |
+
from utils.const import sidechain_atoms
|
| 20 |
+
from utils.logger import print_log
|
| 21 |
+
from evaluation.dG.openmm_relaxer import ForceFieldMinimizer
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class DesignDataset(torch.utils.data.Dataset):
|
| 25 |
+
|
| 26 |
+
MAX_N_ATOM = 14
|
| 27 |
+
|
| 28 |
+
def __init__(self, pdbs, epitopes, lengths_range=None, seqs=None) -> None:
|
| 29 |
+
super().__init__()
|
| 30 |
+
self.pdbs = pdbs
|
| 31 |
+
self.epitopes = epitopes
|
| 32 |
+
self.lengths_range = lengths_range
|
| 33 |
+
self.seqs = seqs
|
| 34 |
+
# structure prediction or codesign
|
| 35 |
+
assert (self.seqs is not None and self.lengths_range is None) | \
|
| 36 |
+
(self.seqs is None and self.lengths_range is not None)
|
| 37 |
+
|
| 38 |
+
def get_epitope(self, idx):
|
| 39 |
+
pdb, epitope_def = self.pdbs[idx], self.epitopes[idx]
|
| 40 |
+
|
| 41 |
+
with open(epitope_def, 'r') as fin:
|
| 42 |
+
epitope = json.load(fin)
|
| 43 |
+
to_str = lambda pos: f'{pos[0]}-{pos[1]}'
|
| 44 |
+
epi_map = {}
|
| 45 |
+
for chain_name, pos in epitope:
|
| 46 |
+
if chain_name not in epi_map:
|
| 47 |
+
epi_map[chain_name] = {}
|
| 48 |
+
epi_map[chain_name][to_str(pos)] = True
|
| 49 |
+
residues, position_ids = [], []
|
| 50 |
+
chain2blocks = pdb_to_list_blocks(pdb, list(epi_map.keys()), dict_form=True)
|
| 51 |
+
if len(chain2blocks) != len(epi_map):
|
| 52 |
+
print_log(f'Some chains in the epitope are missing. Parsed {list(chain2blocks.keys())}, given {list(epi_map.keys())}.', level='WARN')
|
| 53 |
+
for chain_name in chain2blocks:
|
| 54 |
+
chain = chain2blocks[chain_name]
|
| 55 |
+
for i, block in enumerate(chain): # residue
|
| 56 |
+
if to_str(block.id) in epi_map[chain_name]:
|
| 57 |
+
residues.append(block)
|
| 58 |
+
position_ids.append(i + 1) # position ids start from 1
|
| 59 |
+
return residues, position_ids, chain2blocks
|
| 60 |
+
|
| 61 |
+
def generate_pep_chain(self, idx):
|
| 62 |
+
if self.lengths_range is not None: # codesign
|
| 63 |
+
lmin, lmax = self.lengths_range[idx]
|
| 64 |
+
length = np.random.randint(lmin, lmax)
|
| 65 |
+
unk_block = Block(VOCAB.symbol_to_abrv(VOCAB.UNK), [Atom('CA', [0, 0, 0], 'C')])
|
| 66 |
+
return [unk_block] * length
|
| 67 |
+
else:
|
| 68 |
+
seq = self.seqs[idx]
|
| 69 |
+
blocks = []
|
| 70 |
+
for s in seq:
|
| 71 |
+
atoms = []
|
| 72 |
+
for atom_name in VOCAB.backbone_atoms + sidechain_atoms.get(s, []):
|
| 73 |
+
atoms.append(Atom(atom_name, [0, 0, 0], atom_name[0]))
|
| 74 |
+
blocks.append(Block(VOCAB.symbol_to_abrv(s), atoms))
|
| 75 |
+
return blocks
|
| 76 |
+
|
| 77 |
+
def __len__(self):
|
| 78 |
+
return len(self.pdbs)
|
| 79 |
+
|
| 80 |
+
def __getitem__(self, idx: int):
|
| 81 |
+
rec_blocks, rec_position_ids, rec_chain2blocks = self.get_epitope(idx)
|
| 82 |
+
lig_blocks = self.generate_pep_chain(idx)
|
| 83 |
+
|
| 84 |
+
mask = [0 for _ in rec_blocks] + [1 for _ in lig_blocks]
|
| 85 |
+
position_ids = rec_position_ids + [i + 1 for i, _ in enumerate(lig_blocks)]
|
| 86 |
+
X, S, atom_mask = [], [], []
|
| 87 |
+
for block in rec_blocks + lig_blocks:
|
| 88 |
+
symbol = VOCAB.abrv_to_symbol(block.abrv)
|
| 89 |
+
atom2coord = { unit.name: unit.get_coord() for unit in block.units }
|
| 90 |
+
bb_pos = np.mean(list(atom2coord.values()), axis=0).tolist()
|
| 91 |
+
coords, coord_mask = [], []
|
| 92 |
+
for atom_name in VOCAB.backbone_atoms + sidechain_atoms.get(symbol, []):
|
| 93 |
+
if atom_name in atom2coord:
|
| 94 |
+
coords.append(atom2coord[atom_name])
|
| 95 |
+
coord_mask.append(1)
|
| 96 |
+
else:
|
| 97 |
+
coords.append(bb_pos)
|
| 98 |
+
coord_mask.append(0)
|
| 99 |
+
n_pad = self.MAX_N_ATOM - len(coords)
|
| 100 |
+
for _ in range(n_pad):
|
| 101 |
+
coords.append(bb_pos)
|
| 102 |
+
coord_mask.append(0)
|
| 103 |
+
|
| 104 |
+
X.append(coords)
|
| 105 |
+
S.append(VOCAB.symbol_to_idx(symbol))
|
| 106 |
+
atom_mask.append(coord_mask)
|
| 107 |
+
|
| 108 |
+
X, atom_mask = torch.tensor(X, dtype=torch.float), torch.tensor(atom_mask, dtype=torch.bool)
|
| 109 |
+
mask = torch.tensor(mask, dtype=torch.bool)
|
| 110 |
+
cov = calculate_covariance_matrix(X[~mask][:, 1][atom_mask[~mask][:, 1]].numpy()) # only use the receptor to derive the affine transformation
|
| 111 |
+
eps = 1e-4
|
| 112 |
+
cov = cov + eps * np.identity(cov.shape[0])
|
| 113 |
+
L = torch.from_numpy(np.linalg.cholesky(cov)).float().unsqueeze(0)
|
| 114 |
+
|
| 115 |
+
return {
|
| 116 |
+
'X': X, # [N, 14] or [N, 4] if backbone_only == True
|
| 117 |
+
'S': torch.tensor(S, dtype=torch.long), # [N]
|
| 118 |
+
'position_ids': torch.tensor(position_ids, dtype=torch.long), # [N]
|
| 119 |
+
'mask': mask, # [N], 1 for generation
|
| 120 |
+
'atom_mask': atom_mask, # [N, 14] or [N, 4], 1 for having records in the PDB
|
| 121 |
+
'lengths': len(S),
|
| 122 |
+
'rec_chain2blocks': rec_chain2blocks,
|
| 123 |
+
'L': L
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
def collate_fn(self, batch):
|
| 127 |
+
results = {}
|
| 128 |
+
for key in batch[0]:
|
| 129 |
+
values = [item[key] for item in batch]
|
| 130 |
+
if key == 'lengths':
|
| 131 |
+
results[key] = torch.tensor(values, dtype=torch.long)
|
| 132 |
+
elif key == 'rec_chain2blocks':
|
| 133 |
+
results[key] = values
|
| 134 |
+
else:
|
| 135 |
+
results[key] = torch.cat(values, dim=0)
|
| 136 |
+
return results
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
@ray.remote(num_cpus=1, num_gpus=1/16)
|
| 140 |
+
def openmm_relax(pdb_path):
|
| 141 |
+
force_field = ForceFieldMinimizer()
|
| 142 |
+
force_field(pdb_path, pdb_path)
|
| 143 |
+
return pdb_path
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def design(mode, ckpt, gpu, pdbs, epitope_defs, n_samples, out_dir,
|
| 147 |
+
lengths_range=None, seqs=None, identifiers=None, batch_size=8, num_workers=4):
|
| 148 |
+
|
| 149 |
+
# create out dir
|
| 150 |
+
if not os.path.exists(out_dir):
|
| 151 |
+
os.makedirs(out_dir)
|
| 152 |
+
result_summary = open(os.path.join(out_dir, 'summary.jsonl'), 'w')
|
| 153 |
+
if identifiers is None:
|
| 154 |
+
identifiers = [splitext(basename(pdb))[0] for pdb in pdbs]
|
| 155 |
+
# load model
|
| 156 |
+
device = torch.device('cpu' if gpu == -1 else f'cuda:{gpu}')
|
| 157 |
+
model = torch.load(ckpt, map_location='cpu')
|
| 158 |
+
model.to(device)
|
| 159 |
+
model.eval()
|
| 160 |
+
|
| 161 |
+
# generate dataset
|
| 162 |
+
# expand data
|
| 163 |
+
if lengths_range is None: lengths_range = [None for _ in pdbs]
|
| 164 |
+
if seqs is None: seqs = [None for _ in pdbs]
|
| 165 |
+
expand_pdbs, expand_epitopes, expand_lens, expand_ids, expand_seqs = [], [], [], [], []
|
| 166 |
+
for _id, pdb, epitope, l, s, n in zip(identifiers, pdbs, epitope_defs, lengths_range, seqs, n_samples):
|
| 167 |
+
expand_ids.extend([f'{_id}_{i}' for i in range(n)])
|
| 168 |
+
expand_pdbs.extend([pdb for _ in range(n)])
|
| 169 |
+
expand_epitopes.extend([epitope for _ in range(n)])
|
| 170 |
+
expand_lens.extend([l for _ in range(n)])
|
| 171 |
+
expand_seqs.extend([s for _ in range(n)])
|
| 172 |
+
# create dataset
|
| 173 |
+
if expand_lens[0] is None: expand_lens = None
|
| 174 |
+
if expand_seqs[0] is None: expand_seqs = None
|
| 175 |
+
dataset = DesignDataset(expand_pdbs, expand_epitopes, expand_lens, expand_seqs)
|
| 176 |
+
dataloader = DataLoader(dataset, batch_size=batch_size,
|
| 177 |
+
num_workers=num_workers,
|
| 178 |
+
collate_fn=dataset.collate_fn,
|
| 179 |
+
shuffle=False
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
# generate peptides
|
| 183 |
+
cnt = 0
|
| 184 |
+
all_pdbs = []
|
| 185 |
+
for batch in tqdm(dataloader):
|
| 186 |
+
with torch.no_grad():
|
| 187 |
+
# move data
|
| 188 |
+
for k in batch:
|
| 189 |
+
if hasattr(batch[k], 'to'):
|
| 190 |
+
batch[k] = batch[k].to(device)
|
| 191 |
+
# generate
|
| 192 |
+
batch_X, batch_S, batch_pmetric = model.sample(
|
| 193 |
+
batch['X'], batch['S'],
|
| 194 |
+
batch['mask'], batch['position_ids'],
|
| 195 |
+
batch['lengths'], batch['atom_mask'],
|
| 196 |
+
L=batch['L'], sample_opt={
|
| 197 |
+
'energy_func': 'default',
|
| 198 |
+
'energy_lambda': 0.5 if mode == 'struct_pred' else 0.8
|
| 199 |
+
}
|
| 200 |
+
)
|
| 201 |
+
# save data
|
| 202 |
+
for X, S, pmetric, rec_chain2blocks in zip(batch_X, batch_S, batch_pmetric, batch['rec_chain2blocks']):
|
| 203 |
+
if S is None: S = expand_seqs[cnt] # structure prediction
|
| 204 |
+
lig_blocks = []
|
| 205 |
+
for x, s in zip(X, S):
|
| 206 |
+
abrv = VOCAB.symbol_to_abrv(s)
|
| 207 |
+
atoms = VOCAB.backbone_atoms + sidechain_atoms[VOCAB.abrv_to_symbol(abrv)]
|
| 208 |
+
units = [
|
| 209 |
+
Atom(atom_name, coord, atom_name[0]) for atom_name, coord in zip(atoms, x)
|
| 210 |
+
]
|
| 211 |
+
lig_blocks.append(Block(abrv, units))
|
| 212 |
+
list_blocks, chain_names = [], []
|
| 213 |
+
for chain in rec_chain2blocks:
|
| 214 |
+
list_blocks.append(rec_chain2blocks[chain])
|
| 215 |
+
chain_names.append(chain)
|
| 216 |
+
pep_chain_id = chr(max([ord(c) for c in chain_names]) + 1)
|
| 217 |
+
list_blocks.append(lig_blocks)
|
| 218 |
+
chain_names.append(pep_chain_id)
|
| 219 |
+
out_pdb = os.path.join(out_dir, expand_ids[cnt] + '.pdb')
|
| 220 |
+
list_blocks_to_pdb(list_blocks, chain_names, out_pdb)
|
| 221 |
+
all_pdbs.append(out_pdb)
|
| 222 |
+
result_summary.write(json.dumps({
|
| 223 |
+
'id': expand_ids[cnt],
|
| 224 |
+
'rec_chains': list(rec_chain2blocks.keys()),
|
| 225 |
+
'pep_chain': pep_chain_id,
|
| 226 |
+
'pep_seq': ''.join([VOCAB.abrv_to_symbol(block.abrv) for block in lig_blocks])
|
| 227 |
+
}) + '\n')
|
| 228 |
+
result_summary.flush()
|
| 229 |
+
cnt += 1
|
| 230 |
+
result_summary.close()
|
| 231 |
+
|
| 232 |
+
print_log(f'Running openmm relaxation...')
|
| 233 |
+
ray.init(num_cpus=8)
|
| 234 |
+
futures = [openmm_relax.remote(path) for path in all_pdbs]
|
| 235 |
+
pbar = tqdm(total=len(futures))
|
| 236 |
+
while len(futures) > 0:
|
| 237 |
+
done_ids, futures = ray.wait(futures, num_returns=1)
|
| 238 |
+
for done_id in done_ids:
|
| 239 |
+
done_path = ray.get(done_id)
|
| 240 |
+
pbar.update(1)
|
| 241 |
+
print_log(f'Done')
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
def parse():
|
| 245 |
+
parser = argparse.ArgumentParser(description='run pepglad for codesign or structure prediction')
|
| 246 |
+
parser.add_argument('--mode', type=str, required=True, choices=['codesign', 'struct_pred'], help='Running mode')
|
| 247 |
+
parser.add_argument('--pdb', type=str, required=True, help='Path to the PDB file of the target protein')
|
| 248 |
+
parser.add_argument('--pocket', type=str, required=True, help='Path to the pocket definition (*.json generated by detect_pocket)')
|
| 249 |
+
parser.add_argument('--n_samples', type=int, default=10, help='Number of samples')
|
| 250 |
+
parser.add_argument('--out_dir', type=str, required=True, help='Output directory')
|
| 251 |
+
parser.add_argument('--peptide_seq', type=str, required='struct_pred' in sys.argv, help='Peptide sequence for structure prediction')
|
| 252 |
+
parser.add_argument('--length_min', type=int, required='codesign' in sys.argv, help='Minimum peptide length for codesign (inclusive)')
|
| 253 |
+
parser.add_argument('--length_max', type=int, required='codesign' in sys.argv, help='Maximum peptide length for codesign (exclusive)')
|
| 254 |
+
parser.add_argument('--gpu', type=int, default=0, help='GPU to use')
|
| 255 |
+
return parser.parse_args()
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
if __name__ == '__main__':
|
| 259 |
+
args = parse()
|
| 260 |
+
proj_dir = os.path.join(os.path.dirname(__file__), '..')
|
| 261 |
+
ckpt = os.path.join(proj_dir, 'checkpoints', 'fixseq.ckpt' if args.mode == 'struct_pred' else 'codesign.ckpt')
|
| 262 |
+
print_log(f'Loading checkpoint: {ckpt}')
|
| 263 |
+
design(
|
| 264 |
+
mode=args.mode,
|
| 265 |
+
ckpt=ckpt, # path to the checkpoint of the trained model
|
| 266 |
+
gpu=args.gpu, # the ID of the GPU to use
|
| 267 |
+
pdbs=[args.pdb], # paths to the PDB file of each antigen
|
| 268 |
+
epitope_defs=[args.pocket], # paths to the epitope (pocket) definitions
|
| 269 |
+
n_samples=[args.n_samples], # number of samples for each epitope
|
| 270 |
+
out_dir=args.out_dir, # output directory
|
| 271 |
+
identifiers=[os.path.basename(os.path.splitext(args.pdb)[0])], # file name (name of each output candidate)
|
| 272 |
+
lengths_range=[(args.length_min, args.length_max)] if args.mode == 'codesign' else None, # range of acceptable peptide lengths, left inclusive, right exclusive
|
| 273 |
+
seqs=[args.peptide_seq] if args.mode == 'struct_pred' else None # peptide sequences for structure prediction
|
| 274 |
+
)
|
assets/1ssc_A_pocket.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
[["A", [3, " "]], ["A", [4, " "]], ["A", [5, " "]], ["A", [6, " "]], ["A", [7, " "]], ["A", [8, " "]], ["A", [9, " "]], ["A", [11, " "]], ["A", [12, " "]], ["A", [13, " "]], ["A", [43, " "]], ["A", [44, " "]], ["A", [45, " "]], ["A", [46, " "]], ["A", [47, " "]], ["A", [51, " "]], ["A", [54, " "]], ["A", [55, " "]], ["A", [56, " "]], ["A", [57, " "]], ["A", [58, " "]], ["A", [59, " "]], ["A", [63, " "]], ["A", [64, " "]], ["A", [65, " "]], ["A", [66, " "]], ["A", [67, " "]], ["A", [69, " "]], ["A", [71, " "]], ["A", [72, " "]], ["A", [73, " "]], ["A", [74, " "]], ["A", [75, " "]], ["A", [78, " "]], ["A", [79, " "]], ["A", [81, " "]], ["A", [83, " "]], ["A", [102, " "]], ["A", [103, " "]], ["A", [104, " "]], ["A", [105, " "]], ["A", [106, " "]], ["A", [107, " "]], ["A", [108, " "]], ["A", [109, " "]], ["A", [110, " "]], ["A", [111, " "]], ["A", [112, " "]]]
|
cal_metrics.py
ADDED
|
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/python
|
| 2 |
+
# -*- coding:utf-8 -*-
|
| 3 |
+
import argparse
|
| 4 |
+
import json
|
| 5 |
+
import os
|
| 6 |
+
import random
|
| 7 |
+
from copy import deepcopy
|
| 8 |
+
from collections import defaultdict
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
from tqdm.contrib.concurrent import process_map
|
| 11 |
+
import statistics
|
| 12 |
+
import warnings
|
| 13 |
+
warnings.filterwarnings("ignore")
|
| 14 |
+
|
| 15 |
+
import numpy as np
|
| 16 |
+
from scipy.stats import spearmanr
|
| 17 |
+
|
| 18 |
+
from data.converter.pdb_to_list_blocks import pdb_to_list_blocks
|
| 19 |
+
from evaluation import diversity
|
| 20 |
+
from evaluation.dockq import dockq
|
| 21 |
+
from evaluation.rmsd import compute_rmsd
|
| 22 |
+
from utils.random_seed import setup_seed
|
| 23 |
+
from evaluation.seq_metric import aar, slide_aar
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def _get_ref_pdb(_id, root_dir):
|
| 27 |
+
return os.path.join(root_dir, 'references', f'{_id}_ref.pdb')
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def _get_gen_pdb(_id, number, root_dir, use_rosetta):
|
| 31 |
+
suffix = '_rosetta' if use_rosetta else ''
|
| 32 |
+
return os.path.join(root_dir, 'candidates', _id, f'{_id}_gen_{number}{suffix}.pdb')
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def cal_metrics(items):
|
| 36 |
+
# all of the items are conditioned on the same binding pocket
|
| 37 |
+
root_dir = items[0]['root_dir']
|
| 38 |
+
ref_pdb, rec_chain, lig_chain = items[0]['ref_pdb'], items[0]['rec_chain'], items[0]['lig_chain']
|
| 39 |
+
ref_pdb = _get_ref_pdb(items[0]['id'], root_dir)
|
| 40 |
+
seq_only, struct_only, backbone_only = items[0]['seq_only'], items[0]['struct_only'], items[0]['backbone_only']
|
| 41 |
+
|
| 42 |
+
# prepare
|
| 43 |
+
results = defaultdict(list)
|
| 44 |
+
cand_seqs, cand_ca_xs = [], []
|
| 45 |
+
rec_blocks, ref_pep_blocks = pdb_to_list_blocks(ref_pdb, [rec_chain, lig_chain])
|
| 46 |
+
ref_ca_x, ca_mask = [], []
|
| 47 |
+
for ref_block in ref_pep_blocks:
|
| 48 |
+
if ref_block.has_unit('CA'):
|
| 49 |
+
ca_mask.append(1)
|
| 50 |
+
ref_ca_x.append(ref_block.get_unit_by_name('CA').get_coord())
|
| 51 |
+
else:
|
| 52 |
+
ca_mask.append(0)
|
| 53 |
+
ref_ca_x.append([0, 0, 0])
|
| 54 |
+
ref_ca_x, ca_mask = np.array(ref_ca_x), np.array(ca_mask).astype(bool)
|
| 55 |
+
|
| 56 |
+
for item in items:
|
| 57 |
+
if not struct_only:
|
| 58 |
+
cand_seqs.append(item['gen_seq'])
|
| 59 |
+
results['Slide AAR'].append(slide_aar(item['gen_seq'], item['ref_seq'], aar))
|
| 60 |
+
|
| 61 |
+
# structure metrics
|
| 62 |
+
gen_pdb = _get_gen_pdb(item['id'], item['number'], root_dir, item['rosetta'])
|
| 63 |
+
_, gen_pep_blocks = pdb_to_list_blocks(gen_pdb, [rec_chain, lig_chain])
|
| 64 |
+
assert len(gen_pep_blocks) == len(ref_pep_blocks), f'{item}\t{len(ref_pep_blocks)}\t{len(gen_pep_blocks)}'
|
| 65 |
+
|
| 66 |
+
# CA RMSD
|
| 67 |
+
gen_ca_x = np.array([block.get_unit_by_name('CA').get_coord() for block in gen_pep_blocks])
|
| 68 |
+
cand_ca_xs.append(gen_ca_x)
|
| 69 |
+
rmsd = compute_rmsd(ref_ca_x[ca_mask], gen_ca_x[ca_mask], aligned=True)
|
| 70 |
+
results['RMSD(CA)'].append(rmsd)
|
| 71 |
+
if struct_only:
|
| 72 |
+
results['RMSD<=2.0'].append(1 if rmsd <= 2.0 else 0)
|
| 73 |
+
results['RMSD<=5.0'].append(1 if rmsd <= 5.0 else 0)
|
| 74 |
+
results['RMSD<=10.0'].append(1 if rmsd <= 10.0 else 0)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
if backbone_only:
|
| 78 |
+
continue
|
| 79 |
+
|
| 80 |
+
# 5. DockQ
|
| 81 |
+
dockq_score = dockq(gen_pdb, ref_pdb, lig_chain)
|
| 82 |
+
results['DockQ'].append(dockq_score)
|
| 83 |
+
if struct_only:
|
| 84 |
+
results['DockQ>=0.23'].append(1 if dockq_score >= 0.23 else 0)
|
| 85 |
+
results['DockQ>=0.49'].append(1 if dockq_score >= 0.49 else 0)
|
| 86 |
+
results['DockQ>=0.80'].append(1 if dockq_score >= 0.80 else 0)
|
| 87 |
+
|
| 88 |
+
# Full atom RMSD
|
| 89 |
+
if struct_only:
|
| 90 |
+
gen_all_x, ref_all_x = [], []
|
| 91 |
+
for gen_block, ref_block in zip(gen_pep_blocks, ref_pep_blocks):
|
| 92 |
+
for ref_atom in ref_block:
|
| 93 |
+
if gen_block.has_unit(ref_atom.name):
|
| 94 |
+
ref_all_x.append(ref_atom.get_coord())
|
| 95 |
+
gen_all_x.append(gen_block.get_unit_by_name(ref_atom.name).get_coord())
|
| 96 |
+
results['RMSD(full-atom)'].append(compute_rmsd(
|
| 97 |
+
np.array(gen_all_x), np.array(ref_all_x), aligned=True
|
| 98 |
+
))
|
| 99 |
+
|
| 100 |
+
pmets = [item['pmetric'] for item in items]
|
| 101 |
+
indexes = list(range(len(items)))
|
| 102 |
+
# aggregation
|
| 103 |
+
for name in results:
|
| 104 |
+
vals = results[name]
|
| 105 |
+
corr = spearmanr(vals, pmets, nan_policy='omit').statistic
|
| 106 |
+
if np.isnan(corr):
|
| 107 |
+
corr = 0
|
| 108 |
+
aggr_res = {
|
| 109 |
+
'max': max(vals),
|
| 110 |
+
'min': min(vals),
|
| 111 |
+
'mean': sum(vals) / len(vals),
|
| 112 |
+
'random': vals[0],
|
| 113 |
+
'max*': vals[(max if corr > 0 else min)(indexes, key=lambda i: pmets[i])],
|
| 114 |
+
'min*': vals[(min if corr > 0 else max)(indexes, key=lambda i: pmets[i])],
|
| 115 |
+
'pmet_corr': corr,
|
| 116 |
+
'individual': vals,
|
| 117 |
+
'individual_pmet': pmets
|
| 118 |
+
}
|
| 119 |
+
results[name] = aggr_res
|
| 120 |
+
|
| 121 |
+
if len(cand_seqs) > 1 and not seq_only:
|
| 122 |
+
seq_div, struct_div, co_div, consistency = diversity.diversity(cand_seqs, np.array(cand_ca_xs))
|
| 123 |
+
results['Sequence Diversity'] = seq_div
|
| 124 |
+
results['Struct Diversity'] = struct_div
|
| 125 |
+
results['Codesign Diversity'] = co_div
|
| 126 |
+
results['Consistency'] = consistency
|
| 127 |
+
|
| 128 |
+
return results
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def cnt_aa_dist(seqs):
|
| 132 |
+
cnts = {}
|
| 133 |
+
for seq in seqs:
|
| 134 |
+
for aa in seq:
|
| 135 |
+
if aa not in cnts:
|
| 136 |
+
cnts[aa] = 0
|
| 137 |
+
cnts[aa] += 1
|
| 138 |
+
aas = sorted(list(cnts.keys()), key=lambda aa: cnts[aa])
|
| 139 |
+
total = sum(cnts.values())
|
| 140 |
+
for aa in aas:
|
| 141 |
+
print(f'\t{aa}: {cnts[aa] / total}')
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def main(args):
|
| 145 |
+
root_dir = os.path.dirname(args.results)
|
| 146 |
+
# load dG filter
|
| 147 |
+
if args.filter_dG is None:
|
| 148 |
+
filter_func = lambda _id, n: True
|
| 149 |
+
else:
|
| 150 |
+
dG_results = json.load(open(args.filter_dG, 'r'))
|
| 151 |
+
filter_func = lambda _id, n: dG_results[_id]['all'][str(n)] < 0
|
| 152 |
+
# load results
|
| 153 |
+
with open(args.results, 'r') as fin:
|
| 154 |
+
lines = fin.read().strip().split('\n')
|
| 155 |
+
id2items = {}
|
| 156 |
+
for line in lines:
|
| 157 |
+
item = json.loads(line)
|
| 158 |
+
_id = item['id']
|
| 159 |
+
if not filter_func(_id, item['number']):
|
| 160 |
+
continue
|
| 161 |
+
if _id not in id2items:
|
| 162 |
+
id2items[_id] = []
|
| 163 |
+
item['root_dir'] = root_dir
|
| 164 |
+
item['rosetta'] = args.rosetta
|
| 165 |
+
id2items[_id].append(item)
|
| 166 |
+
ids = list(id2items.keys())
|
| 167 |
+
|
| 168 |
+
if args.filter_dG is not None:
|
| 169 |
+
# delete results with only one sample since it cannot calculate diversity
|
| 170 |
+
del_ids = [_id for _id in ids if len(id2items[_id]) < 2]
|
| 171 |
+
for _id in del_ids:
|
| 172 |
+
print(f'Deleting {_id} since it only has one sample passed the filter')
|
| 173 |
+
del id2items[_id]
|
| 174 |
+
|
| 175 |
+
if args.num_workers > 1:
|
| 176 |
+
metrics = process_map(cal_metrics, id2items.values(), max_workers=args.num_workers, chunksize=1)
|
| 177 |
+
else:
|
| 178 |
+
metrics = [cal_metrics(inputs) for inputs in tqdm(id2items.values())]
|
| 179 |
+
|
| 180 |
+
eval_results_path = os.path.join(os.path.dirname(args.results), 'eval_report.json')
|
| 181 |
+
with open(eval_results_path, 'w') as fout:
|
| 182 |
+
for i, _id in enumerate(id2items):
|
| 183 |
+
metric = deepcopy(metrics[i])
|
| 184 |
+
metric['id'] = _id
|
| 185 |
+
fout.write(json.dumps(metric) + '\n')
|
| 186 |
+
|
| 187 |
+
# individual level results
|
| 188 |
+
print('Point-wise evaluation results:')
|
| 189 |
+
for name in metrics[0]:
|
| 190 |
+
vals = [item[name] for item in metrics]
|
| 191 |
+
if isinstance(vals[0], dict):
|
| 192 |
+
if 'RMSD' in name and '<=' not in name:
|
| 193 |
+
aggr = 'min'
|
| 194 |
+
else:
|
| 195 |
+
aggr = 'max'
|
| 196 |
+
aggr_vals = [val[aggr] for val in vals]
|
| 197 |
+
if '>=' in name or '<=' in name: # percentage
|
| 198 |
+
print(f'{name}: {sum(aggr_vals) / len(aggr_vals)}')
|
| 199 |
+
else:
|
| 200 |
+
if 'RMSD' in name:
|
| 201 |
+
print(f'{name}(median): {statistics.median(aggr_vals)}') # unbounded, some extreme values will affect the mean but not the median
|
| 202 |
+
else:
|
| 203 |
+
print(f'{name}(mean): {sum(aggr_vals) / len(aggr_vals)}')
|
| 204 |
+
lowest_i = min([i for i in range(len(aggr_vals))], key=lambda i: aggr_vals[i])
|
| 205 |
+
highest_i = max([i for i in range(len(aggr_vals))], key=lambda i: aggr_vals[i])
|
| 206 |
+
print(f'\tlowest: {aggr_vals[lowest_i]}, id: {ids[lowest_i]}', end='')
|
| 207 |
+
print(f'\thighest: {aggr_vals[highest_i]}, id: {ids[highest_i]}')
|
| 208 |
+
else:
|
| 209 |
+
print(f'{name} (mean): {sum(vals) / len(vals)}')
|
| 210 |
+
lowest_i = min([i for i in range(len(vals))], key=lambda i: vals[i])
|
| 211 |
+
highest_i = max([i for i in range(len(vals))], key=lambda i: vals[i])
|
| 212 |
+
print(f'\tlowest: {vals[lowest_i]}, id: {ids[lowest_i]}')
|
| 213 |
+
print(f'\thighest: {vals[highest_i]}, id: {ids[highest_i]}')
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def parse():
|
| 217 |
+
parser = argparse.ArgumentParser(description='calculate metrics')
|
| 218 |
+
parser.add_argument('--results', type=str, required=True, help='Path to test set')
|
| 219 |
+
parser.add_argument('--num_workers', type=int, default=8, help='Number of workers to use')
|
| 220 |
+
parser.add_argument('--rosetta', action='store_true', help='Use the rosetta-refined structure')
|
| 221 |
+
parser.add_argument('--filter_dG', type=str, default=None, help='Only calculate results on samples with dG<0')
|
| 222 |
+
|
| 223 |
+
return parser.parse_args()
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
if __name__ == '__main__':
|
| 227 |
+
setup_seed(0)
|
| 228 |
+
main(parse())
|
configs/pepbdb/autoencoder/train_codesign.yaml
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
dataset:
|
| 2 |
+
train:
|
| 3 |
+
- class: CoDesignDataset
|
| 4 |
+
mmap_dir: ./datasets/pepbdb/processed
|
| 5 |
+
specify_index: ./datasets/pepbdb/processed/train_index.txt
|
| 6 |
+
backbone_only: false
|
| 7 |
+
cluster: ./datasets/pepbdb/train.cluster
|
| 8 |
+
- class: CoDesignDataset
|
| 9 |
+
mmap_dir: ./datasets/ProtFrag/processed
|
| 10 |
+
backbone_only: false
|
| 11 |
+
valid:
|
| 12 |
+
class: CoDesignDataset
|
| 13 |
+
mmap_dir: ./datasets/pepbdb/processed
|
| 14 |
+
specify_index: ./datasets/pepbdb/processed/valid_index.txt
|
| 15 |
+
backbone_only: false
|
| 16 |
+
|
| 17 |
+
dataloader:
|
| 18 |
+
shuffle: true
|
| 19 |
+
num_workers: 4
|
| 20 |
+
wrapper:
|
| 21 |
+
class: DynamicBatchWrapper
|
| 22 |
+
complexity: n**2
|
| 23 |
+
ubound_per_batch: 60000 # batch size ~24
|
| 24 |
+
|
| 25 |
+
trainer:
|
| 26 |
+
class: AutoEncoderTrainer
|
| 27 |
+
config:
|
| 28 |
+
max_epoch: 100
|
| 29 |
+
save_topk: 10
|
| 30 |
+
save_dir: ./ckpts/autoencoder_codesign_pepbdb
|
| 31 |
+
patience: 10
|
| 32 |
+
metric_min_better: true
|
| 33 |
+
|
| 34 |
+
optimizer:
|
| 35 |
+
class: AdamW
|
| 36 |
+
lr: 1.0e-4
|
| 37 |
+
|
| 38 |
+
scheduler:
|
| 39 |
+
class: ReduceLROnPlateau
|
| 40 |
+
factor: 0.8
|
| 41 |
+
patience: 5
|
| 42 |
+
mode: min
|
| 43 |
+
frequency: val_epoch
|
| 44 |
+
min_lr: 5.0e-6
|
| 45 |
+
|
| 46 |
+
model:
|
| 47 |
+
class: AutoEncoder
|
| 48 |
+
embed_size: 128
|
| 49 |
+
hidden_size: 128
|
| 50 |
+
latent_size: 8
|
| 51 |
+
latent_n_channel: 1
|
| 52 |
+
n_layers: 3
|
| 53 |
+
n_channel: 14 # all atom
|
| 54 |
+
h_kl_weight: 0.3
|
| 55 |
+
z_kl_weight: 0.5
|
| 56 |
+
coord_loss_ratio: 0.5
|
| 57 |
+
coord_loss_weights:
|
| 58 |
+
Xloss: 1.0
|
| 59 |
+
ca_Xloss: 1.0
|
| 60 |
+
bb_bond_lengths_loss: 1.0
|
| 61 |
+
sc_bond_lengths_loss: 1.0
|
| 62 |
+
bb_dihedral_angles_loss: 0.0
|
| 63 |
+
sc_chi_angles_loss: 0.5
|
| 64 |
+
relative_position: false
|
| 65 |
+
anchor_at_ca: true
|
| 66 |
+
mask_ratio: 0.25
|
configs/pepbdb/autoencoder/train_fixseq.yaml
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
dataset:
|
| 2 |
+
train:
|
| 3 |
+
class: CoDesignDataset
|
| 4 |
+
mmap_dir: ./datasets/pepbdb/processed
|
| 5 |
+
specify_index: ./datasets/pepbdb/processed/train_index.txt
|
| 6 |
+
backbone_only: false
|
| 7 |
+
cluster: ./datasets/pepbdb/train.cluster
|
| 8 |
+
valid:
|
| 9 |
+
class: CoDesignDataset
|
| 10 |
+
mmap_dir: ./datasets/pepbdb/processed
|
| 11 |
+
specify_index: ./datasets/pepbdb/processed/valid_index.txt
|
| 12 |
+
backbone_only: false
|
| 13 |
+
|
| 14 |
+
dataloader:
|
| 15 |
+
shuffle: true
|
| 16 |
+
num_workers: 4
|
| 17 |
+
wrapper:
|
| 18 |
+
class: DynamicBatchWrapper
|
| 19 |
+
complexity: n**2
|
| 20 |
+
ubound_per_batch: 60000 # batch size ~24
|
| 21 |
+
|
| 22 |
+
trainer:
|
| 23 |
+
class: AutoEncoderTrainer
|
| 24 |
+
config:
|
| 25 |
+
max_epoch: 150 # the best checkpoint should be obatained at about epoch 457
|
| 26 |
+
save_topk: 10
|
| 27 |
+
save_dir: ./ckpts/autoencoder_fixseq
|
| 28 |
+
patience: 10
|
| 29 |
+
metric_min_better: true
|
| 30 |
+
|
| 31 |
+
optimizer:
|
| 32 |
+
class: AdamW
|
| 33 |
+
lr: 1.0e-4
|
| 34 |
+
|
| 35 |
+
scheduler:
|
| 36 |
+
class: ReduceLROnPlateau
|
| 37 |
+
factor: 0.8
|
| 38 |
+
patience: 15
|
| 39 |
+
mode: min
|
| 40 |
+
frequency: val_epoch
|
| 41 |
+
min_lr: 5.0e-6
|
| 42 |
+
|
| 43 |
+
model:
|
| 44 |
+
class: AutoEncoder
|
| 45 |
+
embed_size: 128
|
| 46 |
+
hidden_size: 128
|
| 47 |
+
latent_size: 0
|
| 48 |
+
latent_n_channel: 1
|
| 49 |
+
n_layers: 3
|
| 50 |
+
n_channel: 14 # all atom
|
| 51 |
+
h_kl_weight: 0.0
|
| 52 |
+
z_kl_weight: 0.6
|
| 53 |
+
coord_loss_ratio: 1.0
|
| 54 |
+
coord_loss_weights:
|
| 55 |
+
Xloss: 1.0
|
| 56 |
+
ca_Xloss: 1.0
|
| 57 |
+
bb_bond_lengths_loss: 1.0
|
| 58 |
+
sc_bond_lengths_loss: 1.0
|
| 59 |
+
bb_dihedral_angles_loss: 0.0
|
| 60 |
+
sc_chi_angles_loss: 0.5
|
| 61 |
+
anchor_at_ca: true
|
| 62 |
+
mode: fixseq
|
| 63 |
+
additional_noise_scale: 1.0
|
configs/pepbdb/ldm/setup_latent_guidance.yaml
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
dataset:
|
| 2 |
+
test:
|
| 3 |
+
class: CoDesignDataset
|
| 4 |
+
mmap_dir: ./datasets/pepbdb/processed
|
| 5 |
+
specify_index: ./datasets/pepbdb/processed/train_index.txt
|
| 6 |
+
backbone_only: false
|
| 7 |
+
|
| 8 |
+
dataloader:
|
| 9 |
+
num_workers: 2
|
| 10 |
+
batch_size: 32
|
| 11 |
+
|
| 12 |
+
backbone_only: false
|
configs/pepbdb/ldm/train_codesign.yaml
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
dataset:
|
| 2 |
+
train:
|
| 3 |
+
- class: CoDesignDataset
|
| 4 |
+
mmap_dir: ./datasets/pepbdb/processed
|
| 5 |
+
specify_index: ./datasets/pepbdb/processed/train_index.txt
|
| 6 |
+
backbone_only: false
|
| 7 |
+
cluster: ./datasets/pepbdb/train.cluster
|
| 8 |
+
use_covariance_matrix: true
|
| 9 |
+
valid:
|
| 10 |
+
class: CoDesignDataset
|
| 11 |
+
mmap_dir: ./datasets/pepbdb/processed
|
| 12 |
+
specify_index: ./datasets/pepbdb/processed/valid_index.txt
|
| 13 |
+
backbone_only: false
|
| 14 |
+
use_covariance_matrix: true
|
| 15 |
+
|
| 16 |
+
dataloader:
|
| 17 |
+
shuffle: true
|
| 18 |
+
num_workers: 4
|
| 19 |
+
wrapper:
|
| 20 |
+
class: DynamicBatchWrapper
|
| 21 |
+
complexity: n**2
|
| 22 |
+
ubound_per_batch: 60000 # batch size ~32
|
| 23 |
+
|
| 24 |
+
trainer:
|
| 25 |
+
class: LDMTrainer
|
| 26 |
+
criterion: Loss
|
| 27 |
+
config:
|
| 28 |
+
max_epoch: 500 # the best checkpoint should be obtained at around epoch 380
|
| 29 |
+
save_topk: 10
|
| 30 |
+
val_freq: 10
|
| 31 |
+
save_dir: ./ckpts/LDM_codesign
|
| 32 |
+
patience: 10
|
| 33 |
+
metric_min_better: true
|
| 34 |
+
|
| 35 |
+
optimizer:
|
| 36 |
+
class: AdamW
|
| 37 |
+
lr: 1.0e-4
|
| 38 |
+
|
| 39 |
+
scheduler:
|
| 40 |
+
class: ReduceLROnPlateau
|
| 41 |
+
factor: 0.6
|
| 42 |
+
patience: 3
|
| 43 |
+
mode: min
|
| 44 |
+
frequency: val_epoch
|
| 45 |
+
min_lr: 5.0e-6
|
| 46 |
+
|
| 47 |
+
model:
|
| 48 |
+
class: LDMPepDesign
|
| 49 |
+
autoencoder_ckpt: ""
|
| 50 |
+
autoencoder_no_randomness: true
|
| 51 |
+
hidden_size: 128
|
| 52 |
+
num_steps: 100
|
| 53 |
+
n_layers: 3
|
| 54 |
+
n_rbf: 32
|
| 55 |
+
cutoff: 3.0 # the coordinates are in standard space
|
| 56 |
+
dist_rbf: 32
|
| 57 |
+
dist_rbf_cutoff: 7.0
|
| 58 |
+
diffusion_opt:
|
| 59 |
+
trans_seq_type: Diffusion
|
| 60 |
+
trans_pos_type: Diffusion
|
| 61 |
+
max_gen_position: 60
|
configs/pepbdb/ldm/train_fixseq.yaml
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
dataset:
|
| 2 |
+
train:
|
| 3 |
+
- class: CoDesignDataset
|
| 4 |
+
mmap_dir: ./datasets/pepbdb/processed
|
| 5 |
+
specify_index: ./datasets/pepbdb/processed/train_index.txt
|
| 6 |
+
backbone_only: false
|
| 7 |
+
cluster: ./datasets/pepbdb/train.cluster
|
| 8 |
+
use_covariance_matrix: true
|
| 9 |
+
valid:
|
| 10 |
+
class: CoDesignDataset
|
| 11 |
+
mmap_dir: ./datasets/pepbdb/processed
|
| 12 |
+
specify_index: ./datasets/pepbdb/processed/valid_index.txt
|
| 13 |
+
backbone_only: false
|
| 14 |
+
use_covariance_matrix: true
|
| 15 |
+
|
| 16 |
+
dataloader:
|
| 17 |
+
shuffle: true
|
| 18 |
+
num_workers: 4
|
| 19 |
+
wrapper:
|
| 20 |
+
class: DynamicBatchWrapper
|
| 21 |
+
complexity: n**2
|
| 22 |
+
ubound_per_batch: 60000 # batch size ~32
|
| 23 |
+
|
| 24 |
+
trainer:
|
| 25 |
+
class: LDMTrainer
|
| 26 |
+
criterion: RMSD
|
| 27 |
+
config:
|
| 28 |
+
max_epoch: 1000 # the best checkpoint will be obtained at about 900 epoch
|
| 29 |
+
save_topk: 10
|
| 30 |
+
val_freq: 10
|
| 31 |
+
save_dir: ./ckpts/LDM_fixseq
|
| 32 |
+
patience: 10
|
| 33 |
+
metric_min_better: true
|
| 34 |
+
|
| 35 |
+
optimizer:
|
| 36 |
+
class: AdamW
|
| 37 |
+
lr: 1.0e-4
|
| 38 |
+
|
| 39 |
+
scheduler:
|
| 40 |
+
class: ReduceLROnPlateau
|
| 41 |
+
factor: 0.6
|
| 42 |
+
patience: 3
|
| 43 |
+
mode: min
|
| 44 |
+
frequency: val_epoch
|
| 45 |
+
min_lr: 5.0e-6
|
| 46 |
+
|
| 47 |
+
model:
|
| 48 |
+
class: LDMPepDesign
|
| 49 |
+
autoencoder_ckpt: ""
|
| 50 |
+
autoencoder_no_randomness: true
|
| 51 |
+
hidden_size: 128
|
| 52 |
+
num_steps: 100
|
| 53 |
+
n_layers: 6
|
| 54 |
+
n_rbf: 32
|
| 55 |
+
cutoff: 3.0 # the coordinates are in standard space
|
| 56 |
+
dist_rbf: 0
|
| 57 |
+
dist_rbf_cutoff: 0.0
|
| 58 |
+
diffusion_opt:
|
| 59 |
+
trans_seq_type: Diffusion
|
| 60 |
+
trans_pos_type: Diffusion
|
| 61 |
+
std: 20.0
|
| 62 |
+
mode: fixseq
|
| 63 |
+
max_gen_position: 60
|
configs/pepbdb/test_codesign.yaml
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
dataset:
|
| 2 |
+
test:
|
| 3 |
+
class: CoDesignDataset
|
| 4 |
+
mmap_dir: ./datasets/pepbdb/processed
|
| 5 |
+
specify_index: ./datasets/pepbdb/processed/test_index.txt
|
| 6 |
+
backbone_only: false
|
| 7 |
+
use_covariance_matrix: true
|
| 8 |
+
|
| 9 |
+
dataloader:
|
| 10 |
+
num_workers: 4
|
| 11 |
+
batch_size: 64
|
| 12 |
+
|
| 13 |
+
backbone_only: false
|
| 14 |
+
n_samples: 40
|
| 15 |
+
|
| 16 |
+
sample_opt:
|
| 17 |
+
energy_func: default
|
| 18 |
+
energy_lambda: 0.8
|
configs/pepbdb/test_fixseq.yaml
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
dataset:
|
| 2 |
+
test:
|
| 3 |
+
class: CoDesignDataset
|
| 4 |
+
mmap_dir: ./datasets/pepbdb/processed
|
| 5 |
+
specify_index: ./datasets/pepbdb/processed/test_index.txt
|
| 6 |
+
backbone_only: false
|
| 7 |
+
use_covariance_matrix: true
|
| 8 |
+
|
| 9 |
+
dataloader:
|
| 10 |
+
num_workers: 4
|
| 11 |
+
batch_size: 64
|
| 12 |
+
|
| 13 |
+
backbone_only: false
|
| 14 |
+
struct_only: true
|
| 15 |
+
n_samples: 10
|
| 16 |
+
|
| 17 |
+
sample_opt:
|
| 18 |
+
energy_func: default
|
| 19 |
+
energy_lambda: 0.8
|
configs/pepbench/autoencoder/train_codesign.yaml
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
dataset:
|
| 2 |
+
train:
|
| 3 |
+
- class: CoDesignDataset
|
| 4 |
+
mmap_dir: ./datasets/train_valid/processed
|
| 5 |
+
specify_index: ./datasets/train_valid/processed/train_index.txt
|
| 6 |
+
backbone_only: false
|
| 7 |
+
cluster: ./datasets/train_valid/train.cluster
|
| 8 |
+
- class: CoDesignDataset
|
| 9 |
+
mmap_dir: ./datasets/ProtFrag/processed
|
| 10 |
+
backbone_only: false
|
| 11 |
+
valid:
|
| 12 |
+
class: CoDesignDataset
|
| 13 |
+
mmap_dir: ./datasets/train_valid/processed
|
| 14 |
+
specify_index: ./datasets/train_valid/processed/valid_index.txt
|
| 15 |
+
backbone_only: false
|
| 16 |
+
|
| 17 |
+
dataloader:
|
| 18 |
+
shuffle: true
|
| 19 |
+
num_workers: 4
|
| 20 |
+
wrapper:
|
| 21 |
+
class: DynamicBatchWrapper
|
| 22 |
+
complexity: n**2
|
| 23 |
+
ubound_per_batch: 60000 # batch size ~24
|
| 24 |
+
|
| 25 |
+
trainer:
|
| 26 |
+
class: AutoEncoderTrainer
|
| 27 |
+
config:
|
| 28 |
+
max_epoch: 100
|
| 29 |
+
save_topk: 10
|
| 30 |
+
save_dir: ./ckpts/autoencoder_codesign
|
| 31 |
+
patience: 10
|
| 32 |
+
metric_min_better: true
|
| 33 |
+
|
| 34 |
+
optimizer:
|
| 35 |
+
class: AdamW
|
| 36 |
+
lr: 1.0e-4
|
| 37 |
+
|
| 38 |
+
scheduler:
|
| 39 |
+
class: ReduceLROnPlateau
|
| 40 |
+
factor: 0.8
|
| 41 |
+
patience: 5
|
| 42 |
+
mode: min
|
| 43 |
+
frequency: val_epoch
|
| 44 |
+
min_lr: 5.0e-6
|
| 45 |
+
|
| 46 |
+
model:
|
| 47 |
+
class: AutoEncoder
|
| 48 |
+
embed_size: 128
|
| 49 |
+
hidden_size: 128
|
| 50 |
+
latent_size: 8
|
| 51 |
+
latent_n_channel: 1
|
| 52 |
+
n_layers: 3
|
| 53 |
+
n_channel: 14 # all atom
|
| 54 |
+
h_kl_weight: 0.3
|
| 55 |
+
z_kl_weight: 0.5
|
| 56 |
+
coord_loss_ratio: 0.5
|
| 57 |
+
coord_loss_weights:
|
| 58 |
+
Xloss: 1.0
|
| 59 |
+
ca_Xloss: 1.0
|
| 60 |
+
bb_bond_lengths_loss: 1.0
|
| 61 |
+
sc_bond_lengths_loss: 1.0
|
| 62 |
+
bb_dihedral_angles_loss: 0.0
|
| 63 |
+
sc_chi_angles_loss: 0.5
|
| 64 |
+
relative_position: false
|
| 65 |
+
anchor_at_ca: true
|
| 66 |
+
mask_ratio: 0.25
|
configs/pepbench/autoencoder/train_fixseq.yaml
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
dataset:
|
| 2 |
+
train:
|
| 3 |
+
class: CoDesignDataset
|
| 4 |
+
mmap_dir: ./datasets/train_valid/processed
|
| 5 |
+
specify_index: ./datasets/train_valid/processed/train_index.txt
|
| 6 |
+
backbone_only: false
|
| 7 |
+
cluster: ./datasets/train_valid/train.cluster
|
| 8 |
+
valid:
|
| 9 |
+
class: CoDesignDataset
|
| 10 |
+
mmap_dir: ./datasets/train_valid/processed
|
| 11 |
+
specify_index: ./datasets/train_valid/processed/valid_index.txt
|
| 12 |
+
backbone_only: false
|
| 13 |
+
|
| 14 |
+
dataloader:
|
| 15 |
+
shuffle: true
|
| 16 |
+
num_workers: 4
|
| 17 |
+
wrapper:
|
| 18 |
+
class: DynamicBatchWrapper
|
| 19 |
+
complexity: n**2
|
| 20 |
+
ubound_per_batch: 60000 # batch size ~24
|
| 21 |
+
|
| 22 |
+
trainer:
|
| 23 |
+
class: AutoEncoderTrainer
|
| 24 |
+
config:
|
| 25 |
+
max_epoch: 500 # the best checkpoint should be obatained at about epoch 457
|
| 26 |
+
save_topk: 10
|
| 27 |
+
save_dir: ./ckpts/autoencoder_fixseq
|
| 28 |
+
patience: 10
|
| 29 |
+
metric_min_better: true
|
| 30 |
+
|
| 31 |
+
optimizer:
|
| 32 |
+
class: AdamW
|
| 33 |
+
lr: 1.0e-4
|
| 34 |
+
|
| 35 |
+
scheduler:
|
| 36 |
+
class: ReduceLROnPlateau
|
| 37 |
+
factor: 0.8
|
| 38 |
+
patience: 15
|
| 39 |
+
mode: min
|
| 40 |
+
frequency: val_epoch
|
| 41 |
+
min_lr: 5.0e-6
|
| 42 |
+
|
| 43 |
+
model:
|
| 44 |
+
class: AutoEncoder
|
| 45 |
+
embed_size: 128
|
| 46 |
+
hidden_size: 128
|
| 47 |
+
latent_size: 0
|
| 48 |
+
latent_n_channel: 1
|
| 49 |
+
n_layers: 3
|
| 50 |
+
n_channel: 14 # all atom
|
| 51 |
+
h_kl_weight: 0.0
|
| 52 |
+
z_kl_weight: 1.0
|
| 53 |
+
coord_loss_ratio: 1.0
|
| 54 |
+
coord_loss_weights:
|
| 55 |
+
Xloss: 1.0
|
| 56 |
+
ca_Xloss: 1.0
|
| 57 |
+
bb_bond_lengths_loss: 1.0
|
| 58 |
+
sc_bond_lengths_loss: 1.0
|
| 59 |
+
bb_dihedral_angles_loss: 0.0
|
| 60 |
+
sc_chi_angles_loss: 0.5
|
| 61 |
+
anchor_at_ca: true
|
| 62 |
+
mode: fixseq
|
configs/pepbench/ldm/setup_latent_guidance.yaml
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
dataset:
|
| 2 |
+
test:
|
| 3 |
+
class: CoDesignDataset
|
| 4 |
+
mmap_dir: ./datasets/train_valid/processed
|
| 5 |
+
specify_index: ./datasets/train_valid/processed/train_index.txt
|
| 6 |
+
backbone_only: false
|
| 7 |
+
|
| 8 |
+
dataloader:
|
| 9 |
+
num_workers: 2
|
| 10 |
+
batch_size: 32
|
| 11 |
+
|
| 12 |
+
backbone_only: false
|
configs/pepbench/ldm/train_codesign.yaml
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
dataset:
|
| 2 |
+
train:
|
| 3 |
+
class: CoDesignDataset
|
| 4 |
+
mmap_dir: ./datasets/train_valid/processed
|
| 5 |
+
specify_index: ./datasets/train_valid/processed/train_index.txt
|
| 6 |
+
backbone_only: false
|
| 7 |
+
cluster: ./datasets/train_valid/train.cluster
|
| 8 |
+
use_covariance_matrix: true
|
| 9 |
+
valid:
|
| 10 |
+
class: CoDesignDataset
|
| 11 |
+
mmap_dir: ./datasets/train_valid/processed
|
| 12 |
+
specify_index: ./datasets/train_valid/processed/valid_index.txt
|
| 13 |
+
backbone_only: false
|
| 14 |
+
use_covariance_matrix: true
|
| 15 |
+
|
| 16 |
+
dataloader:
|
| 17 |
+
shuffle: true
|
| 18 |
+
num_workers: 4
|
| 19 |
+
wrapper:
|
| 20 |
+
class: DynamicBatchWrapper
|
| 21 |
+
complexity: n**2
|
| 22 |
+
ubound_per_batch: 60000 # batch size ~32
|
| 23 |
+
|
| 24 |
+
trainer:
|
| 25 |
+
class: LDMTrainer
|
| 26 |
+
criterion: Loss
|
| 27 |
+
config:
|
| 28 |
+
max_epoch: 500 # the best checkpoint should be obtained at around epoch 380
|
| 29 |
+
save_topk: 10
|
| 30 |
+
val_freq: 10
|
| 31 |
+
save_dir: ./ckpts/LDM_codesign
|
| 32 |
+
patience: 10
|
| 33 |
+
metric_min_better: true
|
| 34 |
+
|
| 35 |
+
optimizer:
|
| 36 |
+
class: AdamW
|
| 37 |
+
lr: 1.0e-4
|
| 38 |
+
|
| 39 |
+
scheduler:
|
| 40 |
+
class: ReduceLROnPlateau
|
| 41 |
+
factor: 0.6
|
| 42 |
+
patience: 3
|
| 43 |
+
mode: min
|
| 44 |
+
frequency: val_epoch
|
| 45 |
+
min_lr: 5.0e-6
|
| 46 |
+
|
| 47 |
+
model:
|
| 48 |
+
class: LDMPepDesign
|
| 49 |
+
autoencoder_ckpt: ""
|
| 50 |
+
autoencoder_no_randomness: true
|
| 51 |
+
hidden_size: 128
|
| 52 |
+
num_steps: 100
|
| 53 |
+
n_layers: 3
|
| 54 |
+
n_rbf: 32
|
| 55 |
+
cutoff: 3.0 # the coordinates are in standard space
|
| 56 |
+
dist_rbf: 32
|
| 57 |
+
dist_rbf_cutoff: 7.0
|
| 58 |
+
diffusion_opt:
|
| 59 |
+
trans_seq_type: Diffusion
|
| 60 |
+
trans_pos_type: Diffusion
|
configs/pepbench/ldm/train_fixseq.yaml
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
dataset:
|
| 2 |
+
train:
|
| 3 |
+
class: CoDesignDataset
|
| 4 |
+
mmap_dir: ./datasets/train_valid/processed
|
| 5 |
+
specify_index: ./datasets/train_valid/processed/train_index.txt
|
| 6 |
+
backbone_only: false
|
| 7 |
+
cluster: ./datasets/train_valid/train.cluster
|
| 8 |
+
use_covariance_matrix: true
|
| 9 |
+
valid:
|
| 10 |
+
class: CoDesignDataset
|
| 11 |
+
mmap_dir: ./datasets/train_valid/processed
|
| 12 |
+
specify_index: ./datasets/train_valid/processed/valid_index.txt
|
| 13 |
+
backbone_only: false
|
| 14 |
+
use_covariance_matrix: true
|
| 15 |
+
|
| 16 |
+
dataloader:
|
| 17 |
+
shuffle: true
|
| 18 |
+
num_workers: 4
|
| 19 |
+
wrapper:
|
| 20 |
+
class: DynamicBatchWrapper
|
| 21 |
+
complexity: n**2
|
| 22 |
+
ubound_per_batch: 60000 # batch size ~32
|
| 23 |
+
|
| 24 |
+
trainer:
|
| 25 |
+
class: LDMTrainer
|
| 26 |
+
criterion: RMSD
|
| 27 |
+
config:
|
| 28 |
+
max_epoch: 1000 # the best checkpoint will be obtained at about 720 epoch
|
| 29 |
+
save_topk: 10
|
| 30 |
+
val_freq: 10
|
| 31 |
+
save_dir: ./ckpts/LDM_fixseq
|
| 32 |
+
patience: 10
|
| 33 |
+
metric_min_better: true
|
| 34 |
+
|
| 35 |
+
optimizer:
|
| 36 |
+
class: AdamW
|
| 37 |
+
lr: 1.0e-4
|
| 38 |
+
|
| 39 |
+
scheduler:
|
| 40 |
+
class: ReduceLROnPlateau
|
| 41 |
+
factor: 0.6
|
| 42 |
+
patience: 3
|
| 43 |
+
mode: min
|
| 44 |
+
frequency: val_epoch
|
| 45 |
+
min_lr: 5.0e-6
|
| 46 |
+
|
| 47 |
+
model:
|
| 48 |
+
class: LDMPepDesign
|
| 49 |
+
autoencoder_ckpt: ""
|
| 50 |
+
autoencoder_no_randomness: true
|
| 51 |
+
hidden_size: 128
|
| 52 |
+
num_steps: 100
|
| 53 |
+
n_layers: 3
|
| 54 |
+
n_rbf: 32
|
| 55 |
+
cutoff: 3.0 # the coordinates are in standard space
|
| 56 |
+
dist_rbf: 0
|
| 57 |
+
dist_rbf_cutoff: 0.0
|
| 58 |
+
diffusion_opt:
|
| 59 |
+
trans_seq_type: Diffusion
|
| 60 |
+
trans_pos_type: Diffusion
|
| 61 |
+
mode: fixseq
|
configs/pepbench/test_codesign.yaml
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
dataset:
|
| 2 |
+
test:
|
| 3 |
+
class: CoDesignDataset
|
| 4 |
+
mmap_dir: ./datasets/LNR/processed
|
| 5 |
+
backbone_only: false
|
| 6 |
+
use_covariance_matrix: true
|
| 7 |
+
|
| 8 |
+
dataloader:
|
| 9 |
+
num_workers: 4
|
| 10 |
+
batch_size: 64
|
| 11 |
+
|
| 12 |
+
backbone_only: false
|
| 13 |
+
n_samples: 40
|
| 14 |
+
|
| 15 |
+
sample_opt:
|
| 16 |
+
energy_func: default
|
| 17 |
+
energy_lambda: 0.8
|
configs/pepbench/test_fixseq.yaml
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
dataset:
|
| 2 |
+
test:
|
| 3 |
+
class: CoDesignDataset
|
| 4 |
+
mmap_dir: ./datasets/LNR/processed
|
| 5 |
+
backbone_only: false
|
| 6 |
+
use_covariance_matrix: true
|
| 7 |
+
|
| 8 |
+
dataloader:
|
| 9 |
+
num_workers: 4
|
| 10 |
+
batch_size: 64
|
| 11 |
+
|
| 12 |
+
backbone_only: false
|
| 13 |
+
struct_only: true
|
| 14 |
+
n_samples: 10
|
| 15 |
+
|
| 16 |
+
sample_opt:
|
| 17 |
+
energy_func: default
|
| 18 |
+
energy_lambda: 0.5
|
data/__init__.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/python
|
| 2 |
+
# -*- coding:utf-8 -*-
|
| 3 |
+
from .dataset_wrapper import MixDatasetWrapper
|
| 4 |
+
from .codesign import CoDesignDataset
|
| 5 |
+
from .resample import ClusterResampler
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from torch.utils.data import DataLoader
|
| 10 |
+
|
| 11 |
+
import utils.register as R
|
| 12 |
+
from utils.logger import print_log
|
| 13 |
+
|
| 14 |
+
def create_dataset(config: dict):
|
| 15 |
+
splits = []
|
| 16 |
+
for split_name in ['train', 'valid', 'test']:
|
| 17 |
+
split_config = config.get(split_name, None)
|
| 18 |
+
if split_config is None:
|
| 19 |
+
splits.append(None)
|
| 20 |
+
continue
|
| 21 |
+
if isinstance(split_config, list):
|
| 22 |
+
dataset = MixDatasetWrapper(
|
| 23 |
+
*[R.construct(cfg) for cfg in split_config]
|
| 24 |
+
)
|
| 25 |
+
else:
|
| 26 |
+
dataset = R.construct(split_config)
|
| 27 |
+
splits.append(dataset)
|
| 28 |
+
return splits # train/valid/test
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def create_dataloader(dataset, config: dict, n_gpu: int=1, validation: bool=False):
|
| 32 |
+
if 'wrapper' in config:
|
| 33 |
+
dataset = R.construct(config['wrapper'], dataset=dataset)
|
| 34 |
+
batch_size = config.get('batch_size', n_gpu) # default 1 on each gpu
|
| 35 |
+
if validation:
|
| 36 |
+
batch_size = config.get('val_batch_size', batch_size)
|
| 37 |
+
shuffle = config.get('shuffle', False)
|
| 38 |
+
num_workers = config.get('num_workers', 4)
|
| 39 |
+
collate_fn = dataset.collate_fn if hasattr(dataset, 'collate_fn') else None
|
| 40 |
+
if n_gpu > 1:
|
| 41 |
+
sampler = torch.utils.data.distributed.DistributedSampler(dataset, shuffle=shuffle)
|
| 42 |
+
batch_size = int(batch_size / n_gpu)
|
| 43 |
+
print_log(f'Batch size on a single GPU: {batch_size}')
|
| 44 |
+
else:
|
| 45 |
+
sampler = None
|
| 46 |
+
return DataLoader(
|
| 47 |
+
dataset=dataset,
|
| 48 |
+
batch_size=batch_size,
|
| 49 |
+
num_workers=num_workers,
|
| 50 |
+
shuffle=(shuffle and sampler is None),
|
| 51 |
+
collate_fn=collate_fn,
|
| 52 |
+
sampler=sampler
|
| 53 |
+
)
|
data/codesign.py
ADDED
|
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import os
|
| 3 |
+
from typing import Optional, Any
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
from torch.nn.utils.rnn import pad_sequence
|
| 8 |
+
|
| 9 |
+
from utils import register as R
|
| 10 |
+
from utils.const import sidechain_atoms
|
| 11 |
+
|
| 12 |
+
from data.converter.list_blocks_to_pdb import list_blocks_to_pdb
|
| 13 |
+
|
| 14 |
+
from .format import VOCAB, Block, Atom
|
| 15 |
+
from .mmap_dataset import MMAPDataset
|
| 16 |
+
from .resample import ClusterResampler
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def calculate_covariance_matrix(point_cloud):
|
| 21 |
+
# Calculate the covariance matrix of the point cloud
|
| 22 |
+
covariance_matrix = np.cov(point_cloud, rowvar=False)
|
| 23 |
+
return covariance_matrix
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
@R.register('CoDesignDataset')
|
| 27 |
+
class CoDesignDataset(MMAPDataset):
|
| 28 |
+
|
| 29 |
+
MAX_N_ATOM = 14
|
| 30 |
+
|
| 31 |
+
def __init__(
|
| 32 |
+
self,
|
| 33 |
+
mmap_dir: str,
|
| 34 |
+
backbone_only: bool, # only backbone (N, CA, C, O) or full-atom
|
| 35 |
+
specify_data: Optional[str] = None,
|
| 36 |
+
specify_index: Optional[str] = None,
|
| 37 |
+
padding_collate: bool = False,
|
| 38 |
+
cluster: Optional[str] = None,
|
| 39 |
+
use_covariance_matrix: bool = False
|
| 40 |
+
) -> None:
|
| 41 |
+
super().__init__(mmap_dir, specify_data, specify_index)
|
| 42 |
+
self.mmap_dir = mmap_dir
|
| 43 |
+
self.backbone_only = backbone_only
|
| 44 |
+
self._lengths = [len(prop[-1].split(',')) + int(prop[1]) for prop in self._properties]
|
| 45 |
+
self.padding_collate = padding_collate
|
| 46 |
+
self.resampler = ClusterResampler(cluster) if cluster else None # should only be used in training!
|
| 47 |
+
self.use_covariance_matrix = use_covariance_matrix
|
| 48 |
+
|
| 49 |
+
self.dynamic_idxs = [i for i in range(len(self))]
|
| 50 |
+
self.update_epoch() # should be called every epoch
|
| 51 |
+
|
| 52 |
+
def update_epoch(self):
|
| 53 |
+
if self.resampler is not None:
|
| 54 |
+
self.dynamic_idxs = self.resampler(len(self))
|
| 55 |
+
|
| 56 |
+
def get_len(self, idx):
|
| 57 |
+
return self._lengths[self.dynamic_idxs[idx]]
|
| 58 |
+
|
| 59 |
+
def get_summary(self, idx: int):
|
| 60 |
+
props = self._properties[idx]
|
| 61 |
+
_id = self._indexes[idx][0].split('.')[0]
|
| 62 |
+
ref_pdb = os.path.join(self.mmap_dir, '..', 'pdbs', _id + '.pdb')
|
| 63 |
+
rec_chain, lig_chain = props[4], props[5]
|
| 64 |
+
return _id, ref_pdb, rec_chain, lig_chain
|
| 65 |
+
|
| 66 |
+
def __getitem__(self, idx: int):
|
| 67 |
+
idx = self.dynamic_idxs[idx]
|
| 68 |
+
rec_blocks, lig_blocks = super().__getitem__(idx)
|
| 69 |
+
# receptor, (lig_chain_id, lig_blocks) = super().__getitem__(idx)
|
| 70 |
+
# pocket = {}
|
| 71 |
+
# for i in self._properties[idx][-1].split(','):
|
| 72 |
+
# chain, i = i.split(':')
|
| 73 |
+
# if chain not in pocket:
|
| 74 |
+
# pocket[chain] = []
|
| 75 |
+
# pocket[chain].append(int(i))
|
| 76 |
+
# rec_blocks = []
|
| 77 |
+
# for chain_id, blocks in receptor:
|
| 78 |
+
# for i in pocket[chain_id]:
|
| 79 |
+
# rec_blocks.append(blocks[i])
|
| 80 |
+
pocket_idx = [int(i) for i in self._properties[idx][-1].split(',')]
|
| 81 |
+
rec_position_ids = [i + 1 for i, _ in enumerate(rec_blocks)]
|
| 82 |
+
rec_blocks = [rec_blocks[i] for i in pocket_idx]
|
| 83 |
+
rec_position_ids = [rec_position_ids[i] for i in pocket_idx]
|
| 84 |
+
rec_blocks = [Block.from_tuple(tup) for tup in rec_blocks]
|
| 85 |
+
lig_blocks = [Block.from_tuple(tup) for tup in lig_blocks]
|
| 86 |
+
|
| 87 |
+
# for block in lig_blocks:
|
| 88 |
+
# block.units = [Atom('CA', [0, 0, 0], 'C')]
|
| 89 |
+
# if idx == 0:
|
| 90 |
+
# print(self._properties[idx])
|
| 91 |
+
# print(''.join(VOCAB.abrv_to_symbol(block.abrv) for block in lig_blocks))
|
| 92 |
+
# list_blocks_to_pdb([
|
| 93 |
+
# rec_blocks, lig_blocks
|
| 94 |
+
# ], ['B', 'A'], 'pocket.pdb')
|
| 95 |
+
|
| 96 |
+
mask = [0 for _ in rec_blocks] + [1 for _ in lig_blocks]
|
| 97 |
+
position_ids = rec_position_ids + [i + 1 for i, _ in enumerate(lig_blocks)]
|
| 98 |
+
X, S, atom_mask = [], [], []
|
| 99 |
+
for block in rec_blocks + lig_blocks:
|
| 100 |
+
symbol = VOCAB.abrv_to_symbol(block.abrv)
|
| 101 |
+
atom2coord = { unit.name: unit.get_coord() for unit in block.units }
|
| 102 |
+
bb_pos = np.mean(list(atom2coord.values()), axis=0).tolist()
|
| 103 |
+
coords, coord_mask = [], []
|
| 104 |
+
for atom_name in VOCAB.backbone_atoms + sidechain_atoms.get(symbol, []):
|
| 105 |
+
if atom_name in atom2coord:
|
| 106 |
+
coords.append(atom2coord[atom_name])
|
| 107 |
+
coord_mask.append(1)
|
| 108 |
+
else:
|
| 109 |
+
coords.append(bb_pos)
|
| 110 |
+
coord_mask.append(0)
|
| 111 |
+
n_pad = self.MAX_N_ATOM - len(coords)
|
| 112 |
+
for _ in range(n_pad):
|
| 113 |
+
coords.append(bb_pos)
|
| 114 |
+
coord_mask.append(0)
|
| 115 |
+
|
| 116 |
+
X.append(coords)
|
| 117 |
+
S.append(VOCAB.symbol_to_idx(symbol))
|
| 118 |
+
atom_mask.append(coord_mask)
|
| 119 |
+
|
| 120 |
+
X, atom_mask = torch.tensor(X, dtype=torch.float), torch.tensor(atom_mask, dtype=torch.bool)
|
| 121 |
+
mask = torch.tensor(mask, dtype=torch.bool)
|
| 122 |
+
if self.backbone_only:
|
| 123 |
+
X, atom_mask = X[:, :4], atom_mask[:, :4]
|
| 124 |
+
|
| 125 |
+
if self.use_covariance_matrix:
|
| 126 |
+
cov = calculate_covariance_matrix(X[~mask][:, 1][atom_mask[~mask][:, 1]].numpy()) # only use the receptor to derive the affine transformation
|
| 127 |
+
eps = 1e-4
|
| 128 |
+
cov = cov + eps * np.identity(cov.shape[0])
|
| 129 |
+
L = torch.from_numpy(np.linalg.cholesky(cov)).float().unsqueeze(0)
|
| 130 |
+
else:
|
| 131 |
+
L = None
|
| 132 |
+
|
| 133 |
+
item = {
|
| 134 |
+
'X': X, # [N, 14] or [N, 4] if backbone_only == True
|
| 135 |
+
'S': torch.tensor(S, dtype=torch.long), # [N]
|
| 136 |
+
'position_ids': torch.tensor(position_ids, dtype=torch.long), # [N]
|
| 137 |
+
'mask': mask, # [N], 1 for generation
|
| 138 |
+
'atom_mask': atom_mask, # [N, 14] or [N, 4], 1 for having records in the PDB
|
| 139 |
+
'lengths': len(S),
|
| 140 |
+
}
|
| 141 |
+
if L is not None:
|
| 142 |
+
item['L'] = L
|
| 143 |
+
return item
|
| 144 |
+
|
| 145 |
+
def collate_fn(self, batch):
|
| 146 |
+
if self.padding_collate:
|
| 147 |
+
results = {}
|
| 148 |
+
pad_idx = VOCAB.symbol_to_idx(VOCAB.PAD)
|
| 149 |
+
for key in batch[0]:
|
| 150 |
+
values = [item[key] for item in batch]
|
| 151 |
+
if values[0] is None:
|
| 152 |
+
results[key] = None
|
| 153 |
+
continue
|
| 154 |
+
if key == 'lengths':
|
| 155 |
+
results[key] = torch.tensor(values, dtype=torch.long)
|
| 156 |
+
elif key == 'S':
|
| 157 |
+
results[key] = pad_sequence(values, batch_first=True, padding_value=pad_idx)
|
| 158 |
+
else:
|
| 159 |
+
results[key] = pad_sequence(values, batch_first=True, padding_value=0)
|
| 160 |
+
return results
|
| 161 |
+
else:
|
| 162 |
+
results = {}
|
| 163 |
+
for key in batch[0]:
|
| 164 |
+
values = [item[key] for item in batch]
|
| 165 |
+
if values[0] is None:
|
| 166 |
+
results[key] = None
|
| 167 |
+
continue
|
| 168 |
+
if key == 'lengths':
|
| 169 |
+
results[key] = torch.tensor(values, dtype=torch.long)
|
| 170 |
+
else:
|
| 171 |
+
results[key] = torch.cat(values, dim=0)
|
| 172 |
+
return results
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
@R.register('ShapeDataset')
|
| 176 |
+
class ShapeDataset(CoDesignDataset):
|
| 177 |
+
def __init__(
|
| 178 |
+
self,
|
| 179 |
+
mmap_dir: str,
|
| 180 |
+
specify_data: Optional[str] = None,
|
| 181 |
+
specify_index: Optional[str] = None,
|
| 182 |
+
padding_collate: bool = False,
|
| 183 |
+
cluster: Optional[str] = None
|
| 184 |
+
) -> None:
|
| 185 |
+
super().__init__(mmap_dir, False, specify_data, specify_index, padding_collate, cluster)
|
| 186 |
+
self.ca_idx = VOCAB.backbone_atoms.index('CA')
|
| 187 |
+
|
| 188 |
+
def __getitem__(self, idx: int):
|
| 189 |
+
item = super().__getitem__(idx)
|
| 190 |
+
|
| 191 |
+
# refine coordinates to CA and the atom furthest from CA
|
| 192 |
+
X = item['X'] # [N, 14, 3]
|
| 193 |
+
atom_mask = item['atom_mask']
|
| 194 |
+
ca_x = X[:, self.ca_idx].unsqueeze(1) # [N, 1, 3]
|
| 195 |
+
sc_x = X[:, 4:] # [N, 10, 3], sidechain atom indexes
|
| 196 |
+
dist = torch.norm(sc_x - ca_x, dim=-1) # [N, 10]
|
| 197 |
+
dist = dist.masked_fill(~atom_mask[:, 4:], 1e10)
|
| 198 |
+
furthest_atom_x = sc_x[torch.arange(sc_x.shape[0]), torch.argmax(dist, dim=-1)] # [N, 3]
|
| 199 |
+
X = torch.cat([ca_x, furthest_atom_x.unsqueeze(1)], dim=1)
|
| 200 |
+
|
| 201 |
+
item['X'] = X
|
| 202 |
+
return item
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
if __name__ == '__main__':
|
| 206 |
+
import sys
|
| 207 |
+
dataset = CoDesignDataset(sys.argv[1], backbone_only=True)
|
| 208 |
+
print(dataset[0])
|
data/converter/blocks_interface.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/python
|
| 2 |
+
# -*- coding:utf-8 -*-
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def blocks_to_coords(blocks):
|
| 7 |
+
max_n_unit = 0
|
| 8 |
+
coords, masks = [], []
|
| 9 |
+
for block in blocks:
|
| 10 |
+
coords.append([unit.get_coord() for unit in block.units])
|
| 11 |
+
max_n_unit = max(max_n_unit, len(coords[-1]))
|
| 12 |
+
masks.append([1 for _ in coords[-1]])
|
| 13 |
+
|
| 14 |
+
for i in range(len(coords)):
|
| 15 |
+
num_pad = max_n_unit - len(coords[i])
|
| 16 |
+
coords[i] = coords[i] + [[0, 0, 0] for _ in range(num_pad)]
|
| 17 |
+
masks[i] = masks[i] + [0 for _ in range(num_pad)]
|
| 18 |
+
|
| 19 |
+
return np.array(coords), np.array(masks).astype('bool') # [N, M, 3], [N, M], M == max_n_unit, in mask 0 is for padding
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def dist_matrix_from_coords(coords1, masks1, coords2, masks2):
|
| 23 |
+
dist = np.linalg.norm(coords1[:, None] - coords2[None, :], axis=-1) # [N1, N2, M]
|
| 24 |
+
dist = dist + np.logical_not(masks1[:, None] * masks2[None, :]) * 1e6 # [N1, N2, M]
|
| 25 |
+
dist = np.min(dist, axis=-1) # [N1, N2]
|
| 26 |
+
return dist
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def dist_matrix_from_blocks(blocks1, blocks2):
|
| 30 |
+
blocks_coord, blocks_mask = blocks_to_coords(blocks1 + blocks2)
|
| 31 |
+
blocks1_coord, blocks1_mask = blocks_coord[:len(blocks1)], blocks_mask[:len(blocks1)]
|
| 32 |
+
blocks2_coord, blocks2_mask = blocks_coord[len(blocks1):], blocks_mask[len(blocks1):]
|
| 33 |
+
dist = dist_matrix_from_coords(blocks1_coord, blocks1_mask, blocks2_coord, blocks2_mask)
|
| 34 |
+
return dist
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def blocks_interface(blocks1, blocks2, dist_th):
|
| 38 |
+
dist = dist_matrix_from_blocks(blocks1, blocks2)
|
| 39 |
+
|
| 40 |
+
on_interface = dist < dist_th
|
| 41 |
+
indexes1 = np.nonzero(on_interface.sum(axis=1) > 0)[0]
|
| 42 |
+
indexes2 = np.nonzero(on_interface.sum(axis=0) > 0)[0]
|
| 43 |
+
|
| 44 |
+
blocks1 = [blocks1[i] for i in indexes1]
|
| 45 |
+
blocks2 = [blocks2[i] for i in indexes2]
|
| 46 |
+
|
| 47 |
+
return (blocks1, blocks2), (indexes1, indexes2)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def add_cb(input_array):
|
| 51 |
+
#from protein mpnn
|
| 52 |
+
#The virtual Cβ coordinates were calculated using ideal angle and bond length definitions: b = Cα - N, c = C - Cα, a = cross(b, c), Cβ = -0.58273431*a + 0.56802827*b - 0.54067466*c + Cα.
|
| 53 |
+
N,CA,C,O = input_array
|
| 54 |
+
b = CA - N
|
| 55 |
+
c = C - CA
|
| 56 |
+
a = np.cross(b,c)
|
| 57 |
+
CB = np.around(-0.58273431*a + 0.56802827*b - 0.54067466*c + CA,3)
|
| 58 |
+
return CB #np.array([N,CA,C,CB,O])
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def blocks_to_cb_coords(blocks):
|
| 62 |
+
cb_coords = []
|
| 63 |
+
for block in blocks:
|
| 64 |
+
try:
|
| 65 |
+
cb_coords.append(block.get_unit_by_name('CB').get_coord())
|
| 66 |
+
except KeyError:
|
| 67 |
+
tmp_coord = np.array([
|
| 68 |
+
block.get_unit_by_name('N').get_coord(),
|
| 69 |
+
block.get_unit_by_name('CA').get_coord(),
|
| 70 |
+
block.get_unit_by_name('C').get_coord(),
|
| 71 |
+
block.get_unit_by_name('O').get_coord()
|
| 72 |
+
])
|
| 73 |
+
cb_coords.append(add_cb(tmp_coord))
|
| 74 |
+
return np.array(cb_coords)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def blocks_cb_interface(blocks1, blocks2, dist_th=8.0):
|
| 78 |
+
cb_coords1 = blocks_to_cb_coords(blocks1)
|
| 79 |
+
cb_coords2 = blocks_to_cb_coords(blocks2)
|
| 80 |
+
dist = np.linalg.norm(cb_coords1[:, None] - cb_coords2[None, :], axis=-1) # [N1, N2]
|
| 81 |
+
|
| 82 |
+
on_interface = dist < dist_th
|
| 83 |
+
indexes1 = np.nonzero(on_interface.sum(axis=1) > 0)[0]
|
| 84 |
+
indexes2 = np.nonzero(on_interface.sum(axis=0) > 0)[0]
|
| 85 |
+
|
| 86 |
+
blocks1 = [blocks1[i] for i in indexes1]
|
| 87 |
+
blocks2 = [blocks2[i] for i in indexes2]
|
| 88 |
+
|
| 89 |
+
return (blocks1, blocks2), (indexes1, indexes2)
|
data/converter/blocks_to_data.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/python
|
| 2 |
+
# -*- coding:utf-8 -*-
|
| 3 |
+
from typing import List
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
from data.format import VOCAB, Block
|
| 8 |
+
from utils import const
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def blocks_to_data(*blocks_list: List[List[Block]]):
|
| 12 |
+
B, A, X, atom_positions, block_lengths, segment_ids = [], [], [], [], [], []
|
| 13 |
+
atom_mask, is_ca = [], []
|
| 14 |
+
topo_edge_index, topo_edge_attr, atom_names = [], [], []
|
| 15 |
+
last_c_node_id = None
|
| 16 |
+
for i, blocks in enumerate(blocks_list):
|
| 17 |
+
if len(blocks) == 0:
|
| 18 |
+
continue
|
| 19 |
+
cur_B, cur_A, cur_X, cur_atom_positions, cur_block_lengths = [], [], [], [], []
|
| 20 |
+
cur_atom_mask, cur_is_ca = [], []
|
| 21 |
+
# other nodes
|
| 22 |
+
for block in blocks:
|
| 23 |
+
b, symbol = VOCAB.abrv_to_idx(block.abrv), VOCAB.abrv_to_symbol(block.abrv)
|
| 24 |
+
x, a, positions, m, ca = [], [], [], [], []
|
| 25 |
+
atom2node_id = {}
|
| 26 |
+
if symbol == '?':
|
| 27 |
+
atom_missing = {}
|
| 28 |
+
else:
|
| 29 |
+
atom_missing = { atom_name: True for atom_name in const.backbone_atoms + const.sidechain_atoms[symbol] }
|
| 30 |
+
for atom in block:
|
| 31 |
+
atom2node_id[atom.name] = len(A) + len(cur_A) + len(a)
|
| 32 |
+
a.append(VOCAB.atom_to_idx(atom.get_element()))
|
| 33 |
+
x.append(atom.get_coord())
|
| 34 |
+
pos_code = ''.join((c for c in atom.get_pos_code() if not c.isdigit()))
|
| 35 |
+
positions.append(VOCAB.atom_pos_to_idx(pos_code))
|
| 36 |
+
if atom.name in atom_missing:
|
| 37 |
+
atom_missing[atom.name] = False
|
| 38 |
+
m.append(1)
|
| 39 |
+
ca.append(atom.name == 'CA')
|
| 40 |
+
atom_names.append(atom.name)
|
| 41 |
+
for atom_name in atom_missing:
|
| 42 |
+
if atom_missing[atom_name]:
|
| 43 |
+
atom2node_id[atom_name] = len(A) + len(cur_A) + len(a)
|
| 44 |
+
a.append(VOCAB.atom_to_idx(atom_name[0])) # only C, N, O, S in proteins
|
| 45 |
+
x.append([0, 0, 0])
|
| 46 |
+
pos_code = ''.join((c for c in atom_name[1:] if not c.isdigit()))
|
| 47 |
+
positions.append(VOCAB.atom_pos_to_idx(pos_code))
|
| 48 |
+
m.append(0)
|
| 49 |
+
ca.append(atom_name == 'CA')
|
| 50 |
+
atom_names.append(atom_name)
|
| 51 |
+
block_len = len(a)
|
| 52 |
+
cur_B.append(b)
|
| 53 |
+
cur_A.extend(a)
|
| 54 |
+
cur_X.extend(x)
|
| 55 |
+
cur_atom_positions.extend(positions)
|
| 56 |
+
cur_block_lengths.append(block_len)
|
| 57 |
+
cur_atom_mask.extend(m)
|
| 58 |
+
cur_is_ca.extend(ca)
|
| 59 |
+
|
| 60 |
+
# topology edges
|
| 61 |
+
for src, dst, bond_type in const.sidechain_bonds.get(VOCAB.abrv_to_symbol(block.abrv), []):
|
| 62 |
+
src, dst = atom2node_id[src], atom2node_id[dst]
|
| 63 |
+
topo_edge_index.append((src, dst)) # no direction
|
| 64 |
+
topo_edge_index.append((dst, src))
|
| 65 |
+
topo_edge_attr.append(bond_type)
|
| 66 |
+
topo_edge_attr.append(bond_type)
|
| 67 |
+
if last_c_node_id is not None and ('CA' in atom2node_id):
|
| 68 |
+
src, dst = last_c_node_id, atom2node_id['N']
|
| 69 |
+
topo_edge_index.append((src, dst)) # no direction
|
| 70 |
+
topo_edge_index.append((dst, src))
|
| 71 |
+
topo_edge_attr.append(4)
|
| 72 |
+
topo_edge_attr.append(4)
|
| 73 |
+
if 'CA' not in atom2node_id:
|
| 74 |
+
last_c_node_id = None
|
| 75 |
+
else:
|
| 76 |
+
last_c_node_id = atom2node_id['C']
|
| 77 |
+
|
| 78 |
+
# update coordinates of the global node to the center
|
| 79 |
+
# cur_X[0] = np.mean(cur_X[1:], axis=0)
|
| 80 |
+
cur_segment_ids = [i for _ in cur_B]
|
| 81 |
+
|
| 82 |
+
# finish these blocks
|
| 83 |
+
B.extend(cur_B)
|
| 84 |
+
A.extend(cur_A)
|
| 85 |
+
X.extend(cur_X)
|
| 86 |
+
atom_positions.extend(cur_atom_positions)
|
| 87 |
+
block_lengths.extend(cur_block_lengths)
|
| 88 |
+
segment_ids.extend(cur_segment_ids)
|
| 89 |
+
atom_mask.extend(cur_atom_mask)
|
| 90 |
+
is_ca.extend(cur_is_ca)
|
| 91 |
+
|
| 92 |
+
X = np.array(X).tolist()
|
| 93 |
+
topo_edge_index = np.array(topo_edge_index).T.tolist()
|
| 94 |
+
topo_edge_attr = (np.array(topo_edge_attr) - 1).tolist() # type starts from 0 but bond type starts from 1
|
| 95 |
+
|
| 96 |
+
data = {
|
| 97 |
+
'X': X, # [Natom, 2, 3]
|
| 98 |
+
'B': B, # [Nb], block (residue) type
|
| 99 |
+
'A': A, # [Natom]
|
| 100 |
+
'atom_positions': atom_positions, # [Natom]
|
| 101 |
+
'block_lengths': block_lengths, # [Nresidue]
|
| 102 |
+
'segment_ids': segment_ids, # [Nresidue]
|
| 103 |
+
'atom_mask': atom_mask, # [Natom]
|
| 104 |
+
'is_ca': is_ca, # [Natom]
|
| 105 |
+
'atom_names': atom_names, # [Natom]
|
| 106 |
+
'topo_edge_index': topo_edge_index, # atom level
|
| 107 |
+
'topo_edge_attr': topo_edge_attr
|
| 108 |
+
}
|
| 109 |
+
|
| 110 |
+
return data
|
data/converter/list_blocks_to_pdb.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/python
|
| 2 |
+
# -*- coding:utf-8 -*-
|
| 3 |
+
import os
|
| 4 |
+
from typing import List
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
from Bio.PDB import PDBParser, PDBIO
|
| 9 |
+
from Bio.PDB.Structure import Structure as BStructure
|
| 10 |
+
from Bio.PDB.Model import Model as BModel
|
| 11 |
+
from Bio.PDB.Chain import Chain as BChain
|
| 12 |
+
from Bio.PDB.Residue import Residue as BResidue
|
| 13 |
+
from Bio.PDB.Atom import Atom as BAtom
|
| 14 |
+
|
| 15 |
+
from data.format import Block, Atom, VOCAB
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def list_blocks_to_pdb(list_blocks: List[List[Block]], chain_names: List[str], out_path: str) -> None:
|
| 19 |
+
'''
|
| 20 |
+
Convert pdb file to a list of lists of blocks using Biopython.
|
| 21 |
+
Each chain will be a list of blocks.
|
| 22 |
+
|
| 23 |
+
Parameters:
|
| 24 |
+
list_blocks: A list of lists of blocks. Each list of blocks will be parsed into one chain in the pdb
|
| 25 |
+
chain_names: name of chains
|
| 26 |
+
out_path: Path to the pdb file
|
| 27 |
+
|
| 28 |
+
'''
|
| 29 |
+
pdb_id = os.path.basename(os.path.splitext(out_path)[0])
|
| 30 |
+
structure = BStructure(id=pdb_id)
|
| 31 |
+
model = BModel(id=0)
|
| 32 |
+
for blocks, chain_name in zip(list_blocks, chain_names):
|
| 33 |
+
chain = BChain(id=chain_name)
|
| 34 |
+
for i, block in enumerate(blocks):
|
| 35 |
+
chain.add(_block_to_biopython(block, i))
|
| 36 |
+
model.add(chain)
|
| 37 |
+
structure.add(model)
|
| 38 |
+
io = PDBIO()
|
| 39 |
+
io.set_structure(structure)
|
| 40 |
+
io.save(out_path)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def _block_to_biopython(block: Block, pos_code: int) -> BResidue:
|
| 44 |
+
_id = (' ', pos_code, ' ')
|
| 45 |
+
residue = BResidue(_id, block.abrv, ' ')
|
| 46 |
+
for i, atom in enumerate(block):
|
| 47 |
+
fullname = ' ' + atom.name
|
| 48 |
+
while len(fullname) < 4:
|
| 49 |
+
fullname += ' '
|
| 50 |
+
bio_atom = BAtom(
|
| 51 |
+
name=atom,
|
| 52 |
+
coord=np.array(atom.coordinate, dtype=np.float32),
|
| 53 |
+
bfactor=0,
|
| 54 |
+
occupancy=1.0,
|
| 55 |
+
altloc=' ',
|
| 56 |
+
fullname=fullname,
|
| 57 |
+
serial_number=i,
|
| 58 |
+
element=atom.element
|
| 59 |
+
)
|
| 60 |
+
residue.add(bio_atom)
|
| 61 |
+
return residue
|
data/converter/pdb_to_list_blocks.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/python
|
| 2 |
+
# -*- coding:utf-8 -*-
|
| 3 |
+
from typing import Dict, List, Optional, Union
|
| 4 |
+
|
| 5 |
+
from Bio.PDB import PDBParser
|
| 6 |
+
|
| 7 |
+
from data.format import Block, Atom
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def pdb_to_list_blocks(pdb: str, selected_chains: Optional[List[str]]=None, return_chain_ids: bool=False, dict_form: bool=False) -> Union[List[List[Block]], Dict[str, List[Block]]]:
|
| 11 |
+
'''
|
| 12 |
+
Convert pdb file to a list of lists of blocks using Biopython.
|
| 13 |
+
Each chain will be a list of blocks.
|
| 14 |
+
|
| 15 |
+
Parameters:
|
| 16 |
+
pdb: Path to the pdb file
|
| 17 |
+
selected_chains: List of selected chain ids. The returned list will be ordered
|
| 18 |
+
according to the ordering of chain ids in this parameter. If not specified,
|
| 19 |
+
all chains will be returned. e.g. ['A', 'B']
|
| 20 |
+
return_chain_ids: Whether to return the ids of each chain
|
| 21 |
+
dict_form: Whether to return chains in dict form (chain id as the key and blocks
|
| 22 |
+
as the value)
|
| 23 |
+
|
| 24 |
+
Returns:
|
| 25 |
+
A list of lists of blocks. Each chain in the pdb file will be parsed into
|
| 26 |
+
one list of blocks.
|
| 27 |
+
example:
|
| 28 |
+
[
|
| 29 |
+
[residueA1, residueA2, ...], # chain A
|
| 30 |
+
[residueB1, residueB2, ...] # chain B
|
| 31 |
+
],
|
| 32 |
+
where each residue is instantiated by Block data class.
|
| 33 |
+
'''
|
| 34 |
+
|
| 35 |
+
parser = PDBParser(QUIET=True)
|
| 36 |
+
structure = parser.get_structure('anonym', pdb)
|
| 37 |
+
|
| 38 |
+
list_blocks, chain_ids, chains = [], {}, []
|
| 39 |
+
|
| 40 |
+
for model in structure.get_models(): # use model 1 only
|
| 41 |
+
structure = model
|
| 42 |
+
break
|
| 43 |
+
|
| 44 |
+
for chain in structure.get_chains():
|
| 45 |
+
|
| 46 |
+
_id = chain.get_id()
|
| 47 |
+
if (selected_chains is not None) and (_id not in selected_chains):
|
| 48 |
+
continue
|
| 49 |
+
|
| 50 |
+
residues, res_ids = [], {}
|
| 51 |
+
|
| 52 |
+
for residue in chain:
|
| 53 |
+
abrv = residue.get_resname()
|
| 54 |
+
hetero_flag, res_number, insert_code = residue.get_id()
|
| 55 |
+
res_id = f'{res_number}-{insert_code}'
|
| 56 |
+
if hetero_flag == 'W':
|
| 57 |
+
continue # residue from glucose (WAT) or water (HOH)
|
| 58 |
+
if hetero_flag.strip() != '' and res_id in res_ids:
|
| 59 |
+
continue # the solvent (e.g. H_EDO (EDO))
|
| 60 |
+
if abrv in ['EDO', 'HOH', 'BME']: # solvent or other molecules
|
| 61 |
+
continue
|
| 62 |
+
if abrv == 'MSE':
|
| 63 |
+
abrv = 'MET' # MET is usually transformed to MSE for structural analysis
|
| 64 |
+
|
| 65 |
+
# filter Hs because not all data include them
|
| 66 |
+
atoms = [ Atom(atom.get_id(), atom.get_coord().tolist(), atom.element) for atom in residue if atom.element != 'H' ]
|
| 67 |
+
block = Block(abrv, atoms, id=(res_number, insert_code))
|
| 68 |
+
if block.is_residue():
|
| 69 |
+
residues.append(block)
|
| 70 |
+
res_ids[res_id] = True
|
| 71 |
+
|
| 72 |
+
if len(residues) == 0: # not a chain
|
| 73 |
+
continue
|
| 74 |
+
|
| 75 |
+
chain_ids[_id] = len(list_blocks)
|
| 76 |
+
list_blocks.append(residues)
|
| 77 |
+
chains.append(_id)
|
| 78 |
+
|
| 79 |
+
# reorder
|
| 80 |
+
if selected_chains is not None:
|
| 81 |
+
list_blocks = [list_blocks[chain_ids[chain_id]] for chain_id in selected_chains]
|
| 82 |
+
chains = selected_chains
|
| 83 |
+
|
| 84 |
+
if dict_form:
|
| 85 |
+
return { chain: blocks for chain, blocks in zip(chains, list_blocks)}
|
| 86 |
+
|
| 87 |
+
if return_chain_ids:
|
| 88 |
+
return list_blocks, chains
|
| 89 |
+
|
| 90 |
+
return list_blocks
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
if __name__ == '__main__':
|
| 94 |
+
import sys
|
| 95 |
+
list_blocks = pdb_to_list_blocks(sys.argv[1])
|
| 96 |
+
print(f'{sys.argv[1]} parsed')
|
| 97 |
+
print(f'number of chains: {len(list_blocks)}')
|
| 98 |
+
for i, chain in enumerate(list_blocks):
|
| 99 |
+
print(f'chain {i} lengths: {len(chain)}')
|
data/dataset_wrapper.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Callable
|
| 2 |
+
from tqdm import tqdm
|
| 3 |
+
from math import log
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
import sympy
|
| 8 |
+
|
| 9 |
+
from utils import register as R
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class MixDatasetWrapper(torch.utils.data.Dataset):
|
| 13 |
+
def __init__(self, *datasets, collate_fn: Callable=None) -> None:
|
| 14 |
+
super().__init__()
|
| 15 |
+
self.datasets = datasets
|
| 16 |
+
self.cum_len = []
|
| 17 |
+
self.total_len = 0
|
| 18 |
+
for dataset in datasets:
|
| 19 |
+
self.total_len += len(dataset)
|
| 20 |
+
self.cum_len.append(self.total_len)
|
| 21 |
+
self.collate_fn = self.datasets[0].collate_fn if collate_fn is None else collate_fn
|
| 22 |
+
if hasattr(datasets[0], '_lengths'):
|
| 23 |
+
self._lengths = []
|
| 24 |
+
for dataset in datasets:
|
| 25 |
+
self._lengths.extend(dataset._lengths)
|
| 26 |
+
|
| 27 |
+
def update_epoch(self):
|
| 28 |
+
for dataset in self.datasets:
|
| 29 |
+
if hasattr(dataset, 'update_epoch'):
|
| 30 |
+
dataset.update_epoch()
|
| 31 |
+
|
| 32 |
+
def get_len(self, idx):
|
| 33 |
+
return self._lengths[idx]
|
| 34 |
+
|
| 35 |
+
def __len__(self):
|
| 36 |
+
return self.total_len
|
| 37 |
+
|
| 38 |
+
def __getitem__(self, idx):
|
| 39 |
+
last_cum_len = 0
|
| 40 |
+
for i, cum_len in enumerate(self.cum_len):
|
| 41 |
+
if idx < cum_len:
|
| 42 |
+
return self.datasets[i].__getitem__(idx - last_cum_len)
|
| 43 |
+
last_cum_len = cum_len
|
| 44 |
+
return None # this is not possible
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
@R.register('DynamicBatchWrapper')
|
| 48 |
+
class DynamicBatchWrapper(torch.utils.data.Dataset):
|
| 49 |
+
def __init__(self, dataset, complexity, ubound_per_batch) -> None:
|
| 50 |
+
super().__init__()
|
| 51 |
+
self.dataset = dataset
|
| 52 |
+
self.indexes = [i for i in range(len(dataset))]
|
| 53 |
+
self.complexity = complexity
|
| 54 |
+
self.eval_func = sympy.lambdify('n', sympy.simplify(complexity))
|
| 55 |
+
self.ubound_per_batch = ubound_per_batch
|
| 56 |
+
self.total_size = None
|
| 57 |
+
self.batch_indexes = []
|
| 58 |
+
self._form_batch()
|
| 59 |
+
|
| 60 |
+
def __getattr__(self, attr):
|
| 61 |
+
if attr in self.__dict__:
|
| 62 |
+
return self.__dict__[attr]
|
| 63 |
+
elif hasattr(self.dataset, attr):
|
| 64 |
+
return getattr(self.dataset, attr)
|
| 65 |
+
else:
|
| 66 |
+
raise AttributeError(f"'DynamicBatchWrapper'(or '{type(self.dataset)}') object has no attribute '{attr}'")
|
| 67 |
+
|
| 68 |
+
def update_epoch(self):
|
| 69 |
+
if hasattr(self.dataset, 'update_epoch'):
|
| 70 |
+
self.dataset.update_epoch()
|
| 71 |
+
self._form_batch()
|
| 72 |
+
|
| 73 |
+
########## overload with your criterion ##########
|
| 74 |
+
def _form_batch(self):
|
| 75 |
+
|
| 76 |
+
np.random.shuffle(self.indexes)
|
| 77 |
+
last_batch_indexes = self.batch_indexes
|
| 78 |
+
self.batch_indexes = []
|
| 79 |
+
|
| 80 |
+
cur_complexity = 0
|
| 81 |
+
batch = []
|
| 82 |
+
|
| 83 |
+
for i in tqdm(self.indexes):
|
| 84 |
+
item_len = self.eval_func(self.dataset.get_len(i))
|
| 85 |
+
if item_len > self.ubound_per_batch:
|
| 86 |
+
continue
|
| 87 |
+
cur_complexity += item_len
|
| 88 |
+
if cur_complexity > self.ubound_per_batch:
|
| 89 |
+
self.batch_indexes.append(batch)
|
| 90 |
+
batch = []
|
| 91 |
+
cur_complexity = item_len
|
| 92 |
+
batch.append(i)
|
| 93 |
+
self.batch_indexes.append(batch)
|
| 94 |
+
|
| 95 |
+
if self.total_size is None:
|
| 96 |
+
self.total_size = len(self.batch_indexes)
|
| 97 |
+
else:
|
| 98 |
+
# control the lengths of the dataset, otherwise the dataloader will raise error
|
| 99 |
+
if len(self.batch_indexes) < self.total_size:
|
| 100 |
+
num_add = self.total_size - len(self.batch_indexes)
|
| 101 |
+
self.batch_indexes = self.batch_indexes + last_batch_indexes[:num_add]
|
| 102 |
+
else:
|
| 103 |
+
self.batch_indexes = self.batch_indexes[:self.total_size]
|
| 104 |
+
|
| 105 |
+
def __len__(self):
|
| 106 |
+
return len(self.batch_indexes)
|
| 107 |
+
|
| 108 |
+
def __getitem__(self, idx):
|
| 109 |
+
return [self.dataset[i] for i in self.batch_indexes[idx]]
|
| 110 |
+
|
| 111 |
+
def collate_fn(self, batched_batch):
|
| 112 |
+
batch = []
|
| 113 |
+
for minibatch in batched_batch:
|
| 114 |
+
batch.extend(minibatch)
|
| 115 |
+
return self.dataset.collate_fn(batch)
|
data/format.py
ADDED
|
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/python
|
| 2 |
+
# -*- coding:utf-8 -*-
|
| 3 |
+
from copy import copy
|
| 4 |
+
from typing import List, Tuple, Iterator, Optional
|
| 5 |
+
|
| 6 |
+
from utils import const
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class MoleculeVocab:
|
| 10 |
+
|
| 11 |
+
MAX_ATOM_NUMBER = 14
|
| 12 |
+
|
| 13 |
+
def __init__(self):
|
| 14 |
+
self.backbone_atoms = ['N', 'CA', 'C', 'O']
|
| 15 |
+
self.PAD, self.MASK, self.UNK, self.LAT = '#', '*', '?', '&' # pad / mask / unk / latent node
|
| 16 |
+
specials = [# special added
|
| 17 |
+
(self.PAD, 'PAD'), (self.MASK, 'MASK'), (self.UNK, 'UNK'), # pad / mask / unk
|
| 18 |
+
(self.LAT, '<L>') # latent node in latent space
|
| 19 |
+
]
|
| 20 |
+
|
| 21 |
+
aas = const.aas
|
| 22 |
+
|
| 23 |
+
# sms = [(e.lower(), e) for e in const.periodic_table]
|
| 24 |
+
sms = [] # disable small molecule vocabulary
|
| 25 |
+
|
| 26 |
+
self.atom_pad, self.atom_mask, self.atom_latent = 'pad', 'msk', 'lat' # Avoid conflict with atom P
|
| 27 |
+
self.atom_pos_pad, self.atom_pos_mask, self.atom_pos_latent = 'pad', 'msk', 'lat'
|
| 28 |
+
self.atom_pos_sm = 'sml' # small molecule
|
| 29 |
+
|
| 30 |
+
# block level vocab
|
| 31 |
+
self.idx2block = specials + aas + sms
|
| 32 |
+
self.symbol2idx, self.abrv2idx = {}, {}
|
| 33 |
+
for i, (symbol, abrv) in enumerate(self.idx2block):
|
| 34 |
+
self.symbol2idx[symbol] = i
|
| 35 |
+
self.abrv2idx[abrv] = i
|
| 36 |
+
self.special_mask = [1 for _ in specials] + [0 for _ in aas] + [0 for _ in sms]
|
| 37 |
+
|
| 38 |
+
# atom level vocab
|
| 39 |
+
self.idx2atom = [self.atom_pad, self.atom_mask, self.atom_latent] + const.periodic_table
|
| 40 |
+
self.idx2atom_pos = [self.atom_pos_pad, self.atom_pos_mask, self.atom_pos_latent, '', 'A', 'B', 'G', 'D', 'E', 'Z', 'H', 'XT', 'P', self.atom_pos_sm] # SM is for atoms in small molecule, 'P' for O1P, O2P, O3P
|
| 41 |
+
self.atom2idx, self.atom_pos2idx = {}, {}
|
| 42 |
+
self.atom2idx = {}
|
| 43 |
+
for i, atom in enumerate(self.idx2atom):
|
| 44 |
+
self.atom2idx[atom] = i
|
| 45 |
+
for i, atom_pos in enumerate(self.idx2atom_pos):
|
| 46 |
+
self.atom_pos2idx[atom_pos] = i
|
| 47 |
+
|
| 48 |
+
# block level APIs
|
| 49 |
+
|
| 50 |
+
def abrv_to_symbol(self, abrv):
|
| 51 |
+
idx = self.abrv_to_idx(abrv)
|
| 52 |
+
return None if idx is None else self.idx2block[idx][0]
|
| 53 |
+
|
| 54 |
+
def symbol_to_abrv(self, symbol):
|
| 55 |
+
idx = self.symbol_to_idx(symbol)
|
| 56 |
+
return None if idx is None else self.idx2block[idx][1]
|
| 57 |
+
|
| 58 |
+
def abrv_to_idx(self, abrv):
|
| 59 |
+
abrv = abrv.upper()
|
| 60 |
+
return self.abrv2idx.get(abrv, self.abrv2idx['UNK'])
|
| 61 |
+
|
| 62 |
+
def symbol_to_idx(self, symbol):
|
| 63 |
+
# symbol = symbol.upper()
|
| 64 |
+
return self.symbol2idx.get(symbol, self.abrv2idx['UNK'])
|
| 65 |
+
|
| 66 |
+
def idx_to_symbol(self, idx):
|
| 67 |
+
return self.idx2block[idx][0]
|
| 68 |
+
|
| 69 |
+
def idx_to_abrv(self, idx):
|
| 70 |
+
return self.idx2block[idx][1]
|
| 71 |
+
|
| 72 |
+
def get_pad_idx(self):
|
| 73 |
+
return self.symbol_to_idx(self.PAD)
|
| 74 |
+
|
| 75 |
+
def get_mask_idx(self):
|
| 76 |
+
return self.symbol_to_idx(self.MASK)
|
| 77 |
+
|
| 78 |
+
def get_special_mask(self):
|
| 79 |
+
return copy(self.special_mask)
|
| 80 |
+
|
| 81 |
+
# atom level APIs
|
| 82 |
+
|
| 83 |
+
def get_atom_pad_idx(self):
|
| 84 |
+
return self.atom2idx[self.atom_pad]
|
| 85 |
+
|
| 86 |
+
def get_atom_mask_idx(self):
|
| 87 |
+
return self.atom2idx[self.atom_mask]
|
| 88 |
+
|
| 89 |
+
def get_atom_latent_idx(self):
|
| 90 |
+
return self.atom2idx[self.atom_latent]
|
| 91 |
+
|
| 92 |
+
def get_atom_pos_pad_idx(self):
|
| 93 |
+
return self.atom_pos2idx[self.atom_pos_pad]
|
| 94 |
+
|
| 95 |
+
def get_atom_pos_mask_idx(self):
|
| 96 |
+
return self.atom_pos2idx[self.atom_pos_mask]
|
| 97 |
+
|
| 98 |
+
def get_atom_pos_latent_idx(self):
|
| 99 |
+
return self.atom_pos2idx[self.atom_pos_latent]
|
| 100 |
+
|
| 101 |
+
def idx_to_atom(self, idx):
|
| 102 |
+
return self.idx2atom[idx]
|
| 103 |
+
|
| 104 |
+
def atom_to_idx(self, atom):
|
| 105 |
+
atom = atom.upper()
|
| 106 |
+
return self.atom2idx.get(atom, self.atom2idx[self.atom_mask])
|
| 107 |
+
|
| 108 |
+
def idx_to_atom_pos(self, idx):
|
| 109 |
+
return self.idx2atom_pos[idx]
|
| 110 |
+
|
| 111 |
+
def atom_pos_to_idx(self, atom_pos):
|
| 112 |
+
return self.atom_pos2idx.get(atom_pos, self.atom_pos2idx[self.atom_pos_mask])
|
| 113 |
+
|
| 114 |
+
# sizes
|
| 115 |
+
|
| 116 |
+
def get_num_atom_type(self):
|
| 117 |
+
return len(self.idx2atom)
|
| 118 |
+
|
| 119 |
+
def get_num_atom_pos(self):
|
| 120 |
+
return len(self.idx2atom_pos)
|
| 121 |
+
|
| 122 |
+
def get_num_block_type(self):
|
| 123 |
+
return len(self.special_mask) - sum(self.special_mask)
|
| 124 |
+
|
| 125 |
+
def __len__(self):
|
| 126 |
+
return len(self.symbol2idx)
|
| 127 |
+
|
| 128 |
+
# others
|
| 129 |
+
@property
|
| 130 |
+
def ca_channel_idx(self):
|
| 131 |
+
return self.backbone_atoms.index('CA')
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
VOCAB = MoleculeVocab()
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
class Atom:
|
| 138 |
+
def __init__(self, atom_name: str, coordinate: List[float], element: str, pos_code: str=None):
|
| 139 |
+
self.name = atom_name
|
| 140 |
+
self.coordinate = coordinate
|
| 141 |
+
self.element = element
|
| 142 |
+
if pos_code is None:
|
| 143 |
+
pos_code = atom_name.lstrip(element)
|
| 144 |
+
self.pos_code = pos_code
|
| 145 |
+
else:
|
| 146 |
+
self.pos_code = pos_code
|
| 147 |
+
|
| 148 |
+
def get_element(self):
|
| 149 |
+
return self.element
|
| 150 |
+
|
| 151 |
+
def get_coord(self):
|
| 152 |
+
return copy(self.coordinate)
|
| 153 |
+
|
| 154 |
+
def get_pos_code(self):
|
| 155 |
+
return self.pos_code
|
| 156 |
+
|
| 157 |
+
def __str__(self) -> str:
|
| 158 |
+
return self.name
|
| 159 |
+
|
| 160 |
+
def __repr__(self) -> str:
|
| 161 |
+
return f"Atom ({self.name}): {self.element}({self.pos_code}) [{','.join(['{:.4f}'.format(num) for num in self.coordinate])}]"
|
| 162 |
+
|
| 163 |
+
def to_tuple(self):
|
| 164 |
+
return (
|
| 165 |
+
self.name,
|
| 166 |
+
self.coordinate,
|
| 167 |
+
self.element,
|
| 168 |
+
self.pos_code
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
@classmethod
|
| 172 |
+
def from_tuple(self, data):
|
| 173 |
+
return Atom(
|
| 174 |
+
atom_name=data[0],
|
| 175 |
+
coordinate=data[1],
|
| 176 |
+
element=data[2],
|
| 177 |
+
pos_code=data[3]
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
class Block:
|
| 182 |
+
def __init__(self, abrv: str, units: List[Atom], id: Optional[any]=None) -> None:
|
| 183 |
+
self.abrv: str = abrv
|
| 184 |
+
self.units: List[Atom] = units
|
| 185 |
+
self._uname2idx = { unit.name: i for i, unit in enumerate(self.units) }
|
| 186 |
+
self.id = id
|
| 187 |
+
|
| 188 |
+
def __len__(self) -> int:
|
| 189 |
+
return len(self.units)
|
| 190 |
+
|
| 191 |
+
def __iter__(self) -> Iterator[Atom]:
|
| 192 |
+
return iter(self.units)
|
| 193 |
+
|
| 194 |
+
def get_unit_by_name(self, name: str) -> Atom:
|
| 195 |
+
idx = self._uname2idx[name]
|
| 196 |
+
return self.units[idx]
|
| 197 |
+
|
| 198 |
+
def has_unit(self, name: str) -> bool:
|
| 199 |
+
return name in self._uname2idx
|
| 200 |
+
|
| 201 |
+
def to_tuple(self):
|
| 202 |
+
return (
|
| 203 |
+
self.abrv,
|
| 204 |
+
[unit.to_tuple() for unit in self.units],
|
| 205 |
+
self.id
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
def is_residue(self):
|
| 209 |
+
return self.has_unit('CA') and self.has_unit('N') and self.has_unit('C') and self.has_unit('O')
|
| 210 |
+
|
| 211 |
+
@classmethod
|
| 212 |
+
def from_tuple(self, data):
|
| 213 |
+
return Block(
|
| 214 |
+
abrv=data[0],
|
| 215 |
+
units=[Atom.from_tuple(unit_data) for unit_data in data[1]],
|
| 216 |
+
id=data[2]
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
def __repr__(self) -> str:
|
| 220 |
+
return f"Block ({self.abrv}):\n\t" + '\n\t'.join([repr(at) for at in self.units]) + '\n'
|
data/mmap_dataset.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/python
|
| 2 |
+
# -*- coding:utf-8 -*-
|
| 3 |
+
import os
|
| 4 |
+
import io
|
| 5 |
+
import gzip
|
| 6 |
+
import json
|
| 7 |
+
import mmap
|
| 8 |
+
from typing import Optional
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def compress(x):
|
| 15 |
+
serialized_x = json.dumps(x).encode()
|
| 16 |
+
buf = io.BytesIO()
|
| 17 |
+
with gzip.GzipFile(fileobj=buf, mode='wb', compresslevel=6) as f:
|
| 18 |
+
f.write(serialized_x)
|
| 19 |
+
compressed = buf.getvalue()
|
| 20 |
+
return compressed
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def decompress(compressed_x):
|
| 24 |
+
buf = io.BytesIO(compressed_x)
|
| 25 |
+
with gzip.GzipFile(fileobj=buf, mode="rb") as f:
|
| 26 |
+
serialized_x = f.read().decode()
|
| 27 |
+
x = json.loads(serialized_x)
|
| 28 |
+
return x
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def _find_measure_unit(num_bytes):
|
| 32 |
+
size, measure_unit = num_bytes, 'Bytes'
|
| 33 |
+
for unit in ['KB', 'MB', 'GB']:
|
| 34 |
+
if size > 1000:
|
| 35 |
+
size /= 1024
|
| 36 |
+
measure_unit = unit
|
| 37 |
+
else:
|
| 38 |
+
break
|
| 39 |
+
return size, measure_unit
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def create_mmap(iterator, out_dir, total_len=None, commit_batch=10000):
|
| 43 |
+
|
| 44 |
+
if not os.path.exists(out_dir):
|
| 45 |
+
os.makedirs(out_dir)
|
| 46 |
+
|
| 47 |
+
data_file_path = os.path.join(out_dir, 'data.bin')
|
| 48 |
+
data_file = open(data_file_path, 'wb')
|
| 49 |
+
index_file = open(os.path.join(out_dir, 'index.txt'), 'w')
|
| 50 |
+
|
| 51 |
+
i, offset, n_finished = 0, 0, 0
|
| 52 |
+
progress_bar = tqdm(iterator, total=total_len)
|
| 53 |
+
for _id, x, properties, entry_idx in iterator:
|
| 54 |
+
progress_bar.set_description(f'Processing {_id}')
|
| 55 |
+
compressed_x = compress(x)
|
| 56 |
+
bin_length = data_file.write(compressed_x)
|
| 57 |
+
properties = '\t'.join([str(prop) for prop in properties])
|
| 58 |
+
index_file.write(f'{_id}\t{offset}\t{offset + bin_length}\t{properties}\n') # tuple of (_id, start, end), data slice is [start, end)
|
| 59 |
+
offset += bin_length
|
| 60 |
+
i += 1
|
| 61 |
+
|
| 62 |
+
if entry_idx > n_finished:
|
| 63 |
+
progress_bar.update(entry_idx - n_finished)
|
| 64 |
+
n_finished = entry_idx
|
| 65 |
+
if total_len is not None:
|
| 66 |
+
expected_size = os.fstat(data_file.fileno()).st_size / n_finished * total_len
|
| 67 |
+
expected_size, measure_unit = _find_measure_unit(expected_size)
|
| 68 |
+
progress_bar.set_postfix({f'{i} saved; Estimated total size ({measure_unit})': expected_size})
|
| 69 |
+
|
| 70 |
+
if i % commit_batch == 0:
|
| 71 |
+
data_file.flush() # save from memory to disk
|
| 72 |
+
index_file.flush()
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
data_file.close()
|
| 76 |
+
index_file.close()
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
class MMAPDataset(torch.utils.data.Dataset):
|
| 80 |
+
|
| 81 |
+
def __init__(self, mmap_dir: str, specify_data: Optional[str]=None, specify_index: Optional[str]=None) -> None:
|
| 82 |
+
super().__init__()
|
| 83 |
+
|
| 84 |
+
self._indexes = []
|
| 85 |
+
self._properties = []
|
| 86 |
+
_index_path = os.path.join(mmap_dir, 'index.txt') if specify_index is None else specify_index
|
| 87 |
+
with open(_index_path, 'r') as f:
|
| 88 |
+
for line in f.readlines():
|
| 89 |
+
messages = line.strip().split('\t')
|
| 90 |
+
_id, start, end = messages[:3]
|
| 91 |
+
_property = messages[3:]
|
| 92 |
+
self._indexes.append((_id, int(start), int(end)))
|
| 93 |
+
self._properties.append(_property)
|
| 94 |
+
_data_path = os.path.join(mmap_dir, 'data.bin') if specify_data is None else specify_data
|
| 95 |
+
self._data_file = open(_data_path, 'rb')
|
| 96 |
+
self._mmap = mmap.mmap(self._data_file.fileno(), 0, access=mmap.ACCESS_READ)
|
| 97 |
+
|
| 98 |
+
def __del__(self):
|
| 99 |
+
self._mmap.close()
|
| 100 |
+
self._data_file.close()
|
| 101 |
+
|
| 102 |
+
def __len__(self):
|
| 103 |
+
return len(self._indexes)
|
| 104 |
+
|
| 105 |
+
def __getitem__(self, idx: int):
|
| 106 |
+
if idx < 0 or idx >= len(self):
|
| 107 |
+
raise IndexError(idx)
|
| 108 |
+
|
| 109 |
+
_, start, end = self._indexes[idx]
|
| 110 |
+
data = decompress(self._mmap[start:end])
|
| 111 |
+
|
| 112 |
+
return data
|
data/resample.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/python
|
| 2 |
+
# -*- coding:utf-8 -*-
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class ClusterResampler:
|
| 7 |
+
def __init__(self, cluster_path: str) -> None:
|
| 8 |
+
idx2prob = []
|
| 9 |
+
with open(cluster_path, 'r') as fin:
|
| 10 |
+
for line in fin:
|
| 11 |
+
cluster_n_member = int(line.strip().split('\t')[-1])
|
| 12 |
+
idx2prob.append(1 / cluster_n_member)
|
| 13 |
+
total = sum(idx2prob)
|
| 14 |
+
idx2prob = [p / total for p in idx2prob]
|
| 15 |
+
self.idx2prob = np.array(idx2prob)
|
| 16 |
+
|
| 17 |
+
def __call__(self, n_sample:int, replace: bool=False):
|
| 18 |
+
idxs = np.random.choice(len(self.idx2prob), size=n_sample, replace=replace, p=self.idx2prob)
|
| 19 |
+
return idxs
|
env.yaml
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: PepGLAD
|
| 2 |
+
channels:
|
| 3 |
+
- pytorch
|
| 4 |
+
- nvidia
|
| 5 |
+
- bioconda
|
| 6 |
+
- pyg
|
| 7 |
+
- salilab
|
| 8 |
+
- conda-forge
|
| 9 |
+
- defaults
|
| 10 |
+
dependencies:
|
| 11 |
+
- python=3.9
|
| 12 |
+
- pytorch::pytorch=1.13.1
|
| 13 |
+
- pytorch::pytorch-cuda=11.7
|
| 14 |
+
- nvidia::cudatoolkit=11.7.0
|
| 15 |
+
- pyg::pytorch-scatter
|
| 16 |
+
- mkl=2024.0.0
|
| 17 |
+
- salilab::dssp
|
| 18 |
+
- anaconda::libboost=1.73.0
|
| 19 |
+
- mmseqs2
|
| 20 |
+
- openmm=8.0.0
|
| 21 |
+
- pdbfixer
|
| 22 |
+
- pip
|
| 23 |
+
- pip:
|
| 24 |
+
- biopython==1.80
|
| 25 |
+
- rdkit-pypi==2022.3.5
|
| 26 |
+
- ray
|
| 27 |
+
- sympy
|
| 28 |
+
- scipy
|
| 29 |
+
- freesasa
|
| 30 |
+
- tensorboard
|
| 31 |
+
- pyyaml
|
| 32 |
+
- tqdm
|
evaluation/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/python
|
| 2 |
+
# -*- coding:utf-8 -*-
|
| 3 |
+
|
evaluation/dG/RosettaFastRelaxUtil.xml
ADDED
|
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<ROSETTASCRIPTS>
|
| 2 |
+
<SCOREFXNS>
|
| 3 |
+
<ScoreFunction name="sfxn_soft" weights="beta_nov16_soft" />
|
| 4 |
+
<ScoreFunction name="sfxn" weights="beta_nov16" />
|
| 5 |
+
<ScoreFunction name="sfxn_relax" weights="beta_nov16" >
|
| 6 |
+
<Reweight scoretype="arg_cation_pi" weight="3" />
|
| 7 |
+
<Reweight scoretype="approximate_buried_unsat_penalty" weight="5" />
|
| 8 |
+
<Set approximate_buried_unsat_penalty_burial_atomic_depth="3.5" />
|
| 9 |
+
<Set approximate_buried_unsat_penalty_hbond_energy_threshold="-0.5" />
|
| 10 |
+
</ScoreFunction>
|
| 11 |
+
<ScoreFunction name="sfxn_softish" weights="beta_nov16" >
|
| 12 |
+
<Reweight scoretype="fa_rep" weight="0.15" />
|
| 13 |
+
</ScoreFunction>
|
| 14 |
+
<ScoreFunction name="sfxn_fa_atr" weights="empty" >
|
| 15 |
+
<Reweight scoretype="fa_atr" weight="1" />
|
| 16 |
+
</ScoreFunction>
|
| 17 |
+
<ScoreFunction name="vdw_sol" weights="empty" >
|
| 18 |
+
<Reweight scoretype="fa_atr" weight="1.0" />
|
| 19 |
+
<Reweight scoretype="fa_rep" weight="0.55" />
|
| 20 |
+
<Reweight scoretype="fa_sol" weight="1.0" />
|
| 21 |
+
</ScoreFunction>
|
| 22 |
+
|
| 23 |
+
</SCOREFXNS>
|
| 24 |
+
<RESIDUE_SELECTORS>
|
| 25 |
+
<Chain name="chainA" chains="A"/>
|
| 26 |
+
<Chain name="chainB" chains="B"/>
|
| 27 |
+
<Neighborhood name="interface_chA" selector="chainB" distance="8.0" />
|
| 28 |
+
<Neighborhood name="interface_chB" selector="chainA" distance="8.0" />
|
| 29 |
+
<And name="AB_interface" selectors="interface_chA,interface_chB" />
|
| 30 |
+
<Not name="Not_interface" selector="AB_interface" />
|
| 31 |
+
<And name="actual_interface_chB" selectors="AB_interface,chainB" />
|
| 32 |
+
<And name="not_interface_chB" selectors="Not_interface,chainB" />
|
| 33 |
+
|
| 34 |
+
<ResidueName name="apolar" residue_name3="ALA,CYS,PHE,ILE,LEU,MET,THR,PRO,VAL,TRP,TYR" />
|
| 35 |
+
<Not name="polar" selector="apolar" />
|
| 36 |
+
|
| 37 |
+
<True name="all" />
|
| 38 |
+
|
| 39 |
+
<ResidueName name="pro_and_gly_positions" residue_name3="PRO,GLY" />
|
| 40 |
+
|
| 41 |
+
<ResiduePDBInfoHasLabel name="HOTSPOT_res" property="HOTSPOT" />
|
| 42 |
+
</RESIDUE_SELECTORS>
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
<RESIDUE_SELECTORS>
|
| 46 |
+
<!-- Layer Design -->
|
| 47 |
+
<Layer name="surface" select_core="false" select_boundary="false" select_surface="true" use_sidechain_neighbors="true"/>
|
| 48 |
+
<Layer name="boundary" select_core="false" select_boundary="true" select_surface="false" use_sidechain_neighbors="true"/>
|
| 49 |
+
<Layer name="core" select_core="true" select_boundary="false" select_surface="false" use_sidechain_neighbors="true"/>
|
| 50 |
+
<SecondaryStructure name="sheet" overlap="0" minH="3" minE="2" include_terminal_loops="false" use_dssp="true" ss="E"/>
|
| 51 |
+
<SecondaryStructure name="entire_loop" overlap="0" minH="3" minE="2" include_terminal_loops="true" use_dssp="true" ss="L"/>
|
| 52 |
+
<SecondaryStructure name="entire_helix" overlap="0" minH="3" minE="2" include_terminal_loops="false" use_dssp="true" ss="H"/>
|
| 53 |
+
<And name="helix_cap" selectors="entire_loop">
|
| 54 |
+
<PrimarySequenceNeighborhood lower="1" upper="0" selector="entire_helix"/>
|
| 55 |
+
</And>
|
| 56 |
+
<And name="helix_start" selectors="entire_helix">
|
| 57 |
+
<PrimarySequenceNeighborhood lower="0" upper="1" selector="helix_cap"/>
|
| 58 |
+
</And>
|
| 59 |
+
<And name="helix" selectors="entire_helix">
|
| 60 |
+
<Not selector="helix_start"/>
|
| 61 |
+
</And>
|
| 62 |
+
<And name="loop" selectors="entire_loop">
|
| 63 |
+
<Not selector="helix_cap"/>
|
| 64 |
+
</And>
|
| 65 |
+
|
| 66 |
+
</RESIDUE_SELECTORS>
|
| 67 |
+
|
| 68 |
+
<TASKOPERATIONS>
|
| 69 |
+
<DesignRestrictions name="layer_design_no_core_polars">
|
| 70 |
+
<Action selector_logic="surface AND helix_start" aas="DEHKPQR"/>
|
| 71 |
+
<Action selector_logic="surface AND helix" aas="EHKQR"/>
|
| 72 |
+
<Action selector_logic="surface AND sheet" aas="EHKNQRST"/>
|
| 73 |
+
<Action selector_logic="surface AND loop" aas="DEGHKNPQRST"/>
|
| 74 |
+
<Action selector_logic="boundary AND helix_start" aas="ADEHIKLNPQRSTVWY"/>
|
| 75 |
+
<Action selector_logic="boundary AND helix" aas="ADEHIKLNQRSTVWY"/>
|
| 76 |
+
<Action selector_logic="boundary AND sheet" aas="DEFHIKLNQRSTVWY"/>
|
| 77 |
+
<Action selector_logic="boundary AND loop" aas="ADEFGHIKLNPQRSTVWY"/>
|
| 78 |
+
<Action selector_logic="core AND helix_start" aas="AFILMPVWY"/>
|
| 79 |
+
<Action selector_logic="core AND helix" aas="AFILVWY"/>
|
| 80 |
+
<Action selector_logic="core AND sheet" aas="FILVWY"/>
|
| 81 |
+
<Action selector_logic="core AND loop" aas="AFGILPVWY"/>
|
| 82 |
+
<Action selector_logic="helix_cap" aas="DNST"/>
|
| 83 |
+
</DesignRestrictions>
|
| 84 |
+
</TASKOPERATIONS>
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
<TASKOPERATIONS>
|
| 88 |
+
<ProteinProteinInterfaceUpweighter name="upweight_interface" interface_weight="3" />
|
| 89 |
+
<ProteinInterfaceDesign name="pack_long" design_chain1="0" design_chain2="0" jump="1" interface_distance_cutoff="15"/>
|
| 90 |
+
<InitializeFromCommandline name="init" />
|
| 91 |
+
<IncludeCurrent name="current" />
|
| 92 |
+
<LimitAromaChi2 name="limitchi2" chi2max="110" chi2min="70" include_trp="True" />
|
| 93 |
+
<ExtraRotamersGeneric name="ex1_ex2" ex1="1" ex2="1" />
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
<OperateOnResidueSubset name="restrict_target_not_interface" selector="not_interface_chB">
|
| 97 |
+
<PreventRepackingRLT/>
|
| 98 |
+
</OperateOnResidueSubset>
|
| 99 |
+
<OperateOnResidueSubset name="restrict2repacking" selector="all">
|
| 100 |
+
<RestrictToRepackingRLT/>
|
| 101 |
+
</OperateOnResidueSubset>
|
| 102 |
+
<OperateOnResidueSubset name="restrict_to_interface" selector="Not_interface">
|
| 103 |
+
<PreventRepackingRLT/>
|
| 104 |
+
</OperateOnResidueSubset>
|
| 105 |
+
<OperateOnResidueSubset name="restrict_target2repacking" selector="chainB">
|
| 106 |
+
<RestrictToRepackingRLT/>
|
| 107 |
+
</OperateOnResidueSubset>
|
| 108 |
+
<OperateOnResidueSubset name="restrict_hotspots2repacking" selector="HOTSPOT_res">
|
| 109 |
+
<RestrictToRepackingRLT/>
|
| 110 |
+
</OperateOnResidueSubset>
|
| 111 |
+
|
| 112 |
+
<DisallowIfNonnative name="disallow_GLY" resnum="0" disallow_aas="G" />
|
| 113 |
+
<DisallowIfNonnative name="disallow_PRO" resnum="0" disallow_aas="P" />
|
| 114 |
+
<SelectBySASA name="PR_monomer_core" mode="sc" state="monomer" probe_radius="2.2" core_asa="10" surface_asa="10" core="0" boundary="1" surface="1" verbose="0" />
|
| 115 |
+
|
| 116 |
+
<OperateOnResidueSubset name="restrict_PRO_GLY" selector="pro_and_gly_positions">
|
| 117 |
+
<PreventRepackingRLT/>
|
| 118 |
+
</OperateOnResidueSubset>
|
| 119 |
+
|
| 120 |
+
PruneBadRotamers name="prune_bad_rotamers" probability_cut="0.01" />
|
| 121 |
+
|
| 122 |
+
</TASKOPERATIONS>
|
| 123 |
+
<MOVERS>
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
<SwitchChainOrder name="chain1onlypre" chain_order="1" />
|
| 127 |
+
<ScoreMover name="scorepose" scorefxn="sfxn" verbose="false" />
|
| 128 |
+
<ParsedProtocol name="chain1only">
|
| 129 |
+
<Add mover="chain1onlypre" />
|
| 130 |
+
<Add mover="scorepose" />
|
| 131 |
+
</ParsedProtocol>
|
| 132 |
+
<TaskAwareMinMover name="min" scorefxn="sfxn" bb="0" chi="1" task_operations="pack_long" />
|
| 133 |
+
|
| 134 |
+
<DeleteRegionMover name="delete_polar" residue_selector="polar" rechain="false" />
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
</MOVERS>
|
| 138 |
+
<FILTERS>
|
| 139 |
+
|
| 140 |
+
<Time name="timed"/>
|
| 141 |
+
|
| 142 |
+
<Sasa name="interface_buried_sasa" confidence="0" />
|
| 143 |
+
<Ddg name="ddg" threshold="0" jump="1" repeats="1" repack="1" relax_mover="min" confidence="0" scorefxn="sfxn" />
|
| 144 |
+
<Ddg name="ddg_norepack" threshold="0" jump="1" repeats="1" repack="0" relax_mover="min" confidence="0" scorefxn="sfxn" />
|
| 145 |
+
<ShapeComplementarity name="interface_sc" verbose="0" min_sc="0.55" write_int_area="1" write_median_dist="1" jump="1" confidence="0"/>
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
### score function monomer terms
|
| 149 |
+
<ScoreType name="total_score_MBF" scorefxn="sfxn" score_type="total_score" threshold="0" confidence="0" />
|
| 150 |
+
<MoveBeforeFilter name="total_score_monomer" mover="chain1only" filter="total_score_MBF" confidence="0" />
|
| 151 |
+
<ResidueCount name="res_count_MBF" max_residue_count="9999" confidence="0"/>
|
| 152 |
+
<MoveBeforeFilter name="res_count_monomer" mover="chain1only" filter="res_count_MBF" confidence="0" />
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
<CalculatorFilter name="score_per_res" equation="total_score_monomer / res" threshold="-3.5" confidence="0">
|
| 156 |
+
<Var name="total_score_monomer" filter="total_score_monomer"/>
|
| 157 |
+
<Var name="res" filter="res_count_monomer"/>
|
| 158 |
+
</CalculatorFilter>
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
<InterfaceHydrophobicResidueContacts name="hydrophobic_residue_contacts" target_selector="chainB" binder_selector="chainA" scorefxn="sfxn_soft" confidence="0"/>
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
<Ddg name="ddg_hydrophobic_pre" threshold="-10" jump="1" repeats="1" repack="0" confidence="0" scorefxn="vdw_sol" />
|
| 165 |
+
<MoveBeforeFilter name="ddg_hydrophobic" mover="delete_polar" filter="ddg_hydrophobic_pre" confidence="0"/>
|
| 166 |
+
|
| 167 |
+
<ContactMolecularSurface name="contact_molecular_surface" distance_weight="0.5" target_selector="chainA" binder_selector="chainB" confidence="0" />
|
| 168 |
+
|
| 169 |
+
</FILTERS>
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
<MOVERS>
|
| 173 |
+
|
| 174 |
+
<FastRelax name="FastRelax" scorefxn="sfxn_relax" repeats="1" batch="false" ramp_down_constraints="false" cartesian="false" bondangle="false" bondlength="false" min_type="dfpmin_armijo_nonmonotone" task_operations="current,ex1_ex2,restrict_target_not_interface,limitchi2" >
|
| 175 |
+
<MoveMap name="MM" >
|
| 176 |
+
<Chain number="1" chi="true" bb="true" />
|
| 177 |
+
<Chain number="2" chi="true" bb="false" />
|
| 178 |
+
<Jump number="1" setting="true" />
|
| 179 |
+
</MoveMap>
|
| 180 |
+
</FastRelax>
|
| 181 |
+
|
| 182 |
+
</MOVERS>
|
| 183 |
+
<APPLY_TO_POSE>
|
| 184 |
+
</APPLY_TO_POSE>
|
| 185 |
+
|
| 186 |
+
<PROTOCOLS>
|
| 187 |
+
</PROTOCOLS>
|
| 188 |
+
|
| 189 |
+
<OUTPUT/>
|
| 190 |
+
</ROSETTASCRIPTS>
|
evaluation/dG/base.py
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/python
|
| 2 |
+
# -*- coding:utf-8 -*-
|
| 3 |
+
import os
|
| 4 |
+
import re
|
| 5 |
+
import json
|
| 6 |
+
from typing import Optional, Tuple, List
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
+
|
| 9 |
+
from .energy import pyrosetta_fastrelax, pyrosetta_interface_energy, rfdiff_refine
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@dataclass
|
| 13 |
+
class RelaxTask:
|
| 14 |
+
in_path: str
|
| 15 |
+
current_path: str
|
| 16 |
+
info: dict
|
| 17 |
+
status: str
|
| 18 |
+
rec_chain: str
|
| 19 |
+
pep_chain: str
|
| 20 |
+
rfdiff_relax: bool = False
|
| 21 |
+
dG: Optional[float] = None
|
| 22 |
+
|
| 23 |
+
def set_dG(self, dG):
|
| 24 |
+
self.dG = dG
|
| 25 |
+
|
| 26 |
+
def get_in_path_with_tag(self, tag):
|
| 27 |
+
name, ext = os.path.splitext(self.in_path)
|
| 28 |
+
new_path = f'{name}_{tag}{ext}'
|
| 29 |
+
return new_path
|
| 30 |
+
|
| 31 |
+
def set_current_path_tag(self, tag):
|
| 32 |
+
new_path = self.get_in_path_with_tag(tag)
|
| 33 |
+
self.current_path = new_path
|
| 34 |
+
return new_path
|
| 35 |
+
|
| 36 |
+
def check_current_path_exists(self):
|
| 37 |
+
ok = os.path.exists(self.current_path)
|
| 38 |
+
if not ok:
|
| 39 |
+
self.mark_failure()
|
| 40 |
+
if os.path.getsize(self.current_path) == 0:
|
| 41 |
+
ok = False
|
| 42 |
+
self.mark_failure()
|
| 43 |
+
os.unlink(self.current_path)
|
| 44 |
+
return ok
|
| 45 |
+
|
| 46 |
+
def update_if_finished(self, tag):
|
| 47 |
+
out_path = self.get_in_path_with_tag(tag)
|
| 48 |
+
if os.path.exists(out_path) and os.path.getsize(out_path) > 0:
|
| 49 |
+
# print('Already finished', out_path)
|
| 50 |
+
self.set_current_path_tag(tag)
|
| 51 |
+
self.mark_success()
|
| 52 |
+
return True
|
| 53 |
+
return False
|
| 54 |
+
|
| 55 |
+
def can_proceed(self):
|
| 56 |
+
self.check_current_path_exists()
|
| 57 |
+
return self.status != 'failed'
|
| 58 |
+
|
| 59 |
+
def mark_success(self):
|
| 60 |
+
self.status = 'success'
|
| 61 |
+
|
| 62 |
+
def mark_failure(self):
|
| 63 |
+
self.status = 'failed'
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class TaskScanner:
|
| 67 |
+
|
| 68 |
+
def __init__(self, results, n_sample, rfdiff_relax):
|
| 69 |
+
super().__init__()
|
| 70 |
+
self.results = results
|
| 71 |
+
self.n_sample = n_sample
|
| 72 |
+
self.rfdiff_relax = rfdiff_relax
|
| 73 |
+
self.visited = set()
|
| 74 |
+
|
| 75 |
+
def scan(self) -> List[RelaxTask]:
|
| 76 |
+
tasks = []
|
| 77 |
+
root_dir = os.path.dirname(self.results)
|
| 78 |
+
with open(self.results, 'r') as fin:
|
| 79 |
+
lines = fin.readlines()
|
| 80 |
+
for line in lines:
|
| 81 |
+
item = json.loads(line)
|
| 82 |
+
if item['number'] >= self.n_sample:
|
| 83 |
+
continue
|
| 84 |
+
_id = f"{item['id']}_{item['number']}"
|
| 85 |
+
if _id in self.visited:
|
| 86 |
+
continue
|
| 87 |
+
gen_pdb = os.path.split(item['gen_pdb'])[-1]
|
| 88 |
+
# subdir = gen_pdb.split('_')[0]
|
| 89 |
+
subdir = '_'.join(gen_pdb.split('_')[:-2])
|
| 90 |
+
gen_pdb = os.path.join(root_dir, 'candidates', subdir, gen_pdb)
|
| 91 |
+
tasks.append(RelaxTask(
|
| 92 |
+
in_path=gen_pdb,
|
| 93 |
+
current_path=gen_pdb,
|
| 94 |
+
info=item,
|
| 95 |
+
status='created',
|
| 96 |
+
rec_chain=item['rec_chain'],
|
| 97 |
+
pep_chain=item['lig_chain'],
|
| 98 |
+
rfdiff_relax=self.rfdiff_relax
|
| 99 |
+
))
|
| 100 |
+
self.visited.add(_id)
|
| 101 |
+
return tasks
|
| 102 |
+
|
| 103 |
+
def scan_dataset(self) -> List[RelaxTask]:
|
| 104 |
+
tasks = []
|
| 105 |
+
root_dir = os.path.dirname(self.results)
|
| 106 |
+
with open(self.results, 'r') as fin: # index file of datasets
|
| 107 |
+
lines = fin.readlines()
|
| 108 |
+
for line in lines:
|
| 109 |
+
line = line.strip('\n').split('\t')
|
| 110 |
+
_id = line[0]
|
| 111 |
+
item = {
|
| 112 |
+
'id': _id,
|
| 113 |
+
'number': 0
|
| 114 |
+
}
|
| 115 |
+
pdb_path = os.path.join(root_dir, 'pdbs', _id + '.pdb')
|
| 116 |
+
tasks.append(RelaxTask(
|
| 117 |
+
in_path=pdb_path,
|
| 118 |
+
current_path=pdb_path,
|
| 119 |
+
info=item,
|
| 120 |
+
status='created',
|
| 121 |
+
rec_chain=line[7],
|
| 122 |
+
pep_chain=line[8],
|
| 123 |
+
rfdiff_relax=self.rfdiff_relax
|
| 124 |
+
))
|
| 125 |
+
self.visited.add(_id)
|
| 126 |
+
return tasks
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def run_pyrosetta(task: RelaxTask):
|
| 130 |
+
if not task.can_proceed() :
|
| 131 |
+
return task
|
| 132 |
+
# if task.update_if_finished('rosetta'):
|
| 133 |
+
# return task
|
| 134 |
+
|
| 135 |
+
out_path = task.set_current_path_tag('rosetta')
|
| 136 |
+
try:
|
| 137 |
+
if task.rfdiff_relax:
|
| 138 |
+
rfdiff_refine(task.in_path, out_path, task.pep_chain)
|
| 139 |
+
else:
|
| 140 |
+
pyrosetta_fastrelax(task.in_path, out_path, task.pep_chain, rfdiff_config=task.rfdiff_relax)
|
| 141 |
+
dG = pyrosetta_interface_energy(out_path, [task.rec_chain], [task.pep_chain])
|
| 142 |
+
task.mark_success()
|
| 143 |
+
except Exception as e:
|
| 144 |
+
print(e)
|
| 145 |
+
dG = 1e10
|
| 146 |
+
task.mark_failure()
|
| 147 |
+
task.set_dG(dG)
|
| 148 |
+
return task
|
evaluation/dG/energy.py
ADDED
|
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/python
|
| 2 |
+
# -*- coding:utf-8 -*-
|
| 3 |
+
'''
|
| 4 |
+
From https://github.com/luost26/diffab/blob/main/diffab/tools/relax/pyrosetta_relaxer.py
|
| 5 |
+
'''
|
| 6 |
+
import os
|
| 7 |
+
import time
|
| 8 |
+
import pyrosetta
|
| 9 |
+
from pyrosetta.rosetta.protocols.analysis import InterfaceAnalyzerMover
|
| 10 |
+
# for fast relax
|
| 11 |
+
from pyrosetta.rosetta import protocols
|
| 12 |
+
from pyrosetta.rosetta.protocols.relax import FastRelax
|
| 13 |
+
from pyrosetta.rosetta.core.pack.task import TaskFactory
|
| 14 |
+
from pyrosetta.rosetta.core.pack.task import operation
|
| 15 |
+
from pyrosetta.rosetta.core.select import residue_selector as selections
|
| 16 |
+
from pyrosetta.rosetta.core.select.movemap import MoveMapFactory, move_map_action
|
| 17 |
+
from pyrosetta.rosetta.core.scoring import ScoreType
|
| 18 |
+
|
| 19 |
+
from Bio.PDB import PDBIO, PDBParser
|
| 20 |
+
from Bio.PDB.Structure import Structure as BStructure
|
| 21 |
+
from Bio.PDB.Model import Model as BModel
|
| 22 |
+
from Bio.PDB.Chain import Chain as BChain
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
pyrosetta.init(' '.join([
|
| 26 |
+
'-mute', 'all',
|
| 27 |
+
'-use_input_sc',
|
| 28 |
+
'-ignore_unrecognized_res',
|
| 29 |
+
'-ignore_zero_occupancy', 'false',
|
| 30 |
+
'-load_PDB_components', 'false',
|
| 31 |
+
'-relax:default_repeats', '2',
|
| 32 |
+
'-no_fconfig',
|
| 33 |
+
# below are from https://github.com/nrbennet/dl_binder_design/blob/main/mpnn_fr/dl_interface_design.py
|
| 34 |
+
# '-beta_nov16',
|
| 35 |
+
'-use_terminal_residues', 'true',
|
| 36 |
+
'-in:file:silent_struct_type', 'binary'
|
| 37 |
+
]))
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def current_milli_time():
|
| 41 |
+
return round(time.time() * 1000)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def get_scorefxn(scorefxn_name:str):
|
| 45 |
+
"""
|
| 46 |
+
Gets the scorefxn with appropriate corrections.
|
| 47 |
+
Taken from: https://gist.github.com/matteoferla/b33585f3aeab58b8424581279e032550
|
| 48 |
+
"""
|
| 49 |
+
import pyrosetta
|
| 50 |
+
|
| 51 |
+
corrections = {
|
| 52 |
+
'beta_july15': False,
|
| 53 |
+
'beta_nov16': False,
|
| 54 |
+
'gen_potential': False,
|
| 55 |
+
'restore_talaris_behavior': False,
|
| 56 |
+
}
|
| 57 |
+
if 'beta_july15' in scorefxn_name or 'beta_nov15' in scorefxn_name:
|
| 58 |
+
# beta_july15 is ref2015
|
| 59 |
+
corrections['beta_july15'] = True
|
| 60 |
+
elif 'beta_nov16' in scorefxn_name:
|
| 61 |
+
corrections['beta_nov16'] = True
|
| 62 |
+
elif 'genpot' in scorefxn_name:
|
| 63 |
+
corrections['gen_potential'] = True
|
| 64 |
+
pyrosetta.rosetta.basic.options.set_boolean_option('corrections:beta_july15', True)
|
| 65 |
+
elif 'talaris' in scorefxn_name: #2013 and 2014
|
| 66 |
+
corrections['restore_talaris_behavior'] = True
|
| 67 |
+
else:
|
| 68 |
+
pass
|
| 69 |
+
for corr, value in corrections.items():
|
| 70 |
+
pyrosetta.rosetta.basic.options.set_boolean_option(f'corrections:{corr}', value)
|
| 71 |
+
return pyrosetta.create_score_function(scorefxn_name)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class RelaxRegion(object):
|
| 75 |
+
|
| 76 |
+
def __init__(self, scorefxn='ref2015', max_iter=1000, subset='nbrs', move_bb=True, rfdiff_config=False):
|
| 77 |
+
super().__init__()
|
| 78 |
+
|
| 79 |
+
if rfdiff_config:
|
| 80 |
+
self.scorefxn = get_scorefxn('beta_nov16')
|
| 81 |
+
xml = os.path.join(os.path.dirname(__file__), 'RosettaFastRelaxUtil.xml')
|
| 82 |
+
objs = protocols.rosetta_scripts.XmlObjects.create_from_file(xml)
|
| 83 |
+
self.fast_relax = objs.get_mover('FastRelax')
|
| 84 |
+
self.fast_relax.max_iter(max_iter)
|
| 85 |
+
else:
|
| 86 |
+
self.scorefxn = get_scorefxn(scorefxn)
|
| 87 |
+
self.fast_relax = FastRelax()
|
| 88 |
+
self.fast_relax.set_scorefxn(self.scorefxn)
|
| 89 |
+
self.fast_relax.max_iter(max_iter)
|
| 90 |
+
|
| 91 |
+
assert subset in ('all', 'target', 'nbrs')
|
| 92 |
+
self.subset = subset
|
| 93 |
+
self.move_bb = move_bb
|
| 94 |
+
|
| 95 |
+
def __call__(self, pdb_path, ligand_chains): # flexible_residue_first, flexible_residue_last):
|
| 96 |
+
pose = pyrosetta.pose_from_pdb(pdb_path)
|
| 97 |
+
start_t = current_milli_time()
|
| 98 |
+
original_pose = pose.clone()
|
| 99 |
+
|
| 100 |
+
tf = TaskFactory()
|
| 101 |
+
tf.push_back(operation.InitializeFromCommandline())
|
| 102 |
+
tf.push_back(operation.RestrictToRepacking()) # Only allow residues to repack. No design at any position.
|
| 103 |
+
|
| 104 |
+
# Create selector for the region to be relaxed
|
| 105 |
+
# Turn off design and repacking on irrelevant positions
|
| 106 |
+
# if flexible_residue_first[-1] == ' ':
|
| 107 |
+
# flexible_residue_first = flexible_residue_first[:-1]
|
| 108 |
+
# if flexible_residue_last[-1] == ' ':
|
| 109 |
+
# flexible_residue_last = flexible_residue_last[:-1]
|
| 110 |
+
if self.subset != 'all':
|
| 111 |
+
chain_selectors = [selections.ChainSelector(chain) for chain in ligand_chains]
|
| 112 |
+
if len(chain_selectors) == 1:
|
| 113 |
+
gen_selector = chain_selectors[0]
|
| 114 |
+
else:
|
| 115 |
+
gen_selector = selections.OrResidueSelector(chain_selectors[0], chain_selectors[1])
|
| 116 |
+
for selector in chain_selectors[2:]:
|
| 117 |
+
gen_selector = selections.OrResidueSelector(gen_selector, selector)
|
| 118 |
+
# gen_selector = selections.ChainSelector(pep_chain)
|
| 119 |
+
# gen_selector = selections.ResidueIndexSelector()
|
| 120 |
+
# gen_selector.set_index_range(
|
| 121 |
+
# pose.pdb_info().pdb2pose(*flexible_residue_first),
|
| 122 |
+
# pose.pdb_info().pdb2pose(*flexible_residue_last),
|
| 123 |
+
# )
|
| 124 |
+
nbr_selector = selections.NeighborhoodResidueSelector()
|
| 125 |
+
nbr_selector.set_focus_selector(gen_selector)
|
| 126 |
+
nbr_selector.set_include_focus_in_subset(True)
|
| 127 |
+
|
| 128 |
+
if self.subset == 'nbrs':
|
| 129 |
+
subset_selector = nbr_selector
|
| 130 |
+
elif self.subset == 'target':
|
| 131 |
+
subset_selector = gen_selector
|
| 132 |
+
|
| 133 |
+
prevent_repacking_rlt = operation.PreventRepackingRLT()
|
| 134 |
+
prevent_subset_repacking = operation.OperateOnResidueSubset(
|
| 135 |
+
prevent_repacking_rlt,
|
| 136 |
+
subset_selector,
|
| 137 |
+
flip_subset=True,
|
| 138 |
+
)
|
| 139 |
+
tf.push_back(prevent_subset_repacking)
|
| 140 |
+
|
| 141 |
+
scorefxn = self.scorefxn
|
| 142 |
+
fr = self.fast_relax
|
| 143 |
+
|
| 144 |
+
pose = original_pose.clone()
|
| 145 |
+
# pos_list = pyrosetta.rosetta.utility.vector1_unsigned_long()
|
| 146 |
+
# for pos in range(pose.pdb_info().pdb2pose(*flexible_residue_first), pose.pdb_info().pdb2pose(*flexible_residue_last)+1):
|
| 147 |
+
# pos_list.append(pos)
|
| 148 |
+
# basic_idealize(pose, pos_list, scorefxn, fast=True)
|
| 149 |
+
|
| 150 |
+
mmf = MoveMapFactory()
|
| 151 |
+
if self.move_bb:
|
| 152 |
+
mmf.add_bb_action(move_map_action.mm_enable, gen_selector)
|
| 153 |
+
mmf.add_chi_action(move_map_action.mm_enable, subset_selector)
|
| 154 |
+
mm = mmf.create_movemap_from_pose(pose)
|
| 155 |
+
|
| 156 |
+
fr.set_movemap(mm)
|
| 157 |
+
fr.set_task_factory(tf)
|
| 158 |
+
fr.apply(pose)
|
| 159 |
+
|
| 160 |
+
e_before = scorefxn(original_pose)
|
| 161 |
+
e_relax = scorefxn(pose)
|
| 162 |
+
# print('\n\n[Finished in %.2f secs]' % ((current_milli_time() - start_t) / 1000))
|
| 163 |
+
# print(' > Energy (before): %.4f' % scorefxn(original_pose))
|
| 164 |
+
# print(' > Energy (optimized): %.4f' % scorefxn(pose))
|
| 165 |
+
return pose, e_before, e_relax
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def pyrosetta_fastrelax(pdb_path, out_path, pep_chain, rfdiff_config=False):
|
| 169 |
+
minimizer = RelaxRegion(rfdiff_config=rfdiff_config)
|
| 170 |
+
pose_min, _, _ = minimizer(
|
| 171 |
+
pdb_path=pdb_path,
|
| 172 |
+
ligand_chains=[pep_chain]
|
| 173 |
+
)
|
| 174 |
+
pose_min.dump_pdb(out_path)
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def _rename_chain(pdb_path, out_path, src_pep_chain, tgt_pep_chain, tgt_rec_chain):
|
| 178 |
+
|
| 179 |
+
io = PDBIO()
|
| 180 |
+
parser = PDBParser()
|
| 181 |
+
|
| 182 |
+
structure = parser.get_structure('anonymous', pdb_path)
|
| 183 |
+
|
| 184 |
+
new_mapping = {}
|
| 185 |
+
pep_chain, rec_chain = BChain(id=tgt_pep_chain), BChain(id=tgt_rec_chain)
|
| 186 |
+
|
| 187 |
+
for model in structure:
|
| 188 |
+
for chain in model:
|
| 189 |
+
if chain.get_id() == src_pep_chain:
|
| 190 |
+
new_mapping[src_pep_chain] = tgt_pep_chain
|
| 191 |
+
for res in chain:
|
| 192 |
+
pep_chain.add(res.copy())
|
| 193 |
+
else:
|
| 194 |
+
new_mapping[chain.get_id()] = tgt_rec_chain
|
| 195 |
+
for res in chain:
|
| 196 |
+
rec_chain.add(res.copy())
|
| 197 |
+
|
| 198 |
+
structure = BStructure(id=structure.get_id())
|
| 199 |
+
model = BModel(id=0)
|
| 200 |
+
model.add(pep_chain)
|
| 201 |
+
model.add(rec_chain)
|
| 202 |
+
structure.add(model)
|
| 203 |
+
|
| 204 |
+
io.set_structure(structure)
|
| 205 |
+
io.save(out_path)
|
| 206 |
+
|
| 207 |
+
return new_mapping
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
def rfdiff_refine(pdb_path, out_path, pep_chain):
|
| 211 |
+
# rename peptide chain to A and receptor to B
|
| 212 |
+
new_mapping = _rename_chain(pdb_path, out_path, pep_chain, 'A', 'B')
|
| 213 |
+
|
| 214 |
+
# force fields from RFDiffusion
|
| 215 |
+
get_scorefxn('beta_nov16')
|
| 216 |
+
xml = os.path.join(os.path.dirname(__file__), 'RosettaFastRelaxUtil.xml')
|
| 217 |
+
objs = protocols.rosetta_scripts.XmlObjects.create_from_file(xml)
|
| 218 |
+
fastrelax = objs.get_mover('FastRelax')
|
| 219 |
+
pose = pyrosetta.pose_from_pdb(out_path)
|
| 220 |
+
fastrelax.apply(pose)
|
| 221 |
+
pose.dump_pdb(out_path)
|
| 222 |
+
|
| 223 |
+
# get back to original chain ids
|
| 224 |
+
reverse_mapping = { new_mapping[key]: key for key in new_mapping }
|
| 225 |
+
_rename_chain(out_path, out_path, 'A', reverse_mapping['A'], reverse_mapping['B'])
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
def pyrosetta_interface_energy(pdb_path, receptor_chains, ligand_chains, return_dict=False):
|
| 229 |
+
pose = pyrosetta.pose_from_pdb(pdb_path)
|
| 230 |
+
interface = ''.join(ligand_chains) + '_' + ''.join(receptor_chains)
|
| 231 |
+
mover = InterfaceAnalyzerMover(interface)
|
| 232 |
+
mover.set_pack_separated(True)
|
| 233 |
+
mover.apply(pose)
|
| 234 |
+
if return_dict:
|
| 235 |
+
return pose.scores
|
| 236 |
+
return pose.scores['dG_separated']
|
evaluation/dG/openmm_relaxer.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import time
|
| 3 |
+
import io
|
| 4 |
+
import logging
|
| 5 |
+
import pdbfixer
|
| 6 |
+
import openmm
|
| 7 |
+
from openmm import app as openmm_app
|
| 8 |
+
from openmm import unit
|
| 9 |
+
ENERGY = unit.kilocalories_per_mole
|
| 10 |
+
LENGTH = unit.angstroms
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class ForceFieldMinimizer(object):
|
| 14 |
+
|
| 15 |
+
def __init__(self, stiffness=10.0, max_iterations=0, tolerance=2.39*unit.kilocalories_per_mole, platform='CUDA'):
|
| 16 |
+
super().__init__()
|
| 17 |
+
self.stiffness = stiffness
|
| 18 |
+
self.max_iterations = max_iterations
|
| 19 |
+
self.tolerance = tolerance
|
| 20 |
+
assert platform in ('CUDA', 'CPU')
|
| 21 |
+
self.platform = platform
|
| 22 |
+
|
| 23 |
+
def _fix(self, pdb_str):
|
| 24 |
+
fixer = pdbfixer.PDBFixer(pdbfile=io.StringIO(pdb_str))
|
| 25 |
+
fixer.findNonstandardResidues()
|
| 26 |
+
fixer.replaceNonstandardResidues()
|
| 27 |
+
|
| 28 |
+
fixer.findMissingResidues()
|
| 29 |
+
fixer.findMissingAtoms()
|
| 30 |
+
fixer.addMissingAtoms(seed=0)
|
| 31 |
+
fixer.addMissingHydrogens()
|
| 32 |
+
|
| 33 |
+
out_handle = io.StringIO()
|
| 34 |
+
openmm_app.PDBFile.writeFile(fixer.topology, fixer.positions, out_handle, keepIds=True)
|
| 35 |
+
return out_handle.getvalue()
|
| 36 |
+
|
| 37 |
+
def _get_pdb_string(self, topology, positions):
|
| 38 |
+
with io.StringIO() as f:
|
| 39 |
+
openmm_app.PDBFile.writeFile(topology, positions, f, keepIds=True)
|
| 40 |
+
return f.getvalue()
|
| 41 |
+
|
| 42 |
+
def _minimize(self, pdb_str):
|
| 43 |
+
pdb = openmm_app.PDBFile(io.StringIO(pdb_str))
|
| 44 |
+
|
| 45 |
+
force_field = openmm_app.ForceField("charmm36.xml") # referring to http://docs.openmm.org/latest/userguide/application/02_running_sims.html
|
| 46 |
+
constraints = openmm_app.HBonds
|
| 47 |
+
system = force_field.createSystem(pdb.topology, constraints=constraints)
|
| 48 |
+
|
| 49 |
+
# Add constraints to non-generated regions
|
| 50 |
+
force = openmm.CustomExternalForce("0.5 * k * ((x-x0)^2 + (y-y0)^2 + (z-z0)^2)")
|
| 51 |
+
force.addGlobalParameter("k", self.stiffness)
|
| 52 |
+
for p in ["x0", "y0", "z0"]:
|
| 53 |
+
force.addPerParticleParameter(p)
|
| 54 |
+
|
| 55 |
+
for i, a in enumerate(pdb.topology.atoms()):
|
| 56 |
+
if a.element.name != 'hydrogen':
|
| 57 |
+
force.addParticle(i, pdb.positions[i])
|
| 58 |
+
|
| 59 |
+
system.addForce(force)
|
| 60 |
+
|
| 61 |
+
# Set up the integrator and simulation
|
| 62 |
+
integrator = openmm.LangevinIntegrator(0, 0.01, 0.0)
|
| 63 |
+
platform = openmm.Platform.getPlatformByName("CUDA")
|
| 64 |
+
simulation = openmm_app.Simulation(pdb.topology, system, integrator, platform)
|
| 65 |
+
simulation.context.setPositions(pdb.positions)
|
| 66 |
+
|
| 67 |
+
# Perform minimization
|
| 68 |
+
ret = {}
|
| 69 |
+
state = simulation.context.getState(getEnergy=True, getPositions=True)
|
| 70 |
+
ret["einit"] = state.getPotentialEnergy().value_in_unit(ENERGY)
|
| 71 |
+
ret["posinit"] = state.getPositions(asNumpy=True).value_in_unit(LENGTH)
|
| 72 |
+
|
| 73 |
+
simulation.minimizeEnergy(maxIterations=self.max_iterations, tolerance=self.tolerance)
|
| 74 |
+
|
| 75 |
+
state = simulation.context.getState(getEnergy=True, getPositions=True)
|
| 76 |
+
ret["efinal"] = state.getPotentialEnergy().value_in_unit(ENERGY)
|
| 77 |
+
ret["pos"] = state.getPositions(asNumpy=True).value_in_unit(LENGTH)
|
| 78 |
+
ret["min_pdb"] = self._get_pdb_string(simulation.topology, state.getPositions())
|
| 79 |
+
|
| 80 |
+
return ret['min_pdb'], ret
|
| 81 |
+
|
| 82 |
+
def _add_energy_remarks(self, pdb_str, ret):
|
| 83 |
+
pdb_lines = pdb_str.splitlines()
|
| 84 |
+
pdb_lines.insert(1, "REMARK 1 FINAL ENERGY: {:.3f} KCAL/MOL".format(ret['efinal']))
|
| 85 |
+
pdb_lines.insert(1, "REMARK 1 INITIAL ENERGY: {:.3f} KCAL/MOL".format(ret['einit']))
|
| 86 |
+
return "\n".join(pdb_lines)
|
| 87 |
+
|
| 88 |
+
def __call__(self, pdb_str, out_path, return_info=True):
|
| 89 |
+
if '\n' not in pdb_str and pdb_str.lower().endswith(".pdb"):
|
| 90 |
+
with open(pdb_str) as f:
|
| 91 |
+
pdb_str = f.read()
|
| 92 |
+
|
| 93 |
+
pdb_fixed = self._fix(pdb_str)
|
| 94 |
+
pdb_min, ret = self._minimize(pdb_fixed)
|
| 95 |
+
pdb_min = self._add_energy_remarks(pdb_min, ret)
|
| 96 |
+
with open(out_path, 'w') as f:
|
| 97 |
+
f.write(pdb_min)
|
| 98 |
+
if return_info:
|
| 99 |
+
return pdb_min, ret
|
| 100 |
+
else:
|
| 101 |
+
return pdb_min
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
if __name__ == '__main__':
|
| 105 |
+
import sys
|
| 106 |
+
force_field = ForceFieldMinimizer()
|
| 107 |
+
force_field(sys.argv[1], sys.argv[2])
|
evaluation/dG/run.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/python
|
| 2 |
+
# -*- coding:utf-8 -*-
|
| 3 |
+
import os
|
| 4 |
+
import json
|
| 5 |
+
import argparse
|
| 6 |
+
import statistics
|
| 7 |
+
|
| 8 |
+
import ray
|
| 9 |
+
|
| 10 |
+
from utils.logger import print_log
|
| 11 |
+
|
| 12 |
+
from .base import TaskScanner, run_pyrosetta
|
| 13 |
+
|
| 14 |
+
# @ray.remote(num_gpus=1/8, num_cpus=1)
|
| 15 |
+
# def run_openmm_remote(task):
|
| 16 |
+
# return run_openmm(task)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@ray.remote(num_cpus=1)
|
| 20 |
+
def run_pyrosetta_remote(task):
|
| 21 |
+
return run_pyrosetta(task)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@ray.remote
|
| 25 |
+
def pipeline_pyrosetta(task):
|
| 26 |
+
funcs = [
|
| 27 |
+
run_pyrosetta_remote,
|
| 28 |
+
]
|
| 29 |
+
for fn in funcs:
|
| 30 |
+
task = fn.remote(task)
|
| 31 |
+
return ray.get(task)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def parse():
|
| 35 |
+
parser = argparse.ArgumentParser(description='calculating dG using pyrosetta')
|
| 36 |
+
parser.add_argument('--results', type=str, required=True, help='Path to the summary of the results (.jsonl)')
|
| 37 |
+
parser.add_argument('--n_sample', type=int, default=float('inf'), help='Maximum number of samples for calculation')
|
| 38 |
+
parser.add_argument('--rfdiff_relax', action='store_true', help='Use rfdiff fastrelax')
|
| 39 |
+
parser.add_argument('--out_path', type=str, default=None, help='Output path, default dG_report.jsonl under the same directory as results')
|
| 40 |
+
return parser.parse_args()
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def main(args):
|
| 44 |
+
# output summary
|
| 45 |
+
if args.out_path is None:
|
| 46 |
+
args.out_path = os.path.join(os.path.dirname(args.results), 'dG_report.jsonl')
|
| 47 |
+
results = {}
|
| 48 |
+
|
| 49 |
+
# parallel
|
| 50 |
+
ray.init()
|
| 51 |
+
scanner = TaskScanner(args.results, args.n_sample, args.rfdiff_relax)
|
| 52 |
+
if args.results.endswith('txt'):
|
| 53 |
+
tasks = scanner.scan_dataset()
|
| 54 |
+
else:
|
| 55 |
+
tasks = scanner.scan()
|
| 56 |
+
futures = [pipeline_pyrosetta.remote(t) for t in tasks]
|
| 57 |
+
if len(futures) > 0:
|
| 58 |
+
print_log(f'Submitted {len(futures)} tasks.')
|
| 59 |
+
while len(futures) > 0:
|
| 60 |
+
done_ids, futures = ray.wait(futures, num_returns=1)
|
| 61 |
+
for done_id in done_ids:
|
| 62 |
+
done_task = ray.get(done_id)
|
| 63 |
+
print_log(f'Remaining {len(futures)}. Finished {done_task.current_path}, dG {done_task.dG}')
|
| 64 |
+
_id, number = done_task.info['id'], done_task.info['number']
|
| 65 |
+
if _id not in results:
|
| 66 |
+
results[_id] = {
|
| 67 |
+
'min': float('inf'),
|
| 68 |
+
'all': {}
|
| 69 |
+
}
|
| 70 |
+
results[_id]['all'][number] = done_task.dG
|
| 71 |
+
results[_id]['min'] = min(results[_id]['min'], done_task.dG)
|
| 72 |
+
|
| 73 |
+
# write results
|
| 74 |
+
for _id in results:
|
| 75 |
+
success = 0
|
| 76 |
+
for n in results[_id]['all']:
|
| 77 |
+
if results[_id]['all'][n] < 0:
|
| 78 |
+
success += 1
|
| 79 |
+
results[_id]['success rate'] = success / len(results[_id]['all'])
|
| 80 |
+
json.dump(results, open(args.out_path, 'w'), indent=2)
|
| 81 |
+
|
| 82 |
+
# show results
|
| 83 |
+
vals = [results[_id]['min'] for _id in results]
|
| 84 |
+
print(f'median: {statistics.median(vals)}, mean: {sum(vals) / len(vals)}')
|
| 85 |
+
success = [results[_id]['success rate'] for _id in results]
|
| 86 |
+
print(f'mean success rate: {sum(success) / len(success)}')
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
if __name__ == '__main__':
|
| 90 |
+
import random
|
| 91 |
+
random.seed(12)
|
| 92 |
+
main(parse())
|
evaluation/diversity.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/python
|
| 2 |
+
# -*- coding:utf-8 -*-
|
| 3 |
+
from typing import List
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
from scipy.cluster.hierarchy import linkage, fcluster
|
| 7 |
+
from scipy.spatial.distance import squareform
|
| 8 |
+
from scipy.stats.contingency import association
|
| 9 |
+
|
| 10 |
+
from evaluation.seq_metric import align_sequences
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def seq_diversity(seqs: List[str], th: float=0.4) -> float:
|
| 14 |
+
'''
|
| 15 |
+
th: sequence distance
|
| 16 |
+
'''
|
| 17 |
+
dists = []
|
| 18 |
+
for i, seq1 in enumerate(seqs):
|
| 19 |
+
dists.append([])
|
| 20 |
+
for j, seq2 in enumerate(seqs):
|
| 21 |
+
_, sim = align_sequences(seq1, seq2)
|
| 22 |
+
dists[i].append(1 - sim)
|
| 23 |
+
dists = np.array(dists)
|
| 24 |
+
Z = linkage(squareform(dists), 'single')
|
| 25 |
+
cluster = fcluster(Z, t=th, criterion='distance')
|
| 26 |
+
return len(np.unique(cluster)) / len(seqs), cluster
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def struct_diversity(structs: np.ndarray, th: float=4.0) -> float:
|
| 30 |
+
'''
|
| 31 |
+
structs: N*L*3, alpha carbon coordinates
|
| 32 |
+
th: threshold for clustering (distance < th)
|
| 33 |
+
'''
|
| 34 |
+
ca_dists = np.sum((structs[:, None] - structs[None, :]) ** 2, axis=-1) # [N, N, L]
|
| 35 |
+
rmsd = np.sqrt(np.mean(ca_dists, axis=-1))
|
| 36 |
+
Z = linkage(squareform(rmsd), 'single') # since the distances might not be euclidean distances (e.g. rmsd)
|
| 37 |
+
cluster = fcluster(Z, t=th, criterion='distance')
|
| 38 |
+
return len(np.unique(cluster)) / structs.shape[0], cluster
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def diversity(seqs: List[str], structs: np.ndarray):
|
| 42 |
+
seq_div, seq_clu = seq_diversity(seqs)
|
| 43 |
+
if structs is None:
|
| 44 |
+
return seq_div, None, seq_div, None
|
| 45 |
+
struct_div, struct_clu = struct_diversity(structs)
|
| 46 |
+
co_div = np.sqrt(seq_div * struct_div)
|
| 47 |
+
|
| 48 |
+
n_seq_clu, n_struct_clu = np.max(seq_clu), np.max(struct_clu) # clusters start from 1
|
| 49 |
+
if n_seq_clu == 1 or n_struct_clu == 1:
|
| 50 |
+
consistency = 1.0 if n_seq_clu == n_struct_clu else 0.0
|
| 51 |
+
else:
|
| 52 |
+
table = [[0 for _ in range(n_struct_clu)] for _ in range(n_seq_clu)]
|
| 53 |
+
for seq_c, struct_c in zip(seq_clu, struct_clu):
|
| 54 |
+
table[seq_c - 1][struct_c - 1] += 1
|
| 55 |
+
consistency = association(np.array(table), method='cramer')
|
| 56 |
+
|
| 57 |
+
return seq_div, struct_div, co_div, consistency
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
if __name__ == '__main__':
|
| 61 |
+
N, L = 100, 10
|
| 62 |
+
a = np.random.randn(N, L, 3)
|
| 63 |
+
print(struct_diversity(a))
|
| 64 |
+
from utils.const import aas
|
| 65 |
+
aas = [tup[0] for tup in aas]
|
| 66 |
+
seqs = np.random.randint(0, len(aas), (N, L))
|
| 67 |
+
seqs = [''.join([aas[i] for i in idx]) for idx in seqs]
|
| 68 |
+
print(seq_diversity(seqs, 0.4))
|
evaluation/dockq.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/python
|
| 2 |
+
# -*- coding:utf-8 -*-
|
| 3 |
+
import os
|
| 4 |
+
import re
|
| 5 |
+
|
| 6 |
+
from globals import DOCKQ_DIR
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def dockq(mod_pdb: str, native_pdb: str, pep_chain: str):
|
| 10 |
+
p = os.popen(f'{os.path.join(DOCKQ_DIR, "DockQ.py")} {mod_pdb} {native_pdb} -model_chain1 {pep_chain} -native_chain1 {pep_chain} -no_needle')
|
| 11 |
+
text = p.read()
|
| 12 |
+
p.close()
|
| 13 |
+
res = re.search(r'DockQ\s+([0-1]\.[0-9]+)', text)
|
| 14 |
+
score = float(res.group(1))
|
| 15 |
+
return score
|
evaluation/rmsd.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/python
|
| 2 |
+
# -*- coding:utf-8 -*-
|
| 3 |
+
import torch
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
# a: [N, 3], b: [N, 3]
|
| 8 |
+
def compute_rmsd(a, b, aligned=False): # amino acids level rmsd
|
| 9 |
+
dist = np.sum((a - b) ** 2, axis=-1)
|
| 10 |
+
rmsd = np.sqrt(dist.sum() / a.shape[0])
|
| 11 |
+
return float(rmsd)
|
evaluation/seq_metric.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/python
|
| 2 |
+
# -*- coding:utf-8 -*-
|
| 3 |
+
from math import sqrt
|
| 4 |
+
|
| 5 |
+
from Bio.Align import substitution_matrices, PairwiseAligner
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def aar(candidate, reference):
|
| 9 |
+
hit = 0
|
| 10 |
+
for a, b in zip(candidate, reference):
|
| 11 |
+
if a == b:
|
| 12 |
+
hit += 1
|
| 13 |
+
return hit / len(reference)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def align_sequences(sequence_A, sequence_B, **kwargs):
|
| 17 |
+
"""
|
| 18 |
+
Performs a global pairwise alignment between two sequences
|
| 19 |
+
using the BLOSUM62 matrix and the Needleman-Wunsch algorithm
|
| 20 |
+
as implemented in Biopython. Returns the alignment, the sequence
|
| 21 |
+
identity and the residue mapping between both original sequences.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
sub_matrice = substitution_matrices.load('BLOSUM62')
|
| 25 |
+
aligner = PairwiseAligner()
|
| 26 |
+
aligner.substitution_matrix = sub_matrice
|
| 27 |
+
alns = aligner.align(sequence_A, sequence_B)
|
| 28 |
+
|
| 29 |
+
best_aln = alns[0]
|
| 30 |
+
aligned_A, aligned_B = best_aln
|
| 31 |
+
|
| 32 |
+
base = sqrt(aligner.score(sequence_A, sequence_A) * aligner.score(sequence_B, sequence_B))
|
| 33 |
+
seq_id = aligner.score(sequence_A, sequence_B) / base
|
| 34 |
+
return (aligned_A, aligned_B), seq_id
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def slide_aar(candidate, reference, aar_func):
|
| 38 |
+
'''
|
| 39 |
+
e.g.
|
| 40 |
+
candidate: AILPV
|
| 41 |
+
reference: ILPVH
|
| 42 |
+
|
| 43 |
+
should be matched as
|
| 44 |
+
AILPV
|
| 45 |
+
ILPVH
|
| 46 |
+
|
| 47 |
+
To do this, we slide the candidate and calculate the maximum aar:
|
| 48 |
+
A
|
| 49 |
+
AI
|
| 50 |
+
AIL
|
| 51 |
+
AILP
|
| 52 |
+
AILPV
|
| 53 |
+
ILPV
|
| 54 |
+
LPV
|
| 55 |
+
PV
|
| 56 |
+
V
|
| 57 |
+
'''
|
| 58 |
+
special_token = ' '
|
| 59 |
+
ref_len = len(reference)
|
| 60 |
+
padded_candidate = special_token * (ref_len - 1) + candidate + special_token * (ref_len - 1)
|
| 61 |
+
value = 0
|
| 62 |
+
for start in range(len(padded_candidate) - ref_len + 1):
|
| 63 |
+
value = max(value, aar_func(padded_candidate[start:start + ref_len], reference))
|
| 64 |
+
return value
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
if __name__ == '__main__':
|
| 68 |
+
print(align_sequences('PKGYAAPSA', 'KPAVYKFTL'))
|
| 69 |
+
print(align_sequences('KPAVYKFTL', 'PKGYAAPSA'))
|
| 70 |
+
print(align_sequences('PKGYAAPSA', 'PKGYAAPSA'))
|
| 71 |
+
print(align_sequences('KPAVYKFTL', 'KPAVYKFTL'))
|
generate.py
ADDED
|
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/python
|
| 2 |
+
# -*- coding:utf-8 -*-
|
| 3 |
+
import argparse
|
| 4 |
+
import json
|
| 5 |
+
import os
|
| 6 |
+
import pickle as pkl
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
from copy import deepcopy
|
| 9 |
+
from multiprocessing import Pool
|
| 10 |
+
|
| 11 |
+
import yaml
|
| 12 |
+
import torch
|
| 13 |
+
from torch.utils.data import DataLoader
|
| 14 |
+
|
| 15 |
+
import models
|
| 16 |
+
from utils.config_utils import overwrite_values
|
| 17 |
+
from data.converter.pdb_to_list_blocks import pdb_to_list_blocks
|
| 18 |
+
from data.converter.list_blocks_to_pdb import list_blocks_to_pdb
|
| 19 |
+
from data.format import VOCAB, Atom
|
| 20 |
+
from data import create_dataloader, create_dataset
|
| 21 |
+
from utils.logger import print_log
|
| 22 |
+
from utils.random_seed import setup_seed
|
| 23 |
+
from utils.const import sidechain_atoms
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def get_best_ckpt(ckpt_dir):
|
| 27 |
+
with open(os.path.join(ckpt_dir, 'checkpoint', 'topk_map.txt'), 'r') as f:
|
| 28 |
+
ls = f.readlines()
|
| 29 |
+
ckpts = []
|
| 30 |
+
for l in ls:
|
| 31 |
+
k,v = l.strip().split(':')
|
| 32 |
+
k = float(k)
|
| 33 |
+
v = v.split('/')[-1]
|
| 34 |
+
ckpts.append((k,v))
|
| 35 |
+
|
| 36 |
+
# ckpts = sorted(ckpts, key=lambda x:x[0])
|
| 37 |
+
best_ckpt = ckpts[0][1]
|
| 38 |
+
return os.path.join(ckpt_dir, 'checkpoint', best_ckpt)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def to_device(data, device):
|
| 42 |
+
if isinstance(data, dict):
|
| 43 |
+
for key in data:
|
| 44 |
+
data[key] = to_device(data[key], device)
|
| 45 |
+
elif isinstance(data, list) or isinstance(data, tuple):
|
| 46 |
+
res = [to_device(item, device) for item in data]
|
| 47 |
+
data = type(data)(res)
|
| 48 |
+
elif hasattr(data, 'to'):
|
| 49 |
+
data = data.to(device)
|
| 50 |
+
return data
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def clamp_coord(coord):
|
| 54 |
+
# some models (e.g. diffab) will output very large coordinates (absolute value >1000) which will corrupt the pdb file
|
| 55 |
+
new_coord = []
|
| 56 |
+
for val in coord:
|
| 57 |
+
if abs(val) >= 1000:
|
| 58 |
+
val = 0
|
| 59 |
+
new_coord.append(val)
|
| 60 |
+
return new_coord
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def overwrite_blocks(blocks, seq=None, X=None):
|
| 64 |
+
if seq is not None:
|
| 65 |
+
assert len(blocks) == len(seq), f'{len(blocks)} {len(seq)}'
|
| 66 |
+
new_blocks = []
|
| 67 |
+
for i, block in enumerate(blocks):
|
| 68 |
+
block = deepcopy(block)
|
| 69 |
+
if seq is None:
|
| 70 |
+
abrv = block.abrv
|
| 71 |
+
else:
|
| 72 |
+
abrv = VOCAB.symbol_to_abrv(seq[i])
|
| 73 |
+
if block.abrv != abrv:
|
| 74 |
+
if X is None:
|
| 75 |
+
block.units = [atom for atom in block.units if atom.name in VOCAB.backbone_atoms]
|
| 76 |
+
if X is not None:
|
| 77 |
+
coords = X[i]
|
| 78 |
+
atoms = VOCAB.backbone_atoms + sidechain_atoms[VOCAB.abrv_to_symbol(abrv)]
|
| 79 |
+
block.units = [
|
| 80 |
+
Atom(atom_name, clamp_coord(coord), atom_name[0]) for atom_name, coord in zip(atoms, coords)
|
| 81 |
+
]
|
| 82 |
+
block.abrv = abrv
|
| 83 |
+
new_blocks.append(block)
|
| 84 |
+
return new_blocks
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def generate_wrapper(model, sample_opt={}):
|
| 88 |
+
if isinstance(model, models.AutoEncoder):
|
| 89 |
+
def wrapper(batch):
|
| 90 |
+
X, S, ppls = model.test(batch['X'], batch['S'], batch['mask'], batch['position_ids'], batch['lengths'], batch['atom_mask'])
|
| 91 |
+
return X, S, ppls
|
| 92 |
+
elif isinstance(model, models.LDMPepDesign):
|
| 93 |
+
def wrapper(batch):
|
| 94 |
+
X, S, ppls = model.sample(batch['X'], batch['S'], batch['mask'], batch['position_ids'], batch['lengths'], batch['atom_mask'],
|
| 95 |
+
L=batch['L'] if 'L' in batch else None, sample_opt=sample_opt)
|
| 96 |
+
return X, S, ppls
|
| 97 |
+
else:
|
| 98 |
+
raise NotImplementedError(f'Wrapper for {type(model)} not implemented')
|
| 99 |
+
return wrapper
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def save_data(
|
| 103 |
+
_id, n,
|
| 104 |
+
x_pkl_file, s_pkl_file, pmetric_pkl_file,
|
| 105 |
+
ref_pdb, rec_chain, lig_chain, ref_save_dir, cand_save_dir,
|
| 106 |
+
seq_only, struct_only, backbone_only
|
| 107 |
+
):
|
| 108 |
+
|
| 109 |
+
X, S, pmetric = pkl.load(open(x_pkl_file, 'rb')), pkl.load(open(s_pkl_file, 'rb')), pkl.load(open(pmetric_pkl_file, 'rb'))
|
| 110 |
+
os.remove(x_pkl_file), os.remove(s_pkl_file), os.remove(pmetric_pkl_file)
|
| 111 |
+
if seq_only:
|
| 112 |
+
X = None
|
| 113 |
+
elif struct_only:
|
| 114 |
+
S = None
|
| 115 |
+
rec_blocks, lig_blocks = pdb_to_list_blocks(ref_pdb, selected_chains=[rec_chain, lig_chain])
|
| 116 |
+
ref_pdb = os.path.join(ref_save_dir, _id + "_ref.pdb")
|
| 117 |
+
list_blocks_to_pdb([rec_blocks, lig_blocks], [rec_chain, lig_chain], ref_pdb)
|
| 118 |
+
# os.system(f'cp {ref_pdb} {os.path.join(ref_save_dir, _id + "_ref.pdb")}')
|
| 119 |
+
ref_seq = ''.join([VOCAB.abrv_to_symbol(block.abrv) for block in lig_blocks])
|
| 120 |
+
lig_blocks = overwrite_blocks(lig_blocks, S, X)
|
| 121 |
+
gen_seq = ''.join([VOCAB.abrv_to_symbol(block.abrv) for block in lig_blocks])
|
| 122 |
+
save_dir = os.path.join(cand_save_dir, _id)
|
| 123 |
+
if not os.path.exists(save_dir):
|
| 124 |
+
os.makedirs(save_dir)
|
| 125 |
+
gen_pdb = os.path.join(save_dir, _id + f'_gen_{n}.pdb')
|
| 126 |
+
list_blocks_to_pdb([rec_blocks, lig_blocks], [rec_chain, lig_chain], gen_pdb)
|
| 127 |
+
|
| 128 |
+
return {
|
| 129 |
+
'id': _id,
|
| 130 |
+
'number': n,
|
| 131 |
+
'gen_pdb': gen_pdb,
|
| 132 |
+
'ref_pdb': ref_pdb,
|
| 133 |
+
'pmetric': pmetric,
|
| 134 |
+
'rec_chain': rec_chain,
|
| 135 |
+
'lig_chain': lig_chain,
|
| 136 |
+
'ref_seq': ref_seq,
|
| 137 |
+
'gen_seq': gen_seq,
|
| 138 |
+
'seq_only': seq_only,
|
| 139 |
+
'struct_only': struct_only,
|
| 140 |
+
'backbone_only': backbone_only
|
| 141 |
+
}
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def main(args, opt_args):
|
| 145 |
+
config = yaml.safe_load(open(args.config, 'r'))
|
| 146 |
+
config = overwrite_values(config, opt_args)
|
| 147 |
+
struct_only = config.get('struct_only', False)
|
| 148 |
+
seq_only = config.get('seq_only', False)
|
| 149 |
+
assert not (seq_only and struct_only)
|
| 150 |
+
backbone_only = config.get('backbone_only', False)
|
| 151 |
+
# load model
|
| 152 |
+
b_ckpt = args.ckpt if args.ckpt.endswith('.ckpt') else get_best_ckpt(args.ckpt)
|
| 153 |
+
ckpt_dir = os.path.split(os.path.split(b_ckpt)[0])[0]
|
| 154 |
+
print(f'Using checkpoint {b_ckpt}')
|
| 155 |
+
model = torch.load(b_ckpt, map_location='cpu')
|
| 156 |
+
device = torch.device('cpu' if args.gpu == -1 else f'cuda:{args.gpu}')
|
| 157 |
+
model.to(device)
|
| 158 |
+
model.eval()
|
| 159 |
+
|
| 160 |
+
# load data
|
| 161 |
+
_, _, test_set = create_dataset(config['dataset'])
|
| 162 |
+
test_loader = create_dataloader(test_set, config['dataloader'])
|
| 163 |
+
|
| 164 |
+
# save path
|
| 165 |
+
if args.save_dir is None:
|
| 166 |
+
save_dir = os.path.join(ckpt_dir, 'results')
|
| 167 |
+
else:
|
| 168 |
+
save_dir = args.save_dir
|
| 169 |
+
ref_save_dir = os.path.join(save_dir, 'references')
|
| 170 |
+
cand_save_dir = os.path.join(save_dir, 'candidates')
|
| 171 |
+
for directory in [ref_save_dir, cand_save_dir]:
|
| 172 |
+
if not os.path.exists(directory):
|
| 173 |
+
os.makedirs(directory)
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
fout = open(os.path.join(save_dir, 'results.jsonl'), 'w')
|
| 177 |
+
item_idx = 0
|
| 178 |
+
|
| 179 |
+
# multiprocessing
|
| 180 |
+
pool = Pool(args.n_cpu)
|
| 181 |
+
|
| 182 |
+
n_samples = config.get('n_samples', 1)
|
| 183 |
+
|
| 184 |
+
pbar = tqdm(total=n_samples * len(test_loader))
|
| 185 |
+
for n in range(n_samples):
|
| 186 |
+
item_idx = 0
|
| 187 |
+
with torch.no_grad():
|
| 188 |
+
for batch in test_loader:
|
| 189 |
+
batch = to_device(batch, device)
|
| 190 |
+
batch_X, batch_S, batch_pmetric = generate_wrapper(model, deepcopy(config.get('sample_opt', {})))(batch)
|
| 191 |
+
|
| 192 |
+
# parallel
|
| 193 |
+
inputs = []
|
| 194 |
+
for X, S, pmetric in zip(batch_X, batch_S, batch_pmetric):
|
| 195 |
+
_id, ref_pdb, rec_chain, lig_chain = test_set.get_summary(item_idx)
|
| 196 |
+
# save temporary pickle file
|
| 197 |
+
x_pkl_file = os.path.join(save_dir, _id + f'_gen_{n}_X.pkl')
|
| 198 |
+
pkl.dump(X, open(x_pkl_file, 'wb'))
|
| 199 |
+
s_pkl_file = os.path.join(save_dir, _id + f'_gen_{n}_S.pkl')
|
| 200 |
+
pkl.dump(S, open(s_pkl_file, 'wb'))
|
| 201 |
+
pmetric_pkl_file = os.path.join(save_dir, _id + f'_gen_{n}_pmetric.pkl')
|
| 202 |
+
pkl.dump(pmetric, open(pmetric_pkl_file, 'wb'))
|
| 203 |
+
inputs.append((
|
| 204 |
+
_id, n,
|
| 205 |
+
x_pkl_file, s_pkl_file, pmetric_pkl_file,
|
| 206 |
+
ref_pdb, rec_chain, lig_chain, ref_save_dir, cand_save_dir,
|
| 207 |
+
seq_only, struct_only, backbone_only
|
| 208 |
+
))
|
| 209 |
+
item_idx += 1
|
| 210 |
+
|
| 211 |
+
results = pool.starmap(save_data, inputs)
|
| 212 |
+
for result in results:
|
| 213 |
+
fout.write(json.dumps(result) + '\n')
|
| 214 |
+
|
| 215 |
+
pbar.update(1)
|
| 216 |
+
|
| 217 |
+
fout.close()
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def parse():
|
| 221 |
+
parser = argparse.ArgumentParser(description='Generate peptides given epitopes')
|
| 222 |
+
parser.add_argument('--config', type=str, required=True, help='Path to the test configuration')
|
| 223 |
+
parser.add_argument('--ckpt', type=str, required=True, help='Path to checkpoint')
|
| 224 |
+
parser.add_argument('--save_dir', type=str, default=None, help='Directory to save generated peptides')
|
| 225 |
+
|
| 226 |
+
parser.add_argument('--gpu', type=int, default=0, help='GPU to use, -1 for cpu')
|
| 227 |
+
parser.add_argument('--n_cpu', type=int, default=4, help='Number of CPU to use (for parallelly saving the generated results)')
|
| 228 |
+
return parser.parse_known_args()
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
if __name__ == '__main__':
|
| 232 |
+
args, opt_args = parse()
|
| 233 |
+
print_log(f'Overwritting args: {opt_args}')
|
| 234 |
+
setup_seed(12)
|
| 235 |
+
main(args, opt_args)
|