initial commit
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +0 -34
- .gitignore +3 -0
- .gitmodules +3 -0
- .gradio/certificate.pem +31 -0
- Flexpert-Design/README.md +69 -0
- Flexpert-Design/configs/ANMAwareFlexibilityProtTrans.yaml +12 -0
- Flexpert-Design/configs/Flexpert-Design-inference.yaml +1 -0
- Flexpert-Design/configs/ProteinMPNN.py +14 -0
- Flexpert-Design/configs/ProteinMPNN.yaml +13 -0
- Flexpert-Design/data_interface.py +205 -0
- Flexpert-Design/data_utils.py +535 -0
- Flexpert-Design/download-cath-data.sh +17 -0
- Flexpert-Design/model_interface.py +631 -0
- Flexpert-Design/predict.py +148 -0
- Flexpert-Design/predict_example/1ah7_A.pdb +0 -0
- Flexpert-Design/predict_example/1ah7_A_instructions.csv +1 -0
- Flexpert-Design/predict_example/compare_seqs.py +59 -0
- Flexpert-Design/predict_example/predictions.txt +2 -0
- Flexpert-Design/requirements.txt +23 -0
- Flexpert-Design/src/__init__.py +49 -0
- Flexpert-Design/src/datasets/__init__.py +15 -0
- Flexpert-Design/src/datasets/alphafold_dataset.py +112 -0
- Flexpert-Design/src/datasets/atlas_dataset.py +133 -0
- Flexpert-Design/src/datasets/casp_dataset.py +57 -0
- Flexpert-Design/src/datasets/cath_dataset.py +141 -0
- Flexpert-Design/src/datasets/dataloader.py +161 -0
- Flexpert-Design/src/datasets/fast_dataloader.py +52 -0
- Flexpert-Design/src/datasets/featurizer.py +743 -0
- Flexpert-Design/src/datasets/flex_cath_dataset.py +155 -0
- Flexpert-Design/src/datasets/foldswitchers_dataset.py +128 -0
- Flexpert-Design/src/datasets/mpnn_dataset.py +492 -0
- Flexpert-Design/src/datasets/pdb_inference.py +329 -0
- Flexpert-Design/src/datasets/ts_dataset.py +47 -0
- Flexpert-Design/src/datasets/utils.py +99 -0
- Flexpert-Design/src/interface/__init__.py +0 -0
- Flexpert-Design/src/interface/data_interface.py +66 -0
- Flexpert-Design/src/interface/model_interface.py +89 -0
- Flexpert-Design/src/interface/pretrain_interface.py +405 -0
- Flexpert-Design/src/models/E3PiFold_model.py +90 -0
- Flexpert-Design/src/models/MemoryESM.py +164 -0
- Flexpert-Design/src/models/MemoryESMIF.py +116 -0
- Flexpert-Design/src/models/MemoryPiFold.py +143 -0
- Flexpert-Design/src/models/MemoryTuning.py +213 -0
- Flexpert-Design/src/models/PretrainESMIF_model.py +32 -0
- Flexpert-Design/src/models/PretrainESM_model.py +35 -0
- Flexpert-Design/src/models/PretrainPiFold_model.py +64 -0
- Flexpert-Design/src/models/Tuning.py +275 -0
- Flexpert-Design/src/models/__init__.py +16 -0
- Flexpert-Design/src/models/alphadesign_model.py +138 -0
- Flexpert-Design/src/models/anm_prottrans.py +677 -0
.gitattributes
CHANGED
|
@@ -1,35 +1 @@
|
|
| 1 |
-
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.gitignore
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
models/weights/
|
| 2 |
+
Flexpert-Design/
|
| 3 |
+
data/atlas/
|
.gitmodules
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[submodule "gradio_molecule3d"]
|
| 2 |
+
path = gradio_molecule3d
|
| 3 |
+
url = https://github.com/Honzus/gradio_molecule3d
|
.gradio/certificate.pem
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
-----BEGIN CERTIFICATE-----
|
| 2 |
+
MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
|
| 3 |
+
TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
|
| 4 |
+
cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
|
| 5 |
+
WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
|
| 6 |
+
ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
|
| 7 |
+
MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
|
| 8 |
+
h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
|
| 9 |
+
0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
|
| 10 |
+
A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
|
| 11 |
+
T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
|
| 12 |
+
B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
|
| 13 |
+
B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
|
| 14 |
+
KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
|
| 15 |
+
OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
|
| 16 |
+
jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
|
| 17 |
+
qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
|
| 18 |
+
rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
|
| 19 |
+
HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
|
| 20 |
+
hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
|
| 21 |
+
ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
|
| 22 |
+
3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
|
| 23 |
+
NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
|
| 24 |
+
ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
|
| 25 |
+
TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
|
| 26 |
+
jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
|
| 27 |
+
oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
|
| 28 |
+
4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
|
| 29 |
+
mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
|
| 30 |
+
emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
|
| 31 |
+
-----END CERTIFICATE-----
|
Flexpert-Design/README.md
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Flexpert-Design
|
| 2 |
+
|
| 3 |
+
In this directory we provide the code to train and run inference with Flexpert-Design. To expedite the release of the codebase, this part of code was not thoroughly curated and contains redundant files and code. The codebase might be revised in the future but probably it will get completely rewritten as part of a future project with an improved model.
|
| 4 |
+
|
| 5 |
+
## Environment
|
| 6 |
+
|
| 7 |
+
Tested for Python 3.9. For other versions enviroment might need to be adapted.
|
| 8 |
+
|
| 9 |
+
Assuming you have already installed the environment for Flexpert-3D and Flexpert-Seq, install the additional dependencies for Flexpert-Design using the `requirements.txt` file in this directory.
|
| 10 |
+
|
| 11 |
+
```bash
|
| 12 |
+
pip install -r requirements.txt
|
| 13 |
+
```
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
## Inference
|
| 17 |
+
|
| 18 |
+
In this example we will illustrate how to run inference with the trained model (trained wights are provided inside the train/results directory, you do not need to train the model again necessarily).
|
| 19 |
+
|
| 20 |
+
Place the PDB files you want to predict in the `predict_example` directory. It is expected that the files are named like `PDBCODE_CHAINID.pdb`, example file '1ahy_A.pdb' is provided. For each PDB file in that folder, add the instructions on flexibility you want to be considered by the ProteinMPNN model in the `PDBCODE_CHAINID_instructions.csv` file - example file '1ah7_A_instructions.csv' is provided. Then run the following command to run inference.
|
| 21 |
+
|
| 22 |
+
```bash
|
| 23 |
+
python3 predict.py \
|
| 24 |
+
--infer_path predict_example/
|
| 25 |
+
```
|
| 26 |
+
|
| 27 |
+
The output will be in the `predict_example/predictions.txt` file.
|
| 28 |
+
|
| 29 |
+
The origininal sequence and the regenerated sequence can be compared using the following script.
|
| 30 |
+
|
| 31 |
+
```bash
|
| 32 |
+
python3 predict_example/compare_seqs.py \
|
| 33 |
+
--pdb_code 1ah7_A
|
| 34 |
+
```
|
| 35 |
+
|
| 36 |
+
## Training
|
| 37 |
+
|
| 38 |
+
First make sure you have the Flexpert-3D model weights in the `Flexpert/models/weights` directory. Alternatively run the following script to download the weights.
|
| 39 |
+
|
| 40 |
+
```bash
|
| 41 |
+
. ../download_flexpert_weights.sh
|
| 42 |
+
```
|
| 43 |
+
|
| 44 |
+
Download the training data:
|
| 45 |
+
|
| 46 |
+
```bash
|
| 47 |
+
. ../download-cath-data.sh
|
| 48 |
+
```
|
| 49 |
+
|
| 50 |
+
Then run the following command to train the model.
|
| 51 |
+
|
| 52 |
+
```bash
|
| 53 |
+
export HF_HOME=./HF_cache
|
| 54 |
+
python3 train.py \
|
| 55 |
+
--batch_size 4 \
|
| 56 |
+
--model_name 'ProteinMPNN' \
|
| 57 |
+
--stage 'fit' \
|
| 58 |
+
--dataset FLEX_CATH4.3 \
|
| 59 |
+
--ex_name training-reproduction \
|
| 60 |
+
--offline 0 \
|
| 61 |
+
--gpus 1 \
|
| 62 |
+
--epoch 11 \
|
| 63 |
+
--use_dynamics 1 \
|
| 64 |
+
--flex_loss_coeff 0.8 \
|
| 65 |
+
--init_flex_features 1 \
|
| 66 |
+
--grad_normalization 0 \
|
| 67 |
+
--loss_fn MSE \
|
| 68 |
+
--use_pmpnn_checkpoint 1
|
| 69 |
+
```
|
Flexpert-Design/configs/ANMAwareFlexibilityProtTrans.yaml
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
checkpoint_path: ../models/weights/flexpert_3d_weights.bin
|
| 2 |
+
data_jsonl_name: /NEWcath_ANM_gt_flex_annotated.jsonl #/debug_cath_ANM_gt_flex_annotated.jsonl #/cath_ANM_gt_flex_annotated.jsonl
|
| 3 |
+
half_precision: False #mixed_precision
|
| 4 |
+
gumbel_temperature: 0.2
|
| 5 |
+
num_labels: 1
|
| 6 |
+
add_pearson_loss: False
|
| 7 |
+
add_sse_loss: False
|
| 8 |
+
adaptor_architecture: 'conv'
|
| 9 |
+
enm_embed_dim: 128
|
| 10 |
+
enm_att_heads: 8
|
| 11 |
+
num_layers: 3
|
| 12 |
+
kernel_size: 5
|
Flexpert-Design/configs/Flexpert-Design-inference.yaml
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
pmpnn_model_path: 'train/results/MSEloss_flex_cath_coeff_0.8/checkpoints/last.ckpt'
|
Flexpert-Design/configs/ProteinMPNN.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
method = 'ProteinMPNN'
|
| 2 |
+
hidden = 128
|
| 3 |
+
k_neighbors=30
|
| 4 |
+
num_letters = 33
|
| 5 |
+
num_encoder_layers = 3
|
| 6 |
+
num_decoder_layers = 3
|
| 7 |
+
vocab = 33
|
| 8 |
+
dropout = 0.1
|
| 9 |
+
smoothing = 0.1
|
| 10 |
+
batch_size = 8
|
| 11 |
+
lr = 0.001
|
| 12 |
+
proteinmpnn_type = 0
|
| 13 |
+
patience = 100
|
| 14 |
+
epoch=100
|
Flexpert-Design/configs/ProteinMPNN.yaml
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
augment_eps: 0.0
|
| 2 |
+
num_encoder_layers: 3
|
| 3 |
+
hidden_dim: 128
|
| 4 |
+
hidden: 128
|
| 5 |
+
k_neighbors: 30
|
| 6 |
+
num_letters: 33
|
| 7 |
+
num_decoder_layers: 3
|
| 8 |
+
vocab: 33
|
| 9 |
+
dropout: 0.1
|
| 10 |
+
smoothing: 0.1
|
| 11 |
+
proteinmpnn_type: 0
|
| 12 |
+
init_flex_features: 1
|
| 13 |
+
starting_checkpoint_path: 'vanilla_mpnn_weights/best-epoch=99-recovery=0.485.ckpt'
|
Flexpert-Design/data_interface.py
ADDED
|
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import inspect
|
| 2 |
+
from torch.utils.data import DataLoader
|
| 3 |
+
from src.interface.data_interface import DInterface_base
|
| 4 |
+
import torch
|
| 5 |
+
import os.path as osp
|
| 6 |
+
from src.tools.utils import cuda
|
| 7 |
+
import pdb
|
| 8 |
+
from src.tools.utils import load_yaml_config
|
| 9 |
+
|
| 10 |
+
class MyDataLoader(DataLoader):
|
| 11 |
+
def __init__(self, dataset, model_name, batch_size=64, num_workers=8, *args, **kwargs):
|
| 12 |
+
super().__init__(dataset, batch_size=batch_size, num_workers=num_workers, *args, **kwargs)
|
| 13 |
+
self.pretrain_device = 'cuda:0'
|
| 14 |
+
self.model_name = model_name
|
| 15 |
+
|
| 16 |
+
def __iter__(self):
|
| 17 |
+
for batch in super().__iter__():
|
| 18 |
+
# 在这里对batch进行处理
|
| 19 |
+
# ...
|
| 20 |
+
try:
|
| 21 |
+
self.pretrain_device = f'cuda:{torch.distributed.get_rank()}'
|
| 22 |
+
except:
|
| 23 |
+
self.pretrain_device = 'cuda:0'
|
| 24 |
+
|
| 25 |
+
stream = torch.cuda.Stream(
|
| 26 |
+
self.pretrain_device
|
| 27 |
+
)
|
| 28 |
+
with torch.cuda.stream(stream):
|
| 29 |
+
if self.model_name=='GVP':
|
| 30 |
+
batch = batch.cuda(non_blocking=True, device=self.pretrain_device)
|
| 31 |
+
yield batch
|
| 32 |
+
else:
|
| 33 |
+
for key, val in batch.items():
|
| 34 |
+
if type(val) == torch.Tensor:
|
| 35 |
+
batch[key] = batch[key].cuda(non_blocking=True, device=self.pretrain_device)
|
| 36 |
+
|
| 37 |
+
# X = batch['X'].cuda(non_blocking=True, device=self.pretrain_device)
|
| 38 |
+
# S = batch['S'].cuda(non_blocking=True, device=self.pretrain_device)
|
| 39 |
+
# score = batch['score'].cuda(non_blocking=True, device=self.pretrain_device)
|
| 40 |
+
# mask = batch['mask'].cuda(non_blocking=True, device=self.pretrain_device)
|
| 41 |
+
# lengths = batch['lengths'].cuda(non_blocking=True, device=self.pretrain_device)
|
| 42 |
+
# chain_mask = batch['chain_mask'].cuda(non_blocking=True, device=self.pretrain_device)
|
| 43 |
+
# chain_encoding = batch['chain_encoding'].cuda(non_blocking=True, device=self.pretrain_device)
|
| 44 |
+
|
| 45 |
+
yield batch
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class DInterface(DInterface_base):
|
| 49 |
+
def __init__(self,**kwargs):
|
| 50 |
+
super().__init__(**kwargs)
|
| 51 |
+
self.save_hyperparameters()
|
| 52 |
+
self.load_data_module()
|
| 53 |
+
|
| 54 |
+
def setup(self, stage=None):
|
| 55 |
+
from src.datasets.featurizer import (featurize_AF, featurize_GTrans, featurize_GVP,
|
| 56 |
+
featurize_ProteinMPNN, featurize_Inversefolding)
|
| 57 |
+
if self.hparams.model_name in ['AlphaDesign', 'PiFold', 'KWDesign', 'GraphTrans', 'StructGNN', 'GCA', 'E3PiFold']:
|
| 58 |
+
self.collate_fn = featurize_GTrans
|
| 59 |
+
elif self.hparams.model_name == 'GVP':
|
| 60 |
+
featurizer = featurize_GVP()
|
| 61 |
+
self.collate_fn = featurizer.collate
|
| 62 |
+
elif self.hparams.model_name == 'ProteinMPNN':
|
| 63 |
+
self.collate_fn = featurize_ProteinMPNN
|
| 64 |
+
elif self.hparams.model_name == 'ESMIF':
|
| 65 |
+
self.collate_fn = featurize_Inversefolding
|
| 66 |
+
|
| 67 |
+
# Assign train/val datasets for use in dataloaders
|
| 68 |
+
if stage == 'fit' or stage is None:
|
| 69 |
+
self.trainset = self.instancialize(split = 'train')
|
| 70 |
+
self.valset = self.instancialize(split='valid')
|
| 71 |
+
|
| 72 |
+
# Assign test dataset for use in dataloader(s)
|
| 73 |
+
if stage == 'test' or stage is None:
|
| 74 |
+
self.testset = self.instancialize(split='test')
|
| 75 |
+
|
| 76 |
+
if stage in ['predict','eval']:
|
| 77 |
+
self.predictset = self.instancialize(split='predict')
|
| 78 |
+
|
| 79 |
+
def train_dataloader(self):
|
| 80 |
+
return MyDataLoader(self.trainset, model_name=self.hparams.model_name, batch_size=self.batch_size, num_workers=self.hparams.num_workers, shuffle=True, prefetch_factor=8, pin_memory=True, collate_fn=self.collate_fn)
|
| 81 |
+
|
| 82 |
+
def val_dataloader(self):
|
| 83 |
+
return MyDataLoader(self.valset, model_name=self.hparams.model_name, batch_size=self.batch_size, num_workers=self.hparams.num_workers, shuffle=False, pin_memory=True, collate_fn=self.collate_fn)
|
| 84 |
+
|
| 85 |
+
def test_dataloader(self):
|
| 86 |
+
return MyDataLoader(self.testset, model_name=self.hparams.model_name, batch_size=self.batch_size, num_workers=self.hparams.num_workers, shuffle=False, pin_memory=True, collate_fn=self.collate_fn)
|
| 87 |
+
|
| 88 |
+
def predict_dataloader(self):
|
| 89 |
+
return MyDataLoader(self.predictset, model_name=self.hparams.model_name, batch_size=self.batch_size, num_workers=self.hparams.num_workers, shuffle=False, pin_memory=True, collate_fn=self.collate_fn)
|
| 90 |
+
|
| 91 |
+
def load_data_module(self):
|
| 92 |
+
|
| 93 |
+
name = self.hparams.dataset
|
| 94 |
+
if name == 'AF2DB':
|
| 95 |
+
from src.datasets.AF2DB_dataset_lmdb import Af2dbDataset
|
| 96 |
+
self.data_module = Af2dbDataset
|
| 97 |
+
|
| 98 |
+
if name == 'TS':
|
| 99 |
+
from src.datasets.ts_dataset import TSDataset
|
| 100 |
+
self.data_module = TSDataset
|
| 101 |
+
self.hparams['path'] = osp.join(self.hparams.data_root, 'ts')
|
| 102 |
+
|
| 103 |
+
if name == 'CASP15':
|
| 104 |
+
from src.datasets.casp_dataset import CASPDataset
|
| 105 |
+
self.data_module = CASPDataset
|
| 106 |
+
self.hparams['path'] = osp.join(self.hparams.data_root, 'casp15')
|
| 107 |
+
|
| 108 |
+
if name == 'CATH4.2':
|
| 109 |
+
from src.datasets.cath_dataset import CATHDataset
|
| 110 |
+
self.data_module = CATHDataset
|
| 111 |
+
self.hparams['version'] = 4.2
|
| 112 |
+
self.hparams['path'] = osp.join(self.hparams.data_root, 'cath4.2')
|
| 113 |
+
|
| 114 |
+
if name == 'CATH4.3':
|
| 115 |
+
from src.datasets.cath_dataset import CATHDataset
|
| 116 |
+
self.data_module = CATHDataset
|
| 117 |
+
self.hparams['version'] = 4.3
|
| 118 |
+
self.hparams['path'] = osp.join(self.hparams.data_root, 'cath4.3')
|
| 119 |
+
|
| 120 |
+
if name == 'MPNN':
|
| 121 |
+
from src.datasets.mpnn_dataset import MPNNDataset
|
| 122 |
+
self.data_module = MPNNDataset
|
| 123 |
+
|
| 124 |
+
if name == 'FOLDSWITCHERS_1':
|
| 125 |
+
from src.datasets.foldswitchers_dataset import FoldswitchersDataset
|
| 126 |
+
self.data_module = FoldswitchersDataset
|
| 127 |
+
self.hparams['path'] = osp.join(self.hparams.data_root, 'fold_switchers/fold_1')
|
| 128 |
+
|
| 129 |
+
if name == 'FOLDSWITCHERS_2':
|
| 130 |
+
from src.datasets.foldswitchers_dataset import FoldswitchersDataset
|
| 131 |
+
self.data_module = FoldswitchersDataset
|
| 132 |
+
self.hparams['path'] = osp.join(self.hparams.data_root, 'fold_switchers/fold_2')
|
| 133 |
+
|
| 134 |
+
if name == 'PDBInference':
|
| 135 |
+
from src.datasets.pdb_inference import PDBInference
|
| 136 |
+
self.data_module = PDBInference
|
| 137 |
+
self.hparams['path'] = osp.join(self.hparams.infer_path)
|
| 138 |
+
|
| 139 |
+
if name == 'ATLAS_DIST_1':
|
| 140 |
+
from src.datasets.atlas_dataset import AtlasDataset
|
| 141 |
+
self.data_module = AtlasDataset
|
| 142 |
+
self.hparams['path'] = osp.join(self.hparams.data_root, 'atlas/distant-frame-pairs_NO_SUPERPOSITION/frames_1')
|
| 143 |
+
|
| 144 |
+
if name == 'ATLAS_DIST_2':
|
| 145 |
+
from src.datasets.atlas_dataset import AtlasDataset
|
| 146 |
+
self.data_module = AtlasDataset
|
| 147 |
+
self.hparams['path'] = osp.join(self.hparams.data_root, 'atlas/distant-frame-pairs_NO_SUPERPOSITION/frames_2')
|
| 148 |
+
|
| 149 |
+
if name == 'ATLAS_CLUSTER_1':
|
| 150 |
+
from src.datasets.atlas_dataset import AtlasDataset
|
| 151 |
+
self.data_module = AtlasDataset
|
| 152 |
+
self.hparams['path'] = osp.join(self.hparams.data_root, 'atlas/cluster-representatives/frames_1')
|
| 153 |
+
|
| 154 |
+
if name == 'ATLAS_CLUSTER_2':
|
| 155 |
+
from src.datasets.atlas_dataset import AtlasDataset
|
| 156 |
+
self.data_module = AtlasDataset
|
| 157 |
+
self.hparams['path'] = osp.join(self.hparams.data_root, 'atlas/cluster-representatives/frames_2')
|
| 158 |
+
|
| 159 |
+
if name == 'ATLAS_PDB':
|
| 160 |
+
from src.datasets.atlas_dataset import AtlasDataset
|
| 161 |
+
self.data_module = AtlasDataset
|
| 162 |
+
self.hparams['path'] = osp.join(self.hparams.data_root, '../atlas_pdb_inference/')
|
| 163 |
+
|
| 164 |
+
if name == 'ATLAS_FULL_MINIMIZED':
|
| 165 |
+
from src.datasets.atlas_dataset import AtlasDataset
|
| 166 |
+
self.data_module = AtlasDataset
|
| 167 |
+
self.hparams['path'] = osp.join(self.hparams.data_root, '../atlas_eval_proteinmpnn/atlas_full/minimized_PDBs/pdbs/')
|
| 168 |
+
|
| 169 |
+
if name == 'ATLAS_FULL_REFOLDED':
|
| 170 |
+
from src.datasets.atlas_dataset import AtlasDataset
|
| 171 |
+
self.data_module = AtlasDataset
|
| 172 |
+
self.hparams['path'] = osp.join(self.hparams.data_root, '../atlas_eval_proteinmpnn/atlas_full/refolded_PDBs/pdbs/')
|
| 173 |
+
|
| 174 |
+
if name == 'ATLAS_FULL_CRYSTAL':
|
| 175 |
+
from src.datasets.atlas_dataset import AtlasDataset
|
| 176 |
+
self.data_module = AtlasDataset
|
| 177 |
+
self.hparams['path'] = osp.join(self.hparams.data_root, '../atlas_eval_proteinmpnn/atlas_full/crystal_PDBs/pdbs/')
|
| 178 |
+
|
| 179 |
+
if name == 'FLEX_CATH4.3':
|
| 180 |
+
from src.datasets.flex_cath_dataset import FlexCATHDataset
|
| 181 |
+
self.data_module = FlexCATHDataset
|
| 182 |
+
self.hparams['version'] = 4.3
|
| 183 |
+
self.hparams['path'] = osp.join(self.hparams.data_root, 'cath4.3')
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def instancialize(self, **other_args):
|
| 187 |
+
""" Instancialize a model using the corresponding parameters
|
| 188 |
+
from self.hparams dictionary. You can also input any args
|
| 189 |
+
to overwrite the corresponding value in self.kwargs.
|
| 190 |
+
"""
|
| 191 |
+
class_args = list(inspect.signature(self.data_module.__init__).parameters)[1:]
|
| 192 |
+
inkeys = self.hparams.keys()
|
| 193 |
+
args1 = {}
|
| 194 |
+
for arg in class_args:
|
| 195 |
+
if arg in inkeys:
|
| 196 |
+
args1[arg] = self.hparams[arg]
|
| 197 |
+
args1.update(other_args)
|
| 198 |
+
|
| 199 |
+
# if self.hparams['test_engineering'] and self.hparams['use_dynamics']:
|
| 200 |
+
# args1['data_jsonl_name'] = self.hparams['test_eng_data_path']
|
| 201 |
+
#elif self.hparams['use_dynamics']:
|
| 202 |
+
if self.hparams['use_dynamics']:
|
| 203 |
+
args1['data_jsonl_name'] = load_yaml_config('configs/ANMAwareFlexibilityProtTrans.yaml')['data_jsonl_name']
|
| 204 |
+
# import pdb; pdb.set_trace()
|
| 205 |
+
return self.data_module(**args1) #Here this leads to __init__ of the class dataset
|
Flexpert-Design/data_utils.py
ADDED
|
@@ -0,0 +1,535 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#From https://github.com/JoreyYan/zetadesign/blob/master/data/data.py
|
| 2 |
+
import glob
|
| 3 |
+
import json
|
| 4 |
+
import numpy as np
|
| 5 |
+
import gzip
|
| 6 |
+
import re
|
| 7 |
+
import multiprocessing
|
| 8 |
+
import tqdm
|
| 9 |
+
import shutil
|
| 10 |
+
SENTINEL = 1
|
| 11 |
+
import biotite.structure as struc
|
| 12 |
+
import biotite.application.dssp as dssp
|
| 13 |
+
import biotite.structure.io.pdb.file as file
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def parse_PDB_biounits(x, sse,ssedssp,atoms=['N', 'CA', 'C'], chain=None):
|
| 17 |
+
'''
|
| 18 |
+
input: x = PDB filename
|
| 19 |
+
atoms = atoms to extract (optional)
|
| 20 |
+
output: (length, atoms, coords=(x,y,z)), sequence
|
| 21 |
+
'''
|
| 22 |
+
|
| 23 |
+
alpha_1 = list("ARNDCQEGHILKMFPSTWYV-")
|
| 24 |
+
states = len(alpha_1)
|
| 25 |
+
alpha_3 = ['ALA', 'ARG', 'ASN', 'ASP', 'CYS', 'GLN', 'GLU', 'GLY', 'HIS', 'ILE',
|
| 26 |
+
'LEU', 'LYS', 'MET', 'PHE', 'PRO', 'SER', 'THR', 'TRP', 'TYR', 'VAL', 'GAP']
|
| 27 |
+
|
| 28 |
+
aa_1_N = {a: n for n, a in enumerate(alpha_1)}
|
| 29 |
+
aa_3_N = {a: n for n, a in enumerate(alpha_3)}
|
| 30 |
+
aa_N_1 = {n: a for n, a in enumerate(alpha_1)}
|
| 31 |
+
aa_1_3 = {a: b for a, b in zip(alpha_1, alpha_3)}
|
| 32 |
+
aa_3_1 = {b: a for a, b in zip(alpha_1, alpha_3)}
|
| 33 |
+
|
| 34 |
+
def AA_to_N(x):
|
| 35 |
+
# ["ARND"] -> [[0,1,2,3]]
|
| 36 |
+
x = np.array(x);
|
| 37 |
+
if x.ndim == 0: x = x[None]
|
| 38 |
+
return [[aa_1_N.get(a, states - 1) for a in y] for y in x]
|
| 39 |
+
|
| 40 |
+
def N_to_AA(x):
|
| 41 |
+
# [[0,1,2,3]] -> ["ARND"]
|
| 42 |
+
x = np.array(x);
|
| 43 |
+
if x.ndim == 1: x = x[None]
|
| 44 |
+
return ["".join([aa_N_1.get(a, "-") for a in y]) for y in x]
|
| 45 |
+
|
| 46 |
+
xyz, seq, plddts, min_resn, max_resn = {}, {}, [], 1e6, -1e6
|
| 47 |
+
|
| 48 |
+
pdbcontents = x.split('\n')[0]
|
| 49 |
+
with open(pdbcontents) as f:
|
| 50 |
+
pdbcontents = f.readlines()
|
| 51 |
+
for line in pdbcontents:
|
| 52 |
+
#line = line.decode("utf-8", "ignore").rstrip()
|
| 53 |
+
|
| 54 |
+
if line[:6] == "HETATM" and line[17:17 + 3] == "MSE":
|
| 55 |
+
line = line.replace("HETATM", "ATOM ")
|
| 56 |
+
line = line.replace("MSE", "MET")
|
| 57 |
+
|
| 58 |
+
if line[:4] == "ATOM":
|
| 59 |
+
ch = line[21:22]
|
| 60 |
+
if ch == chain or chain is None or ch==' ':
|
| 61 |
+
atom = line[12:12 + 4].strip()
|
| 62 |
+
resi = line[17:17 + 3]
|
| 63 |
+
resn = line[22:22 + 5].strip()
|
| 64 |
+
plddt=line[60:60 + 6].strip()
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
x, y, z = [float(line[i:(i + 8)]) for i in [30, 38, 46]]
|
| 69 |
+
|
| 70 |
+
if resn[-1].isalpha():
|
| 71 |
+
resa, resn = resn[-1], int(resn[:-1]) - 1 # in same pos ,use last atoms
|
| 72 |
+
else:
|
| 73 |
+
resa, resn = "_", int(resn) - 1
|
| 74 |
+
# resn = int(resn)
|
| 75 |
+
if resn < min_resn:
|
| 76 |
+
min_resn = resn
|
| 77 |
+
if resn > max_resn:
|
| 78 |
+
max_resn = resn
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
if resn not in xyz:
|
| 83 |
+
xyz[resn] = {}
|
| 84 |
+
if resa not in xyz[resn]:
|
| 85 |
+
xyz[resn][resa] = {}
|
| 86 |
+
if resn not in seq:
|
| 87 |
+
seq[resn] = {}
|
| 88 |
+
|
| 89 |
+
if resa not in seq[resn]:
|
| 90 |
+
seq[resn][resa] = resi
|
| 91 |
+
|
| 92 |
+
if atom not in xyz[resn][resa]:
|
| 93 |
+
xyz[resn][resa][atom] = np.array([x, y, z])
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
# convert to numpy arrays, fill in missing values
|
| 98 |
+
seq_, xyz_ ,sse_,ssedssp_= [], [], [], []
|
| 99 |
+
dsspidx=0
|
| 100 |
+
sseidx=0
|
| 101 |
+
# try:
|
| 102 |
+
# for resn in range(min_resn, max_resn + 1):
|
| 103 |
+
# if resn in seq:
|
| 104 |
+
# for k in sorted(seq[resn]):
|
| 105 |
+
# seq_.append(aa_3_N.get(seq[resn][k], 20))
|
| 106 |
+
# try:
|
| 107 |
+
# if 'CA' in xyz[resn][k]:
|
| 108 |
+
# sse_.append(sse[sseidx])
|
| 109 |
+
# sseidx = sseidx + 1
|
| 110 |
+
# else:
|
| 111 |
+
# sse_.append('-')
|
| 112 |
+
# except:
|
| 113 |
+
# print('error sse')
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
# else:
|
| 117 |
+
# seq_.append(20)
|
| 118 |
+
# sse_.append('-')
|
| 119 |
+
|
| 120 |
+
# misschianatom = False
|
| 121 |
+
# if resn in xyz:
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
# for k in sorted(xyz[resn]):
|
| 125 |
+
# for atom in atoms:
|
| 126 |
+
# if atom in xyz[resn][k]:
|
| 127 |
+
# xyz_.append(xyz[resn][k][atom]) #some will miss C and O ,but sse is normal,because sse just depend on CA
|
| 128 |
+
# else:
|
| 129 |
+
# xyz_.append(np.full(3, np.nan))
|
| 130 |
+
# misschianatom=True
|
| 131 |
+
# if misschianatom:
|
| 132 |
+
# ssedssp_.append('-')
|
| 133 |
+
# misschianatom = False
|
| 134 |
+
# else:
|
| 135 |
+
# try:
|
| 136 |
+
# ssedssp_.append(ssedssp[dsspidx]) # if miss chain atom,xyz ,seq think is ok , but dssp miss this
|
| 137 |
+
# dsspidx = dsspidx + 1
|
| 138 |
+
# except:
|
| 139 |
+
# print(dsspidx)
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
# else:
|
| 143 |
+
# for atom in atoms:
|
| 144 |
+
# xyz_.append(np.full(3, np.nan))
|
| 145 |
+
# ssedssp_.append('-')
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
# return np.array(xyz_).reshape(-1, len(atoms), 3), N_to_AA(np.array(seq_)),np.array(sse_),np.array(ssedssp_)
|
| 149 |
+
# except TypeError as e:
|
| 150 |
+
# print(f"TypeError: {e}")
|
| 151 |
+
# return 'no_chain', 'no_chain','no_chain'
|
| 152 |
+
|
| 153 |
+
for resn in range(int(min_resn), int(max_resn + 1)):
|
| 154 |
+
if resn in seq:
|
| 155 |
+
for k in sorted(seq[resn]):
|
| 156 |
+
seq_.append(aa_3_N.get(seq[resn][k], 20))
|
| 157 |
+
try:
|
| 158 |
+
if 'CA' in xyz[resn][k]:
|
| 159 |
+
sse_.append(sse[sseidx])
|
| 160 |
+
sseidx = sseidx + 1
|
| 161 |
+
else:
|
| 162 |
+
sse_.append('-')
|
| 163 |
+
except:
|
| 164 |
+
print('error sse')
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
else:
|
| 168 |
+
seq_.append(20)
|
| 169 |
+
sse_.append('-')
|
| 170 |
+
|
| 171 |
+
misschianatom = False
|
| 172 |
+
if resn in xyz:
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
for k in sorted(xyz[resn]):
|
| 176 |
+
for atom in atoms:
|
| 177 |
+
if atom in xyz[resn][k]:
|
| 178 |
+
xyz_.append(xyz[resn][k][atom]) #some will miss C and O ,but sse is normal,because sse just depend on CA
|
| 179 |
+
else:
|
| 180 |
+
xyz_.append(np.full(3, np.nan))
|
| 181 |
+
misschianatom=True
|
| 182 |
+
if misschianatom:
|
| 183 |
+
ssedssp_.append('-')
|
| 184 |
+
misschianatom = False
|
| 185 |
+
else:
|
| 186 |
+
try:
|
| 187 |
+
ssedssp_.append(ssedssp[dsspidx]) # if miss chain atom,xyz ,seq think is ok , but dssp miss this
|
| 188 |
+
dsspidx = dsspidx + 1
|
| 189 |
+
except:
|
| 190 |
+
print(dsspidx)
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
else:
|
| 194 |
+
for atom in atoms:
|
| 195 |
+
xyz_.append(np.full(3, np.nan))
|
| 196 |
+
ssedssp_.append('-')
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
return np.array(xyz_).reshape(-1, len(atoms), 3), N_to_AA(np.array(seq_)),np.array(sse_),np.array(ssedssp_)
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
def parse_PDB(path_to_pdb,name, input_chain_list=None):
|
| 203 |
+
"""
|
| 204 |
+
make sure every time just input 1 line
|
| 205 |
+
"""
|
| 206 |
+
c = 0
|
| 207 |
+
pdb_dict_list = []
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
if input_chain_list:
|
| 211 |
+
chain_alphabet = input_chain_list
|
| 212 |
+
else:
|
| 213 |
+
init_alphabet = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S',
|
| 214 |
+
'T',
|
| 215 |
+
'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm',
|
| 216 |
+
'n',
|
| 217 |
+
'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']
|
| 218 |
+
extra_alphabet = [str(item) for item in list(np.arange(300))]
|
| 219 |
+
chain_alphabet = init_alphabet + extra_alphabet
|
| 220 |
+
|
| 221 |
+
biounit_names = [path_to_pdb]
|
| 222 |
+
for biounit in biounit_names:
|
| 223 |
+
my_dict = {}
|
| 224 |
+
s = 0
|
| 225 |
+
concat_seq = ''
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
for letter in chain_alphabet:
|
| 229 |
+
|
| 230 |
+
PDBFile = file.PDBFile.read(biounit)
|
| 231 |
+
array_stack = PDBFile.get_structure(altloc="all")
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
sse1 = struc.annotate_sse(array_stack[0], chain_id=letter).tolist()
|
| 235 |
+
if len(sse1)==0:
|
| 236 |
+
sse1 = struc.annotate_sse(array_stack[0], chain_id='').tolist()
|
| 237 |
+
#ssedssp1 = dssp.DsspApp.annotate_sse(array_stack).tolist()
|
| 238 |
+
ssedssp1 = [] #not annotating dssp for now
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
xyz, seq, _, _= parse_PDB_biounits(biounit,sse1,ssedssp1,atoms=['N', 'CA', 'C','O'], chain=letter) #TODO: fix the float error
|
| 242 |
+
#ssedssp = sse #faking it for now
|
| 243 |
+
# if len(sse)!=len(seq[0]):
|
| 244 |
+
# xxxx=len(seq[0])
|
| 245 |
+
# print(name)
|
| 246 |
+
#assert len(sse)==len(seq[0])
|
| 247 |
+
#assert len(ssedssp) == len(seq[0])
|
| 248 |
+
|
| 249 |
+
if type(xyz) != str:
|
| 250 |
+
concat_seq += seq[0]
|
| 251 |
+
my_dict['seq_chain_' + letter] = seq[0]
|
| 252 |
+
|
| 253 |
+
coords_dict_chain = {}
|
| 254 |
+
coords_dict_chain['N'] = xyz[:, 0, :].tolist()
|
| 255 |
+
coords_dict_chain['CA'] = xyz[:, 1, :].tolist()
|
| 256 |
+
coords_dict_chain['C'] = xyz[:, 2, :].tolist()
|
| 257 |
+
coords_dict_chain['O'] = xyz[:, 3, :].tolist()
|
| 258 |
+
my_dict['coords_chain_' + letter] = coords_dict_chain
|
| 259 |
+
|
| 260 |
+
#sse=''.join(sse)
|
| 261 |
+
#ssedssp=''.join(ssedssp)
|
| 262 |
+
#my_dict['sse3' ] = sse
|
| 263 |
+
#my_dict['sse8'] = ssedssp
|
| 264 |
+
s += 1
|
| 265 |
+
#fi = biounit.rfind("/")
|
| 266 |
+
my_dict['name'] = name#biounit[(fi + 1):-4]
|
| 267 |
+
my_dict['num_of_chains'] = s
|
| 268 |
+
my_dict['seq'] = concat_seq
|
| 269 |
+
if s <= len(chain_alphabet):
|
| 270 |
+
pdb_dict_list.append(my_dict)
|
| 271 |
+
c += 1
|
| 272 |
+
return pdb_dict_list
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
def parse_pdb_split_chain(pdbgzFile):
|
| 279 |
+
|
| 280 |
+
with open(pdbgzFile) as f:
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
lines = f.readlines()
|
| 284 |
+
# pdbcontent = f.decode()
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
pattern = re.compile('ATOM\s+\d+\s*\w+\s*[A-Z]{3,4}\s*(\w)\s*.+\n', re.MULTILINE)
|
| 288 |
+
match = list(set(list(pattern.findall(lines[0]))))
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
name=pdbgzFile.split('/')[-1]
|
| 292 |
+
#for chain in match:
|
| 293 |
+
# parse_PDB
|
| 294 |
+
# match=[name[4]]
|
| 295 |
+
# match=['A']
|
| 296 |
+
pdb_data=parse_PDB(pdbgzFile,name,match)
|
| 297 |
+
|
| 298 |
+
return pdb_data
|
| 299 |
+
def parse_pdb_split_chain_af(pdbgzFile):
|
| 300 |
+
with gzip.open(pdbgzFile, 'rb') as pdbF:
|
| 301 |
+
try:
|
| 302 |
+
pdbcontent = pdbF.read()
|
| 303 |
+
except:
|
| 304 |
+
print(pdbgzFile)
|
| 305 |
+
|
| 306 |
+
pdbcontent = pdbcontent.decode()
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
pattern = re.compile('ATOM\s+\d+\s*\w+\s*[A-Z]{3,4}\s*(\w)\s*.+\n', re.MULTILINE)
|
| 310 |
+
match = list(set(list(pattern.findall(pdbcontent))))
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
name=pdbgzFile.split('/')[-1].split('.')[0]
|
| 314 |
+
#for chain in match:
|
| 315 |
+
# parse_PDB
|
| 316 |
+
# match=[name[4]]
|
| 317 |
+
# match=[1]
|
| 318 |
+
pdb_data=parse_PDB('/media/junyu/data/perotin/aftest080_1000/'+pdbgzFile.split('/')[-1].split('.')[0]+'.pdb',name,match)
|
| 319 |
+
|
| 320 |
+
return pdb_data
|
| 321 |
+
|
| 322 |
+
def parse_pdb_split_chain_af_3dcnn(pdbgzFile):
|
| 323 |
+
with gzip.open(pdbgzFile, 'rb') as pdbF:
|
| 324 |
+
try:
|
| 325 |
+
pdbcontent = pdbF.read()
|
| 326 |
+
except:
|
| 327 |
+
print(pdbgzFile)
|
| 328 |
+
|
| 329 |
+
pdbcontent = pdbcontent.decode()
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
pattern = re.compile('ATOM\s+\d+\s*\w+\s*[A-Z]{3,4}\s*(\w)\s*.+\n', re.MULTILINE)
|
| 333 |
+
match = list(set(list(pattern.findall(pdbcontent))))
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
name=pdbgzFile.split('/')[-1].split('.')[0]
|
| 337 |
+
namelist=[]
|
| 338 |
+
for chain in match:
|
| 339 |
+
namelist.append(name+'__'+chain)
|
| 340 |
+
# match=[name[4]]
|
| 341 |
+
# match=[1]
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
return namelist
|
| 346 |
+
def run_net(files_path,output_path):
|
| 347 |
+
"""
|
| 348 |
+
input is pdbgz's dir
|
| 349 |
+
from pdb to jsonl
|
| 350 |
+
"""
|
| 351 |
+
list=glob.glob(files_path+'*.pdb')#[:3110]
|
| 352 |
+
data=[]
|
| 353 |
+
for i in tqdm.tqdm(list):
|
| 354 |
+
data_chains=parse_pdb_split_chain(i)
|
| 355 |
+
#for chian in data_chains:
|
| 356 |
+
data.append(data_chains[0])
|
| 357 |
+
|
| 358 |
+
print('we want to write now')
|
| 359 |
+
with open(output_path, 'w') as f:
|
| 360 |
+
for entry in data:
|
| 361 |
+
f.write(json.dumps(entry) + '\n')
|
| 362 |
+
|
| 363 |
+
f.close()
|
| 364 |
+
print('finished')
|
| 365 |
+
|
| 366 |
+
def run_netbyondif(filelist,output_path):
|
| 367 |
+
with open(filelist) as f:
|
| 368 |
+
|
| 369 |
+
lines = f.readlines()
|
| 370 |
+
data=[]
|
| 371 |
+
data_1=[]
|
| 372 |
+
# data_2 = []
|
| 373 |
+
# data_3 = []
|
| 374 |
+
# data_4 = []
|
| 375 |
+
# data_5 = []
|
| 376 |
+
# data_6 = []
|
| 377 |
+
# data_7 = []
|
| 378 |
+
# data_8 = []
|
| 379 |
+
# data_9 = []
|
| 380 |
+
# data_10 = []
|
| 381 |
+
nums_dict={1:0,2:0,3:0,4:0,5:0,6:0,7:0,8:0,9:0,10:0,}
|
| 382 |
+
|
| 383 |
+
for i in tqdm.tqdm(lines):
|
| 384 |
+
data_chains,match=parse_pdb_split_chain(i.split('"')[1])
|
| 385 |
+
|
| 386 |
+
for chian in data_chains:
|
| 387 |
+
for i in match:
|
| 388 |
+
meanplddt = round(float(np.asarray(chian['plddts_chain_' + i]).mean()),2)
|
| 389 |
+
data.append({'name':chian['name'],'lens':len(chian['seq']),'meanplddt':meanplddt})
|
| 390 |
+
if int(meanplddt/10)==1:
|
| 391 |
+
#data_1.append(chian)
|
| 392 |
+
nums_dict[1]=nums_dict[1]+1
|
| 393 |
+
elif int(meanplddt/10)==2:
|
| 394 |
+
#data_2.append(chian)
|
| 395 |
+
nums_dict[2] = nums_dict[2] + 1
|
| 396 |
+
elif int(meanplddt / 10) == 3:
|
| 397 |
+
#data_3.append(chian)
|
| 398 |
+
nums_dict[3] = nums_dict[3] + 1
|
| 399 |
+
elif int(meanplddt / 10) == 4:
|
| 400 |
+
#data_4.append(chian)
|
| 401 |
+
nums_dict[4] = nums_dict[4] + 1
|
| 402 |
+
elif int(meanplddt / 10) == 5:
|
| 403 |
+
#data_5.append(chian)
|
| 404 |
+
nums_dict[5] = nums_dict[5] + 1
|
| 405 |
+
elif int(meanplddt / 10) == 6:
|
| 406 |
+
#data_6.append(chian)
|
| 407 |
+
nums_dict[6] = nums_dict[6] + 1
|
| 408 |
+
elif int(meanplddt / 10) == 7:
|
| 409 |
+
#data_7.append(chian)
|
| 410 |
+
nums_dict[7] = nums_dict[7] + 1
|
| 411 |
+
elif int(meanplddt / 10) == 8:
|
| 412 |
+
#data_8.append(chian)
|
| 413 |
+
nums_dict[8] = nums_dict[8] + 1
|
| 414 |
+
elif int(meanplddt / 10) == 9:
|
| 415 |
+
#data_9.append(chian)
|
| 416 |
+
nums_dict[9] = nums_dict[9] + 1
|
| 417 |
+
elif int(meanplddt / 10) == 10:
|
| 418 |
+
#data_10.append(chian)
|
| 419 |
+
nums_dict[10] = nums_dict[10] + 1
|
| 420 |
+
else:
|
| 421 |
+
print(chian['name'])
|
| 422 |
+
|
| 423 |
+
|
| 424 |
+
# data.append(chian)
|
| 425 |
+
#
|
| 426 |
+
f.close()
|
| 427 |
+
output_pathindex=output_path+filelist.split('/')[-1].split('.')[0]+'_detail.jsonl'
|
| 428 |
+
print('we want to write now')
|
| 429 |
+
with open(output_pathindex, 'w') as f:
|
| 430 |
+
for entry in data:
|
| 431 |
+
f.write(json.dumps(entry) + '\n')
|
| 432 |
+
|
| 433 |
+
f.close()
|
| 434 |
+
#print(nums_dict)
|
| 435 |
+
# count(output_pathindex)
|
| 436 |
+
print('finished')
|
| 437 |
+
def list_of_groups(list_info, per_list_len):
|
| 438 |
+
'''
|
| 439 |
+
:param list_info: 列表
|
| 440 |
+
:param per_list_len: 每个小列表的长度
|
| 441 |
+
:return:
|
| 442 |
+
'''
|
| 443 |
+
list_of_group = zip(*(iter(list_info),) *per_list_len)
|
| 444 |
+
end_list = [list(i) for i in list_of_group] # i is a tuple
|
| 445 |
+
count = len(list_info) % per_list_len
|
| 446 |
+
end_list.append(list_info[-count:]) if count !=0 else end_list
|
| 447 |
+
return end_list
|
| 448 |
+
|
| 449 |
+
def count(filelist):
|
| 450 |
+
with open(filelist) as f:
|
| 451 |
+
|
| 452 |
+
lines = f.readlines()
|
| 453 |
+
plddts=[]
|
| 454 |
+
|
| 455 |
+
for i in tqdm.tqdm(lines):
|
| 456 |
+
pl=json.loads(i)['meanplddt']
|
| 457 |
+
plddts.append(int(pl/10))
|
| 458 |
+
|
| 459 |
+
for i in range(10):
|
| 460 |
+
print('counts '+str(i),plddts.count(i))
|
| 461 |
+
|
| 462 |
+
def run_net_aftest(files_path,output_path):
|
| 463 |
+
"""
|
| 464 |
+
input is pdbgz's dir
|
| 465 |
+
"""
|
| 466 |
+
with open(files_path) as f:
|
| 467 |
+
lines = f.readlines()
|
| 468 |
+
data=[]
|
| 469 |
+
for i in tqdm.tqdm(lines):
|
| 470 |
+
|
| 471 |
+
data_chains=parse_pdb_split_chain_af('/media/junyu/data/point_cloud/'+i.split('"')[1])
|
| 472 |
+
for chian in data_chains:
|
| 473 |
+
data.append(chian)
|
| 474 |
+
|
| 475 |
+
# print('we want to write now')
|
| 476 |
+
# with open(output_path, 'w') as f:
|
| 477 |
+
# for entry in data:
|
| 478 |
+
# f.write(json.dumps(entry) + '\n')
|
| 479 |
+
#
|
| 480 |
+
# f.close()
|
| 481 |
+
# print('finished')
|
| 482 |
+
|
| 483 |
+
output_pathindex = output_path + str(80) + 'bigthanclass_1000.text'
|
| 484 |
+
print('we want to write now')
|
| 485 |
+
with open(output_pathindex, 'w') as f:
|
| 486 |
+
for entry in data:
|
| 487 |
+
f.write(entry + '\n')
|
| 488 |
+
|
| 489 |
+
|
| 490 |
+
f.close()
|
| 491 |
+
|
| 492 |
+
# if __name__ == "__main__":
|
| 493 |
+
# files_path='/media/junyu/data/perotin/chain_set/AFDATA/details/80bigthanclass_1000.jsonl' #'/home/junyu/下载/splits/'#
|
| 494 |
+
# output_path='/media/junyu/data/perotin/chain_set/'
|
| 495 |
+
|
| 496 |
+
|
| 497 |
+
|
| 498 |
+
# # run_net_aftest(files_path,output_path)
|
| 499 |
+
|
| 500 |
+
# fakedata='//home/oem/pdb-tools/pdbtools/fixed/'
|
| 501 |
+
# run_net(fakedata,output_path+'tim184.jsonl')
|
| 502 |
+
|
| 503 |
+
|
| 504 |
+
|
| 505 |
+
#
|
| 506 |
+
# f.close()
|
| 507 |
+
# # print(nums_dict)
|
| 508 |
+
# print('finished ' +str(i))
|
| 509 |
+
|
| 510 |
+
|
| 511 |
+
# alllist=list_of_groups(lists,10000)
|
| 512 |
+
|
| 513 |
+
# for i in range(len(alllist)):
|
| 514 |
+
# thislist=alllist[i]
|
| 515 |
+
# with open(output_path+'_'+str(i)+'.jsonl', 'w') as f:
|
| 516 |
+
# for entry in thislist:
|
| 517 |
+
# f.write(json.dumps(entry) + '\n')
|
| 518 |
+
#
|
| 519 |
+
# f.close()
|
| 520 |
+
# # print(nums_dict)
|
| 521 |
+
# print('finished ' +str(i))
|
| 522 |
+
|
| 523 |
+
# _processes = []
|
| 524 |
+
|
| 525 |
+
|
| 526 |
+
|
| 527 |
+
# q = multiprocessing.Queue()
|
| 528 |
+
#
|
| 529 |
+
# proc.start()
|
| 530 |
+
# for eachlist in alllist:
|
| 531 |
+
# _process = multiprocessing.Process(target=run_netbyondif, args=(eachlist,))
|
| 532 |
+
# _process.start()
|
| 533 |
+
|
| 534 |
+
|
| 535 |
+
# run_netbyondif(lists,output_path)
|
Flexpert-Design/download-cath-data.sh
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
echo "Downloading CATH data..."
|
| 3 |
+
|
| 4 |
+
# Create data directory if it doesn't exist
|
| 5 |
+
mkdir -p ../data/
|
| 6 |
+
|
| 7 |
+
# Set file information
|
| 8 |
+
URL="https://data.ciirc.cvut.cz/public/projects/2025Flexpert/cath4.3/"
|
| 9 |
+
OUTPUT_DIR="../data/cath4.3"
|
| 10 |
+
|
| 11 |
+
# Download directory recursively
|
| 12 |
+
echo "Downloading CATH data..."
|
| 13 |
+
wget --no-check-certificate -r -np -nH --cut-dirs=3 --reject "index.html*" \
|
| 14 |
+
--directory-prefix=${OUTPUT_DIR} ${URL}
|
| 15 |
+
|
| 16 |
+
echo "CATH data download completed."
|
| 17 |
+
|
Flexpert-Design/model_interface.py
ADDED
|
@@ -0,0 +1,631 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys; sys.path.append('/huyuqi/xmyu/DiffSDS')
|
| 2 |
+
import inspect
|
| 3 |
+
import torch
|
| 4 |
+
from src.tools.utils import cuda
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import os
|
| 7 |
+
from torcheval.metrics.text import Perplexity
|
| 8 |
+
from src.interface.model_interface import MInterface_base
|
| 9 |
+
import math
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
from omegaconf import OmegaConf
|
| 12 |
+
from src.tools.utils import load_yaml_config
|
| 13 |
+
import torchmetrics
|
| 14 |
+
|
| 15 |
+
class MInterface(MInterface_base):
|
| 16 |
+
def __init__(self, model_name=None, loss=None, lr=None, **kwargs):
|
| 17 |
+
super().__init__()
|
| 18 |
+
self.save_hyperparameters()
|
| 19 |
+
self.load_model()
|
| 20 |
+
self.use_dynamics = kwargs.get('use_dynamics', 0)
|
| 21 |
+
self.flex_loss_coeff = torch.Tensor([kwargs.get('flex_loss_coeff', 0)]).to('cuda:0').to(torch.float)
|
| 22 |
+
self.flex_loss_coeff.requires_grad = False
|
| 23 |
+
if self.use_dynamics:
|
| 24 |
+
self.load_flex_predictor()
|
| 25 |
+
self.flex_loss_type = kwargs.get('loss_fn', 0)
|
| 26 |
+
if self.flex_loss_type == 'MSE':
|
| 27 |
+
self.flex_loss_fn = nn.MSELoss(reduction='none')
|
| 28 |
+
elif self.flex_loss_type == 'L1':
|
| 29 |
+
self.flex_loss_fn = nn.L1Loss(reduction='none')
|
| 30 |
+
elif self.flex_loss_type == 'DPO':
|
| 31 |
+
self.flex_loss_fn = ...
|
| 32 |
+
else:
|
| 33 |
+
raise ValueError(f"Not recognized type of loss function {self.flex_loss_type}")
|
| 34 |
+
self.cross_entropy = nn.NLLLoss(reduction='none')
|
| 35 |
+
os.makedirs(os.path.join(self.hparams.res_dir, self.hparams.ex_name), exist_ok=True)
|
| 36 |
+
|
| 37 |
+
self.control_sum_recovery = 0
|
| 38 |
+
self.control_sum_batch_sizes = 0
|
| 39 |
+
|
| 40 |
+
self.grad_normalization = kwargs.get('grad_normalization', 0)
|
| 41 |
+
self.use_pmpnn_checkpoint = kwargs.get('use_pmpnn_checkpoint',0)
|
| 42 |
+
|
| 43 |
+
if self.use_pmpnn_checkpoint:
|
| 44 |
+
print('Loading pmpnn checkpoint from {}'.format(self.model.pmpnn_init_weights_path))
|
| 45 |
+
state_dict = torch.load(self.model.pmpnn_init_weights_path)['state_dict'] #['module']
|
| 46 |
+
state_dict = {key: value for key, value in state_dict.items() if 'model.' in key[:6]}
|
| 47 |
+
state_dict = {key.replace("model.", ""): value for key, value in state_dict.items()}
|
| 48 |
+
self.model.load_state_dict(state_dict)
|
| 49 |
+
|
| 50 |
+
self.MSE = nn.MSELoss(reduction='none')
|
| 51 |
+
self.automatic_optimization = False
|
| 52 |
+
|
| 53 |
+
if self.hparams.use_dynamics:
|
| 54 |
+
self.pearson = torchmetrics.PearsonCorrCoef()
|
| 55 |
+
self.spearman = torchmetrics.SpearmanCorrCoef()
|
| 56 |
+
self.validation_step_outputs = []
|
| 57 |
+
self.test_step_outputs = []
|
| 58 |
+
|
| 59 |
+
#### setting forward hook
|
| 60 |
+
|
| 61 |
+
# def forward_hook(module, input, output):
|
| 62 |
+
# def check_nan(tensor):
|
| 63 |
+
# if isinstance(tensor, torch.Tensor):
|
| 64 |
+
# if torch.isnan(tensor).any():
|
| 65 |
+
# print(f"NaN detected in the output of {type(module).__name__}")
|
| 66 |
+
# print(f"Tensor shape: {tensor.shape}")
|
| 67 |
+
# print(f"Tensor stats: mean={tensor.mean()}, std={tensor.std()}, min={tensor.min()}, max={tensor.max()}, all={torch.isnan(tensor).all()}")
|
| 68 |
+
# elif isinstance(tensor, tuple):
|
| 69 |
+
# for i, t in enumerate(tensor):
|
| 70 |
+
# if isinstance(t, torch.Tensor):
|
| 71 |
+
# if torch.isnan(t).any():
|
| 72 |
+
# print(f"NaN detected in the output[{i}] of {type(module).__name__}")
|
| 73 |
+
# print(f"Tensor shape: {t.shape}")
|
| 74 |
+
# print(f"Tensor stats: mean={t.mean()}, std={t.std()}, min={t.min()}, max={t.max()}, all={torch.isnan(tensor).all()}")
|
| 75 |
+
|
| 76 |
+
# if isinstance(output, tuple):
|
| 77 |
+
# for i, out in enumerate(output):
|
| 78 |
+
# check_nan(out)
|
| 79 |
+
# else:
|
| 80 |
+
# check_nan(output)
|
| 81 |
+
|
| 82 |
+
# for name, module in self.model.named_modules():
|
| 83 |
+
# module.register_forward_hook(forward_hook)
|
| 84 |
+
|
| 85 |
+
# for name, module in self.flex_model.named_modules():
|
| 86 |
+
# module.register_forward_hook(forward_hook)
|
| 87 |
+
|
| 88 |
+
####
|
| 89 |
+
|
| 90 |
+
def forward(self, batch, mode='train', temperature=1.0):
|
| 91 |
+
if self.hparams.augment_eps>0:
|
| 92 |
+
batch['X'] = batch['X'] + self.hparams.augment_eps * torch.randn_like(batch['X'])
|
| 93 |
+
|
| 94 |
+
batch = self.model._get_features(batch)
|
| 95 |
+
results = self.model(batch)
|
| 96 |
+
|
| 97 |
+
log_probs, mask = results['log_probs'], batch['mask']
|
| 98 |
+
if len(log_probs.shape) == 3:
|
| 99 |
+
if self.hparams.use_dynamics:
|
| 100 |
+
loss = self.combined_flex_aware_loss(batch, pred_log_probs=log_probs)
|
| 101 |
+
#loss = loss_dict['combined_loss']
|
| 102 |
+
else:
|
| 103 |
+
loss = self.cross_entropy(log_probs.permute(0,2,1), batch['S'])
|
| 104 |
+
loss = (loss*mask).sum()/(mask.sum())
|
| 105 |
+
elif len(log_probs.shape) == 2:
|
| 106 |
+
if self.hparams.model_name == 'GVP':
|
| 107 |
+
loss = self.cross_entropy(log_probs, batch.seq)
|
| 108 |
+
else:
|
| 109 |
+
loss = self.cross_entropy(log_probs, batch['S'])
|
| 110 |
+
|
| 111 |
+
if self.hparams.model_name == 'AlphaDesign':
|
| 112 |
+
loss += self.cross_entropy(results['log_probs0'], batch['S'])
|
| 113 |
+
loss = (loss*mask).sum()/(mask.sum())
|
| 114 |
+
|
| 115 |
+
cmp = log_probs.argmax(dim=-1)==batch['S']
|
| 116 |
+
recovery = (cmp*mask).sum()/(mask.sum())
|
| 117 |
+
|
| 118 |
+
if mode == 'predict':
|
| 119 |
+
return {'original_sequence':batch['S'],'correct_positions': cmp, 'mask':mask,'loss':loss, 'recovery':recovery, 'title':batch['title'], 'log_probs': log_probs, 'batch':batch} #, 'gt_bfactors': batch['norm_bfactors'], 'batch':batch}
|
| 120 |
+
elif mode == 'eval':
|
| 121 |
+
return {'original_sequence':batch['S'],'correct_positions': cmp, 'mask':mask,'loss':loss, 'recovery':recovery, 'title':batch['title'], 'log_probs': log_probs, 'batch':batch}
|
| 122 |
+
else:
|
| 123 |
+
return loss, recovery
|
| 124 |
+
|
| 125 |
+
def avgCorrelations(self, preds, gts, masks):
|
| 126 |
+
pearson_R = 0
|
| 127 |
+
spearman_R = 0
|
| 128 |
+
valid_datapoints = 0
|
| 129 |
+
for pred, gt, mask in zip(preds, gts, masks):
|
| 130 |
+
dpR = self.pearson(pred[torch.where(mask)], gt[torch.where(mask)])
|
| 131 |
+
if torch.isnan(dpR):
|
| 132 |
+
continue
|
| 133 |
+
else:
|
| 134 |
+
pearson_R += dpR
|
| 135 |
+
spearman_R += self.spearman(pred[torch.where(mask)], gt[torch.where(mask)])
|
| 136 |
+
valid_datapoints += 1
|
| 137 |
+
return pearson_R/valid_datapoints, spearman_R/valid_datapoints
|
| 138 |
+
|
| 139 |
+
def temperature_schedular(self, batch_idx):
|
| 140 |
+
total_steps = self.hparams.steps_per_epoch*self.hparams.epoch
|
| 141 |
+
|
| 142 |
+
initial_lr = 1.0
|
| 143 |
+
circle_steps = total_steps//100
|
| 144 |
+
x = batch_idx / total_steps
|
| 145 |
+
threshold = 0.48
|
| 146 |
+
if x<threshold:
|
| 147 |
+
linear_decay = 1 - 2*x
|
| 148 |
+
else:
|
| 149 |
+
K = 1 - 2*threshold
|
| 150 |
+
linear_decay = K - K*(x-threshold)/(1-threshold)
|
| 151 |
+
|
| 152 |
+
new_lr = (1+math.cos(batch_idx/circle_steps*math.pi))/2*linear_decay*initial_lr
|
| 153 |
+
|
| 154 |
+
return new_lr
|
| 155 |
+
|
| 156 |
+
# def get_grad_norm(self):
|
| 157 |
+
# total_norm = 0
|
| 158 |
+
# parameters = [p for p in self.parameters() if p.grad is not None and p.requires_grad]
|
| 159 |
+
# for p in parameters:
|
| 160 |
+
# param_norm = p.grad.detach().data.norm(2)
|
| 161 |
+
# total_norm += param_norm.item() ** 2
|
| 162 |
+
# total_norm = total_norm ** 0.5
|
| 163 |
+
# return total_norm
|
| 164 |
+
|
| 165 |
+
#https://lightning.ai/docs/pytorch/1.9.0/notebooks/lightning_examples/basic-gan.html
|
| 166 |
+
def training_step(self, batch, batch_idx, **kwargs):
|
| 167 |
+
if self.use_dynamics:
|
| 168 |
+
raw_loss, recovery = self(batch)
|
| 169 |
+
if type(raw_loss) == dict:
|
| 170 |
+
flex_loss = raw_loss['flex_loss']
|
| 171 |
+
seq_loss = raw_loss['seq_loss']
|
| 172 |
+
opt = self.optimizers()
|
| 173 |
+
opt.zero_grad()
|
| 174 |
+
|
| 175 |
+
_params_for_optimization = [p for p in self.model.parameters() if p.requires_grad]
|
| 176 |
+
_params_for_optimization += [p for p in self.flex_model.parameters() if p.requires_grad]
|
| 177 |
+
|
| 178 |
+
grads_flex = torch.autograd.grad(flex_loss, _params_for_optimization, create_graph=True)
|
| 179 |
+
grads_seq = torch.autograd.grad(seq_loss, _params_for_optimization, create_graph=True)
|
| 180 |
+
if self.grad_normalization:
|
| 181 |
+
norm_grads_flex = [g / (g.norm() + 1e-10) for g in grads_flex]
|
| 182 |
+
norm_grads_seq = [g / (g.norm() + 1e-10) for g in grads_seq]
|
| 183 |
+
else:
|
| 184 |
+
norm_grads_flex = grads_flex
|
| 185 |
+
norm_grads_seq = grads_seq
|
| 186 |
+
|
| 187 |
+
combined_grads = [self.flex_loss_coeff * gflex + (1-self.flex_loss_coeff) * gseq for gflex, gseq in zip(norm_grads_flex, norm_grads_seq)]
|
| 188 |
+
|
| 189 |
+
#maybe track the angle between the gradients?
|
| 190 |
+
self.log_dict({'flex_grad_norm':torch.mean(torch.tensor([g.detach().norm() for g in norm_grads_flex])), 'seq_grad_norm': torch.mean(torch.tensor([g.detach().norm() for g in norm_grads_seq])), 'combined_grad_norm': torch.mean(torch.tensor([g.detach().norm() for g in combined_grads]))}, on_step=True, on_epoch=False, prog_bar=True)
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
for param, grad in zip(_params_for_optimization, combined_grads):
|
| 194 |
+
if param.grad is None:
|
| 195 |
+
param.grad = grad.detach()
|
| 196 |
+
else:
|
| 197 |
+
param.grad += grad.detach()
|
| 198 |
+
|
| 199 |
+
self.clip_gradients(opt, gradient_clip_val=1., gradient_clip_algorithm="norm")
|
| 200 |
+
opt.step()
|
| 201 |
+
|
| 202 |
+
# Update learning rate
|
| 203 |
+
sch = self.lr_schedulers()
|
| 204 |
+
if sch is not None:
|
| 205 |
+
sch.step()
|
| 206 |
+
|
| 207 |
+
loss = flex_loss + seq_loss
|
| 208 |
+
|
| 209 |
+
self.log_dict({'train_flex_loss':flex_loss, 'train_seq_loss':seq_loss}, on_step=True, on_epoch=False, prog_bar=True)
|
| 210 |
+
|
| 211 |
+
# Log the current learning rate
|
| 212 |
+
if sch is not None:
|
| 213 |
+
current_lr = sch.get_last_lr()[0]
|
| 214 |
+
self.log('learning_rate', current_lr, on_step=True, on_epoch=False, prog_bar=True)
|
| 215 |
+
else:
|
| 216 |
+
loss = raw_loss
|
| 217 |
+
self.log('loss', loss, on_step=True, on_epoch=True, prog_bar=True)
|
| 218 |
+
return loss
|
| 219 |
+
else:
|
| 220 |
+
raw_loss, recovery = self(batch)
|
| 221 |
+
if type(raw_loss) == dict:
|
| 222 |
+
loss = raw_loss['combined_loss']
|
| 223 |
+
_ = raw_loss.pop('pred_flex')
|
| 224 |
+
# _ = raw_loss.pop('gt_bfactors')
|
| 225 |
+
_ = raw_loss.pop('gt_flex')
|
| 226 |
+
_ = raw_loss.pop('flex_mask')
|
| 227 |
+
|
| 228 |
+
self.log_dict(raw_loss, on_step=True, on_epoch=True, prog_bar=True)
|
| 229 |
+
else:
|
| 230 |
+
loss = raw_loss
|
| 231 |
+
self.log('loss', loss, on_step=True, on_epoch=True, prog_bar=True)
|
| 232 |
+
return loss
|
| 233 |
+
|
| 234 |
+
def validation_step(self, batch, batch_idx):
|
| 235 |
+
raw_loss, recovery = self(batch)
|
| 236 |
+
if type(raw_loss) == dict:
|
| 237 |
+
loss = raw_loss['flex_loss']+raw_loss['seq_loss'] #raw_loss['combined_loss']
|
| 238 |
+
raw_loss['recovery'] = recovery
|
| 239 |
+
pred_flex = raw_loss.pop('pred_flex')
|
| 240 |
+
gt_flex = batch['gt_flex']
|
| 241 |
+
|
| 242 |
+
flex_mask = raw_loss.pop('flex_mask')
|
| 243 |
+
#epoch_metric_ingredients = {'pred_bfactors':pred_bfactors, 'gt_bfactors':gt_bfactors, 'flex_mask':flex_mask}
|
| 244 |
+
epoch_metric_ingredients = {'pred_flex': pred_flex,'gt_flex':gt_flex, 'flex_mask':flex_mask}
|
| 245 |
+
self.validation_step_outputs.append(epoch_metric_ingredients)
|
| 246 |
+
self.log_dict({ "val_combined_loss":loss,
|
| 247 |
+
"val_seq_loss":raw_loss['seq_loss'],
|
| 248 |
+
"val_flex_loss":raw_loss['flex_loss'],
|
| 249 |
+
"recovery": recovery})
|
| 250 |
+
else:
|
| 251 |
+
loss = raw_loss
|
| 252 |
+
self.log_dict({"val_loss":loss,
|
| 253 |
+
"recovery": recovery})
|
| 254 |
+
#if there is issue with validation metrics - see the test_step below
|
| 255 |
+
return self.log_dict
|
| 256 |
+
|
| 257 |
+
def on_validation_epoch_end(self):
|
| 258 |
+
if self.hparams.use_dynamics:
|
| 259 |
+
# all_preds = [b['pred_bfactors'] for b in self.validation_step_outputs]
|
| 260 |
+
# all_gts = [b['gt_bfactors'] for b in self.validation_step_outputs]
|
| 261 |
+
all_preds = [b['pred_flex'] for b in self.validation_step_outputs]
|
| 262 |
+
all_gts = [b['gt_flex'] for b in self.validation_step_outputs]
|
| 263 |
+
all_masks = [b['flex_mask'] for b in self.validation_step_outputs]
|
| 264 |
+
|
| 265 |
+
max_seq_length = max([pred.size()[1] for pred in all_preds])
|
| 266 |
+
|
| 267 |
+
for set_of_tensors in [all_preds, all_gts, all_masks]:
|
| 268 |
+
for i in range(len(set_of_tensors)):
|
| 269 |
+
set_of_tensors[i] = F.pad(set_of_tensors[i], (0, max_seq_length - set_of_tensors[i].shape[1],0,0), value=float(0))
|
| 270 |
+
all_preds = torch.cat(all_preds, dim=0)
|
| 271 |
+
all_gts = torch.cat(all_gts, dim=0)
|
| 272 |
+
all_masks = torch.cat(all_masks, dim=0)
|
| 273 |
+
# print(all_preds.shape, all_gts.shape, all_masks.shape)
|
| 274 |
+
# do something with all preds
|
| 275 |
+
|
| 276 |
+
# pearson_R = self.pearson(all_preds[torch.where(all_masks)], all_gts[torch.where(all_masks)])
|
| 277 |
+
pearson_R, spearman_R = self.avgCorrelations(all_preds, all_gts, all_masks)
|
| 278 |
+
# try:
|
| 279 |
+
# spearman_R = self.spearman(all_preds[torch.where(all_masks)], all_gts[torch.where(all_masks)])
|
| 280 |
+
# except IndexError:
|
| 281 |
+
# spearman_R = pearson_R
|
| 282 |
+
self.log_dict({"val_pearson_R":pearson_R, "val_spearman_R":spearman_R})
|
| 283 |
+
self.validation_step_outputs.clear() # free memory
|
| 284 |
+
return super().on_validation_epoch_end()
|
| 285 |
+
|
| 286 |
+
def on_test_epoch_end(self):
|
| 287 |
+
import pickle #use pickle to save the self.test_step_outputs to a file
|
| 288 |
+
with open(f'rebuttal_experiments/test_step_outputs_{self.hparams.starting_checkpoint_path.split("/")[-3]}_initFF{self.hparams.init_flex_features}_{self.hparams.test_eng_data_path.split("/")[-1][:-5]}.pkl', 'wb') as f:
|
| 289 |
+
pickle.dump(self.test_step_outputs, f)
|
| 290 |
+
if self.hparams.test_engineering and self.hparams.use_dynamics:
|
| 291 |
+
all_preds = [b['pred_flex'] for b in self.test_step_outputs]
|
| 292 |
+
all_eng_gts = [b['gt_flex'] for b in self.test_step_outputs]
|
| 293 |
+
all_masks = [b['flex_mask'] for b in self.test_step_outputs]
|
| 294 |
+
all_eng_masks = [b['eng_mask'] for b in self.test_step_outputs]
|
| 295 |
+
all_original_gt_flex = [b['original_gt_flex'] for b in self.test_step_outputs]
|
| 296 |
+
|
| 297 |
+
avg_sequence_recovery = sum([b['sequence_recovery'] for b in self.test_step_outputs]) / len(self.test_step_outputs)
|
| 298 |
+
avg_sequence_recovery = avg_sequence_recovery.cpu().tolist()
|
| 299 |
+
max_seq_length = max([pred.size()[1] for pred in all_preds])
|
| 300 |
+
|
| 301 |
+
_pred_flex_pool = []
|
| 302 |
+
_eng_gt_flex_pool = []
|
| 303 |
+
_original_gt_flex_pool = []
|
| 304 |
+
_original_gt_flex_ranks_pool = []
|
| 305 |
+
_eng_gt_flex_ranks_pool = []
|
| 306 |
+
_pred_flex_ranks_pool = []
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
import numpy as np
|
| 310 |
+
for eng_mask, flex_mask, original_gt_flex, eng_gt_flex, pred_flex in zip(all_eng_masks, all_masks, all_original_gt_flex, all_eng_gts, all_preds):
|
| 311 |
+
#select only the values where the engineering mask is 1 and flex mask is 1
|
| 312 |
+
_original_gt_flex = original_gt_flex[eng_mask == 1]
|
| 313 |
+
_eng_gt_flex = eng_gt_flex[eng_mask == 1]
|
| 314 |
+
_pred_flex = pred_flex[eng_mask == 1]
|
| 315 |
+
_pred_flex_pool.append(_pred_flex.cpu().numpy())
|
| 316 |
+
_eng_gt_flex_pool.append(_eng_gt_flex.cpu().numpy())
|
| 317 |
+
_original_gt_flex_pool.append(_original_gt_flex.cpu().numpy())
|
| 318 |
+
|
| 319 |
+
_original_gt_flex_ranks = torch.argsort(torch.argsort(torch.nan_to_num(original_gt_flex, nan=0)))[eng_mask == 1].cpu().numpy()
|
| 320 |
+
_eng_gt_flex_ranks = torch.argsort(torch.argsort(torch.nan_to_num(eng_gt_flex, nan=0)))[eng_mask == 1].cpu().numpy()
|
| 321 |
+
_pred_flex_ranks = torch.argsort(torch.argsort(torch.nan_to_num(pred_flex, nan=0)))[eng_mask == 1].cpu().numpy()
|
| 322 |
+
|
| 323 |
+
_original_gt_flex_ranks_pool.append(_original_gt_flex_ranks)
|
| 324 |
+
_eng_gt_flex_ranks_pool.append(_eng_gt_flex_ranks)
|
| 325 |
+
_pred_flex_ranks_pool.append(_pred_flex_ranks)
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
import matplotlib.pyplot as plt
|
| 329 |
+
import os
|
| 330 |
+
|
| 331 |
+
# # Create 'paper_figures' folder if it doesn't exist
|
| 332 |
+
# if not os.path.exists('paper_figures'):
|
| 333 |
+
# os.makedirs('paper_figures')
|
| 334 |
+
|
| 335 |
+
#pool the numpy arrays in the lists into one numpy array
|
| 336 |
+
_pred_flex_pool = np.concatenate(_pred_flex_pool)
|
| 337 |
+
_eng_gt_flex_pool = np.concatenate(_eng_gt_flex_pool)
|
| 338 |
+
_original_gt_flex_pool = np.concatenate(_original_gt_flex_pool)
|
| 339 |
+
|
| 340 |
+
############################################################################
|
| 341 |
+
all_gt_seqs = [b['gt_seq'] for b in self.test_step_outputs]
|
| 342 |
+
all_pred_logprobs = [b['pred_logprobs'] for b in self.test_step_outputs]
|
| 343 |
+
_gt_seq_pool = []
|
| 344 |
+
_pred_seq_pool = []
|
| 345 |
+
_outside_eng_region_pred_seq_pool = []
|
| 346 |
+
_outside_eng_region_gt_seq_pool = []
|
| 347 |
+
for eng_mask, gt_seq, pred_logprobs in zip(all_eng_masks, all_gt_seqs, all_pred_logprobs):
|
| 348 |
+
#select only the values where the engineering mask is 1
|
| 349 |
+
_outside_eng_region_pred_seq_pool.append(torch.argmax(pred_logprobs[(eng_mask == 0) & (flex_mask == 1)], dim=1).cpu().numpy())
|
| 350 |
+
_outside_eng_region_gt_seq_pool.append(gt_seq[(eng_mask == 0) & (flex_mask == 1)].cpu().numpy())
|
| 351 |
+
|
| 352 |
+
_pred_seq = torch.argmax(pred_logprobs[eng_mask == 1], dim=1)
|
| 353 |
+
_gt_seq = gt_seq[eng_mask == 1]
|
| 354 |
+
|
| 355 |
+
# create and add to the pools the numpy arrays
|
| 356 |
+
_gt_seq_pool.append(_gt_seq.cpu().numpy())
|
| 357 |
+
_pred_seq_pool.append(_pred_seq.cpu().numpy())
|
| 358 |
+
_gt_seq_pool = np.concatenate(_gt_seq_pool)
|
| 359 |
+
_pred_seq_pool = np.concatenate(_pred_seq_pool)
|
| 360 |
+
_outside_eng_region_pred_seq_pool = np.concatenate(_outside_eng_region_pred_seq_pool)
|
| 361 |
+
_outside_eng_region_gt_seq_pool = np.concatenate(_outside_eng_region_gt_seq_pool)
|
| 362 |
+
#output these pools together with the other pools to a json_file
|
| 363 |
+
import json
|
| 364 |
+
with open(f'paper_figures/pools_{self.hparams.starting_checkpoint_path.split("/")[-3]}_initFF{self.hparams.init_flex_features}_{self.hparams.test_eng_data_path.split("/")[-1][:-5]}.json', 'w') as f:
|
| 365 |
+
json.dump({
|
| 366 |
+
'_pred_flex_pool': _pred_flex_pool.tolist(),
|
| 367 |
+
'_eng_gt_flex_pool': _eng_gt_flex_pool.tolist(),
|
| 368 |
+
'_original_gt_flex_pool': _original_gt_flex_pool.tolist(),
|
| 369 |
+
'_pred_seq_pool': _pred_seq_pool.tolist(),
|
| 370 |
+
'_gt_seq_pool': _gt_seq_pool.tolist(),
|
| 371 |
+
'_sequence_recovery': avg_sequence_recovery,
|
| 372 |
+
'_outside_eng_region_pred_seq_pool': _outside_eng_region_pred_seq_pool.tolist(),
|
| 373 |
+
'_outside_eng_region_gt_seq_pool': _outside_eng_region_gt_seq_pool.tolist()
|
| 374 |
+
}, f)
|
| 375 |
+
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
############################################################################
|
| 379 |
+
|
| 380 |
+
|
| 381 |
+
self.test_step_outputs.clear()
|
| 382 |
+
else:
|
| 383 |
+
# all_preds = [b['pred_bfactors'] for b in self.test_step_outputs]
|
| 384 |
+
# all_gts = [b['gt_bfactors'] for b in self.test_step_outputs]
|
| 385 |
+
all_preds = [b['pred_flex'] for b in self.test_step_outputs]
|
| 386 |
+
all_gts = [b['gt_flex'] for b in self.test_step_outputs]
|
| 387 |
+
all_masks = [b['flex_mask'] for b in self.test_step_outputs]
|
| 388 |
+
|
| 389 |
+
max_seq_length = max([pred.size()[1] for pred in all_preds])
|
| 390 |
+
|
| 391 |
+
for set_of_tensors in [all_preds, all_gts, all_masks]:
|
| 392 |
+
for i in range(len(set_of_tensors)):
|
| 393 |
+
set_of_tensors[i] = F.pad(set_of_tensors[i], (0, max_seq_length - set_of_tensors[i].shape[1],0,0), value=float(0))
|
| 394 |
+
|
| 395 |
+
all_preds = torch.cat(all_preds, dim=0)
|
| 396 |
+
all_gts = torch.cat(all_gts, dim=0)
|
| 397 |
+
all_masks = torch.cat(all_masks, dim=0)
|
| 398 |
+
# print(all_preds.shape, all_gts.shape, all_masks.shape)
|
| 399 |
+
# do something with all preds
|
| 400 |
+
# pearson_R = self.pearson(all_preds[torch.where(all_masks)], all_gts[torch.where(all_masks)])
|
| 401 |
+
pearson_R, spearman_R = self.avgCorrelations(all_preds, all_gts, all_masks)
|
| 402 |
+
try:
|
| 403 |
+
spearman_R = self.spearman(all_preds[torch.where(all_masks)], all_gts[torch.where(all_masks)])
|
| 404 |
+
except IndexError:
|
| 405 |
+
spearman_R = pearson_R
|
| 406 |
+
self.log_dict({"test_pearson_R":pearson_R, "test_spearman_R":spearman_R})
|
| 407 |
+
self.test_step_outputs.clear() # free memory
|
| 408 |
+
return super().on_test_epoch_end()
|
| 409 |
+
|
| 410 |
+
def test_step(self, batch, batch_idx):
|
| 411 |
+
# Here we just reuse the validation_step for testing
|
| 412 |
+
#return self.validation_step(batch, batch_idx)
|
| 413 |
+
|
| 414 |
+
raw_loss, recovery = self(batch)
|
| 415 |
+
if type(raw_loss) == dict:
|
| 416 |
+
#loss = raw_loss['combined_loss']
|
| 417 |
+
loss = raw_loss['flex_loss']+raw_loss['seq_loss'] #raw_loss['combined_loss']
|
| 418 |
+
raw_loss['recovery'] = recovery
|
| 419 |
+
# pred_bfactors = raw_loss.pop('pred_bfactors')
|
| 420 |
+
pred_flex = raw_loss.pop('pred_flex')
|
| 421 |
+
# gt_bfactors = raw_loss.pop('gt_bfactors')
|
| 422 |
+
gt_flex = raw_loss.pop('gt_flex')
|
| 423 |
+
flex_mask = raw_loss.pop('flex_mask')
|
| 424 |
+
epoch_metric_ingredients = {'pred_flex':pred_flex, 'gt_flex':gt_flex, 'flex_mask':flex_mask}
|
| 425 |
+
|
| 426 |
+
if self.hparams.test_engineering and self.hparams.use_dynamics:
|
| 427 |
+
eng_mask = raw_loss.pop('eng_mask')
|
| 428 |
+
original_gt_flex = raw_loss.pop('original_gt_flex')
|
| 429 |
+
epoch_metric_ingredients['eng_mask'] = eng_mask
|
| 430 |
+
epoch_metric_ingredients['original_gt_flex'] = original_gt_flex
|
| 431 |
+
epoch_metric_ingredients['gt_seq'] = raw_loss['gt_seq']
|
| 432 |
+
epoch_metric_ingredients['pred_logprobs'] = raw_loss['pred_logprobs']
|
| 433 |
+
epoch_metric_ingredients['sequence_recovery'] = raw_loss['recovery']
|
| 434 |
+
epoch_metric_ingredients['id'] = batch['title']
|
| 435 |
+
|
| 436 |
+
self.test_step_outputs.append(epoch_metric_ingredients)
|
| 437 |
+
out_dict = {"val_combined_loss":loss,
|
| 438 |
+
"val_seq_loss":raw_loss['seq_loss'],
|
| 439 |
+
"val_flex_loss":raw_loss['flex_loss'],
|
| 440 |
+
"recovery": recovery}
|
| 441 |
+
else:
|
| 442 |
+
out_dict = {"val_loss":raw_loss, "recovery": recovery}
|
| 443 |
+
self.log_dict(out_dict,on_step=True,on_epoch=True, sync_dist=True)
|
| 444 |
+
#print(out_dict) #This print statement is fixing it - ultimately fixed by setting 'n_step=True' above
|
| 445 |
+
#Below validation of the correctness of the above loging
|
| 446 |
+
self.control_sum_batch_sizes += len(batch['X'])
|
| 447 |
+
self.control_sum_recovery += len(batch['X'])*recovery
|
| 448 |
+
return out_dict
|
| 449 |
+
|
| 450 |
+
def predict_step(self, batch, batch_idx):
|
| 451 |
+
predict_out = self(batch, mode=self.hparams.stage)
|
| 452 |
+
return predict_out
|
| 453 |
+
|
| 454 |
+
def combined_flex_aware_loss(self, batch, pred_log_probs):
|
| 455 |
+
|
| 456 |
+
_mask = batch['mask']
|
| 457 |
+
|
| 458 |
+
gt_seq = batch['S']
|
| 459 |
+
gt_flex = batch['gt_flex']
|
| 460 |
+
anm_input = batch['enm_vals'] #TODO: manage the loading of the anm input
|
| 461 |
+
|
| 462 |
+
trail_idcs = torch.argmax((batch['S'] == 0).int(), dim=1)
|
| 463 |
+
trail_idcs[trail_idcs == 0] = batch['S'].shape[1] # For sequences without padding
|
| 464 |
+
|
| 465 |
+
# # #TODO: test on one example - remove later
|
| 466 |
+
# # trail_idcs = trail_idcs[0].unsqueeze(0)
|
| 467 |
+
|
| 468 |
+
# # # ###########################################################################
|
| 469 |
+
# # # #### TODO: change back to precomputed GT_FLEX once debugged ###############
|
| 470 |
+
# dl_gtseq = batch['S']
|
| 471 |
+
# dl_anm = batch['enm_vals']
|
| 472 |
+
|
| 473 |
+
|
| 474 |
+
# attention_mask = torch.zeros_like(batch['mask'])
|
| 475 |
+
# for i in range(attention_mask.size(0)):
|
| 476 |
+
# attention_mask[i, :trail_idcs[i]] = 1
|
| 477 |
+
|
| 478 |
+
# dl_predflex_bs4 = self.flex_model(None, dl_anm, trail_idcs, attention_mask = attention_mask, sampled_pmpnn_sequence = dl_gtseq, alphabet='pmpnn') #['predicted_flex'][:,:-1,0]
|
| 479 |
+
# dl_predflex_bs1 = self.flex_model(None, dl_anm[0].unsqueeze(0), trail_idcs[0].unsqueeze(0) , attention_mask = attention_mask[0].unsqueeze(0), sampled_pmpnn_sequence = dl_gtseq[0].unsqueeze(0), alphabet='pmpnn') #['predicted_flex'][:,:-1,0]
|
| 480 |
+
|
| 481 |
+
# testseq = 'MKKAVINGEQIRSISDLHQTLKKELALPEYYGENLDALWDCLTGWVEYPLVLEWRQFEQSKQLTENGAESVLQVFREAKAEGADITIILS'
|
| 482 |
+
# tokenizer_predflex_bs4 = self.flex_model(None, dl_anm[0,:90].unsqueeze(0), trail_idcs[0].unsqueeze(0) , attention_mask = attention_mask[0,:90].unsqueeze(0), sampled_pmpnn_sequence = testseq, alphabet='aa') #['predicted_flex'][:,:-1,0] #['predicted_flex'][:,:-1,0]
|
| 483 |
+
# import pdb; pdb.set_trace()
|
| 484 |
+
# input_ids_predflex_bs4 = self.flex_model(dl_gtseq, dl_anm, trail_idcs, attention_mask = attention_mask, sampled_pmpnn_sequence = None, alphabet='aa') #['predicted_flex'][:,:-1,0]
|
| 485 |
+
# gt_flex = batch['gt_flex']
|
| 486 |
+
# # ####
|
| 487 |
+
# import pdb; pdb.set_trace() #check the mask and the gt_flex vs. onthefly computed gt_flex
|
| 488 |
+
# #TODO: here fix the mask for the prottrans and clean this,
|
| 489 |
+
# # the mask should have all 1s where there is sequence or eos token
|
| 490 |
+
|
| 491 |
+
# attention_mask = ...
|
| 492 |
+
# if self.hparams.get_gt_flex_onthefly:
|
| 493 |
+
|
| 494 |
+
# cache_keys = list(batch['title'])
|
| 495 |
+
|
| 496 |
+
# # Check if all cache_keys are in self.gt_flex_cache
|
| 497 |
+
# all_keys_in_cache = all(cache_key in self.model.gt_flex_cache for cache_key in cache_keys)
|
| 498 |
+
|
| 499 |
+
# if not all_keys_in_cache:
|
| 500 |
+
# gt_flex = self.flex_model(None, anm_input, trail_idcs, attention_mask=attention_mask, sampled_pmpnn_sequence=gt_seq, alphabet='pmpnn')['predicted_flex'][:,:-1,0]
|
| 501 |
+
# for key, val in zip(cache_keys, gt_flex):
|
| 502 |
+
# #TODO: iteruje to spravne?
|
| 503 |
+
# self.model.gt_flex_cache[key] = val
|
| 504 |
+
# else:
|
| 505 |
+
# retrieved_gt_flexs = []
|
| 506 |
+
# for key in cache_keys:
|
| 507 |
+
# _gt_flex = self.model.gt_flex_cache[key]
|
| 508 |
+
# retrieved_gt_flexs.append(_gt_flex)
|
| 509 |
+
# gt_flex = torch.cat(retrieved_gt_flexs, dim=0) #TODO: concat spravne?
|
| 510 |
+
# else:
|
| 511 |
+
# raise NotImplementedError('The precomputed data were not realiable.')
|
| 512 |
+
# gt_flex = batch['gt_flex']
|
| 513 |
+
# ###########################################################################
|
| 514 |
+
|
| 515 |
+
|
| 516 |
+
attention_mask = torch.zeros_like(batch['mask'])
|
| 517 |
+
for i in range(attention_mask.size(0)):
|
| 518 |
+
attention_mask[i, :trail_idcs[i]] = 1
|
| 519 |
+
|
| 520 |
+
#Original sequence loss
|
| 521 |
+
seq_loss = self.cross_entropy(pred_log_probs.permute(0,2,1), gt_seq)
|
| 522 |
+
seq_loss = (seq_loss*_mask).sum()/(_mask.sum())
|
| 523 |
+
#New Dynamics-aware loss
|
| 524 |
+
flex_model_input = pred_log_probs.permute(0,2,1)
|
| 525 |
+
pred_flex = self.flex_model(flex_model_input, anm_input, trail_idcs, attention_mask=attention_mask)['predicted_flex'][:,:-1,0]
|
| 526 |
+
#check here that the loss function is working properly (with the masking and all)
|
| 527 |
+
# import pdb; pdb.set_trace()
|
| 528 |
+
_filter_nans_mask = ~torch.isnan(pred_flex) & ~torch.isnan(gt_flex)
|
| 529 |
+
flex_loss = self.flex_loss_fn(pred_flex[_filter_nans_mask]*_mask[_filter_nans_mask], gt_flex[_filter_nans_mask]*_mask[_filter_nans_mask])
|
| 530 |
+
_flex_mask = _mask*_filter_nans_mask
|
| 531 |
+
_flex_mask = _flex_mask.int()
|
| 532 |
+
flex_loss = flex_loss.sum()/_flex_mask.sum()
|
| 533 |
+
|
| 534 |
+
retVal ={'seq_loss':seq_loss, 'flex_loss':flex_loss, 'pred_flex':pred_flex, 'flex_mask':_flex_mask, 'gt_flex':gt_flex}
|
| 535 |
+
if self.hparams.test_engineering and self.hparams.use_dynamics:
|
| 536 |
+
retVal['eng_mask'] = batch['eng_mask']
|
| 537 |
+
retVal['original_gt_flex'] = batch['original_gt_flex']
|
| 538 |
+
retVal['gt_seq'] = batch['S']
|
| 539 |
+
retVal['pred_logprobs'] = pred_log_probs
|
| 540 |
+
return retVal
|
| 541 |
+
|
| 542 |
+
|
| 543 |
+
def configure_loss(self):
|
| 544 |
+
def loss_function(pred_angle, angles, pred_seq, seqs, seq_loss_mask, angle_loss_mask):
|
| 545 |
+
angle_loss = self.MSE(torch.cat([angles[...,:1],torch.sin(angles[...,1:3]), torch.cos(angles[...,1:3])],dim=-1),
|
| 546 |
+
torch.cat([pred_angle[...,:1],torch.sin(pred_angle[...,1:3]), torch.cos(pred_angle[...,1:3])],dim=-1))
|
| 547 |
+
|
| 548 |
+
angle_loss = angle_loss[angle_loss_mask].sum(dim=-1).mean()
|
| 549 |
+
logits = pred_seq.permute(0,2,1)
|
| 550 |
+
seq_loss = self.cross_entropy(logits, seqs)
|
| 551 |
+
seq_loss = seq_loss[seq_loss_mask].mean()
|
| 552 |
+
|
| 553 |
+
metric=Perplexity()
|
| 554 |
+
metric.update(pred_seq[seq_loss_mask][None,...].cpu(), seqs[seq_loss_mask][None,...].cpu())
|
| 555 |
+
perp = metric.compute()
|
| 556 |
+
|
| 557 |
+
return {"angle_loss": angle_loss, "seq_loss": seq_loss, "perp":perp}
|
| 558 |
+
|
| 559 |
+
self.loss_function = loss_function
|
| 560 |
+
|
| 561 |
+
def load_model(self):
|
| 562 |
+
params = OmegaConf.load(f'configs/{self.hparams.model_name}.yaml')
|
| 563 |
+
params.update(self.hparams)
|
| 564 |
+
|
| 565 |
+
if self.hparams.model_name == 'GraphTrans':
|
| 566 |
+
from src.models.graphtrans_model import GraphTrans_Model
|
| 567 |
+
self.model = GraphTrans_Model(params)
|
| 568 |
+
|
| 569 |
+
if self.hparams.model_name == 'StructGNN':
|
| 570 |
+
from src.models.structgnn_model import StructGNN_Model
|
| 571 |
+
self.model = StructGNN_Model(params)
|
| 572 |
+
|
| 573 |
+
if self.hparams.model_name == 'GVP':
|
| 574 |
+
from src.models.gvp_model import GVP_Model
|
| 575 |
+
self.model = GVP_Model(params)
|
| 576 |
+
|
| 577 |
+
if self.hparams.model_name == 'GCA':
|
| 578 |
+
from src.models.gca_model import GCA_Model
|
| 579 |
+
self.model = GCA_Model(params)
|
| 580 |
+
|
| 581 |
+
if self.hparams.model_name == 'AlphaDesign':
|
| 582 |
+
from src.models.alphadesign_model import AlphaDesign_Model
|
| 583 |
+
self.model = AlphaDesign_Model(params)
|
| 584 |
+
|
| 585 |
+
if self.hparams.model_name == 'ProteinMPNN':
|
| 586 |
+
from src.models.proteinmpnn_model import ProteinMPNN_Model
|
| 587 |
+
self.model = ProteinMPNN_Model(params)
|
| 588 |
+
|
| 589 |
+
if self.hparams.model_name == 'ESMIF':
|
| 590 |
+
pass
|
| 591 |
+
|
| 592 |
+
if self.hparams.model_name == 'PiFold':
|
| 593 |
+
from src.models.pifold_model import PiFold_Model
|
| 594 |
+
self.model = PiFold_Model(params)
|
| 595 |
+
|
| 596 |
+
if self.hparams.model_name == 'KWDesign':
|
| 597 |
+
from src.models.kwdesign_model import KWDesign_model#Design_Model
|
| 598 |
+
self.model = KWDesign_model(params) #Design_Model(params) - this required to significantly change the constructor of Design_Model
|
| 599 |
+
|
| 600 |
+
if self.hparams.model_name == 'E3PiFold':
|
| 601 |
+
from src.models.E3PiFold_model import E3PiFold
|
| 602 |
+
self.model = E3PiFold(params)
|
| 603 |
+
|
| 604 |
+
def load_flex_predictor(self):
|
| 605 |
+
from src.models.anm_prottrans import ANMAwareFlexibilityProtTrans
|
| 606 |
+
flex_params = load_yaml_config(f'configs/ANMAwareFlexibilityProtTrans.yaml')
|
| 607 |
+
# flex_params_dict = OmegaConf.to_container(flex_params, resolve=True)
|
| 608 |
+
self.flex_model = ANMAwareFlexibilityProtTrans(**flex_params)
|
| 609 |
+
|
| 610 |
+
# consider turning on the gradients for debug purposes
|
| 611 |
+
self.flex_model.eval()
|
| 612 |
+
for params in self.flex_model.parameters():
|
| 613 |
+
params.requires_grad = False
|
| 614 |
+
|
| 615 |
+
#also pass it to proteinmpnn:
|
| 616 |
+
# self.model.flex_model = self.flex_model
|
| 617 |
+
|
| 618 |
+
|
| 619 |
+
def instancialize(self, Model, **other_args):
|
| 620 |
+
""" Instancialize a model using the corresponding parameters
|
| 621 |
+
from self.hparams dictionary. You can also input any args
|
| 622 |
+
to overwrite the corresponding value in self.hparams.
|
| 623 |
+
"""
|
| 624 |
+
class_args = inspect.getargspec(Model.__init__).args[1:]
|
| 625 |
+
inkeys = self.hparams.keys()
|
| 626 |
+
args1 = {}
|
| 627 |
+
for arg in class_args:
|
| 628 |
+
if arg in inkeys:
|
| 629 |
+
args1[arg] = getattr(self.hparams, arg)
|
| 630 |
+
args1.update(other_args)
|
| 631 |
+
return Model(**args1)
|
Flexpert-Design/predict.py
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os, sys, warnings, argparse, math, tqdm, datetime
|
| 2 |
+
import pytorch_lightning as pl
|
| 3 |
+
import torch
|
| 4 |
+
from pytorch_lightning.trainer import Trainer
|
| 5 |
+
import pytorch_lightning.callbacks as plc
|
| 6 |
+
import pytorch_lightning.loggers as plog
|
| 7 |
+
from model_interface import MInterface
|
| 8 |
+
from data_interface import DInterface
|
| 9 |
+
from src.tools.logger import SetupCallback, BackupCodeCallback
|
| 10 |
+
from shutil import ignore_patterns
|
| 11 |
+
from transformers import AutoTokenizer
|
| 12 |
+
import numpy as np
|
| 13 |
+
import yaml
|
| 14 |
+
import wandb
|
| 15 |
+
warnings.filterwarnings("ignore")
|
| 16 |
+
|
| 17 |
+
def create_parser():
|
| 18 |
+
parser = argparse.ArgumentParser()
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
parser.add_argument('--infer_path', type=str, help='Path where to read the data to be predicted and where to save the predictions.')
|
| 22 |
+
|
| 23 |
+
# Set-up parameters
|
| 24 |
+
parser.add_argument('--res_dir', default='./train/results', type=str)
|
| 25 |
+
parser.add_argument('--ex_name', default='debug', type=str)
|
| 26 |
+
parser.add_argument('--check_val_every_n_epoch', default=1, type=int)
|
| 27 |
+
parser.add_argument('--stage', default='predict', type=str) #'fit', 'test' or 'predict'
|
| 28 |
+
parser.add_argument('--val_check_interval', default=0.5, type=float, help='Validation check interval')
|
| 29 |
+
|
| 30 |
+
parser.add_argument('--dataset', default='PDBInference') # AF2DB_dataset, CATH_dataset
|
| 31 |
+
parser.add_argument('--model_name', default='ProteinMPNN', choices=['StructGNN', 'GraphTrans', 'GVP', 'GCA', 'AlphaDesign', 'ESMIF', 'PiFold', 'ProteinMPNN', 'KWDesign', 'E3PiFold'])
|
| 32 |
+
# parser.add_argument('--lr', default=4e-4, type=float, help='Learning rate')
|
| 33 |
+
# parser.add_argument('--lr_scheduler', default='onecycle')
|
| 34 |
+
# parser.add_argument('--offline', default=1, type=int)
|
| 35 |
+
parser.add_argument('--seed', default=111, type=int)
|
| 36 |
+
|
| 37 |
+
parser.add_argument('--num_workers', default=12, type=int)
|
| 38 |
+
parser.add_argument('--pad', default=1024, type=int)
|
| 39 |
+
parser.add_argument('--min_length', default=40, type=int)
|
| 40 |
+
parser.add_argument('--data_root', default='./data/')
|
| 41 |
+
|
| 42 |
+
# Training parameters
|
| 43 |
+
# parser.add_argument('--epoch', default=10, type=int, help='end epoch')
|
| 44 |
+
parser.add_argument('--augment_eps', default=0.0, type=float, help='noise level')
|
| 45 |
+
# parser.add_argument('--gpus', default=1, type=int, help='how many GPUs to train on')
|
| 46 |
+
# parser.add_argument('--weight_decay', default=0.0, type=float, help='Weight decay for optimizer')
|
| 47 |
+
|
| 48 |
+
# # Eval parameters
|
| 49 |
+
# parser.add_argument('--eval_sequences_sampled', default=1, type=int, help='How many sequences to sample in evaluation.')
|
| 50 |
+
# parser.add_argument('--eval_sequences_temperature', default=0, type=float, help='What temperature to use for the sampling in evaluation.')
|
| 51 |
+
# parser.add_argument('--eval_output_dir', default=None, type=str, help='Where to save the evaluation output.')
|
| 52 |
+
|
| 53 |
+
# Model parameters
|
| 54 |
+
parser.add_argument('--use_dist', default=1, type=int)
|
| 55 |
+
parser.add_argument('--use_product', default=0, type=int)
|
| 56 |
+
parser.add_argument('--use_pmpnn_checkpoint', default=1, type=int, help='By 1 or 0 decide whether to start with pretrained ProteinMPNN.')
|
| 57 |
+
parser.add_argument('--checkpoint_path', type=str, default=None, help='Path to the model checkpoint to load weights from')
|
| 58 |
+
|
| 59 |
+
# Dynamics aware parameters
|
| 60 |
+
parser.add_argument('--use_dynamics', default=0, type=int)
|
| 61 |
+
# parser.add_argument('--flex_loss_coeff', default=0.5, type=float)
|
| 62 |
+
# parser.add_argument('--get_gt_flex_onthefly', default=0, type=int, help='Flag to get ground truth flexibility on-the-fly (with subsequent caching)')
|
| 63 |
+
parser.add_argument('--init_flex_features', default=1, type=int, help="Set to 0 if no flexibility information should be passed on input to the node features h_V")
|
| 64 |
+
# parser.add_argument('--loss_fn', default='MSE', type=str, help= 'Define what loss to use. Choose MSE, L1 or DPO.')
|
| 65 |
+
# parser.add_argument('--grad_normalization', default=1, type=int, help="Set to 0 if the gradients of the seq and flex losses should not be normalized.")
|
| 66 |
+
# parser.add_argument('--test_engineering', default=0, type=int, help="In this main.py should be set to 0 to not overwrite the training dataset.")
|
| 67 |
+
|
| 68 |
+
args = parser.parse_args()
|
| 69 |
+
return args
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
if __name__ == "__main__":
|
| 73 |
+
|
| 74 |
+
args = create_parser()
|
| 75 |
+
args.batch_size = 1
|
| 76 |
+
print('In the predict stage, defaulting batch size to 1.')
|
| 77 |
+
|
| 78 |
+
assert args.use_dynamics == 0, "In the inference script this should be set to 0."
|
| 79 |
+
|
| 80 |
+
if not os.path.exists(args.infer_path):
|
| 81 |
+
os.makedirs(args.infer_path)
|
| 82 |
+
|
| 83 |
+
if (len(args.infer_path) > 0 or args.dataset=='PDBInference') and (len(args.infer_path) == 0 or args.dataset!='PDBInference'):
|
| 84 |
+
raise ValueError("You should only use --infer_path with --dataset 'PDBInference' and vice versa.")
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
# Load model weights from checkpoint if provided
|
| 88 |
+
if args.checkpoint_path is not None:
|
| 89 |
+
trained_model_path = args.checkpoint_path
|
| 90 |
+
print(f"Loading model weights from checkpoint passed by argument: {trained_model_path}")
|
| 91 |
+
else:
|
| 92 |
+
with open('configs/Flexpert-Design-inference.yaml', 'r') as f:
|
| 93 |
+
config = yaml.load(f, Loader=yaml.FullLoader)
|
| 94 |
+
trained_model_path = config['pmpnn_model_path']
|
| 95 |
+
print(f"Loading model weights from checkpoint specified in Flexpert-Design-inference.yaml: {trained_model_path}")
|
| 96 |
+
|
| 97 |
+
if os.path.exists(trained_model_path):
|
| 98 |
+
print(f"Rewriting the path to the Flexpert-Design trained ProteinMPNN weights in the model interface.")
|
| 99 |
+
args.starting_checkpoint_path = trained_model_path
|
| 100 |
+
else:
|
| 101 |
+
raise FileNotFoundError(f"Checkpoint file not found at {trained_model_path}")
|
| 102 |
+
|
| 103 |
+
pl.seed_everything(args.seed)
|
| 104 |
+
|
| 105 |
+
data_module = DInterface(**vars(args))
|
| 106 |
+
|
| 107 |
+
data_module.setup(stage='predict')
|
| 108 |
+
|
| 109 |
+
model = MInterface(**vars(args))
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
trainer_config = {
|
| 113 |
+
'devices': 1,
|
| 114 |
+
'max_epochs': 1,
|
| 115 |
+
'num_nodes': 1,
|
| 116 |
+
"strategy": 'ddp',
|
| 117 |
+
"precision": '32',
|
| 118 |
+
'accelerator': 'gpu',
|
| 119 |
+
'val_check_interval': args.val_check_interval,
|
| 120 |
+
'check_val_every_n_epoch': args.check_val_every_n_epoch
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
trainer = Trainer(**trainer_config)
|
| 124 |
+
|
| 125 |
+
predictions = trainer.predict(model, data_module)
|
| 126 |
+
|
| 127 |
+
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D", cache_dir='./cache_dir/') # mask token: 32
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
serializable_predictions = []
|
| 131 |
+
for pred_idx, pred in enumerate(predictions):
|
| 132 |
+
logprobs = pred['log_probs'].cpu().numpy()[0] # [L, 21]
|
| 133 |
+
pmpnn_alphabet_tokens_argmax = logprobs.argmax(axis=-1) # [L]
|
| 134 |
+
|
| 135 |
+
aa_sequence = ''.join(tokenizer.decode(pmpnn_alphabet_tokens_argmax, skip_special_tokens=True).split())
|
| 136 |
+
|
| 137 |
+
# Get probability of the predicted sequence
|
| 138 |
+
seq_probs = np.exp(logprobs.max(axis=-1)) # [L]
|
| 139 |
+
avg_prob = float(np.mean(seq_probs))
|
| 140 |
+
|
| 141 |
+
serializable_predictions.append({
|
| 142 |
+
'prediction_id': pred['batch']['title'][0],
|
| 143 |
+
'amino_acid_sequence': aa_sequence
|
| 144 |
+
})
|
| 145 |
+
|
| 146 |
+
with open(f'{args.infer_path}/predictions.txt', 'w') as f:
|
| 147 |
+
for pred in serializable_predictions:
|
| 148 |
+
f.write(f'>{pred["prediction_id"]}\n{pred["amino_acid_sequence"]}\n')
|
Flexpert-Design/predict_example/1ah7_A.pdb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
Flexpert-Design/predict_example/1ah7_A_instructions.csv
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
0.4214228391647339, 0.37651416659355164, 0.1882496476173401, 0.13774731755256653, 0.11560429632663727, 0.12345632910728455, 0.11075370758771896, 0.09350624680519104, 0.06162628158926964, 0.08504123985767365, 0.05511573329567909, 0.03457929939031601, 0.018956221640110016, 0.05267956107854843, 0.021582268178462982, 0.019682325422763824, 0.005200381390750408, 0.01862833835184574, 0.037708550691604614, 0.02962341532111168, 0.0414130762219429, 0.032966580241918564, 0.04219468683004379, 0.043324653059244156, 0.038419052958488464, 0.06062019616365433, 0.0754077360033989, 0.09575366973876953, 0.09765047580003738, 0.1067374125123024, 0.08417803794145584, 0.09050130844116211, 0.07099245488643646, 0.06242087855935097, 0.046906158328056335, 0.024977944791316986, 0.04039282351732254, 0.04056069627404213, 0.02624698355793953, 0.014836937189102173, 0.033674903213977814, 0.03443623706698418, 0.04525991156697273, 0.05213414505124092, 0.02986733242869377, 0.01742064766585827, 0.03752005845308304, 0.02649688348174095, 0.02672465145587921, 0.03430021554231644, 0.011848143301904202, 0.03361964225769043, 0.027863629162311554, 0.03575271740555763, 0.041042227298021317, 0.08238421380519867, 0.08222152292728424, 0.10173829644918442, 0.11664807796478271, 0.13249793648719788, 0.14384658634662628, 0.1345130354166031, 0.13609081506729126, 0.12496572732925415, 0.10717709362506866, 0.08230947703123093, 0.07971317321062088, 0.07025592774152756, 0.06319792568683624, 0.06464961171150208, 0.04482023045420647, 0.051742203533649445, 0.07986844331026077, 0.09591078013181686, 0.11425718665122986, 0.11205209791660309, 0.10624780505895615, 0.10313349217176437, 0.13002970814704895, 0.13183605670928955, 0.15288424491882324, 0.14854931831359863, 0.13990001380443573, 0.09912189096212387, 0.09130637347698212, 0.07575594633817673, 0.061887726187705994, 0.06014912575483322, 0.0577777624130249, 0.051302842795848846, 0.03530939295887947, 0.040248580276966095, 0.0013590790331363678, 0.015310827642679214, 0.03272499516606331, 0.02609187364578247, 0.0054176910780370235, 0.05427498370409012, 0.051064278930425644, 0.06116481125354767, 0.06309916824102402, 0.08470715582370758, 0.1002785935997963, 0.10495362430810928, 0.09807638078927994, 0.0662725567817688, 0.06513857841491699, 0.048988863825798035, 0.029838263988494873, 0.025865966454148293, 0.02097484841942787, 0.014891650527715683, 0.024081528186798096, 0.045654937624931335, 0.052093200385570526, 0.017663143575191498, 0.02189275622367859, 0.08543915301561356, 0.03505314886569977, 0.019413039088249207, 0.045589953660964966, 0.06793230772018433, 0.041016142815351486, 0.05003933981060982, 0.053235944360494614, 0.05916681885719299, 0.058036819100379944, 0.060588233172893524, 0.07040168344974518, 0.07345925271511078, 0.08298061043024063, 0.09419634193181992, 0.11146273463964462, 0.1405934989452362, 0.15075145661830902, 0.13765454292297363, 0.13978315889835358, 0.1482282280921936, 0.1423584669828415, 0.10484395921230316, 0.07584157586097717, 0.06757079809904099, 0.10134144872426987, 0.08083963394165039, 0.07369125634431839, 0.05454648658633232, 0.08305331319570541, 0.07765821367502213, 0.06511223316192627, 0.056034114211797714, 0.08081831783056259, 0.08526752144098282, 0.07231731712818146, 0.07028429210186005, 0.08094073832035065, 0.06563611328601837, 0.07806897908449173, 0.0859430730342865, 0.08600828796625137, 0.08605027943849564, 0.08578154444694519, 0.07862624526023865, 0.07963275909423828, 0.06170313060283661, 0.05005127564072609, 0.05146761238574982, 0.05499078333377838, 0.059220947325229645, 0.06969373673200607, 0.05268307402729988, 0.06721088290214539, 0.04827176779508591, 0.029251907020807266, 0.04153018817305565, 0.03697451949119568, 0.025836892426013947, 0.04521643742918968, 0.05554598197340965, 0.06007472425699234, 0.04923863708972931, 0.06502534449100494, 0.04392743483185768, 0.036296453326940536, 0.0436030775308609, 0.05658774450421333, 0.034551363438367844, 0.049478061497211456, 0.059964731335639954, 0.07313965260982513, 0.062442418187856674, 0.06896451860666275, 0.08025174587965012, 0.08270157873630524, 0.09252781420946121, 0.09688305854797363, 0.11343701928853989, 0.11080081015825272, 0.12969090044498444, 0.10972093790769577, 0.14756494760513306, 0.1637764275074005, 0.18948377668857574, 0.18522633612155914, 0.18577809631824493, 0.20364972949028015, 0.17919236421585083, 0.1657918244600296, 0.1515847146511078, 0.11915461719036102, 0.10438721626996994, 0.10422837734222412, 0.08969437330961227, 0.07429874688386917, 0.07801567018032074, 0.06531910598278046, 0.05813637375831604, 0.04699350893497467, 0.05086237192153931, 0.060301560908555984, 0.04986414313316345, 0.050366174429655075, 0.05464963987469673, 0.05319518595933914, 0.04274186119437218, 0.047863349318504333, 0.036163944751024246, 0.03777360916137695, 0.04280579090118408, 0.04440606012940407, 0.04717888683080673, 0.02609632909297943, 0.05858827754855156, 0.050790246576070786, 0.03004802018404007, 0.04584358632564545, 0.05146845430135727, 0.039567168802022934, 0.03470978885889053, 0.045542243868112564, 0.05142106115818024, 0.05224252864718437, 0.07936417311429977, 0.11145134270191193, 0.14930342137813568, 0.21797531843185425
|
Flexpert-Design/predict_example/compare_seqs.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# read in the predictions.txt file
|
| 2 |
+
# take the sequence from there
|
| 3 |
+
import argparse
|
| 4 |
+
import os
|
| 5 |
+
import biotite.structure.io.pdb as pdb
|
| 6 |
+
from biotite.structure import get_residues
|
| 7 |
+
|
| 8 |
+
def compare_sequences(pdb_code):
|
| 9 |
+
# Read the predicted sequence from predictions.txt
|
| 10 |
+
with open('predict_example/predictions.txt', 'r') as f:
|
| 11 |
+
predictions = f.readlines()
|
| 12 |
+
# Extract the sequence (skip the header line that starts with '>')
|
| 13 |
+
predicted_seqs = {}
|
| 14 |
+
current_pdb = None
|
| 15 |
+
|
| 16 |
+
for line in predictions:
|
| 17 |
+
if line.startswith('>'):
|
| 18 |
+
current_pdb = line.strip()[1:] # Remove the '>' character
|
| 19 |
+
elif current_pdb and line.strip():
|
| 20 |
+
predicted_seqs[current_pdb] = line.strip()
|
| 21 |
+
|
| 22 |
+
# Use the provided pdb_code to get the corresponding sequence
|
| 23 |
+
predicted_seq = predicted_seqs.get(pdb_code, "")
|
| 24 |
+
|
| 25 |
+
# Read the PDB file
|
| 26 |
+
pdb_file = f'predict_example/{pdb_code}.pdb'
|
| 27 |
+
with open(pdb_file, 'r') as f:
|
| 28 |
+
structure = pdb.PDBFile.read(f)
|
| 29 |
+
atoms = pdb.get_structure(structure)
|
| 30 |
+
|
| 31 |
+
# Get residue names from the structure
|
| 32 |
+
residues = get_residues(atoms)[1]
|
| 33 |
+
# Convert three-letter codes to one-letter codes
|
| 34 |
+
aa_dict = {
|
| 35 |
+
'ALA': 'A', 'CYS': 'C', 'ASP': 'D', 'GLU': 'E', 'PHE': 'F',
|
| 36 |
+
'GLY': 'G', 'HIS': 'H', 'ILE': 'I', 'LYS': 'K', 'LEU': 'L',
|
| 37 |
+
'MET': 'M', 'ASN': 'N', 'PRO': 'P', 'GLN': 'Q', 'ARG': 'R',
|
| 38 |
+
'SER': 'S', 'THR': 'T', 'VAL': 'V', 'TRP': 'W', 'TYR': 'Y'
|
| 39 |
+
}
|
| 40 |
+
pdb_seq = ''.join([aa_dict.get(res, 'X') for res in residues])
|
| 41 |
+
|
| 42 |
+
# Compare the two sequences
|
| 43 |
+
match_count = sum(1 for a, b in zip(predicted_seq, pdb_seq) if a == b)
|
| 44 |
+
total_length = max(len(predicted_seq), len(pdb_seq))
|
| 45 |
+
percent_identity = (match_count / min(len(predicted_seq), len(pdb_seq))) * 100
|
| 46 |
+
|
| 47 |
+
# Print the result
|
| 48 |
+
print(f"Predicted sequence: {predicted_seq}")
|
| 49 |
+
print(f"PDB sequence: {pdb_seq}")
|
| 50 |
+
print(f"Sequence length - Predicted: {len(predicted_seq)}, PDB: {len(pdb_seq)}")
|
| 51 |
+
print(f"Matching residues: {match_count}/{min(len(predicted_seq), len(pdb_seq))}")
|
| 52 |
+
print(f"Percent identity: {percent_identity:.2f}%")
|
| 53 |
+
|
| 54 |
+
if __name__ == "__main__":
|
| 55 |
+
parser = argparse.ArgumentParser(description='Compare predicted sequence with PDB sequence')
|
| 56 |
+
parser.add_argument('--pdb_code', type=str, help='PDB code (e.g., 1ah7_A)')
|
| 57 |
+
args = parser.parse_args()
|
| 58 |
+
|
| 59 |
+
compare_sequences(args.pdb_code)
|
Flexpert-Design/predict_example/predictions.txt
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
>1ah7_A
|
| 2 |
+
GSSLDKTEVEESTGLRLVNQAIDILKNDKTRVDKEYLDLIEKYKPELQEGIYKAYHSEPYNDNGKFSRHYYNPVVHTSRIPDAVTAAETGSHYYNKAGEYYKKGDYEEAYFYLGIALAYLSDACNPMNASGYTNESFPEGFYEALQKYVCTIAKKYENTTGEPYYNLTGKNPKDHIRGAATKARELFSGIYHERVKEDFEKGKTSEEARLKWRERIEPQLGKLLLFAQRVMAGAIERFFDTAGGL
|
Flexpert-Design/requirements.txt
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Bio
|
| 2 |
+
biotite
|
| 3 |
+
fair-esm
|
| 4 |
+
evaluate
|
| 5 |
+
joblib
|
| 6 |
+
matplotlib
|
| 7 |
+
numpy
|
| 8 |
+
omegaconf
|
| 9 |
+
packaging
|
| 10 |
+
pandas
|
| 11 |
+
python_dateutil
|
| 12 |
+
pytorch_lightning
|
| 13 |
+
PyYAML
|
| 14 |
+
requests
|
| 15 |
+
safetensors
|
| 16 |
+
scikit_learn
|
| 17 |
+
scipy
|
| 18 |
+
torch
|
| 19 |
+
torch_geometric
|
| 20 |
+
torcheval
|
| 21 |
+
torchmetrics
|
| 22 |
+
tqdm
|
| 23 |
+
transformers
|
Flexpert-Design/src/__init__.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) CAIRI AI Lab. All rights reserved
|
| 2 |
+
|
| 3 |
+
import warnings
|
| 4 |
+
from packaging.version import parse
|
| 5 |
+
|
| 6 |
+
from .version import __version__
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def digit_version(version_str: str, length: int = 4):
|
| 10 |
+
"""Convert a version string into a tuple of integers.
|
| 11 |
+
|
| 12 |
+
This method is usually used for comparing two versions. For pre-release
|
| 13 |
+
versions: alpha < beta < rc.
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
version_str (str): The version string.
|
| 17 |
+
length (int): The maximum number of version levels. Default: 4.
|
| 18 |
+
|
| 19 |
+
Returns:
|
| 20 |
+
tuple[int]: The version info in digits (integers).
|
| 21 |
+
"""
|
| 22 |
+
version = parse(version_str)
|
| 23 |
+
assert version.release, f'failed to parse version {version_str}'
|
| 24 |
+
release = list(version.release)
|
| 25 |
+
release = release[:length]
|
| 26 |
+
if len(release) < length:
|
| 27 |
+
release = release + [0] * (length - len(release))
|
| 28 |
+
if version.is_prerelease:
|
| 29 |
+
mapping = {'a': -3, 'b': -2, 'rc': -1}
|
| 30 |
+
val = -4
|
| 31 |
+
# version.pre can be None
|
| 32 |
+
if version.pre:
|
| 33 |
+
if version.pre[0] not in mapping:
|
| 34 |
+
warnings.warn(f'unknown prerelease version {version.pre[0]}, '
|
| 35 |
+
'version checking may go wrong')
|
| 36 |
+
else:
|
| 37 |
+
val = mapping[version.pre[0]]
|
| 38 |
+
release.extend([val, version.pre[-1]])
|
| 39 |
+
else:
|
| 40 |
+
release.extend([val, 0])
|
| 41 |
+
|
| 42 |
+
elif version.is_postrelease:
|
| 43 |
+
release.extend([1, version.post])
|
| 44 |
+
else:
|
| 45 |
+
release.extend([0, 0])
|
| 46 |
+
return tuple(release)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
__all__ = ['__version__', 'digit_version']
|
Flexpert-Design/src/datasets/__init__.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) CAIRI AI Lab. All rights reserved
|
| 2 |
+
|
| 3 |
+
from .alphafold_dataset import AlphaFoldDataset
|
| 4 |
+
from .cath_dataset import CATHDataset
|
| 5 |
+
from .dataloader import load_data
|
| 6 |
+
from .featurizer import (featurize_AF, featurize_GTrans, featurize_GVP,
|
| 7 |
+
featurize_ProteinMPNN, featurize_Inversefolding)
|
| 8 |
+
from .ts_dataset import TSDataset
|
| 9 |
+
|
| 10 |
+
__all__ = [
|
| 11 |
+
'AlphaFoldDataset', 'CATHDataset', 'TSDataset',
|
| 12 |
+
'load_data',
|
| 13 |
+
'featurize_AF', 'featurize_GTrans', 'featurize_GVP',
|
| 14 |
+
'featurize_ProteinMPNN', 'featurize_Inversefolding'
|
| 15 |
+
]
|
Flexpert-Design/src/datasets/alphafold_dataset.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import os.path as osp
|
| 3 |
+
import json
|
| 4 |
+
import numpy as np
|
| 5 |
+
import pickle as cPickle
|
| 6 |
+
|
| 7 |
+
import torch.utils.data as data
|
| 8 |
+
from src.datasets.utils import cached_property
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class AlphaFoldDataset(data.Dataset):
|
| 12 |
+
def __init__(self, path='./', upid='', mode='train', max_length=500, limit_length=1, joint_data=0):
|
| 13 |
+
|
| 14 |
+
self.path = path
|
| 15 |
+
self.upid = upid
|
| 16 |
+
self.max_length = max_length
|
| 17 |
+
self.limit_length = limit_length
|
| 18 |
+
self.joint_data = joint_data
|
| 19 |
+
|
| 20 |
+
if mode in ['train', 'valid', 'test']:
|
| 21 |
+
self.data = self.cache_data[mode]
|
| 22 |
+
|
| 23 |
+
if mode == 'all':
|
| 24 |
+
self.data = self.cache_data['train'] + self.cache_data['valid'] + self.cache_data['test']
|
| 25 |
+
|
| 26 |
+
self.lengths = np.array([ len(sample['seq']) for sample in self.data])
|
| 27 |
+
self.max_len = np.max(self.lengths)
|
| 28 |
+
self.min_len = np.min(self.lengths)
|
| 29 |
+
|
| 30 |
+
def _raw_data(self, path, upid):
|
| 31 |
+
if not os.path.exists(path):
|
| 32 |
+
raise "no such file:{} !!!".format(path)
|
| 33 |
+
else:
|
| 34 |
+
path = osp.join(path, upid)
|
| 35 |
+
data_ = cPickle.load(open(path+'/data_{}.pkl'.format(upid),'rb'))
|
| 36 |
+
score_ = cPickle.load(open(path+'/data_{}_score.pkl'.format(upid),'rb'))
|
| 37 |
+
for i in range(len(data_)):
|
| 38 |
+
data_[i]['score'] = score_[i]['res_score']
|
| 39 |
+
return data_
|
| 40 |
+
|
| 41 |
+
def _data_info(self, data):
|
| 42 |
+
len_inds = []
|
| 43 |
+
seq2ind = {}
|
| 44 |
+
for ind, temp in enumerate(data):
|
| 45 |
+
if self.limit_length:
|
| 46 |
+
if 30 < len(temp['seq']) and len(temp['seq']) < self.max_length:
|
| 47 |
+
# 'title', 'seq', 'CA', 'C', 'O', 'N'
|
| 48 |
+
len_inds.append(ind)
|
| 49 |
+
seq2ind[temp['seq']] = ind
|
| 50 |
+
else:
|
| 51 |
+
len_inds.append(ind)
|
| 52 |
+
seq2ind[temp['seq']] = ind
|
| 53 |
+
return len_inds, seq2ind
|
| 54 |
+
|
| 55 |
+
def get_data(self, path, upid, **kwargs):
|
| 56 |
+
data_ = self._raw_data(path, upid)
|
| 57 |
+
path = osp.join(path, upid)
|
| 58 |
+
|
| 59 |
+
file_name = 'split_clu_l.json' if self.limit_length else 'split_clu.json'
|
| 60 |
+
|
| 61 |
+
assert os.path.exists(osp.join(path, file_name))
|
| 62 |
+
split = json.load(open(osp.join(path, file_name),'r'))
|
| 63 |
+
data_dict = {'train':[data_[i] for i in split['train']],
|
| 64 |
+
'valid':[data_[i] for i in split['valid']],
|
| 65 |
+
'test':[data_[i] for i in split['test']]}
|
| 66 |
+
return data_dict
|
| 67 |
+
|
| 68 |
+
def get_full_data(self, path, **kwargs):
|
| 69 |
+
datanames = [dataname for dataname in os.listdir(path) if ('_v2' in dataname)]
|
| 70 |
+
file_name = 'split_clu_l.json' if self.limit_length else 'split_clu.json'
|
| 71 |
+
assert os.path.exists(osp.join(path, 'full', file_name))
|
| 72 |
+
split = json.load(open(osp.join(path, 'full', file_name),'r'))
|
| 73 |
+
return split
|
| 74 |
+
|
| 75 |
+
@cached_property
|
| 76 |
+
def cache_data(self): # TODO: joint_data
|
| 77 |
+
path = self.path
|
| 78 |
+
upid = self.upid
|
| 79 |
+
if self.joint_data:
|
| 80 |
+
datanames = [dataname for dataname in os.listdir(path) if ('_v2' in dataname)]
|
| 81 |
+
data_dict = {'train':[], 'valid':[], 'test':[]}
|
| 82 |
+
full_inds = self.get_full_data(path)
|
| 83 |
+
|
| 84 |
+
for dataname in datanames:
|
| 85 |
+
temp = self._raw_data(path, dataname)
|
| 86 |
+
train_idx, valid_idx, test_idx = map(lambda fold: full_inds[dataname][fold], ['train', 'valid', 'test'])
|
| 87 |
+
data_dict['train'] += [temp[i] for i in train_idx]
|
| 88 |
+
data_dict['valid'] += [temp[i] for i in valid_idx]
|
| 89 |
+
|
| 90 |
+
data_test = []
|
| 91 |
+
for i in test_idx:
|
| 92 |
+
item = temp[i]
|
| 93 |
+
item['category'] = dataname
|
| 94 |
+
data_test.append(temp[i])
|
| 95 |
+
|
| 96 |
+
data_dict['test'] += data_test
|
| 97 |
+
|
| 98 |
+
else:
|
| 99 |
+
data_dict = self.get_data(path, upid)
|
| 100 |
+
for item in data_dict['test']:
|
| 101 |
+
item['category'] = upid
|
| 102 |
+
|
| 103 |
+
return data_dict
|
| 104 |
+
|
| 105 |
+
def change_mode(self, mode):
|
| 106 |
+
self.data = self.cache_data[mode]
|
| 107 |
+
|
| 108 |
+
def __len__(self):
|
| 109 |
+
return len(self.data)
|
| 110 |
+
|
| 111 |
+
def __getitem__(self, index):
|
| 112 |
+
return self.data[index]
|
Flexpert-Design/src/datasets/atlas_dataset.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import numpy as np
|
| 4 |
+
from tqdm import tqdm
|
| 5 |
+
import random
|
| 6 |
+
import pdb
|
| 7 |
+
import torch.utils.data as data
|
| 8 |
+
from .utils import cached_property
|
| 9 |
+
from transformers import AutoTokenizer
|
| 10 |
+
|
| 11 |
+
class AtlasDataset(data.Dataset):
|
| 12 |
+
def __init__(self, path='./', split='train', max_length=500, test_name='All', data = None, removeTS=0):
|
| 13 |
+
self.path = path
|
| 14 |
+
self.mode = split
|
| 15 |
+
self.max_length = max_length
|
| 16 |
+
self.test_name = test_name
|
| 17 |
+
self.removeTS = removeTS
|
| 18 |
+
if self.removeTS:
|
| 19 |
+
self.remove = json.load(open(self.path+'/remove.json', 'r'))['remove']
|
| 20 |
+
|
| 21 |
+
if data is None:
|
| 22 |
+
if self.mode in ['eval','predict']:
|
| 23 |
+
self.data = self.cache_data['test'] #This calls the cache_data property
|
| 24 |
+
else:
|
| 25 |
+
self.data = self.cache_data[split] #This calls the cache_data property
|
| 26 |
+
else:
|
| 27 |
+
self.data = data
|
| 28 |
+
|
| 29 |
+
self.tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D", cache_dir="./cache_dir/")
|
| 30 |
+
|
| 31 |
+
@cached_property
|
| 32 |
+
def cache_data(self):
|
| 33 |
+
alphabet='ACDEFGHIKLMNPQRSTVWY'
|
| 34 |
+
alphabet_set = set([a for a in alphabet])
|
| 35 |
+
print("path is: ", self.path)
|
| 36 |
+
|
| 37 |
+
if not os.path.exists(self.path):
|
| 38 |
+
raise "no such file:{} !!!".format(self.path)
|
| 39 |
+
else:
|
| 40 |
+
|
| 41 |
+
with open(self.path+'/chain_set.jsonl') as f:
|
| 42 |
+
lines = f.readlines()
|
| 43 |
+
data_list = []
|
| 44 |
+
|
| 45 |
+
for line in tqdm(lines):
|
| 46 |
+
entry = json.loads(line)
|
| 47 |
+
|
| 48 |
+
if self.removeTS and entry['name'] in self.remove:
|
| 49 |
+
continue
|
| 50 |
+
seq = entry['seq']
|
| 51 |
+
|
| 52 |
+
for key, val in entry['coords'].items():
|
| 53 |
+
entry['coords'][key] = np.asarray(val)
|
| 54 |
+
|
| 55 |
+
bad_chars = set([s for s in seq]).difference(alphabet_set)
|
| 56 |
+
|
| 57 |
+
if len(bad_chars) == 0:
|
| 58 |
+
if len(entry['seq']) <= self.max_length:
|
| 59 |
+
chain_length = len(entry['seq'])
|
| 60 |
+
chain_mask = np.ones(chain_length)
|
| 61 |
+
data_list.append({
|
| 62 |
+
'title':entry['name'],
|
| 63 |
+
'seq':entry['seq'],
|
| 64 |
+
'CA':entry['coords']['CA'],
|
| 65 |
+
'C':entry['coords']['C'],
|
| 66 |
+
'O':entry['coords']['O'],
|
| 67 |
+
'N':entry['coords']['N'],
|
| 68 |
+
'chain_mask': chain_mask,
|
| 69 |
+
'chain_encoding': 1*chain_mask
|
| 70 |
+
})
|
| 71 |
+
|
| 72 |
+
with open(self.path+'/chain_set_splits.json') as f:
|
| 73 |
+
dataset_splits = json.load(f)
|
| 74 |
+
|
| 75 |
+
if self.test_name == 'L100':
|
| 76 |
+
with open(self.path+'/test_split_L100.json') as f:
|
| 77 |
+
test_splits = json.load(f)
|
| 78 |
+
dataset_splits['test'] = test_splits['test']
|
| 79 |
+
|
| 80 |
+
if self.test_name == 'sc':
|
| 81 |
+
with open(self.path+'/test_split_sc.json') as f:
|
| 82 |
+
test_splits = json.load(f)
|
| 83 |
+
dataset_splits['test'] = test_splits['test']
|
| 84 |
+
|
| 85 |
+
name2set = {}
|
| 86 |
+
name2set.update({name:'train' for name in dataset_splits['train']})
|
| 87 |
+
name2set.update({name:'valid' for name in dataset_splits['validation']})
|
| 88 |
+
name2set.update({name:'test' for name in dataset_splits['test']})
|
| 89 |
+
|
| 90 |
+
data_dict = {'train':[],'valid':[],'test':[]}
|
| 91 |
+
for data in data_list:
|
| 92 |
+
#raise ValueError("only 1015 sequences get loaded to the predict set!!! why not whole 1390??? FIX!")
|
| 93 |
+
if name2set.get(data['title']): #This was causing the trouble with empty datasets - missmatch of names in the chain_set and chain_set_split
|
| 94 |
+
if name2set[data['title']] == 'train':
|
| 95 |
+
data_dict['train'].append(data)
|
| 96 |
+
|
| 97 |
+
if name2set[data['title']] == 'valid':
|
| 98 |
+
data_dict['valid'].append(data)
|
| 99 |
+
|
| 100 |
+
if name2set[data['title']] == 'test':
|
| 101 |
+
data['category'] = 'Unkown'
|
| 102 |
+
data['score'] = 100.0
|
| 103 |
+
data_dict['test'].append(data)
|
| 104 |
+
else:
|
| 105 |
+
import pdb; pdb.set_trace()
|
| 106 |
+
return data_dict
|
| 107 |
+
|
| 108 |
+
def change_mode(self, mode):
|
| 109 |
+
self.data = self.cache_data[mode]
|
| 110 |
+
|
| 111 |
+
def __len__(self):
|
| 112 |
+
return len(self.data)
|
| 113 |
+
|
| 114 |
+
def get_item(self, index):
|
| 115 |
+
return self.data[index]
|
| 116 |
+
|
| 117 |
+
def __getitem__(self, index):
|
| 118 |
+
item = self.data[index]
|
| 119 |
+
L = len(item['seq'])
|
| 120 |
+
if L>self.max_length:
|
| 121 |
+
# 计算截断的最大索引
|
| 122 |
+
max_index = L - self.max_length
|
| 123 |
+
# 生成随机的截断索引
|
| 124 |
+
truncate_index = random.randint(0, max_index)
|
| 125 |
+
# 进行截断
|
| 126 |
+
item['seq'] = item['seq'][truncate_index:truncate_index+self.max_length]
|
| 127 |
+
item['CA'] = item['CA'][truncate_index:truncate_index+self.max_length]
|
| 128 |
+
item['C'] = item['C'][truncate_index:truncate_index+self.max_length]
|
| 129 |
+
item['O'] = item['O'][truncate_index:truncate_index+self.max_length]
|
| 130 |
+
item['N'] = item['N'][truncate_index:truncate_index+self.max_length]
|
| 131 |
+
item['chain_mask'] = item['chain_mask'][truncate_index:truncate_index+self.max_length]
|
| 132 |
+
item['chain_encoding'] = item['chain_encoding'][truncate_index:truncate_index+self.max_length]
|
| 133 |
+
return item
|
Flexpert-Design/src/datasets/casp_dataset.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch.utils.data as data
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class CASPDataset(data.Dataset):
|
| 8 |
+
def __init__(self, path = './', split='test'):
|
| 9 |
+
if not os.path.exists(path):
|
| 10 |
+
raise "no such file:{} !!!".format(path)
|
| 11 |
+
else:
|
| 12 |
+
with open(os.path.join(path,'casp15.jsonl')) as f:
|
| 13 |
+
lines = f.readlines()
|
| 14 |
+
|
| 15 |
+
# casp15_data = json.load(open(path+'casp15.json', 'r'))
|
| 16 |
+
|
| 17 |
+
alphabet='ACDEFGHIKLMNPQRSTVWY'
|
| 18 |
+
alphabet_set = set([a for a in alphabet])
|
| 19 |
+
|
| 20 |
+
self.data = []
|
| 21 |
+
for line in lines:
|
| 22 |
+
entry = json.loads(line)
|
| 23 |
+
seq = entry['seq']
|
| 24 |
+
|
| 25 |
+
for key, val in entry['coords'].items():
|
| 26 |
+
entry['coords'][key] = np.asarray(val)
|
| 27 |
+
|
| 28 |
+
bad_chars = set([s for s in seq]).difference(alphabet_set)
|
| 29 |
+
|
| 30 |
+
if len(bad_chars) == 0:
|
| 31 |
+
chain_length = len(entry['seq'])
|
| 32 |
+
chain_mask = np.ones(chain_length)
|
| 33 |
+
self.data.append({
|
| 34 |
+
'title':entry['name'],
|
| 35 |
+
'seq':entry['seq'],
|
| 36 |
+
'CA':entry['coords']['CA'],
|
| 37 |
+
'C':entry['coords']['C'],
|
| 38 |
+
'O':entry['coords']['O'],
|
| 39 |
+
'N':entry['coords']['N'],
|
| 40 |
+
'chain_mask': chain_mask,
|
| 41 |
+
'chain_encoding': 1*chain_mask,
|
| 42 |
+
'classification': entry['classification']
|
| 43 |
+
})
|
| 44 |
+
|
| 45 |
+
def __len__(self):
|
| 46 |
+
return len(self.data)
|
| 47 |
+
|
| 48 |
+
def get_item(self, index):
|
| 49 |
+
return self.data[index]
|
| 50 |
+
|
| 51 |
+
def __getitem__(self, index):
|
| 52 |
+
return self.data[index]
|
| 53 |
+
|
| 54 |
+
if __name__ == '__main__':
|
| 55 |
+
dataset = CASPDataset('/gaozhangyang/experiments/OpenCPD/data/casp15/')
|
| 56 |
+
for data in dataset:
|
| 57 |
+
print(data)
|
Flexpert-Design/src/datasets/cath_dataset.py
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import numpy as np
|
| 4 |
+
from tqdm import tqdm
|
| 5 |
+
import random
|
| 6 |
+
import torch.utils.data as data
|
| 7 |
+
from .utils import cached_property
|
| 8 |
+
from transformers import AutoTokenizer
|
| 9 |
+
from src.tools.utils import load_yaml_config
|
| 10 |
+
|
| 11 |
+
class CATHDataset(data.Dataset):
|
| 12 |
+
def __init__(self, path='./', split='train', max_length=500, test_name='All', data = None, removeTS=0, version=4.2, data_jsonl_name='/chain_set.jsonl'):
|
| 13 |
+
self.version = version
|
| 14 |
+
self.path = path
|
| 15 |
+
self.mode = split
|
| 16 |
+
self.max_length = max_length
|
| 17 |
+
self.test_name = test_name
|
| 18 |
+
self.removeTS = removeTS
|
| 19 |
+
self.data_jsonl_name = data_jsonl_name
|
| 20 |
+
|
| 21 |
+
self.using_dynamics = data_jsonl_name == load_yaml_config('/scratch/project/fta-24-31/koubapet/ProteinInvBench/src/models/configs/FlexibilityProtTrans.yaml')['data_jsonl_name']
|
| 22 |
+
|
| 23 |
+
print(self.data_jsonl_name)
|
| 24 |
+
if self.removeTS:
|
| 25 |
+
self.remove = json.load(open(self.path+'/remove.json', 'r'))['remove']
|
| 26 |
+
|
| 27 |
+
if data is None:
|
| 28 |
+
if split == 'predict':
|
| 29 |
+
_split = 'valid'
|
| 30 |
+
print('In predict mode for CATH4.3 using VALIDATION split as the data. Consider switching to TEST set.')
|
| 31 |
+
else:
|
| 32 |
+
_split = split
|
| 33 |
+
self.data = self.cache_data[_split]
|
| 34 |
+
else:
|
| 35 |
+
self.data = data
|
| 36 |
+
|
| 37 |
+
self.tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D", cache_dir="./cache_dir/")
|
| 38 |
+
|
| 39 |
+
@cached_property
|
| 40 |
+
def cache_data(self):
|
| 41 |
+
alphabet='ACDEFGHIKLMNPQRSTVWY'
|
| 42 |
+
alphabet_set = set([a for a in alphabet])
|
| 43 |
+
print("path is: ", self.path)
|
| 44 |
+
if not os.path.exists(self.path):
|
| 45 |
+
raise "no such file:{} !!!".format(self.path)
|
| 46 |
+
else:
|
| 47 |
+
with open(self.path+'/'+self.data_jsonl_name) as f:
|
| 48 |
+
lines = f.readlines()
|
| 49 |
+
data_list = []
|
| 50 |
+
for line in tqdm(lines):
|
| 51 |
+
entry = json.loads(line)
|
| 52 |
+
if self.removeTS and entry['name'] in self.remove:
|
| 53 |
+
continue
|
| 54 |
+
seq = entry['seq']
|
| 55 |
+
|
| 56 |
+
for key, val in entry['coords'].items():
|
| 57 |
+
entry['coords'][key] = np.asarray(val)
|
| 58 |
+
|
| 59 |
+
bad_chars = set([s for s in seq]).difference(alphabet_set)
|
| 60 |
+
|
| 61 |
+
if len(bad_chars) == 0:
|
| 62 |
+
if len(entry['seq']) <= self.max_length:
|
| 63 |
+
chain_length = len(entry['seq'])
|
| 64 |
+
chain_mask = np.ones(chain_length)
|
| 65 |
+
data_list.append({
|
| 66 |
+
'title':entry['name'],
|
| 67 |
+
'seq':entry['seq'],
|
| 68 |
+
'CA':entry['coords']['CA'],
|
| 69 |
+
'C':entry['coords']['C'],
|
| 70 |
+
'O':entry['coords']['O'],
|
| 71 |
+
'N':entry['coords']['N'],
|
| 72 |
+
'chain_mask': chain_mask,
|
| 73 |
+
'chain_encoding': 1*chain_mask
|
| 74 |
+
})
|
| 75 |
+
if self.using_dynamics: #TODO: pass this bool properly
|
| 76 |
+
data_list[-1]['norm_bfactors'] = entry['bfactor']
|
| 77 |
+
|
| 78 |
+
if self.version==4.2:
|
| 79 |
+
with open(self.path+'/chain_set_splits.json') as f:
|
| 80 |
+
dataset_splits = json.load(f)
|
| 81 |
+
|
| 82 |
+
if self.version==4.3:
|
| 83 |
+
with open(self.path+'/chain_set_splits.json') as f:
|
| 84 |
+
dataset_splits = json.load(f)
|
| 85 |
+
|
| 86 |
+
if self.test_name == 'L100':
|
| 87 |
+
with open(self.path+'/test_split_L100.json') as f:
|
| 88 |
+
test_splits = json.load(f)
|
| 89 |
+
dataset_splits['test'] = test_splits['test']
|
| 90 |
+
|
| 91 |
+
if self.test_name == 'sc':
|
| 92 |
+
with open(self.path+'/test_split_sc.json') as f:
|
| 93 |
+
test_splits = json.load(f)
|
| 94 |
+
dataset_splits['test'] = test_splits['test']
|
| 95 |
+
|
| 96 |
+
name2set = {}
|
| 97 |
+
name2set.update({name:'train' for name in dataset_splits['train']})
|
| 98 |
+
name2set.update({name:'valid' for name in dataset_splits['validation']})
|
| 99 |
+
name2set.update({name:'test' for name in dataset_splits['test']})
|
| 100 |
+
|
| 101 |
+
data_dict = {'train':[],'valid':[],'test':[]}
|
| 102 |
+
for data in data_list:
|
| 103 |
+
if name2set.get(data['title']):
|
| 104 |
+
if name2set[data['title']] == 'train':
|
| 105 |
+
data_dict['train'].append(data)
|
| 106 |
+
|
| 107 |
+
if name2set[data['title']] == 'valid':
|
| 108 |
+
data_dict['valid'].append(data)
|
| 109 |
+
|
| 110 |
+
if name2set[data['title']] == 'test':
|
| 111 |
+
data['category'] = 'Unkown'
|
| 112 |
+
data['score'] = 100.0
|
| 113 |
+
data_dict['test'].append(data)
|
| 114 |
+
return data_dict
|
| 115 |
+
|
| 116 |
+
def change_mode(self, mode):
|
| 117 |
+
self.data = self.cache_data[mode]
|
| 118 |
+
|
| 119 |
+
def __len__(self):
|
| 120 |
+
return len(self.data)
|
| 121 |
+
|
| 122 |
+
def get_item(self, index):
|
| 123 |
+
return self.data[index]
|
| 124 |
+
|
| 125 |
+
def __getitem__(self, index):
|
| 126 |
+
item = self.data[index]
|
| 127 |
+
L = len(item['seq'])
|
| 128 |
+
if L>self.max_length:
|
| 129 |
+
# 计算截断的最大索引
|
| 130 |
+
max_index = L - self.max_length
|
| 131 |
+
# 生成随机的截断索引
|
| 132 |
+
truncate_index = random.randint(0, max_index)
|
| 133 |
+
# 进行截断
|
| 134 |
+
item['seq'] = item['seq'][truncate_index:truncate_index+self.max_length]
|
| 135 |
+
item['CA'] = item['CA'][truncate_index:truncate_index+self.max_length]
|
| 136 |
+
item['C'] = item['C'][truncate_index:truncate_index+self.max_length]
|
| 137 |
+
item['O'] = item['O'][truncate_index:truncate_index+self.max_length]
|
| 138 |
+
item['N'] = item['N'][truncate_index:truncate_index+self.max_length]
|
| 139 |
+
item['chain_mask'] = item['chain_mask'][truncate_index:truncate_index+self.max_length]
|
| 140 |
+
item['chain_encoding'] = item['chain_encoding'][truncate_index:truncate_index+self.max_length]
|
| 141 |
+
return item
|
Flexpert-Design/src/datasets/dataloader.py
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import random
|
| 3 |
+
import os.path as osp
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.utils.data as data
|
| 7 |
+
import pdb
|
| 8 |
+
|
| 9 |
+
from .cath_dataset import CATHDataset
|
| 10 |
+
from .alphafold_dataset import AlphaFoldDataset
|
| 11 |
+
from .ts_dataset import TSDataset
|
| 12 |
+
from .casp_dataset import CASPDataset
|
| 13 |
+
from .mpnn_dataset import MPNNDataset
|
| 14 |
+
from .featurizer import (featurize_AF, featurize_GTrans, featurize_GVP,
|
| 15 |
+
featurize_ProteinMPNN, featurize_Inversefolding)
|
| 16 |
+
from .fast_dataloader import DataLoaderX
|
| 17 |
+
|
| 18 |
+
class GTransDataLoader(torch.utils.data.DataLoader):
|
| 19 |
+
def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0,
|
| 20 |
+
collate_fn=None, **kwargs):
|
| 21 |
+
super(GTransDataLoader, self).__init__(dataset, batch_size, shuffle, sampler, batch_sampler, num_workers, collate_fn,**kwargs)
|
| 22 |
+
self.featurizer = collate_fn
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class BatchSampler(data.Sampler):
|
| 26 |
+
'''
|
| 27 |
+
From https://github.com/jingraham/neurips19-graph-protein-design.
|
| 28 |
+
|
| 29 |
+
A `torch.utils.data.Sampler` which samples batches according to a
|
| 30 |
+
maximum number of graph nodes.
|
| 31 |
+
|
| 32 |
+
:param node_counts: array of node counts in the dataset to sample from
|
| 33 |
+
:param max_nodes: the maximum number of nodes in any batch,
|
| 34 |
+
including batches of a single element
|
| 35 |
+
:param shuffle: if `True`, batches in shuffled order
|
| 36 |
+
'''
|
| 37 |
+
def __init__(self, node_counts, max_nodes=3000, shuffle=True):
|
| 38 |
+
self.node_counts = node_counts
|
| 39 |
+
self.idx = [i for i in range(len(node_counts))
|
| 40 |
+
if node_counts[i] <= max_nodes]
|
| 41 |
+
self.shuffle = shuffle
|
| 42 |
+
self.max_nodes = max_nodes
|
| 43 |
+
self._form_batches()
|
| 44 |
+
|
| 45 |
+
def _form_batches(self):
|
| 46 |
+
self.batches = []
|
| 47 |
+
if self.shuffle: random.shuffle(self.idx)
|
| 48 |
+
idx = self.idx
|
| 49 |
+
while idx:
|
| 50 |
+
batch = []
|
| 51 |
+
n_nodes = 0
|
| 52 |
+
while idx and n_nodes + self.node_counts[idx[0]] <= self.max_nodes:
|
| 53 |
+
next_idx, idx = idx[0], idx[1:]
|
| 54 |
+
n_nodes += self.node_counts[next_idx]
|
| 55 |
+
batch.append(next_idx)
|
| 56 |
+
self.batches.append(batch)
|
| 57 |
+
|
| 58 |
+
def __len__(self):
|
| 59 |
+
if not self.batches: self._form_batches()
|
| 60 |
+
return len(self.batches)
|
| 61 |
+
|
| 62 |
+
def __iter__(self):
|
| 63 |
+
if not self.batches:
|
| 64 |
+
self._form_batches()
|
| 65 |
+
for batch in self.batches:
|
| 66 |
+
yield batch
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class GVPDataLoader(torch.utils.data.DataLoader):
|
| 70 |
+
def __init__(self, dataset, num_workers=0,
|
| 71 |
+
featurizer=None, max_nodes=3000, **kwargs):
|
| 72 |
+
super(GVPDataLoader, self).__init__(dataset,
|
| 73 |
+
batch_sampler = BatchSampler(node_counts = [ len(data['seq']) for data in dataset], max_nodes=max_nodes),
|
| 74 |
+
num_workers = num_workers,
|
| 75 |
+
collate_fn = featurizer.collate,
|
| 76 |
+
**kwargs)
|
| 77 |
+
self.featurizer = featurizer
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def load_data(data_name, method, batch_size, data_root, pdb_path, split_csv, max_nodes=3000, num_workers=8, removeTS=0, test_casp=False, **kwargs):
|
| 81 |
+
if data_name == 'CATH4.2' or data_name == 'TS':
|
| 82 |
+
cath_set = CATHDataset(osp.join(data_root, 'cath4.2'), mode='train', test_name='All', removeTS=removeTS)
|
| 83 |
+
train_set, valid_set, test_set = map(lambda x: copy.copy(x), [cath_set] * 3)
|
| 84 |
+
valid_set.change_mode('valid')
|
| 85 |
+
test_set.change_mode('test')
|
| 86 |
+
if data_name == 'TS':
|
| 87 |
+
test_set = TSDataset(osp.join(data_root, 'ts'))
|
| 88 |
+
|
| 89 |
+
collate_fn = featurize_GTrans
|
| 90 |
+
elif data_name == 'CATH4.3':
|
| 91 |
+
cath_set = CATHDataset(osp.join(data_root, 'cath4.3'), mode='train', test_name='All', removeTS=removeTS, version=4.3)
|
| 92 |
+
train_set, valid_set, test_set = map(lambda x: copy.copy(x), [cath_set] * 3)
|
| 93 |
+
valid_set.change_mode('valid')
|
| 94 |
+
test_set.change_mode('test')
|
| 95 |
+
|
| 96 |
+
collate_fn = featurize_GTrans
|
| 97 |
+
elif data_name == 'AlphaFold':
|
| 98 |
+
af_set = AlphaFoldDataset(osp.join(data_root, 'af2db'), upid=upid, mode='train', limit_length=limit_length, joint_data=joint_data)
|
| 99 |
+
train_set, valid_set, test_set = map(lambda x: copy.copy(x), [af_set] * 3)
|
| 100 |
+
valid_set.change_mode('valid')
|
| 101 |
+
test_set.change_mode('test')
|
| 102 |
+
collate_fn = featurize_AF
|
| 103 |
+
elif data_name=='MPNN':
|
| 104 |
+
train_set = MPNNDataset(mode='train')
|
| 105 |
+
valid_set = MPNNDataset(mode='valid')
|
| 106 |
+
test_set = MPNNDataset(mode='test')
|
| 107 |
+
collate_fn = featurize_GTrans
|
| 108 |
+
|
| 109 |
+
elif data_name == 'S350':
|
| 110 |
+
cath_set = CATHDataset(osp.join(data_root, 's350'), mode='train', test_name='All', removeTS=removeTS, version=4.3)
|
| 111 |
+
train_set, valid_set, test_set = map(lambda x: copy.copy(x), [cath_set] * 3)
|
| 112 |
+
valid_set.change_mode('train')
|
| 113 |
+
test_set.change_mode('train')
|
| 114 |
+
|
| 115 |
+
collate_fn = featurize_GTrans
|
| 116 |
+
|
| 117 |
+
elif data_name == 'Protherm':
|
| 118 |
+
cath_set = CATHDataset(osp.join(data_root, 'protherm'), mode='train', test_name='All', removeTS=removeTS, version=4.3)
|
| 119 |
+
train_set, valid_set, test_set = map(lambda x: copy.copy(x), [cath_set] * 3)
|
| 120 |
+
valid_set.change_mode('valid')
|
| 121 |
+
test_set.change_mode('test')
|
| 122 |
+
|
| 123 |
+
collate_fn = featurize_GTrans
|
| 124 |
+
if test_casp:
|
| 125 |
+
test_set = CASPDataset(osp.join(data_root, 'casp15'))
|
| 126 |
+
|
| 127 |
+
if method in ['AlphaDesign', 'PiFold', 'KWDesign', 'GraphTrans', 'StructGNN']:
|
| 128 |
+
pass
|
| 129 |
+
elif method == 'GVP':
|
| 130 |
+
featurizer = featurize_GVP()
|
| 131 |
+
collate_fn = featurizer.collate
|
| 132 |
+
elif method == 'ProteinMPNN':
|
| 133 |
+
collate_fn = featurize_ProteinMPNN
|
| 134 |
+
elif method == 'ESMIF':
|
| 135 |
+
collate_fn = featurize_Inversefolding
|
| 136 |
+
|
| 137 |
+
# train_set.data = train_set.data[:100]
|
| 138 |
+
# valid_set.data = valid_set.data[:100]
|
| 139 |
+
# test_set.data = test_set.data[:100]
|
| 140 |
+
pdb.set_trace()
|
| 141 |
+
train_loader = DataLoaderX(local_rank=0, dataset=train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers, collate_fn=collate_fn, prefetch_factor=8)
|
| 142 |
+
valid_loader = DataLoaderX(local_rank=0,dataset=valid_set, batch_size=batch_size, shuffle=False, num_workers=num_workers, collate_fn=collate_fn, prefetch_factor=8)
|
| 143 |
+
test_loader = DataLoaderX(local_rank=0,dataset=test_set, batch_size=batch_size, shuffle=False, num_workers=num_workers, collate_fn=collate_fn, prefetch_factor=8)
|
| 144 |
+
|
| 145 |
+
return train_loader, valid_loader, test_loader
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def make_cath_loader(test_set, method, batch_size, max_nodes=3000, num_workers=8):
|
| 149 |
+
if method in ['pifold','adesign', 'graphtrans', 'structgnn', 'gca']:
|
| 150 |
+
collate_fn = featurize_GTrans
|
| 151 |
+
test_loader = GTransDataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=num_workers, collate_fn=collate_fn)
|
| 152 |
+
elif method == 'gvp':
|
| 153 |
+
featurizer = featurize_GVP()
|
| 154 |
+
test_loader = GVPDataLoader(test_set, num_workers=num_workers, featurizer=featurizer, max_nodes=max_nodes)
|
| 155 |
+
elif method == 'proteinmpnn':
|
| 156 |
+
collate_fn = featurize_ProteinMPNN
|
| 157 |
+
test_loader = GTransDataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=num_workers, collate_fn=collate_fn)
|
| 158 |
+
elif method == 'esmif':
|
| 159 |
+
collate_fn = featurize_Inversefolding
|
| 160 |
+
test_loader = GTransDataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=num_workers, collate_fn=collate_fn)
|
| 161 |
+
return test_loader
|
Flexpert-Design/src/datasets/fast_dataloader.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import threading
|
| 2 |
+
import torch
|
| 3 |
+
import queue
|
| 4 |
+
from torch.utils.data import DataLoader
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class DataLoaderX(DataLoader):
|
| 8 |
+
def __init__(self, local_rank, **kwargs):
|
| 9 |
+
super().__init__(**kwargs)
|
| 10 |
+
self.stream = torch.cuda.Stream(
|
| 11 |
+
local_rank
|
| 12 |
+
) # create a new cuda stream in each process
|
| 13 |
+
self.local_rank = local_rank
|
| 14 |
+
# self.custom_collect_fn = custom_collect_fn
|
| 15 |
+
|
| 16 |
+
def __iter__(self):
|
| 17 |
+
self.iter = super().__iter__()
|
| 18 |
+
self.preload()
|
| 19 |
+
return self
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def preload(self):
|
| 23 |
+
while True:
|
| 24 |
+
#获取下一个值
|
| 25 |
+
self.batch = next(self.iter, None)
|
| 26 |
+
if self.batch is not None:
|
| 27 |
+
break
|
| 28 |
+
if self.iter._send_idx==len(self.iter):
|
| 29 |
+
break
|
| 30 |
+
|
| 31 |
+
if (self.batch is None):
|
| 32 |
+
return None
|
| 33 |
+
|
| 34 |
+
with torch.cuda.stream(self.stream): # 将数据预先放进gpu
|
| 35 |
+
for key, val in self.batch.items():
|
| 36 |
+
if type(val) == torch.Tensor:
|
| 37 |
+
self.batch[key] = val.to(
|
| 38 |
+
device=self.local_rank, non_blocking=True
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def __next__(self):
|
| 43 |
+
torch.cuda.current_stream().wait_stream(
|
| 44 |
+
self.stream
|
| 45 |
+
) # wait tensor to put on GPU
|
| 46 |
+
batch = self.batch
|
| 47 |
+
# batch = self.custom_collect_fn(self.batch)
|
| 48 |
+
if batch is None:
|
| 49 |
+
raise StopIteration
|
| 50 |
+
self.preload()
|
| 51 |
+
return batch
|
| 52 |
+
|
Flexpert-Design/src/datasets/featurizer.py
ADDED
|
@@ -0,0 +1,743 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
import itertools
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
import math
|
| 6 |
+
import torch_geometric
|
| 7 |
+
# import torch_cluster
|
| 8 |
+
from collections.abc import Mapping, Sequence
|
| 9 |
+
from torch_geometric.data import Data, Batch
|
| 10 |
+
from torch.utils.data.dataloader import default_collate
|
| 11 |
+
from transformers import AutoTokenizer
|
| 12 |
+
import pdb
|
| 13 |
+
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D", cache_dir='./cache_dir/') # mask token: 32
|
| 14 |
+
|
| 15 |
+
def _normalize(tensor, dim=-1):
|
| 16 |
+
'''
|
| 17 |
+
Normalizes a `torch.Tensor` along dimension `dim` without `nan`s.
|
| 18 |
+
'''
|
| 19 |
+
return torch.nan_to_num(
|
| 20 |
+
torch.div(tensor, torch.norm(tensor, dim=dim, keepdim=True)))
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def _rbf(D, D_min=0., D_max=20., D_count=16, device='cpu'):
|
| 24 |
+
'''
|
| 25 |
+
From https://github.com/jingraham/neurips19-graph-protein-design
|
| 26 |
+
|
| 27 |
+
Returns an RBF embedding of `torch.Tensor` `D` along a new axis=-1.
|
| 28 |
+
That is, if `D` has shape [...dims], then the returned tensor will have
|
| 29 |
+
shape [...dims, D_count].
|
| 30 |
+
'''
|
| 31 |
+
D_mu = torch.linspace(D_min, D_max, D_count, device=device)
|
| 32 |
+
D_mu = D_mu.view([1, -1])
|
| 33 |
+
D_sigma = (D_max - D_min) / D_count
|
| 34 |
+
D_expand = torch.unsqueeze(D, -1)
|
| 35 |
+
|
| 36 |
+
RBF = torch.exp(-((D_expand - D_mu) / D_sigma) ** 2)
|
| 37 |
+
return RBF
|
| 38 |
+
|
| 39 |
+
def shuffle_subset(n, p):
|
| 40 |
+
n_shuffle = np.random.binomial(n, p)
|
| 41 |
+
ix = np.arange(n)
|
| 42 |
+
ix_subset = np.random.choice(ix, size=n_shuffle, replace=False)
|
| 43 |
+
ix_subset_shuffled = np.copy(ix_subset)
|
| 44 |
+
np.random.shuffle(ix_subset_shuffled)
|
| 45 |
+
ix[ix_subset] = ix_subset_shuffled
|
| 46 |
+
return ix
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def featurize_AF(batch, shuffle_fraction=0.):
|
| 50 |
+
""" Pack and pad batch into torch tensors """
|
| 51 |
+
alphabet = 'ACDEFGHIKLMNPQRSTVWY'
|
| 52 |
+
B = len(batch)
|
| 53 |
+
lengths = np.array([len(b['seq']) for b in batch], dtype=np.int32)
|
| 54 |
+
L_max = max([len(b['seq']) for b in batch])
|
| 55 |
+
X = np.zeros([B, L_max, 4, 3])
|
| 56 |
+
S = np.zeros([B, L_max], dtype=np.int32)
|
| 57 |
+
score = np.zeros([B, L_max])
|
| 58 |
+
|
| 59 |
+
# Build the batch
|
| 60 |
+
for i, b in enumerate(batch):
|
| 61 |
+
x = np.stack([b[c] for c in ['N', 'CA', 'C', 'O']], 1) # [#atom, 4, 3]
|
| 62 |
+
|
| 63 |
+
l = len(b['seq'])
|
| 64 |
+
x_pad = np.pad(x, [[0,L_max-l], [0,0], [0,0]], 'constant', constant_values=(np.nan, )) # [#atom, 4, 3]
|
| 65 |
+
X[i,:,:,:] = x_pad
|
| 66 |
+
|
| 67 |
+
# Convert to labels
|
| 68 |
+
indices = np.asarray([alphabet.index(a) for a in b['seq']], dtype=np.int32)
|
| 69 |
+
if shuffle_fraction > 0.:
|
| 70 |
+
idx_shuffle = shuffle_subset(l, shuffle_fraction)
|
| 71 |
+
S[i, :l] = indices[idx_shuffle]
|
| 72 |
+
score[i,:l] = b['score'][idx_shuffle]
|
| 73 |
+
else:
|
| 74 |
+
S[i, :l] = indices
|
| 75 |
+
score[i,:l] = b['score']
|
| 76 |
+
|
| 77 |
+
mask = np.isfinite(np.sum(X,(2,3))).astype(np.float32) # atom mask
|
| 78 |
+
numbers = np.sum(mask, axis=1).astype(np.int)
|
| 79 |
+
S_new = np.zeros_like(S)
|
| 80 |
+
score_new = np.zeros_like(score)
|
| 81 |
+
X_new = np.zeros_like(X)+np.nan
|
| 82 |
+
for i, n in enumerate(numbers):
|
| 83 |
+
X_new[i,:n,::] = X[i][mask[i]==1]
|
| 84 |
+
S_new[i,:n] = S[i][mask[i]==1]
|
| 85 |
+
score_new[i,:n] = score[i][mask[i]==1]
|
| 86 |
+
|
| 87 |
+
X = X_new
|
| 88 |
+
S = S_new
|
| 89 |
+
score = score_new
|
| 90 |
+
isnan = np.isnan(X)
|
| 91 |
+
mask = np.isfinite(np.sum(X,(2,3))).astype(np.float32)
|
| 92 |
+
X[isnan] = 0.
|
| 93 |
+
# Conversion
|
| 94 |
+
S = torch.from_numpy(S).to(dtype=torch.long)
|
| 95 |
+
score = torch.from_numpy(score).float()
|
| 96 |
+
X = torch.from_numpy(X).to(dtype=torch.float32)
|
| 97 |
+
mask = torch.from_numpy(mask).to(dtype=torch.float32)
|
| 98 |
+
return X, S, score, mask, lengths
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def featurize_GTrans(batch):
|
| 102 |
+
""" Pack and pad batch into torch tensors """
|
| 103 |
+
# alphabet = 'ACDEFGHIKLMNPQRSTVWYX'
|
| 104 |
+
batch = [one for one in batch if one is not None]
|
| 105 |
+
B = len(batch)
|
| 106 |
+
if B==0:
|
| 107 |
+
return None
|
| 108 |
+
lengths = np.array([len(b['seq']) for b in batch], dtype=np.int32)
|
| 109 |
+
L_max = max([len(b['seq']) for b in batch])
|
| 110 |
+
X = np.zeros([B, L_max, 4, 3])
|
| 111 |
+
S = np.zeros([B, L_max], dtype=np.int32)
|
| 112 |
+
score = np.ones([B, L_max]) * 100.0
|
| 113 |
+
chain_mask = np.zeros([B, L_max])-1 # 1:需要被预测的掩码部分 0:可见部分
|
| 114 |
+
chain_encoding = np.zeros([B, L_max])-1
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
# Build the batch
|
| 118 |
+
for i, b in enumerate(batch):
|
| 119 |
+
x = np.stack([b[c] for c in ['N', 'CA', 'C', 'O']], 1) # [#atom, 4, 3]
|
| 120 |
+
|
| 121 |
+
l = len(b['seq'])
|
| 122 |
+
x_pad = np.pad(x, [[0,L_max-l], [0,0], [0,0]], 'constant', constant_values=(np.nan, )) # [#atom, 4, 3]
|
| 123 |
+
X[i,:,:,:] = x_pad
|
| 124 |
+
|
| 125 |
+
# Convert to labels
|
| 126 |
+
indices = np.array(tokenizer.encode(b['seq'], add_special_tokens=False))
|
| 127 |
+
# indices = np.asarray([alphabet.index(a) for a in b['seq']], dtype=np.int32)
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
S[i, :l] = indices
|
| 131 |
+
chain_mask[i,:l] = b['chain_mask']
|
| 132 |
+
chain_encoding[i,:l] = b['chain_encoding']
|
| 133 |
+
|
| 134 |
+
mask = np.isfinite(np.sum(X,(2,3))).astype(np.float32) # atom mask
|
| 135 |
+
numbers = np.sum(mask, axis=1).astype(np.int32)
|
| 136 |
+
S_new = np.zeros_like(S)
|
| 137 |
+
X_new = np.zeros_like(X)+np.nan
|
| 138 |
+
for i, n in enumerate(numbers):
|
| 139 |
+
X_new[i,:n,::] = X[i][mask[i]==1]
|
| 140 |
+
S_new[i,:n] = S[i][mask[i]==1]
|
| 141 |
+
|
| 142 |
+
X = X_new
|
| 143 |
+
S = S_new
|
| 144 |
+
isnan = np.isnan(X)
|
| 145 |
+
mask = np.isfinite(np.sum(X,(2,3))).astype(np.float32)
|
| 146 |
+
X[isnan] = 0.
|
| 147 |
+
# Conversion
|
| 148 |
+
S = torch.from_numpy(S).to(dtype=torch.long)
|
| 149 |
+
score = torch.from_numpy(score).float()
|
| 150 |
+
X = torch.from_numpy(X).to(dtype=torch.float32)
|
| 151 |
+
mask = torch.from_numpy(mask).to(dtype=torch.float32)
|
| 152 |
+
lengths = torch.from_numpy(lengths)
|
| 153 |
+
chain_mask = torch.from_numpy(chain_mask)
|
| 154 |
+
chain_encoding = torch.from_numpy(chain_encoding)
|
| 155 |
+
|
| 156 |
+
return {"title": [b['title'] for b in batch],
|
| 157 |
+
"X":X,
|
| 158 |
+
"S":S,
|
| 159 |
+
"score": score,
|
| 160 |
+
"mask":mask,
|
| 161 |
+
"lengths":lengths,
|
| 162 |
+
"chain_mask":chain_mask,
|
| 163 |
+
"chain_encoding":chain_encoding}
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
class featurize_GVP:
|
| 167 |
+
def __init__(self, num_positional_embeddings=16, top_k=30, num_rbf=16):
|
| 168 |
+
self.top_k = top_k
|
| 169 |
+
self.num_rbf = num_rbf
|
| 170 |
+
self.num_positional_embeddings = num_positional_embeddings
|
| 171 |
+
# self.letter_to_num = {'C': 4, 'D': 3, 'S': 15, 'Q': 5, 'K': 11, 'I': 9,
|
| 172 |
+
# 'P': 14, 'T': 16, 'F': 13, 'A': 0, 'G': 7, 'H': 8,
|
| 173 |
+
# 'E': 6, 'L': 10, 'R': 1, 'W': 17, 'V': 19,
|
| 174 |
+
# 'N': 2, 'Y': 18, 'M': 12}
|
| 175 |
+
# self.num_to_letter = {v:k for k, v in self.letter_to_num.items()}
|
| 176 |
+
|
| 177 |
+
def featurize(self, batch):
|
| 178 |
+
data_all = []
|
| 179 |
+
for b in batch:
|
| 180 |
+
if b is None:
|
| 181 |
+
continue
|
| 182 |
+
coords = torch.tensor(np.stack([b[c] for c in ['N', 'CA', 'C', 'O']], 1))
|
| 183 |
+
seq = torch.tensor(np.array(tokenizer.encode(b['seq'], add_special_tokens=False)))
|
| 184 |
+
|
| 185 |
+
mask = torch.isfinite(coords.sum(dim=(1,2)))
|
| 186 |
+
coords[~mask] = np.inf
|
| 187 |
+
|
| 188 |
+
X_ca = coords[:, 1].float()
|
| 189 |
+
edge_index = torch_geometric.nn.knn_graph(X_ca, k=self.top_k)#torch_cluster.knn_graph(X_ca, k=self.top_k)
|
| 190 |
+
|
| 191 |
+
pos_embeddings = self._positional_embeddings(edge_index) # [E, 16]
|
| 192 |
+
E_vectors = X_ca[edge_index[0]] - X_ca[edge_index[1]] # [E, 3]
|
| 193 |
+
rbf = _rbf(E_vectors.norm(dim=-1), D_count=self.num_rbf) # [E, 16]
|
| 194 |
+
|
| 195 |
+
dihedrals = self._dihedrals(coords) # [n,6]
|
| 196 |
+
orientations = self._orientations(X_ca) # [n,2,3]
|
| 197 |
+
sidechains = self._sidechains(coords) # [n,3]
|
| 198 |
+
|
| 199 |
+
node_s = dihedrals.float() # [n,6]
|
| 200 |
+
|
| 201 |
+
node_v = torch.cat([orientations, sidechains.unsqueeze(-2)], dim=-2).float() # [n, 3, 3]
|
| 202 |
+
|
| 203 |
+
edge_s = torch.cat([rbf, pos_embeddings], dim=-1).float() # [E, 32]
|
| 204 |
+
edge_v = _normalize(E_vectors).unsqueeze(-2).float() # [E, 1, 3]
|
| 205 |
+
|
| 206 |
+
node_s, node_v, edge_s, edge_v = map(torch.nan_to_num,(node_s, node_v, edge_s, edge_v))
|
| 207 |
+
|
| 208 |
+
data = torch_geometric.data.Data(x=X_ca, seq=seq,
|
| 209 |
+
node_s=node_s, node_v=node_v,
|
| 210 |
+
edge_s=edge_s, edge_v=edge_v,
|
| 211 |
+
edge_index=edge_index, mask=mask)
|
| 212 |
+
data_all.append(data)
|
| 213 |
+
return data_all
|
| 214 |
+
|
| 215 |
+
def _positional_embeddings(self, edge_index,
|
| 216 |
+
num_embeddings=None,
|
| 217 |
+
period_range=[2, 1000]):
|
| 218 |
+
# From https://github.com/jingraham/neurips19-graph-protein-design
|
| 219 |
+
num_embeddings = num_embeddings or self.num_positional_embeddings
|
| 220 |
+
d = edge_index[0] - edge_index[1]
|
| 221 |
+
|
| 222 |
+
frequency = torch.exp(
|
| 223 |
+
torch.arange(0, num_embeddings, 2, dtype=torch.float32)
|
| 224 |
+
* -(np.log(10000.0) / num_embeddings)
|
| 225 |
+
)
|
| 226 |
+
angles = d.unsqueeze(-1) * frequency
|
| 227 |
+
E = torch.cat((torch.cos(angles), torch.sin(angles)), -1)
|
| 228 |
+
return E
|
| 229 |
+
|
| 230 |
+
def _dihedrals(self, X, eps=1e-7):
|
| 231 |
+
# From https://github.com/jingraham/neurips19-graph-protein-design
|
| 232 |
+
|
| 233 |
+
X = torch.reshape(X[:, :3], [3*X.shape[0], 3])
|
| 234 |
+
dX = X[1:] - X[:-1]
|
| 235 |
+
U = _normalize(dX, dim=-1)
|
| 236 |
+
u_2 = U[:-2]
|
| 237 |
+
u_1 = U[1:-1]
|
| 238 |
+
u_0 = U[2:]
|
| 239 |
+
|
| 240 |
+
# Backbone normals
|
| 241 |
+
n_2 = _normalize(torch.cross(u_2, u_1), dim=-1)
|
| 242 |
+
n_1 = _normalize(torch.cross(u_1, u_0), dim=-1)
|
| 243 |
+
|
| 244 |
+
# Angle between normals
|
| 245 |
+
cosD = torch.sum(n_2 * n_1, -1)
|
| 246 |
+
cosD = torch.clamp(cosD, -1 + eps, 1 - eps)
|
| 247 |
+
D = torch.sign(torch.sum(u_2 * n_1, -1)) * torch.acos(cosD)
|
| 248 |
+
|
| 249 |
+
# This scheme will remove phi[0], psi[-1], omega[-1]
|
| 250 |
+
D = F.pad(D, [1, 2])
|
| 251 |
+
D = torch.reshape(D, [-1, 3])
|
| 252 |
+
# Lift angle representations to the circle
|
| 253 |
+
D_features = torch.cat([torch.cos(D), torch.sin(D)], 1)
|
| 254 |
+
return D_features
|
| 255 |
+
|
| 256 |
+
def _orientations(self, X):
|
| 257 |
+
forward = _normalize(X[1:] - X[:-1])
|
| 258 |
+
backward = _normalize(X[:-1] - X[1:])
|
| 259 |
+
forward = F.pad(forward, [0, 0, 0, 1])
|
| 260 |
+
backward = F.pad(backward, [0, 0, 1, 0])
|
| 261 |
+
return torch.cat([forward.unsqueeze(-2), backward.unsqueeze(-2)], -2)
|
| 262 |
+
|
| 263 |
+
def _sidechains(self, X):
|
| 264 |
+
n, origin, c = X[:, 0], X[:, 1], X[:, 2]
|
| 265 |
+
c, n = _normalize(c - origin), _normalize(n - origin)
|
| 266 |
+
bisector = _normalize(c + n)
|
| 267 |
+
perp = _normalize(torch.cross(c, n))
|
| 268 |
+
vec = -bisector * math.sqrt(1 / 3) - perp * math.sqrt(2 / 3)
|
| 269 |
+
return vec
|
| 270 |
+
|
| 271 |
+
def collate(self, batch):
|
| 272 |
+
batch = self.featurize(batch)
|
| 273 |
+
if (batch is None) or (len(batch)==0):
|
| 274 |
+
return None
|
| 275 |
+
|
| 276 |
+
elem = batch[0]
|
| 277 |
+
if isinstance(elem, Data):
|
| 278 |
+
return Batch.from_data_list(batch)
|
| 279 |
+
elif isinstance(elem, torch.Tensor):
|
| 280 |
+
return default_collate(batch)
|
| 281 |
+
elif isinstance(elem, float):
|
| 282 |
+
return torch.tensor(batch, dtype=torch.float)
|
| 283 |
+
elif isinstance(elem, int):
|
| 284 |
+
return torch.tensor(batch)
|
| 285 |
+
elif isinstance(elem, str):
|
| 286 |
+
return batch
|
| 287 |
+
elif isinstance(elem, Mapping):
|
| 288 |
+
return {key: self.collate([d[key] for d in batch]) for key in elem}
|
| 289 |
+
elif isinstance(elem, tuple) and hasattr(elem, '_fields'):
|
| 290 |
+
return type(elem)(*(self.collate(s) for s in zip(*batch)))
|
| 291 |
+
elif isinstance(elem, Sequence) and not isinstance(elem, str):
|
| 292 |
+
return [self.collate(s) for s in zip(*batch)]
|
| 293 |
+
|
| 294 |
+
raise TypeError('DataLoader found invalid type: {}'.format(type(elem)))
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
def featurize_ProteinMPNN(batch, is_testing=False, chain_dict=None, fixed_position_dict=None, omit_AA_dict=None, tied_positions_dict=None, pssm_dict=None, bias_by_res_dict=None):
|
| 298 |
+
""" Pack and pad batch into torch tensors """
|
| 299 |
+
|
| 300 |
+
batch = [one for one in batch if one is not None]
|
| 301 |
+
# print('______________________________________________________')
|
| 302 |
+
# print('______________________________________________________')
|
| 303 |
+
# print('______________________________________________________')
|
| 304 |
+
# print('______________________________________________________')
|
| 305 |
+
# print(batch[0].keys())
|
| 306 |
+
USING_DYNAMICS = True if ('norm_bfactors' in batch[0].keys()) or ('gt_flex' in batch[0].keys()) or ('enm_vals' in batch[0].keys()) or ('original_gt_flex' in batch[0].keys()) or ('eng_mask' in batch[0].keys()) else False
|
| 307 |
+
|
| 308 |
+
alphabet = 'ACDEFGHIKLMNPQRSTVWYX'
|
| 309 |
+
B = len(batch)
|
| 310 |
+
if B==0:
|
| 311 |
+
return None
|
| 312 |
+
lengths = np.array([len(b['seq']) for b in batch], dtype=np.int32) #sum of chain seq lengths
|
| 313 |
+
L_max = max([len(b['seq']) for b in batch])
|
| 314 |
+
X = np.zeros([B, L_max, 4, 3])
|
| 315 |
+
residue_idx = -100*np.ones([B, L_max], dtype=np.int32)
|
| 316 |
+
chain_M = np.zeros([B, L_max], dtype=np.int32) #1.0 for the bits that need to be predicted
|
| 317 |
+
pssm_coef_all = np.zeros([B, L_max], dtype=np.float32) #1.0 for the bits that need to be predicted
|
| 318 |
+
pssm_bias_all = np.zeros([B, L_max, 21], dtype=np.float32) #1.0 for the bits that need to be predicted
|
| 319 |
+
pssm_log_odds_all = 10000.0*np.ones([B, L_max, 21], dtype=np.float32) #1.0 for the bits that need to be predicted
|
| 320 |
+
chain_M_pos = np.zeros([B, L_max], dtype=np.int32) #1.0 for the bits that need to be predicted
|
| 321 |
+
bias_by_res_all = np.zeros([B, L_max, 21], dtype=np.float32)
|
| 322 |
+
chain_encoding_all = np.zeros([B, L_max], dtype=np.int32) #1.0 for the bits that need to be predicted
|
| 323 |
+
S = np.zeros([B, L_max], dtype=np.int32)
|
| 324 |
+
score = np.zeros([B, L_max])
|
| 325 |
+
omit_AA_mask = np.zeros([B, L_max, len(alphabet)], dtype=np.int32)
|
| 326 |
+
# Build the batch
|
| 327 |
+
letter_list_list = []
|
| 328 |
+
visible_list_list = []
|
| 329 |
+
masked_list_list = []
|
| 330 |
+
masked_chain_length_list_list = []
|
| 331 |
+
tied_pos_list_of_lists_list = []
|
| 332 |
+
# shuffle all chains before the main loop
|
| 333 |
+
if USING_DYNAMICS:
|
| 334 |
+
if ('norm_bfactors' in batch[0].keys()):
|
| 335 |
+
b_factors = np.zeros([B, L_max])
|
| 336 |
+
if ('gt_flex' in batch[0].keys()):
|
| 337 |
+
gt_flex = np.zeros([B, L_max])
|
| 338 |
+
if ('enm_vals' in batch[0].keys()):
|
| 339 |
+
enm_vals = np.zeros([B, L_max])
|
| 340 |
+
if ('original_gt_flex' in batch[0].keys()):
|
| 341 |
+
original_gt_flex = np.zeros([B, L_max])
|
| 342 |
+
if ('eng_mask' in batch[0].keys()):
|
| 343 |
+
eng_mask = np.zeros([B, L_max])
|
| 344 |
+
|
| 345 |
+
for i, b in enumerate(batch):
|
| 346 |
+
if chain_dict != None:
|
| 347 |
+
masked_chains, visible_chains = chain_dict[b['name']] #masked_chains a list of chain letters to predict [A, D, F]
|
| 348 |
+
else:
|
| 349 |
+
# masked_chains = [item[-1:] for item in list(b) if item[:10]=='seq_chain_']
|
| 350 |
+
masked_chains = ['']
|
| 351 |
+
visible_chains = []
|
| 352 |
+
# num_chains = b['num_of_chains']
|
| 353 |
+
all_chains = masked_chains + visible_chains
|
| 354 |
+
#random.shuffle(all_chains)
|
| 355 |
+
for i, b in enumerate(batch):
|
| 356 |
+
mask_dict = {}
|
| 357 |
+
a = 0
|
| 358 |
+
x_chain_list = []
|
| 359 |
+
chain_mask_list = []
|
| 360 |
+
chain_seq_list = []
|
| 361 |
+
chain_encoding_list = []
|
| 362 |
+
c = 1
|
| 363 |
+
letter_list = []
|
| 364 |
+
global_idx_start_list = [0]
|
| 365 |
+
visible_list = []
|
| 366 |
+
masked_list = []
|
| 367 |
+
masked_chain_length_list = []
|
| 368 |
+
fixed_position_mask_list = []
|
| 369 |
+
omit_AA_mask_list = []
|
| 370 |
+
pssm_coef_list = []
|
| 371 |
+
pssm_bias_list = []
|
| 372 |
+
pssm_log_odds_list = []
|
| 373 |
+
bias_by_res_list = []
|
| 374 |
+
|
| 375 |
+
if USING_DYNAMICS:
|
| 376 |
+
if ('norm_bfactors' in batch[0].keys()):
|
| 377 |
+
b_factors_list = []
|
| 378 |
+
if ('gt_flex' in batch[0].keys()):
|
| 379 |
+
gt_flex_list = []
|
| 380 |
+
if ('enm_vals' in batch[0].keys()):
|
| 381 |
+
enm_vals_list = []
|
| 382 |
+
if ('original_gt_flex' in batch[0].keys()):
|
| 383 |
+
original_gt_flex_list = []
|
| 384 |
+
if ('eng_mask' in batch[0].keys()):
|
| 385 |
+
eng_mask_list = []
|
| 386 |
+
l0 = 0
|
| 387 |
+
l1 = 0
|
| 388 |
+
for step, letter in enumerate(all_chains):
|
| 389 |
+
if letter in visible_chains:
|
| 390 |
+
letter_list.append(letter)
|
| 391 |
+
visible_list.append(letter)
|
| 392 |
+
chain_seq = b[f'seq_chain_{letter}']
|
| 393 |
+
chain_seq = ''.join([a if a!='-' else 'X' for a in chain_seq])
|
| 394 |
+
chain_length = len(chain_seq)
|
| 395 |
+
global_idx_start_list.append(global_idx_start_list[-1]+chain_length)
|
| 396 |
+
chain_coords = b[f'coords_chain_{letter}'] #this is a dictionary
|
| 397 |
+
chain_mask = np.zeros(chain_length) #0.0 for visible chains
|
| 398 |
+
x_chain = np.stack([chain_coords[c] for c in [f'N_chain_{letter}', f'CA_chain_{letter}', f'C_chain_{letter}', f'O_chain_{letter}']], 1) #[chain_lenght,4,3]
|
| 399 |
+
x_chain_list.append(x_chain)
|
| 400 |
+
chain_mask_list.append(chain_mask)
|
| 401 |
+
chain_seq_list.append(chain_seq)
|
| 402 |
+
chain_encoding_list.append(c*np.ones(np.array(chain_mask).shape[0]))
|
| 403 |
+
l1 += chain_length
|
| 404 |
+
residue_idx[i, l0:l1] = 100*(c-1)+np.arange(l0, l1)
|
| 405 |
+
l0 += chain_length
|
| 406 |
+
c+=1
|
| 407 |
+
fixed_position_mask = np.ones(chain_length)
|
| 408 |
+
fixed_position_mask_list.append(fixed_position_mask)
|
| 409 |
+
omit_AA_mask_temp = np.zeros([chain_length, len(alphabet)], np.int32)
|
| 410 |
+
omit_AA_mask_list.append(omit_AA_mask_temp)
|
| 411 |
+
pssm_coef = np.zeros(chain_length)
|
| 412 |
+
pssm_bias = np.zeros([chain_length, 21])
|
| 413 |
+
pssm_log_odds = 10000.0*np.ones([chain_length, 21])
|
| 414 |
+
pssm_coef_list.append(pssm_coef)
|
| 415 |
+
pssm_bias_list.append(pssm_bias)
|
| 416 |
+
pssm_log_odds_list.append(pssm_log_odds)
|
| 417 |
+
bias_by_res_list.append(np.zeros([chain_length, 21]))
|
| 418 |
+
if letter in masked_chains:
|
| 419 |
+
masked_list.append(letter)
|
| 420 |
+
letter_list.append(letter)
|
| 421 |
+
|
| 422 |
+
if USING_DYNAMICS:
|
| 423 |
+
if ('norm_bfactors' in batch[0].keys()):
|
| 424 |
+
chain_b_factors = b['norm_bfactors']
|
| 425 |
+
b_factors_list.append(chain_b_factors)
|
| 426 |
+
if ('gt_flex' in batch[0].keys()):
|
| 427 |
+
chain_gt_flex = b['gt_flex']
|
| 428 |
+
gt_flex_list.append(chain_gt_flex)
|
| 429 |
+
if ('enm_vals' in batch[0].keys()):
|
| 430 |
+
chain_enm_vals = b['enm_vals']
|
| 431 |
+
enm_vals_list.append(chain_enm_vals)
|
| 432 |
+
if ('original_gt_flex' in batch[0].keys()):
|
| 433 |
+
chain_original_gt_flex = b['original_gt_flex']
|
| 434 |
+
original_gt_flex_list.append(chain_original_gt_flex)
|
| 435 |
+
if ('eng_mask' in batch[0].keys()):
|
| 436 |
+
chain_eng_mask = b['eng_mask']
|
| 437 |
+
eng_mask_list.append(chain_eng_mask)
|
| 438 |
+
|
| 439 |
+
# chain_seq = b[f'seq_chain_{letter}']
|
| 440 |
+
chain_seq = b[f'seq{letter}']
|
| 441 |
+
chain_seq = ''.join([a if a!='-' else 'X' for a in chain_seq])
|
| 442 |
+
chain_length = len(chain_seq)
|
| 443 |
+
global_idx_start_list.append(global_idx_start_list[-1]+chain_length)
|
| 444 |
+
masked_chain_length_list.append(chain_length)
|
| 445 |
+
# chain_coords = b[f'coords_chain_{letter}'] #this is a dictionary
|
| 446 |
+
chain_coords = b
|
| 447 |
+
chain_mask = np.ones(chain_length) #1.0 for masked
|
| 448 |
+
# x_chain = np.stack([chain_coords[c] for c in [f'N_chain_{letter}', f'CA_chain_{letter}', f'C_chain_{letter}', f'O_chain_{letter}']], 1) #[chain_lenght,4,3]
|
| 449 |
+
x_chain = np.stack([chain_coords[c] for c in [f'N', f'CA', f'C', f'O']], 1) #[chain_lenght,4,3]
|
| 450 |
+
x_chain_list.append(x_chain)
|
| 451 |
+
chain_mask_list.append(chain_mask)
|
| 452 |
+
chain_seq_list.append(chain_seq)
|
| 453 |
+
chain_encoding_list.append(c*np.ones(np.array(chain_mask).shape[0]))
|
| 454 |
+
l1 += chain_length
|
| 455 |
+
residue_idx[i, l0:l1] = 100*(c-1)+np.arange(l0, l1)
|
| 456 |
+
l0 += chain_length
|
| 457 |
+
c+=1
|
| 458 |
+
fixed_position_mask = np.ones(chain_length)
|
| 459 |
+
if fixed_position_dict!=None:
|
| 460 |
+
fixed_pos_list = fixed_position_dict[b['name']][letter]
|
| 461 |
+
if fixed_pos_list:
|
| 462 |
+
fixed_position_mask[np.array(fixed_pos_list)-1] = 0.0
|
| 463 |
+
fixed_position_mask_list.append(fixed_position_mask)
|
| 464 |
+
omit_AA_mask_temp = np.zeros([chain_length, len(alphabet)], np.int32)
|
| 465 |
+
if omit_AA_dict!=None:
|
| 466 |
+
for item in omit_AA_dict[b['name']][letter]:
|
| 467 |
+
idx_AA = np.array(item[0])-1
|
| 468 |
+
AA_idx = np.array([np.argwhere(np.array(list(alphabet))== AA)[0][0] for AA in item[1]]).repeat(idx_AA.shape[0])
|
| 469 |
+
idx_ = np.array([[a, b] for a in idx_AA for b in AA_idx])
|
| 470 |
+
omit_AA_mask_temp[idx_[:,0], idx_[:,1]] = 1
|
| 471 |
+
omit_AA_mask_list.append(omit_AA_mask_temp)
|
| 472 |
+
pssm_coef = np.zeros(chain_length)
|
| 473 |
+
pssm_bias = np.zeros([chain_length, 21])
|
| 474 |
+
pssm_log_odds = 10000.0*np.ones([chain_length, 21])
|
| 475 |
+
if pssm_dict:
|
| 476 |
+
if pssm_dict[b['name']][letter]:
|
| 477 |
+
pssm_coef = pssm_dict[b['name']][letter]['pssm_coef']
|
| 478 |
+
pssm_bias = pssm_dict[b['name']][letter]['pssm_bias']
|
| 479 |
+
pssm_log_odds = pssm_dict[b['name']][letter]['pssm_log_odds']
|
| 480 |
+
pssm_coef_list.append(pssm_coef)
|
| 481 |
+
pssm_bias_list.append(pssm_bias)
|
| 482 |
+
pssm_log_odds_list.append(pssm_log_odds)
|
| 483 |
+
if bias_by_res_dict:
|
| 484 |
+
bias_by_res_list.append(bias_by_res_dict[b['name']][letter])
|
| 485 |
+
else:
|
| 486 |
+
bias_by_res_list.append(np.zeros([chain_length, 21]))
|
| 487 |
+
|
| 488 |
+
|
| 489 |
+
letter_list_np = np.array(letter_list)
|
| 490 |
+
tied_pos_list_of_lists = []
|
| 491 |
+
tied_beta = np.ones(L_max)
|
| 492 |
+
if tied_positions_dict!=None:
|
| 493 |
+
tied_pos_list = tied_positions_dict[b['name']]
|
| 494 |
+
if tied_pos_list:
|
| 495 |
+
set_chains_tied = set(list(itertools.chain(*[list(item) for item in tied_pos_list])))
|
| 496 |
+
for tied_item in tied_pos_list:
|
| 497 |
+
one_list = []
|
| 498 |
+
for k, v in tied_item.items():
|
| 499 |
+
start_idx = global_idx_start_list[np.argwhere(letter_list_np == k)[0][0]]
|
| 500 |
+
if isinstance(v[0], list):
|
| 501 |
+
for v_count in range(len(v[0])):
|
| 502 |
+
one_list.append(start_idx+v[0][v_count]-1)#make 0 to be the first
|
| 503 |
+
tied_beta[start_idx+v[0][v_count]-1] = v[1][v_count]
|
| 504 |
+
else:
|
| 505 |
+
for v_ in v:
|
| 506 |
+
one_list.append(start_idx+v_-1)#make 0 to be the first
|
| 507 |
+
tied_pos_list_of_lists.append(one_list)
|
| 508 |
+
tied_pos_list_of_lists_list.append(tied_pos_list_of_lists)
|
| 509 |
+
|
| 510 |
+
x = np.concatenate(x_chain_list,0) #[L, 4, 3]
|
| 511 |
+
|
| 512 |
+
if USING_DYNAMICS:
|
| 513 |
+
if ('norm_bfactors' in batch[0].keys()):
|
| 514 |
+
bf = np.concatenate(b_factors_list,0) #[L,]
|
| 515 |
+
if ('gt_flex' in batch[0].keys()):
|
| 516 |
+
gt = np.concatenate(gt_flex_list,0) #[L,]
|
| 517 |
+
if ('enm_vals' in batch[0].keys()):
|
| 518 |
+
enm = np.concatenate(enm_vals_list,0)
|
| 519 |
+
if ('original_gt_flex' in batch[0].keys()):
|
| 520 |
+
orig_gt = np.concatenate(original_gt_flex_list,0)
|
| 521 |
+
if ('eng_mask' in batch[0].keys()):
|
| 522 |
+
eng = np.concatenate(eng_mask_list,0)
|
| 523 |
+
|
| 524 |
+
all_sequence = "".join(chain_seq_list)
|
| 525 |
+
m = np.concatenate(chain_mask_list,0) #[L,], 1.0 for places that need to be predicted
|
| 526 |
+
chain_encoding = np.concatenate(chain_encoding_list,0)
|
| 527 |
+
m_pos = np.concatenate(fixed_position_mask_list,0) #[L,], 1.0 for places that need to be predicted
|
| 528 |
+
|
| 529 |
+
pssm_coef_ = np.concatenate(pssm_coef_list,0) #[L,], 1.0 for places that need to be predicted
|
| 530 |
+
pssm_bias_ = np.concatenate(pssm_bias_list,0) #[L,], 1.0 for places that need to be predicted
|
| 531 |
+
pssm_log_odds_ = np.concatenate(pssm_log_odds_list,0) #[L,], 1.0 for places that need to be predicted
|
| 532 |
+
|
| 533 |
+
bias_by_res_ = np.concatenate(bias_by_res_list, 0) #[L,21], 0.0 for places where AA frequencies don't need to be tweaked
|
| 534 |
+
|
| 535 |
+
l = len(all_sequence)
|
| 536 |
+
x_pad = np.pad(x, [[0, L_max-l], [0,0], [0,0]], 'constant', constant_values=(np.nan, ))
|
| 537 |
+
if USING_DYNAMICS:
|
| 538 |
+
if ('norm_bfactors' in batch[0].keys()):
|
| 539 |
+
bf_pad = np.pad(bf, [[0, L_max-l]], 'constant', constant_values=(np.nan, ))
|
| 540 |
+
if ('gt_flex' in batch[0].keys()):
|
| 541 |
+
gt_pad = np.pad(gt, [[0, L_max-l]], 'constant', constant_values=(np.nan, ))
|
| 542 |
+
if ('enm_vals' in batch[0].keys()):
|
| 543 |
+
enm_pad = np.pad(enm, [[0, L_max-l]], 'constant', constant_values=(np.nan, ))
|
| 544 |
+
if ('original_gt_flex' in batch[0].keys()):
|
| 545 |
+
orig_gt_pad = np.pad(orig_gt, [[0, L_max-l]], 'constant', constant_values=(0, ))
|
| 546 |
+
if ('eng_mask' in batch[0].keys()):
|
| 547 |
+
eng_pad = np.pad(eng, [[0, L_max-l]], 'constant', constant_values=(0, ))
|
| 548 |
+
|
| 549 |
+
X[i,:,:,:] = x_pad
|
| 550 |
+
if USING_DYNAMICS:
|
| 551 |
+
if ('norm_bfactors' in batch[0].keys()):
|
| 552 |
+
b_factors[i, :] = bf_pad
|
| 553 |
+
if ('gt_flex' in batch[0].keys()):
|
| 554 |
+
gt_flex[i, :] = gt_pad[:-1]
|
| 555 |
+
if ('enm_vals' in batch[0].keys()):
|
| 556 |
+
enm_vals[i, :] = enm_pad
|
| 557 |
+
if ('original_gt_flex' in batch[0].keys()):
|
| 558 |
+
original_gt_flex[i, :] = orig_gt_pad[:-1]
|
| 559 |
+
if ('eng_mask' in batch[0].keys()):
|
| 560 |
+
eng_mask[i, :] = eng_pad[:-1]
|
| 561 |
+
|
| 562 |
+
if 'score' in b.keys():
|
| 563 |
+
score[i, :l] = b['score']
|
| 564 |
+
else:
|
| 565 |
+
score[i, :l] = 100.0
|
| 566 |
+
|
| 567 |
+
m_pad = np.pad(m, [[0, L_max-l]], 'constant', constant_values=(0.0, ))
|
| 568 |
+
m_pos_pad = np.pad(m_pos, [[0,L_max-l]], 'constant', constant_values=(0.0, ))
|
| 569 |
+
omit_AA_mask_pad = np.pad(np.concatenate(omit_AA_mask_list,0), [[0,L_max-l], [0, 0]], 'constant', constant_values=(0.0, ))
|
| 570 |
+
chain_M[i,:] = m_pad
|
| 571 |
+
chain_M_pos[i,:] = m_pos_pad
|
| 572 |
+
omit_AA_mask[i,] = omit_AA_mask_pad
|
| 573 |
+
|
| 574 |
+
chain_encoding_pad = np.pad(chain_encoding, [[0,L_max-l]], 'constant', constant_values=(0.0, ))
|
| 575 |
+
chain_encoding_all[i,:] = chain_encoding_pad
|
| 576 |
+
|
| 577 |
+
pssm_coef_pad = np.pad(pssm_coef_, [[0, L_max-l]], 'constant', constant_values=(0.0, ))
|
| 578 |
+
pssm_bias_pad = np.pad(pssm_bias_, [[0, L_max-l], [0,0]], 'constant', constant_values=(0.0, ))
|
| 579 |
+
pssm_log_odds_pad = np.pad(pssm_log_odds_, [[0,L_max-l], [0,0]], 'constant', constant_values=(0.0, ))
|
| 580 |
+
|
| 581 |
+
pssm_coef_all[i,:] = pssm_coef_pad
|
| 582 |
+
pssm_bias_all[i,:] = pssm_bias_pad
|
| 583 |
+
pssm_log_odds_all[i,:] = pssm_log_odds_pad
|
| 584 |
+
|
| 585 |
+
bias_by_res_pad = np.pad(bias_by_res_, [[0,L_max-l], [0,0]], 'constant', constant_values=(0.0, ))
|
| 586 |
+
bias_by_res_all[i,:] = bias_by_res_pad
|
| 587 |
+
|
| 588 |
+
# Convert to labels
|
| 589 |
+
indices = np.array(tokenizer.encode(b['seq'], add_special_tokens=False))
|
| 590 |
+
S[i, :l] = indices
|
| 591 |
+
letter_list_list.append(letter_list)
|
| 592 |
+
visible_list_list.append(visible_list)
|
| 593 |
+
masked_list_list.append(masked_list)
|
| 594 |
+
masked_chain_length_list_list.append(masked_chain_length_list)
|
| 595 |
+
|
| 596 |
+
|
| 597 |
+
isnan = np.isnan(X)
|
| 598 |
+
mask = np.isfinite(np.sum(X,(2,3))).astype(np.float32)
|
| 599 |
+
X[isnan] = 0.
|
| 600 |
+
|
| 601 |
+
# Conversion
|
| 602 |
+
pssm_coef_all = torch.from_numpy(pssm_coef_all).to(dtype=torch.float32)
|
| 603 |
+
pssm_bias_all = torch.from_numpy(pssm_bias_all).to(dtype=torch.float32)
|
| 604 |
+
pssm_log_odds_all = torch.from_numpy(pssm_log_odds_all).to(dtype=torch.float32)
|
| 605 |
+
|
| 606 |
+
tied_beta = torch.from_numpy(tied_beta).to(dtype=torch.float32)
|
| 607 |
+
|
| 608 |
+
jumps = ((residue_idx[:,1:]-residue_idx[:,:-1])==1).astype(np.float32)
|
| 609 |
+
bias_by_res_all = torch.from_numpy(bias_by_res_all).to(dtype=torch.float32)
|
| 610 |
+
phi_mask = np.pad(jumps, [[0,0],[1,0]])
|
| 611 |
+
psi_mask = np.pad(jumps, [[0,0],[0,1]])
|
| 612 |
+
omega_mask = np.pad(jumps, [[0,0],[0,1]])
|
| 613 |
+
dihedral_mask = np.concatenate([phi_mask[:,:,None], psi_mask[:,:,None], omega_mask[:,:,None]], -1) #[B,L,3]
|
| 614 |
+
dihedral_mask = torch.from_numpy(dihedral_mask).to(dtype=torch.float32)
|
| 615 |
+
residue_idx = torch.from_numpy(residue_idx).to(dtype=torch.long)
|
| 616 |
+
S = torch.from_numpy(S).to(dtype=torch.long)
|
| 617 |
+
X = torch.from_numpy(X).to(dtype=torch.float32)
|
| 618 |
+
if USING_DYNAMICS:
|
| 619 |
+
if ('norm_bfactors' in batch[0].keys()):
|
| 620 |
+
b_factors = torch.from_numpy(b_factors).to(dtype=torch.float32)
|
| 621 |
+
if ('gt_flex' in batch[0].keys()):
|
| 622 |
+
gt_flex = torch.from_numpy(gt_flex).to(dtype=torch.float32)
|
| 623 |
+
if ('enm_vals' in batch[0].keys()):
|
| 624 |
+
enm_vals = torch.from_numpy(enm_vals).to(dtype=torch.float32)
|
| 625 |
+
if ('original_gt_flex' in batch[0].keys()):
|
| 626 |
+
original_gt_flex = torch.from_numpy(original_gt_flex).to(dtype=torch.float32)
|
| 627 |
+
if ('eng_mask' in batch[0].keys()):
|
| 628 |
+
eng_mask = torch.from_numpy(eng_mask).to(dtype=torch.float32)
|
| 629 |
+
score = torch.from_numpy(score).float()
|
| 630 |
+
mask = torch.from_numpy(mask).to(dtype=torch.float32)
|
| 631 |
+
chain_M = torch.from_numpy(chain_M).to(dtype=torch.float32)
|
| 632 |
+
chain_M_pos = torch.from_numpy(chain_M_pos).to(dtype=torch.float32)
|
| 633 |
+
omit_AA_mask = torch.from_numpy(omit_AA_mask).to(dtype=torch.float32)
|
| 634 |
+
chain_encoding_all = torch.from_numpy(chain_encoding_all).to(dtype=torch.long)
|
| 635 |
+
|
| 636 |
+
if is_testing is False:
|
| 637 |
+
retVal = {"title": [b['title'] for b in batch],
|
| 638 |
+
"X":X,
|
| 639 |
+
"S":S,
|
| 640 |
+
"score": score,
|
| 641 |
+
"mask":mask,
|
| 642 |
+
"lengths":lengths,
|
| 643 |
+
"chain_M":chain_M,
|
| 644 |
+
"chain_M_pos":chain_M_pos,
|
| 645 |
+
"residue_idx":residue_idx,
|
| 646 |
+
"chain_encoding_all":chain_encoding_all}
|
| 647 |
+
if USING_DYNAMICS:
|
| 648 |
+
if ('norm_bfactors' in batch[0].keys()):
|
| 649 |
+
retVal['norm_bfactors'] = b_factors
|
| 650 |
+
if ('gt_flex' in batch[0].keys()):
|
| 651 |
+
retVal['gt_flex'] = gt_flex
|
| 652 |
+
if ('enm_vals' in batch[0].keys()):
|
| 653 |
+
retVal['enm_vals'] = enm_vals
|
| 654 |
+
if ('original_gt_flex' in batch[0].keys()):
|
| 655 |
+
retVal['original_gt_flex'] = original_gt_flex
|
| 656 |
+
if ('eng_mask' in batch[0].keys()):
|
| 657 |
+
retVal['eng_mask'] = eng_mask
|
| 658 |
+
|
| 659 |
+
return retVal
|
| 660 |
+
else:
|
| 661 |
+
retVal = {"title": [b['title'] for b in batch],
|
| 662 |
+
"X":X,
|
| 663 |
+
"S":S,
|
| 664 |
+
"score": score,
|
| 665 |
+
"mask":mask,
|
| 666 |
+
"lengths":lengths,
|
| 667 |
+
"chain_M":chain_M,
|
| 668 |
+
"chain_M_pos":chain_M_pos,
|
| 669 |
+
"residue_idx":residue_idx,
|
| 670 |
+
"chain_encoding_all":chain_encoding_all}
|
| 671 |
+
if USING_DYNAMICS:
|
| 672 |
+
if ('norm_bfactors' in batch[0].keys()):
|
| 673 |
+
retVal['norm_bfactors'] = b_factors
|
| 674 |
+
if ('gt_flex' in batch[0].keys()):
|
| 675 |
+
retVal['gt_flex'] = gt_flex
|
| 676 |
+
if ('enm_vals' in batch[0].keys()):
|
| 677 |
+
retVal['enm_vals'] = enm_vals
|
| 678 |
+
if ('original_gt_flex' in batch[0].keys()):
|
| 679 |
+
retVal['original_gt_flex'] = original_gt_flex
|
| 680 |
+
if ('eng_mask' in batch[0].keys()):
|
| 681 |
+
retVal['eng_mask'] = eng_mask
|
| 682 |
+
return retVal
|
| 683 |
+
|
| 684 |
+
|
| 685 |
+
def featurize_Inversefolding(batch, shuffle_fraction=0.):
|
| 686 |
+
""" Pack and pad batch into torch tensors """
|
| 687 |
+
alphabet = 'ACDEFGHIKLMNPQRSTVWY'
|
| 688 |
+
B = len(batch)
|
| 689 |
+
lengths = np.array([len(b['seq']) for b in batch], dtype=np.int32)
|
| 690 |
+
L_max = max([len(b['seq']) for b in batch])
|
| 691 |
+
X = np.zeros([B, L_max, 3, 3])
|
| 692 |
+
S = np.zeros([B, L_max], dtype=np.int32)
|
| 693 |
+
score = np.ones([B, L_max]) * 100.0
|
| 694 |
+
chain_mask = np.zeros([B, L_max])-1 # 1:需要被预测的掩码部分 0:可见部分
|
| 695 |
+
chain_encoding = np.zeros([B, L_max])-1
|
| 696 |
+
|
| 697 |
+
# Build the batch
|
| 698 |
+
for i, b in enumerate(batch):
|
| 699 |
+
x = np.stack([b[c] for c in ['N', 'CA', 'C']], 1) # [#atom, 4, 3]
|
| 700 |
+
|
| 701 |
+
l = len(b['seq'])
|
| 702 |
+
x_pad = np.pad(x, [[0,L_max-l], [0,0], [0,0]], 'constant', constant_values=(np.nan, )) # [#atom, 3, 3]
|
| 703 |
+
X[i,:,:,:] = x_pad
|
| 704 |
+
|
| 705 |
+
# Convert to labels
|
| 706 |
+
indices = np.array(tokenizer.encode(b['seq'], add_special_tokens=False))
|
| 707 |
+
if shuffle_fraction > 0.:
|
| 708 |
+
idx_shuffle = shuffle_subset(l, shuffle_fraction)
|
| 709 |
+
S[i, :l] = indices[idx_shuffle]
|
| 710 |
+
else:
|
| 711 |
+
S[i, :l] = indices
|
| 712 |
+
|
| 713 |
+
chain_mask[i,:l] = b['chain_mask']
|
| 714 |
+
chain_encoding[i,:l] = b['chain_encoding']
|
| 715 |
+
|
| 716 |
+
mask = np.isfinite(np.sum(X,(2,3))).astype(np.float32) # atom mask
|
| 717 |
+
numbers = np.sum(mask, axis=1).astype(np.int)
|
| 718 |
+
S_new = np.zeros_like(S)
|
| 719 |
+
X_new = np.zeros_like(X)+np.nan
|
| 720 |
+
for i, n in enumerate(numbers):
|
| 721 |
+
X_new[i,:n,::] = X[i][mask[i]==1]
|
| 722 |
+
S_new[i,:n] = S[i][mask[i]==1]
|
| 723 |
+
|
| 724 |
+
X = X_new
|
| 725 |
+
S = S_new
|
| 726 |
+
isnan = np.isnan(X)
|
| 727 |
+
mask = np.isfinite(np.sum(X,(2,3))).astype(np.float32)
|
| 728 |
+
X[isnan] = 0.
|
| 729 |
+
# Conversion
|
| 730 |
+
S = torch.from_numpy(S).to(dtype=torch.long)
|
| 731 |
+
score = torch.from_numpy(score).float()
|
| 732 |
+
X = torch.from_numpy(X).to(dtype=torch.float32)
|
| 733 |
+
mask = torch.from_numpy(mask).to(dtype=torch.float32)
|
| 734 |
+
chain_mask = torch.from_numpy(chain_mask)
|
| 735 |
+
chain_encoding = torch.from_numpy(chain_encoding)
|
| 736 |
+
return {"title": [b['title'] for b in batch],
|
| 737 |
+
"X":X,
|
| 738 |
+
"S":S,
|
| 739 |
+
"score": score,
|
| 740 |
+
"mask":mask,
|
| 741 |
+
"lengths":lengths,
|
| 742 |
+
"chain_mask":chain_mask,
|
| 743 |
+
"chain_encoding":chain_encoding}
|
Flexpert-Design/src/datasets/flex_cath_dataset.py
ADDED
|
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import numpy as np
|
| 4 |
+
from tqdm import tqdm
|
| 5 |
+
import random
|
| 6 |
+
import torch.utils.data as data
|
| 7 |
+
from .utils import cached_property
|
| 8 |
+
from transformers import AutoTokenizer
|
| 9 |
+
from src.tools.utils import load_yaml_config
|
| 10 |
+
|
| 11 |
+
class FlexCATHDataset(data.Dataset):
|
| 12 |
+
def __init__(self, path='./', split='train', max_length=500, test_name='All', data = None, removeTS=0, version=4.3, data_jsonl_name='/chain_set.jsonl', use_dynamics=True):
|
| 13 |
+
self.version = version
|
| 14 |
+
self.path = path
|
| 15 |
+
self.mode = split
|
| 16 |
+
self.max_length = max_length
|
| 17 |
+
self.test_name = test_name
|
| 18 |
+
self.removeTS = removeTS
|
| 19 |
+
self.data_jsonl_name = data_jsonl_name
|
| 20 |
+
|
| 21 |
+
self.using_dynamics = use_dynamics
|
| 22 |
+
|
| 23 |
+
print(self.data_jsonl_name)
|
| 24 |
+
if self.removeTS:
|
| 25 |
+
self.remove = json.load(open(self.path+'/remove.json', 'r'))['remove']
|
| 26 |
+
|
| 27 |
+
if data is None:
|
| 28 |
+
if split == 'predict':
|
| 29 |
+
_split = 'valid'
|
| 30 |
+
print('In predict mode for CATH4.3 using VALIDATION split as the data. Consider switching to TEST set.')
|
| 31 |
+
else:
|
| 32 |
+
_split = split
|
| 33 |
+
self.data = self.cache_data[_split]
|
| 34 |
+
else:
|
| 35 |
+
self.data = data
|
| 36 |
+
|
| 37 |
+
self.tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D", cache_dir="./cache_dir/")
|
| 38 |
+
|
| 39 |
+
@cached_property
|
| 40 |
+
def cache_data(self):
|
| 41 |
+
alphabet='ACDEFGHIKLMNPQRSTVWY'
|
| 42 |
+
alphabet_set = set([a for a in alphabet])
|
| 43 |
+
print("path is: ", self.path)
|
| 44 |
+
if not os.path.exists(self.path):
|
| 45 |
+
raise "no such file:{} !!!".format(self.path)
|
| 46 |
+
else:
|
| 47 |
+
with open(self.path+'/'+self.data_jsonl_name) as f:
|
| 48 |
+
lines = f.readlines()
|
| 49 |
+
data_list = []
|
| 50 |
+
for line in tqdm(lines):
|
| 51 |
+
entry = json.loads(line)
|
| 52 |
+
if self.removeTS and entry['name'] in self.remove:
|
| 53 |
+
continue
|
| 54 |
+
seq = entry['seq']
|
| 55 |
+
|
| 56 |
+
for key, val in entry['coords'].items():
|
| 57 |
+
entry['coords'][key] = np.asarray(val)
|
| 58 |
+
|
| 59 |
+
bad_chars = set([s for s in seq]).difference(alphabet_set)
|
| 60 |
+
|
| 61 |
+
if len(bad_chars) == 0:
|
| 62 |
+
if len(entry['seq']) <= self.max_length:
|
| 63 |
+
chain_length = len(entry['seq'])
|
| 64 |
+
chain_mask = np.ones(chain_length)
|
| 65 |
+
data_list.append({
|
| 66 |
+
'title':entry['name'],
|
| 67 |
+
'seq':entry['seq'],
|
| 68 |
+
'CA':entry['coords']['CA'],
|
| 69 |
+
'C':entry['coords']['C'],
|
| 70 |
+
'O':entry['coords']['O'],
|
| 71 |
+
'N':entry['coords']['N'],
|
| 72 |
+
'chain_mask': chain_mask,
|
| 73 |
+
'chain_encoding': 1*chain_mask
|
| 74 |
+
})
|
| 75 |
+
if self.using_dynamics:
|
| 76 |
+
data_list[-1]['gt_flex'] = entry['gt_flex']
|
| 77 |
+
data_list[-1]['enm_vals'] = entry['enm_vals']
|
| 78 |
+
if 'original_gt_flex' in entry:
|
| 79 |
+
data_list[-1]['original_gt_flex'] = entry['original_gt_flex']
|
| 80 |
+
if 'eng_mask' in entry:
|
| 81 |
+
data_list[-1]['eng_mask'] = entry['eng_mask']
|
| 82 |
+
# else:
|
| 83 |
+
# import pdb; pdb.set_trace()
|
| 84 |
+
# print("Bad chars found in sequence: ", bad_chars)
|
| 85 |
+
|
| 86 |
+
if self.version==4.2:
|
| 87 |
+
with open(self.path+'/chain_set_splits.json') as f:
|
| 88 |
+
dataset_splits = json.load(f)
|
| 89 |
+
|
| 90 |
+
if self.version==4.3:
|
| 91 |
+
with open(self.path+'/chain_set_splits.json') as f:
|
| 92 |
+
dataset_splits = json.load(f)
|
| 93 |
+
# _dataset_splits = json.load(f)
|
| 94 |
+
# dataset_splits = {k: _dataset_splits['train'] for k,_ in _dataset_splits.items()}
|
| 95 |
+
# print("TODO: FIX THIS BACK!!!")
|
| 96 |
+
# import pdb; pdb.set_trace()
|
| 97 |
+
|
| 98 |
+
if self.test_name == 'L100':
|
| 99 |
+
with open(self.path+'/test_split_L100.json') as f:
|
| 100 |
+
test_splits = json.load(f)
|
| 101 |
+
dataset_splits['test'] = test_splits['test']
|
| 102 |
+
|
| 103 |
+
if self.test_name == 'sc':
|
| 104 |
+
with open(self.path+'/test_split_sc.json') as f:
|
| 105 |
+
test_splits = json.load(f)
|
| 106 |
+
dataset_splits['test'] = test_splits['test']
|
| 107 |
+
|
| 108 |
+
name2set = {}
|
| 109 |
+
name2set.update({name:'train' for name in dataset_splits['train']})
|
| 110 |
+
name2set.update({name:'valid' for name in dataset_splits['validation']})
|
| 111 |
+
name2set.update({name:'test' for name in dataset_splits['test']})
|
| 112 |
+
|
| 113 |
+
data_dict = {'train':[],'valid':[],'test':[]}
|
| 114 |
+
for data in data_list:
|
| 115 |
+
if name2set.get(data['title']):
|
| 116 |
+
if name2set[data['title']] == 'train':
|
| 117 |
+
data_dict['train'].append(data)
|
| 118 |
+
|
| 119 |
+
if name2set[data['title']] == 'valid':
|
| 120 |
+
data_dict['valid'].append(data)
|
| 121 |
+
|
| 122 |
+
if name2set[data['title']] == 'test':
|
| 123 |
+
data['category'] = 'Unkown'
|
| 124 |
+
data['score'] = 100.0
|
| 125 |
+
data_dict['test'].append(data)
|
| 126 |
+
return data_dict
|
| 127 |
+
|
| 128 |
+
def change_mode(self, mode):
|
| 129 |
+
self.data = self.cache_data[mode]
|
| 130 |
+
|
| 131 |
+
def __len__(self):
|
| 132 |
+
return len(self.data)
|
| 133 |
+
|
| 134 |
+
def get_item(self, index):
|
| 135 |
+
return self.data[index]
|
| 136 |
+
|
| 137 |
+
def __getitem__(self, index):
|
| 138 |
+
item = self.data[index]
|
| 139 |
+
L = len(item['seq'])
|
| 140 |
+
if L>self.max_length:
|
| 141 |
+
# 计算截断的最大索引
|
| 142 |
+
max_index = L - self.max_length
|
| 143 |
+
# 生成随机的截断索引
|
| 144 |
+
truncate_index = random.randint(0, max_index)
|
| 145 |
+
# 进行截断
|
| 146 |
+
item['seq'] = item['seq'][truncate_index:truncate_index+self.max_length]
|
| 147 |
+
item['CA'] = item['CA'][truncate_index:truncate_index+self.max_length]
|
| 148 |
+
item['C'] = item['C'][truncate_index:truncate_index+self.max_length]
|
| 149 |
+
item['O'] = item['O'][truncate_index:truncate_index+self.max_length]
|
| 150 |
+
item['N'] = item['N'][truncate_index:truncate_index+self.max_length]
|
| 151 |
+
item['chain_mask'] = item['chain_mask'][truncate_index:truncate_index+self.max_length]
|
| 152 |
+
item['chain_encoding'] = item['chain_encoding'][truncate_index:truncate_index+self.max_length]
|
| 153 |
+
item['gt_flex'] = item['gt_flex'][truncate_index:truncate_index+self.max_length]
|
| 154 |
+
item['enm_vals'] = item['enm_vals'][truncate_index:truncate_index+self.max_length]
|
| 155 |
+
return item
|
Flexpert-Design/src/datasets/foldswitchers_dataset.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import numpy as np
|
| 4 |
+
from tqdm import tqdm
|
| 5 |
+
import random
|
| 6 |
+
import pdb
|
| 7 |
+
import torch.utils.data as data
|
| 8 |
+
from .utils import cached_property
|
| 9 |
+
from transformers import AutoTokenizer
|
| 10 |
+
|
| 11 |
+
class FoldswitchersDataset(data.Dataset):
|
| 12 |
+
def __init__(self, path='./', split='train', max_length=500, test_name='All', data = None, removeTS=0):
|
| 13 |
+
self.path = path
|
| 14 |
+
self.mode = split
|
| 15 |
+
self.max_length = max_length
|
| 16 |
+
self.test_name = test_name
|
| 17 |
+
self.removeTS = removeTS
|
| 18 |
+
if self.removeTS:
|
| 19 |
+
self.remove = json.load(open(self.path+'/remove.json', 'r'))['remove']
|
| 20 |
+
|
| 21 |
+
if data is None:
|
| 22 |
+
self.data = self.cache_data[split] #This calls the cache_data property
|
| 23 |
+
else:
|
| 24 |
+
self.data = data
|
| 25 |
+
|
| 26 |
+
self.tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D", cache_dir="./cache_dir/")
|
| 27 |
+
|
| 28 |
+
@cached_property
|
| 29 |
+
def cache_data(self):
|
| 30 |
+
alphabet='ACDEFGHIKLMNPQRSTVWY'
|
| 31 |
+
alphabet_set = set([a for a in alphabet])
|
| 32 |
+
print("path is: ", self.path)
|
| 33 |
+
|
| 34 |
+
if not os.path.exists(self.path):
|
| 35 |
+
raise "no such file:{} !!!".format(self.path)
|
| 36 |
+
else:
|
| 37 |
+
|
| 38 |
+
with open(self.path+'/chain_set.jsonl') as f:
|
| 39 |
+
lines = f.readlines()
|
| 40 |
+
data_list = []
|
| 41 |
+
|
| 42 |
+
for line in tqdm(lines):
|
| 43 |
+
entry = json.loads(line)
|
| 44 |
+
|
| 45 |
+
if self.removeTS and entry['name'] in self.remove:
|
| 46 |
+
continue
|
| 47 |
+
seq = entry['seq']
|
| 48 |
+
|
| 49 |
+
for key, val in entry['coords'].items():
|
| 50 |
+
entry['coords'][key] = np.asarray(val)
|
| 51 |
+
|
| 52 |
+
bad_chars = set([s for s in seq]).difference(alphabet_set)
|
| 53 |
+
|
| 54 |
+
if len(bad_chars) == 0:
|
| 55 |
+
if len(entry['seq']) <= self.max_length:
|
| 56 |
+
chain_length = len(entry['seq'])
|
| 57 |
+
chain_mask = np.ones(chain_length)
|
| 58 |
+
data_list.append({
|
| 59 |
+
'title':entry['name'],
|
| 60 |
+
'seq':entry['seq'],
|
| 61 |
+
'CA':entry['coords']['CA'],
|
| 62 |
+
'C':entry['coords']['C'],
|
| 63 |
+
'O':entry['coords']['O'],
|
| 64 |
+
'N':entry['coords']['N'],
|
| 65 |
+
'chain_mask': chain_mask,
|
| 66 |
+
'chain_encoding': 1*chain_mask
|
| 67 |
+
})
|
| 68 |
+
|
| 69 |
+
with open(self.path+'/chain_set_splits_cleaned.json') as f:
|
| 70 |
+
dataset_splits = json.load(f)
|
| 71 |
+
|
| 72 |
+
if self.test_name == 'L100':
|
| 73 |
+
with open(self.path+'/test_split_L100.json') as f:
|
| 74 |
+
test_splits = json.load(f)
|
| 75 |
+
dataset_splits['test'] = test_splits['test']
|
| 76 |
+
|
| 77 |
+
if self.test_name == 'sc':
|
| 78 |
+
with open(self.path+'/test_split_sc.json') as f:
|
| 79 |
+
test_splits = json.load(f)
|
| 80 |
+
dataset_splits['test'] = test_splits['test']
|
| 81 |
+
|
| 82 |
+
name2set = {}
|
| 83 |
+
name2set.update({name:'train' for name in dataset_splits['train']})
|
| 84 |
+
name2set.update({name:'valid' for name in dataset_splits['validation']})
|
| 85 |
+
name2set.update({name:'test' for name in dataset_splits['test']})
|
| 86 |
+
|
| 87 |
+
data_dict = {'train':[],'valid':[],'test':[]}
|
| 88 |
+
for data in data_list:
|
| 89 |
+
#pdb.set_trace()
|
| 90 |
+
if name2set.get(data['title']): #This was causing the trouble with empty datasets - missmatch of names in the chain_set and chain_set_split
|
| 91 |
+
if name2set[data['title']] == 'train':
|
| 92 |
+
data_dict['train'].append(data)
|
| 93 |
+
|
| 94 |
+
if name2set[data['title']] == 'valid':
|
| 95 |
+
data_dict['valid'].append(data)
|
| 96 |
+
|
| 97 |
+
if name2set[data['title']] == 'test':
|
| 98 |
+
data['category'] = 'Unkown'
|
| 99 |
+
data['score'] = 100.0
|
| 100 |
+
data_dict['test'].append(data)
|
| 101 |
+
return data_dict
|
| 102 |
+
|
| 103 |
+
def change_mode(self, mode):
|
| 104 |
+
self.data = self.cache_data[mode]
|
| 105 |
+
|
| 106 |
+
def __len__(self):
|
| 107 |
+
return len(self.data)
|
| 108 |
+
|
| 109 |
+
def get_item(self, index):
|
| 110 |
+
return self.data[index]
|
| 111 |
+
|
| 112 |
+
def __getitem__(self, index):
|
| 113 |
+
item = self.data[index]
|
| 114 |
+
L = len(item['seq'])
|
| 115 |
+
if L>self.max_length:
|
| 116 |
+
# 计算截断的最大索引
|
| 117 |
+
max_index = L - self.max_length
|
| 118 |
+
# 生成随机的截断索引
|
| 119 |
+
truncate_index = random.randint(0, max_index)
|
| 120 |
+
# 进行截断
|
| 121 |
+
item['seq'] = item['seq'][truncate_index:truncate_index+self.max_length]
|
| 122 |
+
item['CA'] = item['CA'][truncate_index:truncate_index+self.max_length]
|
| 123 |
+
item['C'] = item['C'][truncate_index:truncate_index+self.max_length]
|
| 124 |
+
item['O'] = item['O'][truncate_index:truncate_index+self.max_length]
|
| 125 |
+
item['N'] = item['N'][truncate_index:truncate_index+self.max_length]
|
| 126 |
+
item['chain_mask'] = item['chain_mask'][truncate_index:truncate_index+self.max_length]
|
| 127 |
+
item['chain_encoding'] = item['chain_encoding'][truncate_index:truncate_index+self.max_length]
|
| 128 |
+
return item
|
Flexpert-Design/src/datasets/mpnn_dataset.py
ADDED
|
@@ -0,0 +1,492 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import numpy as np
|
| 4 |
+
from tqdm import tqdm
|
| 5 |
+
import pandas as pd
|
| 6 |
+
import torch.utils.data as data
|
| 7 |
+
from Bio.PDB import PDBParser
|
| 8 |
+
import torch
|
| 9 |
+
import random
|
| 10 |
+
import csv
|
| 11 |
+
from dateutil import parser
|
| 12 |
+
from .fast_dataloader import DataLoaderX
|
| 13 |
+
from torch.utils.data import DataLoader
|
| 14 |
+
import time
|
| 15 |
+
|
| 16 |
+
from joblib import Parallel, delayed, cpu_count
|
| 17 |
+
from tqdm import tqdm
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def pmap_multi(pickleable_fn, data, n_jobs=None, verbose=1, desc=None, **kwargs):
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
Parallel map using joblib.
|
| 24 |
+
|
| 25 |
+
Parameters
|
| 26 |
+
----------
|
| 27 |
+
pickleable_fn : callable
|
| 28 |
+
Function to map over data.
|
| 29 |
+
data : iterable
|
| 30 |
+
Data over which we want to parallelize the function call.
|
| 31 |
+
n_jobs : int, optional
|
| 32 |
+
The maximum number of concurrently running jobs. By default, it is one less than
|
| 33 |
+
the number of CPUs.
|
| 34 |
+
verbose: int, optional
|
| 35 |
+
The verbosity level. If nonzero, the function prints the progress messages.
|
| 36 |
+
The frequency of the messages increases with the verbosity level. If above 10,
|
| 37 |
+
it reports all iterations. If above 50, it sends the output to stdout.
|
| 38 |
+
kwargs
|
| 39 |
+
Additional arguments for :attr:`pickleable_fn`.
|
| 40 |
+
|
| 41 |
+
Returns
|
| 42 |
+
-------
|
| 43 |
+
list
|
| 44 |
+
The i-th element of the list corresponds to the output of applying
|
| 45 |
+
:attr:`pickleable_fn` to :attr:`data[i]`.
|
| 46 |
+
"""
|
| 47 |
+
if n_jobs is None:
|
| 48 |
+
n_jobs = cpu_count() - 1
|
| 49 |
+
|
| 50 |
+
results = Parallel(n_jobs=n_jobs, verbose=verbose, timeout=None)(
|
| 51 |
+
delayed(pickleable_fn)(*d, **kwargs) for i, d in tqdm(enumerate(data),desc=desc)
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def build_training_clusters(params, debug):
|
| 57 |
+
val_ids = set([int(l) for l in open(params['VAL']).readlines()])
|
| 58 |
+
test_ids = set([int(l) for l in open(params['TEST']).readlines()])
|
| 59 |
+
|
| 60 |
+
if debug:
|
| 61 |
+
val_ids = []
|
| 62 |
+
test_ids = []
|
| 63 |
+
|
| 64 |
+
# read & clean list.csv
|
| 65 |
+
with open(params['LIST'], 'r') as f:
|
| 66 |
+
reader = csv.reader(f)
|
| 67 |
+
next(reader)
|
| 68 |
+
rows = [[r[0],r[3],int(r[4])] for r in reader
|
| 69 |
+
if float(r[2])<=params['RESCUT'] and
|
| 70 |
+
parser.parse(r[1])<=parser.parse(params['DATCUT'])]
|
| 71 |
+
|
| 72 |
+
# compile training and validation sets
|
| 73 |
+
train = {}
|
| 74 |
+
valid = {}
|
| 75 |
+
test = {}
|
| 76 |
+
|
| 77 |
+
if debug:
|
| 78 |
+
rows = rows[:20]
|
| 79 |
+
for r in rows:
|
| 80 |
+
if r[2] in val_ids:
|
| 81 |
+
if r[2] in valid.keys():
|
| 82 |
+
valid[r[2]].append(r[:2])
|
| 83 |
+
else:
|
| 84 |
+
valid[r[2]] = [r[:2]]
|
| 85 |
+
elif r[2] in test_ids:
|
| 86 |
+
if r[2] in test.keys():
|
| 87 |
+
test[r[2]].append(r[:2])
|
| 88 |
+
else:
|
| 89 |
+
test[r[2]] = [r[:2]]
|
| 90 |
+
else:
|
| 91 |
+
if r[2] in train.keys():
|
| 92 |
+
train[r[2]].append(r[:2])
|
| 93 |
+
else:
|
| 94 |
+
train[r[2]] = [r[:2]]
|
| 95 |
+
if debug:
|
| 96 |
+
valid=train
|
| 97 |
+
return train, valid, test
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def loader_pdb(item,params):
|
| 101 |
+
|
| 102 |
+
pdbid,chid = item[0].split('_')
|
| 103 |
+
PREFIX = "%s/pdb/%s/%s"%(params['DIR'],pdbid[1:3],pdbid)
|
| 104 |
+
|
| 105 |
+
# load metadata
|
| 106 |
+
if not os.path.isfile(PREFIX+".pt"):
|
| 107 |
+
return {'seq': np.zeros(5)}
|
| 108 |
+
meta = torch.load(PREFIX+".pt")
|
| 109 |
+
asmb_ids = meta['asmb_ids']
|
| 110 |
+
asmb_chains = meta['asmb_chains']
|
| 111 |
+
chids = np.array(meta['chains'])
|
| 112 |
+
|
| 113 |
+
# find candidate assemblies which contain chid chain
|
| 114 |
+
asmb_candidates = set([a for a,b in zip(asmb_ids,asmb_chains)
|
| 115 |
+
if chid in b.split(',')])
|
| 116 |
+
|
| 117 |
+
# if the chains is missing is missing from all the assemblies
|
| 118 |
+
# then return this chain alone
|
| 119 |
+
if len(asmb_candidates)<1:
|
| 120 |
+
chain = torch.load("%s_%s.pt"%(PREFIX,chid))
|
| 121 |
+
L = len(chain['seq'])
|
| 122 |
+
return {'seq' : chain['seq'],
|
| 123 |
+
'xyz' : chain['xyz'],
|
| 124 |
+
'idx' : torch.zeros(L).int(),
|
| 125 |
+
'masked' : torch.Tensor([0]).int(),
|
| 126 |
+
'label' : item[0]}
|
| 127 |
+
|
| 128 |
+
# randomly pick one assembly from candidates
|
| 129 |
+
asmb_i = random.sample(list(asmb_candidates), 1)
|
| 130 |
+
|
| 131 |
+
# indices of selected transforms
|
| 132 |
+
idx = np.where(np.array(asmb_ids)==asmb_i)[0]
|
| 133 |
+
|
| 134 |
+
# load relevant chains
|
| 135 |
+
chains = {c:torch.load("%s_%s.pt"%(PREFIX,c))
|
| 136 |
+
for i in idx for c in asmb_chains[i]
|
| 137 |
+
if c in meta['chains']}
|
| 138 |
+
|
| 139 |
+
# generate assembly
|
| 140 |
+
asmb = {}
|
| 141 |
+
for k in idx:
|
| 142 |
+
|
| 143 |
+
# pick k-th xform
|
| 144 |
+
xform = meta['asmb_xform%d'%k]
|
| 145 |
+
u = xform[:,:3,:3]
|
| 146 |
+
r = xform[:,:3,3]
|
| 147 |
+
|
| 148 |
+
# select chains which k-th xform should be applied to
|
| 149 |
+
s1 = set(meta['chains'])
|
| 150 |
+
s2 = set(asmb_chains[k].split(','))
|
| 151 |
+
chains_k = s1&s2
|
| 152 |
+
|
| 153 |
+
# transform selected chains
|
| 154 |
+
for c in chains_k:
|
| 155 |
+
try:
|
| 156 |
+
xyz = chains[c]['xyz']
|
| 157 |
+
xyz_ru = torch.einsum('bij,raj->brai', u, xyz) + r[:,None,None,:]
|
| 158 |
+
asmb.update({(c,k,i):xyz_i for i,xyz_i in enumerate(xyz_ru)})
|
| 159 |
+
except KeyError:
|
| 160 |
+
return {'seq': np.zeros(5)}
|
| 161 |
+
|
| 162 |
+
# select chains which share considerable similarity to chid
|
| 163 |
+
seqid = meta['tm'][chids==chid][0,:,1]
|
| 164 |
+
homo = set([ch_j for seqid_j,ch_j in zip(seqid,chids)
|
| 165 |
+
if seqid_j>params['HOMO']])
|
| 166 |
+
# stack all chains in the assembly together
|
| 167 |
+
seq,xyz,idx,masked = "",[],[],[]
|
| 168 |
+
seq_list = []
|
| 169 |
+
for counter,(k,v) in enumerate(asmb.items()):
|
| 170 |
+
seq += chains[k[0]]['seq']
|
| 171 |
+
seq_list.append(chains[k[0]]['seq'])
|
| 172 |
+
xyz.append(v)
|
| 173 |
+
idx.append(torch.full((v.shape[0],),counter))
|
| 174 |
+
if k[0] in homo:
|
| 175 |
+
masked.append(counter)
|
| 176 |
+
|
| 177 |
+
return {'seq' : seq,
|
| 178 |
+
'xyz' : torch.cat(xyz,dim=0),
|
| 179 |
+
'idx' : torch.cat(idx,dim=0),
|
| 180 |
+
'masked' : torch.Tensor(masked).int(),
|
| 181 |
+
'label' : item[0]}
|
| 182 |
+
|
| 183 |
+
def get_pdbs(data, max_length=10000, num_units=1000000):
|
| 184 |
+
init_alphabet = ['A', 'B', 'C', 'D', 'E', 'F', 'G','H', 'I', 'J','K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T','U', 'V','W','X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g','h', 'i', 'j','k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't','u', 'v','w','x', 'y', 'z']
|
| 185 |
+
extra_alphabet = [str(item) for item in list(np.arange(300))]
|
| 186 |
+
chain_alphabet = init_alphabet + extra_alphabet
|
| 187 |
+
c = 0
|
| 188 |
+
c1 = 0
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
data = {k:v for k,v in data.items()}
|
| 192 |
+
c1 += 1
|
| 193 |
+
if 'label' in list(data):
|
| 194 |
+
my_dict = {}
|
| 195 |
+
s = 0
|
| 196 |
+
concat_seq = ''
|
| 197 |
+
concat_N = []
|
| 198 |
+
concat_CA = []
|
| 199 |
+
concat_C = []
|
| 200 |
+
concat_O = []
|
| 201 |
+
concat_mask = []
|
| 202 |
+
coords_dict = {}
|
| 203 |
+
mask_list = []
|
| 204 |
+
visible_list = []
|
| 205 |
+
if len(list(np.unique(data['idx']))) < 352:
|
| 206 |
+
for idx in list(np.unique(data['idx'])):
|
| 207 |
+
letter = chain_alphabet[idx]
|
| 208 |
+
res = np.argwhere(data['idx']==idx)
|
| 209 |
+
initial_sequence= "".join(list(np.array(list(data['seq']))[res][0,]))
|
| 210 |
+
if initial_sequence[-6:] == "HHHHHH":
|
| 211 |
+
res = res[:,:-6]
|
| 212 |
+
if initial_sequence[0:6] == "HHHHHH":
|
| 213 |
+
res = res[:,6:]
|
| 214 |
+
if initial_sequence[-7:-1] == "HHHHHH":
|
| 215 |
+
res = res[:,:-7]
|
| 216 |
+
if initial_sequence[-8:-2] == "HHHHHH":
|
| 217 |
+
res = res[:,:-8]
|
| 218 |
+
if initial_sequence[-9:-3] == "HHHHHH":
|
| 219 |
+
res = res[:,:-9]
|
| 220 |
+
if initial_sequence[-10:-4] == "HHHHHH":
|
| 221 |
+
res = res[:,:-10]
|
| 222 |
+
if initial_sequence[1:7] == "HHHHHH":
|
| 223 |
+
res = res[:,7:]
|
| 224 |
+
if initial_sequence[2:8] == "HHHHHH":
|
| 225 |
+
res = res[:,8:]
|
| 226 |
+
if initial_sequence[3:9] == "HHHHHH":
|
| 227 |
+
res = res[:,9:]
|
| 228 |
+
if initial_sequence[4:10] == "HHHHHH":
|
| 229 |
+
res = res[:,10:]
|
| 230 |
+
if res.shape[1] < 4:
|
| 231 |
+
pass
|
| 232 |
+
else:
|
| 233 |
+
my_dict['seq_chain_'+letter]= "".join(list(np.array(list(data['seq']))[res][0,]))
|
| 234 |
+
concat_seq += my_dict['seq_chain_'+letter]
|
| 235 |
+
if idx in data['masked']:
|
| 236 |
+
mask_list.append(letter)
|
| 237 |
+
else:
|
| 238 |
+
visible_list.append(letter)
|
| 239 |
+
coords_dict_chain = {}
|
| 240 |
+
all_atoms = np.array(data['xyz'][res,])[0,] #[L, 14, 3]
|
| 241 |
+
coords_dict_chain['N_chain_'+letter]=all_atoms[:,0,:].tolist()
|
| 242 |
+
coords_dict_chain['CA_chain_'+letter]=all_atoms[:,1,:].tolist()
|
| 243 |
+
coords_dict_chain['C_chain_'+letter]=all_atoms[:,2,:].tolist()
|
| 244 |
+
coords_dict_chain['O_chain_'+letter]=all_atoms[:,3,:].tolist()
|
| 245 |
+
my_dict['coords_chain_'+letter]=coords_dict_chain
|
| 246 |
+
my_dict['name']= data['label']
|
| 247 |
+
my_dict['masked_list']= mask_list
|
| 248 |
+
my_dict['visible_list']= visible_list
|
| 249 |
+
my_dict['num_of_chains'] = len(mask_list) + len(visible_list)
|
| 250 |
+
my_dict['seq'] = concat_seq
|
| 251 |
+
if len(concat_seq) <= max_length:
|
| 252 |
+
return my_dict
|
| 253 |
+
return None
|
| 254 |
+
|
| 255 |
+
def safe_iter(ID, split_dict, params, alphabet_set, max_length=1000):
|
| 256 |
+
sel_idx = np.random.randint(0, len(split_dict[ID]))
|
| 257 |
+
out = loader_pdb(split_dict[ID][sel_idx], params)
|
| 258 |
+
entry = get_pdbs(out)
|
| 259 |
+
if entry is None:
|
| 260 |
+
return None
|
| 261 |
+
|
| 262 |
+
seq = entry['seq']
|
| 263 |
+
bad_chars = set([s for s in seq]).difference(alphabet_set)
|
| 264 |
+
if len(bad_chars) != 0:
|
| 265 |
+
return None
|
| 266 |
+
|
| 267 |
+
if len(entry['seq']) > max_length:
|
| 268 |
+
return None
|
| 269 |
+
|
| 270 |
+
masked_chains = entry['masked_list']
|
| 271 |
+
visible_chains = entry['visible_list']
|
| 272 |
+
|
| 273 |
+
all_chains = masked_chains + visible_chains
|
| 274 |
+
visible_temp_dict = {}
|
| 275 |
+
masked_temp_dict = {}
|
| 276 |
+
|
| 277 |
+
for step, letter in enumerate(all_chains):
|
| 278 |
+
chain_seq = entry[f'seq_chain_{letter}']
|
| 279 |
+
if letter in visible_chains:
|
| 280 |
+
visible_temp_dict[letter] = chain_seq
|
| 281 |
+
elif letter in masked_chains:
|
| 282 |
+
masked_temp_dict[letter] = chain_seq
|
| 283 |
+
|
| 284 |
+
for km, vm in masked_temp_dict.items():
|
| 285 |
+
for kv, vv in visible_temp_dict.items():
|
| 286 |
+
if vm == vv:
|
| 287 |
+
if kv not in masked_chains:
|
| 288 |
+
masked_chains.append(kv)
|
| 289 |
+
if kv in visible_chains:
|
| 290 |
+
visible_chains.remove(kv)
|
| 291 |
+
|
| 292 |
+
all_chains = masked_chains + visible_chains
|
| 293 |
+
random.shuffle(all_chains)
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
x_chain_list = []
|
| 297 |
+
chain_mask_list = []
|
| 298 |
+
chain_seq_list = []
|
| 299 |
+
chain_encoding_list = []
|
| 300 |
+
c = 1
|
| 301 |
+
|
| 302 |
+
for step, letter in enumerate(all_chains):
|
| 303 |
+
if letter in visible_chains:
|
| 304 |
+
chain_seq = entry[f'seq_chain_{letter}']
|
| 305 |
+
chain_length = len(chain_seq)
|
| 306 |
+
chain_coords = entry[f'coords_chain_{letter}'] #this is a dictionary
|
| 307 |
+
chain_mask = np.zeros(chain_length) #0.0 for visible chains
|
| 308 |
+
x_chain = np.stack([chain_coords[c] for c in [f'N_chain_{letter}', f'CA_chain_{letter}', f'C_chain_{letter}', f'O_chain_{letter}']], 1) #[chain_length,4,3]
|
| 309 |
+
x_chain_list.append(x_chain)
|
| 310 |
+
chain_mask_list.append(chain_mask)
|
| 311 |
+
chain_seq_list.append(chain_seq)
|
| 312 |
+
chain_encoding_list.append(c*np.ones(np.array(chain_mask).shape[0]))
|
| 313 |
+
c+=1
|
| 314 |
+
elif letter in masked_chains:
|
| 315 |
+
chain_seq = entry[f'seq_chain_{letter}']
|
| 316 |
+
chain_length = len(chain_seq)
|
| 317 |
+
chain_coords = entry[f'coords_chain_{letter}'] #this is a dictionary
|
| 318 |
+
chain_mask = np.ones(chain_length) #0.0 for visible chains
|
| 319 |
+
x_chain = np.stack([chain_coords[c] for c in [f'N_chain_{letter}', f'CA_chain_{letter}', f'C_chain_{letter}', f'O_chain_{letter}']], 1) #[chain_lenght,4,3]
|
| 320 |
+
x_chain_list.append(x_chain)
|
| 321 |
+
chain_mask_list.append(chain_mask)
|
| 322 |
+
chain_seq_list.append(chain_seq)
|
| 323 |
+
chain_encoding_list.append(c*np.ones(np.array(chain_mask).shape[0]))
|
| 324 |
+
c+=1
|
| 325 |
+
|
| 326 |
+
chain_mask_all = torch.from_numpy(np.concatenate(chain_mask_list))
|
| 327 |
+
chain_encoding_all = torch.from_numpy(np.concatenate(chain_encoding_list))
|
| 328 |
+
x_chain_all = torch.from_numpy(np.concatenate(x_chain_list))
|
| 329 |
+
|
| 330 |
+
data = {
|
| 331 |
+
"title":entry['name'],
|
| 332 |
+
"seq":''.join(chain_seq_list), #len(seq)=n
|
| 333 |
+
"chain_mask":chain_mask_all,
|
| 334 |
+
"chain_encoding":chain_encoding_all,
|
| 335 |
+
"CA":x_chain_all[:,1], # [n,3]
|
| 336 |
+
"C":x_chain_all[:,2],
|
| 337 |
+
"O":x_chain_all[:,3],
|
| 338 |
+
"N":x_chain_all[:,0]} # [n,]
|
| 339 |
+
return data
|
| 340 |
+
|
| 341 |
+
class MPNNDataset(data.Dataset):
|
| 342 |
+
def __init__(self, data_path='/gaozhangyang/drug_dataset/proteinmpnn_data/pdb_2021aug02', rescut=3.5, split='train'):
|
| 343 |
+
self.data_path = data_path
|
| 344 |
+
self.rescut = rescut
|
| 345 |
+
self.params = {
|
| 346 |
+
"LIST" : f"{self.data_path}/list.csv",
|
| 347 |
+
"VAL" : f"{self.data_path}/valid_clusters.txt",
|
| 348 |
+
"TEST" : f"{self.data_path}/test_clusters.txt",
|
| 349 |
+
"DIR" : f"{self.data_path}",
|
| 350 |
+
"DATCUT" : "2030-Jan-01",
|
| 351 |
+
"RESCUT" : self.rescut, #resolution cutoff for PDBs
|
| 352 |
+
"HOMO" : 0.70 #min seq.id. to detect homo chains
|
| 353 |
+
}
|
| 354 |
+
|
| 355 |
+
if not os.path.exists("/gaozhangyang/experiments/OpenCPD/data/mpnn_data/split.pt"):
|
| 356 |
+
train, valid, test = build_training_clusters(self.params, False)
|
| 357 |
+
split = {"train": train, "valid":valid, "test":test}
|
| 358 |
+
torch.save(split, "/gaozhangyang/experiments/OpenCPD/data/mpnn_data/split.pt")
|
| 359 |
+
else:
|
| 360 |
+
split = torch.load("/gaozhangyang/experiments/OpenCPD/data/mpnn_data/split.pt")
|
| 361 |
+
|
| 362 |
+
self.split_dict = split[mode]
|
| 363 |
+
alphabet='ACDEFGHIKLMNPQRSTVWYX'
|
| 364 |
+
self.alphabet_set = set([a for a in alphabet])
|
| 365 |
+
self.IDs = list(self.split_dict.keys())
|
| 366 |
+
# self.data = self.preprocess()
|
| 367 |
+
|
| 368 |
+
def cache_split(self,):
|
| 369 |
+
train, valid, test = build_training_clusters(self.params, False)
|
| 370 |
+
|
| 371 |
+
return {"train": train, "valid":valid, "test":test}
|
| 372 |
+
|
| 373 |
+
@classmethod
|
| 374 |
+
def safe_iter(self, ID, split_dict, params, alphabet_set, max_length=1000):
|
| 375 |
+
# sel_idx = np.random.randint(0, len(split_dict[ID]))
|
| 376 |
+
sel_idx = 0
|
| 377 |
+
out = loader_pdb(split_dict[ID][sel_idx], params)
|
| 378 |
+
entry = get_pdbs(out)
|
| 379 |
+
if entry is None:
|
| 380 |
+
return None
|
| 381 |
+
|
| 382 |
+
seq = entry['seq']
|
| 383 |
+
bad_chars = set([s for s in seq]).difference(alphabet_set)
|
| 384 |
+
if len(bad_chars) != 0:
|
| 385 |
+
return None
|
| 386 |
+
|
| 387 |
+
if len(entry['seq']) > max_length:
|
| 388 |
+
return None
|
| 389 |
+
|
| 390 |
+
masked_chains = entry['masked_list']
|
| 391 |
+
visible_chains = entry['visible_list']
|
| 392 |
+
|
| 393 |
+
all_chains = masked_chains + visible_chains
|
| 394 |
+
visible_temp_dict = {}
|
| 395 |
+
masked_temp_dict = {}
|
| 396 |
+
|
| 397 |
+
for step, letter in enumerate(all_chains):
|
| 398 |
+
chain_seq = entry[f'seq_chain_{letter}']
|
| 399 |
+
if letter in visible_chains:
|
| 400 |
+
visible_temp_dict[letter] = chain_seq
|
| 401 |
+
elif letter in masked_chains:
|
| 402 |
+
masked_temp_dict[letter] = chain_seq
|
| 403 |
+
|
| 404 |
+
for km, vm in masked_temp_dict.items():
|
| 405 |
+
for kv, vv in visible_temp_dict.items():
|
| 406 |
+
if vm == vv:
|
| 407 |
+
if kv not in masked_chains:
|
| 408 |
+
masked_chains.append(kv)
|
| 409 |
+
if kv in visible_chains:
|
| 410 |
+
visible_chains.remove(kv)
|
| 411 |
+
|
| 412 |
+
all_chains = masked_chains + visible_chains
|
| 413 |
+
random.shuffle(all_chains)
|
| 414 |
+
|
| 415 |
+
|
| 416 |
+
x_chain_list = []
|
| 417 |
+
chain_mask_list = []
|
| 418 |
+
chain_seq_list = []
|
| 419 |
+
chain_encoding_list = []
|
| 420 |
+
c = 1
|
| 421 |
+
|
| 422 |
+
for step, letter in enumerate(all_chains):
|
| 423 |
+
if letter in visible_chains:
|
| 424 |
+
chain_seq = entry[f'seq_chain_{letter}']
|
| 425 |
+
chain_length = len(chain_seq)
|
| 426 |
+
chain_coords = entry[f'coords_chain_{letter}'] #this is a dictionary
|
| 427 |
+
chain_mask = np.zeros(chain_length) #0.0 for visible chains
|
| 428 |
+
x_chain = np.stack([chain_coords[c] for c in [f'N_chain_{letter}', f'CA_chain_{letter}', f'C_chain_{letter}', f'O_chain_{letter}']], 1) #[chain_length,4,3]
|
| 429 |
+
x_chain_list.append(x_chain)
|
| 430 |
+
chain_mask_list.append(chain_mask)
|
| 431 |
+
chain_seq_list.append(chain_seq)
|
| 432 |
+
chain_encoding_list.append(c*np.ones(np.array(chain_mask).shape[0]))
|
| 433 |
+
c+=1
|
| 434 |
+
elif letter in masked_chains:
|
| 435 |
+
chain_seq = entry[f'seq_chain_{letter}']
|
| 436 |
+
chain_length = len(chain_seq)
|
| 437 |
+
chain_coords = entry[f'coords_chain_{letter}'] #this is a dictionary
|
| 438 |
+
chain_mask = np.ones(chain_length) #0.0 for visible chains
|
| 439 |
+
x_chain = np.stack([chain_coords[c] for c in [f'N_chain_{letter}', f'CA_chain_{letter}', f'C_chain_{letter}', f'O_chain_{letter}']], 1) #[chain_lenght,4,3]
|
| 440 |
+
x_chain_list.append(x_chain)
|
| 441 |
+
chain_mask_list.append(chain_mask)
|
| 442 |
+
chain_seq_list.append(chain_seq)
|
| 443 |
+
chain_encoding_list.append(c*np.ones(np.array(chain_mask).shape[0]))
|
| 444 |
+
c+=1
|
| 445 |
+
|
| 446 |
+
chain_mask_all = np.concatenate(chain_mask_list)
|
| 447 |
+
chain_encoding_all = np.concatenate(chain_encoding_list)
|
| 448 |
+
x_chain_all = np.concatenate(x_chain_list)
|
| 449 |
+
|
| 450 |
+
data = {
|
| 451 |
+
"title":entry['name']+str(int(chain_mask_all.sum())),
|
| 452 |
+
"seq":''.join(chain_seq_list), #len(seq)=n
|
| 453 |
+
"chain_mask":chain_mask_all,
|
| 454 |
+
"chain_encoding":chain_encoding_all,
|
| 455 |
+
"CA":x_chain_all[:,1], # [n,3]
|
| 456 |
+
"C":x_chain_all[:,2],
|
| 457 |
+
"O":x_chain_all[:,3],
|
| 458 |
+
"N":x_chain_all[:,0]} # [n,]
|
| 459 |
+
return data
|
| 460 |
+
|
| 461 |
+
|
| 462 |
+
|
| 463 |
+
def preprocess(self):
|
| 464 |
+
data = pmap_multi(self.safe_iter, [(ID,) for ID in self.IDs], split_dict=self.split_dict, params=self.params, alphabet_set=self.alphabet_set)
|
| 465 |
+
return data
|
| 466 |
+
|
| 467 |
+
def __len__(self):
|
| 468 |
+
# return len(self.data)
|
| 469 |
+
return len(self.IDs)
|
| 470 |
+
|
| 471 |
+
def __getitem__(self, index):
|
| 472 |
+
ID = self.IDs[index]
|
| 473 |
+
out = self.safe_iter(ID, split_dict=self.split_dict, params=self.params, alphabet_set=self.alphabet_set)
|
| 474 |
+
return out
|
| 475 |
+
|
| 476 |
+
|
| 477 |
+
def collate_fn(batch):
|
| 478 |
+
return batch
|
| 479 |
+
|
| 480 |
+
|
| 481 |
+
if __name__ == "__main__":
|
| 482 |
+
MPNNDataset = MPNNDataset()
|
| 483 |
+
loader = DataLoaderX(local_rank=0, dataset = MPNNDataset, collate_fn=collate_fn, batch_size=4)
|
| 484 |
+
# loader = DataLoader(dataset = MPNNDataset, collate_fn=collate_fn, batch_size=4, prefetch_factor=4, num_workers=4)
|
| 485 |
+
for batch in tqdm(loader):
|
| 486 |
+
for one in batch:
|
| 487 |
+
if one is not None:
|
| 488 |
+
for key, val in one.items():
|
| 489 |
+
if type(val) == torch.Tensor:
|
| 490 |
+
result = val.to('cuda:0')
|
| 491 |
+
time.sleep(2)
|
| 492 |
+
print()
|
Flexpert-Design/src/datasets/pdb_inference.py
ADDED
|
@@ -0,0 +1,329 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import numpy as np
|
| 4 |
+
import random
|
| 5 |
+
import pdb
|
| 6 |
+
import torch.utils.data as data
|
| 7 |
+
from .utils import cached_property
|
| 8 |
+
from transformers import AutoTokenizer
|
| 9 |
+
|
| 10 |
+
#Imports for the PDB parser utils
|
| 11 |
+
import glob
|
| 12 |
+
import json
|
| 13 |
+
import numpy as np
|
| 14 |
+
import gzip
|
| 15 |
+
import re
|
| 16 |
+
import multiprocessing
|
| 17 |
+
import tqdm
|
| 18 |
+
import shutil
|
| 19 |
+
SENTINEL = 1
|
| 20 |
+
import biotite.structure as struc
|
| 21 |
+
import biotite.application.dssp as dssp
|
| 22 |
+
import biotite.structure.io.pdb.file as file
|
| 23 |
+
|
| 24 |
+
class PDBInference(data.Dataset):
|
| 25 |
+
def __init__(self, path='./', max_length=500, *args, **kwargs):
|
| 26 |
+
self.path = path
|
| 27 |
+
self.max_length = max_length
|
| 28 |
+
|
| 29 |
+
self.data = self.cache_data #TODO
|
| 30 |
+
self.tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D", cache_dir="./cache_dir/")
|
| 31 |
+
|
| 32 |
+
@cached_property
|
| 33 |
+
def cache_data(self):
|
| 34 |
+
alphabet='ACDEFGHIKLMNPQRSTVWY'
|
| 35 |
+
alphabet_set = set([a for a in alphabet])
|
| 36 |
+
print("path is: ", self.path)
|
| 37 |
+
|
| 38 |
+
if not os.path.exists(self.path):
|
| 39 |
+
raise "no such folder:{} !!!".format(self.path)
|
| 40 |
+
else:
|
| 41 |
+
|
| 42 |
+
#list all PDBs
|
| 43 |
+
pdb_files = []
|
| 44 |
+
_files = os.listdir(self.path)
|
| 45 |
+
for _file in _files:
|
| 46 |
+
if _file.endswith('.pdb'):
|
| 47 |
+
pdb_files.append(_file)
|
| 48 |
+
print(f'pdb_files size = {len(pdb_files)}')
|
| 49 |
+
#parse the PDBs into lines like if it was from the chain_set.json
|
| 50 |
+
lines = []
|
| 51 |
+
for _pdb in pdb_files:
|
| 52 |
+
_input_chain = _pdb.split('_')[1].split('.')[0] #ASSUMING NAMING 'PDBCODE_CHAINCODE_XXX'
|
| 53 |
+
_line = self.parse_PDB(self.path+'/'+_pdb, name=_pdb.split('.')[0], input_chain=_input_chain) #Input chain list can be parsed here as well
|
| 54 |
+
#pdb.set_trace()
|
| 55 |
+
lines.append(_line[0])
|
| 56 |
+
|
| 57 |
+
print(f'lines size = {len(lines)}')
|
| 58 |
+
data_list = []
|
| 59 |
+
|
| 60 |
+
flex_instructions = {}
|
| 61 |
+
flexibility_files = glob.glob(self.path + '/*instructions.csv')
|
| 62 |
+
for file in flexibility_files:
|
| 63 |
+
with open(file, 'r') as f:
|
| 64 |
+
flexibility_instructions_parsed= f.read().strip().split(',')
|
| 65 |
+
flexibility_instructions_parsed = [float(i) for i in flexibility_instructions_parsed] + [0.0] #add the padding here
|
| 66 |
+
flex_instructions[file.split('/')[-1].split('_instructions')[0]] = flexibility_instructions_parsed
|
| 67 |
+
|
| 68 |
+
for line in tqdm.tqdm(lines):
|
| 69 |
+
entry = line
|
| 70 |
+
|
| 71 |
+
seq = entry['seq']
|
| 72 |
+
|
| 73 |
+
for key, val in entry['coords'].items():
|
| 74 |
+
entry['coords'][key] = np.asarray(val)
|
| 75 |
+
|
| 76 |
+
bad_chars = set([s for s in seq]).difference(alphabet_set)
|
| 77 |
+
try:
|
| 78 |
+
_flex_instructions = flex_instructions[entry['name']]
|
| 79 |
+
except KeyError:
|
| 80 |
+
_flex_instructions = [0.0] * len(seq)
|
| 81 |
+
print(f"No flexibility instructions found for {entry['name']}. Passing zeros.")
|
| 82 |
+
|
| 83 |
+
if len(bad_chars) == 0:
|
| 84 |
+
if len(entry['seq']) <= self.max_length:
|
| 85 |
+
chain_length = len(entry['seq'])
|
| 86 |
+
chain_mask = np.ones(chain_length)
|
| 87 |
+
data_list.append({
|
| 88 |
+
'title':entry['name'],
|
| 89 |
+
'seq':entry['seq'],
|
| 90 |
+
'CA':entry['coords']['CA'],
|
| 91 |
+
'C':entry['coords']['C'],
|
| 92 |
+
'O':entry['coords']['O'],
|
| 93 |
+
'N':entry['coords']['N'],
|
| 94 |
+
'chain_mask': chain_mask,
|
| 95 |
+
'chain_encoding': 1*chain_mask,
|
| 96 |
+
'gt_flex': _flex_instructions
|
| 97 |
+
})
|
| 98 |
+
else:
|
| 99 |
+
print(f'Skipping PDBs with Bad chars, e.g. gaps in the sequence: {entry["name"]}')
|
| 100 |
+
|
| 101 |
+
#data_dict = {'train':[],'valid':data_list,'test':data_list}
|
| 102 |
+
print(f'data_list size = {len(data_list)}')
|
| 103 |
+
return data_list#data_dict
|
| 104 |
+
|
| 105 |
+
def change_mode(self, mode):
|
| 106 |
+
self.data = self.cache_data[mode]
|
| 107 |
+
|
| 108 |
+
def __len__(self):
|
| 109 |
+
return len(self.data)
|
| 110 |
+
|
| 111 |
+
def get_item(self, index):
|
| 112 |
+
return self.data[index]
|
| 113 |
+
|
| 114 |
+
def __getitem__(self, index):
|
| 115 |
+
item = self.data[index]
|
| 116 |
+
L = len(item['seq'])
|
| 117 |
+
if L>self.max_length:
|
| 118 |
+
# 计算截断的最大索引
|
| 119 |
+
max_index = L - self.max_length
|
| 120 |
+
# 生成随机的截断索引
|
| 121 |
+
truncate_index = random.randint(0, max_index)
|
| 122 |
+
# 进行截断
|
| 123 |
+
item['seq'] = item['seq'][truncate_index:truncate_index+self.max_length]
|
| 124 |
+
item['CA'] = item['CA'][truncate_index:truncate_index+self.max_length]
|
| 125 |
+
item['C'] = item['C'][truncate_index:truncate_index+self.max_length]
|
| 126 |
+
item['O'] = item['O'][truncate_index:truncate_index+self.max_length]
|
| 127 |
+
item['N'] = item['N'][truncate_index:truncate_index+self.max_length]
|
| 128 |
+
item['chain_mask'] = item['chain_mask'][truncate_index:truncate_index+self.max_length]
|
| 129 |
+
item['chain_encoding'] = item['chain_encoding'][truncate_index:truncate_index+self.max_length]
|
| 130 |
+
item['gt_flex'] = item['gt_flex'][truncate_index:truncate_index+self.max_length]
|
| 131 |
+
return item
|
| 132 |
+
|
| 133 |
+
#Code from data_utils on local PC, based on: https://github.com/JoreyYan/zetadesign/blob/master/data/data.py
|
| 134 |
+
def parse_PDB_biounits(self, x, sse,ssedssp,atoms=['N', 'CA', 'C'], chain=None):
|
| 135 |
+
'''
|
| 136 |
+
input: x = PDB filename
|
| 137 |
+
atoms = atoms to extract (optional)
|
| 138 |
+
output: (length, atoms, coords=(x,y,z)), sequence
|
| 139 |
+
'''
|
| 140 |
+
|
| 141 |
+
alpha_1 = list("ARNDCQEGHILKMFPSTWYV-")
|
| 142 |
+
states = len(alpha_1)
|
| 143 |
+
alpha_3 = ['ALA', 'ARG', 'ASN', 'ASP', 'CYS', 'GLN', 'GLU', 'GLY', 'HIS', 'ILE',
|
| 144 |
+
'LEU', 'LYS', 'MET', 'PHE', 'PRO', 'SER', 'THR', 'TRP', 'TYR', 'VAL', 'GAP']
|
| 145 |
+
|
| 146 |
+
aa_1_N = {a: n for n, a in enumerate(alpha_1)}
|
| 147 |
+
aa_3_N = {a: n for n, a in enumerate(alpha_3)}
|
| 148 |
+
aa_N_1 = {n: a for n, a in enumerate(alpha_1)}
|
| 149 |
+
aa_1_3 = {a: b for a, b in zip(alpha_1, alpha_3)}
|
| 150 |
+
aa_3_1 = {b: a for a, b in zip(alpha_1, alpha_3)}
|
| 151 |
+
|
| 152 |
+
def AA_to_N(x):
|
| 153 |
+
x = np.array(x)
|
| 154 |
+
if x.ndim == 0: x = x[None]
|
| 155 |
+
return [[aa_1_N.get(a, states - 1) for a in y] for y in x]
|
| 156 |
+
|
| 157 |
+
def N_to_AA(x):
|
| 158 |
+
x = np.array(x)
|
| 159 |
+
if x.ndim == 1: x = x[None]
|
| 160 |
+
return ["".join([aa_N_1.get(a, "-") for a in y]) for y in x]
|
| 161 |
+
|
| 162 |
+
xyz, seq, plddts, min_resn, max_resn = {}, {}, [], 1e6, -1e6
|
| 163 |
+
|
| 164 |
+
pdbcontents = x.split('\n')[0]
|
| 165 |
+
with open(pdbcontents) as f:
|
| 166 |
+
pdbcontents = f.readlines()
|
| 167 |
+
for line in pdbcontents:
|
| 168 |
+
|
| 169 |
+
if line[:6] == "HETATM" and line[17:17 + 3] == "MSE":
|
| 170 |
+
line = line.replace("HETATM", "ATOM ")
|
| 171 |
+
line = line.replace("MSE", "MET")
|
| 172 |
+
|
| 173 |
+
if line[:4] == "ATOM":
|
| 174 |
+
ch = line[21:22]
|
| 175 |
+
if ch == chain or chain is None or ch==' ':
|
| 176 |
+
atom = line[12:12 + 4].strip()
|
| 177 |
+
resi = line[17:17 + 3]
|
| 178 |
+
resn = line[22:22 + 5].strip()
|
| 179 |
+
plddt=line[60:60 + 6].strip()
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
x, y, z = [float(line[i:(i + 8)]) for i in [30, 38, 46]]
|
| 184 |
+
|
| 185 |
+
if resn[-1].isalpha():
|
| 186 |
+
resa, resn = resn[-1], int(resn[:-1]) - 1 # in same pos ,use last atoms
|
| 187 |
+
else:
|
| 188 |
+
resa, resn = "_", int(resn) - 1
|
| 189 |
+
# resn = int(resn)
|
| 190 |
+
if resn < min_resn:
|
| 191 |
+
min_resn = resn
|
| 192 |
+
if resn > max_resn:
|
| 193 |
+
max_resn = resn
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
if resn not in xyz:
|
| 198 |
+
xyz[resn] = {}
|
| 199 |
+
if resa not in xyz[resn]:
|
| 200 |
+
xyz[resn][resa] = {}
|
| 201 |
+
if resn not in seq:
|
| 202 |
+
seq[resn] = {}
|
| 203 |
+
|
| 204 |
+
if resa not in seq[resn]:
|
| 205 |
+
seq[resn][resa] = resi
|
| 206 |
+
|
| 207 |
+
if atom not in xyz[resn][resa]:
|
| 208 |
+
xyz[resn][resa][atom] = np.array([x, y, z])
|
| 209 |
+
|
| 210 |
+
# convert to numpy arrays, fill in missing values
|
| 211 |
+
seq_, xyz_ ,sse_,ssedssp_= [], [], [], []
|
| 212 |
+
dsspidx=0
|
| 213 |
+
sseidx=0
|
| 214 |
+
|
| 215 |
+
for resn in range(int(min_resn), int(max_resn + 1)):
|
| 216 |
+
if resn in seq:
|
| 217 |
+
for k in sorted(seq[resn]):
|
| 218 |
+
seq_.append(aa_3_N.get(seq[resn][k], 20))
|
| 219 |
+
try:
|
| 220 |
+
if 'CA' in xyz[resn][k]:
|
| 221 |
+
sse_.append(sse[sseidx])
|
| 222 |
+
sseidx = sseidx + 1
|
| 223 |
+
else:
|
| 224 |
+
sse_.append('-')
|
| 225 |
+
except:
|
| 226 |
+
print('error sse')
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
else:
|
| 230 |
+
seq_.append(20)
|
| 231 |
+
sse_.append('-')
|
| 232 |
+
|
| 233 |
+
misschianatom = False
|
| 234 |
+
if resn in xyz:
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
for k in sorted(xyz[resn]):
|
| 238 |
+
for atom in atoms:
|
| 239 |
+
if atom in xyz[resn][k]:
|
| 240 |
+
xyz_.append(xyz[resn][k][atom]) #some will miss C and O ,but sse is normal,because sse just depend on CA
|
| 241 |
+
else:
|
| 242 |
+
xyz_.append(np.full(3, np.nan))
|
| 243 |
+
misschianatom=True
|
| 244 |
+
if misschianatom:
|
| 245 |
+
ssedssp_.append('-')
|
| 246 |
+
misschianatom = False
|
| 247 |
+
else:
|
| 248 |
+
try:
|
| 249 |
+
ssedssp_.append(ssedssp[dsspidx]) # if miss chain atom,xyz ,seq think is ok , but dssp miss this
|
| 250 |
+
dsspidx = dsspidx + 1
|
| 251 |
+
except:
|
| 252 |
+
pass
|
| 253 |
+
#print(dsspidx)
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
else:
|
| 257 |
+
for atom in atoms:
|
| 258 |
+
xyz_.append(np.full(3, np.nan))
|
| 259 |
+
ssedssp_.append('-')
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
return np.array(xyz_).reshape(-1, len(atoms), 3), N_to_AA(np.array(seq_)),np.array(sse_),np.array(ssedssp_)
|
| 263 |
+
|
| 264 |
+
def parse_PDB(self, path_to_pdb, name, input_chain):
|
| 265 |
+
"""
|
| 266 |
+
make sure every time just input 1 line
|
| 267 |
+
"""
|
| 268 |
+
c = 0
|
| 269 |
+
pdb_dict_list = []
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
biounit_names = [path_to_pdb]
|
| 273 |
+
for biounit in biounit_names:
|
| 274 |
+
my_dict = {}
|
| 275 |
+
s = 0
|
| 276 |
+
concat_seq = ''
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
letter = input_chain #Assuming single chain!!
|
| 280 |
+
|
| 281 |
+
PDBFile = file.PDBFile.read(biounit)
|
| 282 |
+
array_stack = PDBFile.get_structure(altloc="all")
|
| 283 |
+
|
| 284 |
+
#In case the passed letter is unknown, select one chain from the PDB file based on the dominant protein chain
|
| 285 |
+
if letter not in array_stack.chain_id:
|
| 286 |
+
is_protein = struc.filter_amino_acids(array_stack)
|
| 287 |
+
protein_atoms = array_stack[0][is_protein]
|
| 288 |
+
chain_ids, chain_counts = np.unique(protein_atoms.chain_id, return_counts=True)
|
| 289 |
+
dominant_chain_id = chain_ids[np.argmax(chain_counts)]
|
| 290 |
+
letter = dominant_chain_id
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
sse1 = struc.annotate_sse(array_stack[0], chain_id=letter).tolist()
|
| 294 |
+
if len(sse1)==0:
|
| 295 |
+
sse1 = struc.annotate_sse(array_stack[0], chain_id='').tolist()
|
| 296 |
+
|
| 297 |
+
ssedssp1 = [] #not annotating dssp for now
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
xyz, seq, sse, ssedssp = self.parse_PDB_biounits(biounit,sse1,ssedssp1,atoms=['N', 'CA', 'C','O'], chain=letter) #TODO: fix the float error
|
| 301 |
+
ssedssp = sse #faking it for now
|
| 302 |
+
|
| 303 |
+
assert len(sse)==len(seq[0])
|
| 304 |
+
assert len(ssedssp) == len(seq[0])
|
| 305 |
+
|
| 306 |
+
if type(xyz) != str:
|
| 307 |
+
concat_seq += seq[0]
|
| 308 |
+
my_dict['seq_chain_' + letter] = seq[0]
|
| 309 |
+
|
| 310 |
+
coords_dict_chain = {}
|
| 311 |
+
coords_dict_chain['N'] = xyz[:, 0, :].tolist()
|
| 312 |
+
coords_dict_chain['CA'] = xyz[:, 1, :].tolist()
|
| 313 |
+
coords_dict_chain['C'] = xyz[:, 2, :].tolist()
|
| 314 |
+
coords_dict_chain['O'] = xyz[:, 3, :].tolist()
|
| 315 |
+
my_dict['coords_chain_' + letter] = coords_dict_chain
|
| 316 |
+
my_dict['coords'] = coords_dict_chain
|
| 317 |
+
s += 1
|
| 318 |
+
|
| 319 |
+
# if s>1:
|
| 320 |
+
# raise NotImplementedError('Inference so far implemented only for single chain proteins')
|
| 321 |
+
|
| 322 |
+
my_dict['name'] = name
|
| 323 |
+
my_dict['num_chains'] = s
|
| 324 |
+
my_dict['seq'] = my_dict[f'seq_chain_{letter}'] #concat_seq
|
| 325 |
+
# if s <= len(chain_alphabet):
|
| 326 |
+
# pdb_dict_list.append(my_dict)
|
| 327 |
+
# c += 1
|
| 328 |
+
pdb_dict_list.append(my_dict)
|
| 329 |
+
return pdb_dict_list
|
Flexpert-Design/src/datasets/ts_dataset.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch.utils.data as data
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class TSDataset(data.Dataset):
|
| 8 |
+
def __init__(self, path = './', split='test'):
|
| 9 |
+
if not os.path.exists(path):
|
| 10 |
+
raise "no such file:{} !!!".format(path)
|
| 11 |
+
else:
|
| 12 |
+
ts50_data = json.load(open(path+'/ts50.json'))
|
| 13 |
+
ts500_data = json.load(open(path+'/ts500.json'))
|
| 14 |
+
|
| 15 |
+
# TS500 has proteins with lengths of 500+
|
| 16 |
+
# TS50 only contains proteins with lengths less than 500
|
| 17 |
+
self.data = []
|
| 18 |
+
for temp in ts50_data:
|
| 19 |
+
coords = np.array(temp['coords'])
|
| 20 |
+
self.data.append({'title':temp['name'],
|
| 21 |
+
'seq':temp['seq'],
|
| 22 |
+
'CA':coords[:,1,:],
|
| 23 |
+
'C':coords[:,2,:],
|
| 24 |
+
'O':coords[:,3,:],
|
| 25 |
+
'N':coords[:,0,:],
|
| 26 |
+
'category': 'ts50'
|
| 27 |
+
})
|
| 28 |
+
|
| 29 |
+
for temp in ts500_data:
|
| 30 |
+
coords = np.array(temp['coords'])
|
| 31 |
+
self.data.append({'title':temp['name'],
|
| 32 |
+
'seq':temp['seq'],
|
| 33 |
+
'CA':coords[:,1,:],
|
| 34 |
+
'C':coords[:,2,:],
|
| 35 |
+
'O':coords[:,3,:],
|
| 36 |
+
'N':coords[:,0,:],
|
| 37 |
+
'category': 'ts500'
|
| 38 |
+
})
|
| 39 |
+
|
| 40 |
+
def __len__(self):
|
| 41 |
+
return len(self.data)
|
| 42 |
+
|
| 43 |
+
def get_item(self, index):
|
| 44 |
+
return self.data[index]
|
| 45 |
+
|
| 46 |
+
def __getitem__(self, index):
|
| 47 |
+
return self.data[index]
|
Flexpert-Design/src/datasets/utils.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class cached_property(object):
|
| 5 |
+
"""
|
| 6 |
+
Descriptor (non-data) for building an attribute on-demand on first use.
|
| 7 |
+
"""
|
| 8 |
+
def __init__(self, factory):
|
| 9 |
+
"""
|
| 10 |
+
<factory> is called such: factory(instance) to build the attribute.
|
| 11 |
+
"""
|
| 12 |
+
self._attr_name = factory.__name__
|
| 13 |
+
self._factory = factory
|
| 14 |
+
|
| 15 |
+
def __get__(self, instance, owner):
|
| 16 |
+
# Build the attribute.
|
| 17 |
+
attr = self._factory(instance)
|
| 18 |
+
|
| 19 |
+
# Cache the value; hide ourselves.
|
| 20 |
+
setattr(instance, self._attr_name, attr)
|
| 21 |
+
return attr
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def get_inds(expected_num, clu_nums, cid2clu, seq2ind):
|
| 25 |
+
cur_len, cur_idx, query_cids, query_idx = 0, 0, [], []
|
| 26 |
+
while cur_len < expected_num:
|
| 27 |
+
cid, l = clu_nums[cur_idx % (len(clu_nums))]
|
| 28 |
+
cur_idx += 1
|
| 29 |
+
# check if this cluster has been selected
|
| 30 |
+
if cid in query_cids:
|
| 31 |
+
continue
|
| 32 |
+
if random.random() > 0.5:
|
| 33 |
+
for seq in cid2clu[cid]:
|
| 34 |
+
# seq2ind: ensure it is in limited lengths
|
| 35 |
+
if seq in seq2ind.keys():
|
| 36 |
+
query_idx.append(seq2ind[seq])
|
| 37 |
+
cur_len += 1
|
| 38 |
+
|
| 39 |
+
query_cids.append(cid)
|
| 40 |
+
return query_cids, query_idx
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def get_num(N, valid_num=100):
|
| 44 |
+
train_n, valid_n = int(0.9 * N), min(valid_num, int(0.05 * N))
|
| 45 |
+
test_n = N - train_n - valid_n
|
| 46 |
+
return train_n, valid_n, test_n
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def get_full_inds(expected_num, clu_nums, cid2clu, full_seq2ind):
|
| 50 |
+
cur_len, cur_idx, query_cids, query_idx = 0, 0, [], {}
|
| 51 |
+
# build query_idx for each dataset
|
| 52 |
+
for dataname in full_seq2ind.keys():
|
| 53 |
+
if dataname not in query_idx.keys():
|
| 54 |
+
query_idx[dataname] = []
|
| 55 |
+
cur_idx_lst = list(range(len(clu_nums)))
|
| 56 |
+
while cur_len < expected_num:
|
| 57 |
+
cur_idx = random.choice(cur_idx_lst)
|
| 58 |
+
cid, l = clu_nums[cur_idx]
|
| 59 |
+
# check if this cluster has been selected
|
| 60 |
+
if cid in query_cids:
|
| 61 |
+
continue
|
| 62 |
+
for seq in set(cid2clu[cid]):
|
| 63 |
+
# seq2ind: ensure it is in limited lengths
|
| 64 |
+
for dataname in full_seq2ind.keys():
|
| 65 |
+
if seq in full_seq2ind[dataname].keys():
|
| 66 |
+
query_idx[dataname].append(full_seq2ind[dataname][seq])
|
| 67 |
+
cur_len += 1
|
| 68 |
+
query_cids.append(cid)
|
| 69 |
+
cur_idx_lst.remove(cur_idx)
|
| 70 |
+
return query_cids, query_idx
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def get_inds(expected_num, clu_nums, cid2clu, seq2ind):
|
| 74 |
+
cur_len, query_cids, query_idx = 0, [], []
|
| 75 |
+
cur_idx_lst = list(range(len(clu_nums)))
|
| 76 |
+
while cur_len < expected_num:
|
| 77 |
+
try:
|
| 78 |
+
cur_idx = random.choice(cur_idx_lst)
|
| 79 |
+
cid, l = clu_nums[cur_idx]
|
| 80 |
+
# check if this cluster has been selected
|
| 81 |
+
if cid in query_cids:
|
| 82 |
+
continue
|
| 83 |
+
|
| 84 |
+
# check if this cluster is too big
|
| 85 |
+
pre = abs(expected_num - cur_len)
|
| 86 |
+
aft = abs(cur_len + l - expected_num)
|
| 87 |
+
if pre < aft:
|
| 88 |
+
continue
|
| 89 |
+
|
| 90 |
+
for seq in cid2clu[cid]:
|
| 91 |
+
# seq2ind: ensure it is in limited lengths
|
| 92 |
+
if seq in seq2ind.keys():
|
| 93 |
+
query_idx.append(seq2ind[seq])
|
| 94 |
+
cur_len += 1
|
| 95 |
+
query_cids.append(cid)
|
| 96 |
+
cur_idx_lst.remove(cur_idx)
|
| 97 |
+
except:
|
| 98 |
+
break
|
| 99 |
+
return query_cids, query_idx
|
Flexpert-Design/src/interface/__init__.py
ADDED
|
File without changes
|
Flexpert-Design/src/interface/data_interface.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import inspect
|
| 2 |
+
import importlib
|
| 3 |
+
import pytorch_lightning as pl
|
| 4 |
+
from torch.utils.data import DataLoader
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class DInterface_base(pl.LightningDataModule):
|
| 8 |
+
def __init__(self, **kwargs):
|
| 9 |
+
super().__init__()
|
| 10 |
+
self.save_hyperparameters()
|
| 11 |
+
self.batch_size = self.hparams.batch_size
|
| 12 |
+
print("batch_size", self.batch_size)
|
| 13 |
+
self.load_data_module()
|
| 14 |
+
|
| 15 |
+
def setup(self, stage=None):
|
| 16 |
+
# Assign train/val datasets for use in dataloaders
|
| 17 |
+
if stage == 'fit' or stage is None:
|
| 18 |
+
self.trainset = self.instancialize(split = 'train')
|
| 19 |
+
self.valset = self.instancialize(split='valid')
|
| 20 |
+
|
| 21 |
+
# Assign test dataset for use in dataloader(s)
|
| 22 |
+
if stage == 'test' or stage is None:
|
| 23 |
+
self.testset = self.instancialize(split='test')
|
| 24 |
+
|
| 25 |
+
def train_dataloader(self):
|
| 26 |
+
return DataLoader(self.trainset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True, prefetch_factor=3)
|
| 27 |
+
|
| 28 |
+
def val_dataloader(self):
|
| 29 |
+
return DataLoader(self.valset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)
|
| 30 |
+
|
| 31 |
+
def test_dataloader(self):
|
| 32 |
+
return DataLoader(self.testset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)
|
| 33 |
+
|
| 34 |
+
def load_data_module(self):
|
| 35 |
+
name = self.dataset
|
| 36 |
+
# Change the `snake_case.py` file name to `CamelCase` class name.
|
| 37 |
+
# Please always name your model file name as `snake_case.py` and
|
| 38 |
+
# class name corresponding `CamelCase`.
|
| 39 |
+
camel_name = ''.join([i.capitalize() for i in name.split('_')])
|
| 40 |
+
try:
|
| 41 |
+
self.data_module = getattr(importlib.import_module(
|
| 42 |
+
'.'+name, package=__package__), camel_name)
|
| 43 |
+
except:
|
| 44 |
+
raise ValueError(
|
| 45 |
+
f'Invalid Dataset File Name or Invalid Class Name data.{name}.{camel_name}')
|
| 46 |
+
|
| 47 |
+
def instancialize(self, **other_args):
|
| 48 |
+
""" Instancialize a model using the corresponding parameters
|
| 49 |
+
from self.hparams dictionary. You can also input any args
|
| 50 |
+
to overwrite the corresponding value in self.kwargs.
|
| 51 |
+
"""
|
| 52 |
+
if other_args['split'] == 'train':
|
| 53 |
+
self.data_module = getattr(importlib.import_module(
|
| 54 |
+
'.AF2DB_dataset', package='data'), 'Af2dbDataset')
|
| 55 |
+
else:
|
| 56 |
+
self.data_module = getattr(importlib.import_module(
|
| 57 |
+
'.CASP15_dataset', package='data'), 'CASP15Dataset')
|
| 58 |
+
|
| 59 |
+
class_args = list(inspect.signature(self.data_module.__init__).parameters)[1:]
|
| 60 |
+
inkeys = self.kwargs.keys()
|
| 61 |
+
args1 = {}
|
| 62 |
+
for arg in class_args:
|
| 63 |
+
if arg in inkeys:
|
| 64 |
+
args1[arg] = self.kwargs[arg]
|
| 65 |
+
args1.update(other_args)
|
| 66 |
+
return self.data_module(**args1)
|
Flexpert-Design/src/interface/model_interface.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import pytorch_lightning as pl
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import os
|
| 5 |
+
import torch.optim.lr_scheduler as lrs
|
| 6 |
+
import inspect
|
| 7 |
+
|
| 8 |
+
class MInterface_base(pl.LightningModule):
|
| 9 |
+
def __init__(self, model_name=None, loss=None, lr=None, **kargs):
|
| 10 |
+
super().__init__()
|
| 11 |
+
self.save_hyperparameters()
|
| 12 |
+
self.load_model()
|
| 13 |
+
self.configure_loss()
|
| 14 |
+
os.makedirs(os.path.join(self.hparams.res_dir, self.hparams.ex_name), exist_ok=True)
|
| 15 |
+
|
| 16 |
+
def forward(self, input):
|
| 17 |
+
pass
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def training_step(self, batch, batch_idx, **kwargs):
|
| 21 |
+
pass
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def validation_step(self, batch, batch_idx):
|
| 25 |
+
pass
|
| 26 |
+
|
| 27 |
+
def test_step(self, batch, batch_idx):
|
| 28 |
+
# Here we just reuse the validation_step for testing
|
| 29 |
+
return self.validation_step(batch, batch_idx)
|
| 30 |
+
|
| 31 |
+
def on_validation_epoch_end(self):
|
| 32 |
+
# Make the Progress Bar leave there
|
| 33 |
+
self.print('')
|
| 34 |
+
|
| 35 |
+
def get_schedular(self, optimizer, lr_scheduler='onecycle'):
|
| 36 |
+
if lr_scheduler == 'step':
|
| 37 |
+
scheduler = lrs.StepLR(optimizer,
|
| 38 |
+
step_size=self.hparams.lr_decay_steps,
|
| 39 |
+
gamma=self.hparams.lr_decay_rate)
|
| 40 |
+
elif lr_scheduler == 'cosine':
|
| 41 |
+
scheduler = lrs.CosineAnnealingLR(optimizer,
|
| 42 |
+
T_max=self.hparams.lr_decay_steps,
|
| 43 |
+
eta_min=self.hparams.lr_decay_min_lr)
|
| 44 |
+
elif lr_scheduler == 'onecycle':
|
| 45 |
+
scheduler = lrs.OneCycleLR(optimizer, max_lr=self.hparams.lr, steps_per_epoch=self.hparams.steps_per_epoch, epochs=self.hparams.epoch, three_phase=False)
|
| 46 |
+
else:
|
| 47 |
+
raise ValueError('Invalid lr_scheduler type!')
|
| 48 |
+
|
| 49 |
+
return scheduler
|
| 50 |
+
|
| 51 |
+
def configure_optimizers(self):
|
| 52 |
+
if hasattr(self.hparams, 'weight_decay'):
|
| 53 |
+
weight_decay = self.hparams.weight_decay
|
| 54 |
+
else:
|
| 55 |
+
weight_decay = 0
|
| 56 |
+
|
| 57 |
+
optimizer_g = torch.optim.AdamW(self.model.parameters(), lr=self.hparams.lr, weight_decay=weight_decay, betas=(0.9, 0.98), eps=1e-8)
|
| 58 |
+
|
| 59 |
+
schecular_g = self.get_schedular(optimizer_g, self.hparams.lr_scheduler)
|
| 60 |
+
|
| 61 |
+
return [optimizer_g], [{"scheduler": schecular_g, "interval": "step"}]
|
| 62 |
+
|
| 63 |
+
def lr_scheduler_step(self, *args, **kwargs):
|
| 64 |
+
scheduler = self.lr_schedulers()
|
| 65 |
+
scheduler.step()
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def configure_devices(self):
|
| 69 |
+
self.device = torch.device(self.hparams.device)
|
| 70 |
+
|
| 71 |
+
def configure_loss(self):
|
| 72 |
+
self.loss_function = nn.CrossEntropyLoss(reduction='none')
|
| 73 |
+
|
| 74 |
+
def load_model(self):
|
| 75 |
+
self.model = None
|
| 76 |
+
|
| 77 |
+
def instancialize(self, Model, **other_args):
|
| 78 |
+
""" Instancialize a model using the corresponding parameters
|
| 79 |
+
from self.hparams dictionary. You can also input any args
|
| 80 |
+
to overwrite the corresponding value in self.hparams.
|
| 81 |
+
"""
|
| 82 |
+
class_args = inspect.getargspec(Model.__init__).args[1:]
|
| 83 |
+
inkeys = self.hparams.keys()
|
| 84 |
+
args1 = {}
|
| 85 |
+
for arg in class_args:
|
| 86 |
+
if arg in inkeys:
|
| 87 |
+
args1[arg] = getattr(self.hparams, arg)
|
| 88 |
+
args1.update(other_args)
|
| 89 |
+
return Model(**args1)
|
Flexpert-Design/src/interface/pretrain_interface.py
ADDED
|
@@ -0,0 +1,405 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from omegaconf import OmegaConf
|
| 3 |
+
from transformers import AutoTokenizer, EsmForMaskedLM
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
|
| 6 |
+
class PretrainInterface(torch.nn.Module):
|
| 7 |
+
def __init__(self, name):
|
| 8 |
+
super().__init__()
|
| 9 |
+
self.name = name
|
| 10 |
+
if name == "ESM35M":
|
| 11 |
+
self.esm_dim = 480
|
| 12 |
+
self.tokenizer = AutoTokenizer.from_pretrained("/huyuqi/model_zoom/transformers/models--facebook--esm2_t12_35M_UR50D")
|
| 13 |
+
self.pretrain_model = EsmForMaskedLM.from_pretrained("/huyuqi/model_zoom/transformers/models--facebook--esm2_t12_35M_UR50D")
|
| 14 |
+
if name == "ESM650M":
|
| 15 |
+
self.esm_dim = 1280
|
| 16 |
+
self.tokenizer = AutoTokenizer.from_pretrained("/huyuqi/model_zoom/transformers/models--facebook--esm2_t33_650M_UR50D/snapshots/08e4846e537177426273712802403f7ba8261b6c")
|
| 17 |
+
self.pretrain_model = EsmForMaskedLM.from_pretrained("/huyuqi/model_zoom/transformers/models--facebook--esm2_t33_650M_UR50D/snapshots/08e4846e537177426273712802403f7ba8261b6c")
|
| 18 |
+
if name == "ESM3B":
|
| 19 |
+
self.esm_dim = 2560
|
| 20 |
+
self.tokenizer = AutoTokenizer.from_pretrained("/huyuqi/model_zoom/transformers/models--facebook--esm2_t36_3B_UR50D/snapshots/476b639933c8baad5ad09a60ac1a87f987b656fc")
|
| 21 |
+
self.pretrain_model = EsmForMaskedLM.from_pretrained("/huyuqi/model_zoom/transformers/models--facebook--esm2_t36_3B_UR50D/snapshots/476b639933c8baad5ad09a60ac1a87f987b656fc")
|
| 22 |
+
|
| 23 |
+
if name == "vanilla":
|
| 24 |
+
from step1_VQ.model_interface import MInterface
|
| 25 |
+
pretrain_args = OmegaConf.load("/huyuqi/xmyu/DiffSDS/Pretrain_lightning/results/ESMVQ/base/configs/10-18T01-15-36-project.yaml")
|
| 26 |
+
pretrain_args.diffusion = False
|
| 27 |
+
self.pretrain_model = MInterface(**pretrain_args)
|
| 28 |
+
ckpt = torch.load('/huyuqi/xmyu/DiffSDS/Pretrain_lightning/results/ESMVQ/base/checkpoints/best-epoch=14-val_loss=0.314.pth', map_location=torch.device('cpu'))
|
| 29 |
+
state_dict = {k.replace('_forward_module.', ''): v for k, v in ckpt.items()}
|
| 30 |
+
self.pretrain_model.load_state_dict(state_dict, strict=False)
|
| 31 |
+
|
| 32 |
+
# if name == "LFQ":
|
| 33 |
+
# from step1_VQ.model_interface import MInterface
|
| 34 |
+
# pretrain_args = OmegaConf.load("/huyuqi/xmyu/DiffSDS/Pretrain_lightning/results/ESMFVQ/LFQ_seg_linear/configs/10-17T15-46-37-project.yaml")
|
| 35 |
+
# pretrain_args.diffusion = False
|
| 36 |
+
# self.pretrain_model = MInterface(**pretrain_args)
|
| 37 |
+
# ckpt = torch.load('/huyuqi/xmyu/DiffSDS/Pretrain_lightning/results/ESMFVQ/LFQ_seg_linear/checkpoints/best-epoch=14-val_loss=0.161.pth', map_location=torch.device('cpu'))
|
| 38 |
+
# state_dict = {k.replace('_forward_module.', ''): v for k, v in ckpt.items()}
|
| 39 |
+
# self.pretrain_model.load_state_dict(state_dict, strict=False)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
if name == "softgroup-1":
|
| 44 |
+
from step1_VQ.model_interface import MInterface
|
| 45 |
+
pretrain_args = OmegaConf.load("/huyuqi/xmyu/VQProteinFormer/step1_VQ/results/ESMSoftGroup/softgroup-1/configs/12-16T14-57-28-project.yaml")
|
| 46 |
+
pretrain_args.diffusion = False
|
| 47 |
+
self.pretrain_model = MInterface(**pretrain_args)
|
| 48 |
+
ckpt = torch.load('/huyuqi/xmyu/VQProteinFormer/step1_VQ/results/ESMSoftGroup/softgroup-1/checkpoints/best-epoch=13-val_loss=0.111.pth', map_location=torch.device('cpu'))
|
| 49 |
+
state_dict = {k.replace('_forward_module.', ''): v for k, v in ckpt.items()}
|
| 50 |
+
self.pretrain_model.load_state_dict(state_dict)
|
| 51 |
+
|
| 52 |
+
if name == "softgroup-2":
|
| 53 |
+
from step1_VQ.model_interface import MInterface
|
| 54 |
+
pretrain_args = OmegaConf.load("/huyuqi/xmyu/DiffSDS/Pretrain_lightning/results/ESMSoftGroup/softgroup-2/configs/10-24T12-51-57-project.yaml")
|
| 55 |
+
pretrain_args.diffusion = False
|
| 56 |
+
self.pretrain_model = MInterface(**pretrain_args)
|
| 57 |
+
ckpt = torch.load('/huyuqi/xmyu/DiffSDS/Pretrain_lightning/results/ESMSoftGroup/softgroup-2/checkpoints/best-epoch=14-val_loss=0.067.pth', map_location=torch.device('cpu'))
|
| 58 |
+
state_dict = {k.replace('_forward_module.', ''): v for k, v in ckpt.items()}
|
| 59 |
+
self.pretrain_model.load_state_dict(state_dict)
|
| 60 |
+
|
| 61 |
+
if name == "softgroup-3":
|
| 62 |
+
from step1_VQ.model_interface import MInterface
|
| 63 |
+
pretrain_args = OmegaConf.load("/huyuqi/xmyu/DiffSDS/Pretrain_lightning/results/ESMSoftGroup/softgroup-3/configs/10-25T00-04-15-project.yaml")
|
| 64 |
+
pretrain_args.diffusion = False
|
| 65 |
+
self.pretrain_model = MInterface(**pretrain_args)
|
| 66 |
+
ckpt = torch.load('/huyuqi/xmyu/DiffSDS/Pretrain_lightning/results/ESMSoftGroup/softgroup-3/checkpoints/best-epoch=14-val_loss=0.063.pth', map_location=torch.device('cpu'))
|
| 67 |
+
state_dict = {k.replace('_forward_module.', ''): v for k, v in ckpt.items()}
|
| 68 |
+
self.pretrain_model.load_state_dict(state_dict)
|
| 69 |
+
|
| 70 |
+
if name == "softgroup-4":
|
| 71 |
+
from step1_VQ.model_interface import MInterface
|
| 72 |
+
pretrain_args = OmegaConf.load("/huyuqi/xmyu/DiffSDS/Pretrain_lightning/results/ESMSoftGroup/softgroup_32_vectors/configs/10-19T01-03-55-project.yaml")
|
| 73 |
+
pretrain_args.diffusion = False
|
| 74 |
+
self.pretrain_model = MInterface(**pretrain_args)
|
| 75 |
+
ckpt = torch.load('/huyuqi/xmyu/DiffSDS/Pretrain_lightning/results/ESMSoftGroup/softgroup_32_vectors/checkpoints/best-epoch=14-val_loss=0.056.pth', map_location=torch.device('cpu'))
|
| 76 |
+
state_dict = {k.replace('_forward_module.', ''): v for k, v in ckpt.items()}
|
| 77 |
+
self.pretrain_model.load_state_dict(state_dict, strict=False)
|
| 78 |
+
|
| 79 |
+
if name == "softgroup-5":
|
| 80 |
+
from step1_VQ.model_interface import MInterface
|
| 81 |
+
pretrain_args = OmegaConf.load("/huyuqi/xmyu/DiffSDS/Pretrain_lightning/results/ESMSoftGroup/softgroup-5-gzy/configs/10-27T17-15-56-project.yaml")
|
| 82 |
+
pretrain_args.diffusion = False
|
| 83 |
+
self.pretrain_model = MInterface(**pretrain_args)
|
| 84 |
+
ckpt = torch.load('/huyuqi/xmyu/DiffSDS/Pretrain_lightning/results/ESMSoftGroup/softgroup-5-gzy/checkpoints/best-epoch=14-val_loss=0.039.pth', map_location=torch.device('cpu'))
|
| 85 |
+
state_dict = {k.replace('_forward_module.', ''): v for k, v in ckpt.items()}
|
| 86 |
+
self.pretrain_model.load_state_dict(state_dict)
|
| 87 |
+
|
| 88 |
+
if name == "softgroup-6":
|
| 89 |
+
from step1_VQ.model_interface import MInterface
|
| 90 |
+
pretrain_args = OmegaConf.load("/huyuqi/xmyu/DiffSDS/Pretrain_lightning/results/ESMSoftGroup/softgroup_128_group/configs/10-28T01-28-50-project.yaml")
|
| 91 |
+
pretrain_args.diffusion = False
|
| 92 |
+
self.pretrain_model = MInterface(**pretrain_args)
|
| 93 |
+
ckpt = torch.load('/huyuqi/xmyu/DiffSDS/Pretrain_lightning/results/ESMSoftGroup/softgroup_128_group/checkpoints/best-epoch=14-val_loss=0.011.pth', map_location=torch.device('cpu'))
|
| 94 |
+
state_dict = {k.replace('_forward_module.', ''): v for k, v in ckpt.items()}
|
| 95 |
+
self.pretrain_model.load_state_dict(state_dict)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
if name == "softgroup_128_group":
|
| 100 |
+
from step1_VQ.model_interface import MInterface
|
| 101 |
+
pretrain_args = OmegaConf.load("/huyuqi/xmyu/DiffSDS/Pretrain_lightning/results/ESMSoftGroup/softgroup_128_group/configs/10-28T01-28-50-project.yaml")
|
| 102 |
+
pretrain_args.diffusion = False
|
| 103 |
+
self.pretrain_model = MInterface(**pretrain_args)
|
| 104 |
+
ckpt = torch.load('/huyuqi/xmyu/DiffSDS/Pretrain_lightning/results/ESMSoftGroup/softgroup_128_group/checkpoints/best-epoch=14-val_loss=0.011.pth', map_location=torch.device('cpu'))
|
| 105 |
+
state_dict = {k.replace('_forward_module.', ''): v for k, v in ckpt.items()}
|
| 106 |
+
self.pretrain_model.load_state_dict(state_dict)
|
| 107 |
+
|
| 108 |
+
if name == "diff-softgroup-1":
|
| 109 |
+
from step1_VQ.model_interface import MInterface
|
| 110 |
+
pretrain_args = OmegaConf.load("/huyuqi/xmyu/VQProteinFormer/step1_VQ/results/DiffESMSoftGroup/diff-softgroup-rm-dist/configs/12-17T14-19-21-project.yaml")
|
| 111 |
+
pretrain_args.diffusion = True
|
| 112 |
+
self.pretrain_model = MInterface(**pretrain_args)
|
| 113 |
+
ckpt = torch.load('/huyuqi/xmyu/VQProteinFormer/step1_VQ/results/DiffESMSoftGroup/diff-softgroup-rm-dist/checkpoints/best-epoch=12-val_loss=0.496.pth', map_location=torch.device('cpu'))
|
| 114 |
+
state_dict = {k.replace('_forward_module.', ''): v for k, v in ckpt.items()}
|
| 115 |
+
self.pretrain_model.load_state_dict(state_dict)
|
| 116 |
+
|
| 117 |
+
if name == "diff-softgroup-4":
|
| 118 |
+
from step1_VQ.model_interface import MInterface
|
| 119 |
+
pretrain_args = OmegaConf.load("/huyuqi/xmyu/VQProteinFormer/step1_VQ/results/DiffESMSoftGroup/diff-softgroup-vq32/configs/12-19T01-54-15-project.yaml")
|
| 120 |
+
pretrain_args.diffusion = True
|
| 121 |
+
self.pretrain_model = MInterface(**pretrain_args)
|
| 122 |
+
ckpt = torch.load('/huyuqi/xmyu/VQProteinFormer/step1_VQ/results/DiffESMSoftGroup/diff-softgroup-vq32/checkpoints/best-epoch=13-val_loss=0.184.pth', map_location=torch.device('cpu'))
|
| 123 |
+
state_dict = {k.replace('_forward_module.', ''): v for k, v in ckpt.items()}
|
| 124 |
+
self.pretrain_model.load_state_dict(state_dict)
|
| 125 |
+
|
| 126 |
+
if name == "diff-softgroup-5":
|
| 127 |
+
from step1_VQ.model_interface import MInterface
|
| 128 |
+
pretrain_args = OmegaConf.load("/huyuqi/xmyu/VQProteinFormer/step1_VQ/results/DiffESMSoftGroup/diff-softgroup-vq64/configs/12-19T01-57-07-project.yaml")
|
| 129 |
+
pretrain_args.diffusion = True
|
| 130 |
+
self.pretrain_model = MInterface(**pretrain_args)
|
| 131 |
+
ckpt = torch.load('/huyuqi/xmyu/VQProteinFormer/step1_VQ/results/DiffESMSoftGroup/diff-softgroup-vq64/checkpoints/best-epoch=13-val_loss=0.100.pth', map_location=torch.device('cpu'))
|
| 132 |
+
state_dict = {k.replace('_forward_module.', ''): v for k, v in ckpt.items()}
|
| 133 |
+
self.pretrain_model.load_state_dict(state_dict)
|
| 134 |
+
|
| 135 |
+
if name == "diff-softgroup-6":
|
| 136 |
+
from step1_VQ.model_interface import MInterface
|
| 137 |
+
pretrain_args = OmegaConf.load("/huyuqi/xmyu/VQProteinFormer/step1_VQ/results/DiffESMSoftGroup/diff-softgroup-vq128/configs/12-19T10-47-37-project.yaml")
|
| 138 |
+
pretrain_args.diffusion = True
|
| 139 |
+
self.pretrain_model = MInterface(**pretrain_args)
|
| 140 |
+
ckpt = torch.load('/huyuqi/xmyu/VQProteinFormer/step1_VQ/results/DiffESMSoftGroup/diff-softgroup-vq128/checkpoints/best-epoch=13-val_loss=0.081.pth', map_location=torch.device('cpu'))
|
| 141 |
+
state_dict = {k.replace('_forward_module.', ''): v for k, v in ckpt.items()}
|
| 142 |
+
self.pretrain_model.load_state_dict(state_dict)
|
| 143 |
+
|
| 144 |
+
if name == 'vanilla-1':
|
| 145 |
+
from step1_VQ.model_interface import MInterface
|
| 146 |
+
pretrain_args = OmegaConf.load("/huyuqi/xmyu/DiffSDS/Pretrain_lightning/results/ESMVQ/base/configs/10-18T01-15-37-project.yaml")
|
| 147 |
+
self.pretrain_model = MInterface(**pretrain_args)
|
| 148 |
+
ckpt = torch.load("/huyuqi/xmyu/DiffSDS/Pretrain_lightning/results/ESMVQ/base/checkpoints/best-epoch=14-val_loss=0.314.pth")
|
| 149 |
+
state_dict = {k.replace('_forward_module.', ''): v for k, v in ckpt.items()}
|
| 150 |
+
self.pretrain_model.load_state_dict(state_dict)
|
| 151 |
+
|
| 152 |
+
if name == 'soft-1':
|
| 153 |
+
from step1_VQ.model_interface import MInterface
|
| 154 |
+
pretrain_args = OmegaConf.load("/huyuqi/xmyu/VQProteinFormer/step1_VQ/results/ESMSoft/soft_rerun/configs/12-10T12-38-16-project.yaml")
|
| 155 |
+
pretrain_args.diffusion=False
|
| 156 |
+
pretrain_args.attn_type = 'raw'
|
| 157 |
+
self.pretrain_model = MInterface(**pretrain_args)
|
| 158 |
+
ckpt = torch.load("/huyuqi/xmyu/VQProteinFormer/step1_VQ/results/ESMSoft/soft_rerun/checkpoints/best-epoch=14-val_loss=0.018.pth")
|
| 159 |
+
state_dict = {k.replace('_forward_module.', ''): v for k, v in ckpt.items()}
|
| 160 |
+
self.pretrain_model.load_state_dict(state_dict)
|
| 161 |
+
|
| 162 |
+
if name == 'soft_64_vecs':
|
| 163 |
+
pretrain_args = OmegaConf.load("/huyuqi/xmyu/DiffSDS/Pretrain_lightning/results/ESMSoft/soft_vq_num64/configs/10-19T11-11-58-project.yaml")
|
| 164 |
+
self.pretrain_model = MInterface(**pretrain_args)
|
| 165 |
+
ckpt = torch.load("/huyuqi/xmyu/DiffSDS/Pretrain_lightning/results/ESMSoft/soft_vq_num64/checkpoints/best-epoch=14-val_loss=8.768.pth")
|
| 166 |
+
state_dict = {k.replace('_forward_module.', ''): v for k, v in ckpt.items()}
|
| 167 |
+
self.pretrain_model.load_state_dict(state_dict)
|
| 168 |
+
|
| 169 |
+
if name == 'LFQ':
|
| 170 |
+
from step1_VQ.model_interface import MInterface
|
| 171 |
+
pretrain_args = OmegaConf.load("/huyuqi/xmyu/DiffSDS/Pretrain_lightning/results/ESMFVQ/vanilla_L1loss/configs/10-24T01-36-37-project.yaml")
|
| 172 |
+
pretrain_args.diffusion = False
|
| 173 |
+
pretrain_args.attn_type = 'raw'
|
| 174 |
+
self.pretrain_model = MInterface(**pretrain_args)
|
| 175 |
+
ckpt = torch.load("/huyuqi/xmyu/DiffSDS/Pretrain_lightning/results/ESMFVQ/vanilla_L1loss/checkpoints/best-epoch=14-val_loss=11.328.pth")
|
| 176 |
+
state_dict = {k.replace('_forward_module.', ''): v for k, v in ckpt.items()}
|
| 177 |
+
self.pretrain_model.load_state_dict(state_dict, strict=False)
|
| 178 |
+
|
| 179 |
+
if name == 'SCQ-mlp3-vqdim32':
|
| 180 |
+
from step1_VQ.model_interface import MInterface
|
| 181 |
+
pretrain_args = OmegaConf.load("/huyuqi/xmyu/VQProteinFormer/step1_VQ/results/ESMSoftBV/SoftBV-vq16-conditional-mlp3-vqdim32/configs/12-22T07-52-47-project.yaml")
|
| 182 |
+
pretrain_args.diffusion = False
|
| 183 |
+
pretrain_args.vq_dim, pretrain_args.condition_layer, pretrain_args.sphere = 32, 3, False
|
| 184 |
+
|
| 185 |
+
self.pretrain_model = MInterface(**pretrain_args)
|
| 186 |
+
ckpt = torch.load("/huyuqi/xmyu/VQProteinFormer/step1_VQ/results/ESMSoftBV/SoftBV-vq16-conditional-mlp3-vqdim32/checkpoints/best-epoch=14-val_loss=0.376.pth")
|
| 187 |
+
state_dict = {k.replace('_forward_module.', ''): v for k, v in ckpt.items()}
|
| 188 |
+
self.pretrain_model.load_state_dict(state_dict)
|
| 189 |
+
|
| 190 |
+
if name == 'SCQ-mlp3-vqdim32-sphere':
|
| 191 |
+
from step1_VQ.model_interface import MInterface
|
| 192 |
+
pretrain_args = OmegaConf.load("/huyuqi/xmyu/VQProteinFormer/step1_VQ/results/ESMSoftBV/SoftBV-mlp3-vqdim32-sphere/configs/12-22T10-44-46-project.yaml")
|
| 193 |
+
pretrain_args.diffusion = False
|
| 194 |
+
|
| 195 |
+
self.pretrain_model = MInterface(**pretrain_args)
|
| 196 |
+
ckpt = torch.load("/huyuqi/xmyu/VQProteinFormer/step1_VQ/results/ESMSoftBV/SoftBV-mlp3-vqdim32-sphere/checkpoints/best-epoch=14-val_loss=0.454.pth")
|
| 197 |
+
state_dict = {k.replace('_forward_module.', ''): v for k, v in ckpt.items()}
|
| 198 |
+
self.pretrain_model.load_state_dict(state_dict)
|
| 199 |
+
|
| 200 |
+
if name == 'SCQ-mlp6-vqdim32-sphere':
|
| 201 |
+
from step1_VQ.model_interface import MInterface
|
| 202 |
+
pretrain_args = OmegaConf.load("/huyuqi/xmyu/VQProteinFormer/step1_VQ/results/ESMSoftBV/SoftBV-mlp6BN-vqdim32-sphere/configs/12-22T18-28-04-project.yaml")
|
| 203 |
+
pretrain_args.diffusion = False
|
| 204 |
+
pretrain_args.attn_type = 'raw'
|
| 205 |
+
self.pretrain_model = MInterface(**pretrain_args)
|
| 206 |
+
ckpt = torch.load("/huyuqi/xmyu/VQProteinFormer/step1_VQ/results/ESMSoftBV/SoftBV-mlp6BN-vqdim32-sphere/checkpoints/best-epoch=14-val_loss=0.148.pth")
|
| 207 |
+
state_dict = {k.replace('_forward_module.', ''): v for k, v in ckpt.items()}
|
| 208 |
+
self.pretrain_model.load_state_dict(state_dict)
|
| 209 |
+
|
| 210 |
+
if name == 'SCQ-mlp2-vqdim32':
|
| 211 |
+
from step1_VQ.model_interface import MInterface
|
| 212 |
+
pretrain_args = OmegaConf.load("/huyuqi/xmyu/VQProteinFormer/step1_VQ/results/ESMSoftBV/SoftBV-vq16-conditional-mlp2-vqdim32/configs/12-22T00-21-35-project.yaml")
|
| 213 |
+
pretrain_args.diffusion = False
|
| 214 |
+
pretrain_args.vq_dim, pretrain_args.condition_layer, pretrain_args.sphere = 32, 2, False
|
| 215 |
+
self.pretrain_model = MInterface(**pretrain_args)
|
| 216 |
+
ckpt = torch.load("/huyuqi/xmyu/VQProteinFormer/step1_VQ/results/ESMSoftBV/SoftBV-vq16-conditional-mlp2-vqdim32/checkpoints/best-epoch=14-val_loss=0.362.pth")
|
| 217 |
+
state_dict = {k.replace('_forward_module.', ''): v for k, v in ckpt.items()}
|
| 218 |
+
self.pretrain_model.load_state_dict(state_dict)
|
| 219 |
+
|
| 220 |
+
if name == 'SCQ-mlp2-vqdim32-sphere':
|
| 221 |
+
from step1_VQ.model_interface import MInterface
|
| 222 |
+
pretrain_args = OmegaConf.load("/huyuqi/xmyu/VQProteinFormer/step1_VQ/results/ESMSoftBV/SoftBV-vq16-conditional-sphere-vqdim32/configs/12-22T00-06-35-project.yaml")
|
| 223 |
+
pretrain_args.diffusion = False
|
| 224 |
+
pretrain_args.vq_dim, pretrain_args.condition_layer, pretrain_args.sphere = 32, 2, True
|
| 225 |
+
self.pretrain_model = MInterface(**pretrain_args)
|
| 226 |
+
ckpt = torch.load("/huyuqi/xmyu/VQProteinFormer/step1_VQ/results/ESMSoftBV/SoftBV-vq16-conditional-sphere-vqdim32/checkpoints/best-epoch=14-val_loss=0.338.pth")
|
| 227 |
+
state_dict = {k.replace('_forward_module.', ''): v for k, v in ckpt.items()}
|
| 228 |
+
self.pretrain_model.load_state_dict(state_dict)
|
| 229 |
+
|
| 230 |
+
if name == 'SCQ-mlp2-vqdim16':
|
| 231 |
+
from step1_VQ.model_interface import MInterface
|
| 232 |
+
pretrain_args = OmegaConf.load("/huyuqi/xmyu/VQProteinFormer/step1_VQ/results/ESMSoftBV/SoftBV-vq16-conditional/configs/12-21T13-13-11-project.yaml")
|
| 233 |
+
pretrain_args.diffusion = False
|
| 234 |
+
pretrain_args.vq_dim, pretrain_args.condition_layer, pretrain_args.sphere = 16, 2, False
|
| 235 |
+
self.pretrain_model = MInterface(**pretrain_args)
|
| 236 |
+
ckpt = torch.load("/huyuqi/xmyu/VQProteinFormer/step1_VQ/results/ESMSoftBV/SoftBV-vq16-conditional/checkpoints/best-epoch=14-val_loss=0.094.pth")
|
| 237 |
+
state_dict = {k.replace('_forward_module.', ''): v for k, v in ckpt.items()}
|
| 238 |
+
self.pretrain_model.load_state_dict(state_dict)
|
| 239 |
+
|
| 240 |
+
if name == 'SCQ-mlp2-vqdim16-sphere':
|
| 241 |
+
from step1_VQ.model_interface import MInterface
|
| 242 |
+
pretrain_args = OmegaConf.load("/huyuqi/xmyu/VQProteinFormer/step1_VQ/results/ESMSoftBV/SoftBV-vq16-conditional-sphere/configs/12-21T16-38-57-project.yaml")
|
| 243 |
+
pretrain_args.diffusion = False
|
| 244 |
+
pretrain_args.vq_dim, pretrain_args.condition_layer, pretrain_args.sphere = 16, 2, True
|
| 245 |
+
self.pretrain_model = MInterface(**pretrain_args)
|
| 246 |
+
ckpt = torch.load("/huyuqi/xmyu/VQProteinFormer/step1_VQ/results/ESMSoftBV/SoftBV-vq16-conditional-sphere/checkpoints/best-epoch=14-val_loss=1.080.pth")
|
| 247 |
+
state_dict = {k.replace('_forward_module.', ''): v for k, v in ckpt.items()}
|
| 248 |
+
self.pretrain_model.load_state_dict(state_dict)
|
| 249 |
+
|
| 250 |
+
if name == 'SCQ-vq8-mlp6-vqdim16-sphere':
|
| 251 |
+
from step1_VQ.model_interface import MInterface
|
| 252 |
+
pretrain_args = OmegaConf.load("/huyuqi/xmyu/VQProteinFormer/step1_VQ/results/ESMSoftBV/SoftBV-vq8-mlp6BN-vqdim32-sphere/configs/12-23T05-15-56-project.yaml")
|
| 253 |
+
pretrain_args.diffusion = False
|
| 254 |
+
self.pretrain_model = MInterface(**pretrain_args)
|
| 255 |
+
ckpt = torch.load("/huyuqi/xmyu/VQProteinFormer/step1_VQ/results/ESMSoftBV/SoftBV-vq8-mlp6BN-vqdim32-sphere/checkpoints/best-epoch=14-val_loss=0.892.pth")
|
| 256 |
+
state_dict = {k.replace('_forward_module.', ''): v for k, v in ckpt.items()}
|
| 257 |
+
self.pretrain_model.load_state_dict(state_dict)
|
| 258 |
+
|
| 259 |
+
if name == 'SCQ-mlp9-vqdim32-sphere':
|
| 260 |
+
from step1_VQ.model_interface import MInterface
|
| 261 |
+
pretrain_args = OmegaConf.load("/huyuqi/xmyu/VQProteinFormer/step1_VQ/results/ESMSoftBV/SoftBV-mlp9BN-vqdim32-sphere/configs/12-23T16-20-07-project.yaml")
|
| 262 |
+
pretrain_args.diffusion = False
|
| 263 |
+
|
| 264 |
+
self.pretrain_model = MInterface(**pretrain_args)
|
| 265 |
+
ckpt = torch.load("/huyuqi/xmyu/VQProteinFormer/step1_VQ/results/ESMSoftBV/SoftBV-mlp9BN-vqdim32-sphere/checkpoints/best-epoch=14-val_loss=0.151.pth")
|
| 266 |
+
state_dict = {k.replace('_forward_module.', ''): v for k, v in ckpt.items()}
|
| 267 |
+
self.pretrain_model.load_state_dict(state_dict)
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
if name == 'AF2VQ':
|
| 272 |
+
from step3_AF2VQ.model_interface import MInterface
|
| 273 |
+
pretrain_args = OmegaConf.load("/huyuqi/xmyu/VQProteinFormer/step3_AF2VQ/results/AF2VQ_softgroup16/configs/12-13T07-59-50-project.yaml")
|
| 274 |
+
self.pretrain_model = MInterface(**pretrain_args)
|
| 275 |
+
ckpt = torch.load("/huyuqi/xmyu/VQProteinFormer/step3_AF2VQ/results/AF2VQ_softgroup16/checkpoints/best-epoch=11-val_loss=0.812.pth")
|
| 276 |
+
state_dict = {k.replace('_forward_module.', ''): v for k, v in ckpt.items()}
|
| 277 |
+
self.pretrain_model.load_state_dict(state_dict)
|
| 278 |
+
|
| 279 |
+
if name == "ProGLM":
|
| 280 |
+
self.vq_dim=480
|
| 281 |
+
from step2_ProGLM.model.model_interface import MInterface
|
| 282 |
+
pretrain_args = OmegaConf.load("/huyuqi/xmyu/DiffSDS/Inpainting_representation/results/softgroup_bin_1127/version_4/hparams.yaml")
|
| 283 |
+
self.pretrain_model = MInterface(**pretrain_args)
|
| 284 |
+
ckpt = torch.load('/huyuqi/xmyu/DiffSDS/Inpainting_representation/results/softgroup_bin_1127/checkpoints/best-epoch=08-valid_acc=0.804.ckpt', map_location=torch.device('cpu'))['state_dict']
|
| 285 |
+
state_dict = {k.replace('_forward_module.', ''): v for k, v in ckpt.items()}
|
| 286 |
+
self.pretrain_model.load_state_dict(state_dict)
|
| 287 |
+
|
| 288 |
+
if name == 'ProGLM_softgroup_af2db':
|
| 289 |
+
from step2_ProGLM.model.model_interface import MInterface
|
| 290 |
+
pretrain_args = OmegaConf.load("/huyuqi/xmyu/DiffSDS/Inpainting_representation/results/softgroup_bin_2/version_3/hparams.yaml")
|
| 291 |
+
self.pretrain_model = MInterface(**pretrain_args)
|
| 292 |
+
ckpt = torch.load('/huyuqi/xmyu/DiffSDS/Inpainting_representation/results/softgroup_bin_2/checkpoints/best-epoch=13-valid_acc=0.863.ckpt', map_location=torch.device('cpu'))['state_dict']
|
| 293 |
+
state_dict = {k.replace('_forward_module.', ''): v for k, v in ckpt.items()}
|
| 294 |
+
self.pretrain_model.load_state_dict(state_dict)
|
| 295 |
+
|
| 296 |
+
if name == 'ProGLM_SoftVQ_cath':
|
| 297 |
+
from step2_ProGLM.model.model_interface import MInterface
|
| 298 |
+
pretrain_args = OmegaConf.load("/huyuqi/xmyu/VQProteinFormer/step2_ProGLM/results/ProGLM_SoftVQ_epoch15_pad300/configs/12-25T01-20-35-project.yaml")
|
| 299 |
+
self.pretrain_model = MInterface(**pretrain_args)
|
| 300 |
+
ckpt = torch.load('/huyuqi/xmyu/VQProteinFormer/step2_ProGLM/results/ProGLM_SoftVQ_epoch15_pad300/checkpoints/best-epoch=27-valid_acc=0.001.pth')
|
| 301 |
+
state_dict = {k.replace('_forward_module.', ''): v for k, v in ckpt.items()}
|
| 302 |
+
self.pretrain_model.load_state_dict(state_dict)
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
if name == 'ProGLM_SoftCVQ_cath':
|
| 306 |
+
from step2_ProGLM.model.model_interface import MInterface
|
| 307 |
+
pretrain_args = OmegaConf.load("/huyuqi/xmyu/VQProteinFormer/step2_ProGLM/results/ProGLM_SoftCVQ_epoch15_pad300_BCE/configs/12-25T01-42-37-project.yaml")
|
| 308 |
+
self.pretrain_model = MInterface(**pretrain_args)
|
| 309 |
+
ckpt = torch.load('/huyuqi/xmyu/VQProteinFormer/step2_ProGLM/results/ProGLM_SoftCVQ_epoch15_pad300_BCE/checkpoints/best-epoch=14-valid_acc=0.614.pth')
|
| 310 |
+
state_dict = {k.replace('_forward_module.', ''): v for k, v in ckpt.items()}
|
| 311 |
+
self.pretrain_model.load_state_dict(state_dict)
|
| 312 |
+
|
| 313 |
+
if name == 'ProGLM_SoftCVQ_cath_inpaint':
|
| 314 |
+
from step2_ProGLM.model.model_interface import MInterface
|
| 315 |
+
pretrain_args = OmegaConf.load("/huyuqi/xmyu/VQProteinFormer/step2_ProGLM/results/ProGLM_SoftCVQ_epoch15_pad300_BCE_inpaint/configs/12-25T07-47-52-project.yaml")
|
| 316 |
+
self.pretrain_model = MInterface(**pretrain_args)
|
| 317 |
+
ckpt = torch.load('/huyuqi/xmyu/VQProteinFormer/step2_ProGLM/results/ProGLM_SoftCVQ_epoch15_pad300_BCE_inpaint/checkpoints/best-epoch=14-valid_acc=0.616.pth')
|
| 318 |
+
state_dict = {k.replace('_forward_module.', ''): v for k, v in ckpt.items()}
|
| 319 |
+
self.pretrain_model.load_state_dict(state_dict)
|
| 320 |
+
|
| 321 |
+
if name == 'ProGLM_SoftCVQ_AF2DB':
|
| 322 |
+
from step2_ProGLM.model.model_interface import MInterface
|
| 323 |
+
pretrain_args = OmegaConf.load("/huyuqi/xmyu/VQProteinFormer/step2_ProGLM/results/ProGLM_SoftCVQ_epoch15_AF2DB/configs/12-25T13-01-12-project.yaml")
|
| 324 |
+
self.pretrain_model = MInterface(**pretrain_args)
|
| 325 |
+
ckpt = torch.load('/huyuqi/xmyu/VQProteinFormer/step2_ProGLM/results/ProGLM_SoftCVQ_epoch15_AF2DB/checkpoints/best-epoch=14-valid_acc=0.631.pth')
|
| 326 |
+
state_dict = {k.replace('_forward_module.', ''): v for k, v in ckpt.items()}
|
| 327 |
+
self.pretrain_model.load_state_dict(state_dict)
|
| 328 |
+
|
| 329 |
+
if name == 'ProGLM_SoftCVQ_ESM1B_CATH':
|
| 330 |
+
from step2_ProGLM.model.model_interface import MInterface
|
| 331 |
+
pretrain_args = OmegaConf.load("/huyuqi/xmyu/VQProteinFormer/step2_ProGLM/results/ProGLM_SoftCVQ_ESM1B_CATH_lr5e-5/configs/12-25T16-02-35-project.yaml")
|
| 332 |
+
self.pretrain_model = MInterface(**pretrain_args)
|
| 333 |
+
ckpt = torch.load('/huyuqi/xmyu/VQProteinFormer/step2_ProGLM/results/ProGLM_SoftCVQ_ESM1B_CATH_lr5e-5/checkpoints/best-epoch=14-valid_acc=0.616.pth')
|
| 334 |
+
state_dict = {k.replace('_forward_module.', ''): v for k, v in ckpt.items()}
|
| 335 |
+
self.pretrain_model.load_state_dict(state_dict)
|
| 336 |
+
|
| 337 |
+
if name == 'ProGLM_SoftCVQ_CATH':
|
| 338 |
+
from step2_ProGLM.model.model_interface import MInterface
|
| 339 |
+
pretrain_args = OmegaConf.load("/huyuqi/xmyu/VQProteinFormer/step2_ProGLM/results/ProGPT_SoftCVQ_CATH/configs/12-26T08-13-41-project.yaml")
|
| 340 |
+
self.pretrain_model = MInterface(**pretrain_args)
|
| 341 |
+
ckpt = torch.load('/huyuqi/xmyu/VQProteinFormer/step2_ProGLM/results/ProGPT_SoftCVQ_CATH/checkpoints/best-epoch=14-gpt_acc=0.758.pth')
|
| 342 |
+
state_dict = {k.replace('_forward_module.', ''): v for k, v in ckpt.items()}
|
| 343 |
+
self.pretrain_model.load_state_dict(state_dict)
|
| 344 |
+
|
| 345 |
+
if name == 'ProGLM_SoftCVQ_CATH_epoch10k':
|
| 346 |
+
from step2_ProGLM.model.model_interface import MInterface
|
| 347 |
+
pretrain_args = OmegaConf.load("/huyuqi/xmyu/VQProteinFormer/step2_ProGLM/results/ProGPT_SoftCVQ_CATH_epoch1000/configs/12-27T02-36-49-project.yaml")
|
| 348 |
+
self.pretrain_model = MInterface(**pretrain_args)
|
| 349 |
+
ckpt = torch.load('/huyuqi/xmyu/VQProteinFormer/step2_ProGLM/results/ProGPT_SoftCVQ_CATH_epoch10000_resume/checkpoints/best-epoch=1887-gpt_loss=0.181.pth')
|
| 350 |
+
state_dict = {k.replace('_forward_module.', ''): v for k, v in ckpt.items()}
|
| 351 |
+
self.pretrain_model.load_state_dict(state_dict)
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
if name == 'GearNet':
|
| 355 |
+
from model.PretrainGearNet import PretrainGearNet_Model
|
| 356 |
+
self.pretrain_model = PretrainGearNet_Model()
|
| 357 |
+
|
| 358 |
+
self.pretrain_model.eval()
|
| 359 |
+
|
| 360 |
+
def get_vq_id(self, seqs, angles, attn_mask):
|
| 361 |
+
# if ('softgroup' in self.name) or ('LFQ' in self.name):
|
| 362 |
+
# h_input = self.pretrain_model.model.input(seqs.squeeze(-1), angles)
|
| 363 |
+
# h_enc = self.pretrain_model.model.ProteinEnc(h_input, attn_mask, None).last_hidden_state
|
| 364 |
+
# vq_id, e_enc = self.pretrain_model.model.VQLayer.get_vq(h_enc, attn_mask, temperature=1e-5)
|
| 365 |
+
# return F.pad(vq_id, [0,1,0,0])
|
| 366 |
+
|
| 367 |
+
h_input = self.pretrain_model.model.input(seqs.squeeze(-1), angles)
|
| 368 |
+
h_enc = self.pretrain_model.model.ProteinEnc(h_input, attn_mask, None).last_hidden_state
|
| 369 |
+
vq_id, e_enc = self.pretrain_model.model.VQLayer.get_vq(h_enc, attn_mask, temperature=1e-5)
|
| 370 |
+
return vq_id
|
| 371 |
+
|
| 372 |
+
def forward(self, batch):
|
| 373 |
+
if self.name in ["ESM35M", "ESM650M", "ESM3B"]:
|
| 374 |
+
seqs, attn_mask = batch['seqs'], batch['attn_mask']
|
| 375 |
+
outputs = self.pretrain_model.model(input_ids=seqs[:,:,0], attention_mask=attn_mask)
|
| 376 |
+
pretrain_embedding = outputs.hidden_states
|
| 377 |
+
pretrain_embedding = pretrain_embedding.reshape(-1,self.esm_dim)[attn_mask.view(-1)==1]
|
| 378 |
+
return pretrain_embedding
|
| 379 |
+
if self.name in ["softgroup_128_group"]:
|
| 380 |
+
seqs, angles, attn_mask = batch['seqs'], batch['angles'] , batch['attn_mask']
|
| 381 |
+
vq_id = self.pretrain_model.model.get_vqid(seqs[...,0], angles, attn_mask)
|
| 382 |
+
return vq_id
|
| 383 |
+
if self.name in ["ProGLM"]:
|
| 384 |
+
vq_id, attn_mask, seg, pos = batch['vq_id'], batch['attn_mask'], batch['seg'], batch['pos']
|
| 385 |
+
feat = self.pretrain_model.model.get_feat(vq_id, attn_mask, seg, pos)
|
| 386 |
+
feat = feat.reshape(-1,self.vq_dim)[attn_mask.view(-1)==1]
|
| 387 |
+
return feat
|
| 388 |
+
if self.name in ["GearNet"]:
|
| 389 |
+
seqs = batch['seqs']
|
| 390 |
+
batch = batch['batch']
|
| 391 |
+
attn_mask = batch['attn_mask']
|
| 392 |
+
for idx in range(seqs.shape[0]):
|
| 393 |
+
seq_str = self.pretrain_featurizer.ESM_tokenizer.decode(seqs[idx,attn_mask[idx,:].bool(),0])
|
| 394 |
+
seq_strs.append(seq_str.split(" "))
|
| 395 |
+
seq_strs = sum(seq_strs, [])
|
| 396 |
+
node_index = torch.arange(batch.batch.shape[0], device=batch.batch.device)
|
| 397 |
+
node2graph = batch.batch
|
| 398 |
+
chain_id = torch.ones_like(batch.batch)
|
| 399 |
+
|
| 400 |
+
pretrain_embedding = self.pretrain_gearnet_model(seq_strs, node_index, node2graph, chain_id, batch.pos)
|
| 401 |
+
return pretrain_embedding
|
| 402 |
+
|
| 403 |
+
|
| 404 |
+
|
| 405 |
+
|
Flexpert-Design/src/models/E3PiFold_model.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from typing import Optional
|
| 5 |
+
from torch import Tensor
|
| 6 |
+
from omegaconf import OmegaConf
|
| 7 |
+
from src.modules.E3PiFold import GaussianEncoder, TransformerEncoderWithPair
|
| 8 |
+
from src.tools import gather_nodes, _dihedrals, _get_rbf, _get_dist, _rbf, _orientations_coarse_gl_tuple
|
| 9 |
+
|
| 10 |
+
class E3PiFold(nn.Module):
|
| 11 |
+
def __init__(self, config) -> None:
|
| 12 |
+
super().__init__()
|
| 13 |
+
self.node_embed = nn.Linear(21, config.embed_dim)
|
| 14 |
+
self.protein_embedder = GaussianEncoder(config.kernel_num, config.embed_dim, config.attention_heads, config.use_dist, config.use_product)
|
| 15 |
+
|
| 16 |
+
self.encoder = TransformerEncoderWithPair(
|
| 17 |
+
config.encoder_layers,
|
| 18 |
+
config.embed_dim,
|
| 19 |
+
config.ffn_embed_dim,
|
| 20 |
+
config.attention_heads,
|
| 21 |
+
config.emb_dropout,
|
| 22 |
+
config.dropout,
|
| 23 |
+
config.attention_dropout,
|
| 24 |
+
config.activation_dropout,
|
| 25 |
+
config.max_seq_len,
|
| 26 |
+
)
|
| 27 |
+
self.predictor = nn.Linear(config.embed_dim, 33)
|
| 28 |
+
|
| 29 |
+
def _full_dist(self, X, mask, top_k=30, eps=1E-6):
|
| 30 |
+
mask_2D = torch.unsqueeze(mask,1) * torch.unsqueeze(mask,2)
|
| 31 |
+
dX = torch.unsqueeze(X,1) - torch.unsqueeze(X,2)
|
| 32 |
+
D = (1. - mask_2D)*10000 + mask_2D* torch.sqrt(torch.sum(dX**2, 3) + eps)
|
| 33 |
+
|
| 34 |
+
D_max, _ = torch.max(D, -1, keepdim=True)
|
| 35 |
+
D_adjust = D + (1. - mask_2D) * (D_max+1)
|
| 36 |
+
D_neighbors, E_idx = torch.topk(D_adjust, min(top_k, D_adjust.shape[-1]), dim=-1, largest=False)
|
| 37 |
+
return D_neighbors, E_idx
|
| 38 |
+
|
| 39 |
+
def _get_features(self, batch):
|
| 40 |
+
X = batch['X']
|
| 41 |
+
X_ca = X[:,:,1,:]
|
| 42 |
+
D_neighbors, E_idx = self._full_dist(X_ca, batch['mask'], 30)
|
| 43 |
+
V_angles = _dihedrals(X.float())
|
| 44 |
+
V_direct, E_direct, E_angles = _orientations_coarse_gl_tuple(X.float(), E_idx)
|
| 45 |
+
h_V = torch.cat([V_angles, V_direct], dim=-1).to(X.dtype)
|
| 46 |
+
batch['h_V'] = h_V
|
| 47 |
+
return batch
|
| 48 |
+
|
| 49 |
+
def forward(self, batch):
|
| 50 |
+
'''
|
| 51 |
+
X, H, seq_mask
|
| 52 |
+
'''
|
| 53 |
+
X = batch['X'][:,:,1]
|
| 54 |
+
H = self.node_embed(batch['h_V'])
|
| 55 |
+
seq_mask = batch['mask']
|
| 56 |
+
pair_mask = seq_mask[..., None] * seq_mask[..., None, :]
|
| 57 |
+
padding_mask = 1 - seq_mask
|
| 58 |
+
x, graph_attn_bias = self.protein_embedder(X, H, pair_mask)
|
| 59 |
+
(
|
| 60 |
+
encoder_rep,
|
| 61 |
+
encoder_pair_rep,
|
| 62 |
+
delta_encoder_pair_rep,
|
| 63 |
+
x_norm,
|
| 64 |
+
delta_encoder_pair_rep_norm,
|
| 65 |
+
) = self.encoder(x, padding_mask=padding_mask, attn_mask=graph_attn_bias, pair_mask=pair_mask)
|
| 66 |
+
logits = self.predictor(x)
|
| 67 |
+
log_probs = F.log_softmax(logits, dim=-1)
|
| 68 |
+
|
| 69 |
+
return {'log_probs': log_probs}
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
if __name__ == '__main__':
|
| 73 |
+
B, N, dim = 16, 512, 768
|
| 74 |
+
X = torch.randn(B, N, 3)
|
| 75 |
+
H = torch.randn(B, N, dim)
|
| 76 |
+
seq_mask = (torch.ones(B, N)>0.5).float()
|
| 77 |
+
|
| 78 |
+
config = {'encoder_layers': 12,
|
| 79 |
+
'kernel_num':16,
|
| 80 |
+
'embed_dim': 768,
|
| 81 |
+
'ffn_embed_dim': 3072,
|
| 82 |
+
'attention_heads': 8,
|
| 83 |
+
'emb_dropout': 0.1,
|
| 84 |
+
'dropout': 0.1,
|
| 85 |
+
'attention_dropout': 0.1,
|
| 86 |
+
'activation_dropout': 0.0,
|
| 87 |
+
'max_seq_len': 256}
|
| 88 |
+
config = OmegaConf.create(config)
|
| 89 |
+
model = E3PiFold(config)
|
| 90 |
+
feat = model(X, H, seq_mask)
|
Flexpert-Design/src/models/MemoryESM.py
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import subprocess
|
| 2 |
+
import os
|
| 3 |
+
from joblib import Parallel, delayed, cpu_count
|
| 4 |
+
from tqdm import tqdm
|
| 5 |
+
import pandas as pd
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import shutil
|
| 9 |
+
from .PretrainESM_model import PretrainESM_Model
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class MemoESM(nn.Module):
|
| 13 |
+
def __init__(self, args):
|
| 14 |
+
super().__init__()
|
| 15 |
+
self.PretrainESM = PretrainESM_Model(args)
|
| 16 |
+
self.tokenizer = self.PretrainESM.tokenizer
|
| 17 |
+
self.memory = {}
|
| 18 |
+
# self.fix_memory = False
|
| 19 |
+
|
| 20 |
+
# def save_memory(self, path):
|
| 21 |
+
# params = {key:val for key,val in self.state_dict().items() if "GNNTuning" in key}
|
| 22 |
+
# torch.save({"params":params,"memory": self.memory}, path)
|
| 23 |
+
|
| 24 |
+
# def load_memory(self, path):
|
| 25 |
+
# data = torch.load(path)
|
| 26 |
+
# self.load_state_dict(data['params'], strict=False)
|
| 27 |
+
# self.memory = data['memory']
|
| 28 |
+
|
| 29 |
+
def clean_input(self, batch, score_cut=0.99):
|
| 30 |
+
'''
|
| 31 |
+
require: batch['pred_ids'], batch['attention_mask'], batch['confs']
|
| 32 |
+
'''
|
| 33 |
+
symbol = "<mask>"
|
| 34 |
+
replace_dict = {"-":symbol,
|
| 35 |
+
".":symbol,
|
| 36 |
+
"<eos>":symbol,
|
| 37 |
+
"<unk>":symbol,
|
| 38 |
+
"<cls>":symbol,
|
| 39 |
+
"<pad>":symbol,
|
| 40 |
+
"<null_1>":symbol,
|
| 41 |
+
"<mask>":symbol,
|
| 42 |
+
"U":symbol,
|
| 43 |
+
"O":symbol}
|
| 44 |
+
|
| 45 |
+
device = batch['pred_ids'].device
|
| 46 |
+
query_seqs = []
|
| 47 |
+
for pred_ids, mask, score in zip(batch['pred_ids'], batch['attention_mask'], batch['confs']):
|
| 48 |
+
seq = self.tokenizer.decode(pred_ids[mask], clean_up_tokenization_spaces=False)
|
| 49 |
+
elements = []
|
| 50 |
+
for idx, x in enumerate(seq.split(" ")):
|
| 51 |
+
symbol = replace_dict.get(x, x)
|
| 52 |
+
if score[idx] < score_cut:
|
| 53 |
+
symbol = "<mask>"
|
| 54 |
+
elements.append(symbol)
|
| 55 |
+
seq = "".join(elements)
|
| 56 |
+
query_seqs.append(seq)
|
| 57 |
+
|
| 58 |
+
results = self.tokenizer.batch_encode_plus(query_seqs, return_tensors="pt", padding=True)
|
| 59 |
+
return query_seqs
|
| 60 |
+
|
| 61 |
+
def initoutput(self, B, maxL, device):
|
| 62 |
+
self.out_pred_ids = torch.zeros(B, maxL, dtype=torch.long, device=device)
|
| 63 |
+
self.out_confs = torch.zeros(B, maxL, dtype=torch.float, device=device)
|
| 64 |
+
self.out_embeds = torch.zeros(B, maxL, 1280, dtype=torch.float, device=device)
|
| 65 |
+
self.titles = [None for i in range(B)]
|
| 66 |
+
|
| 67 |
+
def retrivel(self, titles, num_nodes, device, use_memory):
|
| 68 |
+
# retrieval
|
| 69 |
+
unseen = []
|
| 70 |
+
for idx in range(len(titles)):
|
| 71 |
+
name = titles[idx]
|
| 72 |
+
if (name in self.memory) and use_memory:
|
| 73 |
+
memo_pred_ids = self.memory[name]['pred_ids'].to(device)
|
| 74 |
+
memo_confs = self.memory[name]['confs'].to(device)
|
| 75 |
+
memo_embeds = self.memory[name]['embeds'].to(device)
|
| 76 |
+
|
| 77 |
+
self.out_pred_ids[idx, :num_nodes[idx]] = memo_pred_ids
|
| 78 |
+
self.out_confs[idx, :num_nodes[idx]] = memo_confs
|
| 79 |
+
self.out_embeds[idx, :num_nodes[idx]] = memo_embeds
|
| 80 |
+
self.titles[idx] = name
|
| 81 |
+
else:
|
| 82 |
+
unseen.append(idx)
|
| 83 |
+
return unseen
|
| 84 |
+
|
| 85 |
+
def rebatch(self, unseen, batch):
|
| 86 |
+
unseen_pred_ids = []
|
| 87 |
+
unseen_attention_mask = []
|
| 88 |
+
for i in unseen:
|
| 89 |
+
unseen_pred_ids.append(batch['pred_ids'][i])
|
| 90 |
+
unseen_attention_mask.append(batch['attention_mask'][i])
|
| 91 |
+
unseen_pred_ids = torch.stack(unseen_pred_ids)
|
| 92 |
+
unseen_attention_mask = torch.stack(unseen_attention_mask)
|
| 93 |
+
return {"pred_ids":unseen_pred_ids, "attention_mask":unseen_attention_mask}
|
| 94 |
+
|
| 95 |
+
def save2memory(self, unseen,outputs, titles, unseen_attention_mask):
|
| 96 |
+
# save to memory
|
| 97 |
+
for i in range(len(unseen)):
|
| 98 |
+
name = titles[unseen[i]]
|
| 99 |
+
self.titles[unseen[i]] = name
|
| 100 |
+
mask = unseen_attention_mask[i]
|
| 101 |
+
self.memory[name] = {"pred_ids":outputs['pred_ids'][i][mask].detach().to('cpu'),
|
| 102 |
+
"confs":outputs['confs'][i][mask].detach().to('cpu'),
|
| 103 |
+
"embeds":outputs['embeds'][i][mask].detach().to('cpu')}
|
| 104 |
+
|
| 105 |
+
def update(self, unseen, unseen_attention_mask, num_nodes, outputs):
|
| 106 |
+
# update
|
| 107 |
+
for idx in range(len(unseen)):
|
| 108 |
+
mask = unseen_attention_mask[idx]==1
|
| 109 |
+
self.out_pred_ids[unseen[idx], :num_nodes[unseen[idx]]] = outputs['pred_ids'][idx][mask]
|
| 110 |
+
self.out_confs[unseen[idx], :num_nodes[unseen[idx]]] = outputs['confs'][idx][mask]
|
| 111 |
+
self.out_embeds[unseen[idx], :num_nodes[unseen[idx]]] = outputs['embeds'][idx][mask]
|
| 112 |
+
|
| 113 |
+
@torch.no_grad()
|
| 114 |
+
def forward(self, batch, use_memory=False):
|
| 115 |
+
# debatch
|
| 116 |
+
# clean_seqs = self.clean_input(batch)
|
| 117 |
+
device = batch['probs'].device
|
| 118 |
+
B, maxL, _ = batch['probs'].shape
|
| 119 |
+
num_nodes = batch['attention_mask'].sum(dim=-1).tolist()
|
| 120 |
+
self.initoutput(B, maxL, device)
|
| 121 |
+
unseen = self.retrivel(batch['title'], num_nodes, device, use_memory)
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
if len(unseen)>0:
|
| 125 |
+
# batch forward
|
| 126 |
+
new_batch = self.rebatch(unseen, batch)
|
| 127 |
+
outputs = self.PretrainESM(new_batch)
|
| 128 |
+
|
| 129 |
+
self.save2memory(unseen,outputs, batch['title'], new_batch['attention_mask'])
|
| 130 |
+
self.update(unseen, new_batch['attention_mask'], num_nodes, outputs)
|
| 131 |
+
|
| 132 |
+
return {'title':self.titles,'pred_ids':self.out_pred_ids, 'confs':self.out_confs, 'embeds':self.out_embeds, 'attention_mask':batch['attention_mask']}
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
if __name__ == '__main__':
|
| 137 |
+
|
| 138 |
+
# work_space = '/gaozhangyang/experiments/PiFoldV2/data/mmseq_workspace2'
|
| 139 |
+
# target_seqs = ["MVLSPADKTNVKAAWGKVGAHAGEYGAEALERMFLSFPQTKTYFPHFDLSHGSAQVKGHG", "MVHLTPEEKSAVTALWGKVNVDEVGVEALGRLLVVYPWTQRFFESFGDLSTPDAVMGNPKV",
|
| 140 |
+
# "MVLSPADKTNVKAAWGKVGAGGAEALERMFLSFPQKTYYTYFPHFDLSHGSAQVKGHG"]
|
| 141 |
+
|
| 142 |
+
# query_seqs = ["MVLSPADKTNVKAAWGKVGAHAGEYGAEALERMFLSFPTTKFPHFDLSHGSAQV", "MVHLTPEEKSAVTALWGKVNVDEVGGGRLLVVYPWTQRFFESFGDLSTPDAV",]
|
| 143 |
+
|
| 144 |
+
# results = search_seqs(query_seqs, target_seqs, work_space)
|
| 145 |
+
# print(results)
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
import biotite.sequence as seq
|
| 149 |
+
import biotite.sequence.align as align
|
| 150 |
+
|
| 151 |
+
# Create example query and target protein sequences
|
| 152 |
+
query_seq1 = seq.ProteinSequence("MSKXXKAFLNKXXL")
|
| 153 |
+
target_seq1 = seq.ProteinSequence("MSKVKAALNKVLL")
|
| 154 |
+
target_seq2 = seq.ProteinSequence("MSKVKKALNKVLL")
|
| 155 |
+
target_seq3 = seq.ProteinSequence("MSTVAAALKMLLL")
|
| 156 |
+
|
| 157 |
+
results = search_seqs_biotite(["MSKXXKAFLNKXXL"], ["MSKVKAALNKVLL", "MSKVKKALNKVLL", "MSTVAAALKMLLL"])
|
| 158 |
+
|
| 159 |
+
# Print the alignments
|
| 160 |
+
print("Query alignments:")
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
|
Flexpert-Design/src/models/MemoryESMIF.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import subprocess
|
| 2 |
+
import os
|
| 3 |
+
from joblib import Parallel, delayed, cpu_count
|
| 4 |
+
from tqdm import tqdm
|
| 5 |
+
import pandas as pd
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
from .PretrainESMIF_model import PretrainESMIF_Model
|
| 9 |
+
from torch_scatter import scatter_sum
|
| 10 |
+
|
| 11 |
+
class MemoESMIF(nn.Module):
|
| 12 |
+
def __init__(self):
|
| 13 |
+
super().__init__()
|
| 14 |
+
self.PretrainESMIF = PretrainESMIF_Model()
|
| 15 |
+
self.memory = {}
|
| 16 |
+
# self.fix_memory = False
|
| 17 |
+
|
| 18 |
+
# def save_memory(self, path):
|
| 19 |
+
# params = {key:val for key,val in self.state_dict().items() if "GNNTuning" in key}
|
| 20 |
+
# torch.save({"params":params,"memory": self.memory}, path)
|
| 21 |
+
|
| 22 |
+
# def load_memory(self, path):
|
| 23 |
+
# data = torch.load(path)
|
| 24 |
+
# self.load_state_dict(data['params'], strict=False)
|
| 25 |
+
# self.memory = data['memory']
|
| 26 |
+
|
| 27 |
+
def initoutput(self, B, maxL, device):
|
| 28 |
+
self.out_embeds = torch.zeros(B, maxL, 512, dtype=torch.float, device=device)
|
| 29 |
+
self.titles = [None for i in range(B)]
|
| 30 |
+
|
| 31 |
+
def retrivel(self, titles, num_nodes, device, use_memory):
|
| 32 |
+
# retrieval
|
| 33 |
+
unseen = []
|
| 34 |
+
for idx in range(len(titles)):
|
| 35 |
+
name = titles[idx]
|
| 36 |
+
if (name in self.memory) and use_memory:
|
| 37 |
+
memo_embeds = self.memory[name]['embeds'].to(device)
|
| 38 |
+
self.out_embeds[idx, :num_nodes[idx]] = memo_embeds
|
| 39 |
+
self.titles[idx] = name
|
| 40 |
+
else:
|
| 41 |
+
unseen.append(idx)
|
| 42 |
+
return unseen
|
| 43 |
+
|
| 44 |
+
def rebatch(self, unseen, batch):
|
| 45 |
+
unseen_position = []
|
| 46 |
+
for i in unseen:
|
| 47 |
+
mask = batch['batch_id']==i
|
| 48 |
+
unseen_position.append(batch['position'][mask][:,:3,:])
|
| 49 |
+
return {"position":unseen_position}
|
| 50 |
+
|
| 51 |
+
def save2memory(self, unseen,outputs, titles, num_nodes):
|
| 52 |
+
# save to memory
|
| 53 |
+
for i in range(len(unseen)):
|
| 54 |
+
name = titles[unseen[i]]
|
| 55 |
+
self.titles[unseen[i]] = name
|
| 56 |
+
num = num_nodes[unseen[i]]
|
| 57 |
+
self.memory[name] = {"embeds":outputs['feat'][i,:num].detach().to('cpu')}
|
| 58 |
+
|
| 59 |
+
def update(self, unseen, num_nodes, outputs):
|
| 60 |
+
# update
|
| 61 |
+
for idx in range(len(unseen)):
|
| 62 |
+
num = num_nodes[unseen[idx]]
|
| 63 |
+
self.out_embeds[unseen[idx], :num_nodes[unseen[idx]]] = outputs['feat'][idx, :num]
|
| 64 |
+
|
| 65 |
+
@torch.no_grad()
|
| 66 |
+
def forward(self, batch, use_memory=False):
|
| 67 |
+
# debatch
|
| 68 |
+
# clean_seqs = self.clean_input(batch)
|
| 69 |
+
device = batch['position'].device
|
| 70 |
+
num_nodes = scatter_sum(torch.ones_like(batch['batch_id']), batch['batch_id'], dim=0)
|
| 71 |
+
B, maxL = num_nodes.shape[0], num_nodes.max()
|
| 72 |
+
self.initoutput(B, maxL, device)
|
| 73 |
+
unseen = self.retrivel(batch['title'], num_nodes, device, use_memory)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
if len(unseen)>0:
|
| 77 |
+
# batch forward
|
| 78 |
+
new_batch = self.rebatch(unseen, batch)
|
| 79 |
+
outputs = self.PretrainESMIF(new_batch['position'])
|
| 80 |
+
|
| 81 |
+
self.save2memory(unseen,outputs, batch['title'], num_nodes)
|
| 82 |
+
self.update(unseen, num_nodes, outputs)
|
| 83 |
+
|
| 84 |
+
return {'title':self.titles, 'embeds':self.out_embeds}
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
if __name__ == '__main__':
|
| 89 |
+
|
| 90 |
+
# work_space = '/gaozhangyang/experiments/PiFoldV2/data/mmseq_workspace2'
|
| 91 |
+
# target_seqs = ["MVLSPADKTNVKAAWGKVGAHAGEYGAEALERMFLSFPQTKTYFPHFDLSHGSAQVKGHG", "MVHLTPEEKSAVTALWGKVNVDEVGVEALGRLLVVYPWTQRFFESFGDLSTPDAVMGNPKV",
|
| 92 |
+
# "MVLSPADKTNVKAAWGKVGAGGAEALERMFLSFPQKTYYTYFPHFDLSHGSAQVKGHG"]
|
| 93 |
+
|
| 94 |
+
# query_seqs = ["MVLSPADKTNVKAAWGKVGAHAGEYGAEALERMFLSFPTTKFPHFDLSHGSAQV", "MVHLTPEEKSAVTALWGKVNVDEVGGGRLLVVYPWTQRFFESFGDLSTPDAV",]
|
| 95 |
+
|
| 96 |
+
# results = search_seqs(query_seqs, target_seqs, work_space)
|
| 97 |
+
# print(results)
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
import biotite.sequence as seq
|
| 101 |
+
import biotite.sequence.align as align
|
| 102 |
+
|
| 103 |
+
# Create example query and target protein sequences
|
| 104 |
+
query_seq1 = seq.ProteinSequence("MSKXXKAFLNKXXL")
|
| 105 |
+
target_seq1 = seq.ProteinSequence("MSKVKAALNKVLL")
|
| 106 |
+
target_seq2 = seq.ProteinSequence("MSKVKKALNKVLL")
|
| 107 |
+
target_seq3 = seq.ProteinSequence("MSTVAAALKMLLL")
|
| 108 |
+
|
| 109 |
+
results = search_seqs_biotite(["MSKXXKAFLNKXXL"], ["MSKVKAALNKVLL", "MSKVKKALNKVLL", "MSTVAAALKMLLL"])
|
| 110 |
+
|
| 111 |
+
# Print the alignments
|
| 112 |
+
print("Query alignments:")
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
|
Flexpert-Design/src/models/MemoryPiFold.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from .PretrainPiFold_model import PretrainPiFold_Model
|
| 4 |
+
from torch_scatter import scatter_sum
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
|
| 7 |
+
class MemoPiFold_model(nn.Module):
|
| 8 |
+
def __init__(self, args):
|
| 9 |
+
super().__init__()
|
| 10 |
+
self.PretrainPiFold = PretrainPiFold_Model(args)
|
| 11 |
+
self.memory = {}
|
| 12 |
+
|
| 13 |
+
def save_memory(self, path):
|
| 14 |
+
params = {key:val for key,val in self.state_dict().items() if "GNNTuning" in key}
|
| 15 |
+
torch.save({"params":params,"memory": self.memory}, path)
|
| 16 |
+
|
| 17 |
+
def load_memory(self, path):
|
| 18 |
+
data = torch.load(path)
|
| 19 |
+
self.load_state_dict(data['params'], strict=False)
|
| 20 |
+
self.memory = data['memory']
|
| 21 |
+
|
| 22 |
+
def initoutput(self, B, max_L, nums, device):
|
| 23 |
+
self.confs = torch.ones(B, max_L, device=device)
|
| 24 |
+
self.embeds = torch.ones(B, max_L, 128, device=device)
|
| 25 |
+
self.probs = torch.ones(B, max_L, 33, device=device)
|
| 26 |
+
self.attention_mask = torch.ones_like(self.confs)==0
|
| 27 |
+
self.titles = [None for i in range(B)]
|
| 28 |
+
for id, num in enumerate(nums):
|
| 29 |
+
self.attention_mask[id, :num] = True
|
| 30 |
+
self.edge_feats = []
|
| 31 |
+
|
| 32 |
+
def retrivel(self, batch, nums, batch_uid, device, use_memory):
|
| 33 |
+
# retrieval
|
| 34 |
+
unseen = []
|
| 35 |
+
|
| 36 |
+
for idx, name in enumerate(batch['title']):
|
| 37 |
+
if (name in self.memory) and use_memory:
|
| 38 |
+
try:
|
| 39 |
+
self.confs[batch_uid[idx],:nums[idx]] = self.memory[name]['conf'].to(device)
|
| 40 |
+
except:
|
| 41 |
+
self.confs[batch_uid[idx],:nums[idx]] = self.memory[name]['conf'].to(device)
|
| 42 |
+
self.embeds[batch_uid[idx],:nums[idx]] = self.memory[name]['embed'].to(device)
|
| 43 |
+
self.probs[batch_uid[idx],:nums[idx]] = self.memory[name]['prob'].to(device)
|
| 44 |
+
self.edge_feats.append((batch_uid[idx], self.memory[name]['h_E'].to(device)))
|
| 45 |
+
self.titles[batch_uid[idx]] = name
|
| 46 |
+
else:
|
| 47 |
+
unseen.append(idx)
|
| 48 |
+
return unseen
|
| 49 |
+
|
| 50 |
+
def rebatch(self, unseen, batch_uid, batch_id, batch, shift, nums, device):
|
| 51 |
+
h_V2, h_E2, E_idx2, batch_id2 = [], [], [], []
|
| 52 |
+
shift2 = [0]
|
| 53 |
+
idx=0
|
| 54 |
+
for id in batch_uid:
|
| 55 |
+
if id not in unseen:
|
| 56 |
+
continue
|
| 57 |
+
node_mask = batch_id == id
|
| 58 |
+
edge_mask = batch_id[batch['E_idx'][0]] == id
|
| 59 |
+
h_V2.append(batch['h_V'][node_mask])
|
| 60 |
+
h_E2.append(batch['h_E'][edge_mask])
|
| 61 |
+
new_E_idx = batch['E_idx'][:,edge_mask]
|
| 62 |
+
new_E_idx = new_E_idx- shift[batch_id[new_E_idx[0]]]+shift2[-1]
|
| 63 |
+
E_idx2.append(new_E_idx)
|
| 64 |
+
new_batch_id = torch.ones(node_mask.sum().long(), device=device)*idx
|
| 65 |
+
batch_id2.append(new_batch_id)
|
| 66 |
+
shift2.append(shift2[-1]+nums[id])
|
| 67 |
+
idx+=1
|
| 68 |
+
|
| 69 |
+
h_V2 = torch.cat(h_V2)
|
| 70 |
+
h_E2 = torch.cat(h_E2)
|
| 71 |
+
E_idx2 = torch.cat(E_idx2, dim=-1)
|
| 72 |
+
batch_id2 = torch.cat(batch_id2).long()
|
| 73 |
+
return {"h_V":h_V2, 'h_E':h_E2, 'E_idx':E_idx2, 'batch_id':batch_id2}
|
| 74 |
+
|
| 75 |
+
def update_save2memory(self, unseen, batch_id2, E_idx2, batch, pretrain_gnn, max_L):
|
| 76 |
+
for id in batch_id2.unique():
|
| 77 |
+
node_mask = batch_id2 == id
|
| 78 |
+
edge_mask = batch_id2[E_idx2[0]] == id
|
| 79 |
+
title = batch['title'][unseen[int(id)]]
|
| 80 |
+
conf = pretrain_gnn['confs'][id]
|
| 81 |
+
conf = F.pad(conf, (0, max_L-len(conf)))
|
| 82 |
+
embed = pretrain_gnn['embeds'][id]
|
| 83 |
+
embed = F.pad(embed, (0,0,0,max_L-len(embed)))
|
| 84 |
+
prob = pretrain_gnn['probs'][id]
|
| 85 |
+
prob = F.pad(prob, (0,0,0,max_L-len(prob)))
|
| 86 |
+
self.edge_feats.append((unseen[int(id)], pretrain_gnn['h_E'][edge_mask]))
|
| 87 |
+
|
| 88 |
+
self.confs[unseen[int(id)]] = conf
|
| 89 |
+
self.embeds[unseen[int(id)]] = embed
|
| 90 |
+
self.probs[unseen[int(id)]] = prob
|
| 91 |
+
self.titles[unseen[int(id)]] = title
|
| 92 |
+
|
| 93 |
+
attn_mask = self.attention_mask[unseen[int(id)]]
|
| 94 |
+
|
| 95 |
+
# save to memory
|
| 96 |
+
self.memory[title] = {'conf': conf[attn_mask].detach().to('cpu'),
|
| 97 |
+
'embed': embed[attn_mask].detach().to('cpu'),
|
| 98 |
+
'prob': prob[attn_mask].detach().to('cpu'),
|
| 99 |
+
'h_E':pretrain_gnn['h_E'][edge_mask].detach().to('cpu')}
|
| 100 |
+
|
| 101 |
+
@torch.no_grad()
|
| 102 |
+
def forward(self, batch, use_memory=False):
|
| 103 |
+
batch_id = batch['batch_id']
|
| 104 |
+
batch_uid = batch_id.unique()
|
| 105 |
+
device = batch_id.device
|
| 106 |
+
|
| 107 |
+
nums = scatter_sum(torch.ones_like(batch_id), batch_id)
|
| 108 |
+
shift = torch.cat([torch.zeros(1, device=device), torch.cumsum(nums, dim=0)]).long()
|
| 109 |
+
max_L, B = nums.max(), batch_uid.shape[0]
|
| 110 |
+
|
| 111 |
+
self.initoutput(B, max_L, nums, device)
|
| 112 |
+
unseen = self.retrivel(batch, nums, batch_uid, device, use_memory)
|
| 113 |
+
|
| 114 |
+
# organize data
|
| 115 |
+
if len(unseen)>0:
|
| 116 |
+
# rebatch
|
| 117 |
+
new_batch = self.rebatch(unseen, batch_uid, batch_id, batch, shift, nums, device)
|
| 118 |
+
|
| 119 |
+
# forward pass
|
| 120 |
+
pretrain_gnn = self.PretrainPiFold(new_batch)
|
| 121 |
+
|
| 122 |
+
self.update_save2memory(unseen, pretrain_gnn['batch_id'], pretrain_gnn['E_idx'], batch, pretrain_gnn, max_L)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
self.edge_feats = sorted(self.edge_feats, key=lambda x: x[0])
|
| 126 |
+
self.edge_feats = torch.cat([one[1] for one in self.edge_feats])
|
| 127 |
+
|
| 128 |
+
pred_ids = self.probs.argmax(dim=-1)*self.attention_mask + (~self.attention_mask)*1
|
| 129 |
+
|
| 130 |
+
return {'title': self.titles,
|
| 131 |
+
'pred_ids': pred_ids,
|
| 132 |
+
'confs': self.confs,
|
| 133 |
+
'embeds': self.embeds,
|
| 134 |
+
'probs': self.probs,
|
| 135 |
+
'attention_mask': self.attention_mask,
|
| 136 |
+
'h_E':self.edge_feats,
|
| 137 |
+
'E_idx': batch['E_idx'],
|
| 138 |
+
'batch_id': batch['batch_id']}
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def _get_features(self, S, score, X, mask, chain_mask, chain_encoding):
|
| 143 |
+
return self.PretrainPiFold._get_features(S, score, X, mask, chain_mask, chain_encoding)
|
Flexpert-Design/src/models/MemoryTuning.py
ADDED
|
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from .Tuning import GNNTuning_Model
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class MemoTuning(nn.Module):
|
| 7 |
+
def __init__(self, args, tunning_layers_n, tunning_layers_dim, input_design_dim, input_esm_dim, tunning_dropout, tokenizer, fix_memory=False):
|
| 8 |
+
super().__init__()
|
| 9 |
+
self.args = args
|
| 10 |
+
self.tunning_layers_dim = tunning_layers_dim
|
| 11 |
+
self.GNNTuning = GNNTuning_Model(args, num_encoder_layers=tunning_layers_n, hidden_dim=tunning_layers_dim, input_design_dim=input_design_dim, input_esm_dim=input_esm_dim, dropout = tunning_dropout)
|
| 12 |
+
self.tokenizer = tokenizer
|
| 13 |
+
self.memory = {}
|
| 14 |
+
|
| 15 |
+
def save_param_memory(self, path):
|
| 16 |
+
torch.save({"params":self.state_dict(),"memory": self.memory}, path)
|
| 17 |
+
|
| 18 |
+
def load_param_memory(self, path):
|
| 19 |
+
data = torch.load(path)
|
| 20 |
+
self.load_state_dict(data['params'])
|
| 21 |
+
self.memory = data['memory']
|
| 22 |
+
|
| 23 |
+
def get_seqs(self, pred_ids_raw, attention_mask):
|
| 24 |
+
query_seqs = []
|
| 25 |
+
for pred_ids, mask in zip(pred_ids_raw, attention_mask):
|
| 26 |
+
seq = self.tokenizer.decode(pred_ids[mask], clean_up_tokenization_spaces=False)
|
| 27 |
+
seq = "".join(seq.split(" "))
|
| 28 |
+
query_seqs.append(seq)
|
| 29 |
+
return query_seqs
|
| 30 |
+
|
| 31 |
+
def initoutput(self, pretrain_design, B, max_L, device):
|
| 32 |
+
# initialize output
|
| 33 |
+
self.out_pred_ids = torch.zeros_like(pretrain_design['pred_ids'])
|
| 34 |
+
self.out_confs = torch.zeros_like(pretrain_design['confs'])
|
| 35 |
+
self.out_embeds = torch.zeros(B, max_L, self.tunning_layers_dim, device = device)
|
| 36 |
+
self.out_attention_mask = torch.zeros_like(pretrain_design['attention_mask'])
|
| 37 |
+
self.out_probs = torch.zeros_like(pretrain_design['probs'])
|
| 38 |
+
self.out_log_probs = torch.zeros_like(pretrain_design['probs'])
|
| 39 |
+
self.titles = [None for i in range(B)]
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def retrivel(self, keys, num_nodes,device, use_memory):
|
| 44 |
+
unseen = []
|
| 45 |
+
for idx in range(len(keys)):
|
| 46 |
+
key = keys[idx]
|
| 47 |
+
if (key in self.memory) and use_memory:
|
| 48 |
+
self.out_pred_ids[idx, :num_nodes[idx]] = self.memory[key]['pred_ids'].to(device)
|
| 49 |
+
self.out_confs[idx, :num_nodes[idx]] = self.memory[key]['confs'].to(device)
|
| 50 |
+
self.out_embeds[idx, :num_nodes[idx]] = self.memory[key]['embeds'].to(device)
|
| 51 |
+
self.out_attention_mask[idx, :num_nodes[idx]] = self.memory[key]['attention_mask'].to(device)
|
| 52 |
+
self.out_probs[idx, :num_nodes[idx]] = self.memory[key]['probs'].to(device)
|
| 53 |
+
self.out_log_probs[idx, :num_nodes[idx]] = self.memory[key]['log_probs'].to(device)
|
| 54 |
+
self.titles[idx] = key
|
| 55 |
+
else:
|
| 56 |
+
unseen.append(idx)
|
| 57 |
+
return unseen
|
| 58 |
+
|
| 59 |
+
def rebatch(self,unseen, batch_id_raw, E_idx_raw, h_E_raw, shift, num_nodes, pretrain_design, pretrain_esm_msa, pretrain_struct, pretrain_esmif, device):
|
| 60 |
+
unseen_design_pred_ids = []
|
| 61 |
+
unseen_design_confs = []
|
| 62 |
+
unseen_design_embeds = []
|
| 63 |
+
unseen_design_attention_mask = []
|
| 64 |
+
|
| 65 |
+
unseen_esm_pred_ids = []
|
| 66 |
+
unseen_esm_confs = []
|
| 67 |
+
unseen_esm_embeds = []
|
| 68 |
+
unseen_esm_attention_mask = []
|
| 69 |
+
unseen_struct_embeds = []
|
| 70 |
+
unseen_esmif_embeds = []
|
| 71 |
+
h_E = []
|
| 72 |
+
E_idx = []
|
| 73 |
+
batch_id = []
|
| 74 |
+
|
| 75 |
+
new_shift = 0
|
| 76 |
+
for bid, i in enumerate(unseen):
|
| 77 |
+
edge_mask = batch_id_raw[E_idx_raw[0]] == i
|
| 78 |
+
h_E.append(h_E_raw[edge_mask])
|
| 79 |
+
E_idx.append(E_idx_raw[:,edge_mask]-shift[i]+new_shift)
|
| 80 |
+
batch_id.append(torch.ones(num_nodes[i], device=device).long()*bid)
|
| 81 |
+
new_shift += num_nodes[i]
|
| 82 |
+
|
| 83 |
+
unseen_design_pred_ids.append(pretrain_design['pred_ids'][i])
|
| 84 |
+
unseen_design_confs.append(pretrain_design['confs'][i])
|
| 85 |
+
unseen_design_embeds.append(pretrain_design['embeds'][i])
|
| 86 |
+
unseen_design_attention_mask.append(pretrain_design['attention_mask'][i])
|
| 87 |
+
|
| 88 |
+
if self.args.use_LM:
|
| 89 |
+
unseen_esm_pred_ids.append(pretrain_esm_msa['pred_ids'][:,i])
|
| 90 |
+
unseen_esm_confs.append(pretrain_esm_msa['confs'][:,i])
|
| 91 |
+
unseen_esm_embeds.append(pretrain_esm_msa['embeds'][:,i])
|
| 92 |
+
unseen_esm_attention_mask.append(pretrain_esm_msa['attention_mask'][:,i])
|
| 93 |
+
|
| 94 |
+
if self.args.use_gearnet:
|
| 95 |
+
unseen_struct_embeds.append(pretrain_struct['embeds'][:,i])
|
| 96 |
+
|
| 97 |
+
if self.args.use_esmif:
|
| 98 |
+
unseen_esmif_embeds.append(pretrain_esmif['embeds'][i])
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
unseen_design_pred_ids = torch.stack(unseen_design_pred_ids)
|
| 102 |
+
unseen_design_confs = torch.stack(unseen_design_confs)
|
| 103 |
+
unseen_design_embeds = torch.stack(unseen_design_embeds)
|
| 104 |
+
unseen_design_attention_mask = torch.stack(unseen_design_attention_mask)
|
| 105 |
+
|
| 106 |
+
if self.args.use_LM:
|
| 107 |
+
unseen_esm_pred_ids = torch.stack(unseen_esm_pred_ids, dim=1)
|
| 108 |
+
unseen_esm_confs = torch.stack(unseen_esm_confs, dim=1)
|
| 109 |
+
unseen_esm_embeds = torch.stack(unseen_esm_embeds, dim=1)
|
| 110 |
+
unseen_esm_attention_mask = torch.stack(unseen_esm_attention_mask, dim=1)
|
| 111 |
+
|
| 112 |
+
if self.args.use_gearnet:
|
| 113 |
+
unseen_struct_embeds = torch.stack(unseen_struct_embeds, dim=1)
|
| 114 |
+
|
| 115 |
+
if self.args.use_esmif:
|
| 116 |
+
unseen_esmif_embeds = torch.stack(unseen_esmif_embeds, dim=0)
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
unseen_batch = {"pretrain_design":
|
| 120 |
+
{"pred_ids": unseen_design_pred_ids,
|
| 121 |
+
"confs":unseen_design_confs,
|
| 122 |
+
"embeds": unseen_design_embeds,
|
| 123 |
+
"attention_mask":unseen_design_attention_mask},
|
| 124 |
+
"h_E": torch.cat(h_E),
|
| 125 |
+
"E_idx": torch.cat(E_idx, dim=1),
|
| 126 |
+
"batch_id": torch.cat(batch_id),
|
| 127 |
+
"attention_mask":unseen_design_attention_mask
|
| 128 |
+
}
|
| 129 |
+
|
| 130 |
+
if self.args.use_LM:
|
| 131 |
+
unseen_batch["pretrain_esm_msa"]={"pred_ids": unseen_esm_pred_ids,
|
| 132 |
+
"confs":unseen_esm_confs,
|
| 133 |
+
"embeds": unseen_esm_embeds,
|
| 134 |
+
"attention_mask":unseen_esm_attention_mask}
|
| 135 |
+
|
| 136 |
+
if self.args.use_gearnet:
|
| 137 |
+
unseen_batch["pretrain_struct"] = {
|
| 138 |
+
"embeds":unseen_struct_embeds}
|
| 139 |
+
|
| 140 |
+
if self.args.use_esmif:
|
| 141 |
+
unseen_batch["pretrain_esmif"] = {"embeds":unseen_esmif_embeds}
|
| 142 |
+
return unseen_batch
|
| 143 |
+
|
| 144 |
+
def save2memory(self,keys,unseen,num_nodes, unseen_results):
|
| 145 |
+
# save to memory
|
| 146 |
+
for i in range(len(unseen)):
|
| 147 |
+
key = keys[unseen[i]]
|
| 148 |
+
num = num_nodes[unseen[i]]
|
| 149 |
+
self.memory[key] = {"pred_ids":unseen_results['pred_ids'][i][:num].detach().to('cpu'),
|
| 150 |
+
"confs":unseen_results['confs'][i][:num].detach().to('cpu'),
|
| 151 |
+
"embeds":unseen_results['embeds'][i][:num].detach().to('cpu'),
|
| 152 |
+
"probs":unseen_results['probs'][i][:num].detach().to('cpu'),
|
| 153 |
+
"log_probs":unseen_results['log_probs'][i][:num].detach().to('cpu'),
|
| 154 |
+
"attention_mask":unseen_results['attention_mask'][i][:num].detach().to('cpu')}
|
| 155 |
+
|
| 156 |
+
def update(self, unseen, num_nodes, unseen_results, keys):
|
| 157 |
+
# update
|
| 158 |
+
for i in range(len(unseen)):
|
| 159 |
+
num = num_nodes[unseen[i]]
|
| 160 |
+
self.out_pred_ids[unseen[i], :num] = unseen_results['pred_ids'][i][:num]
|
| 161 |
+
self.out_confs[unseen[i], :num] = unseen_results['confs'][i][:num]
|
| 162 |
+
self.out_embeds[unseen[i], :num] = unseen_results['embeds'][i][:num]
|
| 163 |
+
self.out_probs[unseen[i], :num] = unseen_results['probs'][i][:num]
|
| 164 |
+
self.out_log_probs[unseen[i], :num] = unseen_results['log_probs'][i][:num]
|
| 165 |
+
self.titles[unseen[i]] = keys[unseen[i]]
|
| 166 |
+
|
| 167 |
+
def forward(self, batch, use_memory=False):
|
| 168 |
+
self.use_memory = use_memory
|
| 169 |
+
pretrain_design, h_E_raw, E_idx_raw, mask_attend, batch_id_raw = batch['pretrain_design'] ,batch['h_E'], batch['E_idx'], batch['attention_mask'], batch['batch_id']
|
| 170 |
+
device = h_E_raw.device
|
| 171 |
+
|
| 172 |
+
pretrain_esm_msa = None
|
| 173 |
+
if self.args.use_LM:
|
| 174 |
+
pretrain_esm_msa = batch['pretrain_esm_msa']
|
| 175 |
+
|
| 176 |
+
pretrain_struct = None
|
| 177 |
+
if self.args.use_gearnet:
|
| 178 |
+
pretrain_struct = batch['pretrain_struct']
|
| 179 |
+
|
| 180 |
+
pretrain_esmif = None
|
| 181 |
+
if self.args.use_esmif:
|
| 182 |
+
pretrain_esmif = batch['esm_feat']
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
num_nodes = batch['attention_mask'].sum(dim=-1)
|
| 186 |
+
shift = torch.cat([torch.zeros(1, device=device), torch.cumsum(num_nodes, dim=0)]).long()
|
| 187 |
+
|
| 188 |
+
B, max_L = num_nodes.shape[0], num_nodes.max()
|
| 189 |
+
|
| 190 |
+
self.initoutput(pretrain_design, B, max_L, device)
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
# keys = list(zip(design_seqs, *lm_seqs))
|
| 194 |
+
keys = batch['title']
|
| 195 |
+
unseen = self.retrivel(keys, num_nodes,device, use_memory)
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
if len(unseen)>0:
|
| 199 |
+
unseen_batch = self.rebatch(unseen, batch_id_raw, E_idx_raw, h_E_raw, shift, num_nodes, pretrain_design, pretrain_esm_msa, pretrain_struct, pretrain_esmif, device)
|
| 200 |
+
unseen_results = self.GNNTuning(unseen_batch)
|
| 201 |
+
|
| 202 |
+
self.save2memory(keys,unseen,num_nodes, unseen_results)
|
| 203 |
+
self.update(unseen, num_nodes, unseen_results, keys)
|
| 204 |
+
|
| 205 |
+
return {'title':self.titles,'pred_ids':self.out_pred_ids, 'confs':self.out_confs, 'embeds':self.out_embeds, 'probs':self.out_probs, "log_probs":self.out_log_probs, 'attention_mask':pretrain_design['attention_mask']}
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
|
Flexpert-Design/src/models/PretrainESMIF_model.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import esm
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch
|
| 4 |
+
from esm.inverse_folding.util import CoordBatchConverter
|
| 5 |
+
|
| 6 |
+
class PretrainESMIF_Model(nn.Module):
|
| 7 |
+
def __init__(self):
|
| 8 |
+
super(PretrainESMIF_Model, self).__init__()
|
| 9 |
+
# /root/.cache/torch/hub/checkpoints
|
| 10 |
+
model_data = torch.load("./model_zoo/esmif/esm_if1_gvp4_t16_142M_UR50.pt")
|
| 11 |
+
self.model, self.alphabet = esm.pretrained.load_model_and_alphabet_core("esm_if1_gvp4_t16_142M_UR50", model_data, None)
|
| 12 |
+
|
| 13 |
+
def forward(self, coords_list):
|
| 14 |
+
self.model.eval()
|
| 15 |
+
batch_converter = CoordBatchConverter(self.model.decoder.dictionary)
|
| 16 |
+
batch_coords, confidence, _, _, padding_mask = (
|
| 17 |
+
batch_converter([(coord, None, None) for coord in coords_list], device=coords_list[0].device)
|
| 18 |
+
)
|
| 19 |
+
with torch.no_grad():
|
| 20 |
+
encoder_out = self.model.encoder(batch_coords, padding_mask, confidence)
|
| 21 |
+
|
| 22 |
+
feat = encoder_out['encoder_out'][0].permute(1,0,2)[:,1:-1] # 2,1046-2,512
|
| 23 |
+
attention_mask = encoder_out['encoder_padding_mask'][0][:,1:-1]==False # 2,1046-2
|
| 24 |
+
|
| 25 |
+
return {"feat":feat}
|
| 26 |
+
|
| 27 |
+
if __name__ == '__main__':
|
| 28 |
+
model = PretrainESMIF_Model(0.1)
|
| 29 |
+
coords1 = torch.rand(1044,3,3)#N, CA, C
|
| 30 |
+
coords2 = torch.rand(500,3,3)
|
| 31 |
+
model([coords1, coords2])
|
| 32 |
+
print()
|
Flexpert-Design/src/models/PretrainESM_model.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
import torch
|
| 3 |
+
import math
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from transformers import AutoTokenizer, EsmForMaskedLM # EsmForMaskedLM, 1041 line
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class PretrainESM_Model(nn.Module):
|
| 10 |
+
def __init__(self, args):
|
| 11 |
+
""" Graph labeling network """
|
| 12 |
+
super(PretrainESM_Model, self).__init__()
|
| 13 |
+
self.args=args
|
| 14 |
+
# {0: '<cls>', 1: '<pad>', 2: '<eos>', 3: '<unk>', 4: 'L', 5: 'A', 6: 'G', 7: 'V', 8: 'S', 9: 'E', 10: 'R', 11: 'T', 12: 'I', 13: 'D', 14: 'P', 15: 'K', 16: 'Q', 17: 'N', 18: 'F', 19: 'Y', 20: 'M', 21: 'H', 22: 'W', 23: 'C', 24: 'X', 25: 'B', 26: 'U', 27: 'Z', 28: 'O', 29: '.', 30: '-', 31: '<null_1>', 32: '<mask>'}
|
| 15 |
+
self.tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D", cache_dir="./cache_dir/")
|
| 16 |
+
self.model = EsmForMaskedLM.from_pretrained("facebook/esm2_t33_650M_UR50D", cache_dir="./cache_dir/")
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def forward(self,batch):
|
| 20 |
+
outputs = self.model(input_ids=batch['pred_ids'], attention_mask=batch['attention_mask'])
|
| 21 |
+
logits = outputs.logits
|
| 22 |
+
|
| 23 |
+
prop = logits.softmax(dim=-1)
|
| 24 |
+
confidences, pred_ids = prop.max(dim=-1)
|
| 25 |
+
|
| 26 |
+
ret = {"pred_ids": pred_ids,
|
| 27 |
+
"confs": confidences,
|
| 28 |
+
"embeds": outputs.hidden_states,
|
| 29 |
+
"attention_mask": batch['attention_mask']}
|
| 30 |
+
return ret
|
| 31 |
+
|
| 32 |
+
if __name__ == '__main__':
|
| 33 |
+
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D", cache_dir="./cache_dir/")
|
| 34 |
+
tokenizer.convert_ids_to_tokens
|
| 35 |
+
print()
|
Flexpert-Design/src/models/PretrainPiFold_model.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import os.path as osp
|
| 3 |
+
from src.models.pifold_model import PiFold_Model
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class PretrainPiFold_Model(PiFold_Model):
|
| 8 |
+
def __init__(self, args, **kwargs):
|
| 9 |
+
""" Graph labeling network """
|
| 10 |
+
PiFold_Model.__init__(self, args)
|
| 11 |
+
if args.augment_eps>0:
|
| 12 |
+
pretrain_pifold_path = osp.join(self.args.res_dir, self.args.dataset, f"PiFold_{args.augment_eps}", "checkpoint.pth")
|
| 13 |
+
else:
|
| 14 |
+
# pretrain_pifold_path = osp.join(self.args.res_dir, self.args.dataset, "PiFold", "checkpoint.pth")
|
| 15 |
+
pretrain_pifold_path = osp.join('model_zoo', self.args.dataset, "PiFold", "checkpoint.pth")
|
| 16 |
+
self.load_state_dict(torch.load(pretrain_pifold_path))
|
| 17 |
+
|
| 18 |
+
@torch.no_grad()
|
| 19 |
+
def forward(self, batch):
|
| 20 |
+
h_V, h_P, P_idx, batch_id = batch['h_V'], batch['h_E'], batch['E_idx'], batch['batch_id']
|
| 21 |
+
device = h_V.device
|
| 22 |
+
h_V = self.W_v(self.norm_nodes(self.node_embedding(h_V)))
|
| 23 |
+
h_P = self.W_e(self.norm_edges(self.edge_embedding(h_P)))
|
| 24 |
+
|
| 25 |
+
h_V, h_P = self.encoder(h_V, h_P, P_idx, batch_id)
|
| 26 |
+
log_probs, logits = self.decoder(h_V, batch_id)
|
| 27 |
+
probs = F.softmax(logits, dim=-1)
|
| 28 |
+
conf, pred_id = probs.max(dim=-1)
|
| 29 |
+
|
| 30 |
+
maxL = 0
|
| 31 |
+
for b in batch_id.unique():
|
| 32 |
+
mask = batch_id==b
|
| 33 |
+
L = mask.sum()
|
| 34 |
+
if L>maxL:
|
| 35 |
+
maxL=L
|
| 36 |
+
|
| 37 |
+
confs = []
|
| 38 |
+
seqs = []
|
| 39 |
+
embeds = []
|
| 40 |
+
probs2 = []
|
| 41 |
+
for b in batch_id.unique():
|
| 42 |
+
mask = batch_id==b
|
| 43 |
+
# elements = [alphabet[int(id)] for id in pred_id[mask]]
|
| 44 |
+
elements = self.tokenizer.decode(pred_id[mask]).split(" ")
|
| 45 |
+
seqs.append(elements)
|
| 46 |
+
confs.append(conf[mask])
|
| 47 |
+
embeds.append(h_V[mask])
|
| 48 |
+
probs2.append(probs[mask])
|
| 49 |
+
|
| 50 |
+
seqs = self.tokenizer(["".join(one) for one in seqs], padding=True, truncation=True, return_tensors='pt', add_special_tokens=False)
|
| 51 |
+
confs = torch.stack([F.pad(one, (0, maxL-len(one))) for one in confs])
|
| 52 |
+
embeds = torch.stack([F.pad(one, (0,0, 0, maxL-len(one))) for one in embeds])
|
| 53 |
+
probs2 = torch.stack([F.pad(one, (0,0, 0, maxL-len(one)), value=1/33) for one in probs2])
|
| 54 |
+
|
| 55 |
+
ret = {"pred_ids":seqs['input_ids'].to(device),
|
| 56 |
+
"confs":confs,
|
| 57 |
+
"embeds":embeds,
|
| 58 |
+
"probs":probs2,
|
| 59 |
+
"attention_mask":seqs['attention_mask'].to(device),
|
| 60 |
+
"E_idx":P_idx,
|
| 61 |
+
"batch_id":batch_id,
|
| 62 |
+
"h_E":h_P}
|
| 63 |
+
return ret
|
| 64 |
+
|
Flexpert-Design/src/models/Tuning.py
ADDED
|
@@ -0,0 +1,275 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
from src.modules.pifold_module import *
|
| 3 |
+
from torch_scatter import scatter_softmax, scatter_log_softmax
|
| 4 |
+
|
| 5 |
+
def positional_encoding(x):
|
| 6 |
+
batch_size, seq_len, hidden_size = x.size()
|
| 7 |
+
pos = torch.arange(0, seq_len).float().unsqueeze(1).repeat(1, hidden_size // 2)
|
| 8 |
+
div = torch.exp(torch.arange(0, hidden_size, 2).float() * (-torch.log(torch.tensor(10000.0)) / hidden_size))
|
| 9 |
+
sin = torch.sin(pos * div)
|
| 10 |
+
cos = torch.cos(pos * div)
|
| 11 |
+
pos_encoding = torch.cat([sin, cos], dim=-1).unsqueeze(0).repeat(batch_size, 1, 1)
|
| 12 |
+
return pos_encoding
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class MSAAttention(nn.Module):
|
| 16 |
+
def __init__(self, hidden_dim) -> None:
|
| 17 |
+
super().__init__()
|
| 18 |
+
self.MSA_Q = nn.Linear(hidden_dim, hidden_dim)
|
| 19 |
+
self.MSA_K = nn.Linear(hidden_dim, hidden_dim)
|
| 20 |
+
self.MSA_V = nn.Linear(hidden_dim, hidden_dim)
|
| 21 |
+
|
| 22 |
+
def forward(self, inputs_embeds):
|
| 23 |
+
pos_enc = positional_encoding(inputs_embeds)
|
| 24 |
+
inputs_embeds = inputs_embeds + pos_enc
|
| 25 |
+
|
| 26 |
+
query = self.MSA_Q(inputs_embeds) # shape: [batch, N, 128]
|
| 27 |
+
key = self.MSA_K(inputs_embeds) # shape: [batch, N, 128]
|
| 28 |
+
value = self.MSA_V(inputs_embeds) # shape: [batch, N, 128]
|
| 29 |
+
attn_scores = torch.bmm(query, key.transpose(1, 2))
|
| 30 |
+
attn_weights = nn.functional.softmax(attn_scores, dim=2)
|
| 31 |
+
|
| 32 |
+
attn_output = torch.bmm(attn_weights, value)
|
| 33 |
+
return attn_output
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class GNNTuning_Model(nn.Module):
|
| 37 |
+
def __init__(self, args, num_encoder_layers, hidden_dim, input_design_dim, input_esm_dim, input_struct_dim=3072, input_esmif_dim=512, dropout=0.1):
|
| 38 |
+
super(GNNTuning_Model, self).__init__()
|
| 39 |
+
self.args = args
|
| 40 |
+
encoder_layers = []
|
| 41 |
+
for i in range(num_encoder_layers):
|
| 42 |
+
encoder_layers.append(
|
| 43 |
+
GeneralGNN(hidden_dim, hidden_dim*2, dropout=dropout, node_net = "AttMLP", edge_net = "EdgeMLP", node_context = 1, edge_context = 0),
|
| 44 |
+
)
|
| 45 |
+
self.encoder_layers = nn.Sequential(*encoder_layers)
|
| 46 |
+
|
| 47 |
+
from transformers import AutoTokenizer
|
| 48 |
+
from transformers.models.esm.modeling_esm import EsmModel, EsmEmbeddings
|
| 49 |
+
from transformers.models.esm.configuration_esm import EsmConfig
|
| 50 |
+
|
| 51 |
+
self.tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D", cache_dir="./cache_dir/")
|
| 52 |
+
config = EsmConfig(attention_probs_dropout_prob=0,
|
| 53 |
+
hidden_size=hidden_dim,
|
| 54 |
+
intermediate_size=1280,
|
| 55 |
+
mask_token_id=32,
|
| 56 |
+
num_attention_heads=12,
|
| 57 |
+
num_hidden_layers=3,
|
| 58 |
+
pad_token_id=1,
|
| 59 |
+
position_embedding_type="rotary",
|
| 60 |
+
token_dropout=False,
|
| 61 |
+
vocab_size=33
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
self.DesignEmbed = EsmEmbeddings(config)
|
| 65 |
+
self.ESMEmbed = EsmEmbeddings(config)
|
| 66 |
+
self.EdgeEmbed = nn.Sequential(nn.Linear(416+16+16, 512),
|
| 67 |
+
nn.ReLU(),
|
| 68 |
+
nn.Linear(512, hidden_dim),
|
| 69 |
+
nn.ReLU(),
|
| 70 |
+
nn.Linear(hidden_dim,hidden_dim))
|
| 71 |
+
|
| 72 |
+
self.DesignConf = nn.Sequential(nn.Linear(1, 128),
|
| 73 |
+
nn.ReLU(),
|
| 74 |
+
nn.Linear(128, 128),
|
| 75 |
+
nn.ReLU(),
|
| 76 |
+
nn.Linear(128,1),
|
| 77 |
+
nn.Sigmoid())
|
| 78 |
+
|
| 79 |
+
self.ESMConf = nn.Sequential(nn.Linear(1, 128),
|
| 80 |
+
nn.ReLU(),
|
| 81 |
+
nn.Linear(128, 128),
|
| 82 |
+
nn.ReLU(),
|
| 83 |
+
nn.Linear(128,1))
|
| 84 |
+
|
| 85 |
+
self.DesignProj = nn.Sequential(nn.Linear(input_design_dim, 512),
|
| 86 |
+
nn.ReLU(),
|
| 87 |
+
nn.Linear(512, hidden_dim),
|
| 88 |
+
nn.ReLU(),
|
| 89 |
+
nn.Linear(hidden_dim,hidden_dim))
|
| 90 |
+
|
| 91 |
+
self.ESMProj = nn.Sequential(nn.Linear(input_esm_dim, 512),
|
| 92 |
+
nn.ReLU(),
|
| 93 |
+
nn.Linear(512, hidden_dim),
|
| 94 |
+
nn.ReLU(),
|
| 95 |
+
nn.Linear(hidden_dim,hidden_dim))
|
| 96 |
+
|
| 97 |
+
self.StructProj = nn.Sequential(nn.Linear(input_struct_dim, 512),
|
| 98 |
+
nn.ReLU(),
|
| 99 |
+
nn.Linear(512, hidden_dim),
|
| 100 |
+
nn.ReLU(),
|
| 101 |
+
nn.Linear(hidden_dim,hidden_dim))
|
| 102 |
+
|
| 103 |
+
self.ESMIFProj = nn.Sequential(nn.Linear(input_esmif_dim, 512),
|
| 104 |
+
nn.ReLU(),
|
| 105 |
+
nn.Linear(512, hidden_dim),
|
| 106 |
+
nn.ReLU(),
|
| 107 |
+
nn.Linear(hidden_dim,hidden_dim))
|
| 108 |
+
|
| 109 |
+
self.ReadOut = nn.Linear(hidden_dim,33)
|
| 110 |
+
# self.TimeEmbed = nn.Embedding(20, hidden_dim)
|
| 111 |
+
# self.ProbEmbed = nn.Sequential(nn.Linear(33, 512),
|
| 112 |
+
# nn.ReLU(),
|
| 113 |
+
# nn.Linear(512, hidden_dim),
|
| 114 |
+
# nn.ReLU(),
|
| 115 |
+
# nn.Linear(hidden_dim,hidden_dim))
|
| 116 |
+
|
| 117 |
+
self.MLP1 = nn.Sequential(nn.Linear(1, 512),
|
| 118 |
+
nn.ReLU(),
|
| 119 |
+
nn.Linear(512, hidden_dim),
|
| 120 |
+
nn.ReLU(),
|
| 121 |
+
nn.Linear(hidden_dim,1),
|
| 122 |
+
nn.Sigmoid())
|
| 123 |
+
|
| 124 |
+
self.MLP2 = nn.Sequential(nn.Linear(1, 512),
|
| 125 |
+
nn.ReLU(),
|
| 126 |
+
nn.Linear(512, hidden_dim),
|
| 127 |
+
nn.ReLU(),
|
| 128 |
+
nn.Linear(hidden_dim,1),
|
| 129 |
+
nn.Sigmoid())
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
# def embed_gnn(self, pretrain_gnn, mask_select_id, mask_select_feat):
|
| 135 |
+
# gnn_embed = self.DesignEmbed(mask_select_id(pretrain_gnn['pred_ids'])).squeeze()
|
| 136 |
+
# gnn_conf = self.DesignConf(mask_select_id(pretrain_gnn['confs']))
|
| 137 |
+
# gnn_proj = self.DesignProj(mask_select_feat(pretrain_gnn['embeds']))
|
| 138 |
+
|
| 139 |
+
# if self.args.use_confembed:
|
| 140 |
+
# return gnn_embed*F.sigmoid(gnn_conf) + gnn_proj
|
| 141 |
+
# else:
|
| 142 |
+
# return gnn_embed + gnn_proj
|
| 143 |
+
|
| 144 |
+
# def embed_esm(self, pretrain_esm, mask_select_id, mask_select_feat):
|
| 145 |
+
# esm_embed = self.ESMEmbed(mask_select_id(pretrain_esm['pred_ids'])).squeeze()
|
| 146 |
+
# esm_conf = self.ESMConf(mask_select_id(pretrain_esm['confs']))
|
| 147 |
+
# esm_proj = self.ESMProj(mask_select_feat(pretrain_esm['embeds']))
|
| 148 |
+
# if self.args.use_confembed:
|
| 149 |
+
# return esm_embed*F.sigmoid(esm_conf) + esm_proj
|
| 150 |
+
# else:
|
| 151 |
+
# return esm_embed + esm_proj
|
| 152 |
+
|
| 153 |
+
# def embed_struct(self, pretrain_struct, mask_select_feat):
|
| 154 |
+
# struct_proj = self.StructProj(mask_select_feat(pretrain_struct['embeds']))
|
| 155 |
+
# return struct_proj
|
| 156 |
+
|
| 157 |
+
# def embed_esmif(self, pretrain_esmif, mask_select_feat):
|
| 158 |
+
# struct_proj = self.ESMIFProj(mask_select_feat(pretrain_esmif['embeds']))
|
| 159 |
+
# return struct_proj
|
| 160 |
+
|
| 161 |
+
def fuse(self, mask_select_feat, mask_select_id, gnn_embed=None, esm_embed=None, gearnet_embed=None, esmif_embed=None, gnn_pred_id=None, esm_pred_id=None, confidence=None, confidence_esm=None):
|
| 162 |
+
gnn, esm, gearnet, esmif, conf = 0, 0, 0, 0, 1.0
|
| 163 |
+
if gnn_embed is not None:
|
| 164 |
+
gnn = self.DesignProj(mask_select_feat(gnn_embed))
|
| 165 |
+
gnn += self.DesignEmbed(mask_select_id(gnn_pred_id)).squeeze()
|
| 166 |
+
|
| 167 |
+
if esm_embed is not None:
|
| 168 |
+
esm = self.ESMProj(mask_select_feat(esm_embed))
|
| 169 |
+
esm += self.ESMEmbed(mask_select_id(esm_pred_id)).squeeze()
|
| 170 |
+
|
| 171 |
+
if gearnet_embed is not None:
|
| 172 |
+
gearnet = self.StructProj(mask_select_feat(gearnet_embed))
|
| 173 |
+
|
| 174 |
+
if esmif_embed is not None:
|
| 175 |
+
esmif = self.ESMIFProj(mask_select_feat(esmif_embed))
|
| 176 |
+
|
| 177 |
+
if conf is not None:
|
| 178 |
+
conf = self.DesignConf(mask_select_id(confidence))
|
| 179 |
+
esm_conf = self.ESMConf(mask_select_id(confidence_esm))
|
| 180 |
+
|
| 181 |
+
return (gnn*conf+esm*esm_conf+gearnet+esmif)
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def forward(self, batch):
|
| 185 |
+
pretrain_design, h_E_raw, E_idx, mask_attend, batch_id = batch['pretrain_design'], batch['h_E'], batch['E_idx'], batch['attention_mask'], batch['batch_id']
|
| 186 |
+
|
| 187 |
+
if self.args.use_LM:
|
| 188 |
+
pretrain_esm_msa = batch['pretrain_esm_msa']
|
| 189 |
+
|
| 190 |
+
if self.args.use_gearnet:
|
| 191 |
+
pretrain_struct = batch['pretrain_struct']
|
| 192 |
+
|
| 193 |
+
if self.args.use_esmif:
|
| 194 |
+
pretrain_esmif = batch['pretrain_esmif']
|
| 195 |
+
|
| 196 |
+
mask_select_id = lambda x: torch.masked_select(x, mask_attend.bool()).reshape(-1,1)
|
| 197 |
+
mask_select_feat = lambda x: torch.masked_select(x, mask_attend.bool().unsqueeze(-1)).reshape(-1,x.shape[-1])
|
| 198 |
+
|
| 199 |
+
inputs_embeds = 0
|
| 200 |
+
for i in range(self.args.msa_n):
|
| 201 |
+
gnn_embed = pretrain_design['embeds']
|
| 202 |
+
esm_embed = pretrain_esm_msa['embeds'][i] if self.args.use_LM else None
|
| 203 |
+
gearnet_embed = pretrain_struct['embeds'][i] if self.args.use_gearnet else None
|
| 204 |
+
esmif_embed = pretrain_esmif['embeds'] if self.args.use_esmif else None
|
| 205 |
+
confidence = pretrain_design['confs']
|
| 206 |
+
confidence_esm = pretrain_esm_msa['confs'][i]
|
| 207 |
+
inputs_embeds += self.fuse(mask_select_feat, mask_select_id, gnn_embed, esm_embed, gearnet_embed, esmif_embed, pretrain_design['pred_ids'], pretrain_esm_msa['pred_ids'][i], confidence, confidence_esm)
|
| 208 |
+
|
| 209 |
+
h_V = inputs_embeds
|
| 210 |
+
h_E = self.EdgeEmbed(h_E_raw)
|
| 211 |
+
|
| 212 |
+
for layer in self.encoder_layers:
|
| 213 |
+
h_V, h_E = layer(h_V, h_E, E_idx, batch_id)
|
| 214 |
+
|
| 215 |
+
logits = self.ReadOut(h_V)
|
| 216 |
+
|
| 217 |
+
# confidence update
|
| 218 |
+
old_confs = mask_select_id(pretrain_design['confs'])
|
| 219 |
+
confs = torch.softmax(logits, dim=-1).max(dim=-1)[0][:,None]
|
| 220 |
+
h_V = h_V*self.MLP1(confs-old_confs) + inputs_embeds*self.MLP2(old_confs-confs)
|
| 221 |
+
logits = self.ReadOut(h_V)
|
| 222 |
+
|
| 223 |
+
B, N = pretrain_design['confs'].shape
|
| 224 |
+
vocab_size = logits.shape[-1]
|
| 225 |
+
|
| 226 |
+
new_logits = torch.zeros(B,N,vocab_size, device=logits.device).reshape(B*N, vocab_size)
|
| 227 |
+
new_logits = new_logits.masked_scatter_(mask_attend.bool().view(-1,1), logits)
|
| 228 |
+
new_logits = new_logits.reshape(B,N,vocab_size)
|
| 229 |
+
log_probs = torch.log_softmax(new_logits, dim=-1)
|
| 230 |
+
|
| 231 |
+
device = logits.device
|
| 232 |
+
seqs, confs, embeds, probs2 = self.to_matrix(h_V, logits, batch_id)
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
ret = {"pred_ids":seqs['input_ids'].to(device),
|
| 236 |
+
"confs":confs,
|
| 237 |
+
"embeds":embeds,
|
| 238 |
+
"probs":probs2,
|
| 239 |
+
"attention_mask":seqs['attention_mask'].to(device),
|
| 240 |
+
"h_E":h_E_raw,
|
| 241 |
+
"E_idx":E_idx,
|
| 242 |
+
"batch_id":batch_id,
|
| 243 |
+
"log_probs":log_probs}
|
| 244 |
+
return ret
|
| 245 |
+
|
| 246 |
+
def to_matrix(self, h_V, logits, batch_id):
|
| 247 |
+
|
| 248 |
+
probs = F.softmax(logits, dim=-1)
|
| 249 |
+
conf, pred_id = probs.max(dim=-1)
|
| 250 |
+
|
| 251 |
+
maxL = 0
|
| 252 |
+
for b in batch_id.unique():
|
| 253 |
+
mask = batch_id==b
|
| 254 |
+
L = mask.sum()
|
| 255 |
+
if L>maxL:
|
| 256 |
+
maxL=L
|
| 257 |
+
|
| 258 |
+
confs = []
|
| 259 |
+
seqs = []
|
| 260 |
+
embeds = []
|
| 261 |
+
probs2 = []
|
| 262 |
+
for b in batch_id.unique():
|
| 263 |
+
mask = batch_id==b
|
| 264 |
+
# elements = [alphabet[int(id)] for id in pred_id[mask]]
|
| 265 |
+
elements = self.tokenizer.decode(pred_id[mask]).split(" ")
|
| 266 |
+
seqs.append(elements)
|
| 267 |
+
confs.append(conf[mask])
|
| 268 |
+
embeds.append(h_V[mask])
|
| 269 |
+
probs2.append(probs[mask])
|
| 270 |
+
|
| 271 |
+
seqs = self.tokenizer(["".join(one) for one in seqs], padding=True, truncation=True, return_tensors='pt', add_special_tokens=False)
|
| 272 |
+
confs = torch.stack([F.pad(one, (0, maxL-len(one))) for one in confs])
|
| 273 |
+
embeds = torch.stack([F.pad(one, (0,0, 0, maxL-len(one))) for one in embeds])
|
| 274 |
+
probs2 = torch.stack([F.pad(one, (0,0, 0, maxL-len(one)), value=1/33) for one in probs2])
|
| 275 |
+
return seqs, confs, embeds, probs2
|
Flexpert-Design/src/models/__init__.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) CAIRI AI Lab. All rights reserved
|
| 2 |
+
|
| 3 |
+
# from .alphadesign_model import AlphaDesign_Model
|
| 4 |
+
# from .esmif_model import GVPTransformerModel as ESMIF_Model
|
| 5 |
+
# from .gca_model import GCA_Model
|
| 6 |
+
# from .graphtrans_model import GraphTrans_Model
|
| 7 |
+
# from .gvp_model import GVP_Model
|
| 8 |
+
# from .pifold_model import PiFold_Model
|
| 9 |
+
from .proteinmpnn_model import ProteinMPNN_Model
|
| 10 |
+
# from .structgnn_model import StructGNN_Model
|
| 11 |
+
# from .kwdesign_model import KWDesign_model
|
| 12 |
+
|
| 13 |
+
__all__ = [
|
| 14 |
+
'AlphaDesign_Model', 'ESMIF_Model', 'GCA_Model', 'GraphTrans_Model', 'GVP_Model',
|
| 15 |
+
'PiFold_Model', 'ProteinMPNN_Model', 'StructGNN_Model', 'KWDesign_model'
|
| 16 |
+
]
|
Flexpert-Design/src/models/alphadesign_model.py
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
from src.modules.alphadesign_module import ATDecoder, CNNDecoder, CNNDecoder2, StructureEncoder
|
| 5 |
+
from src.tools.design_utils import gather_nodes, _dihedrals, _rbf, _orientations_coarse_gl
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class AlphaDesign_Model(nn.Module):
|
| 9 |
+
def __init__(self, args, **kwargs):
|
| 10 |
+
""" Graph labeling network """
|
| 11 |
+
super(AlphaDesign_Model, self).__init__()
|
| 12 |
+
self.args = args
|
| 13 |
+
node_features = args.node_features
|
| 14 |
+
edge_features = args.edge_features
|
| 15 |
+
hidden_dim = args.hidden_dim
|
| 16 |
+
dropout = args.dropout
|
| 17 |
+
num_encoder_layers = args.num_encoder_layers
|
| 18 |
+
self.top_k = args.k_neighbors
|
| 19 |
+
self.num_rbf = 16
|
| 20 |
+
self.num_positional_embeddings = 16
|
| 21 |
+
|
| 22 |
+
if args.use_new_feat:
|
| 23 |
+
node_in, edge_in = 12, 16+7
|
| 24 |
+
else:
|
| 25 |
+
node_in, edge_in = 6, 16+7
|
| 26 |
+
self.node_embedding = nn.Linear(node_in, node_features, bias=True)
|
| 27 |
+
self.edge_embedding = nn.Linear(edge_in, edge_features, bias=True)
|
| 28 |
+
self.norm_nodes = nn.BatchNorm1d(node_features)
|
| 29 |
+
self.norm_edges = nn.BatchNorm1d(edge_features)
|
| 30 |
+
|
| 31 |
+
self.W_v = nn.Sequential(
|
| 32 |
+
nn.Linear(node_features, hidden_dim, bias=True),
|
| 33 |
+
nn.LeakyReLU(),
|
| 34 |
+
nn.BatchNorm1d(hidden_dim),
|
| 35 |
+
nn.Linear(hidden_dim, hidden_dim, bias=True),
|
| 36 |
+
nn.LeakyReLU(),
|
| 37 |
+
nn.BatchNorm1d(hidden_dim),
|
| 38 |
+
nn.Linear(hidden_dim, hidden_dim, bias=True)
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
self.W_e = nn.Linear(edge_features, hidden_dim, bias=True)
|
| 43 |
+
self.W_f = nn.Linear(edge_features, hidden_dim, bias=True)
|
| 44 |
+
|
| 45 |
+
self.encoder = StructureEncoder(hidden_dim, num_encoder_layers, dropout, use_SGT=self.args.use_SGT)
|
| 46 |
+
|
| 47 |
+
if args.autoregressive:
|
| 48 |
+
self.decoder = ATDecoder(args, hidden_dim, dropout)
|
| 49 |
+
else:
|
| 50 |
+
self.decoder = CNNDecoder(hidden_dim, hidden_dim)
|
| 51 |
+
self.decoder2 = CNNDecoder2(hidden_dim, hidden_dim)
|
| 52 |
+
|
| 53 |
+
# self.chain_embed = nn.Embedding(2,16)
|
| 54 |
+
self._init_params()
|
| 55 |
+
|
| 56 |
+
def forward(self, batch, AT_test = False, return_logit=False):
|
| 57 |
+
h_V, h_P, P_idx, batch_id = batch['_V'], batch['_E'], batch['E_idx'], batch['batch_id']
|
| 58 |
+
h_V = self.W_v(self.norm_nodes(self.node_embedding(h_V)))
|
| 59 |
+
h_P = self.W_e(self.norm_edges(self.edge_embedding(h_P)))
|
| 60 |
+
|
| 61 |
+
h_V = self.encoder(h_V, h_P, P_idx, batch_id)
|
| 62 |
+
log_probs0 = None
|
| 63 |
+
if AT_test:
|
| 64 |
+
log_probs = self.decoder.sampling(h_V, h_P, P_idx, batch_id)
|
| 65 |
+
else:
|
| 66 |
+
log_probs0, logits = self.decoder(h_V, batch_id)
|
| 67 |
+
log_probs, logits = self.decoder2(h_V, logits, batch_id)
|
| 68 |
+
if return_logit:
|
| 69 |
+
return {'log_probs': log_probs, 'logits': logits}
|
| 70 |
+
else:
|
| 71 |
+
return {'log_probs': log_probs, 'log_probs0': log_probs0}
|
| 72 |
+
|
| 73 |
+
def _init_params(self):
|
| 74 |
+
for p in self.parameters():
|
| 75 |
+
if p.dim() > 1:
|
| 76 |
+
nn.init.xavier_uniform_(p)
|
| 77 |
+
|
| 78 |
+
def _full_dist(self, X, mask, top_k=30, eps=1E-6):
|
| 79 |
+
mask_2D = torch.unsqueeze(mask,1) * torch.unsqueeze(mask,2)
|
| 80 |
+
dX = torch.unsqueeze(X,1) - torch.unsqueeze(X,2)
|
| 81 |
+
D = (1. - mask_2D)*10000 + mask_2D* torch.sqrt(torch.sum(dX**2, 3) + eps)
|
| 82 |
+
|
| 83 |
+
D_max, _ = torch.max(D, -1, keepdim=True)
|
| 84 |
+
D_adjust = D + (1. - mask_2D) * (D_max+1)
|
| 85 |
+
D_neighbors, E_idx = torch.topk(D_adjust, min(top_k, D_adjust.shape[-1]), dim=-1, largest=False)
|
| 86 |
+
return D_neighbors, E_idx
|
| 87 |
+
|
| 88 |
+
def _get_features(self, batch):
|
| 89 |
+
S, score, X, mask = batch['S'], batch['score'], batch['X'], batch['mask']
|
| 90 |
+
mask_bool = (mask==1)
|
| 91 |
+
|
| 92 |
+
B, N, _,_ = X.shape
|
| 93 |
+
X_ca = X[:,:,1,:]
|
| 94 |
+
D_neighbors, E_idx = self._full_dist(X_ca, mask, self.top_k)
|
| 95 |
+
|
| 96 |
+
# sequence
|
| 97 |
+
S = torch.masked_select(S, mask_bool)
|
| 98 |
+
if score is not None:
|
| 99 |
+
score = torch.masked_select(score, mask_bool)
|
| 100 |
+
|
| 101 |
+
# node feature
|
| 102 |
+
_V = _dihedrals(X)
|
| 103 |
+
if not self.args.use_new_feat:
|
| 104 |
+
_V = _V[...,:6]
|
| 105 |
+
_V = torch.masked_select(_V, mask_bool.unsqueeze(-1)).reshape(-1,_V.shape[-1])
|
| 106 |
+
|
| 107 |
+
# edge feature
|
| 108 |
+
_E = torch.cat((_rbf(D_neighbors, self.num_rbf), _orientations_coarse_gl(X, E_idx)), -1) # [4,387,387,23]
|
| 109 |
+
mask_attend = gather_nodes(mask.unsqueeze(-1), E_idx).squeeze(-1) # 一阶邻居节点的mask: 1代表节点存在, 0代表节点不存在
|
| 110 |
+
mask_attend = (mask.unsqueeze(-1) * mask_attend) == 1 # 自身的mask*邻居节点的mask
|
| 111 |
+
_E = torch.masked_select(_E, mask_attend.unsqueeze(-1)).reshape(-1,_E.shape[-1])
|
| 112 |
+
|
| 113 |
+
# edge index
|
| 114 |
+
shift = mask.sum(dim=1).cumsum(dim=0) - mask.sum(dim=1)
|
| 115 |
+
src = shift.view(B,1,1) + E_idx
|
| 116 |
+
src = torch.masked_select(src, mask_attend).view(1,-1)
|
| 117 |
+
dst = shift.view(B,1,1) + torch.arange(0, N, device=src.device).view(1,-1,1).expand_as(mask_attend)
|
| 118 |
+
dst = torch.masked_select(dst, mask_attend).view(1,-1)
|
| 119 |
+
E_idx = torch.cat((dst, src), dim=0).long()
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
# 3D point
|
| 123 |
+
sparse_idx = mask.nonzero()
|
| 124 |
+
X = X[sparse_idx[:,0],sparse_idx[:,1],:,:]
|
| 125 |
+
batch_id = sparse_idx[:,0]
|
| 126 |
+
|
| 127 |
+
mask = torch.masked_select(mask, mask_bool)
|
| 128 |
+
|
| 129 |
+
batch.update({'X':X,
|
| 130 |
+
'S':S,
|
| 131 |
+
'score':score,
|
| 132 |
+
'_V':_V,
|
| 133 |
+
'_E':_E,
|
| 134 |
+
'E_idx':E_idx,
|
| 135 |
+
'batch_id': batch_id,
|
| 136 |
+
'mask':mask})
|
| 137 |
+
|
| 138 |
+
return batch
|
Flexpert-Design/src/models/anm_prottrans.py
ADDED
|
@@ -0,0 +1,677 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#import dependencies
|
| 2 |
+
import os.path
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
| 8 |
+
from torch.utils.data import DataLoader
|
| 9 |
+
|
| 10 |
+
import re
|
| 11 |
+
import numpy as np
|
| 12 |
+
import pandas as pd
|
| 13 |
+
import copy
|
| 14 |
+
import pdb
|
| 15 |
+
|
| 16 |
+
import transformers, datasets
|
| 17 |
+
from transformers.modeling_outputs import TokenClassifierOutput, BaseModelOutputWithPastAndCrossAttentions
|
| 18 |
+
from transformers.models.t5.modeling_t5 import T5Config, T5PreTrainedModel, T5Stack
|
| 19 |
+
from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
|
| 20 |
+
from transformers import T5EncoderModel, T5Tokenizer
|
| 21 |
+
from transformers import TrainingArguments, Trainer, set_seed
|
| 22 |
+
from safetensors import safe_open
|
| 23 |
+
|
| 24 |
+
#DataCollator
|
| 25 |
+
from transformers.data.data_collator import DataCollatorMixin
|
| 26 |
+
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
|
| 27 |
+
from transformers.utils import PaddingStrategy
|
| 28 |
+
|
| 29 |
+
import random
|
| 30 |
+
import warnings
|
| 31 |
+
from collections.abc import Mapping
|
| 32 |
+
from dataclasses import dataclass
|
| 33 |
+
from random import randint
|
| 34 |
+
from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union
|
| 35 |
+
|
| 36 |
+
from evaluate import load
|
| 37 |
+
from datasets import Dataset
|
| 38 |
+
|
| 39 |
+
from tqdm import tqdm
|
| 40 |
+
import random
|
| 41 |
+
|
| 42 |
+
from scipy import stats
|
| 43 |
+
from sklearn.metrics import accuracy_score
|
| 44 |
+
|
| 45 |
+
import matplotlib.pyplot as plt
|
| 46 |
+
|
| 47 |
+
from Bio import SeqIO
|
| 48 |
+
from io import StringIO
|
| 49 |
+
import requests
|
| 50 |
+
import tempfile
|
| 51 |
+
|
| 52 |
+
from sklearn.model_selection import train_test_split
|
| 53 |
+
import csv
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
#### UTILS
|
| 57 |
+
|
| 58 |
+
class LoRAConfig:
|
| 59 |
+
def __init__(self):
|
| 60 |
+
self.lora_rank = 4
|
| 61 |
+
self.lora_init_scale = 0.01
|
| 62 |
+
self.lora_modules = ".*SelfAttention|.*EncDecAttention"
|
| 63 |
+
self.lora_layers = "q|k|v|o"
|
| 64 |
+
self.trainable_param_names = ".*layer_norm.*|.*lora_[ab].*"
|
| 65 |
+
self.lora_scaling_rank = 1
|
| 66 |
+
# lora_modules and lora_layers are speicified with regular expressions
|
| 67 |
+
# see https://www.w3schools.com/python/python_regex.asp for reference
|
| 68 |
+
|
| 69 |
+
class LoRALinear(nn.Module):
|
| 70 |
+
def __init__(self, linear_layer, rank, scaling_rank, init_scale):
|
| 71 |
+
super().__init__()
|
| 72 |
+
self.in_features = linear_layer.in_features
|
| 73 |
+
self.out_features = linear_layer.out_features
|
| 74 |
+
self.rank = rank
|
| 75 |
+
self.scaling_rank = scaling_rank
|
| 76 |
+
self.weight = linear_layer.weight
|
| 77 |
+
self.bias = linear_layer.bias
|
| 78 |
+
if self.rank > 0:
|
| 79 |
+
self.lora_a = nn.Parameter(torch.randn(rank, linear_layer.in_features) * init_scale)
|
| 80 |
+
if init_scale < 0:
|
| 81 |
+
self.lora_b = nn.Parameter(torch.randn(linear_layer.out_features, rank) * init_scale)
|
| 82 |
+
else:
|
| 83 |
+
self.lora_b = nn.Parameter(torch.zeros(linear_layer.out_features, rank))
|
| 84 |
+
if self.scaling_rank:
|
| 85 |
+
self.multi_lora_a = nn.Parameter(
|
| 86 |
+
torch.ones(self.scaling_rank, linear_layer.in_features)
|
| 87 |
+
+ torch.randn(self.scaling_rank, linear_layer.in_features) * init_scale
|
| 88 |
+
)
|
| 89 |
+
if init_scale < 0:
|
| 90 |
+
self.multi_lora_b = nn.Parameter(
|
| 91 |
+
torch.ones(linear_layer.out_features, self.scaling_rank)
|
| 92 |
+
+ torch.randn(linear_layer.out_features, self.scaling_rank) * init_scale
|
| 93 |
+
)
|
| 94 |
+
else:
|
| 95 |
+
self.multi_lora_b = nn.Parameter(torch.ones(linear_layer.out_features, self.scaling_rank))
|
| 96 |
+
|
| 97 |
+
def forward(self, input):
|
| 98 |
+
if self.scaling_rank == 1 and self.rank == 0:
|
| 99 |
+
# parsimonious implementation for ia3 and lora scaling
|
| 100 |
+
if self.multi_lora_a.requires_grad:
|
| 101 |
+
hidden = F.linear((input * self.multi_lora_a.flatten()), self.weight, self.bias)
|
| 102 |
+
else:
|
| 103 |
+
hidden = F.linear(input, self.weight, self.bias)
|
| 104 |
+
if self.multi_lora_b.requires_grad:
|
| 105 |
+
hidden = hidden * self.multi_lora_b.flatten()
|
| 106 |
+
return hidden
|
| 107 |
+
else:
|
| 108 |
+
# general implementation for lora (adding and scaling)
|
| 109 |
+
weight = self.weight
|
| 110 |
+
if self.scaling_rank:
|
| 111 |
+
weight = weight * torch.matmul(self.multi_lora_b, self.multi_lora_a) / self.scaling_rank
|
| 112 |
+
if self.rank:
|
| 113 |
+
weight = weight + torch.matmul(self.lora_b, self.lora_a) / self.rank
|
| 114 |
+
return F.linear(input, weight, self.bias)
|
| 115 |
+
|
| 116 |
+
def extra_repr(self):
|
| 117 |
+
return "in_features={}, out_features={}, bias={}, rank={}, scaling_rank={}".format(
|
| 118 |
+
self.in_features, self.out_features, self.bias is not None, self.rank, self.scaling_rank
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def modify_with_lora(transformer, config):
|
| 123 |
+
for m_name, module in dict(transformer.named_modules()).items():
|
| 124 |
+
if re.fullmatch(config.lora_modules, m_name):
|
| 125 |
+
for c_name, layer in dict(module.named_children()).items():
|
| 126 |
+
if re.fullmatch(config.lora_layers, c_name):
|
| 127 |
+
assert isinstance(
|
| 128 |
+
layer, nn.Linear
|
| 129 |
+
), f"LoRA can only be applied to torch.nn.Linear, but {layer} is {type(layer)}."
|
| 130 |
+
setattr(
|
| 131 |
+
module,
|
| 132 |
+
c_name,
|
| 133 |
+
LoRALinear(layer, config.lora_rank, config.lora_scaling_rank, config.lora_init_scale),
|
| 134 |
+
)
|
| 135 |
+
return transformer
|
| 136 |
+
|
| 137 |
+
class ClassConfig:
|
| 138 |
+
def __init__(self, dropout=0.2, num_labels=1, add_pearson_loss=False, add_sse_loss=False, adaptor_architecture = None , enm_embed_dim = 512, enm_att_heads = 8, kernel_size = 3, num_layers = 2, **kwargs):
|
| 139 |
+
self.dropout_rate = dropout
|
| 140 |
+
self.num_labels = num_labels
|
| 141 |
+
self.add_pearson_loss = add_pearson_loss
|
| 142 |
+
self.add_sse_loss = add_sse_loss
|
| 143 |
+
self.adaptor_architecture = adaptor_architecture
|
| 144 |
+
self.enm_embed_dim = enm_embed_dim
|
| 145 |
+
self.enm_att_heads = enm_att_heads
|
| 146 |
+
self.kernel_size = kernel_size
|
| 147 |
+
self.num_layers = num_layers
|
| 148 |
+
|
| 149 |
+
class ENMAdaptedAttentionClassifier(nn.Module):
|
| 150 |
+
def __init__(self, seq_embedding_dim, out_dim, enm_embed_dim, num_att_heads):
|
| 151 |
+
super(ENMAdaptedAttentionClassifier, self).__init__()
|
| 152 |
+
self.embedding = nn.Linear(1, enm_embed_dim)
|
| 153 |
+
self.enm_attention = nn.MultiheadAttention(enm_embed_dim, num_att_heads)
|
| 154 |
+
self.layer_norm = nn.LayerNorm(enm_embed_dim)
|
| 155 |
+
self.enm_adaptor = nn.Linear(enm_embed_dim, seq_embedding_dim)
|
| 156 |
+
self.adapted_classifier = nn.Linear(2*seq_embedding_dim, out_dim)
|
| 157 |
+
|
| 158 |
+
def forward(self, seq_embedding, enm_input):
|
| 159 |
+
enm_input = enm_input.transpose(0, 1) # Transpose to shape (N, B, E) for MultiheadAttention
|
| 160 |
+
enm_input = enm_input.unsqueeze(-1) # Add a dimension for the embedding
|
| 161 |
+
enm_input_embedded = self.embedding(enm_input)
|
| 162 |
+
enm_att, _ = self.enm_attention(enm_input_embedded, enm_input_embedded, enm_input_embedded)
|
| 163 |
+
enm_att = enm_att.transpose(0, 1) # Transpose back to shape (B, N, E)
|
| 164 |
+
enm_att = self.layer_norm(enm_att + enm_input.transpose(0, 1))
|
| 165 |
+
enm_embedding = self.enm_adaptor(enm_att)
|
| 166 |
+
combined_embedding = torch.cat((seq_embedding, enm_embedding), dim=-1)
|
| 167 |
+
logits = self.adapted_classifier(combined_embedding)
|
| 168 |
+
return logits
|
| 169 |
+
|
| 170 |
+
class ENMAdaptedConvClassifier(nn.Module):
|
| 171 |
+
def __init__(self, seq_embedding_dim, out_dim, kernel_size, enm_embedding_dim, num_layers):
|
| 172 |
+
super(ENMAdaptedConvClassifier, self).__init__()
|
| 173 |
+
layers = []
|
| 174 |
+
self.conv1 = nn.Conv1d(1, enm_embedding_dim, kernel_size=kernel_size, padding=(kernel_size-1)//2)
|
| 175 |
+
layers.append(self.conv1)
|
| 176 |
+
layers.append(nn.ReLU())
|
| 177 |
+
for i in range(num_layers-1):
|
| 178 |
+
layers.append(nn.Conv1d(enm_embedding_dim, enm_embedding_dim, kernel_size=kernel_size, padding=(kernel_size-1)//2))
|
| 179 |
+
layers.append(nn.ReLU())
|
| 180 |
+
self.conv_net = nn.Sequential(*layers)
|
| 181 |
+
self.adapted_classifier = nn.Linear(seq_embedding_dim+1, out_dim)
|
| 182 |
+
|
| 183 |
+
def forward(self, seq_embedding, enm_input, attention_mask=None):
|
| 184 |
+
enm_input = torch.nan_to_num(enm_input, nan=0.0)
|
| 185 |
+
enm_input = enm_input.unsqueeze(1)
|
| 186 |
+
conv_out = self.conv_net(enm_input)
|
| 187 |
+
enm_embedding = conv_out.transpose(1,2)
|
| 188 |
+
|
| 189 |
+
if attention_mask is not None:
|
| 190 |
+
# Use attention_mask to ignore padded elements
|
| 191 |
+
mask = attention_mask.unsqueeze(-1).float()
|
| 192 |
+
enm_embedding = enm_embedding * mask
|
| 193 |
+
# Compute mean over non-padded elements
|
| 194 |
+
|
| 195 |
+
enm_embedding = enm_embedding.mean(dim=-1).unsqueeze(-1)
|
| 196 |
+
# enm_embedding = enm_embedding.sum(dim=2)/ mask.sum(dim=2).clamp(min=1e-9)
|
| 197 |
+
else:
|
| 198 |
+
raise ValueError('We actually want to provide the mask.')
|
| 199 |
+
enm_embedding = torch.mean(enm_embedding, dim=1)
|
| 200 |
+
|
| 201 |
+
# enm_embedding = enm_embedding.unsqueeze(1).expand(-1, seq_embedding.size(1), -1)
|
| 202 |
+
combined_embedding = torch.cat((seq_embedding, enm_embedding), dim=-1)
|
| 203 |
+
logits = self.adapted_classifier(combined_embedding)
|
| 204 |
+
return logits
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
class ENMAdaptedDirectClassifier(nn.Module):
|
| 209 |
+
def __init__(self, seq_embedding_dim, out_dim):
|
| 210 |
+
super(ENMAdaptedDirectClassifier, self).__init__()
|
| 211 |
+
self.adapted_classifier = nn.Linear(seq_embedding_dim+1, out_dim)
|
| 212 |
+
|
| 213 |
+
def forward(self, seq_embedding, enm_input):
|
| 214 |
+
enm_input = enm_input.unsqueeze(-1)
|
| 215 |
+
combined_embedding = torch.cat((seq_embedding, enm_input), dim=-1)
|
| 216 |
+
logits = self.adapted_classifier(combined_embedding)
|
| 217 |
+
return logits
|
| 218 |
+
|
| 219 |
+
class ENMNoAdaptorClassifier(nn.Module):
|
| 220 |
+
def __init__(self, seq_embedding_dim, out_dim):
|
| 221 |
+
super(ENMNoAdaptorClassifier, self).__init__()
|
| 222 |
+
self.adapted_classifier = nn.Linear(seq_embedding_dim, out_dim)
|
| 223 |
+
|
| 224 |
+
def forward(self, seq_embedding, enm_input):
|
| 225 |
+
_ = enm_input #ignoring enm_input
|
| 226 |
+
logits = self.adapted_classifier(seq_embedding)
|
| 227 |
+
return logits
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
class T5EncoderForTokenClassification(T5PreTrainedModel):
|
| 231 |
+
|
| 232 |
+
def __init__(self, config: T5Config, class_config):
|
| 233 |
+
super().__init__(config)
|
| 234 |
+
self.num_labels = class_config.num_labels
|
| 235 |
+
self.config = config
|
| 236 |
+
self.add_pearson_loss = class_config.add_pearson_loss
|
| 237 |
+
self.add_sse_loss = class_config.add_sse_loss
|
| 238 |
+
self.shared = nn.Embedding(config.vocab_size, config.d_model)
|
| 239 |
+
|
| 240 |
+
encoder_config = copy.deepcopy(config)
|
| 241 |
+
encoder_config.use_cache = False
|
| 242 |
+
encoder_config.is_encoder_decoder = False
|
| 243 |
+
self.encoder = T5Stack(encoder_config, self.shared)
|
| 244 |
+
# self.encoder = CustomT5Stack(encoder_config, self.shared)
|
| 245 |
+
|
| 246 |
+
# import pdb; pdb.set_trace()
|
| 247 |
+
original_embedding = self.encoder.embed_tokens
|
| 248 |
+
in_dim, out_dim = tuple(original_embedding.weight.shape)
|
| 249 |
+
self.new_embedding = nn.Linear(in_dim, out_dim, bias=False).to('cuda:0') #TODO: pass the correct weights!!! And careful! the embedding layer and the linear layer are maybe mutually "transposed"
|
| 250 |
+
print("Initialized new_embedding layer - without weights yet!")
|
| 251 |
+
# self.new_embedding.weight = nn.Parameter(original_embedding.weight.T)
|
| 252 |
+
|
| 253 |
+
# self.weight = original_embedding.weight
|
| 254 |
+
# self.weight = nn.Parameter(self.new_embedding.weight.T)
|
| 255 |
+
# self.encoder.forward = new_forward.__get__(self.encoder, self.encoder.__class__)
|
| 256 |
+
|
| 257 |
+
self.dropout = nn.Dropout(class_config.dropout_rate)
|
| 258 |
+
if class_config.adaptor_architecture == 'attention':
|
| 259 |
+
self.classifier = ENMAdaptedAttentionClassifier(config.hidden_size, class_config.num_labels, class_config.enm_embed_dim, class_config.enm_att_heads) #nn.Linear(config.hidden_size, class_config.num_labels)
|
| 260 |
+
elif class_config.adaptor_architecture == 'direct':
|
| 261 |
+
self.classifier = ENMAdaptedDirectClassifier(config.hidden_size, class_config.num_labels)
|
| 262 |
+
elif class_config.adaptor_architecture == 'conv':
|
| 263 |
+
self.classifier = ENMAdaptedConvClassifier(config.hidden_size, class_config.num_labels, class_config.kernel_size, class_config.enm_embed_dim, class_config.num_layers)
|
| 264 |
+
elif class_config.adaptor_architecture == 'no-adaptor':
|
| 265 |
+
self.classifier = ENMNoAdaptorClassifier(config.hidden_size, class_config.num_labels)
|
| 266 |
+
else:
|
| 267 |
+
raise ValueError('Only attention, direct, conv and no-adaptor architectures are supported for the adaptor.')
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
# Initialize weights and apply final processing
|
| 271 |
+
self.post_init()
|
| 272 |
+
|
| 273 |
+
# Model parallel
|
| 274 |
+
self.model_parallel = False
|
| 275 |
+
self.device_map = None
|
| 276 |
+
|
| 277 |
+
def parallelize(self, device_map=None):
|
| 278 |
+
self.device_map = (
|
| 279 |
+
get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
|
| 280 |
+
if device_map is None
|
| 281 |
+
else device_map
|
| 282 |
+
)
|
| 283 |
+
assert_device_map(self.device_map, len(self.encoder.block))
|
| 284 |
+
self.encoder.parallelize(self.device_map)
|
| 285 |
+
self.classifier = self.classifier.to(self.encoder.first_device)
|
| 286 |
+
self.model_parallel = True
|
| 287 |
+
|
| 288 |
+
def deparallelize(self):
|
| 289 |
+
self.encoder.deparallelize()
|
| 290 |
+
self.encoder = self.encoder.to("cpu")
|
| 291 |
+
self.model_parallel = False
|
| 292 |
+
self.device_map = None
|
| 293 |
+
torch.cuda.empty_cache()
|
| 294 |
+
|
| 295 |
+
def get_input_embeddings(self):
|
| 296 |
+
return self.shared
|
| 297 |
+
|
| 298 |
+
def set_input_embeddings(self, new_embeddings):
|
| 299 |
+
self.shared = new_embeddings
|
| 300 |
+
self.encoder.set_input_embeddings(new_embeddings)
|
| 301 |
+
|
| 302 |
+
def get_encoder(self):
|
| 303 |
+
return self.encoder
|
| 304 |
+
|
| 305 |
+
def _prune_heads(self, heads_to_prune):
|
| 306 |
+
"""
|
| 307 |
+
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|
| 308 |
+
class PreTrainedModel
|
| 309 |
+
"""
|
| 310 |
+
for layer, heads in heads_to_prune.items():
|
| 311 |
+
self.encoder.layer[layer].attention.prune_heads(heads)
|
| 312 |
+
|
| 313 |
+
def forward(
|
| 314 |
+
self,
|
| 315 |
+
enm_vals = None,
|
| 316 |
+
input_ids=None,
|
| 317 |
+
attention_mask=None,
|
| 318 |
+
head_mask=None,
|
| 319 |
+
inputs_embeds=None,
|
| 320 |
+
labels=None,
|
| 321 |
+
output_attentions=None,
|
| 322 |
+
output_hidden_states=None,
|
| 323 |
+
return_dict=None,
|
| 324 |
+
):
|
| 325 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 326 |
+
if inputs_embeds is not None:
|
| 327 |
+
outputs = self.encoder(input_ids=None,attention_mask=attention_mask,inputs_embeds=inputs_embeds,head_mask=head_mask,output_attentions=output_attentions,output_hidden_states=output_hidden_states,return_dict=return_dict,)
|
| 328 |
+
elif input_ids is not None:
|
| 329 |
+
outputs = self.encoder(input_ids=input_ids,attention_mask=attention_mask,inputs_embeds=None,head_mask=head_mask,output_attentions=output_attentions,output_hidden_states=output_hidden_states,return_dict=return_dict,)
|
| 330 |
+
sequence_output = outputs[0]
|
| 331 |
+
# import pdb; pdb.set_trace() #TODO: CHECK EVERYTHING IS IN EVAL MODE and the dropout below is OFF
|
| 332 |
+
sequence_output = self.dropout(sequence_output)
|
| 333 |
+
#TODO: check the enm_vals are padded properly and check that the sequence limit (in the transformer) is indeed 512
|
| 334 |
+
# logits = self.classifier(sequence_output, enm_vals)
|
| 335 |
+
|
| 336 |
+
logits = self.classifier(sequence_output, enm_vals, attention_mask)
|
| 337 |
+
if not return_dict:
|
| 338 |
+
output = (logits,) + outputs[2:]
|
| 339 |
+
return ((loss,) + output) if loss is not None else output
|
| 340 |
+
|
| 341 |
+
return TokenClassifierOutput(
|
| 342 |
+
#loss=loss,
|
| 343 |
+
logits=logits,
|
| 344 |
+
hidden_states=outputs.hidden_states,
|
| 345 |
+
attentions=outputs.attentions,
|
| 346 |
+
)
|
| 347 |
+
|
| 348 |
+
class ENMAdaptedTrainer(Trainer):
|
| 349 |
+
def compute_loss(self, model, inputs, return_outputs=False):
|
| 350 |
+
labels = inputs.get("labels")
|
| 351 |
+
#enm_vals = inputs.get("enm_vals")
|
| 352 |
+
|
| 353 |
+
outputs = model(**inputs)
|
| 354 |
+
logits = outputs.get('logits')
|
| 355 |
+
mask = inputs.get('attention_mask')
|
| 356 |
+
loss_fct = MSELoss()
|
| 357 |
+
|
| 358 |
+
active_loss = mask.view(-1) == 1
|
| 359 |
+
active_logits = logits.view(-1)
|
| 360 |
+
active_labels = torch.where(active_loss, labels.view(-1), torch.tensor(-100).type_as(labels))
|
| 361 |
+
valid_logits=active_logits[active_labels!=-100]
|
| 362 |
+
valid_labels=active_labels[active_labels!=-100]
|
| 363 |
+
|
| 364 |
+
loss = loss_fct(valid_labels, valid_logits)
|
| 365 |
+
return (loss, outputs) if return_outputs else loss
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
def PT5_classification_model(half_precision, class_config):
|
| 370 |
+
# Load PT5 and tokenizer
|
| 371 |
+
# possible to load the half preciion model (thanks to @pawel-rezo for pointing that out)
|
| 372 |
+
if not half_precision:
|
| 373 |
+
model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_uniref50")
|
| 374 |
+
tokenizer = T5Tokenizer.from_pretrained("Rostlab/prot_t5_xl_uniref50")
|
| 375 |
+
elif half_precision and torch.cuda.is_available() :
|
| 376 |
+
tokenizer = T5Tokenizer.from_pretrained('Rostlab/prot_t5_xl_half_uniref50-enc', do_lower_case=False)
|
| 377 |
+
model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_half_uniref50-enc", torch_dtype=torch.float16).to(torch.device('cuda'))
|
| 378 |
+
else:
|
| 379 |
+
raise ValueError('Half precision can be run on GPU only.')
|
| 380 |
+
|
| 381 |
+
|
| 382 |
+
|
| 383 |
+
# Create new Classifier model with PT5 dimensions
|
| 384 |
+
class_model=T5EncoderForTokenClassification(model.config,class_config)
|
| 385 |
+
|
| 386 |
+
# Set encoder and embedding weights to checkpoint weights
|
| 387 |
+
class_model.shared=model.shared
|
| 388 |
+
class_model.encoder=model.encoder
|
| 389 |
+
|
| 390 |
+
# Delete the checkpoint model
|
| 391 |
+
model=class_model
|
| 392 |
+
del class_model
|
| 393 |
+
|
| 394 |
+
# Print number of trainable parameters
|
| 395 |
+
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
|
| 396 |
+
params = sum([np.prod(p.size()) for p in model_parameters])
|
| 397 |
+
print("ProtT5_Classfier\nTrainable Parameter: "+ str(params))
|
| 398 |
+
|
| 399 |
+
# Add model modification lora
|
| 400 |
+
config = LoRAConfig()
|
| 401 |
+
|
| 402 |
+
# Add LoRA layers
|
| 403 |
+
model = modify_with_lora(model, config)
|
| 404 |
+
|
| 405 |
+
# Freeze Embeddings and Encoder (except LoRA)
|
| 406 |
+
for (param_name, param) in model.shared.named_parameters():
|
| 407 |
+
param.requires_grad = False
|
| 408 |
+
for (param_name, param) in model.encoder.named_parameters():
|
| 409 |
+
param.requires_grad = False
|
| 410 |
+
|
| 411 |
+
for (param_name, param) in model.named_parameters():
|
| 412 |
+
if re.fullmatch(config.trainable_param_names, param_name):
|
| 413 |
+
param.requires_grad = True
|
| 414 |
+
|
| 415 |
+
# Print trainable Parameter
|
| 416 |
+
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
|
| 417 |
+
params = sum([np.prod(p.size()) for p in model_parameters])
|
| 418 |
+
print("ProtT5_LoRA_Classfier\nTrainable Parameter: "+ str(params) + "\n")
|
| 419 |
+
|
| 420 |
+
return model, tokenizer
|
| 421 |
+
|
| 422 |
+
|
| 423 |
+
@dataclass
|
| 424 |
+
class DataCollatorForTokenRegression(DataCollatorMixin):
|
| 425 |
+
"""
|
| 426 |
+
Data collator that will dynamically pad the inputs received, as well as the labels.
|
| 427 |
+
Args:
|
| 428 |
+
tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]):
|
| 429 |
+
The tokenizer used for encoding the data.
|
| 430 |
+
padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):
|
| 431 |
+
Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
|
| 432 |
+
among:
|
| 433 |
+
- `True` or `'longest'` (default): Pad to the longest sequence in the batch (or no padding if only a single
|
| 434 |
+
sequence is provided).
|
| 435 |
+
- `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
|
| 436 |
+
acceptable input length for the model if that argument is not provided.
|
| 437 |
+
- `False` or `'do_not_pad'`: No padding (i.e., can output a batch with sequences of different lengths).
|
| 438 |
+
max_length (`int`, *optional*):
|
| 439 |
+
Maximum length of the returned list and optionally padding length (see above).
|
| 440 |
+
pad_to_multiple_of (`int`, *optional*):
|
| 441 |
+
If set will pad the sequence to a multiple of the provided value.
|
| 442 |
+
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
|
| 443 |
+
7.5 (Volta).
|
| 444 |
+
label_pad_token_id (`int`, *optional*, defaults to -100):
|
| 445 |
+
The id to use when padding the labels (-100 will be automatically ignore by PyTorch loss functions).
|
| 446 |
+
return_tensors (`str`):
|
| 447 |
+
The type of Tensor to return. Allowable values are "np", "pt" and "tf".
|
| 448 |
+
"""
|
| 449 |
+
|
| 450 |
+
tokenizer: PreTrainedTokenizerBase
|
| 451 |
+
padding: Union[bool, str, PaddingStrategy] = True
|
| 452 |
+
max_length: Optional[int] = None
|
| 453 |
+
pad_to_multiple_of: Optional[int] = None
|
| 454 |
+
label_pad_token_id: int = -100
|
| 455 |
+
return_tensors: str = "pt"
|
| 456 |
+
|
| 457 |
+
def torch_call(self, features):
|
| 458 |
+
label_name = "label" if "label" in features[0].keys() else "labels"
|
| 459 |
+
labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None
|
| 460 |
+
|
| 461 |
+
no_labels_features = [{k: v for k, v in feature.items() if k != label_name and k!= 'enm_vals'} for feature in features]
|
| 462 |
+
|
| 463 |
+
batch = self.tokenizer.pad(
|
| 464 |
+
no_labels_features,
|
| 465 |
+
padding=self.padding,
|
| 466 |
+
max_length=self.max_length,
|
| 467 |
+
pad_to_multiple_of=self.pad_to_multiple_of,
|
| 468 |
+
return_tensors="pt",
|
| 469 |
+
)
|
| 470 |
+
|
| 471 |
+
batch['enm_vals'] = torch.nn.utils.rnn.pad_sequence([torch.tensor(feature['enm_vals'], dtype=torch.float) for feature in features], batch_first=True, padding_value=0.0)
|
| 472 |
+
#batch = self.tokenizer.pad(no_labels_features,padding=self.padding,max_length=self.max_length,pad_to_multiple_of=self.pad_to_multiple_of,return_tensors="pt")
|
| 473 |
+
if labels is None:
|
| 474 |
+
return batch
|
| 475 |
+
|
| 476 |
+
sequence_length = batch["input_ids"].shape[1]
|
| 477 |
+
padding_side = self.tokenizer.padding_side
|
| 478 |
+
|
| 479 |
+
def to_list(tensor_or_iterable):
|
| 480 |
+
if isinstance(tensor_or_iterable, torch.Tensor):
|
| 481 |
+
return tensor_or_iterable.tolist()
|
| 482 |
+
return list(tensor_or_iterable)
|
| 483 |
+
|
| 484 |
+
if padding_side == "right":
|
| 485 |
+
batch[label_name] = [
|
| 486 |
+
to_list(label) + [self.label_pad_token_id] * (sequence_length - len(label)) for label in labels
|
| 487 |
+
|
| 488 |
+
]
|
| 489 |
+
else:
|
| 490 |
+
batch[label_name] = [
|
| 491 |
+
[self.label_pad_token_id] * (sequence_length - len(label)) + to_list(label) for label in labels
|
| 492 |
+
]
|
| 493 |
+
|
| 494 |
+
batch[label_name] = torch.tensor(batch[label_name], dtype=torch.float)
|
| 495 |
+
return batch
|
| 496 |
+
|
| 497 |
+
def _torch_collate_batch(examples, tokenizer, pad_to_multiple_of: Optional[int] = None):
|
| 498 |
+
"""Collate `examples` into a batch, using the information in `tokenizer` for padding if necessary."""
|
| 499 |
+
# Tensorize if necessary.
|
| 500 |
+
if isinstance(examples[0], (list, tuple, np.ndarray)):
|
| 501 |
+
examples = [torch.tensor(e, dtype=torch.long) for e in examples]
|
| 502 |
+
|
| 503 |
+
length_of_first = examples[0].size(0)
|
| 504 |
+
|
| 505 |
+
# Check if padding is necessary.
|
| 506 |
+
|
| 507 |
+
are_tensors_same_length = all(x.size(0) == length_of_first for x in examples)
|
| 508 |
+
if are_tensors_same_length and (pad_to_multiple_of is None or length_of_first % pad_to_multiple_of == 0):
|
| 509 |
+
return torch.stack(examples, dim=0)
|
| 510 |
+
|
| 511 |
+
# If yes, check if we have a `pad_token`.
|
| 512 |
+
if tokenizer._pad_token is None:
|
| 513 |
+
raise ValueError(
|
| 514 |
+
"You are attempting to pad samples but the tokenizer you are using"
|
| 515 |
+
f" ({tokenizer.__class__.__name__}) does not have a pad token."
|
| 516 |
+
)
|
| 517 |
+
|
| 518 |
+
# Creating the full tensor and filling it with our data.
|
| 519 |
+
max_length = max(x.size(0) for x in examples)
|
| 520 |
+
if pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
|
| 521 |
+
max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
|
| 522 |
+
result = examples[0].new_full([len(examples), max_length], tokenizer.pad_token_id)
|
| 523 |
+
for i, example in enumerate(examples):
|
| 524 |
+
if tokenizer.padding_side == "right":
|
| 525 |
+
result[i, : example.shape[0]] = example
|
| 526 |
+
else:
|
| 527 |
+
result[i, -example.shape[0] :] = example
|
| 528 |
+
return result
|
| 529 |
+
|
| 530 |
+
def tolist(x):
|
| 531 |
+
if isinstance(x, list):
|
| 532 |
+
return x
|
| 533 |
+
elif hasattr(x, "numpy"): # Checks for TF tensors without needing the import
|
| 534 |
+
x = x.numpy()
|
| 535 |
+
return x.tolist()
|
| 536 |
+
|
| 537 |
+
#### END OF UTILS
|
| 538 |
+
|
| 539 |
+
def do_topology_split(df, split_path):
|
| 540 |
+
import json
|
| 541 |
+
with open(split_path, 'r') as f:
|
| 542 |
+
splits = json.load(f)
|
| 543 |
+
|
| 544 |
+
#split the dataframe according to the splits
|
| 545 |
+
train_df = df[df['name'].isin(splits['train'])]
|
| 546 |
+
valid_df = df[df['name'].isin(splits['validation'])]
|
| 547 |
+
test_df = df[df['name'].isin(splits['test'])]
|
| 548 |
+
return train_df, valid_df, test_df
|
| 549 |
+
|
| 550 |
+
|
| 551 |
+
class ANMAwareFlexibilityProtTrans(nn.Module):
|
| 552 |
+
def __init__(self, gumbel_temperature, **kwargs):
|
| 553 |
+
super(ANMAwareFlexibilityProtTrans, self).__init__()
|
| 554 |
+
|
| 555 |
+
model, tokenizer = self.load_finetuned_model(**kwargs)
|
| 556 |
+
self.model = model
|
| 557 |
+
self.tokenizer = tokenizer
|
| 558 |
+
self.device = torch.device('cuda')
|
| 559 |
+
self.model.to(self.device)
|
| 560 |
+
self.model.eval()
|
| 561 |
+
self.gumbel_temperature = gumbel_temperature
|
| 562 |
+
self.logit_transform = nn.functional.gumbel_softmax #Use the Straight Through Gumbel SoftMax - in forward process it does argmax,
|
| 563 |
+
|
| 564 |
+
# in the backward process it approximates the gradient of argmax by the gradient of the Gumbel Softmax
|
| 565 |
+
# https://pytorch.org/docs/stable/generated/torch.nn.functional.gumbel_softmax.html set hard=True to do the Straight-Through trick
|
| 566 |
+
|
| 567 |
+
self.conversion_tensor = self.construct_pmpnn_t5_conversion_tensor()
|
| 568 |
+
|
| 569 |
+
|
| 570 |
+
def construct_pmpnn_t5_conversion_tensor(self):
|
| 571 |
+
"""
|
| 572 |
+
Creates tensor which takes the onehot encodings in the proteinmpnn vocabulary and maps them to ProtTrans vocabulary.
|
| 573 |
+
"""
|
| 574 |
+
_one_hots = []
|
| 575 |
+
for idx in [[0,1], 2, [3, 29, 30, 31, 32], 5, 4, 6, 7, 8, 10, 9, 13, 11, 12, 14, 15, 18, 16, 17, 19, 20, 21, 22, 23, 24, 25, 28, 26, 27]:
|
| 576 |
+
if isinstance(idx, int):
|
| 577 |
+
_oh = F.one_hot(torch.tensor([idx]), 33)
|
| 578 |
+
else:
|
| 579 |
+
_sohs = []
|
| 580 |
+
for subidx in idx:
|
| 581 |
+
_soh = F.one_hot(torch.tensor([subidx]), 33)
|
| 582 |
+
_sohs.append(_soh)
|
| 583 |
+
_oh = torch.sum(torch.stack(_sohs), dim=0)
|
| 584 |
+
_one_hots.append(_oh)
|
| 585 |
+
#_one_hots = [F.one_hot(torch.tensor([idx]), 33)[0] if isinstance(idx, int) else torch.sum(torch.stack([F.one_hot(torch.tensor([subidx]), 33)[0] for subidx in idx]), dim=0) for idx in [[0,1], 2, [3, 29, 30, 31, 32], 5, 4, 6, 7, 8, 10, 9, 13, 11, 12, 14, 15, 18, 16, 17, 19, 20, 21, 22, 23, 24, 25, 28, 26, 27]]
|
| 586 |
+
_one_hots.extend([torch.zeros((1,33)) for _ in range(100)])
|
| 587 |
+
return torch.cat(_one_hots, dim=0).to(torch.device('cuda')).float()
|
| 588 |
+
|
| 589 |
+
def load_finetuned_model(self, checkpoint_path, half_precision, **kwargs):#num_labels, add_pearson_loss, add_sse_loss, adaptor_architecture, enm_embed_dim, enm_att_heads, num_layers, kernel_size):
|
| 590 |
+
class_config=ClassConfig(**kwargs) #um_labels=num_labels, add_pearson_loss=add_pearson_loss, add_sse_loss=add_sse_loss, adaptor_architecture = adaptor_architecture, enm_embed_dim = enm_embed_dim, enm_att_heads = enm_att_heads, num_layers = num_layers, kernel_size = kernel_size)
|
| 591 |
+
model, tokenizer = PT5_classification_model(half_precision=half_precision, class_config=class_config) #.from_pretrained(args.model_path)
|
| 592 |
+
|
| 593 |
+
# model.load_state_dict(torch.load(args.model_path))
|
| 594 |
+
# try:
|
| 595 |
+
# with safe_open(f"{checkpoint_path}/model.safetensors", framework="pt", device="cuda:0") as f:
|
| 596 |
+
# state_dict = {}
|
| 597 |
+
# for key in f.keys():
|
| 598 |
+
# state_dict[key] = f.get_tensor(key)
|
| 599 |
+
# model.load_state_dict(state_dict, strict=False)
|
| 600 |
+
# except:
|
| 601 |
+
# state_dict = torch.load(f"{checkpoint_path}/pytorch_model.bin", map_location='cuda:0')
|
| 602 |
+
# model.load_state_dict(state_dict, strict=False)
|
| 603 |
+
state_dict = torch.load(checkpoint_path, map_location='cuda:0')
|
| 604 |
+
model.load_state_dict(state_dict, strict=False)
|
| 605 |
+
model.eval()
|
| 606 |
+
|
| 607 |
+
original_embedding = model.encoder.embed_tokens
|
| 608 |
+
model.new_embedding.weight = nn.Parameter(original_embedding.weight.T)
|
| 609 |
+
print('Set the weights for the new embedding layer!')
|
| 610 |
+
return model, tokenizer
|
| 611 |
+
|
| 612 |
+
def translate_to_model_vocab(self, batch_one_hot, trail_idcs):
|
| 613 |
+
# Pad the batch_one_hot tensor with zeros along the last dimension
|
| 614 |
+
batch_one_hot = F.pad(batch_one_hot, (0, 1, 0, 0, 0, 0), 'constant', 0)
|
| 615 |
+
|
| 616 |
+
#TODO: VERIFY THAT THE GRADIENTS ARE OK AFTER THE MASKED_SCATTER OPERATION
|
| 617 |
+
# Create a mask for the '2' token
|
| 618 |
+
mask = torch.zeros_like(batch_one_hot, dtype=torch.bool)
|
| 619 |
+
for i, trail_idx in enumerate(trail_idcs):
|
| 620 |
+
if trail_idx < batch_one_hot.size(2): # Ensure index is within bounds
|
| 621 |
+
mask[i, :, trail_idx] = True
|
| 622 |
+
|
| 623 |
+
# Create a tensor with '2' in the one-hot encoding
|
| 624 |
+
token_2 = torch.zeros_like(batch_one_hot)
|
| 625 |
+
token_2[:, 2, :] = 1 # Assuming '2' corresponds to index 2 in the one-hot encoding
|
| 626 |
+
|
| 627 |
+
# Use masked_scatter_ to modify the tensor in-place while preserving gradients
|
| 628 |
+
batch_one_hot.masked_scatter_(mask, token_2[mask])
|
| 629 |
+
|
| 630 |
+
T5_translation = torch.einsum('ej,ijk->iek', self.conversion_tensor, batch_one_hot)
|
| 631 |
+
T5_translation = T5_translation.permute(0,2,1)
|
| 632 |
+
return T5_translation
|
| 633 |
+
|
| 634 |
+
def forward(self, pmpnn_logits, anm_input, trail_idcs, attention_mask, sampled_pmpnn_sequence = None, alphabet = None): #batch example 32x33x395 (batch_size x ProteinMPNN vocab size x seq length)
|
| 635 |
+
|
| 636 |
+
anm_input = F.pad(anm_input, (0, 1, 0, 0), 'constant', 0)
|
| 637 |
+
attention_mask = F.pad(attention_mask, (0, 1, 0, 0), 'constant', 1)
|
| 638 |
+
|
| 639 |
+
if sampled_pmpnn_sequence is None:
|
| 640 |
+
if alphabet is None:
|
| 641 |
+
batch_one_hot = self.logit_transform(pmpnn_logits, tau=self.gumbel_temperature, hard=True, dim=1)
|
| 642 |
+
batch_token_ids = self.translate_to_model_vocab(batch_one_hot, trail_idcs)
|
| 643 |
+
inputs = batch_token_ids #.to(torch.int)
|
| 644 |
+
# elif alphabet == 'aa':
|
| 645 |
+
# batch_one_hot = ... #TODO one hot encode the pmpnn tokens
|
| 646 |
+
# batch_token_ids = self.translate_to_model_vocab(batch_one_hot, trail_idcs)
|
| 647 |
+
# input_ids = ... #TODO: argmax to get the tokens from the one hot encodings
|
| 648 |
+
# outputs = self.model(input_ids = input_ids, enm_vals=anm_input, attention_mask = attention_mask) #TODO?: pass the mask as well (take it from the batch, pad it for the end of sequence, convert to Tensor)
|
| 649 |
+
# predicted_flex = outputs.logits
|
| 650 |
+
# return {'predicted_flex': predicted_flex, 'enm_vals': anm_input, 'input_ids': input_ids}
|
| 651 |
+
|
| 652 |
+
# elif alphabet is None:
|
| 653 |
+
# raise ValueError('need to specify what alphabet is used to encode sampled_pmpnn_sequence!')
|
| 654 |
+
# elif alphabet is 'pmpnn':
|
| 655 |
+
# # Convert sampled_pmpnn_sequence to one-hot encoding
|
| 656 |
+
# batch_one_hot = F.one_hot(sampled_pmpnn_sequence, num_classes=33).float().permute(0,2,1)
|
| 657 |
+
# batch_token_ids = self.translate_to_model_vocab(batch_one_hot, trail_idcs)
|
| 658 |
+
# inputs = batch_token_ids
|
| 659 |
+
# elif alphabet is 'pt5':
|
| 660 |
+
# inputs = F.one_hot(sampled_pmpnn_sequence, num_classes=128).float() #.permute(0,2,1)
|
| 661 |
+
# elif alphabet is 'aa':
|
| 662 |
+
# ... #TODO apply tokenizer
|
| 663 |
+
# #tokens = self.tokenizer(" ".join(sampled_pmpnn_sequence))
|
| 664 |
+
# tokens = self.tokenizer(" ".join(sampled_pmpnn_sequence))
|
| 665 |
+
# input_ids = torch.tensor(tokens['input_ids']).cuda().unsqueeze(0)
|
| 666 |
+
|
| 667 |
+
# outputs = self.model(input_ids = input_ids, enm_vals=anm_input, attention_mask = attention_mask) #TODO?: pass the mask as well (take it from the batch, pad it for the end of sequence, convert to Tensor)
|
| 668 |
+
# predicted_flex = outputs.logits
|
| 669 |
+
# return {'predicted_flex': predicted_flex, 'enm_vals': anm_input, 'input_ids': input_ids}
|
| 670 |
+
|
| 671 |
+
inputs_embeds = self.model.new_embedding(inputs) #TODO pass through embedding
|
| 672 |
+
outputs = self.model(enm_vals=anm_input, inputs_embeds = inputs_embeds, attention_mask = attention_mask) #TODO?: pass the mask as well (take it from the batch, pad it for the end of sequence, convert to Tensor)
|
| 673 |
+
#TODO: above it throws RuntimeError: Expected tensor for argument #1 'indices' to have one of the following scalar types:
|
| 674 |
+
# Long, Int; but got torch.cuda.FloatTensor instead (while checking arguments for embedding)
|
| 675 |
+
|
| 676 |
+
predicted_flex = outputs.logits
|
| 677 |
+
return {'predicted_flex': predicted_flex, 'enm_vals': anm_input, 'input_ids': inputs}
|