Honzus24 commited on
Commit
7968cb0
·
1 Parent(s): 99ed7c2

initial commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +0 -34
  2. .gitignore +3 -0
  3. .gitmodules +3 -0
  4. .gradio/certificate.pem +31 -0
  5. Flexpert-Design/README.md +69 -0
  6. Flexpert-Design/configs/ANMAwareFlexibilityProtTrans.yaml +12 -0
  7. Flexpert-Design/configs/Flexpert-Design-inference.yaml +1 -0
  8. Flexpert-Design/configs/ProteinMPNN.py +14 -0
  9. Flexpert-Design/configs/ProteinMPNN.yaml +13 -0
  10. Flexpert-Design/data_interface.py +205 -0
  11. Flexpert-Design/data_utils.py +535 -0
  12. Flexpert-Design/download-cath-data.sh +17 -0
  13. Flexpert-Design/model_interface.py +631 -0
  14. Flexpert-Design/predict.py +148 -0
  15. Flexpert-Design/predict_example/1ah7_A.pdb +0 -0
  16. Flexpert-Design/predict_example/1ah7_A_instructions.csv +1 -0
  17. Flexpert-Design/predict_example/compare_seqs.py +59 -0
  18. Flexpert-Design/predict_example/predictions.txt +2 -0
  19. Flexpert-Design/requirements.txt +23 -0
  20. Flexpert-Design/src/__init__.py +49 -0
  21. Flexpert-Design/src/datasets/__init__.py +15 -0
  22. Flexpert-Design/src/datasets/alphafold_dataset.py +112 -0
  23. Flexpert-Design/src/datasets/atlas_dataset.py +133 -0
  24. Flexpert-Design/src/datasets/casp_dataset.py +57 -0
  25. Flexpert-Design/src/datasets/cath_dataset.py +141 -0
  26. Flexpert-Design/src/datasets/dataloader.py +161 -0
  27. Flexpert-Design/src/datasets/fast_dataloader.py +52 -0
  28. Flexpert-Design/src/datasets/featurizer.py +743 -0
  29. Flexpert-Design/src/datasets/flex_cath_dataset.py +155 -0
  30. Flexpert-Design/src/datasets/foldswitchers_dataset.py +128 -0
  31. Flexpert-Design/src/datasets/mpnn_dataset.py +492 -0
  32. Flexpert-Design/src/datasets/pdb_inference.py +329 -0
  33. Flexpert-Design/src/datasets/ts_dataset.py +47 -0
  34. Flexpert-Design/src/datasets/utils.py +99 -0
  35. Flexpert-Design/src/interface/__init__.py +0 -0
  36. Flexpert-Design/src/interface/data_interface.py +66 -0
  37. Flexpert-Design/src/interface/model_interface.py +89 -0
  38. Flexpert-Design/src/interface/pretrain_interface.py +405 -0
  39. Flexpert-Design/src/models/E3PiFold_model.py +90 -0
  40. Flexpert-Design/src/models/MemoryESM.py +164 -0
  41. Flexpert-Design/src/models/MemoryESMIF.py +116 -0
  42. Flexpert-Design/src/models/MemoryPiFold.py +143 -0
  43. Flexpert-Design/src/models/MemoryTuning.py +213 -0
  44. Flexpert-Design/src/models/PretrainESMIF_model.py +32 -0
  45. Flexpert-Design/src/models/PretrainESM_model.py +35 -0
  46. Flexpert-Design/src/models/PretrainPiFold_model.py +64 -0
  47. Flexpert-Design/src/models/Tuning.py +275 -0
  48. Flexpert-Design/src/models/__init__.py +16 -0
  49. Flexpert-Design/src/models/alphadesign_model.py +138 -0
  50. Flexpert-Design/src/models/anm_prottrans.py +677 -0
.gitattributes CHANGED
@@ -1,35 +1 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
  *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
1
  *.ckpt filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ models/weights/
2
+ Flexpert-Design/
3
+ data/atlas/
.gitmodules ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [submodule "gradio_molecule3d"]
2
+ path = gradio_molecule3d
3
+ url = https://github.com/Honzus/gradio_molecule3d
.gradio/certificate.pem ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -----BEGIN CERTIFICATE-----
2
+ MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
3
+ TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
4
+ cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
5
+ WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
6
+ ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
7
+ MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
8
+ h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
9
+ 0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
10
+ A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
11
+ T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
12
+ B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
13
+ B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
14
+ KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
15
+ OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
16
+ jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
17
+ qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
18
+ rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
19
+ HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
20
+ hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
21
+ ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
22
+ 3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
23
+ NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
24
+ ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
25
+ TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
26
+ jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
27
+ oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
28
+ 4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
29
+ mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
30
+ emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
31
+ -----END CERTIFICATE-----
Flexpert-Design/README.md ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Flexpert-Design
2
+
3
+ In this directory we provide the code to train and run inference with Flexpert-Design. To expedite the release of the codebase, this part of code was not thoroughly curated and contains redundant files and code. The codebase might be revised in the future but probably it will get completely rewritten as part of a future project with an improved model.
4
+
5
+ ## Environment
6
+
7
+ Tested for Python 3.9. For other versions enviroment might need to be adapted.
8
+
9
+ Assuming you have already installed the environment for Flexpert-3D and Flexpert-Seq, install the additional dependencies for Flexpert-Design using the `requirements.txt` file in this directory.
10
+
11
+ ```bash
12
+ pip install -r requirements.txt
13
+ ```
14
+
15
+
16
+ ## Inference
17
+
18
+ In this example we will illustrate how to run inference with the trained model (trained wights are provided inside the train/results directory, you do not need to train the model again necessarily).
19
+
20
+ Place the PDB files you want to predict in the `predict_example` directory. It is expected that the files are named like `PDBCODE_CHAINID.pdb`, example file '1ahy_A.pdb' is provided. For each PDB file in that folder, add the instructions on flexibility you want to be considered by the ProteinMPNN model in the `PDBCODE_CHAINID_instructions.csv` file - example file '1ah7_A_instructions.csv' is provided. Then run the following command to run inference.
21
+
22
+ ```bash
23
+ python3 predict.py \
24
+ --infer_path predict_example/
25
+ ```
26
+
27
+ The output will be in the `predict_example/predictions.txt` file.
28
+
29
+ The origininal sequence and the regenerated sequence can be compared using the following script.
30
+
31
+ ```bash
32
+ python3 predict_example/compare_seqs.py \
33
+ --pdb_code 1ah7_A
34
+ ```
35
+
36
+ ## Training
37
+
38
+ First make sure you have the Flexpert-3D model weights in the `Flexpert/models/weights` directory. Alternatively run the following script to download the weights.
39
+
40
+ ```bash
41
+ . ../download_flexpert_weights.sh
42
+ ```
43
+
44
+ Download the training data:
45
+
46
+ ```bash
47
+ . ../download-cath-data.sh
48
+ ```
49
+
50
+ Then run the following command to train the model.
51
+
52
+ ```bash
53
+ export HF_HOME=./HF_cache
54
+ python3 train.py \
55
+ --batch_size 4 \
56
+ --model_name 'ProteinMPNN' \
57
+ --stage 'fit' \
58
+ --dataset FLEX_CATH4.3 \
59
+ --ex_name training-reproduction \
60
+ --offline 0 \
61
+ --gpus 1 \
62
+ --epoch 11 \
63
+ --use_dynamics 1 \
64
+ --flex_loss_coeff 0.8 \
65
+ --init_flex_features 1 \
66
+ --grad_normalization 0 \
67
+ --loss_fn MSE \
68
+ --use_pmpnn_checkpoint 1
69
+ ```
Flexpert-Design/configs/ANMAwareFlexibilityProtTrans.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ checkpoint_path: ../models/weights/flexpert_3d_weights.bin
2
+ data_jsonl_name: /NEWcath_ANM_gt_flex_annotated.jsonl #/debug_cath_ANM_gt_flex_annotated.jsonl #/cath_ANM_gt_flex_annotated.jsonl
3
+ half_precision: False #mixed_precision
4
+ gumbel_temperature: 0.2
5
+ num_labels: 1
6
+ add_pearson_loss: False
7
+ add_sse_loss: False
8
+ adaptor_architecture: 'conv'
9
+ enm_embed_dim: 128
10
+ enm_att_heads: 8
11
+ num_layers: 3
12
+ kernel_size: 5
Flexpert-Design/configs/Flexpert-Design-inference.yaml ADDED
@@ -0,0 +1 @@
 
 
1
+ pmpnn_model_path: 'train/results/MSEloss_flex_cath_coeff_0.8/checkpoints/last.ckpt'
Flexpert-Design/configs/ProteinMPNN.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ method = 'ProteinMPNN'
2
+ hidden = 128
3
+ k_neighbors=30
4
+ num_letters = 33
5
+ num_encoder_layers = 3
6
+ num_decoder_layers = 3
7
+ vocab = 33
8
+ dropout = 0.1
9
+ smoothing = 0.1
10
+ batch_size = 8
11
+ lr = 0.001
12
+ proteinmpnn_type = 0
13
+ patience = 100
14
+ epoch=100
Flexpert-Design/configs/ProteinMPNN.yaml ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ augment_eps: 0.0
2
+ num_encoder_layers: 3
3
+ hidden_dim: 128
4
+ hidden: 128
5
+ k_neighbors: 30
6
+ num_letters: 33
7
+ num_decoder_layers: 3
8
+ vocab: 33
9
+ dropout: 0.1
10
+ smoothing: 0.1
11
+ proteinmpnn_type: 0
12
+ init_flex_features: 1
13
+ starting_checkpoint_path: 'vanilla_mpnn_weights/best-epoch=99-recovery=0.485.ckpt'
Flexpert-Design/data_interface.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ from torch.utils.data import DataLoader
3
+ from src.interface.data_interface import DInterface_base
4
+ import torch
5
+ import os.path as osp
6
+ from src.tools.utils import cuda
7
+ import pdb
8
+ from src.tools.utils import load_yaml_config
9
+
10
+ class MyDataLoader(DataLoader):
11
+ def __init__(self, dataset, model_name, batch_size=64, num_workers=8, *args, **kwargs):
12
+ super().__init__(dataset, batch_size=batch_size, num_workers=num_workers, *args, **kwargs)
13
+ self.pretrain_device = 'cuda:0'
14
+ self.model_name = model_name
15
+
16
+ def __iter__(self):
17
+ for batch in super().__iter__():
18
+ # 在这里对batch进行处理
19
+ # ...
20
+ try:
21
+ self.pretrain_device = f'cuda:{torch.distributed.get_rank()}'
22
+ except:
23
+ self.pretrain_device = 'cuda:0'
24
+
25
+ stream = torch.cuda.Stream(
26
+ self.pretrain_device
27
+ )
28
+ with torch.cuda.stream(stream):
29
+ if self.model_name=='GVP':
30
+ batch = batch.cuda(non_blocking=True, device=self.pretrain_device)
31
+ yield batch
32
+ else:
33
+ for key, val in batch.items():
34
+ if type(val) == torch.Tensor:
35
+ batch[key] = batch[key].cuda(non_blocking=True, device=self.pretrain_device)
36
+
37
+ # X = batch['X'].cuda(non_blocking=True, device=self.pretrain_device)
38
+ # S = batch['S'].cuda(non_blocking=True, device=self.pretrain_device)
39
+ # score = batch['score'].cuda(non_blocking=True, device=self.pretrain_device)
40
+ # mask = batch['mask'].cuda(non_blocking=True, device=self.pretrain_device)
41
+ # lengths = batch['lengths'].cuda(non_blocking=True, device=self.pretrain_device)
42
+ # chain_mask = batch['chain_mask'].cuda(non_blocking=True, device=self.pretrain_device)
43
+ # chain_encoding = batch['chain_encoding'].cuda(non_blocking=True, device=self.pretrain_device)
44
+
45
+ yield batch
46
+
47
+
48
+ class DInterface(DInterface_base):
49
+ def __init__(self,**kwargs):
50
+ super().__init__(**kwargs)
51
+ self.save_hyperparameters()
52
+ self.load_data_module()
53
+
54
+ def setup(self, stage=None):
55
+ from src.datasets.featurizer import (featurize_AF, featurize_GTrans, featurize_GVP,
56
+ featurize_ProteinMPNN, featurize_Inversefolding)
57
+ if self.hparams.model_name in ['AlphaDesign', 'PiFold', 'KWDesign', 'GraphTrans', 'StructGNN', 'GCA', 'E3PiFold']:
58
+ self.collate_fn = featurize_GTrans
59
+ elif self.hparams.model_name == 'GVP':
60
+ featurizer = featurize_GVP()
61
+ self.collate_fn = featurizer.collate
62
+ elif self.hparams.model_name == 'ProteinMPNN':
63
+ self.collate_fn = featurize_ProteinMPNN
64
+ elif self.hparams.model_name == 'ESMIF':
65
+ self.collate_fn = featurize_Inversefolding
66
+
67
+ # Assign train/val datasets for use in dataloaders
68
+ if stage == 'fit' or stage is None:
69
+ self.trainset = self.instancialize(split = 'train')
70
+ self.valset = self.instancialize(split='valid')
71
+
72
+ # Assign test dataset for use in dataloader(s)
73
+ if stage == 'test' or stage is None:
74
+ self.testset = self.instancialize(split='test')
75
+
76
+ if stage in ['predict','eval']:
77
+ self.predictset = self.instancialize(split='predict')
78
+
79
+ def train_dataloader(self):
80
+ return MyDataLoader(self.trainset, model_name=self.hparams.model_name, batch_size=self.batch_size, num_workers=self.hparams.num_workers, shuffle=True, prefetch_factor=8, pin_memory=True, collate_fn=self.collate_fn)
81
+
82
+ def val_dataloader(self):
83
+ return MyDataLoader(self.valset, model_name=self.hparams.model_name, batch_size=self.batch_size, num_workers=self.hparams.num_workers, shuffle=False, pin_memory=True, collate_fn=self.collate_fn)
84
+
85
+ def test_dataloader(self):
86
+ return MyDataLoader(self.testset, model_name=self.hparams.model_name, batch_size=self.batch_size, num_workers=self.hparams.num_workers, shuffle=False, pin_memory=True, collate_fn=self.collate_fn)
87
+
88
+ def predict_dataloader(self):
89
+ return MyDataLoader(self.predictset, model_name=self.hparams.model_name, batch_size=self.batch_size, num_workers=self.hparams.num_workers, shuffle=False, pin_memory=True, collate_fn=self.collate_fn)
90
+
91
+ def load_data_module(self):
92
+
93
+ name = self.hparams.dataset
94
+ if name == 'AF2DB':
95
+ from src.datasets.AF2DB_dataset_lmdb import Af2dbDataset
96
+ self.data_module = Af2dbDataset
97
+
98
+ if name == 'TS':
99
+ from src.datasets.ts_dataset import TSDataset
100
+ self.data_module = TSDataset
101
+ self.hparams['path'] = osp.join(self.hparams.data_root, 'ts')
102
+
103
+ if name == 'CASP15':
104
+ from src.datasets.casp_dataset import CASPDataset
105
+ self.data_module = CASPDataset
106
+ self.hparams['path'] = osp.join(self.hparams.data_root, 'casp15')
107
+
108
+ if name == 'CATH4.2':
109
+ from src.datasets.cath_dataset import CATHDataset
110
+ self.data_module = CATHDataset
111
+ self.hparams['version'] = 4.2
112
+ self.hparams['path'] = osp.join(self.hparams.data_root, 'cath4.2')
113
+
114
+ if name == 'CATH4.3':
115
+ from src.datasets.cath_dataset import CATHDataset
116
+ self.data_module = CATHDataset
117
+ self.hparams['version'] = 4.3
118
+ self.hparams['path'] = osp.join(self.hparams.data_root, 'cath4.3')
119
+
120
+ if name == 'MPNN':
121
+ from src.datasets.mpnn_dataset import MPNNDataset
122
+ self.data_module = MPNNDataset
123
+
124
+ if name == 'FOLDSWITCHERS_1':
125
+ from src.datasets.foldswitchers_dataset import FoldswitchersDataset
126
+ self.data_module = FoldswitchersDataset
127
+ self.hparams['path'] = osp.join(self.hparams.data_root, 'fold_switchers/fold_1')
128
+
129
+ if name == 'FOLDSWITCHERS_2':
130
+ from src.datasets.foldswitchers_dataset import FoldswitchersDataset
131
+ self.data_module = FoldswitchersDataset
132
+ self.hparams['path'] = osp.join(self.hparams.data_root, 'fold_switchers/fold_2')
133
+
134
+ if name == 'PDBInference':
135
+ from src.datasets.pdb_inference import PDBInference
136
+ self.data_module = PDBInference
137
+ self.hparams['path'] = osp.join(self.hparams.infer_path)
138
+
139
+ if name == 'ATLAS_DIST_1':
140
+ from src.datasets.atlas_dataset import AtlasDataset
141
+ self.data_module = AtlasDataset
142
+ self.hparams['path'] = osp.join(self.hparams.data_root, 'atlas/distant-frame-pairs_NO_SUPERPOSITION/frames_1')
143
+
144
+ if name == 'ATLAS_DIST_2':
145
+ from src.datasets.atlas_dataset import AtlasDataset
146
+ self.data_module = AtlasDataset
147
+ self.hparams['path'] = osp.join(self.hparams.data_root, 'atlas/distant-frame-pairs_NO_SUPERPOSITION/frames_2')
148
+
149
+ if name == 'ATLAS_CLUSTER_1':
150
+ from src.datasets.atlas_dataset import AtlasDataset
151
+ self.data_module = AtlasDataset
152
+ self.hparams['path'] = osp.join(self.hparams.data_root, 'atlas/cluster-representatives/frames_1')
153
+
154
+ if name == 'ATLAS_CLUSTER_2':
155
+ from src.datasets.atlas_dataset import AtlasDataset
156
+ self.data_module = AtlasDataset
157
+ self.hparams['path'] = osp.join(self.hparams.data_root, 'atlas/cluster-representatives/frames_2')
158
+
159
+ if name == 'ATLAS_PDB':
160
+ from src.datasets.atlas_dataset import AtlasDataset
161
+ self.data_module = AtlasDataset
162
+ self.hparams['path'] = osp.join(self.hparams.data_root, '../atlas_pdb_inference/')
163
+
164
+ if name == 'ATLAS_FULL_MINIMIZED':
165
+ from src.datasets.atlas_dataset import AtlasDataset
166
+ self.data_module = AtlasDataset
167
+ self.hparams['path'] = osp.join(self.hparams.data_root, '../atlas_eval_proteinmpnn/atlas_full/minimized_PDBs/pdbs/')
168
+
169
+ if name == 'ATLAS_FULL_REFOLDED':
170
+ from src.datasets.atlas_dataset import AtlasDataset
171
+ self.data_module = AtlasDataset
172
+ self.hparams['path'] = osp.join(self.hparams.data_root, '../atlas_eval_proteinmpnn/atlas_full/refolded_PDBs/pdbs/')
173
+
174
+ if name == 'ATLAS_FULL_CRYSTAL':
175
+ from src.datasets.atlas_dataset import AtlasDataset
176
+ self.data_module = AtlasDataset
177
+ self.hparams['path'] = osp.join(self.hparams.data_root, '../atlas_eval_proteinmpnn/atlas_full/crystal_PDBs/pdbs/')
178
+
179
+ if name == 'FLEX_CATH4.3':
180
+ from src.datasets.flex_cath_dataset import FlexCATHDataset
181
+ self.data_module = FlexCATHDataset
182
+ self.hparams['version'] = 4.3
183
+ self.hparams['path'] = osp.join(self.hparams.data_root, 'cath4.3')
184
+
185
+
186
+ def instancialize(self, **other_args):
187
+ """ Instancialize a model using the corresponding parameters
188
+ from self.hparams dictionary. You can also input any args
189
+ to overwrite the corresponding value in self.kwargs.
190
+ """
191
+ class_args = list(inspect.signature(self.data_module.__init__).parameters)[1:]
192
+ inkeys = self.hparams.keys()
193
+ args1 = {}
194
+ for arg in class_args:
195
+ if arg in inkeys:
196
+ args1[arg] = self.hparams[arg]
197
+ args1.update(other_args)
198
+
199
+ # if self.hparams['test_engineering'] and self.hparams['use_dynamics']:
200
+ # args1['data_jsonl_name'] = self.hparams['test_eng_data_path']
201
+ #elif self.hparams['use_dynamics']:
202
+ if self.hparams['use_dynamics']:
203
+ args1['data_jsonl_name'] = load_yaml_config('configs/ANMAwareFlexibilityProtTrans.yaml')['data_jsonl_name']
204
+ # import pdb; pdb.set_trace()
205
+ return self.data_module(**args1) #Here this leads to __init__ of the class dataset
Flexpert-Design/data_utils.py ADDED
@@ -0,0 +1,535 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #From https://github.com/JoreyYan/zetadesign/blob/master/data/data.py
2
+ import glob
3
+ import json
4
+ import numpy as np
5
+ import gzip
6
+ import re
7
+ import multiprocessing
8
+ import tqdm
9
+ import shutil
10
+ SENTINEL = 1
11
+ import biotite.structure as struc
12
+ import biotite.application.dssp as dssp
13
+ import biotite.structure.io.pdb.file as file
14
+
15
+
16
+ def parse_PDB_biounits(x, sse,ssedssp,atoms=['N', 'CA', 'C'], chain=None):
17
+ '''
18
+ input: x = PDB filename
19
+ atoms = atoms to extract (optional)
20
+ output: (length, atoms, coords=(x,y,z)), sequence
21
+ '''
22
+
23
+ alpha_1 = list("ARNDCQEGHILKMFPSTWYV-")
24
+ states = len(alpha_1)
25
+ alpha_3 = ['ALA', 'ARG', 'ASN', 'ASP', 'CYS', 'GLN', 'GLU', 'GLY', 'HIS', 'ILE',
26
+ 'LEU', 'LYS', 'MET', 'PHE', 'PRO', 'SER', 'THR', 'TRP', 'TYR', 'VAL', 'GAP']
27
+
28
+ aa_1_N = {a: n for n, a in enumerate(alpha_1)}
29
+ aa_3_N = {a: n for n, a in enumerate(alpha_3)}
30
+ aa_N_1 = {n: a for n, a in enumerate(alpha_1)}
31
+ aa_1_3 = {a: b for a, b in zip(alpha_1, alpha_3)}
32
+ aa_3_1 = {b: a for a, b in zip(alpha_1, alpha_3)}
33
+
34
+ def AA_to_N(x):
35
+ # ["ARND"] -> [[0,1,2,3]]
36
+ x = np.array(x);
37
+ if x.ndim == 0: x = x[None]
38
+ return [[aa_1_N.get(a, states - 1) for a in y] for y in x]
39
+
40
+ def N_to_AA(x):
41
+ # [[0,1,2,3]] -> ["ARND"]
42
+ x = np.array(x);
43
+ if x.ndim == 1: x = x[None]
44
+ return ["".join([aa_N_1.get(a, "-") for a in y]) for y in x]
45
+
46
+ xyz, seq, plddts, min_resn, max_resn = {}, {}, [], 1e6, -1e6
47
+
48
+ pdbcontents = x.split('\n')[0]
49
+ with open(pdbcontents) as f:
50
+ pdbcontents = f.readlines()
51
+ for line in pdbcontents:
52
+ #line = line.decode("utf-8", "ignore").rstrip()
53
+
54
+ if line[:6] == "HETATM" and line[17:17 + 3] == "MSE":
55
+ line = line.replace("HETATM", "ATOM ")
56
+ line = line.replace("MSE", "MET")
57
+
58
+ if line[:4] == "ATOM":
59
+ ch = line[21:22]
60
+ if ch == chain or chain is None or ch==' ':
61
+ atom = line[12:12 + 4].strip()
62
+ resi = line[17:17 + 3]
63
+ resn = line[22:22 + 5].strip()
64
+ plddt=line[60:60 + 6].strip()
65
+
66
+
67
+
68
+ x, y, z = [float(line[i:(i + 8)]) for i in [30, 38, 46]]
69
+
70
+ if resn[-1].isalpha():
71
+ resa, resn = resn[-1], int(resn[:-1]) - 1 # in same pos ,use last atoms
72
+ else:
73
+ resa, resn = "_", int(resn) - 1
74
+ # resn = int(resn)
75
+ if resn < min_resn:
76
+ min_resn = resn
77
+ if resn > max_resn:
78
+ max_resn = resn
79
+
80
+
81
+
82
+ if resn not in xyz:
83
+ xyz[resn] = {}
84
+ if resa not in xyz[resn]:
85
+ xyz[resn][resa] = {}
86
+ if resn not in seq:
87
+ seq[resn] = {}
88
+
89
+ if resa not in seq[resn]:
90
+ seq[resn][resa] = resi
91
+
92
+ if atom not in xyz[resn][resa]:
93
+ xyz[resn][resa][atom] = np.array([x, y, z])
94
+
95
+
96
+
97
+ # convert to numpy arrays, fill in missing values
98
+ seq_, xyz_ ,sse_,ssedssp_= [], [], [], []
99
+ dsspidx=0
100
+ sseidx=0
101
+ # try:
102
+ # for resn in range(min_resn, max_resn + 1):
103
+ # if resn in seq:
104
+ # for k in sorted(seq[resn]):
105
+ # seq_.append(aa_3_N.get(seq[resn][k], 20))
106
+ # try:
107
+ # if 'CA' in xyz[resn][k]:
108
+ # sse_.append(sse[sseidx])
109
+ # sseidx = sseidx + 1
110
+ # else:
111
+ # sse_.append('-')
112
+ # except:
113
+ # print('error sse')
114
+
115
+
116
+ # else:
117
+ # seq_.append(20)
118
+ # sse_.append('-')
119
+
120
+ # misschianatom = False
121
+ # if resn in xyz:
122
+
123
+
124
+ # for k in sorted(xyz[resn]):
125
+ # for atom in atoms:
126
+ # if atom in xyz[resn][k]:
127
+ # xyz_.append(xyz[resn][k][atom]) #some will miss C and O ,but sse is normal,because sse just depend on CA
128
+ # else:
129
+ # xyz_.append(np.full(3, np.nan))
130
+ # misschianatom=True
131
+ # if misschianatom:
132
+ # ssedssp_.append('-')
133
+ # misschianatom = False
134
+ # else:
135
+ # try:
136
+ # ssedssp_.append(ssedssp[dsspidx]) # if miss chain atom,xyz ,seq think is ok , but dssp miss this
137
+ # dsspidx = dsspidx + 1
138
+ # except:
139
+ # print(dsspidx)
140
+
141
+
142
+ # else:
143
+ # for atom in atoms:
144
+ # xyz_.append(np.full(3, np.nan))
145
+ # ssedssp_.append('-')
146
+
147
+
148
+ # return np.array(xyz_).reshape(-1, len(atoms), 3), N_to_AA(np.array(seq_)),np.array(sse_),np.array(ssedssp_)
149
+ # except TypeError as e:
150
+ # print(f"TypeError: {e}")
151
+ # return 'no_chain', 'no_chain','no_chain'
152
+
153
+ for resn in range(int(min_resn), int(max_resn + 1)):
154
+ if resn in seq:
155
+ for k in sorted(seq[resn]):
156
+ seq_.append(aa_3_N.get(seq[resn][k], 20))
157
+ try:
158
+ if 'CA' in xyz[resn][k]:
159
+ sse_.append(sse[sseidx])
160
+ sseidx = sseidx + 1
161
+ else:
162
+ sse_.append('-')
163
+ except:
164
+ print('error sse')
165
+
166
+
167
+ else:
168
+ seq_.append(20)
169
+ sse_.append('-')
170
+
171
+ misschianatom = False
172
+ if resn in xyz:
173
+
174
+
175
+ for k in sorted(xyz[resn]):
176
+ for atom in atoms:
177
+ if atom in xyz[resn][k]:
178
+ xyz_.append(xyz[resn][k][atom]) #some will miss C and O ,but sse is normal,because sse just depend on CA
179
+ else:
180
+ xyz_.append(np.full(3, np.nan))
181
+ misschianatom=True
182
+ if misschianatom:
183
+ ssedssp_.append('-')
184
+ misschianatom = False
185
+ else:
186
+ try:
187
+ ssedssp_.append(ssedssp[dsspidx]) # if miss chain atom,xyz ,seq think is ok , but dssp miss this
188
+ dsspidx = dsspidx + 1
189
+ except:
190
+ print(dsspidx)
191
+
192
+
193
+ else:
194
+ for atom in atoms:
195
+ xyz_.append(np.full(3, np.nan))
196
+ ssedssp_.append('-')
197
+
198
+
199
+ return np.array(xyz_).reshape(-1, len(atoms), 3), N_to_AA(np.array(seq_)),np.array(sse_),np.array(ssedssp_)
200
+
201
+
202
+ def parse_PDB(path_to_pdb,name, input_chain_list=None):
203
+ """
204
+ make sure every time just input 1 line
205
+ """
206
+ c = 0
207
+ pdb_dict_list = []
208
+
209
+
210
+ if input_chain_list:
211
+ chain_alphabet = input_chain_list
212
+ else:
213
+ init_alphabet = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S',
214
+ 'T',
215
+ 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm',
216
+ 'n',
217
+ 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']
218
+ extra_alphabet = [str(item) for item in list(np.arange(300))]
219
+ chain_alphabet = init_alphabet + extra_alphabet
220
+
221
+ biounit_names = [path_to_pdb]
222
+ for biounit in biounit_names:
223
+ my_dict = {}
224
+ s = 0
225
+ concat_seq = ''
226
+
227
+
228
+ for letter in chain_alphabet:
229
+
230
+ PDBFile = file.PDBFile.read(biounit)
231
+ array_stack = PDBFile.get_structure(altloc="all")
232
+
233
+
234
+ sse1 = struc.annotate_sse(array_stack[0], chain_id=letter).tolist()
235
+ if len(sse1)==0:
236
+ sse1 = struc.annotate_sse(array_stack[0], chain_id='').tolist()
237
+ #ssedssp1 = dssp.DsspApp.annotate_sse(array_stack).tolist()
238
+ ssedssp1 = [] #not annotating dssp for now
239
+
240
+
241
+ xyz, seq, _, _= parse_PDB_biounits(biounit,sse1,ssedssp1,atoms=['N', 'CA', 'C','O'], chain=letter) #TODO: fix the float error
242
+ #ssedssp = sse #faking it for now
243
+ # if len(sse)!=len(seq[0]):
244
+ # xxxx=len(seq[0])
245
+ # print(name)
246
+ #assert len(sse)==len(seq[0])
247
+ #assert len(ssedssp) == len(seq[0])
248
+
249
+ if type(xyz) != str:
250
+ concat_seq += seq[0]
251
+ my_dict['seq_chain_' + letter] = seq[0]
252
+
253
+ coords_dict_chain = {}
254
+ coords_dict_chain['N'] = xyz[:, 0, :].tolist()
255
+ coords_dict_chain['CA'] = xyz[:, 1, :].tolist()
256
+ coords_dict_chain['C'] = xyz[:, 2, :].tolist()
257
+ coords_dict_chain['O'] = xyz[:, 3, :].tolist()
258
+ my_dict['coords_chain_' + letter] = coords_dict_chain
259
+
260
+ #sse=''.join(sse)
261
+ #ssedssp=''.join(ssedssp)
262
+ #my_dict['sse3' ] = sse
263
+ #my_dict['sse8'] = ssedssp
264
+ s += 1
265
+ #fi = biounit.rfind("/")
266
+ my_dict['name'] = name#biounit[(fi + 1):-4]
267
+ my_dict['num_of_chains'] = s
268
+ my_dict['seq'] = concat_seq
269
+ if s <= len(chain_alphabet):
270
+ pdb_dict_list.append(my_dict)
271
+ c += 1
272
+ return pdb_dict_list
273
+
274
+
275
+
276
+
277
+
278
+ def parse_pdb_split_chain(pdbgzFile):
279
+
280
+ with open(pdbgzFile) as f:
281
+
282
+
283
+ lines = f.readlines()
284
+ # pdbcontent = f.decode()
285
+
286
+
287
+ pattern = re.compile('ATOM\s+\d+\s*\w+\s*[A-Z]{3,4}\s*(\w)\s*.+\n', re.MULTILINE)
288
+ match = list(set(list(pattern.findall(lines[0]))))
289
+
290
+
291
+ name=pdbgzFile.split('/')[-1]
292
+ #for chain in match:
293
+ # parse_PDB
294
+ # match=[name[4]]
295
+ # match=['A']
296
+ pdb_data=parse_PDB(pdbgzFile,name,match)
297
+
298
+ return pdb_data
299
+ def parse_pdb_split_chain_af(pdbgzFile):
300
+ with gzip.open(pdbgzFile, 'rb') as pdbF:
301
+ try:
302
+ pdbcontent = pdbF.read()
303
+ except:
304
+ print(pdbgzFile)
305
+
306
+ pdbcontent = pdbcontent.decode()
307
+
308
+
309
+ pattern = re.compile('ATOM\s+\d+\s*\w+\s*[A-Z]{3,4}\s*(\w)\s*.+\n', re.MULTILINE)
310
+ match = list(set(list(pattern.findall(pdbcontent))))
311
+
312
+
313
+ name=pdbgzFile.split('/')[-1].split('.')[0]
314
+ #for chain in match:
315
+ # parse_PDB
316
+ # match=[name[4]]
317
+ # match=[1]
318
+ pdb_data=parse_PDB('/media/junyu/data/perotin/aftest080_1000/'+pdbgzFile.split('/')[-1].split('.')[0]+'.pdb',name,match)
319
+
320
+ return pdb_data
321
+
322
+ def parse_pdb_split_chain_af_3dcnn(pdbgzFile):
323
+ with gzip.open(pdbgzFile, 'rb') as pdbF:
324
+ try:
325
+ pdbcontent = pdbF.read()
326
+ except:
327
+ print(pdbgzFile)
328
+
329
+ pdbcontent = pdbcontent.decode()
330
+
331
+
332
+ pattern = re.compile('ATOM\s+\d+\s*\w+\s*[A-Z]{3,4}\s*(\w)\s*.+\n', re.MULTILINE)
333
+ match = list(set(list(pattern.findall(pdbcontent))))
334
+
335
+
336
+ name=pdbgzFile.split('/')[-1].split('.')[0]
337
+ namelist=[]
338
+ for chain in match:
339
+ namelist.append(name+'__'+chain)
340
+ # match=[name[4]]
341
+ # match=[1]
342
+
343
+
344
+
345
+ return namelist
346
+ def run_net(files_path,output_path):
347
+ """
348
+ input is pdbgz's dir
349
+ from pdb to jsonl
350
+ """
351
+ list=glob.glob(files_path+'*.pdb')#[:3110]
352
+ data=[]
353
+ for i in tqdm.tqdm(list):
354
+ data_chains=parse_pdb_split_chain(i)
355
+ #for chian in data_chains:
356
+ data.append(data_chains[0])
357
+
358
+ print('we want to write now')
359
+ with open(output_path, 'w') as f:
360
+ for entry in data:
361
+ f.write(json.dumps(entry) + '\n')
362
+
363
+ f.close()
364
+ print('finished')
365
+
366
+ def run_netbyondif(filelist,output_path):
367
+ with open(filelist) as f:
368
+
369
+ lines = f.readlines()
370
+ data=[]
371
+ data_1=[]
372
+ # data_2 = []
373
+ # data_3 = []
374
+ # data_4 = []
375
+ # data_5 = []
376
+ # data_6 = []
377
+ # data_7 = []
378
+ # data_8 = []
379
+ # data_9 = []
380
+ # data_10 = []
381
+ nums_dict={1:0,2:0,3:0,4:0,5:0,6:0,7:0,8:0,9:0,10:0,}
382
+
383
+ for i in tqdm.tqdm(lines):
384
+ data_chains,match=parse_pdb_split_chain(i.split('"')[1])
385
+
386
+ for chian in data_chains:
387
+ for i in match:
388
+ meanplddt = round(float(np.asarray(chian['plddts_chain_' + i]).mean()),2)
389
+ data.append({'name':chian['name'],'lens':len(chian['seq']),'meanplddt':meanplddt})
390
+ if int(meanplddt/10)==1:
391
+ #data_1.append(chian)
392
+ nums_dict[1]=nums_dict[1]+1
393
+ elif int(meanplddt/10)==2:
394
+ #data_2.append(chian)
395
+ nums_dict[2] = nums_dict[2] + 1
396
+ elif int(meanplddt / 10) == 3:
397
+ #data_3.append(chian)
398
+ nums_dict[3] = nums_dict[3] + 1
399
+ elif int(meanplddt / 10) == 4:
400
+ #data_4.append(chian)
401
+ nums_dict[4] = nums_dict[4] + 1
402
+ elif int(meanplddt / 10) == 5:
403
+ #data_5.append(chian)
404
+ nums_dict[5] = nums_dict[5] + 1
405
+ elif int(meanplddt / 10) == 6:
406
+ #data_6.append(chian)
407
+ nums_dict[6] = nums_dict[6] + 1
408
+ elif int(meanplddt / 10) == 7:
409
+ #data_7.append(chian)
410
+ nums_dict[7] = nums_dict[7] + 1
411
+ elif int(meanplddt / 10) == 8:
412
+ #data_8.append(chian)
413
+ nums_dict[8] = nums_dict[8] + 1
414
+ elif int(meanplddt / 10) == 9:
415
+ #data_9.append(chian)
416
+ nums_dict[9] = nums_dict[9] + 1
417
+ elif int(meanplddt / 10) == 10:
418
+ #data_10.append(chian)
419
+ nums_dict[10] = nums_dict[10] + 1
420
+ else:
421
+ print(chian['name'])
422
+
423
+
424
+ # data.append(chian)
425
+ #
426
+ f.close()
427
+ output_pathindex=output_path+filelist.split('/')[-1].split('.')[0]+'_detail.jsonl'
428
+ print('we want to write now')
429
+ with open(output_pathindex, 'w') as f:
430
+ for entry in data:
431
+ f.write(json.dumps(entry) + '\n')
432
+
433
+ f.close()
434
+ #print(nums_dict)
435
+ # count(output_pathindex)
436
+ print('finished')
437
+ def list_of_groups(list_info, per_list_len):
438
+ '''
439
+ :param list_info: 列表
440
+ :param per_list_len: 每个小列表的长度
441
+ :return:
442
+ '''
443
+ list_of_group = zip(*(iter(list_info),) *per_list_len)
444
+ end_list = [list(i) for i in list_of_group] # i is a tuple
445
+ count = len(list_info) % per_list_len
446
+ end_list.append(list_info[-count:]) if count !=0 else end_list
447
+ return end_list
448
+
449
+ def count(filelist):
450
+ with open(filelist) as f:
451
+
452
+ lines = f.readlines()
453
+ plddts=[]
454
+
455
+ for i in tqdm.tqdm(lines):
456
+ pl=json.loads(i)['meanplddt']
457
+ plddts.append(int(pl/10))
458
+
459
+ for i in range(10):
460
+ print('counts '+str(i),plddts.count(i))
461
+
462
+ def run_net_aftest(files_path,output_path):
463
+ """
464
+ input is pdbgz's dir
465
+ """
466
+ with open(files_path) as f:
467
+ lines = f.readlines()
468
+ data=[]
469
+ for i in tqdm.tqdm(lines):
470
+
471
+ data_chains=parse_pdb_split_chain_af('/media/junyu/data/point_cloud/'+i.split('"')[1])
472
+ for chian in data_chains:
473
+ data.append(chian)
474
+
475
+ # print('we want to write now')
476
+ # with open(output_path, 'w') as f:
477
+ # for entry in data:
478
+ # f.write(json.dumps(entry) + '\n')
479
+ #
480
+ # f.close()
481
+ # print('finished')
482
+
483
+ output_pathindex = output_path + str(80) + 'bigthanclass_1000.text'
484
+ print('we want to write now')
485
+ with open(output_pathindex, 'w') as f:
486
+ for entry in data:
487
+ f.write(entry + '\n')
488
+
489
+
490
+ f.close()
491
+
492
+ # if __name__ == "__main__":
493
+ # files_path='/media/junyu/data/perotin/chain_set/AFDATA/details/80bigthanclass_1000.jsonl' #'/home/junyu/下载/splits/'#
494
+ # output_path='/media/junyu/data/perotin/chain_set/'
495
+
496
+
497
+
498
+ # # run_net_aftest(files_path,output_path)
499
+
500
+ # fakedata='//home/oem/pdb-tools/pdbtools/fixed/'
501
+ # run_net(fakedata,output_path+'tim184.jsonl')
502
+
503
+
504
+
505
+ #
506
+ # f.close()
507
+ # # print(nums_dict)
508
+ # print('finished ' +str(i))
509
+
510
+
511
+ # alllist=list_of_groups(lists,10000)
512
+
513
+ # for i in range(len(alllist)):
514
+ # thislist=alllist[i]
515
+ # with open(output_path+'_'+str(i)+'.jsonl', 'w') as f:
516
+ # for entry in thislist:
517
+ # f.write(json.dumps(entry) + '\n')
518
+ #
519
+ # f.close()
520
+ # # print(nums_dict)
521
+ # print('finished ' +str(i))
522
+
523
+ # _processes = []
524
+
525
+
526
+
527
+ # q = multiprocessing.Queue()
528
+ #
529
+ # proc.start()
530
+ # for eachlist in alllist:
531
+ # _process = multiprocessing.Process(target=run_netbyondif, args=(eachlist,))
532
+ # _process.start()
533
+
534
+
535
+ # run_netbyondif(lists,output_path)
Flexpert-Design/download-cath-data.sh ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ echo "Downloading CATH data..."
3
+
4
+ # Create data directory if it doesn't exist
5
+ mkdir -p ../data/
6
+
7
+ # Set file information
8
+ URL="https://data.ciirc.cvut.cz/public/projects/2025Flexpert/cath4.3/"
9
+ OUTPUT_DIR="../data/cath4.3"
10
+
11
+ # Download directory recursively
12
+ echo "Downloading CATH data..."
13
+ wget --no-check-certificate -r -np -nH --cut-dirs=3 --reject "index.html*" \
14
+ --directory-prefix=${OUTPUT_DIR} ${URL}
15
+
16
+ echo "CATH data download completed."
17
+
Flexpert-Design/model_interface.py ADDED
@@ -0,0 +1,631 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys; sys.path.append('/huyuqi/xmyu/DiffSDS')
2
+ import inspect
3
+ import torch
4
+ from src.tools.utils import cuda
5
+ import torch.nn as nn
6
+ import os
7
+ from torcheval.metrics.text import Perplexity
8
+ from src.interface.model_interface import MInterface_base
9
+ import math
10
+ import torch.nn.functional as F
11
+ from omegaconf import OmegaConf
12
+ from src.tools.utils import load_yaml_config
13
+ import torchmetrics
14
+
15
+ class MInterface(MInterface_base):
16
+ def __init__(self, model_name=None, loss=None, lr=None, **kwargs):
17
+ super().__init__()
18
+ self.save_hyperparameters()
19
+ self.load_model()
20
+ self.use_dynamics = kwargs.get('use_dynamics', 0)
21
+ self.flex_loss_coeff = torch.Tensor([kwargs.get('flex_loss_coeff', 0)]).to('cuda:0').to(torch.float)
22
+ self.flex_loss_coeff.requires_grad = False
23
+ if self.use_dynamics:
24
+ self.load_flex_predictor()
25
+ self.flex_loss_type = kwargs.get('loss_fn', 0)
26
+ if self.flex_loss_type == 'MSE':
27
+ self.flex_loss_fn = nn.MSELoss(reduction='none')
28
+ elif self.flex_loss_type == 'L1':
29
+ self.flex_loss_fn = nn.L1Loss(reduction='none')
30
+ elif self.flex_loss_type == 'DPO':
31
+ self.flex_loss_fn = ...
32
+ else:
33
+ raise ValueError(f"Not recognized type of loss function {self.flex_loss_type}")
34
+ self.cross_entropy = nn.NLLLoss(reduction='none')
35
+ os.makedirs(os.path.join(self.hparams.res_dir, self.hparams.ex_name), exist_ok=True)
36
+
37
+ self.control_sum_recovery = 0
38
+ self.control_sum_batch_sizes = 0
39
+
40
+ self.grad_normalization = kwargs.get('grad_normalization', 0)
41
+ self.use_pmpnn_checkpoint = kwargs.get('use_pmpnn_checkpoint',0)
42
+
43
+ if self.use_pmpnn_checkpoint:
44
+ print('Loading pmpnn checkpoint from {}'.format(self.model.pmpnn_init_weights_path))
45
+ state_dict = torch.load(self.model.pmpnn_init_weights_path)['state_dict'] #['module']
46
+ state_dict = {key: value for key, value in state_dict.items() if 'model.' in key[:6]}
47
+ state_dict = {key.replace("model.", ""): value for key, value in state_dict.items()}
48
+ self.model.load_state_dict(state_dict)
49
+
50
+ self.MSE = nn.MSELoss(reduction='none')
51
+ self.automatic_optimization = False
52
+
53
+ if self.hparams.use_dynamics:
54
+ self.pearson = torchmetrics.PearsonCorrCoef()
55
+ self.spearman = torchmetrics.SpearmanCorrCoef()
56
+ self.validation_step_outputs = []
57
+ self.test_step_outputs = []
58
+
59
+ #### setting forward hook
60
+
61
+ # def forward_hook(module, input, output):
62
+ # def check_nan(tensor):
63
+ # if isinstance(tensor, torch.Tensor):
64
+ # if torch.isnan(tensor).any():
65
+ # print(f"NaN detected in the output of {type(module).__name__}")
66
+ # print(f"Tensor shape: {tensor.shape}")
67
+ # print(f"Tensor stats: mean={tensor.mean()}, std={tensor.std()}, min={tensor.min()}, max={tensor.max()}, all={torch.isnan(tensor).all()}")
68
+ # elif isinstance(tensor, tuple):
69
+ # for i, t in enumerate(tensor):
70
+ # if isinstance(t, torch.Tensor):
71
+ # if torch.isnan(t).any():
72
+ # print(f"NaN detected in the output[{i}] of {type(module).__name__}")
73
+ # print(f"Tensor shape: {t.shape}")
74
+ # print(f"Tensor stats: mean={t.mean()}, std={t.std()}, min={t.min()}, max={t.max()}, all={torch.isnan(tensor).all()}")
75
+
76
+ # if isinstance(output, tuple):
77
+ # for i, out in enumerate(output):
78
+ # check_nan(out)
79
+ # else:
80
+ # check_nan(output)
81
+
82
+ # for name, module in self.model.named_modules():
83
+ # module.register_forward_hook(forward_hook)
84
+
85
+ # for name, module in self.flex_model.named_modules():
86
+ # module.register_forward_hook(forward_hook)
87
+
88
+ ####
89
+
90
+ def forward(self, batch, mode='train', temperature=1.0):
91
+ if self.hparams.augment_eps>0:
92
+ batch['X'] = batch['X'] + self.hparams.augment_eps * torch.randn_like(batch['X'])
93
+
94
+ batch = self.model._get_features(batch)
95
+ results = self.model(batch)
96
+
97
+ log_probs, mask = results['log_probs'], batch['mask']
98
+ if len(log_probs.shape) == 3:
99
+ if self.hparams.use_dynamics:
100
+ loss = self.combined_flex_aware_loss(batch, pred_log_probs=log_probs)
101
+ #loss = loss_dict['combined_loss']
102
+ else:
103
+ loss = self.cross_entropy(log_probs.permute(0,2,1), batch['S'])
104
+ loss = (loss*mask).sum()/(mask.sum())
105
+ elif len(log_probs.shape) == 2:
106
+ if self.hparams.model_name == 'GVP':
107
+ loss = self.cross_entropy(log_probs, batch.seq)
108
+ else:
109
+ loss = self.cross_entropy(log_probs, batch['S'])
110
+
111
+ if self.hparams.model_name == 'AlphaDesign':
112
+ loss += self.cross_entropy(results['log_probs0'], batch['S'])
113
+ loss = (loss*mask).sum()/(mask.sum())
114
+
115
+ cmp = log_probs.argmax(dim=-1)==batch['S']
116
+ recovery = (cmp*mask).sum()/(mask.sum())
117
+
118
+ if mode == 'predict':
119
+ return {'original_sequence':batch['S'],'correct_positions': cmp, 'mask':mask,'loss':loss, 'recovery':recovery, 'title':batch['title'], 'log_probs': log_probs, 'batch':batch} #, 'gt_bfactors': batch['norm_bfactors'], 'batch':batch}
120
+ elif mode == 'eval':
121
+ return {'original_sequence':batch['S'],'correct_positions': cmp, 'mask':mask,'loss':loss, 'recovery':recovery, 'title':batch['title'], 'log_probs': log_probs, 'batch':batch}
122
+ else:
123
+ return loss, recovery
124
+
125
+ def avgCorrelations(self, preds, gts, masks):
126
+ pearson_R = 0
127
+ spearman_R = 0
128
+ valid_datapoints = 0
129
+ for pred, gt, mask in zip(preds, gts, masks):
130
+ dpR = self.pearson(pred[torch.where(mask)], gt[torch.where(mask)])
131
+ if torch.isnan(dpR):
132
+ continue
133
+ else:
134
+ pearson_R += dpR
135
+ spearman_R += self.spearman(pred[torch.where(mask)], gt[torch.where(mask)])
136
+ valid_datapoints += 1
137
+ return pearson_R/valid_datapoints, spearman_R/valid_datapoints
138
+
139
+ def temperature_schedular(self, batch_idx):
140
+ total_steps = self.hparams.steps_per_epoch*self.hparams.epoch
141
+
142
+ initial_lr = 1.0
143
+ circle_steps = total_steps//100
144
+ x = batch_idx / total_steps
145
+ threshold = 0.48
146
+ if x<threshold:
147
+ linear_decay = 1 - 2*x
148
+ else:
149
+ K = 1 - 2*threshold
150
+ linear_decay = K - K*(x-threshold)/(1-threshold)
151
+
152
+ new_lr = (1+math.cos(batch_idx/circle_steps*math.pi))/2*linear_decay*initial_lr
153
+
154
+ return new_lr
155
+
156
+ # def get_grad_norm(self):
157
+ # total_norm = 0
158
+ # parameters = [p for p in self.parameters() if p.grad is not None and p.requires_grad]
159
+ # for p in parameters:
160
+ # param_norm = p.grad.detach().data.norm(2)
161
+ # total_norm += param_norm.item() ** 2
162
+ # total_norm = total_norm ** 0.5
163
+ # return total_norm
164
+
165
+ #https://lightning.ai/docs/pytorch/1.9.0/notebooks/lightning_examples/basic-gan.html
166
+ def training_step(self, batch, batch_idx, **kwargs):
167
+ if self.use_dynamics:
168
+ raw_loss, recovery = self(batch)
169
+ if type(raw_loss) == dict:
170
+ flex_loss = raw_loss['flex_loss']
171
+ seq_loss = raw_loss['seq_loss']
172
+ opt = self.optimizers()
173
+ opt.zero_grad()
174
+
175
+ _params_for_optimization = [p for p in self.model.parameters() if p.requires_grad]
176
+ _params_for_optimization += [p for p in self.flex_model.parameters() if p.requires_grad]
177
+
178
+ grads_flex = torch.autograd.grad(flex_loss, _params_for_optimization, create_graph=True)
179
+ grads_seq = torch.autograd.grad(seq_loss, _params_for_optimization, create_graph=True)
180
+ if self.grad_normalization:
181
+ norm_grads_flex = [g / (g.norm() + 1e-10) for g in grads_flex]
182
+ norm_grads_seq = [g / (g.norm() + 1e-10) for g in grads_seq]
183
+ else:
184
+ norm_grads_flex = grads_flex
185
+ norm_grads_seq = grads_seq
186
+
187
+ combined_grads = [self.flex_loss_coeff * gflex + (1-self.flex_loss_coeff) * gseq for gflex, gseq in zip(norm_grads_flex, norm_grads_seq)]
188
+
189
+ #maybe track the angle between the gradients?
190
+ self.log_dict({'flex_grad_norm':torch.mean(torch.tensor([g.detach().norm() for g in norm_grads_flex])), 'seq_grad_norm': torch.mean(torch.tensor([g.detach().norm() for g in norm_grads_seq])), 'combined_grad_norm': torch.mean(torch.tensor([g.detach().norm() for g in combined_grads]))}, on_step=True, on_epoch=False, prog_bar=True)
191
+
192
+
193
+ for param, grad in zip(_params_for_optimization, combined_grads):
194
+ if param.grad is None:
195
+ param.grad = grad.detach()
196
+ else:
197
+ param.grad += grad.detach()
198
+
199
+ self.clip_gradients(opt, gradient_clip_val=1., gradient_clip_algorithm="norm")
200
+ opt.step()
201
+
202
+ # Update learning rate
203
+ sch = self.lr_schedulers()
204
+ if sch is not None:
205
+ sch.step()
206
+
207
+ loss = flex_loss + seq_loss
208
+
209
+ self.log_dict({'train_flex_loss':flex_loss, 'train_seq_loss':seq_loss}, on_step=True, on_epoch=False, prog_bar=True)
210
+
211
+ # Log the current learning rate
212
+ if sch is not None:
213
+ current_lr = sch.get_last_lr()[0]
214
+ self.log('learning_rate', current_lr, on_step=True, on_epoch=False, prog_bar=True)
215
+ else:
216
+ loss = raw_loss
217
+ self.log('loss', loss, on_step=True, on_epoch=True, prog_bar=True)
218
+ return loss
219
+ else:
220
+ raw_loss, recovery = self(batch)
221
+ if type(raw_loss) == dict:
222
+ loss = raw_loss['combined_loss']
223
+ _ = raw_loss.pop('pred_flex')
224
+ # _ = raw_loss.pop('gt_bfactors')
225
+ _ = raw_loss.pop('gt_flex')
226
+ _ = raw_loss.pop('flex_mask')
227
+
228
+ self.log_dict(raw_loss, on_step=True, on_epoch=True, prog_bar=True)
229
+ else:
230
+ loss = raw_loss
231
+ self.log('loss', loss, on_step=True, on_epoch=True, prog_bar=True)
232
+ return loss
233
+
234
+ def validation_step(self, batch, batch_idx):
235
+ raw_loss, recovery = self(batch)
236
+ if type(raw_loss) == dict:
237
+ loss = raw_loss['flex_loss']+raw_loss['seq_loss'] #raw_loss['combined_loss']
238
+ raw_loss['recovery'] = recovery
239
+ pred_flex = raw_loss.pop('pred_flex')
240
+ gt_flex = batch['gt_flex']
241
+
242
+ flex_mask = raw_loss.pop('flex_mask')
243
+ #epoch_metric_ingredients = {'pred_bfactors':pred_bfactors, 'gt_bfactors':gt_bfactors, 'flex_mask':flex_mask}
244
+ epoch_metric_ingredients = {'pred_flex': pred_flex,'gt_flex':gt_flex, 'flex_mask':flex_mask}
245
+ self.validation_step_outputs.append(epoch_metric_ingredients)
246
+ self.log_dict({ "val_combined_loss":loss,
247
+ "val_seq_loss":raw_loss['seq_loss'],
248
+ "val_flex_loss":raw_loss['flex_loss'],
249
+ "recovery": recovery})
250
+ else:
251
+ loss = raw_loss
252
+ self.log_dict({"val_loss":loss,
253
+ "recovery": recovery})
254
+ #if there is issue with validation metrics - see the test_step below
255
+ return self.log_dict
256
+
257
+ def on_validation_epoch_end(self):
258
+ if self.hparams.use_dynamics:
259
+ # all_preds = [b['pred_bfactors'] for b in self.validation_step_outputs]
260
+ # all_gts = [b['gt_bfactors'] for b in self.validation_step_outputs]
261
+ all_preds = [b['pred_flex'] for b in self.validation_step_outputs]
262
+ all_gts = [b['gt_flex'] for b in self.validation_step_outputs]
263
+ all_masks = [b['flex_mask'] for b in self.validation_step_outputs]
264
+
265
+ max_seq_length = max([pred.size()[1] for pred in all_preds])
266
+
267
+ for set_of_tensors in [all_preds, all_gts, all_masks]:
268
+ for i in range(len(set_of_tensors)):
269
+ set_of_tensors[i] = F.pad(set_of_tensors[i], (0, max_seq_length - set_of_tensors[i].shape[1],0,0), value=float(0))
270
+ all_preds = torch.cat(all_preds, dim=0)
271
+ all_gts = torch.cat(all_gts, dim=0)
272
+ all_masks = torch.cat(all_masks, dim=0)
273
+ # print(all_preds.shape, all_gts.shape, all_masks.shape)
274
+ # do something with all preds
275
+
276
+ # pearson_R = self.pearson(all_preds[torch.where(all_masks)], all_gts[torch.where(all_masks)])
277
+ pearson_R, spearman_R = self.avgCorrelations(all_preds, all_gts, all_masks)
278
+ # try:
279
+ # spearman_R = self.spearman(all_preds[torch.where(all_masks)], all_gts[torch.where(all_masks)])
280
+ # except IndexError:
281
+ # spearman_R = pearson_R
282
+ self.log_dict({"val_pearson_R":pearson_R, "val_spearman_R":spearman_R})
283
+ self.validation_step_outputs.clear() # free memory
284
+ return super().on_validation_epoch_end()
285
+
286
+ def on_test_epoch_end(self):
287
+ import pickle #use pickle to save the self.test_step_outputs to a file
288
+ with open(f'rebuttal_experiments/test_step_outputs_{self.hparams.starting_checkpoint_path.split("/")[-3]}_initFF{self.hparams.init_flex_features}_{self.hparams.test_eng_data_path.split("/")[-1][:-5]}.pkl', 'wb') as f:
289
+ pickle.dump(self.test_step_outputs, f)
290
+ if self.hparams.test_engineering and self.hparams.use_dynamics:
291
+ all_preds = [b['pred_flex'] for b in self.test_step_outputs]
292
+ all_eng_gts = [b['gt_flex'] for b in self.test_step_outputs]
293
+ all_masks = [b['flex_mask'] for b in self.test_step_outputs]
294
+ all_eng_masks = [b['eng_mask'] for b in self.test_step_outputs]
295
+ all_original_gt_flex = [b['original_gt_flex'] for b in self.test_step_outputs]
296
+
297
+ avg_sequence_recovery = sum([b['sequence_recovery'] for b in self.test_step_outputs]) / len(self.test_step_outputs)
298
+ avg_sequence_recovery = avg_sequence_recovery.cpu().tolist()
299
+ max_seq_length = max([pred.size()[1] for pred in all_preds])
300
+
301
+ _pred_flex_pool = []
302
+ _eng_gt_flex_pool = []
303
+ _original_gt_flex_pool = []
304
+ _original_gt_flex_ranks_pool = []
305
+ _eng_gt_flex_ranks_pool = []
306
+ _pred_flex_ranks_pool = []
307
+
308
+
309
+ import numpy as np
310
+ for eng_mask, flex_mask, original_gt_flex, eng_gt_flex, pred_flex in zip(all_eng_masks, all_masks, all_original_gt_flex, all_eng_gts, all_preds):
311
+ #select only the values where the engineering mask is 1 and flex mask is 1
312
+ _original_gt_flex = original_gt_flex[eng_mask == 1]
313
+ _eng_gt_flex = eng_gt_flex[eng_mask == 1]
314
+ _pred_flex = pred_flex[eng_mask == 1]
315
+ _pred_flex_pool.append(_pred_flex.cpu().numpy())
316
+ _eng_gt_flex_pool.append(_eng_gt_flex.cpu().numpy())
317
+ _original_gt_flex_pool.append(_original_gt_flex.cpu().numpy())
318
+
319
+ _original_gt_flex_ranks = torch.argsort(torch.argsort(torch.nan_to_num(original_gt_flex, nan=0)))[eng_mask == 1].cpu().numpy()
320
+ _eng_gt_flex_ranks = torch.argsort(torch.argsort(torch.nan_to_num(eng_gt_flex, nan=0)))[eng_mask == 1].cpu().numpy()
321
+ _pred_flex_ranks = torch.argsort(torch.argsort(torch.nan_to_num(pred_flex, nan=0)))[eng_mask == 1].cpu().numpy()
322
+
323
+ _original_gt_flex_ranks_pool.append(_original_gt_flex_ranks)
324
+ _eng_gt_flex_ranks_pool.append(_eng_gt_flex_ranks)
325
+ _pred_flex_ranks_pool.append(_pred_flex_ranks)
326
+
327
+
328
+ import matplotlib.pyplot as plt
329
+ import os
330
+
331
+ # # Create 'paper_figures' folder if it doesn't exist
332
+ # if not os.path.exists('paper_figures'):
333
+ # os.makedirs('paper_figures')
334
+
335
+ #pool the numpy arrays in the lists into one numpy array
336
+ _pred_flex_pool = np.concatenate(_pred_flex_pool)
337
+ _eng_gt_flex_pool = np.concatenate(_eng_gt_flex_pool)
338
+ _original_gt_flex_pool = np.concatenate(_original_gt_flex_pool)
339
+
340
+ ############################################################################
341
+ all_gt_seqs = [b['gt_seq'] for b in self.test_step_outputs]
342
+ all_pred_logprobs = [b['pred_logprobs'] for b in self.test_step_outputs]
343
+ _gt_seq_pool = []
344
+ _pred_seq_pool = []
345
+ _outside_eng_region_pred_seq_pool = []
346
+ _outside_eng_region_gt_seq_pool = []
347
+ for eng_mask, gt_seq, pred_logprobs in zip(all_eng_masks, all_gt_seqs, all_pred_logprobs):
348
+ #select only the values where the engineering mask is 1
349
+ _outside_eng_region_pred_seq_pool.append(torch.argmax(pred_logprobs[(eng_mask == 0) & (flex_mask == 1)], dim=1).cpu().numpy())
350
+ _outside_eng_region_gt_seq_pool.append(gt_seq[(eng_mask == 0) & (flex_mask == 1)].cpu().numpy())
351
+
352
+ _pred_seq = torch.argmax(pred_logprobs[eng_mask == 1], dim=1)
353
+ _gt_seq = gt_seq[eng_mask == 1]
354
+
355
+ # create and add to the pools the numpy arrays
356
+ _gt_seq_pool.append(_gt_seq.cpu().numpy())
357
+ _pred_seq_pool.append(_pred_seq.cpu().numpy())
358
+ _gt_seq_pool = np.concatenate(_gt_seq_pool)
359
+ _pred_seq_pool = np.concatenate(_pred_seq_pool)
360
+ _outside_eng_region_pred_seq_pool = np.concatenate(_outside_eng_region_pred_seq_pool)
361
+ _outside_eng_region_gt_seq_pool = np.concatenate(_outside_eng_region_gt_seq_pool)
362
+ #output these pools together with the other pools to a json_file
363
+ import json
364
+ with open(f'paper_figures/pools_{self.hparams.starting_checkpoint_path.split("/")[-3]}_initFF{self.hparams.init_flex_features}_{self.hparams.test_eng_data_path.split("/")[-1][:-5]}.json', 'w') as f:
365
+ json.dump({
366
+ '_pred_flex_pool': _pred_flex_pool.tolist(),
367
+ '_eng_gt_flex_pool': _eng_gt_flex_pool.tolist(),
368
+ '_original_gt_flex_pool': _original_gt_flex_pool.tolist(),
369
+ '_pred_seq_pool': _pred_seq_pool.tolist(),
370
+ '_gt_seq_pool': _gt_seq_pool.tolist(),
371
+ '_sequence_recovery': avg_sequence_recovery,
372
+ '_outside_eng_region_pred_seq_pool': _outside_eng_region_pred_seq_pool.tolist(),
373
+ '_outside_eng_region_gt_seq_pool': _outside_eng_region_gt_seq_pool.tolist()
374
+ }, f)
375
+
376
+
377
+
378
+ ############################################################################
379
+
380
+
381
+ self.test_step_outputs.clear()
382
+ else:
383
+ # all_preds = [b['pred_bfactors'] for b in self.test_step_outputs]
384
+ # all_gts = [b['gt_bfactors'] for b in self.test_step_outputs]
385
+ all_preds = [b['pred_flex'] for b in self.test_step_outputs]
386
+ all_gts = [b['gt_flex'] for b in self.test_step_outputs]
387
+ all_masks = [b['flex_mask'] for b in self.test_step_outputs]
388
+
389
+ max_seq_length = max([pred.size()[1] for pred in all_preds])
390
+
391
+ for set_of_tensors in [all_preds, all_gts, all_masks]:
392
+ for i in range(len(set_of_tensors)):
393
+ set_of_tensors[i] = F.pad(set_of_tensors[i], (0, max_seq_length - set_of_tensors[i].shape[1],0,0), value=float(0))
394
+
395
+ all_preds = torch.cat(all_preds, dim=0)
396
+ all_gts = torch.cat(all_gts, dim=0)
397
+ all_masks = torch.cat(all_masks, dim=0)
398
+ # print(all_preds.shape, all_gts.shape, all_masks.shape)
399
+ # do something with all preds
400
+ # pearson_R = self.pearson(all_preds[torch.where(all_masks)], all_gts[torch.where(all_masks)])
401
+ pearson_R, spearman_R = self.avgCorrelations(all_preds, all_gts, all_masks)
402
+ try:
403
+ spearman_R = self.spearman(all_preds[torch.where(all_masks)], all_gts[torch.where(all_masks)])
404
+ except IndexError:
405
+ spearman_R = pearson_R
406
+ self.log_dict({"test_pearson_R":pearson_R, "test_spearman_R":spearman_R})
407
+ self.test_step_outputs.clear() # free memory
408
+ return super().on_test_epoch_end()
409
+
410
+ def test_step(self, batch, batch_idx):
411
+ # Here we just reuse the validation_step for testing
412
+ #return self.validation_step(batch, batch_idx)
413
+
414
+ raw_loss, recovery = self(batch)
415
+ if type(raw_loss) == dict:
416
+ #loss = raw_loss['combined_loss']
417
+ loss = raw_loss['flex_loss']+raw_loss['seq_loss'] #raw_loss['combined_loss']
418
+ raw_loss['recovery'] = recovery
419
+ # pred_bfactors = raw_loss.pop('pred_bfactors')
420
+ pred_flex = raw_loss.pop('pred_flex')
421
+ # gt_bfactors = raw_loss.pop('gt_bfactors')
422
+ gt_flex = raw_loss.pop('gt_flex')
423
+ flex_mask = raw_loss.pop('flex_mask')
424
+ epoch_metric_ingredients = {'pred_flex':pred_flex, 'gt_flex':gt_flex, 'flex_mask':flex_mask}
425
+
426
+ if self.hparams.test_engineering and self.hparams.use_dynamics:
427
+ eng_mask = raw_loss.pop('eng_mask')
428
+ original_gt_flex = raw_loss.pop('original_gt_flex')
429
+ epoch_metric_ingredients['eng_mask'] = eng_mask
430
+ epoch_metric_ingredients['original_gt_flex'] = original_gt_flex
431
+ epoch_metric_ingredients['gt_seq'] = raw_loss['gt_seq']
432
+ epoch_metric_ingredients['pred_logprobs'] = raw_loss['pred_logprobs']
433
+ epoch_metric_ingredients['sequence_recovery'] = raw_loss['recovery']
434
+ epoch_metric_ingredients['id'] = batch['title']
435
+
436
+ self.test_step_outputs.append(epoch_metric_ingredients)
437
+ out_dict = {"val_combined_loss":loss,
438
+ "val_seq_loss":raw_loss['seq_loss'],
439
+ "val_flex_loss":raw_loss['flex_loss'],
440
+ "recovery": recovery}
441
+ else:
442
+ out_dict = {"val_loss":raw_loss, "recovery": recovery}
443
+ self.log_dict(out_dict,on_step=True,on_epoch=True, sync_dist=True)
444
+ #print(out_dict) #This print statement is fixing it - ultimately fixed by setting 'n_step=True' above
445
+ #Below validation of the correctness of the above loging
446
+ self.control_sum_batch_sizes += len(batch['X'])
447
+ self.control_sum_recovery += len(batch['X'])*recovery
448
+ return out_dict
449
+
450
+ def predict_step(self, batch, batch_idx):
451
+ predict_out = self(batch, mode=self.hparams.stage)
452
+ return predict_out
453
+
454
+ def combined_flex_aware_loss(self, batch, pred_log_probs):
455
+
456
+ _mask = batch['mask']
457
+
458
+ gt_seq = batch['S']
459
+ gt_flex = batch['gt_flex']
460
+ anm_input = batch['enm_vals'] #TODO: manage the loading of the anm input
461
+
462
+ trail_idcs = torch.argmax((batch['S'] == 0).int(), dim=1)
463
+ trail_idcs[trail_idcs == 0] = batch['S'].shape[1] # For sequences without padding
464
+
465
+ # # #TODO: test on one example - remove later
466
+ # # trail_idcs = trail_idcs[0].unsqueeze(0)
467
+
468
+ # # # ###########################################################################
469
+ # # # #### TODO: change back to precomputed GT_FLEX once debugged ###############
470
+ # dl_gtseq = batch['S']
471
+ # dl_anm = batch['enm_vals']
472
+
473
+
474
+ # attention_mask = torch.zeros_like(batch['mask'])
475
+ # for i in range(attention_mask.size(0)):
476
+ # attention_mask[i, :trail_idcs[i]] = 1
477
+
478
+ # dl_predflex_bs4 = self.flex_model(None, dl_anm, trail_idcs, attention_mask = attention_mask, sampled_pmpnn_sequence = dl_gtseq, alphabet='pmpnn') #['predicted_flex'][:,:-1,0]
479
+ # dl_predflex_bs1 = self.flex_model(None, dl_anm[0].unsqueeze(0), trail_idcs[0].unsqueeze(0) , attention_mask = attention_mask[0].unsqueeze(0), sampled_pmpnn_sequence = dl_gtseq[0].unsqueeze(0), alphabet='pmpnn') #['predicted_flex'][:,:-1,0]
480
+
481
+ # testseq = 'MKKAVINGEQIRSISDLHQTLKKELALPEYYGENLDALWDCLTGWVEYPLVLEWRQFEQSKQLTENGAESVLQVFREAKAEGADITIILS'
482
+ # tokenizer_predflex_bs4 = self.flex_model(None, dl_anm[0,:90].unsqueeze(0), trail_idcs[0].unsqueeze(0) , attention_mask = attention_mask[0,:90].unsqueeze(0), sampled_pmpnn_sequence = testseq, alphabet='aa') #['predicted_flex'][:,:-1,0] #['predicted_flex'][:,:-1,0]
483
+ # import pdb; pdb.set_trace()
484
+ # input_ids_predflex_bs4 = self.flex_model(dl_gtseq, dl_anm, trail_idcs, attention_mask = attention_mask, sampled_pmpnn_sequence = None, alphabet='aa') #['predicted_flex'][:,:-1,0]
485
+ # gt_flex = batch['gt_flex']
486
+ # # ####
487
+ # import pdb; pdb.set_trace() #check the mask and the gt_flex vs. onthefly computed gt_flex
488
+ # #TODO: here fix the mask for the prottrans and clean this,
489
+ # # the mask should have all 1s where there is sequence or eos token
490
+
491
+ # attention_mask = ...
492
+ # if self.hparams.get_gt_flex_onthefly:
493
+
494
+ # cache_keys = list(batch['title'])
495
+
496
+ # # Check if all cache_keys are in self.gt_flex_cache
497
+ # all_keys_in_cache = all(cache_key in self.model.gt_flex_cache for cache_key in cache_keys)
498
+
499
+ # if not all_keys_in_cache:
500
+ # gt_flex = self.flex_model(None, anm_input, trail_idcs, attention_mask=attention_mask, sampled_pmpnn_sequence=gt_seq, alphabet='pmpnn')['predicted_flex'][:,:-1,0]
501
+ # for key, val in zip(cache_keys, gt_flex):
502
+ # #TODO: iteruje to spravne?
503
+ # self.model.gt_flex_cache[key] = val
504
+ # else:
505
+ # retrieved_gt_flexs = []
506
+ # for key in cache_keys:
507
+ # _gt_flex = self.model.gt_flex_cache[key]
508
+ # retrieved_gt_flexs.append(_gt_flex)
509
+ # gt_flex = torch.cat(retrieved_gt_flexs, dim=0) #TODO: concat spravne?
510
+ # else:
511
+ # raise NotImplementedError('The precomputed data were not realiable.')
512
+ # gt_flex = batch['gt_flex']
513
+ # ###########################################################################
514
+
515
+
516
+ attention_mask = torch.zeros_like(batch['mask'])
517
+ for i in range(attention_mask.size(0)):
518
+ attention_mask[i, :trail_idcs[i]] = 1
519
+
520
+ #Original sequence loss
521
+ seq_loss = self.cross_entropy(pred_log_probs.permute(0,2,1), gt_seq)
522
+ seq_loss = (seq_loss*_mask).sum()/(_mask.sum())
523
+ #New Dynamics-aware loss
524
+ flex_model_input = pred_log_probs.permute(0,2,1)
525
+ pred_flex = self.flex_model(flex_model_input, anm_input, trail_idcs, attention_mask=attention_mask)['predicted_flex'][:,:-1,0]
526
+ #check here that the loss function is working properly (with the masking and all)
527
+ # import pdb; pdb.set_trace()
528
+ _filter_nans_mask = ~torch.isnan(pred_flex) & ~torch.isnan(gt_flex)
529
+ flex_loss = self.flex_loss_fn(pred_flex[_filter_nans_mask]*_mask[_filter_nans_mask], gt_flex[_filter_nans_mask]*_mask[_filter_nans_mask])
530
+ _flex_mask = _mask*_filter_nans_mask
531
+ _flex_mask = _flex_mask.int()
532
+ flex_loss = flex_loss.sum()/_flex_mask.sum()
533
+
534
+ retVal ={'seq_loss':seq_loss, 'flex_loss':flex_loss, 'pred_flex':pred_flex, 'flex_mask':_flex_mask, 'gt_flex':gt_flex}
535
+ if self.hparams.test_engineering and self.hparams.use_dynamics:
536
+ retVal['eng_mask'] = batch['eng_mask']
537
+ retVal['original_gt_flex'] = batch['original_gt_flex']
538
+ retVal['gt_seq'] = batch['S']
539
+ retVal['pred_logprobs'] = pred_log_probs
540
+ return retVal
541
+
542
+
543
+ def configure_loss(self):
544
+ def loss_function(pred_angle, angles, pred_seq, seqs, seq_loss_mask, angle_loss_mask):
545
+ angle_loss = self.MSE(torch.cat([angles[...,:1],torch.sin(angles[...,1:3]), torch.cos(angles[...,1:3])],dim=-1),
546
+ torch.cat([pred_angle[...,:1],torch.sin(pred_angle[...,1:3]), torch.cos(pred_angle[...,1:3])],dim=-1))
547
+
548
+ angle_loss = angle_loss[angle_loss_mask].sum(dim=-1).mean()
549
+ logits = pred_seq.permute(0,2,1)
550
+ seq_loss = self.cross_entropy(logits, seqs)
551
+ seq_loss = seq_loss[seq_loss_mask].mean()
552
+
553
+ metric=Perplexity()
554
+ metric.update(pred_seq[seq_loss_mask][None,...].cpu(), seqs[seq_loss_mask][None,...].cpu())
555
+ perp = metric.compute()
556
+
557
+ return {"angle_loss": angle_loss, "seq_loss": seq_loss, "perp":perp}
558
+
559
+ self.loss_function = loss_function
560
+
561
+ def load_model(self):
562
+ params = OmegaConf.load(f'configs/{self.hparams.model_name}.yaml')
563
+ params.update(self.hparams)
564
+
565
+ if self.hparams.model_name == 'GraphTrans':
566
+ from src.models.graphtrans_model import GraphTrans_Model
567
+ self.model = GraphTrans_Model(params)
568
+
569
+ if self.hparams.model_name == 'StructGNN':
570
+ from src.models.structgnn_model import StructGNN_Model
571
+ self.model = StructGNN_Model(params)
572
+
573
+ if self.hparams.model_name == 'GVP':
574
+ from src.models.gvp_model import GVP_Model
575
+ self.model = GVP_Model(params)
576
+
577
+ if self.hparams.model_name == 'GCA':
578
+ from src.models.gca_model import GCA_Model
579
+ self.model = GCA_Model(params)
580
+
581
+ if self.hparams.model_name == 'AlphaDesign':
582
+ from src.models.alphadesign_model import AlphaDesign_Model
583
+ self.model = AlphaDesign_Model(params)
584
+
585
+ if self.hparams.model_name == 'ProteinMPNN':
586
+ from src.models.proteinmpnn_model import ProteinMPNN_Model
587
+ self.model = ProteinMPNN_Model(params)
588
+
589
+ if self.hparams.model_name == 'ESMIF':
590
+ pass
591
+
592
+ if self.hparams.model_name == 'PiFold':
593
+ from src.models.pifold_model import PiFold_Model
594
+ self.model = PiFold_Model(params)
595
+
596
+ if self.hparams.model_name == 'KWDesign':
597
+ from src.models.kwdesign_model import KWDesign_model#Design_Model
598
+ self.model = KWDesign_model(params) #Design_Model(params) - this required to significantly change the constructor of Design_Model
599
+
600
+ if self.hparams.model_name == 'E3PiFold':
601
+ from src.models.E3PiFold_model import E3PiFold
602
+ self.model = E3PiFold(params)
603
+
604
+ def load_flex_predictor(self):
605
+ from src.models.anm_prottrans import ANMAwareFlexibilityProtTrans
606
+ flex_params = load_yaml_config(f'configs/ANMAwareFlexibilityProtTrans.yaml')
607
+ # flex_params_dict = OmegaConf.to_container(flex_params, resolve=True)
608
+ self.flex_model = ANMAwareFlexibilityProtTrans(**flex_params)
609
+
610
+ # consider turning on the gradients for debug purposes
611
+ self.flex_model.eval()
612
+ for params in self.flex_model.parameters():
613
+ params.requires_grad = False
614
+
615
+ #also pass it to proteinmpnn:
616
+ # self.model.flex_model = self.flex_model
617
+
618
+
619
+ def instancialize(self, Model, **other_args):
620
+ """ Instancialize a model using the corresponding parameters
621
+ from self.hparams dictionary. You can also input any args
622
+ to overwrite the corresponding value in self.hparams.
623
+ """
624
+ class_args = inspect.getargspec(Model.__init__).args[1:]
625
+ inkeys = self.hparams.keys()
626
+ args1 = {}
627
+ for arg in class_args:
628
+ if arg in inkeys:
629
+ args1[arg] = getattr(self.hparams, arg)
630
+ args1.update(other_args)
631
+ return Model(**args1)
Flexpert-Design/predict.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys, warnings, argparse, math, tqdm, datetime
2
+ import pytorch_lightning as pl
3
+ import torch
4
+ from pytorch_lightning.trainer import Trainer
5
+ import pytorch_lightning.callbacks as plc
6
+ import pytorch_lightning.loggers as plog
7
+ from model_interface import MInterface
8
+ from data_interface import DInterface
9
+ from src.tools.logger import SetupCallback, BackupCodeCallback
10
+ from shutil import ignore_patterns
11
+ from transformers import AutoTokenizer
12
+ import numpy as np
13
+ import yaml
14
+ import wandb
15
+ warnings.filterwarnings("ignore")
16
+
17
+ def create_parser():
18
+ parser = argparse.ArgumentParser()
19
+
20
+
21
+ parser.add_argument('--infer_path', type=str, help='Path where to read the data to be predicted and where to save the predictions.')
22
+
23
+ # Set-up parameters
24
+ parser.add_argument('--res_dir', default='./train/results', type=str)
25
+ parser.add_argument('--ex_name', default='debug', type=str)
26
+ parser.add_argument('--check_val_every_n_epoch', default=1, type=int)
27
+ parser.add_argument('--stage', default='predict', type=str) #'fit', 'test' or 'predict'
28
+ parser.add_argument('--val_check_interval', default=0.5, type=float, help='Validation check interval')
29
+
30
+ parser.add_argument('--dataset', default='PDBInference') # AF2DB_dataset, CATH_dataset
31
+ parser.add_argument('--model_name', default='ProteinMPNN', choices=['StructGNN', 'GraphTrans', 'GVP', 'GCA', 'AlphaDesign', 'ESMIF', 'PiFold', 'ProteinMPNN', 'KWDesign', 'E3PiFold'])
32
+ # parser.add_argument('--lr', default=4e-4, type=float, help='Learning rate')
33
+ # parser.add_argument('--lr_scheduler', default='onecycle')
34
+ # parser.add_argument('--offline', default=1, type=int)
35
+ parser.add_argument('--seed', default=111, type=int)
36
+
37
+ parser.add_argument('--num_workers', default=12, type=int)
38
+ parser.add_argument('--pad', default=1024, type=int)
39
+ parser.add_argument('--min_length', default=40, type=int)
40
+ parser.add_argument('--data_root', default='./data/')
41
+
42
+ # Training parameters
43
+ # parser.add_argument('--epoch', default=10, type=int, help='end epoch')
44
+ parser.add_argument('--augment_eps', default=0.0, type=float, help='noise level')
45
+ # parser.add_argument('--gpus', default=1, type=int, help='how many GPUs to train on')
46
+ # parser.add_argument('--weight_decay', default=0.0, type=float, help='Weight decay for optimizer')
47
+
48
+ # # Eval parameters
49
+ # parser.add_argument('--eval_sequences_sampled', default=1, type=int, help='How many sequences to sample in evaluation.')
50
+ # parser.add_argument('--eval_sequences_temperature', default=0, type=float, help='What temperature to use for the sampling in evaluation.')
51
+ # parser.add_argument('--eval_output_dir', default=None, type=str, help='Where to save the evaluation output.')
52
+
53
+ # Model parameters
54
+ parser.add_argument('--use_dist', default=1, type=int)
55
+ parser.add_argument('--use_product', default=0, type=int)
56
+ parser.add_argument('--use_pmpnn_checkpoint', default=1, type=int, help='By 1 or 0 decide whether to start with pretrained ProteinMPNN.')
57
+ parser.add_argument('--checkpoint_path', type=str, default=None, help='Path to the model checkpoint to load weights from')
58
+
59
+ # Dynamics aware parameters
60
+ parser.add_argument('--use_dynamics', default=0, type=int)
61
+ # parser.add_argument('--flex_loss_coeff', default=0.5, type=float)
62
+ # parser.add_argument('--get_gt_flex_onthefly', default=0, type=int, help='Flag to get ground truth flexibility on-the-fly (with subsequent caching)')
63
+ parser.add_argument('--init_flex_features', default=1, type=int, help="Set to 0 if no flexibility information should be passed on input to the node features h_V")
64
+ # parser.add_argument('--loss_fn', default='MSE', type=str, help= 'Define what loss to use. Choose MSE, L1 or DPO.')
65
+ # parser.add_argument('--grad_normalization', default=1, type=int, help="Set to 0 if the gradients of the seq and flex losses should not be normalized.")
66
+ # parser.add_argument('--test_engineering', default=0, type=int, help="In this main.py should be set to 0 to not overwrite the training dataset.")
67
+
68
+ args = parser.parse_args()
69
+ return args
70
+
71
+
72
+ if __name__ == "__main__":
73
+
74
+ args = create_parser()
75
+ args.batch_size = 1
76
+ print('In the predict stage, defaulting batch size to 1.')
77
+
78
+ assert args.use_dynamics == 0, "In the inference script this should be set to 0."
79
+
80
+ if not os.path.exists(args.infer_path):
81
+ os.makedirs(args.infer_path)
82
+
83
+ if (len(args.infer_path) > 0 or args.dataset=='PDBInference') and (len(args.infer_path) == 0 or args.dataset!='PDBInference'):
84
+ raise ValueError("You should only use --infer_path with --dataset 'PDBInference' and vice versa.")
85
+
86
+
87
+ # Load model weights from checkpoint if provided
88
+ if args.checkpoint_path is not None:
89
+ trained_model_path = args.checkpoint_path
90
+ print(f"Loading model weights from checkpoint passed by argument: {trained_model_path}")
91
+ else:
92
+ with open('configs/Flexpert-Design-inference.yaml', 'r') as f:
93
+ config = yaml.load(f, Loader=yaml.FullLoader)
94
+ trained_model_path = config['pmpnn_model_path']
95
+ print(f"Loading model weights from checkpoint specified in Flexpert-Design-inference.yaml: {trained_model_path}")
96
+
97
+ if os.path.exists(trained_model_path):
98
+ print(f"Rewriting the path to the Flexpert-Design trained ProteinMPNN weights in the model interface.")
99
+ args.starting_checkpoint_path = trained_model_path
100
+ else:
101
+ raise FileNotFoundError(f"Checkpoint file not found at {trained_model_path}")
102
+
103
+ pl.seed_everything(args.seed)
104
+
105
+ data_module = DInterface(**vars(args))
106
+
107
+ data_module.setup(stage='predict')
108
+
109
+ model = MInterface(**vars(args))
110
+
111
+
112
+ trainer_config = {
113
+ 'devices': 1,
114
+ 'max_epochs': 1,
115
+ 'num_nodes': 1,
116
+ "strategy": 'ddp',
117
+ "precision": '32',
118
+ 'accelerator': 'gpu',
119
+ 'val_check_interval': args.val_check_interval,
120
+ 'check_val_every_n_epoch': args.check_val_every_n_epoch
121
+ }
122
+
123
+ trainer = Trainer(**trainer_config)
124
+
125
+ predictions = trainer.predict(model, data_module)
126
+
127
+ tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D", cache_dir='./cache_dir/') # mask token: 32
128
+
129
+
130
+ serializable_predictions = []
131
+ for pred_idx, pred in enumerate(predictions):
132
+ logprobs = pred['log_probs'].cpu().numpy()[0] # [L, 21]
133
+ pmpnn_alphabet_tokens_argmax = logprobs.argmax(axis=-1) # [L]
134
+
135
+ aa_sequence = ''.join(tokenizer.decode(pmpnn_alphabet_tokens_argmax, skip_special_tokens=True).split())
136
+
137
+ # Get probability of the predicted sequence
138
+ seq_probs = np.exp(logprobs.max(axis=-1)) # [L]
139
+ avg_prob = float(np.mean(seq_probs))
140
+
141
+ serializable_predictions.append({
142
+ 'prediction_id': pred['batch']['title'][0],
143
+ 'amino_acid_sequence': aa_sequence
144
+ })
145
+
146
+ with open(f'{args.infer_path}/predictions.txt', 'w') as f:
147
+ for pred in serializable_predictions:
148
+ f.write(f'>{pred["prediction_id"]}\n{pred["amino_acid_sequence"]}\n')
Flexpert-Design/predict_example/1ah7_A.pdb ADDED
The diff for this file is too large to render. See raw diff
 
Flexpert-Design/predict_example/1ah7_A_instructions.csv ADDED
@@ -0,0 +1 @@
 
 
1
+ 0.4214228391647339, 0.37651416659355164, 0.1882496476173401, 0.13774731755256653, 0.11560429632663727, 0.12345632910728455, 0.11075370758771896, 0.09350624680519104, 0.06162628158926964, 0.08504123985767365, 0.05511573329567909, 0.03457929939031601, 0.018956221640110016, 0.05267956107854843, 0.021582268178462982, 0.019682325422763824, 0.005200381390750408, 0.01862833835184574, 0.037708550691604614, 0.02962341532111168, 0.0414130762219429, 0.032966580241918564, 0.04219468683004379, 0.043324653059244156, 0.038419052958488464, 0.06062019616365433, 0.0754077360033989, 0.09575366973876953, 0.09765047580003738, 0.1067374125123024, 0.08417803794145584, 0.09050130844116211, 0.07099245488643646, 0.06242087855935097, 0.046906158328056335, 0.024977944791316986, 0.04039282351732254, 0.04056069627404213, 0.02624698355793953, 0.014836937189102173, 0.033674903213977814, 0.03443623706698418, 0.04525991156697273, 0.05213414505124092, 0.02986733242869377, 0.01742064766585827, 0.03752005845308304, 0.02649688348174095, 0.02672465145587921, 0.03430021554231644, 0.011848143301904202, 0.03361964225769043, 0.027863629162311554, 0.03575271740555763, 0.041042227298021317, 0.08238421380519867, 0.08222152292728424, 0.10173829644918442, 0.11664807796478271, 0.13249793648719788, 0.14384658634662628, 0.1345130354166031, 0.13609081506729126, 0.12496572732925415, 0.10717709362506866, 0.08230947703123093, 0.07971317321062088, 0.07025592774152756, 0.06319792568683624, 0.06464961171150208, 0.04482023045420647, 0.051742203533649445, 0.07986844331026077, 0.09591078013181686, 0.11425718665122986, 0.11205209791660309, 0.10624780505895615, 0.10313349217176437, 0.13002970814704895, 0.13183605670928955, 0.15288424491882324, 0.14854931831359863, 0.13990001380443573, 0.09912189096212387, 0.09130637347698212, 0.07575594633817673, 0.061887726187705994, 0.06014912575483322, 0.0577777624130249, 0.051302842795848846, 0.03530939295887947, 0.040248580276966095, 0.0013590790331363678, 0.015310827642679214, 0.03272499516606331, 0.02609187364578247, 0.0054176910780370235, 0.05427498370409012, 0.051064278930425644, 0.06116481125354767, 0.06309916824102402, 0.08470715582370758, 0.1002785935997963, 0.10495362430810928, 0.09807638078927994, 0.0662725567817688, 0.06513857841491699, 0.048988863825798035, 0.029838263988494873, 0.025865966454148293, 0.02097484841942787, 0.014891650527715683, 0.024081528186798096, 0.045654937624931335, 0.052093200385570526, 0.017663143575191498, 0.02189275622367859, 0.08543915301561356, 0.03505314886569977, 0.019413039088249207, 0.045589953660964966, 0.06793230772018433, 0.041016142815351486, 0.05003933981060982, 0.053235944360494614, 0.05916681885719299, 0.058036819100379944, 0.060588233172893524, 0.07040168344974518, 0.07345925271511078, 0.08298061043024063, 0.09419634193181992, 0.11146273463964462, 0.1405934989452362, 0.15075145661830902, 0.13765454292297363, 0.13978315889835358, 0.1482282280921936, 0.1423584669828415, 0.10484395921230316, 0.07584157586097717, 0.06757079809904099, 0.10134144872426987, 0.08083963394165039, 0.07369125634431839, 0.05454648658633232, 0.08305331319570541, 0.07765821367502213, 0.06511223316192627, 0.056034114211797714, 0.08081831783056259, 0.08526752144098282, 0.07231731712818146, 0.07028429210186005, 0.08094073832035065, 0.06563611328601837, 0.07806897908449173, 0.0859430730342865, 0.08600828796625137, 0.08605027943849564, 0.08578154444694519, 0.07862624526023865, 0.07963275909423828, 0.06170313060283661, 0.05005127564072609, 0.05146761238574982, 0.05499078333377838, 0.059220947325229645, 0.06969373673200607, 0.05268307402729988, 0.06721088290214539, 0.04827176779508591, 0.029251907020807266, 0.04153018817305565, 0.03697451949119568, 0.025836892426013947, 0.04521643742918968, 0.05554598197340965, 0.06007472425699234, 0.04923863708972931, 0.06502534449100494, 0.04392743483185768, 0.036296453326940536, 0.0436030775308609, 0.05658774450421333, 0.034551363438367844, 0.049478061497211456, 0.059964731335639954, 0.07313965260982513, 0.062442418187856674, 0.06896451860666275, 0.08025174587965012, 0.08270157873630524, 0.09252781420946121, 0.09688305854797363, 0.11343701928853989, 0.11080081015825272, 0.12969090044498444, 0.10972093790769577, 0.14756494760513306, 0.1637764275074005, 0.18948377668857574, 0.18522633612155914, 0.18577809631824493, 0.20364972949028015, 0.17919236421585083, 0.1657918244600296, 0.1515847146511078, 0.11915461719036102, 0.10438721626996994, 0.10422837734222412, 0.08969437330961227, 0.07429874688386917, 0.07801567018032074, 0.06531910598278046, 0.05813637375831604, 0.04699350893497467, 0.05086237192153931, 0.060301560908555984, 0.04986414313316345, 0.050366174429655075, 0.05464963987469673, 0.05319518595933914, 0.04274186119437218, 0.047863349318504333, 0.036163944751024246, 0.03777360916137695, 0.04280579090118408, 0.04440606012940407, 0.04717888683080673, 0.02609632909297943, 0.05858827754855156, 0.050790246576070786, 0.03004802018404007, 0.04584358632564545, 0.05146845430135727, 0.039567168802022934, 0.03470978885889053, 0.045542243868112564, 0.05142106115818024, 0.05224252864718437, 0.07936417311429977, 0.11145134270191193, 0.14930342137813568, 0.21797531843185425
Flexpert-Design/predict_example/compare_seqs.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # read in the predictions.txt file
2
+ # take the sequence from there
3
+ import argparse
4
+ import os
5
+ import biotite.structure.io.pdb as pdb
6
+ from biotite.structure import get_residues
7
+
8
+ def compare_sequences(pdb_code):
9
+ # Read the predicted sequence from predictions.txt
10
+ with open('predict_example/predictions.txt', 'r') as f:
11
+ predictions = f.readlines()
12
+ # Extract the sequence (skip the header line that starts with '>')
13
+ predicted_seqs = {}
14
+ current_pdb = None
15
+
16
+ for line in predictions:
17
+ if line.startswith('>'):
18
+ current_pdb = line.strip()[1:] # Remove the '>' character
19
+ elif current_pdb and line.strip():
20
+ predicted_seqs[current_pdb] = line.strip()
21
+
22
+ # Use the provided pdb_code to get the corresponding sequence
23
+ predicted_seq = predicted_seqs.get(pdb_code, "")
24
+
25
+ # Read the PDB file
26
+ pdb_file = f'predict_example/{pdb_code}.pdb'
27
+ with open(pdb_file, 'r') as f:
28
+ structure = pdb.PDBFile.read(f)
29
+ atoms = pdb.get_structure(structure)
30
+
31
+ # Get residue names from the structure
32
+ residues = get_residues(atoms)[1]
33
+ # Convert three-letter codes to one-letter codes
34
+ aa_dict = {
35
+ 'ALA': 'A', 'CYS': 'C', 'ASP': 'D', 'GLU': 'E', 'PHE': 'F',
36
+ 'GLY': 'G', 'HIS': 'H', 'ILE': 'I', 'LYS': 'K', 'LEU': 'L',
37
+ 'MET': 'M', 'ASN': 'N', 'PRO': 'P', 'GLN': 'Q', 'ARG': 'R',
38
+ 'SER': 'S', 'THR': 'T', 'VAL': 'V', 'TRP': 'W', 'TYR': 'Y'
39
+ }
40
+ pdb_seq = ''.join([aa_dict.get(res, 'X') for res in residues])
41
+
42
+ # Compare the two sequences
43
+ match_count = sum(1 for a, b in zip(predicted_seq, pdb_seq) if a == b)
44
+ total_length = max(len(predicted_seq), len(pdb_seq))
45
+ percent_identity = (match_count / min(len(predicted_seq), len(pdb_seq))) * 100
46
+
47
+ # Print the result
48
+ print(f"Predicted sequence: {predicted_seq}")
49
+ print(f"PDB sequence: {pdb_seq}")
50
+ print(f"Sequence length - Predicted: {len(predicted_seq)}, PDB: {len(pdb_seq)}")
51
+ print(f"Matching residues: {match_count}/{min(len(predicted_seq), len(pdb_seq))}")
52
+ print(f"Percent identity: {percent_identity:.2f}%")
53
+
54
+ if __name__ == "__main__":
55
+ parser = argparse.ArgumentParser(description='Compare predicted sequence with PDB sequence')
56
+ parser.add_argument('--pdb_code', type=str, help='PDB code (e.g., 1ah7_A)')
57
+ args = parser.parse_args()
58
+
59
+ compare_sequences(args.pdb_code)
Flexpert-Design/predict_example/predictions.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ >1ah7_A
2
+ GSSLDKTEVEESTGLRLVNQAIDILKNDKTRVDKEYLDLIEKYKPELQEGIYKAYHSEPYNDNGKFSRHYYNPVVHTSRIPDAVTAAETGSHYYNKAGEYYKKGDYEEAYFYLGIALAYLSDACNPMNASGYTNESFPEGFYEALQKYVCTIAKKYENTTGEPYYNLTGKNPKDHIRGAATKARELFSGIYHERVKEDFEKGKTSEEARLKWRERIEPQLGKLLLFAQRVMAGAIERFFDTAGGL
Flexpert-Design/requirements.txt ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Bio
2
+ biotite
3
+ fair-esm
4
+ evaluate
5
+ joblib
6
+ matplotlib
7
+ numpy
8
+ omegaconf
9
+ packaging
10
+ pandas
11
+ python_dateutil
12
+ pytorch_lightning
13
+ PyYAML
14
+ requests
15
+ safetensors
16
+ scikit_learn
17
+ scipy
18
+ torch
19
+ torch_geometric
20
+ torcheval
21
+ torchmetrics
22
+ tqdm
23
+ transformers
Flexpert-Design/src/__init__.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) CAIRI AI Lab. All rights reserved
2
+
3
+ import warnings
4
+ from packaging.version import parse
5
+
6
+ from .version import __version__
7
+
8
+
9
+ def digit_version(version_str: str, length: int = 4):
10
+ """Convert a version string into a tuple of integers.
11
+
12
+ This method is usually used for comparing two versions. For pre-release
13
+ versions: alpha < beta < rc.
14
+
15
+ Args:
16
+ version_str (str): The version string.
17
+ length (int): The maximum number of version levels. Default: 4.
18
+
19
+ Returns:
20
+ tuple[int]: The version info in digits (integers).
21
+ """
22
+ version = parse(version_str)
23
+ assert version.release, f'failed to parse version {version_str}'
24
+ release = list(version.release)
25
+ release = release[:length]
26
+ if len(release) < length:
27
+ release = release + [0] * (length - len(release))
28
+ if version.is_prerelease:
29
+ mapping = {'a': -3, 'b': -2, 'rc': -1}
30
+ val = -4
31
+ # version.pre can be None
32
+ if version.pre:
33
+ if version.pre[0] not in mapping:
34
+ warnings.warn(f'unknown prerelease version {version.pre[0]}, '
35
+ 'version checking may go wrong')
36
+ else:
37
+ val = mapping[version.pre[0]]
38
+ release.extend([val, version.pre[-1]])
39
+ else:
40
+ release.extend([val, 0])
41
+
42
+ elif version.is_postrelease:
43
+ release.extend([1, version.post])
44
+ else:
45
+ release.extend([0, 0])
46
+ return tuple(release)
47
+
48
+
49
+ __all__ = ['__version__', 'digit_version']
Flexpert-Design/src/datasets/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) CAIRI AI Lab. All rights reserved
2
+
3
+ from .alphafold_dataset import AlphaFoldDataset
4
+ from .cath_dataset import CATHDataset
5
+ from .dataloader import load_data
6
+ from .featurizer import (featurize_AF, featurize_GTrans, featurize_GVP,
7
+ featurize_ProteinMPNN, featurize_Inversefolding)
8
+ from .ts_dataset import TSDataset
9
+
10
+ __all__ = [
11
+ 'AlphaFoldDataset', 'CATHDataset', 'TSDataset',
12
+ 'load_data',
13
+ 'featurize_AF', 'featurize_GTrans', 'featurize_GVP',
14
+ 'featurize_ProteinMPNN', 'featurize_Inversefolding'
15
+ ]
Flexpert-Design/src/datasets/alphafold_dataset.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import os.path as osp
3
+ import json
4
+ import numpy as np
5
+ import pickle as cPickle
6
+
7
+ import torch.utils.data as data
8
+ from src.datasets.utils import cached_property
9
+
10
+
11
+ class AlphaFoldDataset(data.Dataset):
12
+ def __init__(self, path='./', upid='', mode='train', max_length=500, limit_length=1, joint_data=0):
13
+
14
+ self.path = path
15
+ self.upid = upid
16
+ self.max_length = max_length
17
+ self.limit_length = limit_length
18
+ self.joint_data = joint_data
19
+
20
+ if mode in ['train', 'valid', 'test']:
21
+ self.data = self.cache_data[mode]
22
+
23
+ if mode == 'all':
24
+ self.data = self.cache_data['train'] + self.cache_data['valid'] + self.cache_data['test']
25
+
26
+ self.lengths = np.array([ len(sample['seq']) for sample in self.data])
27
+ self.max_len = np.max(self.lengths)
28
+ self.min_len = np.min(self.lengths)
29
+
30
+ def _raw_data(self, path, upid):
31
+ if not os.path.exists(path):
32
+ raise "no such file:{} !!!".format(path)
33
+ else:
34
+ path = osp.join(path, upid)
35
+ data_ = cPickle.load(open(path+'/data_{}.pkl'.format(upid),'rb'))
36
+ score_ = cPickle.load(open(path+'/data_{}_score.pkl'.format(upid),'rb'))
37
+ for i in range(len(data_)):
38
+ data_[i]['score'] = score_[i]['res_score']
39
+ return data_
40
+
41
+ def _data_info(self, data):
42
+ len_inds = []
43
+ seq2ind = {}
44
+ for ind, temp in enumerate(data):
45
+ if self.limit_length:
46
+ if 30 < len(temp['seq']) and len(temp['seq']) < self.max_length:
47
+ # 'title', 'seq', 'CA', 'C', 'O', 'N'
48
+ len_inds.append(ind)
49
+ seq2ind[temp['seq']] = ind
50
+ else:
51
+ len_inds.append(ind)
52
+ seq2ind[temp['seq']] = ind
53
+ return len_inds, seq2ind
54
+
55
+ def get_data(self, path, upid, **kwargs):
56
+ data_ = self._raw_data(path, upid)
57
+ path = osp.join(path, upid)
58
+
59
+ file_name = 'split_clu_l.json' if self.limit_length else 'split_clu.json'
60
+
61
+ assert os.path.exists(osp.join(path, file_name))
62
+ split = json.load(open(osp.join(path, file_name),'r'))
63
+ data_dict = {'train':[data_[i] for i in split['train']],
64
+ 'valid':[data_[i] for i in split['valid']],
65
+ 'test':[data_[i] for i in split['test']]}
66
+ return data_dict
67
+
68
+ def get_full_data(self, path, **kwargs):
69
+ datanames = [dataname for dataname in os.listdir(path) if ('_v2' in dataname)]
70
+ file_name = 'split_clu_l.json' if self.limit_length else 'split_clu.json'
71
+ assert os.path.exists(osp.join(path, 'full', file_name))
72
+ split = json.load(open(osp.join(path, 'full', file_name),'r'))
73
+ return split
74
+
75
+ @cached_property
76
+ def cache_data(self): # TODO: joint_data
77
+ path = self.path
78
+ upid = self.upid
79
+ if self.joint_data:
80
+ datanames = [dataname for dataname in os.listdir(path) if ('_v2' in dataname)]
81
+ data_dict = {'train':[], 'valid':[], 'test':[]}
82
+ full_inds = self.get_full_data(path)
83
+
84
+ for dataname in datanames:
85
+ temp = self._raw_data(path, dataname)
86
+ train_idx, valid_idx, test_idx = map(lambda fold: full_inds[dataname][fold], ['train', 'valid', 'test'])
87
+ data_dict['train'] += [temp[i] for i in train_idx]
88
+ data_dict['valid'] += [temp[i] for i in valid_idx]
89
+
90
+ data_test = []
91
+ for i in test_idx:
92
+ item = temp[i]
93
+ item['category'] = dataname
94
+ data_test.append(temp[i])
95
+
96
+ data_dict['test'] += data_test
97
+
98
+ else:
99
+ data_dict = self.get_data(path, upid)
100
+ for item in data_dict['test']:
101
+ item['category'] = upid
102
+
103
+ return data_dict
104
+
105
+ def change_mode(self, mode):
106
+ self.data = self.cache_data[mode]
107
+
108
+ def __len__(self):
109
+ return len(self.data)
110
+
111
+ def __getitem__(self, index):
112
+ return self.data[index]
Flexpert-Design/src/datasets/atlas_dataset.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import numpy as np
4
+ from tqdm import tqdm
5
+ import random
6
+ import pdb
7
+ import torch.utils.data as data
8
+ from .utils import cached_property
9
+ from transformers import AutoTokenizer
10
+
11
+ class AtlasDataset(data.Dataset):
12
+ def __init__(self, path='./', split='train', max_length=500, test_name='All', data = None, removeTS=0):
13
+ self.path = path
14
+ self.mode = split
15
+ self.max_length = max_length
16
+ self.test_name = test_name
17
+ self.removeTS = removeTS
18
+ if self.removeTS:
19
+ self.remove = json.load(open(self.path+'/remove.json', 'r'))['remove']
20
+
21
+ if data is None:
22
+ if self.mode in ['eval','predict']:
23
+ self.data = self.cache_data['test'] #This calls the cache_data property
24
+ else:
25
+ self.data = self.cache_data[split] #This calls the cache_data property
26
+ else:
27
+ self.data = data
28
+
29
+ self.tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D", cache_dir="./cache_dir/")
30
+
31
+ @cached_property
32
+ def cache_data(self):
33
+ alphabet='ACDEFGHIKLMNPQRSTVWY'
34
+ alphabet_set = set([a for a in alphabet])
35
+ print("path is: ", self.path)
36
+
37
+ if not os.path.exists(self.path):
38
+ raise "no such file:{} !!!".format(self.path)
39
+ else:
40
+
41
+ with open(self.path+'/chain_set.jsonl') as f:
42
+ lines = f.readlines()
43
+ data_list = []
44
+
45
+ for line in tqdm(lines):
46
+ entry = json.loads(line)
47
+
48
+ if self.removeTS and entry['name'] in self.remove:
49
+ continue
50
+ seq = entry['seq']
51
+
52
+ for key, val in entry['coords'].items():
53
+ entry['coords'][key] = np.asarray(val)
54
+
55
+ bad_chars = set([s for s in seq]).difference(alphabet_set)
56
+
57
+ if len(bad_chars) == 0:
58
+ if len(entry['seq']) <= self.max_length:
59
+ chain_length = len(entry['seq'])
60
+ chain_mask = np.ones(chain_length)
61
+ data_list.append({
62
+ 'title':entry['name'],
63
+ 'seq':entry['seq'],
64
+ 'CA':entry['coords']['CA'],
65
+ 'C':entry['coords']['C'],
66
+ 'O':entry['coords']['O'],
67
+ 'N':entry['coords']['N'],
68
+ 'chain_mask': chain_mask,
69
+ 'chain_encoding': 1*chain_mask
70
+ })
71
+
72
+ with open(self.path+'/chain_set_splits.json') as f:
73
+ dataset_splits = json.load(f)
74
+
75
+ if self.test_name == 'L100':
76
+ with open(self.path+'/test_split_L100.json') as f:
77
+ test_splits = json.load(f)
78
+ dataset_splits['test'] = test_splits['test']
79
+
80
+ if self.test_name == 'sc':
81
+ with open(self.path+'/test_split_sc.json') as f:
82
+ test_splits = json.load(f)
83
+ dataset_splits['test'] = test_splits['test']
84
+
85
+ name2set = {}
86
+ name2set.update({name:'train' for name in dataset_splits['train']})
87
+ name2set.update({name:'valid' for name in dataset_splits['validation']})
88
+ name2set.update({name:'test' for name in dataset_splits['test']})
89
+
90
+ data_dict = {'train':[],'valid':[],'test':[]}
91
+ for data in data_list:
92
+ #raise ValueError("only 1015 sequences get loaded to the predict set!!! why not whole 1390??? FIX!")
93
+ if name2set.get(data['title']): #This was causing the trouble with empty datasets - missmatch of names in the chain_set and chain_set_split
94
+ if name2set[data['title']] == 'train':
95
+ data_dict['train'].append(data)
96
+
97
+ if name2set[data['title']] == 'valid':
98
+ data_dict['valid'].append(data)
99
+
100
+ if name2set[data['title']] == 'test':
101
+ data['category'] = 'Unkown'
102
+ data['score'] = 100.0
103
+ data_dict['test'].append(data)
104
+ else:
105
+ import pdb; pdb.set_trace()
106
+ return data_dict
107
+
108
+ def change_mode(self, mode):
109
+ self.data = self.cache_data[mode]
110
+
111
+ def __len__(self):
112
+ return len(self.data)
113
+
114
+ def get_item(self, index):
115
+ return self.data[index]
116
+
117
+ def __getitem__(self, index):
118
+ item = self.data[index]
119
+ L = len(item['seq'])
120
+ if L>self.max_length:
121
+ # 计算截断的最大索引
122
+ max_index = L - self.max_length
123
+ # 生成随机的截断索引
124
+ truncate_index = random.randint(0, max_index)
125
+ # 进行截断
126
+ item['seq'] = item['seq'][truncate_index:truncate_index+self.max_length]
127
+ item['CA'] = item['CA'][truncate_index:truncate_index+self.max_length]
128
+ item['C'] = item['C'][truncate_index:truncate_index+self.max_length]
129
+ item['O'] = item['O'][truncate_index:truncate_index+self.max_length]
130
+ item['N'] = item['N'][truncate_index:truncate_index+self.max_length]
131
+ item['chain_mask'] = item['chain_mask'][truncate_index:truncate_index+self.max_length]
132
+ item['chain_encoding'] = item['chain_encoding'][truncate_index:truncate_index+self.max_length]
133
+ return item
Flexpert-Design/src/datasets/casp_dataset.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import numpy as np
4
+ import torch.utils.data as data
5
+
6
+
7
+ class CASPDataset(data.Dataset):
8
+ def __init__(self, path = './', split='test'):
9
+ if not os.path.exists(path):
10
+ raise "no such file:{} !!!".format(path)
11
+ else:
12
+ with open(os.path.join(path,'casp15.jsonl')) as f:
13
+ lines = f.readlines()
14
+
15
+ # casp15_data = json.load(open(path+'casp15.json', 'r'))
16
+
17
+ alphabet='ACDEFGHIKLMNPQRSTVWY'
18
+ alphabet_set = set([a for a in alphabet])
19
+
20
+ self.data = []
21
+ for line in lines:
22
+ entry = json.loads(line)
23
+ seq = entry['seq']
24
+
25
+ for key, val in entry['coords'].items():
26
+ entry['coords'][key] = np.asarray(val)
27
+
28
+ bad_chars = set([s for s in seq]).difference(alphabet_set)
29
+
30
+ if len(bad_chars) == 0:
31
+ chain_length = len(entry['seq'])
32
+ chain_mask = np.ones(chain_length)
33
+ self.data.append({
34
+ 'title':entry['name'],
35
+ 'seq':entry['seq'],
36
+ 'CA':entry['coords']['CA'],
37
+ 'C':entry['coords']['C'],
38
+ 'O':entry['coords']['O'],
39
+ 'N':entry['coords']['N'],
40
+ 'chain_mask': chain_mask,
41
+ 'chain_encoding': 1*chain_mask,
42
+ 'classification': entry['classification']
43
+ })
44
+
45
+ def __len__(self):
46
+ return len(self.data)
47
+
48
+ def get_item(self, index):
49
+ return self.data[index]
50
+
51
+ def __getitem__(self, index):
52
+ return self.data[index]
53
+
54
+ if __name__ == '__main__':
55
+ dataset = CASPDataset('/gaozhangyang/experiments/OpenCPD/data/casp15/')
56
+ for data in dataset:
57
+ print(data)
Flexpert-Design/src/datasets/cath_dataset.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import numpy as np
4
+ from tqdm import tqdm
5
+ import random
6
+ import torch.utils.data as data
7
+ from .utils import cached_property
8
+ from transformers import AutoTokenizer
9
+ from src.tools.utils import load_yaml_config
10
+
11
+ class CATHDataset(data.Dataset):
12
+ def __init__(self, path='./', split='train', max_length=500, test_name='All', data = None, removeTS=0, version=4.2, data_jsonl_name='/chain_set.jsonl'):
13
+ self.version = version
14
+ self.path = path
15
+ self.mode = split
16
+ self.max_length = max_length
17
+ self.test_name = test_name
18
+ self.removeTS = removeTS
19
+ self.data_jsonl_name = data_jsonl_name
20
+
21
+ self.using_dynamics = data_jsonl_name == load_yaml_config('/scratch/project/fta-24-31/koubapet/ProteinInvBench/src/models/configs/FlexibilityProtTrans.yaml')['data_jsonl_name']
22
+
23
+ print(self.data_jsonl_name)
24
+ if self.removeTS:
25
+ self.remove = json.load(open(self.path+'/remove.json', 'r'))['remove']
26
+
27
+ if data is None:
28
+ if split == 'predict':
29
+ _split = 'valid'
30
+ print('In predict mode for CATH4.3 using VALIDATION split as the data. Consider switching to TEST set.')
31
+ else:
32
+ _split = split
33
+ self.data = self.cache_data[_split]
34
+ else:
35
+ self.data = data
36
+
37
+ self.tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D", cache_dir="./cache_dir/")
38
+
39
+ @cached_property
40
+ def cache_data(self):
41
+ alphabet='ACDEFGHIKLMNPQRSTVWY'
42
+ alphabet_set = set([a for a in alphabet])
43
+ print("path is: ", self.path)
44
+ if not os.path.exists(self.path):
45
+ raise "no such file:{} !!!".format(self.path)
46
+ else:
47
+ with open(self.path+'/'+self.data_jsonl_name) as f:
48
+ lines = f.readlines()
49
+ data_list = []
50
+ for line in tqdm(lines):
51
+ entry = json.loads(line)
52
+ if self.removeTS and entry['name'] in self.remove:
53
+ continue
54
+ seq = entry['seq']
55
+
56
+ for key, val in entry['coords'].items():
57
+ entry['coords'][key] = np.asarray(val)
58
+
59
+ bad_chars = set([s for s in seq]).difference(alphabet_set)
60
+
61
+ if len(bad_chars) == 0:
62
+ if len(entry['seq']) <= self.max_length:
63
+ chain_length = len(entry['seq'])
64
+ chain_mask = np.ones(chain_length)
65
+ data_list.append({
66
+ 'title':entry['name'],
67
+ 'seq':entry['seq'],
68
+ 'CA':entry['coords']['CA'],
69
+ 'C':entry['coords']['C'],
70
+ 'O':entry['coords']['O'],
71
+ 'N':entry['coords']['N'],
72
+ 'chain_mask': chain_mask,
73
+ 'chain_encoding': 1*chain_mask
74
+ })
75
+ if self.using_dynamics: #TODO: pass this bool properly
76
+ data_list[-1]['norm_bfactors'] = entry['bfactor']
77
+
78
+ if self.version==4.2:
79
+ with open(self.path+'/chain_set_splits.json') as f:
80
+ dataset_splits = json.load(f)
81
+
82
+ if self.version==4.3:
83
+ with open(self.path+'/chain_set_splits.json') as f:
84
+ dataset_splits = json.load(f)
85
+
86
+ if self.test_name == 'L100':
87
+ with open(self.path+'/test_split_L100.json') as f:
88
+ test_splits = json.load(f)
89
+ dataset_splits['test'] = test_splits['test']
90
+
91
+ if self.test_name == 'sc':
92
+ with open(self.path+'/test_split_sc.json') as f:
93
+ test_splits = json.load(f)
94
+ dataset_splits['test'] = test_splits['test']
95
+
96
+ name2set = {}
97
+ name2set.update({name:'train' for name in dataset_splits['train']})
98
+ name2set.update({name:'valid' for name in dataset_splits['validation']})
99
+ name2set.update({name:'test' for name in dataset_splits['test']})
100
+
101
+ data_dict = {'train':[],'valid':[],'test':[]}
102
+ for data in data_list:
103
+ if name2set.get(data['title']):
104
+ if name2set[data['title']] == 'train':
105
+ data_dict['train'].append(data)
106
+
107
+ if name2set[data['title']] == 'valid':
108
+ data_dict['valid'].append(data)
109
+
110
+ if name2set[data['title']] == 'test':
111
+ data['category'] = 'Unkown'
112
+ data['score'] = 100.0
113
+ data_dict['test'].append(data)
114
+ return data_dict
115
+
116
+ def change_mode(self, mode):
117
+ self.data = self.cache_data[mode]
118
+
119
+ def __len__(self):
120
+ return len(self.data)
121
+
122
+ def get_item(self, index):
123
+ return self.data[index]
124
+
125
+ def __getitem__(self, index):
126
+ item = self.data[index]
127
+ L = len(item['seq'])
128
+ if L>self.max_length:
129
+ # 计算截断的最大索引
130
+ max_index = L - self.max_length
131
+ # 生成随机的截断索引
132
+ truncate_index = random.randint(0, max_index)
133
+ # 进行截断
134
+ item['seq'] = item['seq'][truncate_index:truncate_index+self.max_length]
135
+ item['CA'] = item['CA'][truncate_index:truncate_index+self.max_length]
136
+ item['C'] = item['C'][truncate_index:truncate_index+self.max_length]
137
+ item['O'] = item['O'][truncate_index:truncate_index+self.max_length]
138
+ item['N'] = item['N'][truncate_index:truncate_index+self.max_length]
139
+ item['chain_mask'] = item['chain_mask'][truncate_index:truncate_index+self.max_length]
140
+ item['chain_encoding'] = item['chain_encoding'][truncate_index:truncate_index+self.max_length]
141
+ return item
Flexpert-Design/src/datasets/dataloader.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import random
3
+ import os.path as osp
4
+
5
+ import torch
6
+ import torch.utils.data as data
7
+ import pdb
8
+
9
+ from .cath_dataset import CATHDataset
10
+ from .alphafold_dataset import AlphaFoldDataset
11
+ from .ts_dataset import TSDataset
12
+ from .casp_dataset import CASPDataset
13
+ from .mpnn_dataset import MPNNDataset
14
+ from .featurizer import (featurize_AF, featurize_GTrans, featurize_GVP,
15
+ featurize_ProteinMPNN, featurize_Inversefolding)
16
+ from .fast_dataloader import DataLoaderX
17
+
18
+ class GTransDataLoader(torch.utils.data.DataLoader):
19
+ def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0,
20
+ collate_fn=None, **kwargs):
21
+ super(GTransDataLoader, self).__init__(dataset, batch_size, shuffle, sampler, batch_sampler, num_workers, collate_fn,**kwargs)
22
+ self.featurizer = collate_fn
23
+
24
+
25
+ class BatchSampler(data.Sampler):
26
+ '''
27
+ From https://github.com/jingraham/neurips19-graph-protein-design.
28
+
29
+ A `torch.utils.data.Sampler` which samples batches according to a
30
+ maximum number of graph nodes.
31
+
32
+ :param node_counts: array of node counts in the dataset to sample from
33
+ :param max_nodes: the maximum number of nodes in any batch,
34
+ including batches of a single element
35
+ :param shuffle: if `True`, batches in shuffled order
36
+ '''
37
+ def __init__(self, node_counts, max_nodes=3000, shuffle=True):
38
+ self.node_counts = node_counts
39
+ self.idx = [i for i in range(len(node_counts))
40
+ if node_counts[i] <= max_nodes]
41
+ self.shuffle = shuffle
42
+ self.max_nodes = max_nodes
43
+ self._form_batches()
44
+
45
+ def _form_batches(self):
46
+ self.batches = []
47
+ if self.shuffle: random.shuffle(self.idx)
48
+ idx = self.idx
49
+ while idx:
50
+ batch = []
51
+ n_nodes = 0
52
+ while idx and n_nodes + self.node_counts[idx[0]] <= self.max_nodes:
53
+ next_idx, idx = idx[0], idx[1:]
54
+ n_nodes += self.node_counts[next_idx]
55
+ batch.append(next_idx)
56
+ self.batches.append(batch)
57
+
58
+ def __len__(self):
59
+ if not self.batches: self._form_batches()
60
+ return len(self.batches)
61
+
62
+ def __iter__(self):
63
+ if not self.batches:
64
+ self._form_batches()
65
+ for batch in self.batches:
66
+ yield batch
67
+
68
+
69
+ class GVPDataLoader(torch.utils.data.DataLoader):
70
+ def __init__(self, dataset, num_workers=0,
71
+ featurizer=None, max_nodes=3000, **kwargs):
72
+ super(GVPDataLoader, self).__init__(dataset,
73
+ batch_sampler = BatchSampler(node_counts = [ len(data['seq']) for data in dataset], max_nodes=max_nodes),
74
+ num_workers = num_workers,
75
+ collate_fn = featurizer.collate,
76
+ **kwargs)
77
+ self.featurizer = featurizer
78
+
79
+
80
+ def load_data(data_name, method, batch_size, data_root, pdb_path, split_csv, max_nodes=3000, num_workers=8, removeTS=0, test_casp=False, **kwargs):
81
+ if data_name == 'CATH4.2' or data_name == 'TS':
82
+ cath_set = CATHDataset(osp.join(data_root, 'cath4.2'), mode='train', test_name='All', removeTS=removeTS)
83
+ train_set, valid_set, test_set = map(lambda x: copy.copy(x), [cath_set] * 3)
84
+ valid_set.change_mode('valid')
85
+ test_set.change_mode('test')
86
+ if data_name == 'TS':
87
+ test_set = TSDataset(osp.join(data_root, 'ts'))
88
+
89
+ collate_fn = featurize_GTrans
90
+ elif data_name == 'CATH4.3':
91
+ cath_set = CATHDataset(osp.join(data_root, 'cath4.3'), mode='train', test_name='All', removeTS=removeTS, version=4.3)
92
+ train_set, valid_set, test_set = map(lambda x: copy.copy(x), [cath_set] * 3)
93
+ valid_set.change_mode('valid')
94
+ test_set.change_mode('test')
95
+
96
+ collate_fn = featurize_GTrans
97
+ elif data_name == 'AlphaFold':
98
+ af_set = AlphaFoldDataset(osp.join(data_root, 'af2db'), upid=upid, mode='train', limit_length=limit_length, joint_data=joint_data)
99
+ train_set, valid_set, test_set = map(lambda x: copy.copy(x), [af_set] * 3)
100
+ valid_set.change_mode('valid')
101
+ test_set.change_mode('test')
102
+ collate_fn = featurize_AF
103
+ elif data_name=='MPNN':
104
+ train_set = MPNNDataset(mode='train')
105
+ valid_set = MPNNDataset(mode='valid')
106
+ test_set = MPNNDataset(mode='test')
107
+ collate_fn = featurize_GTrans
108
+
109
+ elif data_name == 'S350':
110
+ cath_set = CATHDataset(osp.join(data_root, 's350'), mode='train', test_name='All', removeTS=removeTS, version=4.3)
111
+ train_set, valid_set, test_set = map(lambda x: copy.copy(x), [cath_set] * 3)
112
+ valid_set.change_mode('train')
113
+ test_set.change_mode('train')
114
+
115
+ collate_fn = featurize_GTrans
116
+
117
+ elif data_name == 'Protherm':
118
+ cath_set = CATHDataset(osp.join(data_root, 'protherm'), mode='train', test_name='All', removeTS=removeTS, version=4.3)
119
+ train_set, valid_set, test_set = map(lambda x: copy.copy(x), [cath_set] * 3)
120
+ valid_set.change_mode('valid')
121
+ test_set.change_mode('test')
122
+
123
+ collate_fn = featurize_GTrans
124
+ if test_casp:
125
+ test_set = CASPDataset(osp.join(data_root, 'casp15'))
126
+
127
+ if method in ['AlphaDesign', 'PiFold', 'KWDesign', 'GraphTrans', 'StructGNN']:
128
+ pass
129
+ elif method == 'GVP':
130
+ featurizer = featurize_GVP()
131
+ collate_fn = featurizer.collate
132
+ elif method == 'ProteinMPNN':
133
+ collate_fn = featurize_ProteinMPNN
134
+ elif method == 'ESMIF':
135
+ collate_fn = featurize_Inversefolding
136
+
137
+ # train_set.data = train_set.data[:100]
138
+ # valid_set.data = valid_set.data[:100]
139
+ # test_set.data = test_set.data[:100]
140
+ pdb.set_trace()
141
+ train_loader = DataLoaderX(local_rank=0, dataset=train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers, collate_fn=collate_fn, prefetch_factor=8)
142
+ valid_loader = DataLoaderX(local_rank=0,dataset=valid_set, batch_size=batch_size, shuffle=False, num_workers=num_workers, collate_fn=collate_fn, prefetch_factor=8)
143
+ test_loader = DataLoaderX(local_rank=0,dataset=test_set, batch_size=batch_size, shuffle=False, num_workers=num_workers, collate_fn=collate_fn, prefetch_factor=8)
144
+
145
+ return train_loader, valid_loader, test_loader
146
+
147
+
148
+ def make_cath_loader(test_set, method, batch_size, max_nodes=3000, num_workers=8):
149
+ if method in ['pifold','adesign', 'graphtrans', 'structgnn', 'gca']:
150
+ collate_fn = featurize_GTrans
151
+ test_loader = GTransDataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=num_workers, collate_fn=collate_fn)
152
+ elif method == 'gvp':
153
+ featurizer = featurize_GVP()
154
+ test_loader = GVPDataLoader(test_set, num_workers=num_workers, featurizer=featurizer, max_nodes=max_nodes)
155
+ elif method == 'proteinmpnn':
156
+ collate_fn = featurize_ProteinMPNN
157
+ test_loader = GTransDataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=num_workers, collate_fn=collate_fn)
158
+ elif method == 'esmif':
159
+ collate_fn = featurize_Inversefolding
160
+ test_loader = GTransDataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=num_workers, collate_fn=collate_fn)
161
+ return test_loader
Flexpert-Design/src/datasets/fast_dataloader.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import threading
2
+ import torch
3
+ import queue
4
+ from torch.utils.data import DataLoader
5
+
6
+
7
+ class DataLoaderX(DataLoader):
8
+ def __init__(self, local_rank, **kwargs):
9
+ super().__init__(**kwargs)
10
+ self.stream = torch.cuda.Stream(
11
+ local_rank
12
+ ) # create a new cuda stream in each process
13
+ self.local_rank = local_rank
14
+ # self.custom_collect_fn = custom_collect_fn
15
+
16
+ def __iter__(self):
17
+ self.iter = super().__iter__()
18
+ self.preload()
19
+ return self
20
+
21
+
22
+ def preload(self):
23
+ while True:
24
+ #获取下一个值
25
+ self.batch = next(self.iter, None)
26
+ if self.batch is not None:
27
+ break
28
+ if self.iter._send_idx==len(self.iter):
29
+ break
30
+
31
+ if (self.batch is None):
32
+ return None
33
+
34
+ with torch.cuda.stream(self.stream): # 将数据预先放进gpu
35
+ for key, val in self.batch.items():
36
+ if type(val) == torch.Tensor:
37
+ self.batch[key] = val.to(
38
+ device=self.local_rank, non_blocking=True
39
+ )
40
+
41
+
42
+ def __next__(self):
43
+ torch.cuda.current_stream().wait_stream(
44
+ self.stream
45
+ ) # wait tensor to put on GPU
46
+ batch = self.batch
47
+ # batch = self.custom_collect_fn(self.batch)
48
+ if batch is None:
49
+ raise StopIteration
50
+ self.preload()
51
+ return batch
52
+
Flexpert-Design/src/datasets/featurizer.py ADDED
@@ -0,0 +1,743 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import itertools
4
+ import torch.nn.functional as F
5
+ import math
6
+ import torch_geometric
7
+ # import torch_cluster
8
+ from collections.abc import Mapping, Sequence
9
+ from torch_geometric.data import Data, Batch
10
+ from torch.utils.data.dataloader import default_collate
11
+ from transformers import AutoTokenizer
12
+ import pdb
13
+ tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D", cache_dir='./cache_dir/') # mask token: 32
14
+
15
+ def _normalize(tensor, dim=-1):
16
+ '''
17
+ Normalizes a `torch.Tensor` along dimension `dim` without `nan`s.
18
+ '''
19
+ return torch.nan_to_num(
20
+ torch.div(tensor, torch.norm(tensor, dim=dim, keepdim=True)))
21
+
22
+
23
+ def _rbf(D, D_min=0., D_max=20., D_count=16, device='cpu'):
24
+ '''
25
+ From https://github.com/jingraham/neurips19-graph-protein-design
26
+
27
+ Returns an RBF embedding of `torch.Tensor` `D` along a new axis=-1.
28
+ That is, if `D` has shape [...dims], then the returned tensor will have
29
+ shape [...dims, D_count].
30
+ '''
31
+ D_mu = torch.linspace(D_min, D_max, D_count, device=device)
32
+ D_mu = D_mu.view([1, -1])
33
+ D_sigma = (D_max - D_min) / D_count
34
+ D_expand = torch.unsqueeze(D, -1)
35
+
36
+ RBF = torch.exp(-((D_expand - D_mu) / D_sigma) ** 2)
37
+ return RBF
38
+
39
+ def shuffle_subset(n, p):
40
+ n_shuffle = np.random.binomial(n, p)
41
+ ix = np.arange(n)
42
+ ix_subset = np.random.choice(ix, size=n_shuffle, replace=False)
43
+ ix_subset_shuffled = np.copy(ix_subset)
44
+ np.random.shuffle(ix_subset_shuffled)
45
+ ix[ix_subset] = ix_subset_shuffled
46
+ return ix
47
+
48
+
49
+ def featurize_AF(batch, shuffle_fraction=0.):
50
+ """ Pack and pad batch into torch tensors """
51
+ alphabet = 'ACDEFGHIKLMNPQRSTVWY'
52
+ B = len(batch)
53
+ lengths = np.array([len(b['seq']) for b in batch], dtype=np.int32)
54
+ L_max = max([len(b['seq']) for b in batch])
55
+ X = np.zeros([B, L_max, 4, 3])
56
+ S = np.zeros([B, L_max], dtype=np.int32)
57
+ score = np.zeros([B, L_max])
58
+
59
+ # Build the batch
60
+ for i, b in enumerate(batch):
61
+ x = np.stack([b[c] for c in ['N', 'CA', 'C', 'O']], 1) # [#atom, 4, 3]
62
+
63
+ l = len(b['seq'])
64
+ x_pad = np.pad(x, [[0,L_max-l], [0,0], [0,0]], 'constant', constant_values=(np.nan, )) # [#atom, 4, 3]
65
+ X[i,:,:,:] = x_pad
66
+
67
+ # Convert to labels
68
+ indices = np.asarray([alphabet.index(a) for a in b['seq']], dtype=np.int32)
69
+ if shuffle_fraction > 0.:
70
+ idx_shuffle = shuffle_subset(l, shuffle_fraction)
71
+ S[i, :l] = indices[idx_shuffle]
72
+ score[i,:l] = b['score'][idx_shuffle]
73
+ else:
74
+ S[i, :l] = indices
75
+ score[i,:l] = b['score']
76
+
77
+ mask = np.isfinite(np.sum(X,(2,3))).astype(np.float32) # atom mask
78
+ numbers = np.sum(mask, axis=1).astype(np.int)
79
+ S_new = np.zeros_like(S)
80
+ score_new = np.zeros_like(score)
81
+ X_new = np.zeros_like(X)+np.nan
82
+ for i, n in enumerate(numbers):
83
+ X_new[i,:n,::] = X[i][mask[i]==1]
84
+ S_new[i,:n] = S[i][mask[i]==1]
85
+ score_new[i,:n] = score[i][mask[i]==1]
86
+
87
+ X = X_new
88
+ S = S_new
89
+ score = score_new
90
+ isnan = np.isnan(X)
91
+ mask = np.isfinite(np.sum(X,(2,3))).astype(np.float32)
92
+ X[isnan] = 0.
93
+ # Conversion
94
+ S = torch.from_numpy(S).to(dtype=torch.long)
95
+ score = torch.from_numpy(score).float()
96
+ X = torch.from_numpy(X).to(dtype=torch.float32)
97
+ mask = torch.from_numpy(mask).to(dtype=torch.float32)
98
+ return X, S, score, mask, lengths
99
+
100
+
101
+ def featurize_GTrans(batch):
102
+ """ Pack and pad batch into torch tensors """
103
+ # alphabet = 'ACDEFGHIKLMNPQRSTVWYX'
104
+ batch = [one for one in batch if one is not None]
105
+ B = len(batch)
106
+ if B==0:
107
+ return None
108
+ lengths = np.array([len(b['seq']) for b in batch], dtype=np.int32)
109
+ L_max = max([len(b['seq']) for b in batch])
110
+ X = np.zeros([B, L_max, 4, 3])
111
+ S = np.zeros([B, L_max], dtype=np.int32)
112
+ score = np.ones([B, L_max]) * 100.0
113
+ chain_mask = np.zeros([B, L_max])-1 # 1:需要被预测的掩码部分 0:可见部分
114
+ chain_encoding = np.zeros([B, L_max])-1
115
+
116
+
117
+ # Build the batch
118
+ for i, b in enumerate(batch):
119
+ x = np.stack([b[c] for c in ['N', 'CA', 'C', 'O']], 1) # [#atom, 4, 3]
120
+
121
+ l = len(b['seq'])
122
+ x_pad = np.pad(x, [[0,L_max-l], [0,0], [0,0]], 'constant', constant_values=(np.nan, )) # [#atom, 4, 3]
123
+ X[i,:,:,:] = x_pad
124
+
125
+ # Convert to labels
126
+ indices = np.array(tokenizer.encode(b['seq'], add_special_tokens=False))
127
+ # indices = np.asarray([alphabet.index(a) for a in b['seq']], dtype=np.int32)
128
+
129
+
130
+ S[i, :l] = indices
131
+ chain_mask[i,:l] = b['chain_mask']
132
+ chain_encoding[i,:l] = b['chain_encoding']
133
+
134
+ mask = np.isfinite(np.sum(X,(2,3))).astype(np.float32) # atom mask
135
+ numbers = np.sum(mask, axis=1).astype(np.int32)
136
+ S_new = np.zeros_like(S)
137
+ X_new = np.zeros_like(X)+np.nan
138
+ for i, n in enumerate(numbers):
139
+ X_new[i,:n,::] = X[i][mask[i]==1]
140
+ S_new[i,:n] = S[i][mask[i]==1]
141
+
142
+ X = X_new
143
+ S = S_new
144
+ isnan = np.isnan(X)
145
+ mask = np.isfinite(np.sum(X,(2,3))).astype(np.float32)
146
+ X[isnan] = 0.
147
+ # Conversion
148
+ S = torch.from_numpy(S).to(dtype=torch.long)
149
+ score = torch.from_numpy(score).float()
150
+ X = torch.from_numpy(X).to(dtype=torch.float32)
151
+ mask = torch.from_numpy(mask).to(dtype=torch.float32)
152
+ lengths = torch.from_numpy(lengths)
153
+ chain_mask = torch.from_numpy(chain_mask)
154
+ chain_encoding = torch.from_numpy(chain_encoding)
155
+
156
+ return {"title": [b['title'] for b in batch],
157
+ "X":X,
158
+ "S":S,
159
+ "score": score,
160
+ "mask":mask,
161
+ "lengths":lengths,
162
+ "chain_mask":chain_mask,
163
+ "chain_encoding":chain_encoding}
164
+
165
+
166
+ class featurize_GVP:
167
+ def __init__(self, num_positional_embeddings=16, top_k=30, num_rbf=16):
168
+ self.top_k = top_k
169
+ self.num_rbf = num_rbf
170
+ self.num_positional_embeddings = num_positional_embeddings
171
+ # self.letter_to_num = {'C': 4, 'D': 3, 'S': 15, 'Q': 5, 'K': 11, 'I': 9,
172
+ # 'P': 14, 'T': 16, 'F': 13, 'A': 0, 'G': 7, 'H': 8,
173
+ # 'E': 6, 'L': 10, 'R': 1, 'W': 17, 'V': 19,
174
+ # 'N': 2, 'Y': 18, 'M': 12}
175
+ # self.num_to_letter = {v:k for k, v in self.letter_to_num.items()}
176
+
177
+ def featurize(self, batch):
178
+ data_all = []
179
+ for b in batch:
180
+ if b is None:
181
+ continue
182
+ coords = torch.tensor(np.stack([b[c] for c in ['N', 'CA', 'C', 'O']], 1))
183
+ seq = torch.tensor(np.array(tokenizer.encode(b['seq'], add_special_tokens=False)))
184
+
185
+ mask = torch.isfinite(coords.sum(dim=(1,2)))
186
+ coords[~mask] = np.inf
187
+
188
+ X_ca = coords[:, 1].float()
189
+ edge_index = torch_geometric.nn.knn_graph(X_ca, k=self.top_k)#torch_cluster.knn_graph(X_ca, k=self.top_k)
190
+
191
+ pos_embeddings = self._positional_embeddings(edge_index) # [E, 16]
192
+ E_vectors = X_ca[edge_index[0]] - X_ca[edge_index[1]] # [E, 3]
193
+ rbf = _rbf(E_vectors.norm(dim=-1), D_count=self.num_rbf) # [E, 16]
194
+
195
+ dihedrals = self._dihedrals(coords) # [n,6]
196
+ orientations = self._orientations(X_ca) # [n,2,3]
197
+ sidechains = self._sidechains(coords) # [n,3]
198
+
199
+ node_s = dihedrals.float() # [n,6]
200
+
201
+ node_v = torch.cat([orientations, sidechains.unsqueeze(-2)], dim=-2).float() # [n, 3, 3]
202
+
203
+ edge_s = torch.cat([rbf, pos_embeddings], dim=-1).float() # [E, 32]
204
+ edge_v = _normalize(E_vectors).unsqueeze(-2).float() # [E, 1, 3]
205
+
206
+ node_s, node_v, edge_s, edge_v = map(torch.nan_to_num,(node_s, node_v, edge_s, edge_v))
207
+
208
+ data = torch_geometric.data.Data(x=X_ca, seq=seq,
209
+ node_s=node_s, node_v=node_v,
210
+ edge_s=edge_s, edge_v=edge_v,
211
+ edge_index=edge_index, mask=mask)
212
+ data_all.append(data)
213
+ return data_all
214
+
215
+ def _positional_embeddings(self, edge_index,
216
+ num_embeddings=None,
217
+ period_range=[2, 1000]):
218
+ # From https://github.com/jingraham/neurips19-graph-protein-design
219
+ num_embeddings = num_embeddings or self.num_positional_embeddings
220
+ d = edge_index[0] - edge_index[1]
221
+
222
+ frequency = torch.exp(
223
+ torch.arange(0, num_embeddings, 2, dtype=torch.float32)
224
+ * -(np.log(10000.0) / num_embeddings)
225
+ )
226
+ angles = d.unsqueeze(-1) * frequency
227
+ E = torch.cat((torch.cos(angles), torch.sin(angles)), -1)
228
+ return E
229
+
230
+ def _dihedrals(self, X, eps=1e-7):
231
+ # From https://github.com/jingraham/neurips19-graph-protein-design
232
+
233
+ X = torch.reshape(X[:, :3], [3*X.shape[0], 3])
234
+ dX = X[1:] - X[:-1]
235
+ U = _normalize(dX, dim=-1)
236
+ u_2 = U[:-2]
237
+ u_1 = U[1:-1]
238
+ u_0 = U[2:]
239
+
240
+ # Backbone normals
241
+ n_2 = _normalize(torch.cross(u_2, u_1), dim=-1)
242
+ n_1 = _normalize(torch.cross(u_1, u_0), dim=-1)
243
+
244
+ # Angle between normals
245
+ cosD = torch.sum(n_2 * n_1, -1)
246
+ cosD = torch.clamp(cosD, -1 + eps, 1 - eps)
247
+ D = torch.sign(torch.sum(u_2 * n_1, -1)) * torch.acos(cosD)
248
+
249
+ # This scheme will remove phi[0], psi[-1], omega[-1]
250
+ D = F.pad(D, [1, 2])
251
+ D = torch.reshape(D, [-1, 3])
252
+ # Lift angle representations to the circle
253
+ D_features = torch.cat([torch.cos(D), torch.sin(D)], 1)
254
+ return D_features
255
+
256
+ def _orientations(self, X):
257
+ forward = _normalize(X[1:] - X[:-1])
258
+ backward = _normalize(X[:-1] - X[1:])
259
+ forward = F.pad(forward, [0, 0, 0, 1])
260
+ backward = F.pad(backward, [0, 0, 1, 0])
261
+ return torch.cat([forward.unsqueeze(-2), backward.unsqueeze(-2)], -2)
262
+
263
+ def _sidechains(self, X):
264
+ n, origin, c = X[:, 0], X[:, 1], X[:, 2]
265
+ c, n = _normalize(c - origin), _normalize(n - origin)
266
+ bisector = _normalize(c + n)
267
+ perp = _normalize(torch.cross(c, n))
268
+ vec = -bisector * math.sqrt(1 / 3) - perp * math.sqrt(2 / 3)
269
+ return vec
270
+
271
+ def collate(self, batch):
272
+ batch = self.featurize(batch)
273
+ if (batch is None) or (len(batch)==0):
274
+ return None
275
+
276
+ elem = batch[0]
277
+ if isinstance(elem, Data):
278
+ return Batch.from_data_list(batch)
279
+ elif isinstance(elem, torch.Tensor):
280
+ return default_collate(batch)
281
+ elif isinstance(elem, float):
282
+ return torch.tensor(batch, dtype=torch.float)
283
+ elif isinstance(elem, int):
284
+ return torch.tensor(batch)
285
+ elif isinstance(elem, str):
286
+ return batch
287
+ elif isinstance(elem, Mapping):
288
+ return {key: self.collate([d[key] for d in batch]) for key in elem}
289
+ elif isinstance(elem, tuple) and hasattr(elem, '_fields'):
290
+ return type(elem)(*(self.collate(s) for s in zip(*batch)))
291
+ elif isinstance(elem, Sequence) and not isinstance(elem, str):
292
+ return [self.collate(s) for s in zip(*batch)]
293
+
294
+ raise TypeError('DataLoader found invalid type: {}'.format(type(elem)))
295
+
296
+
297
+ def featurize_ProteinMPNN(batch, is_testing=False, chain_dict=None, fixed_position_dict=None, omit_AA_dict=None, tied_positions_dict=None, pssm_dict=None, bias_by_res_dict=None):
298
+ """ Pack and pad batch into torch tensors """
299
+
300
+ batch = [one for one in batch if one is not None]
301
+ # print('______________________________________________________')
302
+ # print('______________________________________________________')
303
+ # print('______________________________________________________')
304
+ # print('______________________________________________________')
305
+ # print(batch[0].keys())
306
+ USING_DYNAMICS = True if ('norm_bfactors' in batch[0].keys()) or ('gt_flex' in batch[0].keys()) or ('enm_vals' in batch[0].keys()) or ('original_gt_flex' in batch[0].keys()) or ('eng_mask' in batch[0].keys()) else False
307
+
308
+ alphabet = 'ACDEFGHIKLMNPQRSTVWYX'
309
+ B = len(batch)
310
+ if B==0:
311
+ return None
312
+ lengths = np.array([len(b['seq']) for b in batch], dtype=np.int32) #sum of chain seq lengths
313
+ L_max = max([len(b['seq']) for b in batch])
314
+ X = np.zeros([B, L_max, 4, 3])
315
+ residue_idx = -100*np.ones([B, L_max], dtype=np.int32)
316
+ chain_M = np.zeros([B, L_max], dtype=np.int32) #1.0 for the bits that need to be predicted
317
+ pssm_coef_all = np.zeros([B, L_max], dtype=np.float32) #1.0 for the bits that need to be predicted
318
+ pssm_bias_all = np.zeros([B, L_max, 21], dtype=np.float32) #1.0 for the bits that need to be predicted
319
+ pssm_log_odds_all = 10000.0*np.ones([B, L_max, 21], dtype=np.float32) #1.0 for the bits that need to be predicted
320
+ chain_M_pos = np.zeros([B, L_max], dtype=np.int32) #1.0 for the bits that need to be predicted
321
+ bias_by_res_all = np.zeros([B, L_max, 21], dtype=np.float32)
322
+ chain_encoding_all = np.zeros([B, L_max], dtype=np.int32) #1.0 for the bits that need to be predicted
323
+ S = np.zeros([B, L_max], dtype=np.int32)
324
+ score = np.zeros([B, L_max])
325
+ omit_AA_mask = np.zeros([B, L_max, len(alphabet)], dtype=np.int32)
326
+ # Build the batch
327
+ letter_list_list = []
328
+ visible_list_list = []
329
+ masked_list_list = []
330
+ masked_chain_length_list_list = []
331
+ tied_pos_list_of_lists_list = []
332
+ # shuffle all chains before the main loop
333
+ if USING_DYNAMICS:
334
+ if ('norm_bfactors' in batch[0].keys()):
335
+ b_factors = np.zeros([B, L_max])
336
+ if ('gt_flex' in batch[0].keys()):
337
+ gt_flex = np.zeros([B, L_max])
338
+ if ('enm_vals' in batch[0].keys()):
339
+ enm_vals = np.zeros([B, L_max])
340
+ if ('original_gt_flex' in batch[0].keys()):
341
+ original_gt_flex = np.zeros([B, L_max])
342
+ if ('eng_mask' in batch[0].keys()):
343
+ eng_mask = np.zeros([B, L_max])
344
+
345
+ for i, b in enumerate(batch):
346
+ if chain_dict != None:
347
+ masked_chains, visible_chains = chain_dict[b['name']] #masked_chains a list of chain letters to predict [A, D, F]
348
+ else:
349
+ # masked_chains = [item[-1:] for item in list(b) if item[:10]=='seq_chain_']
350
+ masked_chains = ['']
351
+ visible_chains = []
352
+ # num_chains = b['num_of_chains']
353
+ all_chains = masked_chains + visible_chains
354
+ #random.shuffle(all_chains)
355
+ for i, b in enumerate(batch):
356
+ mask_dict = {}
357
+ a = 0
358
+ x_chain_list = []
359
+ chain_mask_list = []
360
+ chain_seq_list = []
361
+ chain_encoding_list = []
362
+ c = 1
363
+ letter_list = []
364
+ global_idx_start_list = [0]
365
+ visible_list = []
366
+ masked_list = []
367
+ masked_chain_length_list = []
368
+ fixed_position_mask_list = []
369
+ omit_AA_mask_list = []
370
+ pssm_coef_list = []
371
+ pssm_bias_list = []
372
+ pssm_log_odds_list = []
373
+ bias_by_res_list = []
374
+
375
+ if USING_DYNAMICS:
376
+ if ('norm_bfactors' in batch[0].keys()):
377
+ b_factors_list = []
378
+ if ('gt_flex' in batch[0].keys()):
379
+ gt_flex_list = []
380
+ if ('enm_vals' in batch[0].keys()):
381
+ enm_vals_list = []
382
+ if ('original_gt_flex' in batch[0].keys()):
383
+ original_gt_flex_list = []
384
+ if ('eng_mask' in batch[0].keys()):
385
+ eng_mask_list = []
386
+ l0 = 0
387
+ l1 = 0
388
+ for step, letter in enumerate(all_chains):
389
+ if letter in visible_chains:
390
+ letter_list.append(letter)
391
+ visible_list.append(letter)
392
+ chain_seq = b[f'seq_chain_{letter}']
393
+ chain_seq = ''.join([a if a!='-' else 'X' for a in chain_seq])
394
+ chain_length = len(chain_seq)
395
+ global_idx_start_list.append(global_idx_start_list[-1]+chain_length)
396
+ chain_coords = b[f'coords_chain_{letter}'] #this is a dictionary
397
+ chain_mask = np.zeros(chain_length) #0.0 for visible chains
398
+ x_chain = np.stack([chain_coords[c] for c in [f'N_chain_{letter}', f'CA_chain_{letter}', f'C_chain_{letter}', f'O_chain_{letter}']], 1) #[chain_lenght,4,3]
399
+ x_chain_list.append(x_chain)
400
+ chain_mask_list.append(chain_mask)
401
+ chain_seq_list.append(chain_seq)
402
+ chain_encoding_list.append(c*np.ones(np.array(chain_mask).shape[0]))
403
+ l1 += chain_length
404
+ residue_idx[i, l0:l1] = 100*(c-1)+np.arange(l0, l1)
405
+ l0 += chain_length
406
+ c+=1
407
+ fixed_position_mask = np.ones(chain_length)
408
+ fixed_position_mask_list.append(fixed_position_mask)
409
+ omit_AA_mask_temp = np.zeros([chain_length, len(alphabet)], np.int32)
410
+ omit_AA_mask_list.append(omit_AA_mask_temp)
411
+ pssm_coef = np.zeros(chain_length)
412
+ pssm_bias = np.zeros([chain_length, 21])
413
+ pssm_log_odds = 10000.0*np.ones([chain_length, 21])
414
+ pssm_coef_list.append(pssm_coef)
415
+ pssm_bias_list.append(pssm_bias)
416
+ pssm_log_odds_list.append(pssm_log_odds)
417
+ bias_by_res_list.append(np.zeros([chain_length, 21]))
418
+ if letter in masked_chains:
419
+ masked_list.append(letter)
420
+ letter_list.append(letter)
421
+
422
+ if USING_DYNAMICS:
423
+ if ('norm_bfactors' in batch[0].keys()):
424
+ chain_b_factors = b['norm_bfactors']
425
+ b_factors_list.append(chain_b_factors)
426
+ if ('gt_flex' in batch[0].keys()):
427
+ chain_gt_flex = b['gt_flex']
428
+ gt_flex_list.append(chain_gt_flex)
429
+ if ('enm_vals' in batch[0].keys()):
430
+ chain_enm_vals = b['enm_vals']
431
+ enm_vals_list.append(chain_enm_vals)
432
+ if ('original_gt_flex' in batch[0].keys()):
433
+ chain_original_gt_flex = b['original_gt_flex']
434
+ original_gt_flex_list.append(chain_original_gt_flex)
435
+ if ('eng_mask' in batch[0].keys()):
436
+ chain_eng_mask = b['eng_mask']
437
+ eng_mask_list.append(chain_eng_mask)
438
+
439
+ # chain_seq = b[f'seq_chain_{letter}']
440
+ chain_seq = b[f'seq{letter}']
441
+ chain_seq = ''.join([a if a!='-' else 'X' for a in chain_seq])
442
+ chain_length = len(chain_seq)
443
+ global_idx_start_list.append(global_idx_start_list[-1]+chain_length)
444
+ masked_chain_length_list.append(chain_length)
445
+ # chain_coords = b[f'coords_chain_{letter}'] #this is a dictionary
446
+ chain_coords = b
447
+ chain_mask = np.ones(chain_length) #1.0 for masked
448
+ # x_chain = np.stack([chain_coords[c] for c in [f'N_chain_{letter}', f'CA_chain_{letter}', f'C_chain_{letter}', f'O_chain_{letter}']], 1) #[chain_lenght,4,3]
449
+ x_chain = np.stack([chain_coords[c] for c in [f'N', f'CA', f'C', f'O']], 1) #[chain_lenght,4,3]
450
+ x_chain_list.append(x_chain)
451
+ chain_mask_list.append(chain_mask)
452
+ chain_seq_list.append(chain_seq)
453
+ chain_encoding_list.append(c*np.ones(np.array(chain_mask).shape[0]))
454
+ l1 += chain_length
455
+ residue_idx[i, l0:l1] = 100*(c-1)+np.arange(l0, l1)
456
+ l0 += chain_length
457
+ c+=1
458
+ fixed_position_mask = np.ones(chain_length)
459
+ if fixed_position_dict!=None:
460
+ fixed_pos_list = fixed_position_dict[b['name']][letter]
461
+ if fixed_pos_list:
462
+ fixed_position_mask[np.array(fixed_pos_list)-1] = 0.0
463
+ fixed_position_mask_list.append(fixed_position_mask)
464
+ omit_AA_mask_temp = np.zeros([chain_length, len(alphabet)], np.int32)
465
+ if omit_AA_dict!=None:
466
+ for item in omit_AA_dict[b['name']][letter]:
467
+ idx_AA = np.array(item[0])-1
468
+ AA_idx = np.array([np.argwhere(np.array(list(alphabet))== AA)[0][0] for AA in item[1]]).repeat(idx_AA.shape[0])
469
+ idx_ = np.array([[a, b] for a in idx_AA for b in AA_idx])
470
+ omit_AA_mask_temp[idx_[:,0], idx_[:,1]] = 1
471
+ omit_AA_mask_list.append(omit_AA_mask_temp)
472
+ pssm_coef = np.zeros(chain_length)
473
+ pssm_bias = np.zeros([chain_length, 21])
474
+ pssm_log_odds = 10000.0*np.ones([chain_length, 21])
475
+ if pssm_dict:
476
+ if pssm_dict[b['name']][letter]:
477
+ pssm_coef = pssm_dict[b['name']][letter]['pssm_coef']
478
+ pssm_bias = pssm_dict[b['name']][letter]['pssm_bias']
479
+ pssm_log_odds = pssm_dict[b['name']][letter]['pssm_log_odds']
480
+ pssm_coef_list.append(pssm_coef)
481
+ pssm_bias_list.append(pssm_bias)
482
+ pssm_log_odds_list.append(pssm_log_odds)
483
+ if bias_by_res_dict:
484
+ bias_by_res_list.append(bias_by_res_dict[b['name']][letter])
485
+ else:
486
+ bias_by_res_list.append(np.zeros([chain_length, 21]))
487
+
488
+
489
+ letter_list_np = np.array(letter_list)
490
+ tied_pos_list_of_lists = []
491
+ tied_beta = np.ones(L_max)
492
+ if tied_positions_dict!=None:
493
+ tied_pos_list = tied_positions_dict[b['name']]
494
+ if tied_pos_list:
495
+ set_chains_tied = set(list(itertools.chain(*[list(item) for item in tied_pos_list])))
496
+ for tied_item in tied_pos_list:
497
+ one_list = []
498
+ for k, v in tied_item.items():
499
+ start_idx = global_idx_start_list[np.argwhere(letter_list_np == k)[0][0]]
500
+ if isinstance(v[0], list):
501
+ for v_count in range(len(v[0])):
502
+ one_list.append(start_idx+v[0][v_count]-1)#make 0 to be the first
503
+ tied_beta[start_idx+v[0][v_count]-1] = v[1][v_count]
504
+ else:
505
+ for v_ in v:
506
+ one_list.append(start_idx+v_-1)#make 0 to be the first
507
+ tied_pos_list_of_lists.append(one_list)
508
+ tied_pos_list_of_lists_list.append(tied_pos_list_of_lists)
509
+
510
+ x = np.concatenate(x_chain_list,0) #[L, 4, 3]
511
+
512
+ if USING_DYNAMICS:
513
+ if ('norm_bfactors' in batch[0].keys()):
514
+ bf = np.concatenate(b_factors_list,0) #[L,]
515
+ if ('gt_flex' in batch[0].keys()):
516
+ gt = np.concatenate(gt_flex_list,0) #[L,]
517
+ if ('enm_vals' in batch[0].keys()):
518
+ enm = np.concatenate(enm_vals_list,0)
519
+ if ('original_gt_flex' in batch[0].keys()):
520
+ orig_gt = np.concatenate(original_gt_flex_list,0)
521
+ if ('eng_mask' in batch[0].keys()):
522
+ eng = np.concatenate(eng_mask_list,0)
523
+
524
+ all_sequence = "".join(chain_seq_list)
525
+ m = np.concatenate(chain_mask_list,0) #[L,], 1.0 for places that need to be predicted
526
+ chain_encoding = np.concatenate(chain_encoding_list,0)
527
+ m_pos = np.concatenate(fixed_position_mask_list,0) #[L,], 1.0 for places that need to be predicted
528
+
529
+ pssm_coef_ = np.concatenate(pssm_coef_list,0) #[L,], 1.0 for places that need to be predicted
530
+ pssm_bias_ = np.concatenate(pssm_bias_list,0) #[L,], 1.0 for places that need to be predicted
531
+ pssm_log_odds_ = np.concatenate(pssm_log_odds_list,0) #[L,], 1.0 for places that need to be predicted
532
+
533
+ bias_by_res_ = np.concatenate(bias_by_res_list, 0) #[L,21], 0.0 for places where AA frequencies don't need to be tweaked
534
+
535
+ l = len(all_sequence)
536
+ x_pad = np.pad(x, [[0, L_max-l], [0,0], [0,0]], 'constant', constant_values=(np.nan, ))
537
+ if USING_DYNAMICS:
538
+ if ('norm_bfactors' in batch[0].keys()):
539
+ bf_pad = np.pad(bf, [[0, L_max-l]], 'constant', constant_values=(np.nan, ))
540
+ if ('gt_flex' in batch[0].keys()):
541
+ gt_pad = np.pad(gt, [[0, L_max-l]], 'constant', constant_values=(np.nan, ))
542
+ if ('enm_vals' in batch[0].keys()):
543
+ enm_pad = np.pad(enm, [[0, L_max-l]], 'constant', constant_values=(np.nan, ))
544
+ if ('original_gt_flex' in batch[0].keys()):
545
+ orig_gt_pad = np.pad(orig_gt, [[0, L_max-l]], 'constant', constant_values=(0, ))
546
+ if ('eng_mask' in batch[0].keys()):
547
+ eng_pad = np.pad(eng, [[0, L_max-l]], 'constant', constant_values=(0, ))
548
+
549
+ X[i,:,:,:] = x_pad
550
+ if USING_DYNAMICS:
551
+ if ('norm_bfactors' in batch[0].keys()):
552
+ b_factors[i, :] = bf_pad
553
+ if ('gt_flex' in batch[0].keys()):
554
+ gt_flex[i, :] = gt_pad[:-1]
555
+ if ('enm_vals' in batch[0].keys()):
556
+ enm_vals[i, :] = enm_pad
557
+ if ('original_gt_flex' in batch[0].keys()):
558
+ original_gt_flex[i, :] = orig_gt_pad[:-1]
559
+ if ('eng_mask' in batch[0].keys()):
560
+ eng_mask[i, :] = eng_pad[:-1]
561
+
562
+ if 'score' in b.keys():
563
+ score[i, :l] = b['score']
564
+ else:
565
+ score[i, :l] = 100.0
566
+
567
+ m_pad = np.pad(m, [[0, L_max-l]], 'constant', constant_values=(0.0, ))
568
+ m_pos_pad = np.pad(m_pos, [[0,L_max-l]], 'constant', constant_values=(0.0, ))
569
+ omit_AA_mask_pad = np.pad(np.concatenate(omit_AA_mask_list,0), [[0,L_max-l], [0, 0]], 'constant', constant_values=(0.0, ))
570
+ chain_M[i,:] = m_pad
571
+ chain_M_pos[i,:] = m_pos_pad
572
+ omit_AA_mask[i,] = omit_AA_mask_pad
573
+
574
+ chain_encoding_pad = np.pad(chain_encoding, [[0,L_max-l]], 'constant', constant_values=(0.0, ))
575
+ chain_encoding_all[i,:] = chain_encoding_pad
576
+
577
+ pssm_coef_pad = np.pad(pssm_coef_, [[0, L_max-l]], 'constant', constant_values=(0.0, ))
578
+ pssm_bias_pad = np.pad(pssm_bias_, [[0, L_max-l], [0,0]], 'constant', constant_values=(0.0, ))
579
+ pssm_log_odds_pad = np.pad(pssm_log_odds_, [[0,L_max-l], [0,0]], 'constant', constant_values=(0.0, ))
580
+
581
+ pssm_coef_all[i,:] = pssm_coef_pad
582
+ pssm_bias_all[i,:] = pssm_bias_pad
583
+ pssm_log_odds_all[i,:] = pssm_log_odds_pad
584
+
585
+ bias_by_res_pad = np.pad(bias_by_res_, [[0,L_max-l], [0,0]], 'constant', constant_values=(0.0, ))
586
+ bias_by_res_all[i,:] = bias_by_res_pad
587
+
588
+ # Convert to labels
589
+ indices = np.array(tokenizer.encode(b['seq'], add_special_tokens=False))
590
+ S[i, :l] = indices
591
+ letter_list_list.append(letter_list)
592
+ visible_list_list.append(visible_list)
593
+ masked_list_list.append(masked_list)
594
+ masked_chain_length_list_list.append(masked_chain_length_list)
595
+
596
+
597
+ isnan = np.isnan(X)
598
+ mask = np.isfinite(np.sum(X,(2,3))).astype(np.float32)
599
+ X[isnan] = 0.
600
+
601
+ # Conversion
602
+ pssm_coef_all = torch.from_numpy(pssm_coef_all).to(dtype=torch.float32)
603
+ pssm_bias_all = torch.from_numpy(pssm_bias_all).to(dtype=torch.float32)
604
+ pssm_log_odds_all = torch.from_numpy(pssm_log_odds_all).to(dtype=torch.float32)
605
+
606
+ tied_beta = torch.from_numpy(tied_beta).to(dtype=torch.float32)
607
+
608
+ jumps = ((residue_idx[:,1:]-residue_idx[:,:-1])==1).astype(np.float32)
609
+ bias_by_res_all = torch.from_numpy(bias_by_res_all).to(dtype=torch.float32)
610
+ phi_mask = np.pad(jumps, [[0,0],[1,0]])
611
+ psi_mask = np.pad(jumps, [[0,0],[0,1]])
612
+ omega_mask = np.pad(jumps, [[0,0],[0,1]])
613
+ dihedral_mask = np.concatenate([phi_mask[:,:,None], psi_mask[:,:,None], omega_mask[:,:,None]], -1) #[B,L,3]
614
+ dihedral_mask = torch.from_numpy(dihedral_mask).to(dtype=torch.float32)
615
+ residue_idx = torch.from_numpy(residue_idx).to(dtype=torch.long)
616
+ S = torch.from_numpy(S).to(dtype=torch.long)
617
+ X = torch.from_numpy(X).to(dtype=torch.float32)
618
+ if USING_DYNAMICS:
619
+ if ('norm_bfactors' in batch[0].keys()):
620
+ b_factors = torch.from_numpy(b_factors).to(dtype=torch.float32)
621
+ if ('gt_flex' in batch[0].keys()):
622
+ gt_flex = torch.from_numpy(gt_flex).to(dtype=torch.float32)
623
+ if ('enm_vals' in batch[0].keys()):
624
+ enm_vals = torch.from_numpy(enm_vals).to(dtype=torch.float32)
625
+ if ('original_gt_flex' in batch[0].keys()):
626
+ original_gt_flex = torch.from_numpy(original_gt_flex).to(dtype=torch.float32)
627
+ if ('eng_mask' in batch[0].keys()):
628
+ eng_mask = torch.from_numpy(eng_mask).to(dtype=torch.float32)
629
+ score = torch.from_numpy(score).float()
630
+ mask = torch.from_numpy(mask).to(dtype=torch.float32)
631
+ chain_M = torch.from_numpy(chain_M).to(dtype=torch.float32)
632
+ chain_M_pos = torch.from_numpy(chain_M_pos).to(dtype=torch.float32)
633
+ omit_AA_mask = torch.from_numpy(omit_AA_mask).to(dtype=torch.float32)
634
+ chain_encoding_all = torch.from_numpy(chain_encoding_all).to(dtype=torch.long)
635
+
636
+ if is_testing is False:
637
+ retVal = {"title": [b['title'] for b in batch],
638
+ "X":X,
639
+ "S":S,
640
+ "score": score,
641
+ "mask":mask,
642
+ "lengths":lengths,
643
+ "chain_M":chain_M,
644
+ "chain_M_pos":chain_M_pos,
645
+ "residue_idx":residue_idx,
646
+ "chain_encoding_all":chain_encoding_all}
647
+ if USING_DYNAMICS:
648
+ if ('norm_bfactors' in batch[0].keys()):
649
+ retVal['norm_bfactors'] = b_factors
650
+ if ('gt_flex' in batch[0].keys()):
651
+ retVal['gt_flex'] = gt_flex
652
+ if ('enm_vals' in batch[0].keys()):
653
+ retVal['enm_vals'] = enm_vals
654
+ if ('original_gt_flex' in batch[0].keys()):
655
+ retVal['original_gt_flex'] = original_gt_flex
656
+ if ('eng_mask' in batch[0].keys()):
657
+ retVal['eng_mask'] = eng_mask
658
+
659
+ return retVal
660
+ else:
661
+ retVal = {"title": [b['title'] for b in batch],
662
+ "X":X,
663
+ "S":S,
664
+ "score": score,
665
+ "mask":mask,
666
+ "lengths":lengths,
667
+ "chain_M":chain_M,
668
+ "chain_M_pos":chain_M_pos,
669
+ "residue_idx":residue_idx,
670
+ "chain_encoding_all":chain_encoding_all}
671
+ if USING_DYNAMICS:
672
+ if ('norm_bfactors' in batch[0].keys()):
673
+ retVal['norm_bfactors'] = b_factors
674
+ if ('gt_flex' in batch[0].keys()):
675
+ retVal['gt_flex'] = gt_flex
676
+ if ('enm_vals' in batch[0].keys()):
677
+ retVal['enm_vals'] = enm_vals
678
+ if ('original_gt_flex' in batch[0].keys()):
679
+ retVal['original_gt_flex'] = original_gt_flex
680
+ if ('eng_mask' in batch[0].keys()):
681
+ retVal['eng_mask'] = eng_mask
682
+ return retVal
683
+
684
+
685
+ def featurize_Inversefolding(batch, shuffle_fraction=0.):
686
+ """ Pack and pad batch into torch tensors """
687
+ alphabet = 'ACDEFGHIKLMNPQRSTVWY'
688
+ B = len(batch)
689
+ lengths = np.array([len(b['seq']) for b in batch], dtype=np.int32)
690
+ L_max = max([len(b['seq']) for b in batch])
691
+ X = np.zeros([B, L_max, 3, 3])
692
+ S = np.zeros([B, L_max], dtype=np.int32)
693
+ score = np.ones([B, L_max]) * 100.0
694
+ chain_mask = np.zeros([B, L_max])-1 # 1:需要被预测的掩码部分 0:可见部分
695
+ chain_encoding = np.zeros([B, L_max])-1
696
+
697
+ # Build the batch
698
+ for i, b in enumerate(batch):
699
+ x = np.stack([b[c] for c in ['N', 'CA', 'C']], 1) # [#atom, 4, 3]
700
+
701
+ l = len(b['seq'])
702
+ x_pad = np.pad(x, [[0,L_max-l], [0,0], [0,0]], 'constant', constant_values=(np.nan, )) # [#atom, 3, 3]
703
+ X[i,:,:,:] = x_pad
704
+
705
+ # Convert to labels
706
+ indices = np.array(tokenizer.encode(b['seq'], add_special_tokens=False))
707
+ if shuffle_fraction > 0.:
708
+ idx_shuffle = shuffle_subset(l, shuffle_fraction)
709
+ S[i, :l] = indices[idx_shuffle]
710
+ else:
711
+ S[i, :l] = indices
712
+
713
+ chain_mask[i,:l] = b['chain_mask']
714
+ chain_encoding[i,:l] = b['chain_encoding']
715
+
716
+ mask = np.isfinite(np.sum(X,(2,3))).astype(np.float32) # atom mask
717
+ numbers = np.sum(mask, axis=1).astype(np.int)
718
+ S_new = np.zeros_like(S)
719
+ X_new = np.zeros_like(X)+np.nan
720
+ for i, n in enumerate(numbers):
721
+ X_new[i,:n,::] = X[i][mask[i]==1]
722
+ S_new[i,:n] = S[i][mask[i]==1]
723
+
724
+ X = X_new
725
+ S = S_new
726
+ isnan = np.isnan(X)
727
+ mask = np.isfinite(np.sum(X,(2,3))).astype(np.float32)
728
+ X[isnan] = 0.
729
+ # Conversion
730
+ S = torch.from_numpy(S).to(dtype=torch.long)
731
+ score = torch.from_numpy(score).float()
732
+ X = torch.from_numpy(X).to(dtype=torch.float32)
733
+ mask = torch.from_numpy(mask).to(dtype=torch.float32)
734
+ chain_mask = torch.from_numpy(chain_mask)
735
+ chain_encoding = torch.from_numpy(chain_encoding)
736
+ return {"title": [b['title'] for b in batch],
737
+ "X":X,
738
+ "S":S,
739
+ "score": score,
740
+ "mask":mask,
741
+ "lengths":lengths,
742
+ "chain_mask":chain_mask,
743
+ "chain_encoding":chain_encoding}
Flexpert-Design/src/datasets/flex_cath_dataset.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import numpy as np
4
+ from tqdm import tqdm
5
+ import random
6
+ import torch.utils.data as data
7
+ from .utils import cached_property
8
+ from transformers import AutoTokenizer
9
+ from src.tools.utils import load_yaml_config
10
+
11
+ class FlexCATHDataset(data.Dataset):
12
+ def __init__(self, path='./', split='train', max_length=500, test_name='All', data = None, removeTS=0, version=4.3, data_jsonl_name='/chain_set.jsonl', use_dynamics=True):
13
+ self.version = version
14
+ self.path = path
15
+ self.mode = split
16
+ self.max_length = max_length
17
+ self.test_name = test_name
18
+ self.removeTS = removeTS
19
+ self.data_jsonl_name = data_jsonl_name
20
+
21
+ self.using_dynamics = use_dynamics
22
+
23
+ print(self.data_jsonl_name)
24
+ if self.removeTS:
25
+ self.remove = json.load(open(self.path+'/remove.json', 'r'))['remove']
26
+
27
+ if data is None:
28
+ if split == 'predict':
29
+ _split = 'valid'
30
+ print('In predict mode for CATH4.3 using VALIDATION split as the data. Consider switching to TEST set.')
31
+ else:
32
+ _split = split
33
+ self.data = self.cache_data[_split]
34
+ else:
35
+ self.data = data
36
+
37
+ self.tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D", cache_dir="./cache_dir/")
38
+
39
+ @cached_property
40
+ def cache_data(self):
41
+ alphabet='ACDEFGHIKLMNPQRSTVWY'
42
+ alphabet_set = set([a for a in alphabet])
43
+ print("path is: ", self.path)
44
+ if not os.path.exists(self.path):
45
+ raise "no such file:{} !!!".format(self.path)
46
+ else:
47
+ with open(self.path+'/'+self.data_jsonl_name) as f:
48
+ lines = f.readlines()
49
+ data_list = []
50
+ for line in tqdm(lines):
51
+ entry = json.loads(line)
52
+ if self.removeTS and entry['name'] in self.remove:
53
+ continue
54
+ seq = entry['seq']
55
+
56
+ for key, val in entry['coords'].items():
57
+ entry['coords'][key] = np.asarray(val)
58
+
59
+ bad_chars = set([s for s in seq]).difference(alphabet_set)
60
+
61
+ if len(bad_chars) == 0:
62
+ if len(entry['seq']) <= self.max_length:
63
+ chain_length = len(entry['seq'])
64
+ chain_mask = np.ones(chain_length)
65
+ data_list.append({
66
+ 'title':entry['name'],
67
+ 'seq':entry['seq'],
68
+ 'CA':entry['coords']['CA'],
69
+ 'C':entry['coords']['C'],
70
+ 'O':entry['coords']['O'],
71
+ 'N':entry['coords']['N'],
72
+ 'chain_mask': chain_mask,
73
+ 'chain_encoding': 1*chain_mask
74
+ })
75
+ if self.using_dynamics:
76
+ data_list[-1]['gt_flex'] = entry['gt_flex']
77
+ data_list[-1]['enm_vals'] = entry['enm_vals']
78
+ if 'original_gt_flex' in entry:
79
+ data_list[-1]['original_gt_flex'] = entry['original_gt_flex']
80
+ if 'eng_mask' in entry:
81
+ data_list[-1]['eng_mask'] = entry['eng_mask']
82
+ # else:
83
+ # import pdb; pdb.set_trace()
84
+ # print("Bad chars found in sequence: ", bad_chars)
85
+
86
+ if self.version==4.2:
87
+ with open(self.path+'/chain_set_splits.json') as f:
88
+ dataset_splits = json.load(f)
89
+
90
+ if self.version==4.3:
91
+ with open(self.path+'/chain_set_splits.json') as f:
92
+ dataset_splits = json.load(f)
93
+ # _dataset_splits = json.load(f)
94
+ # dataset_splits = {k: _dataset_splits['train'] for k,_ in _dataset_splits.items()}
95
+ # print("TODO: FIX THIS BACK!!!")
96
+ # import pdb; pdb.set_trace()
97
+
98
+ if self.test_name == 'L100':
99
+ with open(self.path+'/test_split_L100.json') as f:
100
+ test_splits = json.load(f)
101
+ dataset_splits['test'] = test_splits['test']
102
+
103
+ if self.test_name == 'sc':
104
+ with open(self.path+'/test_split_sc.json') as f:
105
+ test_splits = json.load(f)
106
+ dataset_splits['test'] = test_splits['test']
107
+
108
+ name2set = {}
109
+ name2set.update({name:'train' for name in dataset_splits['train']})
110
+ name2set.update({name:'valid' for name in dataset_splits['validation']})
111
+ name2set.update({name:'test' for name in dataset_splits['test']})
112
+
113
+ data_dict = {'train':[],'valid':[],'test':[]}
114
+ for data in data_list:
115
+ if name2set.get(data['title']):
116
+ if name2set[data['title']] == 'train':
117
+ data_dict['train'].append(data)
118
+
119
+ if name2set[data['title']] == 'valid':
120
+ data_dict['valid'].append(data)
121
+
122
+ if name2set[data['title']] == 'test':
123
+ data['category'] = 'Unkown'
124
+ data['score'] = 100.0
125
+ data_dict['test'].append(data)
126
+ return data_dict
127
+
128
+ def change_mode(self, mode):
129
+ self.data = self.cache_data[mode]
130
+
131
+ def __len__(self):
132
+ return len(self.data)
133
+
134
+ def get_item(self, index):
135
+ return self.data[index]
136
+
137
+ def __getitem__(self, index):
138
+ item = self.data[index]
139
+ L = len(item['seq'])
140
+ if L>self.max_length:
141
+ # 计算截断的最大索引
142
+ max_index = L - self.max_length
143
+ # 生成随机的截断索引
144
+ truncate_index = random.randint(0, max_index)
145
+ # 进行截断
146
+ item['seq'] = item['seq'][truncate_index:truncate_index+self.max_length]
147
+ item['CA'] = item['CA'][truncate_index:truncate_index+self.max_length]
148
+ item['C'] = item['C'][truncate_index:truncate_index+self.max_length]
149
+ item['O'] = item['O'][truncate_index:truncate_index+self.max_length]
150
+ item['N'] = item['N'][truncate_index:truncate_index+self.max_length]
151
+ item['chain_mask'] = item['chain_mask'][truncate_index:truncate_index+self.max_length]
152
+ item['chain_encoding'] = item['chain_encoding'][truncate_index:truncate_index+self.max_length]
153
+ item['gt_flex'] = item['gt_flex'][truncate_index:truncate_index+self.max_length]
154
+ item['enm_vals'] = item['enm_vals'][truncate_index:truncate_index+self.max_length]
155
+ return item
Flexpert-Design/src/datasets/foldswitchers_dataset.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import numpy as np
4
+ from tqdm import tqdm
5
+ import random
6
+ import pdb
7
+ import torch.utils.data as data
8
+ from .utils import cached_property
9
+ from transformers import AutoTokenizer
10
+
11
+ class FoldswitchersDataset(data.Dataset):
12
+ def __init__(self, path='./', split='train', max_length=500, test_name='All', data = None, removeTS=0):
13
+ self.path = path
14
+ self.mode = split
15
+ self.max_length = max_length
16
+ self.test_name = test_name
17
+ self.removeTS = removeTS
18
+ if self.removeTS:
19
+ self.remove = json.load(open(self.path+'/remove.json', 'r'))['remove']
20
+
21
+ if data is None:
22
+ self.data = self.cache_data[split] #This calls the cache_data property
23
+ else:
24
+ self.data = data
25
+
26
+ self.tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D", cache_dir="./cache_dir/")
27
+
28
+ @cached_property
29
+ def cache_data(self):
30
+ alphabet='ACDEFGHIKLMNPQRSTVWY'
31
+ alphabet_set = set([a for a in alphabet])
32
+ print("path is: ", self.path)
33
+
34
+ if not os.path.exists(self.path):
35
+ raise "no such file:{} !!!".format(self.path)
36
+ else:
37
+
38
+ with open(self.path+'/chain_set.jsonl') as f:
39
+ lines = f.readlines()
40
+ data_list = []
41
+
42
+ for line in tqdm(lines):
43
+ entry = json.loads(line)
44
+
45
+ if self.removeTS and entry['name'] in self.remove:
46
+ continue
47
+ seq = entry['seq']
48
+
49
+ for key, val in entry['coords'].items():
50
+ entry['coords'][key] = np.asarray(val)
51
+
52
+ bad_chars = set([s for s in seq]).difference(alphabet_set)
53
+
54
+ if len(bad_chars) == 0:
55
+ if len(entry['seq']) <= self.max_length:
56
+ chain_length = len(entry['seq'])
57
+ chain_mask = np.ones(chain_length)
58
+ data_list.append({
59
+ 'title':entry['name'],
60
+ 'seq':entry['seq'],
61
+ 'CA':entry['coords']['CA'],
62
+ 'C':entry['coords']['C'],
63
+ 'O':entry['coords']['O'],
64
+ 'N':entry['coords']['N'],
65
+ 'chain_mask': chain_mask,
66
+ 'chain_encoding': 1*chain_mask
67
+ })
68
+
69
+ with open(self.path+'/chain_set_splits_cleaned.json') as f:
70
+ dataset_splits = json.load(f)
71
+
72
+ if self.test_name == 'L100':
73
+ with open(self.path+'/test_split_L100.json') as f:
74
+ test_splits = json.load(f)
75
+ dataset_splits['test'] = test_splits['test']
76
+
77
+ if self.test_name == 'sc':
78
+ with open(self.path+'/test_split_sc.json') as f:
79
+ test_splits = json.load(f)
80
+ dataset_splits['test'] = test_splits['test']
81
+
82
+ name2set = {}
83
+ name2set.update({name:'train' for name in dataset_splits['train']})
84
+ name2set.update({name:'valid' for name in dataset_splits['validation']})
85
+ name2set.update({name:'test' for name in dataset_splits['test']})
86
+
87
+ data_dict = {'train':[],'valid':[],'test':[]}
88
+ for data in data_list:
89
+ #pdb.set_trace()
90
+ if name2set.get(data['title']): #This was causing the trouble with empty datasets - missmatch of names in the chain_set and chain_set_split
91
+ if name2set[data['title']] == 'train':
92
+ data_dict['train'].append(data)
93
+
94
+ if name2set[data['title']] == 'valid':
95
+ data_dict['valid'].append(data)
96
+
97
+ if name2set[data['title']] == 'test':
98
+ data['category'] = 'Unkown'
99
+ data['score'] = 100.0
100
+ data_dict['test'].append(data)
101
+ return data_dict
102
+
103
+ def change_mode(self, mode):
104
+ self.data = self.cache_data[mode]
105
+
106
+ def __len__(self):
107
+ return len(self.data)
108
+
109
+ def get_item(self, index):
110
+ return self.data[index]
111
+
112
+ def __getitem__(self, index):
113
+ item = self.data[index]
114
+ L = len(item['seq'])
115
+ if L>self.max_length:
116
+ # 计算截断的最大索引
117
+ max_index = L - self.max_length
118
+ # 生成随机的截断索引
119
+ truncate_index = random.randint(0, max_index)
120
+ # 进行截断
121
+ item['seq'] = item['seq'][truncate_index:truncate_index+self.max_length]
122
+ item['CA'] = item['CA'][truncate_index:truncate_index+self.max_length]
123
+ item['C'] = item['C'][truncate_index:truncate_index+self.max_length]
124
+ item['O'] = item['O'][truncate_index:truncate_index+self.max_length]
125
+ item['N'] = item['N'][truncate_index:truncate_index+self.max_length]
126
+ item['chain_mask'] = item['chain_mask'][truncate_index:truncate_index+self.max_length]
127
+ item['chain_encoding'] = item['chain_encoding'][truncate_index:truncate_index+self.max_length]
128
+ return item
Flexpert-Design/src/datasets/mpnn_dataset.py ADDED
@@ -0,0 +1,492 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import numpy as np
4
+ from tqdm import tqdm
5
+ import pandas as pd
6
+ import torch.utils.data as data
7
+ from Bio.PDB import PDBParser
8
+ import torch
9
+ import random
10
+ import csv
11
+ from dateutil import parser
12
+ from .fast_dataloader import DataLoaderX
13
+ from torch.utils.data import DataLoader
14
+ import time
15
+
16
+ from joblib import Parallel, delayed, cpu_count
17
+ from tqdm import tqdm
18
+
19
+
20
+ def pmap_multi(pickleable_fn, data, n_jobs=None, verbose=1, desc=None, **kwargs):
21
+ """
22
+
23
+ Parallel map using joblib.
24
+
25
+ Parameters
26
+ ----------
27
+ pickleable_fn : callable
28
+ Function to map over data.
29
+ data : iterable
30
+ Data over which we want to parallelize the function call.
31
+ n_jobs : int, optional
32
+ The maximum number of concurrently running jobs. By default, it is one less than
33
+ the number of CPUs.
34
+ verbose: int, optional
35
+ The verbosity level. If nonzero, the function prints the progress messages.
36
+ The frequency of the messages increases with the verbosity level. If above 10,
37
+ it reports all iterations. If above 50, it sends the output to stdout.
38
+ kwargs
39
+ Additional arguments for :attr:`pickleable_fn`.
40
+
41
+ Returns
42
+ -------
43
+ list
44
+ The i-th element of the list corresponds to the output of applying
45
+ :attr:`pickleable_fn` to :attr:`data[i]`.
46
+ """
47
+ if n_jobs is None:
48
+ n_jobs = cpu_count() - 1
49
+
50
+ results = Parallel(n_jobs=n_jobs, verbose=verbose, timeout=None)(
51
+ delayed(pickleable_fn)(*d, **kwargs) for i, d in tqdm(enumerate(data),desc=desc)
52
+ )
53
+
54
+
55
+
56
+ def build_training_clusters(params, debug):
57
+ val_ids = set([int(l) for l in open(params['VAL']).readlines()])
58
+ test_ids = set([int(l) for l in open(params['TEST']).readlines()])
59
+
60
+ if debug:
61
+ val_ids = []
62
+ test_ids = []
63
+
64
+ # read & clean list.csv
65
+ with open(params['LIST'], 'r') as f:
66
+ reader = csv.reader(f)
67
+ next(reader)
68
+ rows = [[r[0],r[3],int(r[4])] for r in reader
69
+ if float(r[2])<=params['RESCUT'] and
70
+ parser.parse(r[1])<=parser.parse(params['DATCUT'])]
71
+
72
+ # compile training and validation sets
73
+ train = {}
74
+ valid = {}
75
+ test = {}
76
+
77
+ if debug:
78
+ rows = rows[:20]
79
+ for r in rows:
80
+ if r[2] in val_ids:
81
+ if r[2] in valid.keys():
82
+ valid[r[2]].append(r[:2])
83
+ else:
84
+ valid[r[2]] = [r[:2]]
85
+ elif r[2] in test_ids:
86
+ if r[2] in test.keys():
87
+ test[r[2]].append(r[:2])
88
+ else:
89
+ test[r[2]] = [r[:2]]
90
+ else:
91
+ if r[2] in train.keys():
92
+ train[r[2]].append(r[:2])
93
+ else:
94
+ train[r[2]] = [r[:2]]
95
+ if debug:
96
+ valid=train
97
+ return train, valid, test
98
+
99
+
100
+ def loader_pdb(item,params):
101
+
102
+ pdbid,chid = item[0].split('_')
103
+ PREFIX = "%s/pdb/%s/%s"%(params['DIR'],pdbid[1:3],pdbid)
104
+
105
+ # load metadata
106
+ if not os.path.isfile(PREFIX+".pt"):
107
+ return {'seq': np.zeros(5)}
108
+ meta = torch.load(PREFIX+".pt")
109
+ asmb_ids = meta['asmb_ids']
110
+ asmb_chains = meta['asmb_chains']
111
+ chids = np.array(meta['chains'])
112
+
113
+ # find candidate assemblies which contain chid chain
114
+ asmb_candidates = set([a for a,b in zip(asmb_ids,asmb_chains)
115
+ if chid in b.split(',')])
116
+
117
+ # if the chains is missing is missing from all the assemblies
118
+ # then return this chain alone
119
+ if len(asmb_candidates)<1:
120
+ chain = torch.load("%s_%s.pt"%(PREFIX,chid))
121
+ L = len(chain['seq'])
122
+ return {'seq' : chain['seq'],
123
+ 'xyz' : chain['xyz'],
124
+ 'idx' : torch.zeros(L).int(),
125
+ 'masked' : torch.Tensor([0]).int(),
126
+ 'label' : item[0]}
127
+
128
+ # randomly pick one assembly from candidates
129
+ asmb_i = random.sample(list(asmb_candidates), 1)
130
+
131
+ # indices of selected transforms
132
+ idx = np.where(np.array(asmb_ids)==asmb_i)[0]
133
+
134
+ # load relevant chains
135
+ chains = {c:torch.load("%s_%s.pt"%(PREFIX,c))
136
+ for i in idx for c in asmb_chains[i]
137
+ if c in meta['chains']}
138
+
139
+ # generate assembly
140
+ asmb = {}
141
+ for k in idx:
142
+
143
+ # pick k-th xform
144
+ xform = meta['asmb_xform%d'%k]
145
+ u = xform[:,:3,:3]
146
+ r = xform[:,:3,3]
147
+
148
+ # select chains which k-th xform should be applied to
149
+ s1 = set(meta['chains'])
150
+ s2 = set(asmb_chains[k].split(','))
151
+ chains_k = s1&s2
152
+
153
+ # transform selected chains
154
+ for c in chains_k:
155
+ try:
156
+ xyz = chains[c]['xyz']
157
+ xyz_ru = torch.einsum('bij,raj->brai', u, xyz) + r[:,None,None,:]
158
+ asmb.update({(c,k,i):xyz_i for i,xyz_i in enumerate(xyz_ru)})
159
+ except KeyError:
160
+ return {'seq': np.zeros(5)}
161
+
162
+ # select chains which share considerable similarity to chid
163
+ seqid = meta['tm'][chids==chid][0,:,1]
164
+ homo = set([ch_j for seqid_j,ch_j in zip(seqid,chids)
165
+ if seqid_j>params['HOMO']])
166
+ # stack all chains in the assembly together
167
+ seq,xyz,idx,masked = "",[],[],[]
168
+ seq_list = []
169
+ for counter,(k,v) in enumerate(asmb.items()):
170
+ seq += chains[k[0]]['seq']
171
+ seq_list.append(chains[k[0]]['seq'])
172
+ xyz.append(v)
173
+ idx.append(torch.full((v.shape[0],),counter))
174
+ if k[0] in homo:
175
+ masked.append(counter)
176
+
177
+ return {'seq' : seq,
178
+ 'xyz' : torch.cat(xyz,dim=0),
179
+ 'idx' : torch.cat(idx,dim=0),
180
+ 'masked' : torch.Tensor(masked).int(),
181
+ 'label' : item[0]}
182
+
183
+ def get_pdbs(data, max_length=10000, num_units=1000000):
184
+ init_alphabet = ['A', 'B', 'C', 'D', 'E', 'F', 'G','H', 'I', 'J','K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T','U', 'V','W','X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g','h', 'i', 'j','k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't','u', 'v','w','x', 'y', 'z']
185
+ extra_alphabet = [str(item) for item in list(np.arange(300))]
186
+ chain_alphabet = init_alphabet + extra_alphabet
187
+ c = 0
188
+ c1 = 0
189
+
190
+
191
+ data = {k:v for k,v in data.items()}
192
+ c1 += 1
193
+ if 'label' in list(data):
194
+ my_dict = {}
195
+ s = 0
196
+ concat_seq = ''
197
+ concat_N = []
198
+ concat_CA = []
199
+ concat_C = []
200
+ concat_O = []
201
+ concat_mask = []
202
+ coords_dict = {}
203
+ mask_list = []
204
+ visible_list = []
205
+ if len(list(np.unique(data['idx']))) < 352:
206
+ for idx in list(np.unique(data['idx'])):
207
+ letter = chain_alphabet[idx]
208
+ res = np.argwhere(data['idx']==idx)
209
+ initial_sequence= "".join(list(np.array(list(data['seq']))[res][0,]))
210
+ if initial_sequence[-6:] == "HHHHHH":
211
+ res = res[:,:-6]
212
+ if initial_sequence[0:6] == "HHHHHH":
213
+ res = res[:,6:]
214
+ if initial_sequence[-7:-1] == "HHHHHH":
215
+ res = res[:,:-7]
216
+ if initial_sequence[-8:-2] == "HHHHHH":
217
+ res = res[:,:-8]
218
+ if initial_sequence[-9:-3] == "HHHHHH":
219
+ res = res[:,:-9]
220
+ if initial_sequence[-10:-4] == "HHHHHH":
221
+ res = res[:,:-10]
222
+ if initial_sequence[1:7] == "HHHHHH":
223
+ res = res[:,7:]
224
+ if initial_sequence[2:8] == "HHHHHH":
225
+ res = res[:,8:]
226
+ if initial_sequence[3:9] == "HHHHHH":
227
+ res = res[:,9:]
228
+ if initial_sequence[4:10] == "HHHHHH":
229
+ res = res[:,10:]
230
+ if res.shape[1] < 4:
231
+ pass
232
+ else:
233
+ my_dict['seq_chain_'+letter]= "".join(list(np.array(list(data['seq']))[res][0,]))
234
+ concat_seq += my_dict['seq_chain_'+letter]
235
+ if idx in data['masked']:
236
+ mask_list.append(letter)
237
+ else:
238
+ visible_list.append(letter)
239
+ coords_dict_chain = {}
240
+ all_atoms = np.array(data['xyz'][res,])[0,] #[L, 14, 3]
241
+ coords_dict_chain['N_chain_'+letter]=all_atoms[:,0,:].tolist()
242
+ coords_dict_chain['CA_chain_'+letter]=all_atoms[:,1,:].tolist()
243
+ coords_dict_chain['C_chain_'+letter]=all_atoms[:,2,:].tolist()
244
+ coords_dict_chain['O_chain_'+letter]=all_atoms[:,3,:].tolist()
245
+ my_dict['coords_chain_'+letter]=coords_dict_chain
246
+ my_dict['name']= data['label']
247
+ my_dict['masked_list']= mask_list
248
+ my_dict['visible_list']= visible_list
249
+ my_dict['num_of_chains'] = len(mask_list) + len(visible_list)
250
+ my_dict['seq'] = concat_seq
251
+ if len(concat_seq) <= max_length:
252
+ return my_dict
253
+ return None
254
+
255
+ def safe_iter(ID, split_dict, params, alphabet_set, max_length=1000):
256
+ sel_idx = np.random.randint(0, len(split_dict[ID]))
257
+ out = loader_pdb(split_dict[ID][sel_idx], params)
258
+ entry = get_pdbs(out)
259
+ if entry is None:
260
+ return None
261
+
262
+ seq = entry['seq']
263
+ bad_chars = set([s for s in seq]).difference(alphabet_set)
264
+ if len(bad_chars) != 0:
265
+ return None
266
+
267
+ if len(entry['seq']) > max_length:
268
+ return None
269
+
270
+ masked_chains = entry['masked_list']
271
+ visible_chains = entry['visible_list']
272
+
273
+ all_chains = masked_chains + visible_chains
274
+ visible_temp_dict = {}
275
+ masked_temp_dict = {}
276
+
277
+ for step, letter in enumerate(all_chains):
278
+ chain_seq = entry[f'seq_chain_{letter}']
279
+ if letter in visible_chains:
280
+ visible_temp_dict[letter] = chain_seq
281
+ elif letter in masked_chains:
282
+ masked_temp_dict[letter] = chain_seq
283
+
284
+ for km, vm in masked_temp_dict.items():
285
+ for kv, vv in visible_temp_dict.items():
286
+ if vm == vv:
287
+ if kv not in masked_chains:
288
+ masked_chains.append(kv)
289
+ if kv in visible_chains:
290
+ visible_chains.remove(kv)
291
+
292
+ all_chains = masked_chains + visible_chains
293
+ random.shuffle(all_chains)
294
+
295
+
296
+ x_chain_list = []
297
+ chain_mask_list = []
298
+ chain_seq_list = []
299
+ chain_encoding_list = []
300
+ c = 1
301
+
302
+ for step, letter in enumerate(all_chains):
303
+ if letter in visible_chains:
304
+ chain_seq = entry[f'seq_chain_{letter}']
305
+ chain_length = len(chain_seq)
306
+ chain_coords = entry[f'coords_chain_{letter}'] #this is a dictionary
307
+ chain_mask = np.zeros(chain_length) #0.0 for visible chains
308
+ x_chain = np.stack([chain_coords[c] for c in [f'N_chain_{letter}', f'CA_chain_{letter}', f'C_chain_{letter}', f'O_chain_{letter}']], 1) #[chain_length,4,3]
309
+ x_chain_list.append(x_chain)
310
+ chain_mask_list.append(chain_mask)
311
+ chain_seq_list.append(chain_seq)
312
+ chain_encoding_list.append(c*np.ones(np.array(chain_mask).shape[0]))
313
+ c+=1
314
+ elif letter in masked_chains:
315
+ chain_seq = entry[f'seq_chain_{letter}']
316
+ chain_length = len(chain_seq)
317
+ chain_coords = entry[f'coords_chain_{letter}'] #this is a dictionary
318
+ chain_mask = np.ones(chain_length) #0.0 for visible chains
319
+ x_chain = np.stack([chain_coords[c] for c in [f'N_chain_{letter}', f'CA_chain_{letter}', f'C_chain_{letter}', f'O_chain_{letter}']], 1) #[chain_lenght,4,3]
320
+ x_chain_list.append(x_chain)
321
+ chain_mask_list.append(chain_mask)
322
+ chain_seq_list.append(chain_seq)
323
+ chain_encoding_list.append(c*np.ones(np.array(chain_mask).shape[0]))
324
+ c+=1
325
+
326
+ chain_mask_all = torch.from_numpy(np.concatenate(chain_mask_list))
327
+ chain_encoding_all = torch.from_numpy(np.concatenate(chain_encoding_list))
328
+ x_chain_all = torch.from_numpy(np.concatenate(x_chain_list))
329
+
330
+ data = {
331
+ "title":entry['name'],
332
+ "seq":''.join(chain_seq_list), #len(seq)=n
333
+ "chain_mask":chain_mask_all,
334
+ "chain_encoding":chain_encoding_all,
335
+ "CA":x_chain_all[:,1], # [n,3]
336
+ "C":x_chain_all[:,2],
337
+ "O":x_chain_all[:,3],
338
+ "N":x_chain_all[:,0]} # [n,]
339
+ return data
340
+
341
+ class MPNNDataset(data.Dataset):
342
+ def __init__(self, data_path='/gaozhangyang/drug_dataset/proteinmpnn_data/pdb_2021aug02', rescut=3.5, split='train'):
343
+ self.data_path = data_path
344
+ self.rescut = rescut
345
+ self.params = {
346
+ "LIST" : f"{self.data_path}/list.csv",
347
+ "VAL" : f"{self.data_path}/valid_clusters.txt",
348
+ "TEST" : f"{self.data_path}/test_clusters.txt",
349
+ "DIR" : f"{self.data_path}",
350
+ "DATCUT" : "2030-Jan-01",
351
+ "RESCUT" : self.rescut, #resolution cutoff for PDBs
352
+ "HOMO" : 0.70 #min seq.id. to detect homo chains
353
+ }
354
+
355
+ if not os.path.exists("/gaozhangyang/experiments/OpenCPD/data/mpnn_data/split.pt"):
356
+ train, valid, test = build_training_clusters(self.params, False)
357
+ split = {"train": train, "valid":valid, "test":test}
358
+ torch.save(split, "/gaozhangyang/experiments/OpenCPD/data/mpnn_data/split.pt")
359
+ else:
360
+ split = torch.load("/gaozhangyang/experiments/OpenCPD/data/mpnn_data/split.pt")
361
+
362
+ self.split_dict = split[mode]
363
+ alphabet='ACDEFGHIKLMNPQRSTVWYX'
364
+ self.alphabet_set = set([a for a in alphabet])
365
+ self.IDs = list(self.split_dict.keys())
366
+ # self.data = self.preprocess()
367
+
368
+ def cache_split(self,):
369
+ train, valid, test = build_training_clusters(self.params, False)
370
+
371
+ return {"train": train, "valid":valid, "test":test}
372
+
373
+ @classmethod
374
+ def safe_iter(self, ID, split_dict, params, alphabet_set, max_length=1000):
375
+ # sel_idx = np.random.randint(0, len(split_dict[ID]))
376
+ sel_idx = 0
377
+ out = loader_pdb(split_dict[ID][sel_idx], params)
378
+ entry = get_pdbs(out)
379
+ if entry is None:
380
+ return None
381
+
382
+ seq = entry['seq']
383
+ bad_chars = set([s for s in seq]).difference(alphabet_set)
384
+ if len(bad_chars) != 0:
385
+ return None
386
+
387
+ if len(entry['seq']) > max_length:
388
+ return None
389
+
390
+ masked_chains = entry['masked_list']
391
+ visible_chains = entry['visible_list']
392
+
393
+ all_chains = masked_chains + visible_chains
394
+ visible_temp_dict = {}
395
+ masked_temp_dict = {}
396
+
397
+ for step, letter in enumerate(all_chains):
398
+ chain_seq = entry[f'seq_chain_{letter}']
399
+ if letter in visible_chains:
400
+ visible_temp_dict[letter] = chain_seq
401
+ elif letter in masked_chains:
402
+ masked_temp_dict[letter] = chain_seq
403
+
404
+ for km, vm in masked_temp_dict.items():
405
+ for kv, vv in visible_temp_dict.items():
406
+ if vm == vv:
407
+ if kv not in masked_chains:
408
+ masked_chains.append(kv)
409
+ if kv in visible_chains:
410
+ visible_chains.remove(kv)
411
+
412
+ all_chains = masked_chains + visible_chains
413
+ random.shuffle(all_chains)
414
+
415
+
416
+ x_chain_list = []
417
+ chain_mask_list = []
418
+ chain_seq_list = []
419
+ chain_encoding_list = []
420
+ c = 1
421
+
422
+ for step, letter in enumerate(all_chains):
423
+ if letter in visible_chains:
424
+ chain_seq = entry[f'seq_chain_{letter}']
425
+ chain_length = len(chain_seq)
426
+ chain_coords = entry[f'coords_chain_{letter}'] #this is a dictionary
427
+ chain_mask = np.zeros(chain_length) #0.0 for visible chains
428
+ x_chain = np.stack([chain_coords[c] for c in [f'N_chain_{letter}', f'CA_chain_{letter}', f'C_chain_{letter}', f'O_chain_{letter}']], 1) #[chain_length,4,3]
429
+ x_chain_list.append(x_chain)
430
+ chain_mask_list.append(chain_mask)
431
+ chain_seq_list.append(chain_seq)
432
+ chain_encoding_list.append(c*np.ones(np.array(chain_mask).shape[0]))
433
+ c+=1
434
+ elif letter in masked_chains:
435
+ chain_seq = entry[f'seq_chain_{letter}']
436
+ chain_length = len(chain_seq)
437
+ chain_coords = entry[f'coords_chain_{letter}'] #this is a dictionary
438
+ chain_mask = np.ones(chain_length) #0.0 for visible chains
439
+ x_chain = np.stack([chain_coords[c] for c in [f'N_chain_{letter}', f'CA_chain_{letter}', f'C_chain_{letter}', f'O_chain_{letter}']], 1) #[chain_lenght,4,3]
440
+ x_chain_list.append(x_chain)
441
+ chain_mask_list.append(chain_mask)
442
+ chain_seq_list.append(chain_seq)
443
+ chain_encoding_list.append(c*np.ones(np.array(chain_mask).shape[0]))
444
+ c+=1
445
+
446
+ chain_mask_all = np.concatenate(chain_mask_list)
447
+ chain_encoding_all = np.concatenate(chain_encoding_list)
448
+ x_chain_all = np.concatenate(x_chain_list)
449
+
450
+ data = {
451
+ "title":entry['name']+str(int(chain_mask_all.sum())),
452
+ "seq":''.join(chain_seq_list), #len(seq)=n
453
+ "chain_mask":chain_mask_all,
454
+ "chain_encoding":chain_encoding_all,
455
+ "CA":x_chain_all[:,1], # [n,3]
456
+ "C":x_chain_all[:,2],
457
+ "O":x_chain_all[:,3],
458
+ "N":x_chain_all[:,0]} # [n,]
459
+ return data
460
+
461
+
462
+
463
+ def preprocess(self):
464
+ data = pmap_multi(self.safe_iter, [(ID,) for ID in self.IDs], split_dict=self.split_dict, params=self.params, alphabet_set=self.alphabet_set)
465
+ return data
466
+
467
+ def __len__(self):
468
+ # return len(self.data)
469
+ return len(self.IDs)
470
+
471
+ def __getitem__(self, index):
472
+ ID = self.IDs[index]
473
+ out = self.safe_iter(ID, split_dict=self.split_dict, params=self.params, alphabet_set=self.alphabet_set)
474
+ return out
475
+
476
+
477
+ def collate_fn(batch):
478
+ return batch
479
+
480
+
481
+ if __name__ == "__main__":
482
+ MPNNDataset = MPNNDataset()
483
+ loader = DataLoaderX(local_rank=0, dataset = MPNNDataset, collate_fn=collate_fn, batch_size=4)
484
+ # loader = DataLoader(dataset = MPNNDataset, collate_fn=collate_fn, batch_size=4, prefetch_factor=4, num_workers=4)
485
+ for batch in tqdm(loader):
486
+ for one in batch:
487
+ if one is not None:
488
+ for key, val in one.items():
489
+ if type(val) == torch.Tensor:
490
+ result = val.to('cuda:0')
491
+ time.sleep(2)
492
+ print()
Flexpert-Design/src/datasets/pdb_inference.py ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import numpy as np
4
+ import random
5
+ import pdb
6
+ import torch.utils.data as data
7
+ from .utils import cached_property
8
+ from transformers import AutoTokenizer
9
+
10
+ #Imports for the PDB parser utils
11
+ import glob
12
+ import json
13
+ import numpy as np
14
+ import gzip
15
+ import re
16
+ import multiprocessing
17
+ import tqdm
18
+ import shutil
19
+ SENTINEL = 1
20
+ import biotite.structure as struc
21
+ import biotite.application.dssp as dssp
22
+ import biotite.structure.io.pdb.file as file
23
+
24
+ class PDBInference(data.Dataset):
25
+ def __init__(self, path='./', max_length=500, *args, **kwargs):
26
+ self.path = path
27
+ self.max_length = max_length
28
+
29
+ self.data = self.cache_data #TODO
30
+ self.tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D", cache_dir="./cache_dir/")
31
+
32
+ @cached_property
33
+ def cache_data(self):
34
+ alphabet='ACDEFGHIKLMNPQRSTVWY'
35
+ alphabet_set = set([a for a in alphabet])
36
+ print("path is: ", self.path)
37
+
38
+ if not os.path.exists(self.path):
39
+ raise "no such folder:{} !!!".format(self.path)
40
+ else:
41
+
42
+ #list all PDBs
43
+ pdb_files = []
44
+ _files = os.listdir(self.path)
45
+ for _file in _files:
46
+ if _file.endswith('.pdb'):
47
+ pdb_files.append(_file)
48
+ print(f'pdb_files size = {len(pdb_files)}')
49
+ #parse the PDBs into lines like if it was from the chain_set.json
50
+ lines = []
51
+ for _pdb in pdb_files:
52
+ _input_chain = _pdb.split('_')[1].split('.')[0] #ASSUMING NAMING 'PDBCODE_CHAINCODE_XXX'
53
+ _line = self.parse_PDB(self.path+'/'+_pdb, name=_pdb.split('.')[0], input_chain=_input_chain) #Input chain list can be parsed here as well
54
+ #pdb.set_trace()
55
+ lines.append(_line[0])
56
+
57
+ print(f'lines size = {len(lines)}')
58
+ data_list = []
59
+
60
+ flex_instructions = {}
61
+ flexibility_files = glob.glob(self.path + '/*instructions.csv')
62
+ for file in flexibility_files:
63
+ with open(file, 'r') as f:
64
+ flexibility_instructions_parsed= f.read().strip().split(',')
65
+ flexibility_instructions_parsed = [float(i) for i in flexibility_instructions_parsed] + [0.0] #add the padding here
66
+ flex_instructions[file.split('/')[-1].split('_instructions')[0]] = flexibility_instructions_parsed
67
+
68
+ for line in tqdm.tqdm(lines):
69
+ entry = line
70
+
71
+ seq = entry['seq']
72
+
73
+ for key, val in entry['coords'].items():
74
+ entry['coords'][key] = np.asarray(val)
75
+
76
+ bad_chars = set([s for s in seq]).difference(alphabet_set)
77
+ try:
78
+ _flex_instructions = flex_instructions[entry['name']]
79
+ except KeyError:
80
+ _flex_instructions = [0.0] * len(seq)
81
+ print(f"No flexibility instructions found for {entry['name']}. Passing zeros.")
82
+
83
+ if len(bad_chars) == 0:
84
+ if len(entry['seq']) <= self.max_length:
85
+ chain_length = len(entry['seq'])
86
+ chain_mask = np.ones(chain_length)
87
+ data_list.append({
88
+ 'title':entry['name'],
89
+ 'seq':entry['seq'],
90
+ 'CA':entry['coords']['CA'],
91
+ 'C':entry['coords']['C'],
92
+ 'O':entry['coords']['O'],
93
+ 'N':entry['coords']['N'],
94
+ 'chain_mask': chain_mask,
95
+ 'chain_encoding': 1*chain_mask,
96
+ 'gt_flex': _flex_instructions
97
+ })
98
+ else:
99
+ print(f'Skipping PDBs with Bad chars, e.g. gaps in the sequence: {entry["name"]}')
100
+
101
+ #data_dict = {'train':[],'valid':data_list,'test':data_list}
102
+ print(f'data_list size = {len(data_list)}')
103
+ return data_list#data_dict
104
+
105
+ def change_mode(self, mode):
106
+ self.data = self.cache_data[mode]
107
+
108
+ def __len__(self):
109
+ return len(self.data)
110
+
111
+ def get_item(self, index):
112
+ return self.data[index]
113
+
114
+ def __getitem__(self, index):
115
+ item = self.data[index]
116
+ L = len(item['seq'])
117
+ if L>self.max_length:
118
+ # 计算截断的最大索引
119
+ max_index = L - self.max_length
120
+ # 生成随机的截断索引
121
+ truncate_index = random.randint(0, max_index)
122
+ # 进行截断
123
+ item['seq'] = item['seq'][truncate_index:truncate_index+self.max_length]
124
+ item['CA'] = item['CA'][truncate_index:truncate_index+self.max_length]
125
+ item['C'] = item['C'][truncate_index:truncate_index+self.max_length]
126
+ item['O'] = item['O'][truncate_index:truncate_index+self.max_length]
127
+ item['N'] = item['N'][truncate_index:truncate_index+self.max_length]
128
+ item['chain_mask'] = item['chain_mask'][truncate_index:truncate_index+self.max_length]
129
+ item['chain_encoding'] = item['chain_encoding'][truncate_index:truncate_index+self.max_length]
130
+ item['gt_flex'] = item['gt_flex'][truncate_index:truncate_index+self.max_length]
131
+ return item
132
+
133
+ #Code from data_utils on local PC, based on: https://github.com/JoreyYan/zetadesign/blob/master/data/data.py
134
+ def parse_PDB_biounits(self, x, sse,ssedssp,atoms=['N', 'CA', 'C'], chain=None):
135
+ '''
136
+ input: x = PDB filename
137
+ atoms = atoms to extract (optional)
138
+ output: (length, atoms, coords=(x,y,z)), sequence
139
+ '''
140
+
141
+ alpha_1 = list("ARNDCQEGHILKMFPSTWYV-")
142
+ states = len(alpha_1)
143
+ alpha_3 = ['ALA', 'ARG', 'ASN', 'ASP', 'CYS', 'GLN', 'GLU', 'GLY', 'HIS', 'ILE',
144
+ 'LEU', 'LYS', 'MET', 'PHE', 'PRO', 'SER', 'THR', 'TRP', 'TYR', 'VAL', 'GAP']
145
+
146
+ aa_1_N = {a: n for n, a in enumerate(alpha_1)}
147
+ aa_3_N = {a: n for n, a in enumerate(alpha_3)}
148
+ aa_N_1 = {n: a for n, a in enumerate(alpha_1)}
149
+ aa_1_3 = {a: b for a, b in zip(alpha_1, alpha_3)}
150
+ aa_3_1 = {b: a for a, b in zip(alpha_1, alpha_3)}
151
+
152
+ def AA_to_N(x):
153
+ x = np.array(x)
154
+ if x.ndim == 0: x = x[None]
155
+ return [[aa_1_N.get(a, states - 1) for a in y] for y in x]
156
+
157
+ def N_to_AA(x):
158
+ x = np.array(x)
159
+ if x.ndim == 1: x = x[None]
160
+ return ["".join([aa_N_1.get(a, "-") for a in y]) for y in x]
161
+
162
+ xyz, seq, plddts, min_resn, max_resn = {}, {}, [], 1e6, -1e6
163
+
164
+ pdbcontents = x.split('\n')[0]
165
+ with open(pdbcontents) as f:
166
+ pdbcontents = f.readlines()
167
+ for line in pdbcontents:
168
+
169
+ if line[:6] == "HETATM" and line[17:17 + 3] == "MSE":
170
+ line = line.replace("HETATM", "ATOM ")
171
+ line = line.replace("MSE", "MET")
172
+
173
+ if line[:4] == "ATOM":
174
+ ch = line[21:22]
175
+ if ch == chain or chain is None or ch==' ':
176
+ atom = line[12:12 + 4].strip()
177
+ resi = line[17:17 + 3]
178
+ resn = line[22:22 + 5].strip()
179
+ plddt=line[60:60 + 6].strip()
180
+
181
+
182
+
183
+ x, y, z = [float(line[i:(i + 8)]) for i in [30, 38, 46]]
184
+
185
+ if resn[-1].isalpha():
186
+ resa, resn = resn[-1], int(resn[:-1]) - 1 # in same pos ,use last atoms
187
+ else:
188
+ resa, resn = "_", int(resn) - 1
189
+ # resn = int(resn)
190
+ if resn < min_resn:
191
+ min_resn = resn
192
+ if resn > max_resn:
193
+ max_resn = resn
194
+
195
+
196
+
197
+ if resn not in xyz:
198
+ xyz[resn] = {}
199
+ if resa not in xyz[resn]:
200
+ xyz[resn][resa] = {}
201
+ if resn not in seq:
202
+ seq[resn] = {}
203
+
204
+ if resa not in seq[resn]:
205
+ seq[resn][resa] = resi
206
+
207
+ if atom not in xyz[resn][resa]:
208
+ xyz[resn][resa][atom] = np.array([x, y, z])
209
+
210
+ # convert to numpy arrays, fill in missing values
211
+ seq_, xyz_ ,sse_,ssedssp_= [], [], [], []
212
+ dsspidx=0
213
+ sseidx=0
214
+
215
+ for resn in range(int(min_resn), int(max_resn + 1)):
216
+ if resn in seq:
217
+ for k in sorted(seq[resn]):
218
+ seq_.append(aa_3_N.get(seq[resn][k], 20))
219
+ try:
220
+ if 'CA' in xyz[resn][k]:
221
+ sse_.append(sse[sseidx])
222
+ sseidx = sseidx + 1
223
+ else:
224
+ sse_.append('-')
225
+ except:
226
+ print('error sse')
227
+
228
+
229
+ else:
230
+ seq_.append(20)
231
+ sse_.append('-')
232
+
233
+ misschianatom = False
234
+ if resn in xyz:
235
+
236
+
237
+ for k in sorted(xyz[resn]):
238
+ for atom in atoms:
239
+ if atom in xyz[resn][k]:
240
+ xyz_.append(xyz[resn][k][atom]) #some will miss C and O ,but sse is normal,because sse just depend on CA
241
+ else:
242
+ xyz_.append(np.full(3, np.nan))
243
+ misschianatom=True
244
+ if misschianatom:
245
+ ssedssp_.append('-')
246
+ misschianatom = False
247
+ else:
248
+ try:
249
+ ssedssp_.append(ssedssp[dsspidx]) # if miss chain atom,xyz ,seq think is ok , but dssp miss this
250
+ dsspidx = dsspidx + 1
251
+ except:
252
+ pass
253
+ #print(dsspidx)
254
+
255
+
256
+ else:
257
+ for atom in atoms:
258
+ xyz_.append(np.full(3, np.nan))
259
+ ssedssp_.append('-')
260
+
261
+
262
+ return np.array(xyz_).reshape(-1, len(atoms), 3), N_to_AA(np.array(seq_)),np.array(sse_),np.array(ssedssp_)
263
+
264
+ def parse_PDB(self, path_to_pdb, name, input_chain):
265
+ """
266
+ make sure every time just input 1 line
267
+ """
268
+ c = 0
269
+ pdb_dict_list = []
270
+
271
+
272
+ biounit_names = [path_to_pdb]
273
+ for biounit in biounit_names:
274
+ my_dict = {}
275
+ s = 0
276
+ concat_seq = ''
277
+
278
+
279
+ letter = input_chain #Assuming single chain!!
280
+
281
+ PDBFile = file.PDBFile.read(biounit)
282
+ array_stack = PDBFile.get_structure(altloc="all")
283
+
284
+ #In case the passed letter is unknown, select one chain from the PDB file based on the dominant protein chain
285
+ if letter not in array_stack.chain_id:
286
+ is_protein = struc.filter_amino_acids(array_stack)
287
+ protein_atoms = array_stack[0][is_protein]
288
+ chain_ids, chain_counts = np.unique(protein_atoms.chain_id, return_counts=True)
289
+ dominant_chain_id = chain_ids[np.argmax(chain_counts)]
290
+ letter = dominant_chain_id
291
+
292
+
293
+ sse1 = struc.annotate_sse(array_stack[0], chain_id=letter).tolist()
294
+ if len(sse1)==0:
295
+ sse1 = struc.annotate_sse(array_stack[0], chain_id='').tolist()
296
+
297
+ ssedssp1 = [] #not annotating dssp for now
298
+
299
+
300
+ xyz, seq, sse, ssedssp = self.parse_PDB_biounits(biounit,sse1,ssedssp1,atoms=['N', 'CA', 'C','O'], chain=letter) #TODO: fix the float error
301
+ ssedssp = sse #faking it for now
302
+
303
+ assert len(sse)==len(seq[0])
304
+ assert len(ssedssp) == len(seq[0])
305
+
306
+ if type(xyz) != str:
307
+ concat_seq += seq[0]
308
+ my_dict['seq_chain_' + letter] = seq[0]
309
+
310
+ coords_dict_chain = {}
311
+ coords_dict_chain['N'] = xyz[:, 0, :].tolist()
312
+ coords_dict_chain['CA'] = xyz[:, 1, :].tolist()
313
+ coords_dict_chain['C'] = xyz[:, 2, :].tolist()
314
+ coords_dict_chain['O'] = xyz[:, 3, :].tolist()
315
+ my_dict['coords_chain_' + letter] = coords_dict_chain
316
+ my_dict['coords'] = coords_dict_chain
317
+ s += 1
318
+
319
+ # if s>1:
320
+ # raise NotImplementedError('Inference so far implemented only for single chain proteins')
321
+
322
+ my_dict['name'] = name
323
+ my_dict['num_chains'] = s
324
+ my_dict['seq'] = my_dict[f'seq_chain_{letter}'] #concat_seq
325
+ # if s <= len(chain_alphabet):
326
+ # pdb_dict_list.append(my_dict)
327
+ # c += 1
328
+ pdb_dict_list.append(my_dict)
329
+ return pdb_dict_list
Flexpert-Design/src/datasets/ts_dataset.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import numpy as np
4
+ import torch.utils.data as data
5
+
6
+
7
+ class TSDataset(data.Dataset):
8
+ def __init__(self, path = './', split='test'):
9
+ if not os.path.exists(path):
10
+ raise "no such file:{} !!!".format(path)
11
+ else:
12
+ ts50_data = json.load(open(path+'/ts50.json'))
13
+ ts500_data = json.load(open(path+'/ts500.json'))
14
+
15
+ # TS500 has proteins with lengths of 500+
16
+ # TS50 only contains proteins with lengths less than 500
17
+ self.data = []
18
+ for temp in ts50_data:
19
+ coords = np.array(temp['coords'])
20
+ self.data.append({'title':temp['name'],
21
+ 'seq':temp['seq'],
22
+ 'CA':coords[:,1,:],
23
+ 'C':coords[:,2,:],
24
+ 'O':coords[:,3,:],
25
+ 'N':coords[:,0,:],
26
+ 'category': 'ts50'
27
+ })
28
+
29
+ for temp in ts500_data:
30
+ coords = np.array(temp['coords'])
31
+ self.data.append({'title':temp['name'],
32
+ 'seq':temp['seq'],
33
+ 'CA':coords[:,1,:],
34
+ 'C':coords[:,2,:],
35
+ 'O':coords[:,3,:],
36
+ 'N':coords[:,0,:],
37
+ 'category': 'ts500'
38
+ })
39
+
40
+ def __len__(self):
41
+ return len(self.data)
42
+
43
+ def get_item(self, index):
44
+ return self.data[index]
45
+
46
+ def __getitem__(self, index):
47
+ return self.data[index]
Flexpert-Design/src/datasets/utils.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+
4
+ class cached_property(object):
5
+ """
6
+ Descriptor (non-data) for building an attribute on-demand on first use.
7
+ """
8
+ def __init__(self, factory):
9
+ """
10
+ <factory> is called such: factory(instance) to build the attribute.
11
+ """
12
+ self._attr_name = factory.__name__
13
+ self._factory = factory
14
+
15
+ def __get__(self, instance, owner):
16
+ # Build the attribute.
17
+ attr = self._factory(instance)
18
+
19
+ # Cache the value; hide ourselves.
20
+ setattr(instance, self._attr_name, attr)
21
+ return attr
22
+
23
+
24
+ def get_inds(expected_num, clu_nums, cid2clu, seq2ind):
25
+ cur_len, cur_idx, query_cids, query_idx = 0, 0, [], []
26
+ while cur_len < expected_num:
27
+ cid, l = clu_nums[cur_idx % (len(clu_nums))]
28
+ cur_idx += 1
29
+ # check if this cluster has been selected
30
+ if cid in query_cids:
31
+ continue
32
+ if random.random() > 0.5:
33
+ for seq in cid2clu[cid]:
34
+ # seq2ind: ensure it is in limited lengths
35
+ if seq in seq2ind.keys():
36
+ query_idx.append(seq2ind[seq])
37
+ cur_len += 1
38
+
39
+ query_cids.append(cid)
40
+ return query_cids, query_idx
41
+
42
+
43
+ def get_num(N, valid_num=100):
44
+ train_n, valid_n = int(0.9 * N), min(valid_num, int(0.05 * N))
45
+ test_n = N - train_n - valid_n
46
+ return train_n, valid_n, test_n
47
+
48
+
49
+ def get_full_inds(expected_num, clu_nums, cid2clu, full_seq2ind):
50
+ cur_len, cur_idx, query_cids, query_idx = 0, 0, [], {}
51
+ # build query_idx for each dataset
52
+ for dataname in full_seq2ind.keys():
53
+ if dataname not in query_idx.keys():
54
+ query_idx[dataname] = []
55
+ cur_idx_lst = list(range(len(clu_nums)))
56
+ while cur_len < expected_num:
57
+ cur_idx = random.choice(cur_idx_lst)
58
+ cid, l = clu_nums[cur_idx]
59
+ # check if this cluster has been selected
60
+ if cid in query_cids:
61
+ continue
62
+ for seq in set(cid2clu[cid]):
63
+ # seq2ind: ensure it is in limited lengths
64
+ for dataname in full_seq2ind.keys():
65
+ if seq in full_seq2ind[dataname].keys():
66
+ query_idx[dataname].append(full_seq2ind[dataname][seq])
67
+ cur_len += 1
68
+ query_cids.append(cid)
69
+ cur_idx_lst.remove(cur_idx)
70
+ return query_cids, query_idx
71
+
72
+
73
+ def get_inds(expected_num, clu_nums, cid2clu, seq2ind):
74
+ cur_len, query_cids, query_idx = 0, [], []
75
+ cur_idx_lst = list(range(len(clu_nums)))
76
+ while cur_len < expected_num:
77
+ try:
78
+ cur_idx = random.choice(cur_idx_lst)
79
+ cid, l = clu_nums[cur_idx]
80
+ # check if this cluster has been selected
81
+ if cid in query_cids:
82
+ continue
83
+
84
+ # check if this cluster is too big
85
+ pre = abs(expected_num - cur_len)
86
+ aft = abs(cur_len + l - expected_num)
87
+ if pre < aft:
88
+ continue
89
+
90
+ for seq in cid2clu[cid]:
91
+ # seq2ind: ensure it is in limited lengths
92
+ if seq in seq2ind.keys():
93
+ query_idx.append(seq2ind[seq])
94
+ cur_len += 1
95
+ query_cids.append(cid)
96
+ cur_idx_lst.remove(cur_idx)
97
+ except:
98
+ break
99
+ return query_cids, query_idx
Flexpert-Design/src/interface/__init__.py ADDED
File without changes
Flexpert-Design/src/interface/data_interface.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import importlib
3
+ import pytorch_lightning as pl
4
+ from torch.utils.data import DataLoader
5
+
6
+
7
+ class DInterface_base(pl.LightningDataModule):
8
+ def __init__(self, **kwargs):
9
+ super().__init__()
10
+ self.save_hyperparameters()
11
+ self.batch_size = self.hparams.batch_size
12
+ print("batch_size", self.batch_size)
13
+ self.load_data_module()
14
+
15
+ def setup(self, stage=None):
16
+ # Assign train/val datasets for use in dataloaders
17
+ if stage == 'fit' or stage is None:
18
+ self.trainset = self.instancialize(split = 'train')
19
+ self.valset = self.instancialize(split='valid')
20
+
21
+ # Assign test dataset for use in dataloader(s)
22
+ if stage == 'test' or stage is None:
23
+ self.testset = self.instancialize(split='test')
24
+
25
+ def train_dataloader(self):
26
+ return DataLoader(self.trainset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True, prefetch_factor=3)
27
+
28
+ def val_dataloader(self):
29
+ return DataLoader(self.valset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)
30
+
31
+ def test_dataloader(self):
32
+ return DataLoader(self.testset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)
33
+
34
+ def load_data_module(self):
35
+ name = self.dataset
36
+ # Change the `snake_case.py` file name to `CamelCase` class name.
37
+ # Please always name your model file name as `snake_case.py` and
38
+ # class name corresponding `CamelCase`.
39
+ camel_name = ''.join([i.capitalize() for i in name.split('_')])
40
+ try:
41
+ self.data_module = getattr(importlib.import_module(
42
+ '.'+name, package=__package__), camel_name)
43
+ except:
44
+ raise ValueError(
45
+ f'Invalid Dataset File Name or Invalid Class Name data.{name}.{camel_name}')
46
+
47
+ def instancialize(self, **other_args):
48
+ """ Instancialize a model using the corresponding parameters
49
+ from self.hparams dictionary. You can also input any args
50
+ to overwrite the corresponding value in self.kwargs.
51
+ """
52
+ if other_args['split'] == 'train':
53
+ self.data_module = getattr(importlib.import_module(
54
+ '.AF2DB_dataset', package='data'), 'Af2dbDataset')
55
+ else:
56
+ self.data_module = getattr(importlib.import_module(
57
+ '.CASP15_dataset', package='data'), 'CASP15Dataset')
58
+
59
+ class_args = list(inspect.signature(self.data_module.__init__).parameters)[1:]
60
+ inkeys = self.kwargs.keys()
61
+ args1 = {}
62
+ for arg in class_args:
63
+ if arg in inkeys:
64
+ args1[arg] = self.kwargs[arg]
65
+ args1.update(other_args)
66
+ return self.data_module(**args1)
Flexpert-Design/src/interface/model_interface.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import pytorch_lightning as pl
3
+ import torch.nn as nn
4
+ import os
5
+ import torch.optim.lr_scheduler as lrs
6
+ import inspect
7
+
8
+ class MInterface_base(pl.LightningModule):
9
+ def __init__(self, model_name=None, loss=None, lr=None, **kargs):
10
+ super().__init__()
11
+ self.save_hyperparameters()
12
+ self.load_model()
13
+ self.configure_loss()
14
+ os.makedirs(os.path.join(self.hparams.res_dir, self.hparams.ex_name), exist_ok=True)
15
+
16
+ def forward(self, input):
17
+ pass
18
+
19
+
20
+ def training_step(self, batch, batch_idx, **kwargs):
21
+ pass
22
+
23
+
24
+ def validation_step(self, batch, batch_idx):
25
+ pass
26
+
27
+ def test_step(self, batch, batch_idx):
28
+ # Here we just reuse the validation_step for testing
29
+ return self.validation_step(batch, batch_idx)
30
+
31
+ def on_validation_epoch_end(self):
32
+ # Make the Progress Bar leave there
33
+ self.print('')
34
+
35
+ def get_schedular(self, optimizer, lr_scheduler='onecycle'):
36
+ if lr_scheduler == 'step':
37
+ scheduler = lrs.StepLR(optimizer,
38
+ step_size=self.hparams.lr_decay_steps,
39
+ gamma=self.hparams.lr_decay_rate)
40
+ elif lr_scheduler == 'cosine':
41
+ scheduler = lrs.CosineAnnealingLR(optimizer,
42
+ T_max=self.hparams.lr_decay_steps,
43
+ eta_min=self.hparams.lr_decay_min_lr)
44
+ elif lr_scheduler == 'onecycle':
45
+ scheduler = lrs.OneCycleLR(optimizer, max_lr=self.hparams.lr, steps_per_epoch=self.hparams.steps_per_epoch, epochs=self.hparams.epoch, three_phase=False)
46
+ else:
47
+ raise ValueError('Invalid lr_scheduler type!')
48
+
49
+ return scheduler
50
+
51
+ def configure_optimizers(self):
52
+ if hasattr(self.hparams, 'weight_decay'):
53
+ weight_decay = self.hparams.weight_decay
54
+ else:
55
+ weight_decay = 0
56
+
57
+ optimizer_g = torch.optim.AdamW(self.model.parameters(), lr=self.hparams.lr, weight_decay=weight_decay, betas=(0.9, 0.98), eps=1e-8)
58
+
59
+ schecular_g = self.get_schedular(optimizer_g, self.hparams.lr_scheduler)
60
+
61
+ return [optimizer_g], [{"scheduler": schecular_g, "interval": "step"}]
62
+
63
+ def lr_scheduler_step(self, *args, **kwargs):
64
+ scheduler = self.lr_schedulers()
65
+ scheduler.step()
66
+
67
+
68
+ def configure_devices(self):
69
+ self.device = torch.device(self.hparams.device)
70
+
71
+ def configure_loss(self):
72
+ self.loss_function = nn.CrossEntropyLoss(reduction='none')
73
+
74
+ def load_model(self):
75
+ self.model = None
76
+
77
+ def instancialize(self, Model, **other_args):
78
+ """ Instancialize a model using the corresponding parameters
79
+ from self.hparams dictionary. You can also input any args
80
+ to overwrite the corresponding value in self.hparams.
81
+ """
82
+ class_args = inspect.getargspec(Model.__init__).args[1:]
83
+ inkeys = self.hparams.keys()
84
+ args1 = {}
85
+ for arg in class_args:
86
+ if arg in inkeys:
87
+ args1[arg] = getattr(self.hparams, arg)
88
+ args1.update(other_args)
89
+ return Model(**args1)
Flexpert-Design/src/interface/pretrain_interface.py ADDED
@@ -0,0 +1,405 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from omegaconf import OmegaConf
3
+ from transformers import AutoTokenizer, EsmForMaskedLM
4
+ import torch.nn.functional as F
5
+
6
+ class PretrainInterface(torch.nn.Module):
7
+ def __init__(self, name):
8
+ super().__init__()
9
+ self.name = name
10
+ if name == "ESM35M":
11
+ self.esm_dim = 480
12
+ self.tokenizer = AutoTokenizer.from_pretrained("/huyuqi/model_zoom/transformers/models--facebook--esm2_t12_35M_UR50D")
13
+ self.pretrain_model = EsmForMaskedLM.from_pretrained("/huyuqi/model_zoom/transformers/models--facebook--esm2_t12_35M_UR50D")
14
+ if name == "ESM650M":
15
+ self.esm_dim = 1280
16
+ self.tokenizer = AutoTokenizer.from_pretrained("/huyuqi/model_zoom/transformers/models--facebook--esm2_t33_650M_UR50D/snapshots/08e4846e537177426273712802403f7ba8261b6c")
17
+ self.pretrain_model = EsmForMaskedLM.from_pretrained("/huyuqi/model_zoom/transformers/models--facebook--esm2_t33_650M_UR50D/snapshots/08e4846e537177426273712802403f7ba8261b6c")
18
+ if name == "ESM3B":
19
+ self.esm_dim = 2560
20
+ self.tokenizer = AutoTokenizer.from_pretrained("/huyuqi/model_zoom/transformers/models--facebook--esm2_t36_3B_UR50D/snapshots/476b639933c8baad5ad09a60ac1a87f987b656fc")
21
+ self.pretrain_model = EsmForMaskedLM.from_pretrained("/huyuqi/model_zoom/transformers/models--facebook--esm2_t36_3B_UR50D/snapshots/476b639933c8baad5ad09a60ac1a87f987b656fc")
22
+
23
+ if name == "vanilla":
24
+ from step1_VQ.model_interface import MInterface
25
+ pretrain_args = OmegaConf.load("/huyuqi/xmyu/DiffSDS/Pretrain_lightning/results/ESMVQ/base/configs/10-18T01-15-36-project.yaml")
26
+ pretrain_args.diffusion = False
27
+ self.pretrain_model = MInterface(**pretrain_args)
28
+ ckpt = torch.load('/huyuqi/xmyu/DiffSDS/Pretrain_lightning/results/ESMVQ/base/checkpoints/best-epoch=14-val_loss=0.314.pth', map_location=torch.device('cpu'))
29
+ state_dict = {k.replace('_forward_module.', ''): v for k, v in ckpt.items()}
30
+ self.pretrain_model.load_state_dict(state_dict, strict=False)
31
+
32
+ # if name == "LFQ":
33
+ # from step1_VQ.model_interface import MInterface
34
+ # pretrain_args = OmegaConf.load("/huyuqi/xmyu/DiffSDS/Pretrain_lightning/results/ESMFVQ/LFQ_seg_linear/configs/10-17T15-46-37-project.yaml")
35
+ # pretrain_args.diffusion = False
36
+ # self.pretrain_model = MInterface(**pretrain_args)
37
+ # ckpt = torch.load('/huyuqi/xmyu/DiffSDS/Pretrain_lightning/results/ESMFVQ/LFQ_seg_linear/checkpoints/best-epoch=14-val_loss=0.161.pth', map_location=torch.device('cpu'))
38
+ # state_dict = {k.replace('_forward_module.', ''): v for k, v in ckpt.items()}
39
+ # self.pretrain_model.load_state_dict(state_dict, strict=False)
40
+
41
+
42
+
43
+ if name == "softgroup-1":
44
+ from step1_VQ.model_interface import MInterface
45
+ pretrain_args = OmegaConf.load("/huyuqi/xmyu/VQProteinFormer/step1_VQ/results/ESMSoftGroup/softgroup-1/configs/12-16T14-57-28-project.yaml")
46
+ pretrain_args.diffusion = False
47
+ self.pretrain_model = MInterface(**pretrain_args)
48
+ ckpt = torch.load('/huyuqi/xmyu/VQProteinFormer/step1_VQ/results/ESMSoftGroup/softgroup-1/checkpoints/best-epoch=13-val_loss=0.111.pth', map_location=torch.device('cpu'))
49
+ state_dict = {k.replace('_forward_module.', ''): v for k, v in ckpt.items()}
50
+ self.pretrain_model.load_state_dict(state_dict)
51
+
52
+ if name == "softgroup-2":
53
+ from step1_VQ.model_interface import MInterface
54
+ pretrain_args = OmegaConf.load("/huyuqi/xmyu/DiffSDS/Pretrain_lightning/results/ESMSoftGroup/softgroup-2/configs/10-24T12-51-57-project.yaml")
55
+ pretrain_args.diffusion = False
56
+ self.pretrain_model = MInterface(**pretrain_args)
57
+ ckpt = torch.load('/huyuqi/xmyu/DiffSDS/Pretrain_lightning/results/ESMSoftGroup/softgroup-2/checkpoints/best-epoch=14-val_loss=0.067.pth', map_location=torch.device('cpu'))
58
+ state_dict = {k.replace('_forward_module.', ''): v for k, v in ckpt.items()}
59
+ self.pretrain_model.load_state_dict(state_dict)
60
+
61
+ if name == "softgroup-3":
62
+ from step1_VQ.model_interface import MInterface
63
+ pretrain_args = OmegaConf.load("/huyuqi/xmyu/DiffSDS/Pretrain_lightning/results/ESMSoftGroup/softgroup-3/configs/10-25T00-04-15-project.yaml")
64
+ pretrain_args.diffusion = False
65
+ self.pretrain_model = MInterface(**pretrain_args)
66
+ ckpt = torch.load('/huyuqi/xmyu/DiffSDS/Pretrain_lightning/results/ESMSoftGroup/softgroup-3/checkpoints/best-epoch=14-val_loss=0.063.pth', map_location=torch.device('cpu'))
67
+ state_dict = {k.replace('_forward_module.', ''): v for k, v in ckpt.items()}
68
+ self.pretrain_model.load_state_dict(state_dict)
69
+
70
+ if name == "softgroup-4":
71
+ from step1_VQ.model_interface import MInterface
72
+ pretrain_args = OmegaConf.load("/huyuqi/xmyu/DiffSDS/Pretrain_lightning/results/ESMSoftGroup/softgroup_32_vectors/configs/10-19T01-03-55-project.yaml")
73
+ pretrain_args.diffusion = False
74
+ self.pretrain_model = MInterface(**pretrain_args)
75
+ ckpt = torch.load('/huyuqi/xmyu/DiffSDS/Pretrain_lightning/results/ESMSoftGroup/softgroup_32_vectors/checkpoints/best-epoch=14-val_loss=0.056.pth', map_location=torch.device('cpu'))
76
+ state_dict = {k.replace('_forward_module.', ''): v for k, v in ckpt.items()}
77
+ self.pretrain_model.load_state_dict(state_dict, strict=False)
78
+
79
+ if name == "softgroup-5":
80
+ from step1_VQ.model_interface import MInterface
81
+ pretrain_args = OmegaConf.load("/huyuqi/xmyu/DiffSDS/Pretrain_lightning/results/ESMSoftGroup/softgroup-5-gzy/configs/10-27T17-15-56-project.yaml")
82
+ pretrain_args.diffusion = False
83
+ self.pretrain_model = MInterface(**pretrain_args)
84
+ ckpt = torch.load('/huyuqi/xmyu/DiffSDS/Pretrain_lightning/results/ESMSoftGroup/softgroup-5-gzy/checkpoints/best-epoch=14-val_loss=0.039.pth', map_location=torch.device('cpu'))
85
+ state_dict = {k.replace('_forward_module.', ''): v for k, v in ckpt.items()}
86
+ self.pretrain_model.load_state_dict(state_dict)
87
+
88
+ if name == "softgroup-6":
89
+ from step1_VQ.model_interface import MInterface
90
+ pretrain_args = OmegaConf.load("/huyuqi/xmyu/DiffSDS/Pretrain_lightning/results/ESMSoftGroup/softgroup_128_group/configs/10-28T01-28-50-project.yaml")
91
+ pretrain_args.diffusion = False
92
+ self.pretrain_model = MInterface(**pretrain_args)
93
+ ckpt = torch.load('/huyuqi/xmyu/DiffSDS/Pretrain_lightning/results/ESMSoftGroup/softgroup_128_group/checkpoints/best-epoch=14-val_loss=0.011.pth', map_location=torch.device('cpu'))
94
+ state_dict = {k.replace('_forward_module.', ''): v for k, v in ckpt.items()}
95
+ self.pretrain_model.load_state_dict(state_dict)
96
+
97
+
98
+
99
+ if name == "softgroup_128_group":
100
+ from step1_VQ.model_interface import MInterface
101
+ pretrain_args = OmegaConf.load("/huyuqi/xmyu/DiffSDS/Pretrain_lightning/results/ESMSoftGroup/softgroup_128_group/configs/10-28T01-28-50-project.yaml")
102
+ pretrain_args.diffusion = False
103
+ self.pretrain_model = MInterface(**pretrain_args)
104
+ ckpt = torch.load('/huyuqi/xmyu/DiffSDS/Pretrain_lightning/results/ESMSoftGroup/softgroup_128_group/checkpoints/best-epoch=14-val_loss=0.011.pth', map_location=torch.device('cpu'))
105
+ state_dict = {k.replace('_forward_module.', ''): v for k, v in ckpt.items()}
106
+ self.pretrain_model.load_state_dict(state_dict)
107
+
108
+ if name == "diff-softgroup-1":
109
+ from step1_VQ.model_interface import MInterface
110
+ pretrain_args = OmegaConf.load("/huyuqi/xmyu/VQProteinFormer/step1_VQ/results/DiffESMSoftGroup/diff-softgroup-rm-dist/configs/12-17T14-19-21-project.yaml")
111
+ pretrain_args.diffusion = True
112
+ self.pretrain_model = MInterface(**pretrain_args)
113
+ ckpt = torch.load('/huyuqi/xmyu/VQProteinFormer/step1_VQ/results/DiffESMSoftGroup/diff-softgroup-rm-dist/checkpoints/best-epoch=12-val_loss=0.496.pth', map_location=torch.device('cpu'))
114
+ state_dict = {k.replace('_forward_module.', ''): v for k, v in ckpt.items()}
115
+ self.pretrain_model.load_state_dict(state_dict)
116
+
117
+ if name == "diff-softgroup-4":
118
+ from step1_VQ.model_interface import MInterface
119
+ pretrain_args = OmegaConf.load("/huyuqi/xmyu/VQProteinFormer/step1_VQ/results/DiffESMSoftGroup/diff-softgroup-vq32/configs/12-19T01-54-15-project.yaml")
120
+ pretrain_args.diffusion = True
121
+ self.pretrain_model = MInterface(**pretrain_args)
122
+ ckpt = torch.load('/huyuqi/xmyu/VQProteinFormer/step1_VQ/results/DiffESMSoftGroup/diff-softgroup-vq32/checkpoints/best-epoch=13-val_loss=0.184.pth', map_location=torch.device('cpu'))
123
+ state_dict = {k.replace('_forward_module.', ''): v for k, v in ckpt.items()}
124
+ self.pretrain_model.load_state_dict(state_dict)
125
+
126
+ if name == "diff-softgroup-5":
127
+ from step1_VQ.model_interface import MInterface
128
+ pretrain_args = OmegaConf.load("/huyuqi/xmyu/VQProteinFormer/step1_VQ/results/DiffESMSoftGroup/diff-softgroup-vq64/configs/12-19T01-57-07-project.yaml")
129
+ pretrain_args.diffusion = True
130
+ self.pretrain_model = MInterface(**pretrain_args)
131
+ ckpt = torch.load('/huyuqi/xmyu/VQProteinFormer/step1_VQ/results/DiffESMSoftGroup/diff-softgroup-vq64/checkpoints/best-epoch=13-val_loss=0.100.pth', map_location=torch.device('cpu'))
132
+ state_dict = {k.replace('_forward_module.', ''): v for k, v in ckpt.items()}
133
+ self.pretrain_model.load_state_dict(state_dict)
134
+
135
+ if name == "diff-softgroup-6":
136
+ from step1_VQ.model_interface import MInterface
137
+ pretrain_args = OmegaConf.load("/huyuqi/xmyu/VQProteinFormer/step1_VQ/results/DiffESMSoftGroup/diff-softgroup-vq128/configs/12-19T10-47-37-project.yaml")
138
+ pretrain_args.diffusion = True
139
+ self.pretrain_model = MInterface(**pretrain_args)
140
+ ckpt = torch.load('/huyuqi/xmyu/VQProteinFormer/step1_VQ/results/DiffESMSoftGroup/diff-softgroup-vq128/checkpoints/best-epoch=13-val_loss=0.081.pth', map_location=torch.device('cpu'))
141
+ state_dict = {k.replace('_forward_module.', ''): v for k, v in ckpt.items()}
142
+ self.pretrain_model.load_state_dict(state_dict)
143
+
144
+ if name == 'vanilla-1':
145
+ from step1_VQ.model_interface import MInterface
146
+ pretrain_args = OmegaConf.load("/huyuqi/xmyu/DiffSDS/Pretrain_lightning/results/ESMVQ/base/configs/10-18T01-15-37-project.yaml")
147
+ self.pretrain_model = MInterface(**pretrain_args)
148
+ ckpt = torch.load("/huyuqi/xmyu/DiffSDS/Pretrain_lightning/results/ESMVQ/base/checkpoints/best-epoch=14-val_loss=0.314.pth")
149
+ state_dict = {k.replace('_forward_module.', ''): v for k, v in ckpt.items()}
150
+ self.pretrain_model.load_state_dict(state_dict)
151
+
152
+ if name == 'soft-1':
153
+ from step1_VQ.model_interface import MInterface
154
+ pretrain_args = OmegaConf.load("/huyuqi/xmyu/VQProteinFormer/step1_VQ/results/ESMSoft/soft_rerun/configs/12-10T12-38-16-project.yaml")
155
+ pretrain_args.diffusion=False
156
+ pretrain_args.attn_type = 'raw'
157
+ self.pretrain_model = MInterface(**pretrain_args)
158
+ ckpt = torch.load("/huyuqi/xmyu/VQProteinFormer/step1_VQ/results/ESMSoft/soft_rerun/checkpoints/best-epoch=14-val_loss=0.018.pth")
159
+ state_dict = {k.replace('_forward_module.', ''): v for k, v in ckpt.items()}
160
+ self.pretrain_model.load_state_dict(state_dict)
161
+
162
+ if name == 'soft_64_vecs':
163
+ pretrain_args = OmegaConf.load("/huyuqi/xmyu/DiffSDS/Pretrain_lightning/results/ESMSoft/soft_vq_num64/configs/10-19T11-11-58-project.yaml")
164
+ self.pretrain_model = MInterface(**pretrain_args)
165
+ ckpt = torch.load("/huyuqi/xmyu/DiffSDS/Pretrain_lightning/results/ESMSoft/soft_vq_num64/checkpoints/best-epoch=14-val_loss=8.768.pth")
166
+ state_dict = {k.replace('_forward_module.', ''): v for k, v in ckpt.items()}
167
+ self.pretrain_model.load_state_dict(state_dict)
168
+
169
+ if name == 'LFQ':
170
+ from step1_VQ.model_interface import MInterface
171
+ pretrain_args = OmegaConf.load("/huyuqi/xmyu/DiffSDS/Pretrain_lightning/results/ESMFVQ/vanilla_L1loss/configs/10-24T01-36-37-project.yaml")
172
+ pretrain_args.diffusion = False
173
+ pretrain_args.attn_type = 'raw'
174
+ self.pretrain_model = MInterface(**pretrain_args)
175
+ ckpt = torch.load("/huyuqi/xmyu/DiffSDS/Pretrain_lightning/results/ESMFVQ/vanilla_L1loss/checkpoints/best-epoch=14-val_loss=11.328.pth")
176
+ state_dict = {k.replace('_forward_module.', ''): v for k, v in ckpt.items()}
177
+ self.pretrain_model.load_state_dict(state_dict, strict=False)
178
+
179
+ if name == 'SCQ-mlp3-vqdim32':
180
+ from step1_VQ.model_interface import MInterface
181
+ pretrain_args = OmegaConf.load("/huyuqi/xmyu/VQProteinFormer/step1_VQ/results/ESMSoftBV/SoftBV-vq16-conditional-mlp3-vqdim32/configs/12-22T07-52-47-project.yaml")
182
+ pretrain_args.diffusion = False
183
+ pretrain_args.vq_dim, pretrain_args.condition_layer, pretrain_args.sphere = 32, 3, False
184
+
185
+ self.pretrain_model = MInterface(**pretrain_args)
186
+ ckpt = torch.load("/huyuqi/xmyu/VQProteinFormer/step1_VQ/results/ESMSoftBV/SoftBV-vq16-conditional-mlp3-vqdim32/checkpoints/best-epoch=14-val_loss=0.376.pth")
187
+ state_dict = {k.replace('_forward_module.', ''): v for k, v in ckpt.items()}
188
+ self.pretrain_model.load_state_dict(state_dict)
189
+
190
+ if name == 'SCQ-mlp3-vqdim32-sphere':
191
+ from step1_VQ.model_interface import MInterface
192
+ pretrain_args = OmegaConf.load("/huyuqi/xmyu/VQProteinFormer/step1_VQ/results/ESMSoftBV/SoftBV-mlp3-vqdim32-sphere/configs/12-22T10-44-46-project.yaml")
193
+ pretrain_args.diffusion = False
194
+
195
+ self.pretrain_model = MInterface(**pretrain_args)
196
+ ckpt = torch.load("/huyuqi/xmyu/VQProteinFormer/step1_VQ/results/ESMSoftBV/SoftBV-mlp3-vqdim32-sphere/checkpoints/best-epoch=14-val_loss=0.454.pth")
197
+ state_dict = {k.replace('_forward_module.', ''): v for k, v in ckpt.items()}
198
+ self.pretrain_model.load_state_dict(state_dict)
199
+
200
+ if name == 'SCQ-mlp6-vqdim32-sphere':
201
+ from step1_VQ.model_interface import MInterface
202
+ pretrain_args = OmegaConf.load("/huyuqi/xmyu/VQProteinFormer/step1_VQ/results/ESMSoftBV/SoftBV-mlp6BN-vqdim32-sphere/configs/12-22T18-28-04-project.yaml")
203
+ pretrain_args.diffusion = False
204
+ pretrain_args.attn_type = 'raw'
205
+ self.pretrain_model = MInterface(**pretrain_args)
206
+ ckpt = torch.load("/huyuqi/xmyu/VQProteinFormer/step1_VQ/results/ESMSoftBV/SoftBV-mlp6BN-vqdim32-sphere/checkpoints/best-epoch=14-val_loss=0.148.pth")
207
+ state_dict = {k.replace('_forward_module.', ''): v for k, v in ckpt.items()}
208
+ self.pretrain_model.load_state_dict(state_dict)
209
+
210
+ if name == 'SCQ-mlp2-vqdim32':
211
+ from step1_VQ.model_interface import MInterface
212
+ pretrain_args = OmegaConf.load("/huyuqi/xmyu/VQProteinFormer/step1_VQ/results/ESMSoftBV/SoftBV-vq16-conditional-mlp2-vqdim32/configs/12-22T00-21-35-project.yaml")
213
+ pretrain_args.diffusion = False
214
+ pretrain_args.vq_dim, pretrain_args.condition_layer, pretrain_args.sphere = 32, 2, False
215
+ self.pretrain_model = MInterface(**pretrain_args)
216
+ ckpt = torch.load("/huyuqi/xmyu/VQProteinFormer/step1_VQ/results/ESMSoftBV/SoftBV-vq16-conditional-mlp2-vqdim32/checkpoints/best-epoch=14-val_loss=0.362.pth")
217
+ state_dict = {k.replace('_forward_module.', ''): v for k, v in ckpt.items()}
218
+ self.pretrain_model.load_state_dict(state_dict)
219
+
220
+ if name == 'SCQ-mlp2-vqdim32-sphere':
221
+ from step1_VQ.model_interface import MInterface
222
+ pretrain_args = OmegaConf.load("/huyuqi/xmyu/VQProteinFormer/step1_VQ/results/ESMSoftBV/SoftBV-vq16-conditional-sphere-vqdim32/configs/12-22T00-06-35-project.yaml")
223
+ pretrain_args.diffusion = False
224
+ pretrain_args.vq_dim, pretrain_args.condition_layer, pretrain_args.sphere = 32, 2, True
225
+ self.pretrain_model = MInterface(**pretrain_args)
226
+ ckpt = torch.load("/huyuqi/xmyu/VQProteinFormer/step1_VQ/results/ESMSoftBV/SoftBV-vq16-conditional-sphere-vqdim32/checkpoints/best-epoch=14-val_loss=0.338.pth")
227
+ state_dict = {k.replace('_forward_module.', ''): v for k, v in ckpt.items()}
228
+ self.pretrain_model.load_state_dict(state_dict)
229
+
230
+ if name == 'SCQ-mlp2-vqdim16':
231
+ from step1_VQ.model_interface import MInterface
232
+ pretrain_args = OmegaConf.load("/huyuqi/xmyu/VQProteinFormer/step1_VQ/results/ESMSoftBV/SoftBV-vq16-conditional/configs/12-21T13-13-11-project.yaml")
233
+ pretrain_args.diffusion = False
234
+ pretrain_args.vq_dim, pretrain_args.condition_layer, pretrain_args.sphere = 16, 2, False
235
+ self.pretrain_model = MInterface(**pretrain_args)
236
+ ckpt = torch.load("/huyuqi/xmyu/VQProteinFormer/step1_VQ/results/ESMSoftBV/SoftBV-vq16-conditional/checkpoints/best-epoch=14-val_loss=0.094.pth")
237
+ state_dict = {k.replace('_forward_module.', ''): v for k, v in ckpt.items()}
238
+ self.pretrain_model.load_state_dict(state_dict)
239
+
240
+ if name == 'SCQ-mlp2-vqdim16-sphere':
241
+ from step1_VQ.model_interface import MInterface
242
+ pretrain_args = OmegaConf.load("/huyuqi/xmyu/VQProteinFormer/step1_VQ/results/ESMSoftBV/SoftBV-vq16-conditional-sphere/configs/12-21T16-38-57-project.yaml")
243
+ pretrain_args.diffusion = False
244
+ pretrain_args.vq_dim, pretrain_args.condition_layer, pretrain_args.sphere = 16, 2, True
245
+ self.pretrain_model = MInterface(**pretrain_args)
246
+ ckpt = torch.load("/huyuqi/xmyu/VQProteinFormer/step1_VQ/results/ESMSoftBV/SoftBV-vq16-conditional-sphere/checkpoints/best-epoch=14-val_loss=1.080.pth")
247
+ state_dict = {k.replace('_forward_module.', ''): v for k, v in ckpt.items()}
248
+ self.pretrain_model.load_state_dict(state_dict)
249
+
250
+ if name == 'SCQ-vq8-mlp6-vqdim16-sphere':
251
+ from step1_VQ.model_interface import MInterface
252
+ pretrain_args = OmegaConf.load("/huyuqi/xmyu/VQProteinFormer/step1_VQ/results/ESMSoftBV/SoftBV-vq8-mlp6BN-vqdim32-sphere/configs/12-23T05-15-56-project.yaml")
253
+ pretrain_args.diffusion = False
254
+ self.pretrain_model = MInterface(**pretrain_args)
255
+ ckpt = torch.load("/huyuqi/xmyu/VQProteinFormer/step1_VQ/results/ESMSoftBV/SoftBV-vq8-mlp6BN-vqdim32-sphere/checkpoints/best-epoch=14-val_loss=0.892.pth")
256
+ state_dict = {k.replace('_forward_module.', ''): v for k, v in ckpt.items()}
257
+ self.pretrain_model.load_state_dict(state_dict)
258
+
259
+ if name == 'SCQ-mlp9-vqdim32-sphere':
260
+ from step1_VQ.model_interface import MInterface
261
+ pretrain_args = OmegaConf.load("/huyuqi/xmyu/VQProteinFormer/step1_VQ/results/ESMSoftBV/SoftBV-mlp9BN-vqdim32-sphere/configs/12-23T16-20-07-project.yaml")
262
+ pretrain_args.diffusion = False
263
+
264
+ self.pretrain_model = MInterface(**pretrain_args)
265
+ ckpt = torch.load("/huyuqi/xmyu/VQProteinFormer/step1_VQ/results/ESMSoftBV/SoftBV-mlp9BN-vqdim32-sphere/checkpoints/best-epoch=14-val_loss=0.151.pth")
266
+ state_dict = {k.replace('_forward_module.', ''): v for k, v in ckpt.items()}
267
+ self.pretrain_model.load_state_dict(state_dict)
268
+
269
+
270
+
271
+ if name == 'AF2VQ':
272
+ from step3_AF2VQ.model_interface import MInterface
273
+ pretrain_args = OmegaConf.load("/huyuqi/xmyu/VQProteinFormer/step3_AF2VQ/results/AF2VQ_softgroup16/configs/12-13T07-59-50-project.yaml")
274
+ self.pretrain_model = MInterface(**pretrain_args)
275
+ ckpt = torch.load("/huyuqi/xmyu/VQProteinFormer/step3_AF2VQ/results/AF2VQ_softgroup16/checkpoints/best-epoch=11-val_loss=0.812.pth")
276
+ state_dict = {k.replace('_forward_module.', ''): v for k, v in ckpt.items()}
277
+ self.pretrain_model.load_state_dict(state_dict)
278
+
279
+ if name == "ProGLM":
280
+ self.vq_dim=480
281
+ from step2_ProGLM.model.model_interface import MInterface
282
+ pretrain_args = OmegaConf.load("/huyuqi/xmyu/DiffSDS/Inpainting_representation/results/softgroup_bin_1127/version_4/hparams.yaml")
283
+ self.pretrain_model = MInterface(**pretrain_args)
284
+ ckpt = torch.load('/huyuqi/xmyu/DiffSDS/Inpainting_representation/results/softgroup_bin_1127/checkpoints/best-epoch=08-valid_acc=0.804.ckpt', map_location=torch.device('cpu'))['state_dict']
285
+ state_dict = {k.replace('_forward_module.', ''): v for k, v in ckpt.items()}
286
+ self.pretrain_model.load_state_dict(state_dict)
287
+
288
+ if name == 'ProGLM_softgroup_af2db':
289
+ from step2_ProGLM.model.model_interface import MInterface
290
+ pretrain_args = OmegaConf.load("/huyuqi/xmyu/DiffSDS/Inpainting_representation/results/softgroup_bin_2/version_3/hparams.yaml")
291
+ self.pretrain_model = MInterface(**pretrain_args)
292
+ ckpt = torch.load('/huyuqi/xmyu/DiffSDS/Inpainting_representation/results/softgroup_bin_2/checkpoints/best-epoch=13-valid_acc=0.863.ckpt', map_location=torch.device('cpu'))['state_dict']
293
+ state_dict = {k.replace('_forward_module.', ''): v for k, v in ckpt.items()}
294
+ self.pretrain_model.load_state_dict(state_dict)
295
+
296
+ if name == 'ProGLM_SoftVQ_cath':
297
+ from step2_ProGLM.model.model_interface import MInterface
298
+ pretrain_args = OmegaConf.load("/huyuqi/xmyu/VQProteinFormer/step2_ProGLM/results/ProGLM_SoftVQ_epoch15_pad300/configs/12-25T01-20-35-project.yaml")
299
+ self.pretrain_model = MInterface(**pretrain_args)
300
+ ckpt = torch.load('/huyuqi/xmyu/VQProteinFormer/step2_ProGLM/results/ProGLM_SoftVQ_epoch15_pad300/checkpoints/best-epoch=27-valid_acc=0.001.pth')
301
+ state_dict = {k.replace('_forward_module.', ''): v for k, v in ckpt.items()}
302
+ self.pretrain_model.load_state_dict(state_dict)
303
+
304
+
305
+ if name == 'ProGLM_SoftCVQ_cath':
306
+ from step2_ProGLM.model.model_interface import MInterface
307
+ pretrain_args = OmegaConf.load("/huyuqi/xmyu/VQProteinFormer/step2_ProGLM/results/ProGLM_SoftCVQ_epoch15_pad300_BCE/configs/12-25T01-42-37-project.yaml")
308
+ self.pretrain_model = MInterface(**pretrain_args)
309
+ ckpt = torch.load('/huyuqi/xmyu/VQProteinFormer/step2_ProGLM/results/ProGLM_SoftCVQ_epoch15_pad300_BCE/checkpoints/best-epoch=14-valid_acc=0.614.pth')
310
+ state_dict = {k.replace('_forward_module.', ''): v for k, v in ckpt.items()}
311
+ self.pretrain_model.load_state_dict(state_dict)
312
+
313
+ if name == 'ProGLM_SoftCVQ_cath_inpaint':
314
+ from step2_ProGLM.model.model_interface import MInterface
315
+ pretrain_args = OmegaConf.load("/huyuqi/xmyu/VQProteinFormer/step2_ProGLM/results/ProGLM_SoftCVQ_epoch15_pad300_BCE_inpaint/configs/12-25T07-47-52-project.yaml")
316
+ self.pretrain_model = MInterface(**pretrain_args)
317
+ ckpt = torch.load('/huyuqi/xmyu/VQProteinFormer/step2_ProGLM/results/ProGLM_SoftCVQ_epoch15_pad300_BCE_inpaint/checkpoints/best-epoch=14-valid_acc=0.616.pth')
318
+ state_dict = {k.replace('_forward_module.', ''): v for k, v in ckpt.items()}
319
+ self.pretrain_model.load_state_dict(state_dict)
320
+
321
+ if name == 'ProGLM_SoftCVQ_AF2DB':
322
+ from step2_ProGLM.model.model_interface import MInterface
323
+ pretrain_args = OmegaConf.load("/huyuqi/xmyu/VQProteinFormer/step2_ProGLM/results/ProGLM_SoftCVQ_epoch15_AF2DB/configs/12-25T13-01-12-project.yaml")
324
+ self.pretrain_model = MInterface(**pretrain_args)
325
+ ckpt = torch.load('/huyuqi/xmyu/VQProteinFormer/step2_ProGLM/results/ProGLM_SoftCVQ_epoch15_AF2DB/checkpoints/best-epoch=14-valid_acc=0.631.pth')
326
+ state_dict = {k.replace('_forward_module.', ''): v for k, v in ckpt.items()}
327
+ self.pretrain_model.load_state_dict(state_dict)
328
+
329
+ if name == 'ProGLM_SoftCVQ_ESM1B_CATH':
330
+ from step2_ProGLM.model.model_interface import MInterface
331
+ pretrain_args = OmegaConf.load("/huyuqi/xmyu/VQProteinFormer/step2_ProGLM/results/ProGLM_SoftCVQ_ESM1B_CATH_lr5e-5/configs/12-25T16-02-35-project.yaml")
332
+ self.pretrain_model = MInterface(**pretrain_args)
333
+ ckpt = torch.load('/huyuqi/xmyu/VQProteinFormer/step2_ProGLM/results/ProGLM_SoftCVQ_ESM1B_CATH_lr5e-5/checkpoints/best-epoch=14-valid_acc=0.616.pth')
334
+ state_dict = {k.replace('_forward_module.', ''): v for k, v in ckpt.items()}
335
+ self.pretrain_model.load_state_dict(state_dict)
336
+
337
+ if name == 'ProGLM_SoftCVQ_CATH':
338
+ from step2_ProGLM.model.model_interface import MInterface
339
+ pretrain_args = OmegaConf.load("/huyuqi/xmyu/VQProteinFormer/step2_ProGLM/results/ProGPT_SoftCVQ_CATH/configs/12-26T08-13-41-project.yaml")
340
+ self.pretrain_model = MInterface(**pretrain_args)
341
+ ckpt = torch.load('/huyuqi/xmyu/VQProteinFormer/step2_ProGLM/results/ProGPT_SoftCVQ_CATH/checkpoints/best-epoch=14-gpt_acc=0.758.pth')
342
+ state_dict = {k.replace('_forward_module.', ''): v for k, v in ckpt.items()}
343
+ self.pretrain_model.load_state_dict(state_dict)
344
+
345
+ if name == 'ProGLM_SoftCVQ_CATH_epoch10k':
346
+ from step2_ProGLM.model.model_interface import MInterface
347
+ pretrain_args = OmegaConf.load("/huyuqi/xmyu/VQProteinFormer/step2_ProGLM/results/ProGPT_SoftCVQ_CATH_epoch1000/configs/12-27T02-36-49-project.yaml")
348
+ self.pretrain_model = MInterface(**pretrain_args)
349
+ ckpt = torch.load('/huyuqi/xmyu/VQProteinFormer/step2_ProGLM/results/ProGPT_SoftCVQ_CATH_epoch10000_resume/checkpoints/best-epoch=1887-gpt_loss=0.181.pth')
350
+ state_dict = {k.replace('_forward_module.', ''): v for k, v in ckpt.items()}
351
+ self.pretrain_model.load_state_dict(state_dict)
352
+
353
+
354
+ if name == 'GearNet':
355
+ from model.PretrainGearNet import PretrainGearNet_Model
356
+ self.pretrain_model = PretrainGearNet_Model()
357
+
358
+ self.pretrain_model.eval()
359
+
360
+ def get_vq_id(self, seqs, angles, attn_mask):
361
+ # if ('softgroup' in self.name) or ('LFQ' in self.name):
362
+ # h_input = self.pretrain_model.model.input(seqs.squeeze(-1), angles)
363
+ # h_enc = self.pretrain_model.model.ProteinEnc(h_input, attn_mask, None).last_hidden_state
364
+ # vq_id, e_enc = self.pretrain_model.model.VQLayer.get_vq(h_enc, attn_mask, temperature=1e-5)
365
+ # return F.pad(vq_id, [0,1,0,0])
366
+
367
+ h_input = self.pretrain_model.model.input(seqs.squeeze(-1), angles)
368
+ h_enc = self.pretrain_model.model.ProteinEnc(h_input, attn_mask, None).last_hidden_state
369
+ vq_id, e_enc = self.pretrain_model.model.VQLayer.get_vq(h_enc, attn_mask, temperature=1e-5)
370
+ return vq_id
371
+
372
+ def forward(self, batch):
373
+ if self.name in ["ESM35M", "ESM650M", "ESM3B"]:
374
+ seqs, attn_mask = batch['seqs'], batch['attn_mask']
375
+ outputs = self.pretrain_model.model(input_ids=seqs[:,:,0], attention_mask=attn_mask)
376
+ pretrain_embedding = outputs.hidden_states
377
+ pretrain_embedding = pretrain_embedding.reshape(-1,self.esm_dim)[attn_mask.view(-1)==1]
378
+ return pretrain_embedding
379
+ if self.name in ["softgroup_128_group"]:
380
+ seqs, angles, attn_mask = batch['seqs'], batch['angles'] , batch['attn_mask']
381
+ vq_id = self.pretrain_model.model.get_vqid(seqs[...,0], angles, attn_mask)
382
+ return vq_id
383
+ if self.name in ["ProGLM"]:
384
+ vq_id, attn_mask, seg, pos = batch['vq_id'], batch['attn_mask'], batch['seg'], batch['pos']
385
+ feat = self.pretrain_model.model.get_feat(vq_id, attn_mask, seg, pos)
386
+ feat = feat.reshape(-1,self.vq_dim)[attn_mask.view(-1)==1]
387
+ return feat
388
+ if self.name in ["GearNet"]:
389
+ seqs = batch['seqs']
390
+ batch = batch['batch']
391
+ attn_mask = batch['attn_mask']
392
+ for idx in range(seqs.shape[0]):
393
+ seq_str = self.pretrain_featurizer.ESM_tokenizer.decode(seqs[idx,attn_mask[idx,:].bool(),0])
394
+ seq_strs.append(seq_str.split(" "))
395
+ seq_strs = sum(seq_strs, [])
396
+ node_index = torch.arange(batch.batch.shape[0], device=batch.batch.device)
397
+ node2graph = batch.batch
398
+ chain_id = torch.ones_like(batch.batch)
399
+
400
+ pretrain_embedding = self.pretrain_gearnet_model(seq_strs, node_index, node2graph, chain_id, batch.pos)
401
+ return pretrain_embedding
402
+
403
+
404
+
405
+
Flexpert-Design/src/models/E3PiFold_model.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from typing import Optional
5
+ from torch import Tensor
6
+ from omegaconf import OmegaConf
7
+ from src.modules.E3PiFold import GaussianEncoder, TransformerEncoderWithPair
8
+ from src.tools import gather_nodes, _dihedrals, _get_rbf, _get_dist, _rbf, _orientations_coarse_gl_tuple
9
+
10
+ class E3PiFold(nn.Module):
11
+ def __init__(self, config) -> None:
12
+ super().__init__()
13
+ self.node_embed = nn.Linear(21, config.embed_dim)
14
+ self.protein_embedder = GaussianEncoder(config.kernel_num, config.embed_dim, config.attention_heads, config.use_dist, config.use_product)
15
+
16
+ self.encoder = TransformerEncoderWithPair(
17
+ config.encoder_layers,
18
+ config.embed_dim,
19
+ config.ffn_embed_dim,
20
+ config.attention_heads,
21
+ config.emb_dropout,
22
+ config.dropout,
23
+ config.attention_dropout,
24
+ config.activation_dropout,
25
+ config.max_seq_len,
26
+ )
27
+ self.predictor = nn.Linear(config.embed_dim, 33)
28
+
29
+ def _full_dist(self, X, mask, top_k=30, eps=1E-6):
30
+ mask_2D = torch.unsqueeze(mask,1) * torch.unsqueeze(mask,2)
31
+ dX = torch.unsqueeze(X,1) - torch.unsqueeze(X,2)
32
+ D = (1. - mask_2D)*10000 + mask_2D* torch.sqrt(torch.sum(dX**2, 3) + eps)
33
+
34
+ D_max, _ = torch.max(D, -1, keepdim=True)
35
+ D_adjust = D + (1. - mask_2D) * (D_max+1)
36
+ D_neighbors, E_idx = torch.topk(D_adjust, min(top_k, D_adjust.shape[-1]), dim=-1, largest=False)
37
+ return D_neighbors, E_idx
38
+
39
+ def _get_features(self, batch):
40
+ X = batch['X']
41
+ X_ca = X[:,:,1,:]
42
+ D_neighbors, E_idx = self._full_dist(X_ca, batch['mask'], 30)
43
+ V_angles = _dihedrals(X.float())
44
+ V_direct, E_direct, E_angles = _orientations_coarse_gl_tuple(X.float(), E_idx)
45
+ h_V = torch.cat([V_angles, V_direct], dim=-1).to(X.dtype)
46
+ batch['h_V'] = h_V
47
+ return batch
48
+
49
+ def forward(self, batch):
50
+ '''
51
+ X, H, seq_mask
52
+ '''
53
+ X = batch['X'][:,:,1]
54
+ H = self.node_embed(batch['h_V'])
55
+ seq_mask = batch['mask']
56
+ pair_mask = seq_mask[..., None] * seq_mask[..., None, :]
57
+ padding_mask = 1 - seq_mask
58
+ x, graph_attn_bias = self.protein_embedder(X, H, pair_mask)
59
+ (
60
+ encoder_rep,
61
+ encoder_pair_rep,
62
+ delta_encoder_pair_rep,
63
+ x_norm,
64
+ delta_encoder_pair_rep_norm,
65
+ ) = self.encoder(x, padding_mask=padding_mask, attn_mask=graph_attn_bias, pair_mask=pair_mask)
66
+ logits = self.predictor(x)
67
+ log_probs = F.log_softmax(logits, dim=-1)
68
+
69
+ return {'log_probs': log_probs}
70
+
71
+
72
+ if __name__ == '__main__':
73
+ B, N, dim = 16, 512, 768
74
+ X = torch.randn(B, N, 3)
75
+ H = torch.randn(B, N, dim)
76
+ seq_mask = (torch.ones(B, N)>0.5).float()
77
+
78
+ config = {'encoder_layers': 12,
79
+ 'kernel_num':16,
80
+ 'embed_dim': 768,
81
+ 'ffn_embed_dim': 3072,
82
+ 'attention_heads': 8,
83
+ 'emb_dropout': 0.1,
84
+ 'dropout': 0.1,
85
+ 'attention_dropout': 0.1,
86
+ 'activation_dropout': 0.0,
87
+ 'max_seq_len': 256}
88
+ config = OmegaConf.create(config)
89
+ model = E3PiFold(config)
90
+ feat = model(X, H, seq_mask)
Flexpert-Design/src/models/MemoryESM.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import subprocess
2
+ import os
3
+ from joblib import Parallel, delayed, cpu_count
4
+ from tqdm import tqdm
5
+ import pandas as pd
6
+ import torch
7
+ import torch.nn as nn
8
+ import shutil
9
+ from .PretrainESM_model import PretrainESM_Model
10
+
11
+
12
+ class MemoESM(nn.Module):
13
+ def __init__(self, args):
14
+ super().__init__()
15
+ self.PretrainESM = PretrainESM_Model(args)
16
+ self.tokenizer = self.PretrainESM.tokenizer
17
+ self.memory = {}
18
+ # self.fix_memory = False
19
+
20
+ # def save_memory(self, path):
21
+ # params = {key:val for key,val in self.state_dict().items() if "GNNTuning" in key}
22
+ # torch.save({"params":params,"memory": self.memory}, path)
23
+
24
+ # def load_memory(self, path):
25
+ # data = torch.load(path)
26
+ # self.load_state_dict(data['params'], strict=False)
27
+ # self.memory = data['memory']
28
+
29
+ def clean_input(self, batch, score_cut=0.99):
30
+ '''
31
+ require: batch['pred_ids'], batch['attention_mask'], batch['confs']
32
+ '''
33
+ symbol = "<mask>"
34
+ replace_dict = {"-":symbol,
35
+ ".":symbol,
36
+ "<eos>":symbol,
37
+ "<unk>":symbol,
38
+ "<cls>":symbol,
39
+ "<pad>":symbol,
40
+ "<null_1>":symbol,
41
+ "<mask>":symbol,
42
+ "U":symbol,
43
+ "O":symbol}
44
+
45
+ device = batch['pred_ids'].device
46
+ query_seqs = []
47
+ for pred_ids, mask, score in zip(batch['pred_ids'], batch['attention_mask'], batch['confs']):
48
+ seq = self.tokenizer.decode(pred_ids[mask], clean_up_tokenization_spaces=False)
49
+ elements = []
50
+ for idx, x in enumerate(seq.split(" ")):
51
+ symbol = replace_dict.get(x, x)
52
+ if score[idx] < score_cut:
53
+ symbol = "<mask>"
54
+ elements.append(symbol)
55
+ seq = "".join(elements)
56
+ query_seqs.append(seq)
57
+
58
+ results = self.tokenizer.batch_encode_plus(query_seqs, return_tensors="pt", padding=True)
59
+ return query_seqs
60
+
61
+ def initoutput(self, B, maxL, device):
62
+ self.out_pred_ids = torch.zeros(B, maxL, dtype=torch.long, device=device)
63
+ self.out_confs = torch.zeros(B, maxL, dtype=torch.float, device=device)
64
+ self.out_embeds = torch.zeros(B, maxL, 1280, dtype=torch.float, device=device)
65
+ self.titles = [None for i in range(B)]
66
+
67
+ def retrivel(self, titles, num_nodes, device, use_memory):
68
+ # retrieval
69
+ unseen = []
70
+ for idx in range(len(titles)):
71
+ name = titles[idx]
72
+ if (name in self.memory) and use_memory:
73
+ memo_pred_ids = self.memory[name]['pred_ids'].to(device)
74
+ memo_confs = self.memory[name]['confs'].to(device)
75
+ memo_embeds = self.memory[name]['embeds'].to(device)
76
+
77
+ self.out_pred_ids[idx, :num_nodes[idx]] = memo_pred_ids
78
+ self.out_confs[idx, :num_nodes[idx]] = memo_confs
79
+ self.out_embeds[idx, :num_nodes[idx]] = memo_embeds
80
+ self.titles[idx] = name
81
+ else:
82
+ unseen.append(idx)
83
+ return unseen
84
+
85
+ def rebatch(self, unseen, batch):
86
+ unseen_pred_ids = []
87
+ unseen_attention_mask = []
88
+ for i in unseen:
89
+ unseen_pred_ids.append(batch['pred_ids'][i])
90
+ unseen_attention_mask.append(batch['attention_mask'][i])
91
+ unseen_pred_ids = torch.stack(unseen_pred_ids)
92
+ unseen_attention_mask = torch.stack(unseen_attention_mask)
93
+ return {"pred_ids":unseen_pred_ids, "attention_mask":unseen_attention_mask}
94
+
95
+ def save2memory(self, unseen,outputs, titles, unseen_attention_mask):
96
+ # save to memory
97
+ for i in range(len(unseen)):
98
+ name = titles[unseen[i]]
99
+ self.titles[unseen[i]] = name
100
+ mask = unseen_attention_mask[i]
101
+ self.memory[name] = {"pred_ids":outputs['pred_ids'][i][mask].detach().to('cpu'),
102
+ "confs":outputs['confs'][i][mask].detach().to('cpu'),
103
+ "embeds":outputs['embeds'][i][mask].detach().to('cpu')}
104
+
105
+ def update(self, unseen, unseen_attention_mask, num_nodes, outputs):
106
+ # update
107
+ for idx in range(len(unseen)):
108
+ mask = unseen_attention_mask[idx]==1
109
+ self.out_pred_ids[unseen[idx], :num_nodes[unseen[idx]]] = outputs['pred_ids'][idx][mask]
110
+ self.out_confs[unseen[idx], :num_nodes[unseen[idx]]] = outputs['confs'][idx][mask]
111
+ self.out_embeds[unseen[idx], :num_nodes[unseen[idx]]] = outputs['embeds'][idx][mask]
112
+
113
+ @torch.no_grad()
114
+ def forward(self, batch, use_memory=False):
115
+ # debatch
116
+ # clean_seqs = self.clean_input(batch)
117
+ device = batch['probs'].device
118
+ B, maxL, _ = batch['probs'].shape
119
+ num_nodes = batch['attention_mask'].sum(dim=-1).tolist()
120
+ self.initoutput(B, maxL, device)
121
+ unseen = self.retrivel(batch['title'], num_nodes, device, use_memory)
122
+
123
+
124
+ if len(unseen)>0:
125
+ # batch forward
126
+ new_batch = self.rebatch(unseen, batch)
127
+ outputs = self.PretrainESM(new_batch)
128
+
129
+ self.save2memory(unseen,outputs, batch['title'], new_batch['attention_mask'])
130
+ self.update(unseen, new_batch['attention_mask'], num_nodes, outputs)
131
+
132
+ return {'title':self.titles,'pred_ids':self.out_pred_ids, 'confs':self.out_confs, 'embeds':self.out_embeds, 'attention_mask':batch['attention_mask']}
133
+
134
+
135
+
136
+ if __name__ == '__main__':
137
+
138
+ # work_space = '/gaozhangyang/experiments/PiFoldV2/data/mmseq_workspace2'
139
+ # target_seqs = ["MVLSPADKTNVKAAWGKVGAHAGEYGAEALERMFLSFPQTKTYFPHFDLSHGSAQVKGHG", "MVHLTPEEKSAVTALWGKVNVDEVGVEALGRLLVVYPWTQRFFESFGDLSTPDAVMGNPKV",
140
+ # "MVLSPADKTNVKAAWGKVGAGGAEALERMFLSFPQKTYYTYFPHFDLSHGSAQVKGHG"]
141
+
142
+ # query_seqs = ["MVLSPADKTNVKAAWGKVGAHAGEYGAEALERMFLSFPTTKFPHFDLSHGSAQV", "MVHLTPEEKSAVTALWGKVNVDEVGGGRLLVVYPWTQRFFESFGDLSTPDAV",]
143
+
144
+ # results = search_seqs(query_seqs, target_seqs, work_space)
145
+ # print(results)
146
+
147
+
148
+ import biotite.sequence as seq
149
+ import biotite.sequence.align as align
150
+
151
+ # Create example query and target protein sequences
152
+ query_seq1 = seq.ProteinSequence("MSKXXKAFLNKXXL")
153
+ target_seq1 = seq.ProteinSequence("MSKVKAALNKVLL")
154
+ target_seq2 = seq.ProteinSequence("MSKVKKALNKVLL")
155
+ target_seq3 = seq.ProteinSequence("MSTVAAALKMLLL")
156
+
157
+ results = search_seqs_biotite(["MSKXXKAFLNKXXL"], ["MSKVKAALNKVLL", "MSKVKKALNKVLL", "MSTVAAALKMLLL"])
158
+
159
+ # Print the alignments
160
+ print("Query alignments:")
161
+
162
+
163
+
164
+
Flexpert-Design/src/models/MemoryESMIF.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import subprocess
2
+ import os
3
+ from joblib import Parallel, delayed, cpu_count
4
+ from tqdm import tqdm
5
+ import pandas as pd
6
+ import torch
7
+ import torch.nn as nn
8
+ from .PretrainESMIF_model import PretrainESMIF_Model
9
+ from torch_scatter import scatter_sum
10
+
11
+ class MemoESMIF(nn.Module):
12
+ def __init__(self):
13
+ super().__init__()
14
+ self.PretrainESMIF = PretrainESMIF_Model()
15
+ self.memory = {}
16
+ # self.fix_memory = False
17
+
18
+ # def save_memory(self, path):
19
+ # params = {key:val for key,val in self.state_dict().items() if "GNNTuning" in key}
20
+ # torch.save({"params":params,"memory": self.memory}, path)
21
+
22
+ # def load_memory(self, path):
23
+ # data = torch.load(path)
24
+ # self.load_state_dict(data['params'], strict=False)
25
+ # self.memory = data['memory']
26
+
27
+ def initoutput(self, B, maxL, device):
28
+ self.out_embeds = torch.zeros(B, maxL, 512, dtype=torch.float, device=device)
29
+ self.titles = [None for i in range(B)]
30
+
31
+ def retrivel(self, titles, num_nodes, device, use_memory):
32
+ # retrieval
33
+ unseen = []
34
+ for idx in range(len(titles)):
35
+ name = titles[idx]
36
+ if (name in self.memory) and use_memory:
37
+ memo_embeds = self.memory[name]['embeds'].to(device)
38
+ self.out_embeds[idx, :num_nodes[idx]] = memo_embeds
39
+ self.titles[idx] = name
40
+ else:
41
+ unseen.append(idx)
42
+ return unseen
43
+
44
+ def rebatch(self, unseen, batch):
45
+ unseen_position = []
46
+ for i in unseen:
47
+ mask = batch['batch_id']==i
48
+ unseen_position.append(batch['position'][mask][:,:3,:])
49
+ return {"position":unseen_position}
50
+
51
+ def save2memory(self, unseen,outputs, titles, num_nodes):
52
+ # save to memory
53
+ for i in range(len(unseen)):
54
+ name = titles[unseen[i]]
55
+ self.titles[unseen[i]] = name
56
+ num = num_nodes[unseen[i]]
57
+ self.memory[name] = {"embeds":outputs['feat'][i,:num].detach().to('cpu')}
58
+
59
+ def update(self, unseen, num_nodes, outputs):
60
+ # update
61
+ for idx in range(len(unseen)):
62
+ num = num_nodes[unseen[idx]]
63
+ self.out_embeds[unseen[idx], :num_nodes[unseen[idx]]] = outputs['feat'][idx, :num]
64
+
65
+ @torch.no_grad()
66
+ def forward(self, batch, use_memory=False):
67
+ # debatch
68
+ # clean_seqs = self.clean_input(batch)
69
+ device = batch['position'].device
70
+ num_nodes = scatter_sum(torch.ones_like(batch['batch_id']), batch['batch_id'], dim=0)
71
+ B, maxL = num_nodes.shape[0], num_nodes.max()
72
+ self.initoutput(B, maxL, device)
73
+ unseen = self.retrivel(batch['title'], num_nodes, device, use_memory)
74
+
75
+
76
+ if len(unseen)>0:
77
+ # batch forward
78
+ new_batch = self.rebatch(unseen, batch)
79
+ outputs = self.PretrainESMIF(new_batch['position'])
80
+
81
+ self.save2memory(unseen,outputs, batch['title'], num_nodes)
82
+ self.update(unseen, num_nodes, outputs)
83
+
84
+ return {'title':self.titles, 'embeds':self.out_embeds}
85
+
86
+
87
+
88
+ if __name__ == '__main__':
89
+
90
+ # work_space = '/gaozhangyang/experiments/PiFoldV2/data/mmseq_workspace2'
91
+ # target_seqs = ["MVLSPADKTNVKAAWGKVGAHAGEYGAEALERMFLSFPQTKTYFPHFDLSHGSAQVKGHG", "MVHLTPEEKSAVTALWGKVNVDEVGVEALGRLLVVYPWTQRFFESFGDLSTPDAVMGNPKV",
92
+ # "MVLSPADKTNVKAAWGKVGAGGAEALERMFLSFPQKTYYTYFPHFDLSHGSAQVKGHG"]
93
+
94
+ # query_seqs = ["MVLSPADKTNVKAAWGKVGAHAGEYGAEALERMFLSFPTTKFPHFDLSHGSAQV", "MVHLTPEEKSAVTALWGKVNVDEVGGGRLLVVYPWTQRFFESFGDLSTPDAV",]
95
+
96
+ # results = search_seqs(query_seqs, target_seqs, work_space)
97
+ # print(results)
98
+
99
+
100
+ import biotite.sequence as seq
101
+ import biotite.sequence.align as align
102
+
103
+ # Create example query and target protein sequences
104
+ query_seq1 = seq.ProteinSequence("MSKXXKAFLNKXXL")
105
+ target_seq1 = seq.ProteinSequence("MSKVKAALNKVLL")
106
+ target_seq2 = seq.ProteinSequence("MSKVKKALNKVLL")
107
+ target_seq3 = seq.ProteinSequence("MSTVAAALKMLLL")
108
+
109
+ results = search_seqs_biotite(["MSKXXKAFLNKXXL"], ["MSKVKAALNKVLL", "MSKVKKALNKVLL", "MSTVAAALKMLLL"])
110
+
111
+ # Print the alignments
112
+ print("Query alignments:")
113
+
114
+
115
+
116
+
Flexpert-Design/src/models/MemoryPiFold.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from .PretrainPiFold_model import PretrainPiFold_Model
4
+ from torch_scatter import scatter_sum
5
+ import torch.nn.functional as F
6
+
7
+ class MemoPiFold_model(nn.Module):
8
+ def __init__(self, args):
9
+ super().__init__()
10
+ self.PretrainPiFold = PretrainPiFold_Model(args)
11
+ self.memory = {}
12
+
13
+ def save_memory(self, path):
14
+ params = {key:val for key,val in self.state_dict().items() if "GNNTuning" in key}
15
+ torch.save({"params":params,"memory": self.memory}, path)
16
+
17
+ def load_memory(self, path):
18
+ data = torch.load(path)
19
+ self.load_state_dict(data['params'], strict=False)
20
+ self.memory = data['memory']
21
+
22
+ def initoutput(self, B, max_L, nums, device):
23
+ self.confs = torch.ones(B, max_L, device=device)
24
+ self.embeds = torch.ones(B, max_L, 128, device=device)
25
+ self.probs = torch.ones(B, max_L, 33, device=device)
26
+ self.attention_mask = torch.ones_like(self.confs)==0
27
+ self.titles = [None for i in range(B)]
28
+ for id, num in enumerate(nums):
29
+ self.attention_mask[id, :num] = True
30
+ self.edge_feats = []
31
+
32
+ def retrivel(self, batch, nums, batch_uid, device, use_memory):
33
+ # retrieval
34
+ unseen = []
35
+
36
+ for idx, name in enumerate(batch['title']):
37
+ if (name in self.memory) and use_memory:
38
+ try:
39
+ self.confs[batch_uid[idx],:nums[idx]] = self.memory[name]['conf'].to(device)
40
+ except:
41
+ self.confs[batch_uid[idx],:nums[idx]] = self.memory[name]['conf'].to(device)
42
+ self.embeds[batch_uid[idx],:nums[idx]] = self.memory[name]['embed'].to(device)
43
+ self.probs[batch_uid[idx],:nums[idx]] = self.memory[name]['prob'].to(device)
44
+ self.edge_feats.append((batch_uid[idx], self.memory[name]['h_E'].to(device)))
45
+ self.titles[batch_uid[idx]] = name
46
+ else:
47
+ unseen.append(idx)
48
+ return unseen
49
+
50
+ def rebatch(self, unseen, batch_uid, batch_id, batch, shift, nums, device):
51
+ h_V2, h_E2, E_idx2, batch_id2 = [], [], [], []
52
+ shift2 = [0]
53
+ idx=0
54
+ for id in batch_uid:
55
+ if id not in unseen:
56
+ continue
57
+ node_mask = batch_id == id
58
+ edge_mask = batch_id[batch['E_idx'][0]] == id
59
+ h_V2.append(batch['h_V'][node_mask])
60
+ h_E2.append(batch['h_E'][edge_mask])
61
+ new_E_idx = batch['E_idx'][:,edge_mask]
62
+ new_E_idx = new_E_idx- shift[batch_id[new_E_idx[0]]]+shift2[-1]
63
+ E_idx2.append(new_E_idx)
64
+ new_batch_id = torch.ones(node_mask.sum().long(), device=device)*idx
65
+ batch_id2.append(new_batch_id)
66
+ shift2.append(shift2[-1]+nums[id])
67
+ idx+=1
68
+
69
+ h_V2 = torch.cat(h_V2)
70
+ h_E2 = torch.cat(h_E2)
71
+ E_idx2 = torch.cat(E_idx2, dim=-1)
72
+ batch_id2 = torch.cat(batch_id2).long()
73
+ return {"h_V":h_V2, 'h_E':h_E2, 'E_idx':E_idx2, 'batch_id':batch_id2}
74
+
75
+ def update_save2memory(self, unseen, batch_id2, E_idx2, batch, pretrain_gnn, max_L):
76
+ for id in batch_id2.unique():
77
+ node_mask = batch_id2 == id
78
+ edge_mask = batch_id2[E_idx2[0]] == id
79
+ title = batch['title'][unseen[int(id)]]
80
+ conf = pretrain_gnn['confs'][id]
81
+ conf = F.pad(conf, (0, max_L-len(conf)))
82
+ embed = pretrain_gnn['embeds'][id]
83
+ embed = F.pad(embed, (0,0,0,max_L-len(embed)))
84
+ prob = pretrain_gnn['probs'][id]
85
+ prob = F.pad(prob, (0,0,0,max_L-len(prob)))
86
+ self.edge_feats.append((unseen[int(id)], pretrain_gnn['h_E'][edge_mask]))
87
+
88
+ self.confs[unseen[int(id)]] = conf
89
+ self.embeds[unseen[int(id)]] = embed
90
+ self.probs[unseen[int(id)]] = prob
91
+ self.titles[unseen[int(id)]] = title
92
+
93
+ attn_mask = self.attention_mask[unseen[int(id)]]
94
+
95
+ # save to memory
96
+ self.memory[title] = {'conf': conf[attn_mask].detach().to('cpu'),
97
+ 'embed': embed[attn_mask].detach().to('cpu'),
98
+ 'prob': prob[attn_mask].detach().to('cpu'),
99
+ 'h_E':pretrain_gnn['h_E'][edge_mask].detach().to('cpu')}
100
+
101
+ @torch.no_grad()
102
+ def forward(self, batch, use_memory=False):
103
+ batch_id = batch['batch_id']
104
+ batch_uid = batch_id.unique()
105
+ device = batch_id.device
106
+
107
+ nums = scatter_sum(torch.ones_like(batch_id), batch_id)
108
+ shift = torch.cat([torch.zeros(1, device=device), torch.cumsum(nums, dim=0)]).long()
109
+ max_L, B = nums.max(), batch_uid.shape[0]
110
+
111
+ self.initoutput(B, max_L, nums, device)
112
+ unseen = self.retrivel(batch, nums, batch_uid, device, use_memory)
113
+
114
+ # organize data
115
+ if len(unseen)>0:
116
+ # rebatch
117
+ new_batch = self.rebatch(unseen, batch_uid, batch_id, batch, shift, nums, device)
118
+
119
+ # forward pass
120
+ pretrain_gnn = self.PretrainPiFold(new_batch)
121
+
122
+ self.update_save2memory(unseen, pretrain_gnn['batch_id'], pretrain_gnn['E_idx'], batch, pretrain_gnn, max_L)
123
+
124
+
125
+ self.edge_feats = sorted(self.edge_feats, key=lambda x: x[0])
126
+ self.edge_feats = torch.cat([one[1] for one in self.edge_feats])
127
+
128
+ pred_ids = self.probs.argmax(dim=-1)*self.attention_mask + (~self.attention_mask)*1
129
+
130
+ return {'title': self.titles,
131
+ 'pred_ids': pred_ids,
132
+ 'confs': self.confs,
133
+ 'embeds': self.embeds,
134
+ 'probs': self.probs,
135
+ 'attention_mask': self.attention_mask,
136
+ 'h_E':self.edge_feats,
137
+ 'E_idx': batch['E_idx'],
138
+ 'batch_id': batch['batch_id']}
139
+
140
+
141
+
142
+ def _get_features(self, S, score, X, mask, chain_mask, chain_encoding):
143
+ return self.PretrainPiFold._get_features(S, score, X, mask, chain_mask, chain_encoding)
Flexpert-Design/src/models/MemoryTuning.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from .Tuning import GNNTuning_Model
4
+
5
+
6
+ class MemoTuning(nn.Module):
7
+ def __init__(self, args, tunning_layers_n, tunning_layers_dim, input_design_dim, input_esm_dim, tunning_dropout, tokenizer, fix_memory=False):
8
+ super().__init__()
9
+ self.args = args
10
+ self.tunning_layers_dim = tunning_layers_dim
11
+ self.GNNTuning = GNNTuning_Model(args, num_encoder_layers=tunning_layers_n, hidden_dim=tunning_layers_dim, input_design_dim=input_design_dim, input_esm_dim=input_esm_dim, dropout = tunning_dropout)
12
+ self.tokenizer = tokenizer
13
+ self.memory = {}
14
+
15
+ def save_param_memory(self, path):
16
+ torch.save({"params":self.state_dict(),"memory": self.memory}, path)
17
+
18
+ def load_param_memory(self, path):
19
+ data = torch.load(path)
20
+ self.load_state_dict(data['params'])
21
+ self.memory = data['memory']
22
+
23
+ def get_seqs(self, pred_ids_raw, attention_mask):
24
+ query_seqs = []
25
+ for pred_ids, mask in zip(pred_ids_raw, attention_mask):
26
+ seq = self.tokenizer.decode(pred_ids[mask], clean_up_tokenization_spaces=False)
27
+ seq = "".join(seq.split(" "))
28
+ query_seqs.append(seq)
29
+ return query_seqs
30
+
31
+ def initoutput(self, pretrain_design, B, max_L, device):
32
+ # initialize output
33
+ self.out_pred_ids = torch.zeros_like(pretrain_design['pred_ids'])
34
+ self.out_confs = torch.zeros_like(pretrain_design['confs'])
35
+ self.out_embeds = torch.zeros(B, max_L, self.tunning_layers_dim, device = device)
36
+ self.out_attention_mask = torch.zeros_like(pretrain_design['attention_mask'])
37
+ self.out_probs = torch.zeros_like(pretrain_design['probs'])
38
+ self.out_log_probs = torch.zeros_like(pretrain_design['probs'])
39
+ self.titles = [None for i in range(B)]
40
+
41
+
42
+
43
+ def retrivel(self, keys, num_nodes,device, use_memory):
44
+ unseen = []
45
+ for idx in range(len(keys)):
46
+ key = keys[idx]
47
+ if (key in self.memory) and use_memory:
48
+ self.out_pred_ids[idx, :num_nodes[idx]] = self.memory[key]['pred_ids'].to(device)
49
+ self.out_confs[idx, :num_nodes[idx]] = self.memory[key]['confs'].to(device)
50
+ self.out_embeds[idx, :num_nodes[idx]] = self.memory[key]['embeds'].to(device)
51
+ self.out_attention_mask[idx, :num_nodes[idx]] = self.memory[key]['attention_mask'].to(device)
52
+ self.out_probs[idx, :num_nodes[idx]] = self.memory[key]['probs'].to(device)
53
+ self.out_log_probs[idx, :num_nodes[idx]] = self.memory[key]['log_probs'].to(device)
54
+ self.titles[idx] = key
55
+ else:
56
+ unseen.append(idx)
57
+ return unseen
58
+
59
+ def rebatch(self,unseen, batch_id_raw, E_idx_raw, h_E_raw, shift, num_nodes, pretrain_design, pretrain_esm_msa, pretrain_struct, pretrain_esmif, device):
60
+ unseen_design_pred_ids = []
61
+ unseen_design_confs = []
62
+ unseen_design_embeds = []
63
+ unseen_design_attention_mask = []
64
+
65
+ unseen_esm_pred_ids = []
66
+ unseen_esm_confs = []
67
+ unseen_esm_embeds = []
68
+ unseen_esm_attention_mask = []
69
+ unseen_struct_embeds = []
70
+ unseen_esmif_embeds = []
71
+ h_E = []
72
+ E_idx = []
73
+ batch_id = []
74
+
75
+ new_shift = 0
76
+ for bid, i in enumerate(unseen):
77
+ edge_mask = batch_id_raw[E_idx_raw[0]] == i
78
+ h_E.append(h_E_raw[edge_mask])
79
+ E_idx.append(E_idx_raw[:,edge_mask]-shift[i]+new_shift)
80
+ batch_id.append(torch.ones(num_nodes[i], device=device).long()*bid)
81
+ new_shift += num_nodes[i]
82
+
83
+ unseen_design_pred_ids.append(pretrain_design['pred_ids'][i])
84
+ unseen_design_confs.append(pretrain_design['confs'][i])
85
+ unseen_design_embeds.append(pretrain_design['embeds'][i])
86
+ unseen_design_attention_mask.append(pretrain_design['attention_mask'][i])
87
+
88
+ if self.args.use_LM:
89
+ unseen_esm_pred_ids.append(pretrain_esm_msa['pred_ids'][:,i])
90
+ unseen_esm_confs.append(pretrain_esm_msa['confs'][:,i])
91
+ unseen_esm_embeds.append(pretrain_esm_msa['embeds'][:,i])
92
+ unseen_esm_attention_mask.append(pretrain_esm_msa['attention_mask'][:,i])
93
+
94
+ if self.args.use_gearnet:
95
+ unseen_struct_embeds.append(pretrain_struct['embeds'][:,i])
96
+
97
+ if self.args.use_esmif:
98
+ unseen_esmif_embeds.append(pretrain_esmif['embeds'][i])
99
+
100
+
101
+ unseen_design_pred_ids = torch.stack(unseen_design_pred_ids)
102
+ unseen_design_confs = torch.stack(unseen_design_confs)
103
+ unseen_design_embeds = torch.stack(unseen_design_embeds)
104
+ unseen_design_attention_mask = torch.stack(unseen_design_attention_mask)
105
+
106
+ if self.args.use_LM:
107
+ unseen_esm_pred_ids = torch.stack(unseen_esm_pred_ids, dim=1)
108
+ unseen_esm_confs = torch.stack(unseen_esm_confs, dim=1)
109
+ unseen_esm_embeds = torch.stack(unseen_esm_embeds, dim=1)
110
+ unseen_esm_attention_mask = torch.stack(unseen_esm_attention_mask, dim=1)
111
+
112
+ if self.args.use_gearnet:
113
+ unseen_struct_embeds = torch.stack(unseen_struct_embeds, dim=1)
114
+
115
+ if self.args.use_esmif:
116
+ unseen_esmif_embeds = torch.stack(unseen_esmif_embeds, dim=0)
117
+
118
+
119
+ unseen_batch = {"pretrain_design":
120
+ {"pred_ids": unseen_design_pred_ids,
121
+ "confs":unseen_design_confs,
122
+ "embeds": unseen_design_embeds,
123
+ "attention_mask":unseen_design_attention_mask},
124
+ "h_E": torch.cat(h_E),
125
+ "E_idx": torch.cat(E_idx, dim=1),
126
+ "batch_id": torch.cat(batch_id),
127
+ "attention_mask":unseen_design_attention_mask
128
+ }
129
+
130
+ if self.args.use_LM:
131
+ unseen_batch["pretrain_esm_msa"]={"pred_ids": unseen_esm_pred_ids,
132
+ "confs":unseen_esm_confs,
133
+ "embeds": unseen_esm_embeds,
134
+ "attention_mask":unseen_esm_attention_mask}
135
+
136
+ if self.args.use_gearnet:
137
+ unseen_batch["pretrain_struct"] = {
138
+ "embeds":unseen_struct_embeds}
139
+
140
+ if self.args.use_esmif:
141
+ unseen_batch["pretrain_esmif"] = {"embeds":unseen_esmif_embeds}
142
+ return unseen_batch
143
+
144
+ def save2memory(self,keys,unseen,num_nodes, unseen_results):
145
+ # save to memory
146
+ for i in range(len(unseen)):
147
+ key = keys[unseen[i]]
148
+ num = num_nodes[unseen[i]]
149
+ self.memory[key] = {"pred_ids":unseen_results['pred_ids'][i][:num].detach().to('cpu'),
150
+ "confs":unseen_results['confs'][i][:num].detach().to('cpu'),
151
+ "embeds":unseen_results['embeds'][i][:num].detach().to('cpu'),
152
+ "probs":unseen_results['probs'][i][:num].detach().to('cpu'),
153
+ "log_probs":unseen_results['log_probs'][i][:num].detach().to('cpu'),
154
+ "attention_mask":unseen_results['attention_mask'][i][:num].detach().to('cpu')}
155
+
156
+ def update(self, unseen, num_nodes, unseen_results, keys):
157
+ # update
158
+ for i in range(len(unseen)):
159
+ num = num_nodes[unseen[i]]
160
+ self.out_pred_ids[unseen[i], :num] = unseen_results['pred_ids'][i][:num]
161
+ self.out_confs[unseen[i], :num] = unseen_results['confs'][i][:num]
162
+ self.out_embeds[unseen[i], :num] = unseen_results['embeds'][i][:num]
163
+ self.out_probs[unseen[i], :num] = unseen_results['probs'][i][:num]
164
+ self.out_log_probs[unseen[i], :num] = unseen_results['log_probs'][i][:num]
165
+ self.titles[unseen[i]] = keys[unseen[i]]
166
+
167
+ def forward(self, batch, use_memory=False):
168
+ self.use_memory = use_memory
169
+ pretrain_design, h_E_raw, E_idx_raw, mask_attend, batch_id_raw = batch['pretrain_design'] ,batch['h_E'], batch['E_idx'], batch['attention_mask'], batch['batch_id']
170
+ device = h_E_raw.device
171
+
172
+ pretrain_esm_msa = None
173
+ if self.args.use_LM:
174
+ pretrain_esm_msa = batch['pretrain_esm_msa']
175
+
176
+ pretrain_struct = None
177
+ if self.args.use_gearnet:
178
+ pretrain_struct = batch['pretrain_struct']
179
+
180
+ pretrain_esmif = None
181
+ if self.args.use_esmif:
182
+ pretrain_esmif = batch['esm_feat']
183
+
184
+
185
+ num_nodes = batch['attention_mask'].sum(dim=-1)
186
+ shift = torch.cat([torch.zeros(1, device=device), torch.cumsum(num_nodes, dim=0)]).long()
187
+
188
+ B, max_L = num_nodes.shape[0], num_nodes.max()
189
+
190
+ self.initoutput(pretrain_design, B, max_L, device)
191
+
192
+
193
+ # keys = list(zip(design_seqs, *lm_seqs))
194
+ keys = batch['title']
195
+ unseen = self.retrivel(keys, num_nodes,device, use_memory)
196
+
197
+
198
+ if len(unseen)>0:
199
+ unseen_batch = self.rebatch(unseen, batch_id_raw, E_idx_raw, h_E_raw, shift, num_nodes, pretrain_design, pretrain_esm_msa, pretrain_struct, pretrain_esmif, device)
200
+ unseen_results = self.GNNTuning(unseen_batch)
201
+
202
+ self.save2memory(keys,unseen,num_nodes, unseen_results)
203
+ self.update(unseen, num_nodes, unseen_results, keys)
204
+
205
+ return {'title':self.titles,'pred_ids':self.out_pred_ids, 'confs':self.out_confs, 'embeds':self.out_embeds, 'probs':self.out_probs, "log_probs":self.out_log_probs, 'attention_mask':pretrain_design['attention_mask']}
206
+
207
+
208
+
209
+
210
+
211
+
212
+
213
+
Flexpert-Design/src/models/PretrainESMIF_model.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import esm
2
+ import torch.nn as nn
3
+ import torch
4
+ from esm.inverse_folding.util import CoordBatchConverter
5
+
6
+ class PretrainESMIF_Model(nn.Module):
7
+ def __init__(self):
8
+ super(PretrainESMIF_Model, self).__init__()
9
+ # /root/.cache/torch/hub/checkpoints
10
+ model_data = torch.load("./model_zoo/esmif/esm_if1_gvp4_t16_142M_UR50.pt")
11
+ self.model, self.alphabet = esm.pretrained.load_model_and_alphabet_core("esm_if1_gvp4_t16_142M_UR50", model_data, None)
12
+
13
+ def forward(self, coords_list):
14
+ self.model.eval()
15
+ batch_converter = CoordBatchConverter(self.model.decoder.dictionary)
16
+ batch_coords, confidence, _, _, padding_mask = (
17
+ batch_converter([(coord, None, None) for coord in coords_list], device=coords_list[0].device)
18
+ )
19
+ with torch.no_grad():
20
+ encoder_out = self.model.encoder(batch_coords, padding_mask, confidence)
21
+
22
+ feat = encoder_out['encoder_out'][0].permute(1,0,2)[:,1:-1] # 2,1046-2,512
23
+ attention_mask = encoder_out['encoder_padding_mask'][0][:,1:-1]==False # 2,1046-2
24
+
25
+ return {"feat":feat}
26
+
27
+ if __name__ == '__main__':
28
+ model = PretrainESMIF_Model(0.1)
29
+ coords1 = torch.rand(1044,3,3)#N, CA, C
30
+ coords2 = torch.rand(500,3,3)
31
+ model([coords1, coords2])
32
+ print()
Flexpert-Design/src/models/PretrainESM_model.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import torch
3
+ import math
4
+ import torch.nn as nn
5
+ from transformers import AutoTokenizer, EsmForMaskedLM # EsmForMaskedLM, 1041 line
6
+ import torch
7
+
8
+
9
+ class PretrainESM_Model(nn.Module):
10
+ def __init__(self, args):
11
+ """ Graph labeling network """
12
+ super(PretrainESM_Model, self).__init__()
13
+ self.args=args
14
+ # {0: '<cls>', 1: '<pad>', 2: '<eos>', 3: '<unk>', 4: 'L', 5: 'A', 6: 'G', 7: 'V', 8: 'S', 9: 'E', 10: 'R', 11: 'T', 12: 'I', 13: 'D', 14: 'P', 15: 'K', 16: 'Q', 17: 'N', 18: 'F', 19: 'Y', 20: 'M', 21: 'H', 22: 'W', 23: 'C', 24: 'X', 25: 'B', 26: 'U', 27: 'Z', 28: 'O', 29: '.', 30: '-', 31: '<null_1>', 32: '<mask>'}
15
+ self.tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D", cache_dir="./cache_dir/")
16
+ self.model = EsmForMaskedLM.from_pretrained("facebook/esm2_t33_650M_UR50D", cache_dir="./cache_dir/")
17
+
18
+
19
+ def forward(self,batch):
20
+ outputs = self.model(input_ids=batch['pred_ids'], attention_mask=batch['attention_mask'])
21
+ logits = outputs.logits
22
+
23
+ prop = logits.softmax(dim=-1)
24
+ confidences, pred_ids = prop.max(dim=-1)
25
+
26
+ ret = {"pred_ids": pred_ids,
27
+ "confs": confidences,
28
+ "embeds": outputs.hidden_states,
29
+ "attention_mask": batch['attention_mask']}
30
+ return ret
31
+
32
+ if __name__ == '__main__':
33
+ tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D", cache_dir="./cache_dir/")
34
+ tokenizer.convert_ids_to_tokens
35
+ print()
Flexpert-Design/src/models/PretrainPiFold_model.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os.path as osp
3
+ from src.models.pifold_model import PiFold_Model
4
+ import torch.nn.functional as F
5
+
6
+
7
+ class PretrainPiFold_Model(PiFold_Model):
8
+ def __init__(self, args, **kwargs):
9
+ """ Graph labeling network """
10
+ PiFold_Model.__init__(self, args)
11
+ if args.augment_eps>0:
12
+ pretrain_pifold_path = osp.join(self.args.res_dir, self.args.dataset, f"PiFold_{args.augment_eps}", "checkpoint.pth")
13
+ else:
14
+ # pretrain_pifold_path = osp.join(self.args.res_dir, self.args.dataset, "PiFold", "checkpoint.pth")
15
+ pretrain_pifold_path = osp.join('model_zoo', self.args.dataset, "PiFold", "checkpoint.pth")
16
+ self.load_state_dict(torch.load(pretrain_pifold_path))
17
+
18
+ @torch.no_grad()
19
+ def forward(self, batch):
20
+ h_V, h_P, P_idx, batch_id = batch['h_V'], batch['h_E'], batch['E_idx'], batch['batch_id']
21
+ device = h_V.device
22
+ h_V = self.W_v(self.norm_nodes(self.node_embedding(h_V)))
23
+ h_P = self.W_e(self.norm_edges(self.edge_embedding(h_P)))
24
+
25
+ h_V, h_P = self.encoder(h_V, h_P, P_idx, batch_id)
26
+ log_probs, logits = self.decoder(h_V, batch_id)
27
+ probs = F.softmax(logits, dim=-1)
28
+ conf, pred_id = probs.max(dim=-1)
29
+
30
+ maxL = 0
31
+ for b in batch_id.unique():
32
+ mask = batch_id==b
33
+ L = mask.sum()
34
+ if L>maxL:
35
+ maxL=L
36
+
37
+ confs = []
38
+ seqs = []
39
+ embeds = []
40
+ probs2 = []
41
+ for b in batch_id.unique():
42
+ mask = batch_id==b
43
+ # elements = [alphabet[int(id)] for id in pred_id[mask]]
44
+ elements = self.tokenizer.decode(pred_id[mask]).split(" ")
45
+ seqs.append(elements)
46
+ confs.append(conf[mask])
47
+ embeds.append(h_V[mask])
48
+ probs2.append(probs[mask])
49
+
50
+ seqs = self.tokenizer(["".join(one) for one in seqs], padding=True, truncation=True, return_tensors='pt', add_special_tokens=False)
51
+ confs = torch.stack([F.pad(one, (0, maxL-len(one))) for one in confs])
52
+ embeds = torch.stack([F.pad(one, (0,0, 0, maxL-len(one))) for one in embeds])
53
+ probs2 = torch.stack([F.pad(one, (0,0, 0, maxL-len(one)), value=1/33) for one in probs2])
54
+
55
+ ret = {"pred_ids":seqs['input_ids'].to(device),
56
+ "confs":confs,
57
+ "embeds":embeds,
58
+ "probs":probs2,
59
+ "attention_mask":seqs['attention_mask'].to(device),
60
+ "E_idx":P_idx,
61
+ "batch_id":batch_id,
62
+ "h_E":h_P}
63
+ return ret
64
+
Flexpert-Design/src/models/Tuning.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from src.modules.pifold_module import *
3
+ from torch_scatter import scatter_softmax, scatter_log_softmax
4
+
5
+ def positional_encoding(x):
6
+ batch_size, seq_len, hidden_size = x.size()
7
+ pos = torch.arange(0, seq_len).float().unsqueeze(1).repeat(1, hidden_size // 2)
8
+ div = torch.exp(torch.arange(0, hidden_size, 2).float() * (-torch.log(torch.tensor(10000.0)) / hidden_size))
9
+ sin = torch.sin(pos * div)
10
+ cos = torch.cos(pos * div)
11
+ pos_encoding = torch.cat([sin, cos], dim=-1).unsqueeze(0).repeat(batch_size, 1, 1)
12
+ return pos_encoding
13
+
14
+
15
+ class MSAAttention(nn.Module):
16
+ def __init__(self, hidden_dim) -> None:
17
+ super().__init__()
18
+ self.MSA_Q = nn.Linear(hidden_dim, hidden_dim)
19
+ self.MSA_K = nn.Linear(hidden_dim, hidden_dim)
20
+ self.MSA_V = nn.Linear(hidden_dim, hidden_dim)
21
+
22
+ def forward(self, inputs_embeds):
23
+ pos_enc = positional_encoding(inputs_embeds)
24
+ inputs_embeds = inputs_embeds + pos_enc
25
+
26
+ query = self.MSA_Q(inputs_embeds) # shape: [batch, N, 128]
27
+ key = self.MSA_K(inputs_embeds) # shape: [batch, N, 128]
28
+ value = self.MSA_V(inputs_embeds) # shape: [batch, N, 128]
29
+ attn_scores = torch.bmm(query, key.transpose(1, 2))
30
+ attn_weights = nn.functional.softmax(attn_scores, dim=2)
31
+
32
+ attn_output = torch.bmm(attn_weights, value)
33
+ return attn_output
34
+
35
+
36
+ class GNNTuning_Model(nn.Module):
37
+ def __init__(self, args, num_encoder_layers, hidden_dim, input_design_dim, input_esm_dim, input_struct_dim=3072, input_esmif_dim=512, dropout=0.1):
38
+ super(GNNTuning_Model, self).__init__()
39
+ self.args = args
40
+ encoder_layers = []
41
+ for i in range(num_encoder_layers):
42
+ encoder_layers.append(
43
+ GeneralGNN(hidden_dim, hidden_dim*2, dropout=dropout, node_net = "AttMLP", edge_net = "EdgeMLP", node_context = 1, edge_context = 0),
44
+ )
45
+ self.encoder_layers = nn.Sequential(*encoder_layers)
46
+
47
+ from transformers import AutoTokenizer
48
+ from transformers.models.esm.modeling_esm import EsmModel, EsmEmbeddings
49
+ from transformers.models.esm.configuration_esm import EsmConfig
50
+
51
+ self.tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D", cache_dir="./cache_dir/")
52
+ config = EsmConfig(attention_probs_dropout_prob=0,
53
+ hidden_size=hidden_dim,
54
+ intermediate_size=1280,
55
+ mask_token_id=32,
56
+ num_attention_heads=12,
57
+ num_hidden_layers=3,
58
+ pad_token_id=1,
59
+ position_embedding_type="rotary",
60
+ token_dropout=False,
61
+ vocab_size=33
62
+ )
63
+
64
+ self.DesignEmbed = EsmEmbeddings(config)
65
+ self.ESMEmbed = EsmEmbeddings(config)
66
+ self.EdgeEmbed = nn.Sequential(nn.Linear(416+16+16, 512),
67
+ nn.ReLU(),
68
+ nn.Linear(512, hidden_dim),
69
+ nn.ReLU(),
70
+ nn.Linear(hidden_dim,hidden_dim))
71
+
72
+ self.DesignConf = nn.Sequential(nn.Linear(1, 128),
73
+ nn.ReLU(),
74
+ nn.Linear(128, 128),
75
+ nn.ReLU(),
76
+ nn.Linear(128,1),
77
+ nn.Sigmoid())
78
+
79
+ self.ESMConf = nn.Sequential(nn.Linear(1, 128),
80
+ nn.ReLU(),
81
+ nn.Linear(128, 128),
82
+ nn.ReLU(),
83
+ nn.Linear(128,1))
84
+
85
+ self.DesignProj = nn.Sequential(nn.Linear(input_design_dim, 512),
86
+ nn.ReLU(),
87
+ nn.Linear(512, hidden_dim),
88
+ nn.ReLU(),
89
+ nn.Linear(hidden_dim,hidden_dim))
90
+
91
+ self.ESMProj = nn.Sequential(nn.Linear(input_esm_dim, 512),
92
+ nn.ReLU(),
93
+ nn.Linear(512, hidden_dim),
94
+ nn.ReLU(),
95
+ nn.Linear(hidden_dim,hidden_dim))
96
+
97
+ self.StructProj = nn.Sequential(nn.Linear(input_struct_dim, 512),
98
+ nn.ReLU(),
99
+ nn.Linear(512, hidden_dim),
100
+ nn.ReLU(),
101
+ nn.Linear(hidden_dim,hidden_dim))
102
+
103
+ self.ESMIFProj = nn.Sequential(nn.Linear(input_esmif_dim, 512),
104
+ nn.ReLU(),
105
+ nn.Linear(512, hidden_dim),
106
+ nn.ReLU(),
107
+ nn.Linear(hidden_dim,hidden_dim))
108
+
109
+ self.ReadOut = nn.Linear(hidden_dim,33)
110
+ # self.TimeEmbed = nn.Embedding(20, hidden_dim)
111
+ # self.ProbEmbed = nn.Sequential(nn.Linear(33, 512),
112
+ # nn.ReLU(),
113
+ # nn.Linear(512, hidden_dim),
114
+ # nn.ReLU(),
115
+ # nn.Linear(hidden_dim,hidden_dim))
116
+
117
+ self.MLP1 = nn.Sequential(nn.Linear(1, 512),
118
+ nn.ReLU(),
119
+ nn.Linear(512, hidden_dim),
120
+ nn.ReLU(),
121
+ nn.Linear(hidden_dim,1),
122
+ nn.Sigmoid())
123
+
124
+ self.MLP2 = nn.Sequential(nn.Linear(1, 512),
125
+ nn.ReLU(),
126
+ nn.Linear(512, hidden_dim),
127
+ nn.ReLU(),
128
+ nn.Linear(hidden_dim,1),
129
+ nn.Sigmoid())
130
+
131
+
132
+
133
+
134
+ # def embed_gnn(self, pretrain_gnn, mask_select_id, mask_select_feat):
135
+ # gnn_embed = self.DesignEmbed(mask_select_id(pretrain_gnn['pred_ids'])).squeeze()
136
+ # gnn_conf = self.DesignConf(mask_select_id(pretrain_gnn['confs']))
137
+ # gnn_proj = self.DesignProj(mask_select_feat(pretrain_gnn['embeds']))
138
+
139
+ # if self.args.use_confembed:
140
+ # return gnn_embed*F.sigmoid(gnn_conf) + gnn_proj
141
+ # else:
142
+ # return gnn_embed + gnn_proj
143
+
144
+ # def embed_esm(self, pretrain_esm, mask_select_id, mask_select_feat):
145
+ # esm_embed = self.ESMEmbed(mask_select_id(pretrain_esm['pred_ids'])).squeeze()
146
+ # esm_conf = self.ESMConf(mask_select_id(pretrain_esm['confs']))
147
+ # esm_proj = self.ESMProj(mask_select_feat(pretrain_esm['embeds']))
148
+ # if self.args.use_confembed:
149
+ # return esm_embed*F.sigmoid(esm_conf) + esm_proj
150
+ # else:
151
+ # return esm_embed + esm_proj
152
+
153
+ # def embed_struct(self, pretrain_struct, mask_select_feat):
154
+ # struct_proj = self.StructProj(mask_select_feat(pretrain_struct['embeds']))
155
+ # return struct_proj
156
+
157
+ # def embed_esmif(self, pretrain_esmif, mask_select_feat):
158
+ # struct_proj = self.ESMIFProj(mask_select_feat(pretrain_esmif['embeds']))
159
+ # return struct_proj
160
+
161
+ def fuse(self, mask_select_feat, mask_select_id, gnn_embed=None, esm_embed=None, gearnet_embed=None, esmif_embed=None, gnn_pred_id=None, esm_pred_id=None, confidence=None, confidence_esm=None):
162
+ gnn, esm, gearnet, esmif, conf = 0, 0, 0, 0, 1.0
163
+ if gnn_embed is not None:
164
+ gnn = self.DesignProj(mask_select_feat(gnn_embed))
165
+ gnn += self.DesignEmbed(mask_select_id(gnn_pred_id)).squeeze()
166
+
167
+ if esm_embed is not None:
168
+ esm = self.ESMProj(mask_select_feat(esm_embed))
169
+ esm += self.ESMEmbed(mask_select_id(esm_pred_id)).squeeze()
170
+
171
+ if gearnet_embed is not None:
172
+ gearnet = self.StructProj(mask_select_feat(gearnet_embed))
173
+
174
+ if esmif_embed is not None:
175
+ esmif = self.ESMIFProj(mask_select_feat(esmif_embed))
176
+
177
+ if conf is not None:
178
+ conf = self.DesignConf(mask_select_id(confidence))
179
+ esm_conf = self.ESMConf(mask_select_id(confidence_esm))
180
+
181
+ return (gnn*conf+esm*esm_conf+gearnet+esmif)
182
+
183
+
184
+ def forward(self, batch):
185
+ pretrain_design, h_E_raw, E_idx, mask_attend, batch_id = batch['pretrain_design'], batch['h_E'], batch['E_idx'], batch['attention_mask'], batch['batch_id']
186
+
187
+ if self.args.use_LM:
188
+ pretrain_esm_msa = batch['pretrain_esm_msa']
189
+
190
+ if self.args.use_gearnet:
191
+ pretrain_struct = batch['pretrain_struct']
192
+
193
+ if self.args.use_esmif:
194
+ pretrain_esmif = batch['pretrain_esmif']
195
+
196
+ mask_select_id = lambda x: torch.masked_select(x, mask_attend.bool()).reshape(-1,1)
197
+ mask_select_feat = lambda x: torch.masked_select(x, mask_attend.bool().unsqueeze(-1)).reshape(-1,x.shape[-1])
198
+
199
+ inputs_embeds = 0
200
+ for i in range(self.args.msa_n):
201
+ gnn_embed = pretrain_design['embeds']
202
+ esm_embed = pretrain_esm_msa['embeds'][i] if self.args.use_LM else None
203
+ gearnet_embed = pretrain_struct['embeds'][i] if self.args.use_gearnet else None
204
+ esmif_embed = pretrain_esmif['embeds'] if self.args.use_esmif else None
205
+ confidence = pretrain_design['confs']
206
+ confidence_esm = pretrain_esm_msa['confs'][i]
207
+ inputs_embeds += self.fuse(mask_select_feat, mask_select_id, gnn_embed, esm_embed, gearnet_embed, esmif_embed, pretrain_design['pred_ids'], pretrain_esm_msa['pred_ids'][i], confidence, confidence_esm)
208
+
209
+ h_V = inputs_embeds
210
+ h_E = self.EdgeEmbed(h_E_raw)
211
+
212
+ for layer in self.encoder_layers:
213
+ h_V, h_E = layer(h_V, h_E, E_idx, batch_id)
214
+
215
+ logits = self.ReadOut(h_V)
216
+
217
+ # confidence update
218
+ old_confs = mask_select_id(pretrain_design['confs'])
219
+ confs = torch.softmax(logits, dim=-1).max(dim=-1)[0][:,None]
220
+ h_V = h_V*self.MLP1(confs-old_confs) + inputs_embeds*self.MLP2(old_confs-confs)
221
+ logits = self.ReadOut(h_V)
222
+
223
+ B, N = pretrain_design['confs'].shape
224
+ vocab_size = logits.shape[-1]
225
+
226
+ new_logits = torch.zeros(B,N,vocab_size, device=logits.device).reshape(B*N, vocab_size)
227
+ new_logits = new_logits.masked_scatter_(mask_attend.bool().view(-1,1), logits)
228
+ new_logits = new_logits.reshape(B,N,vocab_size)
229
+ log_probs = torch.log_softmax(new_logits, dim=-1)
230
+
231
+ device = logits.device
232
+ seqs, confs, embeds, probs2 = self.to_matrix(h_V, logits, batch_id)
233
+
234
+
235
+ ret = {"pred_ids":seqs['input_ids'].to(device),
236
+ "confs":confs,
237
+ "embeds":embeds,
238
+ "probs":probs2,
239
+ "attention_mask":seqs['attention_mask'].to(device),
240
+ "h_E":h_E_raw,
241
+ "E_idx":E_idx,
242
+ "batch_id":batch_id,
243
+ "log_probs":log_probs}
244
+ return ret
245
+
246
+ def to_matrix(self, h_V, logits, batch_id):
247
+
248
+ probs = F.softmax(logits, dim=-1)
249
+ conf, pred_id = probs.max(dim=-1)
250
+
251
+ maxL = 0
252
+ for b in batch_id.unique():
253
+ mask = batch_id==b
254
+ L = mask.sum()
255
+ if L>maxL:
256
+ maxL=L
257
+
258
+ confs = []
259
+ seqs = []
260
+ embeds = []
261
+ probs2 = []
262
+ for b in batch_id.unique():
263
+ mask = batch_id==b
264
+ # elements = [alphabet[int(id)] for id in pred_id[mask]]
265
+ elements = self.tokenizer.decode(pred_id[mask]).split(" ")
266
+ seqs.append(elements)
267
+ confs.append(conf[mask])
268
+ embeds.append(h_V[mask])
269
+ probs2.append(probs[mask])
270
+
271
+ seqs = self.tokenizer(["".join(one) for one in seqs], padding=True, truncation=True, return_tensors='pt', add_special_tokens=False)
272
+ confs = torch.stack([F.pad(one, (0, maxL-len(one))) for one in confs])
273
+ embeds = torch.stack([F.pad(one, (0,0, 0, maxL-len(one))) for one in embeds])
274
+ probs2 = torch.stack([F.pad(one, (0,0, 0, maxL-len(one)), value=1/33) for one in probs2])
275
+ return seqs, confs, embeds, probs2
Flexpert-Design/src/models/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) CAIRI AI Lab. All rights reserved
2
+
3
+ # from .alphadesign_model import AlphaDesign_Model
4
+ # from .esmif_model import GVPTransformerModel as ESMIF_Model
5
+ # from .gca_model import GCA_Model
6
+ # from .graphtrans_model import GraphTrans_Model
7
+ # from .gvp_model import GVP_Model
8
+ # from .pifold_model import PiFold_Model
9
+ from .proteinmpnn_model import ProteinMPNN_Model
10
+ # from .structgnn_model import StructGNN_Model
11
+ # from .kwdesign_model import KWDesign_model
12
+
13
+ __all__ = [
14
+ 'AlphaDesign_Model', 'ESMIF_Model', 'GCA_Model', 'GraphTrans_Model', 'GVP_Model',
15
+ 'PiFold_Model', 'ProteinMPNN_Model', 'StructGNN_Model', 'KWDesign_model'
16
+ ]
Flexpert-Design/src/models/alphadesign_model.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from src.modules.alphadesign_module import ATDecoder, CNNDecoder, CNNDecoder2, StructureEncoder
5
+ from src.tools.design_utils import gather_nodes, _dihedrals, _rbf, _orientations_coarse_gl
6
+
7
+
8
+ class AlphaDesign_Model(nn.Module):
9
+ def __init__(self, args, **kwargs):
10
+ """ Graph labeling network """
11
+ super(AlphaDesign_Model, self).__init__()
12
+ self.args = args
13
+ node_features = args.node_features
14
+ edge_features = args.edge_features
15
+ hidden_dim = args.hidden_dim
16
+ dropout = args.dropout
17
+ num_encoder_layers = args.num_encoder_layers
18
+ self.top_k = args.k_neighbors
19
+ self.num_rbf = 16
20
+ self.num_positional_embeddings = 16
21
+
22
+ if args.use_new_feat:
23
+ node_in, edge_in = 12, 16+7
24
+ else:
25
+ node_in, edge_in = 6, 16+7
26
+ self.node_embedding = nn.Linear(node_in, node_features, bias=True)
27
+ self.edge_embedding = nn.Linear(edge_in, edge_features, bias=True)
28
+ self.norm_nodes = nn.BatchNorm1d(node_features)
29
+ self.norm_edges = nn.BatchNorm1d(edge_features)
30
+
31
+ self.W_v = nn.Sequential(
32
+ nn.Linear(node_features, hidden_dim, bias=True),
33
+ nn.LeakyReLU(),
34
+ nn.BatchNorm1d(hidden_dim),
35
+ nn.Linear(hidden_dim, hidden_dim, bias=True),
36
+ nn.LeakyReLU(),
37
+ nn.BatchNorm1d(hidden_dim),
38
+ nn.Linear(hidden_dim, hidden_dim, bias=True)
39
+ )
40
+
41
+
42
+ self.W_e = nn.Linear(edge_features, hidden_dim, bias=True)
43
+ self.W_f = nn.Linear(edge_features, hidden_dim, bias=True)
44
+
45
+ self.encoder = StructureEncoder(hidden_dim, num_encoder_layers, dropout, use_SGT=self.args.use_SGT)
46
+
47
+ if args.autoregressive:
48
+ self.decoder = ATDecoder(args, hidden_dim, dropout)
49
+ else:
50
+ self.decoder = CNNDecoder(hidden_dim, hidden_dim)
51
+ self.decoder2 = CNNDecoder2(hidden_dim, hidden_dim)
52
+
53
+ # self.chain_embed = nn.Embedding(2,16)
54
+ self._init_params()
55
+
56
+ def forward(self, batch, AT_test = False, return_logit=False):
57
+ h_V, h_P, P_idx, batch_id = batch['_V'], batch['_E'], batch['E_idx'], batch['batch_id']
58
+ h_V = self.W_v(self.norm_nodes(self.node_embedding(h_V)))
59
+ h_P = self.W_e(self.norm_edges(self.edge_embedding(h_P)))
60
+
61
+ h_V = self.encoder(h_V, h_P, P_idx, batch_id)
62
+ log_probs0 = None
63
+ if AT_test:
64
+ log_probs = self.decoder.sampling(h_V, h_P, P_idx, batch_id)
65
+ else:
66
+ log_probs0, logits = self.decoder(h_V, batch_id)
67
+ log_probs, logits = self.decoder2(h_V, logits, batch_id)
68
+ if return_logit:
69
+ return {'log_probs': log_probs, 'logits': logits}
70
+ else:
71
+ return {'log_probs': log_probs, 'log_probs0': log_probs0}
72
+
73
+ def _init_params(self):
74
+ for p in self.parameters():
75
+ if p.dim() > 1:
76
+ nn.init.xavier_uniform_(p)
77
+
78
+ def _full_dist(self, X, mask, top_k=30, eps=1E-6):
79
+ mask_2D = torch.unsqueeze(mask,1) * torch.unsqueeze(mask,2)
80
+ dX = torch.unsqueeze(X,1) - torch.unsqueeze(X,2)
81
+ D = (1. - mask_2D)*10000 + mask_2D* torch.sqrt(torch.sum(dX**2, 3) + eps)
82
+
83
+ D_max, _ = torch.max(D, -1, keepdim=True)
84
+ D_adjust = D + (1. - mask_2D) * (D_max+1)
85
+ D_neighbors, E_idx = torch.topk(D_adjust, min(top_k, D_adjust.shape[-1]), dim=-1, largest=False)
86
+ return D_neighbors, E_idx
87
+
88
+ def _get_features(self, batch):
89
+ S, score, X, mask = batch['S'], batch['score'], batch['X'], batch['mask']
90
+ mask_bool = (mask==1)
91
+
92
+ B, N, _,_ = X.shape
93
+ X_ca = X[:,:,1,:]
94
+ D_neighbors, E_idx = self._full_dist(X_ca, mask, self.top_k)
95
+
96
+ # sequence
97
+ S = torch.masked_select(S, mask_bool)
98
+ if score is not None:
99
+ score = torch.masked_select(score, mask_bool)
100
+
101
+ # node feature
102
+ _V = _dihedrals(X)
103
+ if not self.args.use_new_feat:
104
+ _V = _V[...,:6]
105
+ _V = torch.masked_select(_V, mask_bool.unsqueeze(-1)).reshape(-1,_V.shape[-1])
106
+
107
+ # edge feature
108
+ _E = torch.cat((_rbf(D_neighbors, self.num_rbf), _orientations_coarse_gl(X, E_idx)), -1) # [4,387,387,23]
109
+ mask_attend = gather_nodes(mask.unsqueeze(-1), E_idx).squeeze(-1) # 一阶邻居节点的mask: 1代表节点存在, 0代表节点不存在
110
+ mask_attend = (mask.unsqueeze(-1) * mask_attend) == 1 # 自身的mask*邻居节点的mask
111
+ _E = torch.masked_select(_E, mask_attend.unsqueeze(-1)).reshape(-1,_E.shape[-1])
112
+
113
+ # edge index
114
+ shift = mask.sum(dim=1).cumsum(dim=0) - mask.sum(dim=1)
115
+ src = shift.view(B,1,1) + E_idx
116
+ src = torch.masked_select(src, mask_attend).view(1,-1)
117
+ dst = shift.view(B,1,1) + torch.arange(0, N, device=src.device).view(1,-1,1).expand_as(mask_attend)
118
+ dst = torch.masked_select(dst, mask_attend).view(1,-1)
119
+ E_idx = torch.cat((dst, src), dim=0).long()
120
+
121
+
122
+ # 3D point
123
+ sparse_idx = mask.nonzero()
124
+ X = X[sparse_idx[:,0],sparse_idx[:,1],:,:]
125
+ batch_id = sparse_idx[:,0]
126
+
127
+ mask = torch.masked_select(mask, mask_bool)
128
+
129
+ batch.update({'X':X,
130
+ 'S':S,
131
+ 'score':score,
132
+ '_V':_V,
133
+ '_E':_E,
134
+ 'E_idx':E_idx,
135
+ 'batch_id': batch_id,
136
+ 'mask':mask})
137
+
138
+ return batch
Flexpert-Design/src/models/anm_prottrans.py ADDED
@@ -0,0 +1,677 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #import dependencies
2
+ import os.path
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
8
+ from torch.utils.data import DataLoader
9
+
10
+ import re
11
+ import numpy as np
12
+ import pandas as pd
13
+ import copy
14
+ import pdb
15
+
16
+ import transformers, datasets
17
+ from transformers.modeling_outputs import TokenClassifierOutput, BaseModelOutputWithPastAndCrossAttentions
18
+ from transformers.models.t5.modeling_t5 import T5Config, T5PreTrainedModel, T5Stack
19
+ from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
20
+ from transformers import T5EncoderModel, T5Tokenizer
21
+ from transformers import TrainingArguments, Trainer, set_seed
22
+ from safetensors import safe_open
23
+
24
+ #DataCollator
25
+ from transformers.data.data_collator import DataCollatorMixin
26
+ from transformers.tokenization_utils_base import PreTrainedTokenizerBase
27
+ from transformers.utils import PaddingStrategy
28
+
29
+ import random
30
+ import warnings
31
+ from collections.abc import Mapping
32
+ from dataclasses import dataclass
33
+ from random import randint
34
+ from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union
35
+
36
+ from evaluate import load
37
+ from datasets import Dataset
38
+
39
+ from tqdm import tqdm
40
+ import random
41
+
42
+ from scipy import stats
43
+ from sklearn.metrics import accuracy_score
44
+
45
+ import matplotlib.pyplot as plt
46
+
47
+ from Bio import SeqIO
48
+ from io import StringIO
49
+ import requests
50
+ import tempfile
51
+
52
+ from sklearn.model_selection import train_test_split
53
+ import csv
54
+
55
+
56
+ #### UTILS
57
+
58
+ class LoRAConfig:
59
+ def __init__(self):
60
+ self.lora_rank = 4
61
+ self.lora_init_scale = 0.01
62
+ self.lora_modules = ".*SelfAttention|.*EncDecAttention"
63
+ self.lora_layers = "q|k|v|o"
64
+ self.trainable_param_names = ".*layer_norm.*|.*lora_[ab].*"
65
+ self.lora_scaling_rank = 1
66
+ # lora_modules and lora_layers are speicified with regular expressions
67
+ # see https://www.w3schools.com/python/python_regex.asp for reference
68
+
69
+ class LoRALinear(nn.Module):
70
+ def __init__(self, linear_layer, rank, scaling_rank, init_scale):
71
+ super().__init__()
72
+ self.in_features = linear_layer.in_features
73
+ self.out_features = linear_layer.out_features
74
+ self.rank = rank
75
+ self.scaling_rank = scaling_rank
76
+ self.weight = linear_layer.weight
77
+ self.bias = linear_layer.bias
78
+ if self.rank > 0:
79
+ self.lora_a = nn.Parameter(torch.randn(rank, linear_layer.in_features) * init_scale)
80
+ if init_scale < 0:
81
+ self.lora_b = nn.Parameter(torch.randn(linear_layer.out_features, rank) * init_scale)
82
+ else:
83
+ self.lora_b = nn.Parameter(torch.zeros(linear_layer.out_features, rank))
84
+ if self.scaling_rank:
85
+ self.multi_lora_a = nn.Parameter(
86
+ torch.ones(self.scaling_rank, linear_layer.in_features)
87
+ + torch.randn(self.scaling_rank, linear_layer.in_features) * init_scale
88
+ )
89
+ if init_scale < 0:
90
+ self.multi_lora_b = nn.Parameter(
91
+ torch.ones(linear_layer.out_features, self.scaling_rank)
92
+ + torch.randn(linear_layer.out_features, self.scaling_rank) * init_scale
93
+ )
94
+ else:
95
+ self.multi_lora_b = nn.Parameter(torch.ones(linear_layer.out_features, self.scaling_rank))
96
+
97
+ def forward(self, input):
98
+ if self.scaling_rank == 1 and self.rank == 0:
99
+ # parsimonious implementation for ia3 and lora scaling
100
+ if self.multi_lora_a.requires_grad:
101
+ hidden = F.linear((input * self.multi_lora_a.flatten()), self.weight, self.bias)
102
+ else:
103
+ hidden = F.linear(input, self.weight, self.bias)
104
+ if self.multi_lora_b.requires_grad:
105
+ hidden = hidden * self.multi_lora_b.flatten()
106
+ return hidden
107
+ else:
108
+ # general implementation for lora (adding and scaling)
109
+ weight = self.weight
110
+ if self.scaling_rank:
111
+ weight = weight * torch.matmul(self.multi_lora_b, self.multi_lora_a) / self.scaling_rank
112
+ if self.rank:
113
+ weight = weight + torch.matmul(self.lora_b, self.lora_a) / self.rank
114
+ return F.linear(input, weight, self.bias)
115
+
116
+ def extra_repr(self):
117
+ return "in_features={}, out_features={}, bias={}, rank={}, scaling_rank={}".format(
118
+ self.in_features, self.out_features, self.bias is not None, self.rank, self.scaling_rank
119
+ )
120
+
121
+
122
+ def modify_with_lora(transformer, config):
123
+ for m_name, module in dict(transformer.named_modules()).items():
124
+ if re.fullmatch(config.lora_modules, m_name):
125
+ for c_name, layer in dict(module.named_children()).items():
126
+ if re.fullmatch(config.lora_layers, c_name):
127
+ assert isinstance(
128
+ layer, nn.Linear
129
+ ), f"LoRA can only be applied to torch.nn.Linear, but {layer} is {type(layer)}."
130
+ setattr(
131
+ module,
132
+ c_name,
133
+ LoRALinear(layer, config.lora_rank, config.lora_scaling_rank, config.lora_init_scale),
134
+ )
135
+ return transformer
136
+
137
+ class ClassConfig:
138
+ def __init__(self, dropout=0.2, num_labels=1, add_pearson_loss=False, add_sse_loss=False, adaptor_architecture = None , enm_embed_dim = 512, enm_att_heads = 8, kernel_size = 3, num_layers = 2, **kwargs):
139
+ self.dropout_rate = dropout
140
+ self.num_labels = num_labels
141
+ self.add_pearson_loss = add_pearson_loss
142
+ self.add_sse_loss = add_sse_loss
143
+ self.adaptor_architecture = adaptor_architecture
144
+ self.enm_embed_dim = enm_embed_dim
145
+ self.enm_att_heads = enm_att_heads
146
+ self.kernel_size = kernel_size
147
+ self.num_layers = num_layers
148
+
149
+ class ENMAdaptedAttentionClassifier(nn.Module):
150
+ def __init__(self, seq_embedding_dim, out_dim, enm_embed_dim, num_att_heads):
151
+ super(ENMAdaptedAttentionClassifier, self).__init__()
152
+ self.embedding = nn.Linear(1, enm_embed_dim)
153
+ self.enm_attention = nn.MultiheadAttention(enm_embed_dim, num_att_heads)
154
+ self.layer_norm = nn.LayerNorm(enm_embed_dim)
155
+ self.enm_adaptor = nn.Linear(enm_embed_dim, seq_embedding_dim)
156
+ self.adapted_classifier = nn.Linear(2*seq_embedding_dim, out_dim)
157
+
158
+ def forward(self, seq_embedding, enm_input):
159
+ enm_input = enm_input.transpose(0, 1) # Transpose to shape (N, B, E) for MultiheadAttention
160
+ enm_input = enm_input.unsqueeze(-1) # Add a dimension for the embedding
161
+ enm_input_embedded = self.embedding(enm_input)
162
+ enm_att, _ = self.enm_attention(enm_input_embedded, enm_input_embedded, enm_input_embedded)
163
+ enm_att = enm_att.transpose(0, 1) # Transpose back to shape (B, N, E)
164
+ enm_att = self.layer_norm(enm_att + enm_input.transpose(0, 1))
165
+ enm_embedding = self.enm_adaptor(enm_att)
166
+ combined_embedding = torch.cat((seq_embedding, enm_embedding), dim=-1)
167
+ logits = self.adapted_classifier(combined_embedding)
168
+ return logits
169
+
170
+ class ENMAdaptedConvClassifier(nn.Module):
171
+ def __init__(self, seq_embedding_dim, out_dim, kernel_size, enm_embedding_dim, num_layers):
172
+ super(ENMAdaptedConvClassifier, self).__init__()
173
+ layers = []
174
+ self.conv1 = nn.Conv1d(1, enm_embedding_dim, kernel_size=kernel_size, padding=(kernel_size-1)//2)
175
+ layers.append(self.conv1)
176
+ layers.append(nn.ReLU())
177
+ for i in range(num_layers-1):
178
+ layers.append(nn.Conv1d(enm_embedding_dim, enm_embedding_dim, kernel_size=kernel_size, padding=(kernel_size-1)//2))
179
+ layers.append(nn.ReLU())
180
+ self.conv_net = nn.Sequential(*layers)
181
+ self.adapted_classifier = nn.Linear(seq_embedding_dim+1, out_dim)
182
+
183
+ def forward(self, seq_embedding, enm_input, attention_mask=None):
184
+ enm_input = torch.nan_to_num(enm_input, nan=0.0)
185
+ enm_input = enm_input.unsqueeze(1)
186
+ conv_out = self.conv_net(enm_input)
187
+ enm_embedding = conv_out.transpose(1,2)
188
+
189
+ if attention_mask is not None:
190
+ # Use attention_mask to ignore padded elements
191
+ mask = attention_mask.unsqueeze(-1).float()
192
+ enm_embedding = enm_embedding * mask
193
+ # Compute mean over non-padded elements
194
+
195
+ enm_embedding = enm_embedding.mean(dim=-1).unsqueeze(-1)
196
+ # enm_embedding = enm_embedding.sum(dim=2)/ mask.sum(dim=2).clamp(min=1e-9)
197
+ else:
198
+ raise ValueError('We actually want to provide the mask.')
199
+ enm_embedding = torch.mean(enm_embedding, dim=1)
200
+
201
+ # enm_embedding = enm_embedding.unsqueeze(1).expand(-1, seq_embedding.size(1), -1)
202
+ combined_embedding = torch.cat((seq_embedding, enm_embedding), dim=-1)
203
+ logits = self.adapted_classifier(combined_embedding)
204
+ return logits
205
+
206
+
207
+
208
+ class ENMAdaptedDirectClassifier(nn.Module):
209
+ def __init__(self, seq_embedding_dim, out_dim):
210
+ super(ENMAdaptedDirectClassifier, self).__init__()
211
+ self.adapted_classifier = nn.Linear(seq_embedding_dim+1, out_dim)
212
+
213
+ def forward(self, seq_embedding, enm_input):
214
+ enm_input = enm_input.unsqueeze(-1)
215
+ combined_embedding = torch.cat((seq_embedding, enm_input), dim=-1)
216
+ logits = self.adapted_classifier(combined_embedding)
217
+ return logits
218
+
219
+ class ENMNoAdaptorClassifier(nn.Module):
220
+ def __init__(self, seq_embedding_dim, out_dim):
221
+ super(ENMNoAdaptorClassifier, self).__init__()
222
+ self.adapted_classifier = nn.Linear(seq_embedding_dim, out_dim)
223
+
224
+ def forward(self, seq_embedding, enm_input):
225
+ _ = enm_input #ignoring enm_input
226
+ logits = self.adapted_classifier(seq_embedding)
227
+ return logits
228
+
229
+
230
+ class T5EncoderForTokenClassification(T5PreTrainedModel):
231
+
232
+ def __init__(self, config: T5Config, class_config):
233
+ super().__init__(config)
234
+ self.num_labels = class_config.num_labels
235
+ self.config = config
236
+ self.add_pearson_loss = class_config.add_pearson_loss
237
+ self.add_sse_loss = class_config.add_sse_loss
238
+ self.shared = nn.Embedding(config.vocab_size, config.d_model)
239
+
240
+ encoder_config = copy.deepcopy(config)
241
+ encoder_config.use_cache = False
242
+ encoder_config.is_encoder_decoder = False
243
+ self.encoder = T5Stack(encoder_config, self.shared)
244
+ # self.encoder = CustomT5Stack(encoder_config, self.shared)
245
+
246
+ # import pdb; pdb.set_trace()
247
+ original_embedding = self.encoder.embed_tokens
248
+ in_dim, out_dim = tuple(original_embedding.weight.shape)
249
+ self.new_embedding = nn.Linear(in_dim, out_dim, bias=False).to('cuda:0') #TODO: pass the correct weights!!! And careful! the embedding layer and the linear layer are maybe mutually "transposed"
250
+ print("Initialized new_embedding layer - without weights yet!")
251
+ # self.new_embedding.weight = nn.Parameter(original_embedding.weight.T)
252
+
253
+ # self.weight = original_embedding.weight
254
+ # self.weight = nn.Parameter(self.new_embedding.weight.T)
255
+ # self.encoder.forward = new_forward.__get__(self.encoder, self.encoder.__class__)
256
+
257
+ self.dropout = nn.Dropout(class_config.dropout_rate)
258
+ if class_config.adaptor_architecture == 'attention':
259
+ self.classifier = ENMAdaptedAttentionClassifier(config.hidden_size, class_config.num_labels, class_config.enm_embed_dim, class_config.enm_att_heads) #nn.Linear(config.hidden_size, class_config.num_labels)
260
+ elif class_config.adaptor_architecture == 'direct':
261
+ self.classifier = ENMAdaptedDirectClassifier(config.hidden_size, class_config.num_labels)
262
+ elif class_config.adaptor_architecture == 'conv':
263
+ self.classifier = ENMAdaptedConvClassifier(config.hidden_size, class_config.num_labels, class_config.kernel_size, class_config.enm_embed_dim, class_config.num_layers)
264
+ elif class_config.adaptor_architecture == 'no-adaptor':
265
+ self.classifier = ENMNoAdaptorClassifier(config.hidden_size, class_config.num_labels)
266
+ else:
267
+ raise ValueError('Only attention, direct, conv and no-adaptor architectures are supported for the adaptor.')
268
+
269
+
270
+ # Initialize weights and apply final processing
271
+ self.post_init()
272
+
273
+ # Model parallel
274
+ self.model_parallel = False
275
+ self.device_map = None
276
+
277
+ def parallelize(self, device_map=None):
278
+ self.device_map = (
279
+ get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
280
+ if device_map is None
281
+ else device_map
282
+ )
283
+ assert_device_map(self.device_map, len(self.encoder.block))
284
+ self.encoder.parallelize(self.device_map)
285
+ self.classifier = self.classifier.to(self.encoder.first_device)
286
+ self.model_parallel = True
287
+
288
+ def deparallelize(self):
289
+ self.encoder.deparallelize()
290
+ self.encoder = self.encoder.to("cpu")
291
+ self.model_parallel = False
292
+ self.device_map = None
293
+ torch.cuda.empty_cache()
294
+
295
+ def get_input_embeddings(self):
296
+ return self.shared
297
+
298
+ def set_input_embeddings(self, new_embeddings):
299
+ self.shared = new_embeddings
300
+ self.encoder.set_input_embeddings(new_embeddings)
301
+
302
+ def get_encoder(self):
303
+ return self.encoder
304
+
305
+ def _prune_heads(self, heads_to_prune):
306
+ """
307
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
308
+ class PreTrainedModel
309
+ """
310
+ for layer, heads in heads_to_prune.items():
311
+ self.encoder.layer[layer].attention.prune_heads(heads)
312
+
313
+ def forward(
314
+ self,
315
+ enm_vals = None,
316
+ input_ids=None,
317
+ attention_mask=None,
318
+ head_mask=None,
319
+ inputs_embeds=None,
320
+ labels=None,
321
+ output_attentions=None,
322
+ output_hidden_states=None,
323
+ return_dict=None,
324
+ ):
325
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
326
+ if inputs_embeds is not None:
327
+ outputs = self.encoder(input_ids=None,attention_mask=attention_mask,inputs_embeds=inputs_embeds,head_mask=head_mask,output_attentions=output_attentions,output_hidden_states=output_hidden_states,return_dict=return_dict,)
328
+ elif input_ids is not None:
329
+ outputs = self.encoder(input_ids=input_ids,attention_mask=attention_mask,inputs_embeds=None,head_mask=head_mask,output_attentions=output_attentions,output_hidden_states=output_hidden_states,return_dict=return_dict,)
330
+ sequence_output = outputs[0]
331
+ # import pdb; pdb.set_trace() #TODO: CHECK EVERYTHING IS IN EVAL MODE and the dropout below is OFF
332
+ sequence_output = self.dropout(sequence_output)
333
+ #TODO: check the enm_vals are padded properly and check that the sequence limit (in the transformer) is indeed 512
334
+ # logits = self.classifier(sequence_output, enm_vals)
335
+
336
+ logits = self.classifier(sequence_output, enm_vals, attention_mask)
337
+ if not return_dict:
338
+ output = (logits,) + outputs[2:]
339
+ return ((loss,) + output) if loss is not None else output
340
+
341
+ return TokenClassifierOutput(
342
+ #loss=loss,
343
+ logits=logits,
344
+ hidden_states=outputs.hidden_states,
345
+ attentions=outputs.attentions,
346
+ )
347
+
348
+ class ENMAdaptedTrainer(Trainer):
349
+ def compute_loss(self, model, inputs, return_outputs=False):
350
+ labels = inputs.get("labels")
351
+ #enm_vals = inputs.get("enm_vals")
352
+
353
+ outputs = model(**inputs)
354
+ logits = outputs.get('logits')
355
+ mask = inputs.get('attention_mask')
356
+ loss_fct = MSELoss()
357
+
358
+ active_loss = mask.view(-1) == 1
359
+ active_logits = logits.view(-1)
360
+ active_labels = torch.where(active_loss, labels.view(-1), torch.tensor(-100).type_as(labels))
361
+ valid_logits=active_logits[active_labels!=-100]
362
+ valid_labels=active_labels[active_labels!=-100]
363
+
364
+ loss = loss_fct(valid_labels, valid_logits)
365
+ return (loss, outputs) if return_outputs else loss
366
+
367
+
368
+
369
+ def PT5_classification_model(half_precision, class_config):
370
+ # Load PT5 and tokenizer
371
+ # possible to load the half preciion model (thanks to @pawel-rezo for pointing that out)
372
+ if not half_precision:
373
+ model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_uniref50")
374
+ tokenizer = T5Tokenizer.from_pretrained("Rostlab/prot_t5_xl_uniref50")
375
+ elif half_precision and torch.cuda.is_available() :
376
+ tokenizer = T5Tokenizer.from_pretrained('Rostlab/prot_t5_xl_half_uniref50-enc', do_lower_case=False)
377
+ model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_half_uniref50-enc", torch_dtype=torch.float16).to(torch.device('cuda'))
378
+ else:
379
+ raise ValueError('Half precision can be run on GPU only.')
380
+
381
+
382
+
383
+ # Create new Classifier model with PT5 dimensions
384
+ class_model=T5EncoderForTokenClassification(model.config,class_config)
385
+
386
+ # Set encoder and embedding weights to checkpoint weights
387
+ class_model.shared=model.shared
388
+ class_model.encoder=model.encoder
389
+
390
+ # Delete the checkpoint model
391
+ model=class_model
392
+ del class_model
393
+
394
+ # Print number of trainable parameters
395
+ model_parameters = filter(lambda p: p.requires_grad, model.parameters())
396
+ params = sum([np.prod(p.size()) for p in model_parameters])
397
+ print("ProtT5_Classfier\nTrainable Parameter: "+ str(params))
398
+
399
+ # Add model modification lora
400
+ config = LoRAConfig()
401
+
402
+ # Add LoRA layers
403
+ model = modify_with_lora(model, config)
404
+
405
+ # Freeze Embeddings and Encoder (except LoRA)
406
+ for (param_name, param) in model.shared.named_parameters():
407
+ param.requires_grad = False
408
+ for (param_name, param) in model.encoder.named_parameters():
409
+ param.requires_grad = False
410
+
411
+ for (param_name, param) in model.named_parameters():
412
+ if re.fullmatch(config.trainable_param_names, param_name):
413
+ param.requires_grad = True
414
+
415
+ # Print trainable Parameter
416
+ model_parameters = filter(lambda p: p.requires_grad, model.parameters())
417
+ params = sum([np.prod(p.size()) for p in model_parameters])
418
+ print("ProtT5_LoRA_Classfier\nTrainable Parameter: "+ str(params) + "\n")
419
+
420
+ return model, tokenizer
421
+
422
+
423
+ @dataclass
424
+ class DataCollatorForTokenRegression(DataCollatorMixin):
425
+ """
426
+ Data collator that will dynamically pad the inputs received, as well as the labels.
427
+ Args:
428
+ tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]):
429
+ The tokenizer used for encoding the data.
430
+ padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):
431
+ Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
432
+ among:
433
+ - `True` or `'longest'` (default): Pad to the longest sequence in the batch (or no padding if only a single
434
+ sequence is provided).
435
+ - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
436
+ acceptable input length for the model if that argument is not provided.
437
+ - `False` or `'do_not_pad'`: No padding (i.e., can output a batch with sequences of different lengths).
438
+ max_length (`int`, *optional*):
439
+ Maximum length of the returned list and optionally padding length (see above).
440
+ pad_to_multiple_of (`int`, *optional*):
441
+ If set will pad the sequence to a multiple of the provided value.
442
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
443
+ 7.5 (Volta).
444
+ label_pad_token_id (`int`, *optional*, defaults to -100):
445
+ The id to use when padding the labels (-100 will be automatically ignore by PyTorch loss functions).
446
+ return_tensors (`str`):
447
+ The type of Tensor to return. Allowable values are "np", "pt" and "tf".
448
+ """
449
+
450
+ tokenizer: PreTrainedTokenizerBase
451
+ padding: Union[bool, str, PaddingStrategy] = True
452
+ max_length: Optional[int] = None
453
+ pad_to_multiple_of: Optional[int] = None
454
+ label_pad_token_id: int = -100
455
+ return_tensors: str = "pt"
456
+
457
+ def torch_call(self, features):
458
+ label_name = "label" if "label" in features[0].keys() else "labels"
459
+ labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None
460
+
461
+ no_labels_features = [{k: v for k, v in feature.items() if k != label_name and k!= 'enm_vals'} for feature in features]
462
+
463
+ batch = self.tokenizer.pad(
464
+ no_labels_features,
465
+ padding=self.padding,
466
+ max_length=self.max_length,
467
+ pad_to_multiple_of=self.pad_to_multiple_of,
468
+ return_tensors="pt",
469
+ )
470
+
471
+ batch['enm_vals'] = torch.nn.utils.rnn.pad_sequence([torch.tensor(feature['enm_vals'], dtype=torch.float) for feature in features], batch_first=True, padding_value=0.0)
472
+ #batch = self.tokenizer.pad(no_labels_features,padding=self.padding,max_length=self.max_length,pad_to_multiple_of=self.pad_to_multiple_of,return_tensors="pt")
473
+ if labels is None:
474
+ return batch
475
+
476
+ sequence_length = batch["input_ids"].shape[1]
477
+ padding_side = self.tokenizer.padding_side
478
+
479
+ def to_list(tensor_or_iterable):
480
+ if isinstance(tensor_or_iterable, torch.Tensor):
481
+ return tensor_or_iterable.tolist()
482
+ return list(tensor_or_iterable)
483
+
484
+ if padding_side == "right":
485
+ batch[label_name] = [
486
+ to_list(label) + [self.label_pad_token_id] * (sequence_length - len(label)) for label in labels
487
+
488
+ ]
489
+ else:
490
+ batch[label_name] = [
491
+ [self.label_pad_token_id] * (sequence_length - len(label)) + to_list(label) for label in labels
492
+ ]
493
+
494
+ batch[label_name] = torch.tensor(batch[label_name], dtype=torch.float)
495
+ return batch
496
+
497
+ def _torch_collate_batch(examples, tokenizer, pad_to_multiple_of: Optional[int] = None):
498
+ """Collate `examples` into a batch, using the information in `tokenizer` for padding if necessary."""
499
+ # Tensorize if necessary.
500
+ if isinstance(examples[0], (list, tuple, np.ndarray)):
501
+ examples = [torch.tensor(e, dtype=torch.long) for e in examples]
502
+
503
+ length_of_first = examples[0].size(0)
504
+
505
+ # Check if padding is necessary.
506
+
507
+ are_tensors_same_length = all(x.size(0) == length_of_first for x in examples)
508
+ if are_tensors_same_length and (pad_to_multiple_of is None or length_of_first % pad_to_multiple_of == 0):
509
+ return torch.stack(examples, dim=0)
510
+
511
+ # If yes, check if we have a `pad_token`.
512
+ if tokenizer._pad_token is None:
513
+ raise ValueError(
514
+ "You are attempting to pad samples but the tokenizer you are using"
515
+ f" ({tokenizer.__class__.__name__}) does not have a pad token."
516
+ )
517
+
518
+ # Creating the full tensor and filling it with our data.
519
+ max_length = max(x.size(0) for x in examples)
520
+ if pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
521
+ max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
522
+ result = examples[0].new_full([len(examples), max_length], tokenizer.pad_token_id)
523
+ for i, example in enumerate(examples):
524
+ if tokenizer.padding_side == "right":
525
+ result[i, : example.shape[0]] = example
526
+ else:
527
+ result[i, -example.shape[0] :] = example
528
+ return result
529
+
530
+ def tolist(x):
531
+ if isinstance(x, list):
532
+ return x
533
+ elif hasattr(x, "numpy"): # Checks for TF tensors without needing the import
534
+ x = x.numpy()
535
+ return x.tolist()
536
+
537
+ #### END OF UTILS
538
+
539
+ def do_topology_split(df, split_path):
540
+ import json
541
+ with open(split_path, 'r') as f:
542
+ splits = json.load(f)
543
+
544
+ #split the dataframe according to the splits
545
+ train_df = df[df['name'].isin(splits['train'])]
546
+ valid_df = df[df['name'].isin(splits['validation'])]
547
+ test_df = df[df['name'].isin(splits['test'])]
548
+ return train_df, valid_df, test_df
549
+
550
+
551
+ class ANMAwareFlexibilityProtTrans(nn.Module):
552
+ def __init__(self, gumbel_temperature, **kwargs):
553
+ super(ANMAwareFlexibilityProtTrans, self).__init__()
554
+
555
+ model, tokenizer = self.load_finetuned_model(**kwargs)
556
+ self.model = model
557
+ self.tokenizer = tokenizer
558
+ self.device = torch.device('cuda')
559
+ self.model.to(self.device)
560
+ self.model.eval()
561
+ self.gumbel_temperature = gumbel_temperature
562
+ self.logit_transform = nn.functional.gumbel_softmax #Use the Straight Through Gumbel SoftMax - in forward process it does argmax,
563
+
564
+ # in the backward process it approximates the gradient of argmax by the gradient of the Gumbel Softmax
565
+ # https://pytorch.org/docs/stable/generated/torch.nn.functional.gumbel_softmax.html set hard=True to do the Straight-Through trick
566
+
567
+ self.conversion_tensor = self.construct_pmpnn_t5_conversion_tensor()
568
+
569
+
570
+ def construct_pmpnn_t5_conversion_tensor(self):
571
+ """
572
+ Creates tensor which takes the onehot encodings in the proteinmpnn vocabulary and maps them to ProtTrans vocabulary.
573
+ """
574
+ _one_hots = []
575
+ for idx in [[0,1], 2, [3, 29, 30, 31, 32], 5, 4, 6, 7, 8, 10, 9, 13, 11, 12, 14, 15, 18, 16, 17, 19, 20, 21, 22, 23, 24, 25, 28, 26, 27]:
576
+ if isinstance(idx, int):
577
+ _oh = F.one_hot(torch.tensor([idx]), 33)
578
+ else:
579
+ _sohs = []
580
+ for subidx in idx:
581
+ _soh = F.one_hot(torch.tensor([subidx]), 33)
582
+ _sohs.append(_soh)
583
+ _oh = torch.sum(torch.stack(_sohs), dim=0)
584
+ _one_hots.append(_oh)
585
+ #_one_hots = [F.one_hot(torch.tensor([idx]), 33)[0] if isinstance(idx, int) else torch.sum(torch.stack([F.one_hot(torch.tensor([subidx]), 33)[0] for subidx in idx]), dim=0) for idx in [[0,1], 2, [3, 29, 30, 31, 32], 5, 4, 6, 7, 8, 10, 9, 13, 11, 12, 14, 15, 18, 16, 17, 19, 20, 21, 22, 23, 24, 25, 28, 26, 27]]
586
+ _one_hots.extend([torch.zeros((1,33)) for _ in range(100)])
587
+ return torch.cat(_one_hots, dim=0).to(torch.device('cuda')).float()
588
+
589
+ def load_finetuned_model(self, checkpoint_path, half_precision, **kwargs):#num_labels, add_pearson_loss, add_sse_loss, adaptor_architecture, enm_embed_dim, enm_att_heads, num_layers, kernel_size):
590
+ class_config=ClassConfig(**kwargs) #um_labels=num_labels, add_pearson_loss=add_pearson_loss, add_sse_loss=add_sse_loss, adaptor_architecture = adaptor_architecture, enm_embed_dim = enm_embed_dim, enm_att_heads = enm_att_heads, num_layers = num_layers, kernel_size = kernel_size)
591
+ model, tokenizer = PT5_classification_model(half_precision=half_precision, class_config=class_config) #.from_pretrained(args.model_path)
592
+
593
+ # model.load_state_dict(torch.load(args.model_path))
594
+ # try:
595
+ # with safe_open(f"{checkpoint_path}/model.safetensors", framework="pt", device="cuda:0") as f:
596
+ # state_dict = {}
597
+ # for key in f.keys():
598
+ # state_dict[key] = f.get_tensor(key)
599
+ # model.load_state_dict(state_dict, strict=False)
600
+ # except:
601
+ # state_dict = torch.load(f"{checkpoint_path}/pytorch_model.bin", map_location='cuda:0')
602
+ # model.load_state_dict(state_dict, strict=False)
603
+ state_dict = torch.load(checkpoint_path, map_location='cuda:0')
604
+ model.load_state_dict(state_dict, strict=False)
605
+ model.eval()
606
+
607
+ original_embedding = model.encoder.embed_tokens
608
+ model.new_embedding.weight = nn.Parameter(original_embedding.weight.T)
609
+ print('Set the weights for the new embedding layer!')
610
+ return model, tokenizer
611
+
612
+ def translate_to_model_vocab(self, batch_one_hot, trail_idcs):
613
+ # Pad the batch_one_hot tensor with zeros along the last dimension
614
+ batch_one_hot = F.pad(batch_one_hot, (0, 1, 0, 0, 0, 0), 'constant', 0)
615
+
616
+ #TODO: VERIFY THAT THE GRADIENTS ARE OK AFTER THE MASKED_SCATTER OPERATION
617
+ # Create a mask for the '2' token
618
+ mask = torch.zeros_like(batch_one_hot, dtype=torch.bool)
619
+ for i, trail_idx in enumerate(trail_idcs):
620
+ if trail_idx < batch_one_hot.size(2): # Ensure index is within bounds
621
+ mask[i, :, trail_idx] = True
622
+
623
+ # Create a tensor with '2' in the one-hot encoding
624
+ token_2 = torch.zeros_like(batch_one_hot)
625
+ token_2[:, 2, :] = 1 # Assuming '2' corresponds to index 2 in the one-hot encoding
626
+
627
+ # Use masked_scatter_ to modify the tensor in-place while preserving gradients
628
+ batch_one_hot.masked_scatter_(mask, token_2[mask])
629
+
630
+ T5_translation = torch.einsum('ej,ijk->iek', self.conversion_tensor, batch_one_hot)
631
+ T5_translation = T5_translation.permute(0,2,1)
632
+ return T5_translation
633
+
634
+ def forward(self, pmpnn_logits, anm_input, trail_idcs, attention_mask, sampled_pmpnn_sequence = None, alphabet = None): #batch example 32x33x395 (batch_size x ProteinMPNN vocab size x seq length)
635
+
636
+ anm_input = F.pad(anm_input, (0, 1, 0, 0), 'constant', 0)
637
+ attention_mask = F.pad(attention_mask, (0, 1, 0, 0), 'constant', 1)
638
+
639
+ if sampled_pmpnn_sequence is None:
640
+ if alphabet is None:
641
+ batch_one_hot = self.logit_transform(pmpnn_logits, tau=self.gumbel_temperature, hard=True, dim=1)
642
+ batch_token_ids = self.translate_to_model_vocab(batch_one_hot, trail_idcs)
643
+ inputs = batch_token_ids #.to(torch.int)
644
+ # elif alphabet == 'aa':
645
+ # batch_one_hot = ... #TODO one hot encode the pmpnn tokens
646
+ # batch_token_ids = self.translate_to_model_vocab(batch_one_hot, trail_idcs)
647
+ # input_ids = ... #TODO: argmax to get the tokens from the one hot encodings
648
+ # outputs = self.model(input_ids = input_ids, enm_vals=anm_input, attention_mask = attention_mask) #TODO?: pass the mask as well (take it from the batch, pad it for the end of sequence, convert to Tensor)
649
+ # predicted_flex = outputs.logits
650
+ # return {'predicted_flex': predicted_flex, 'enm_vals': anm_input, 'input_ids': input_ids}
651
+
652
+ # elif alphabet is None:
653
+ # raise ValueError('need to specify what alphabet is used to encode sampled_pmpnn_sequence!')
654
+ # elif alphabet is 'pmpnn':
655
+ # # Convert sampled_pmpnn_sequence to one-hot encoding
656
+ # batch_one_hot = F.one_hot(sampled_pmpnn_sequence, num_classes=33).float().permute(0,2,1)
657
+ # batch_token_ids = self.translate_to_model_vocab(batch_one_hot, trail_idcs)
658
+ # inputs = batch_token_ids
659
+ # elif alphabet is 'pt5':
660
+ # inputs = F.one_hot(sampled_pmpnn_sequence, num_classes=128).float() #.permute(0,2,1)
661
+ # elif alphabet is 'aa':
662
+ # ... #TODO apply tokenizer
663
+ # #tokens = self.tokenizer(" ".join(sampled_pmpnn_sequence))
664
+ # tokens = self.tokenizer(" ".join(sampled_pmpnn_sequence))
665
+ # input_ids = torch.tensor(tokens['input_ids']).cuda().unsqueeze(0)
666
+
667
+ # outputs = self.model(input_ids = input_ids, enm_vals=anm_input, attention_mask = attention_mask) #TODO?: pass the mask as well (take it from the batch, pad it for the end of sequence, convert to Tensor)
668
+ # predicted_flex = outputs.logits
669
+ # return {'predicted_flex': predicted_flex, 'enm_vals': anm_input, 'input_ids': input_ids}
670
+
671
+ inputs_embeds = self.model.new_embedding(inputs) #TODO pass through embedding
672
+ outputs = self.model(enm_vals=anm_input, inputs_embeds = inputs_embeds, attention_mask = attention_mask) #TODO?: pass the mask as well (take it from the batch, pad it for the end of sequence, convert to Tensor)
673
+ #TODO: above it throws RuntimeError: Expected tensor for argument #1 'indices' to have one of the following scalar types:
674
+ # Long, Int; but got torch.cuda.FloatTensor instead (while checking arguments for embedding)
675
+
676
+ predicted_flex = outputs.logits
677
+ return {'predicted_flex': predicted_flex, 'enm_vals': anm_input, 'input_ids': inputs}