Jacob Gershon commited on
Commit
59a9ccf
0 Parent(s):
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +34 -0
  2. LICENSE +21 -0
  3. README.md +17 -0
  4. app.py +404 -0
  5. examples/aa_weights.json +22 -0
  6. examples/binder_design.sh +16 -0
  7. examples/loop_design.sh +15 -0
  8. examples/motif_scaffolding.sh +14 -0
  9. examples/out/design_000000.pdb +0 -0
  10. examples/out/design_000000.trb +0 -0
  11. examples/partial_diffusion.sh +15 -0
  12. examples/pdbs/G12D_manual_mut.pdb +0 -0
  13. examples/pdbs/cd86.pdb +0 -0
  14. examples/pdbs/rsv5_5tpn.pdb +0 -0
  15. examples/secondary_structure.sh +21 -0
  16. examples/secondary_structure_bias.sh +15 -0
  17. examples/secondary_structure_from_pdb.sh +21 -0
  18. examples/symmetric_design.sh +16 -0
  19. examples/weighted_sequence.sh +15 -0
  20. examples/weighted_sequence_json.sh +16 -0
  21. model/.ipynb_checkpoints/RoseTTAFoldModel-checkpoint.py +140 -0
  22. model/Attention_module.py +411 -0
  23. model/AuxiliaryPredictor.py +92 -0
  24. model/Embeddings.py +307 -0
  25. model/RoseTTAFoldModel.py +140 -0
  26. model/SE3_network.py +83 -0
  27. model/Track_module.py +476 -0
  28. model/__pycache__/Attention_module.cpython-310.pyc +0 -0
  29. model/__pycache__/AuxiliaryPredictor.cpython-310.pyc +0 -0
  30. model/__pycache__/Embeddings.cpython-310.pyc +0 -0
  31. model/__pycache__/RoseTTAFoldModel.cpython-310.pyc +0 -0
  32. model/__pycache__/SE3_network.cpython-310.pyc +0 -0
  33. model/__pycache__/Track_module.cpython-310.pyc +0 -0
  34. model/__pycache__/ab_tools.cpython-310.pyc +0 -0
  35. model/__pycache__/apply_masks.cpython-310.pyc +0 -0
  36. model/__pycache__/arguments.cpython-310.pyc +0 -0
  37. model/__pycache__/chemical.cpython-310.pyc +0 -0
  38. model/__pycache__/data_loader.cpython-310.pyc +0 -0
  39. model/__pycache__/diffusion.cpython-310.pyc +0 -0
  40. model/__pycache__/kinematics.cpython-310.pyc +0 -0
  41. model/__pycache__/loss.cpython-310.pyc +0 -0
  42. model/__pycache__/mask_generator.cpython-310.pyc +0 -0
  43. model/__pycache__/parsers.cpython-310.pyc +0 -0
  44. model/__pycache__/scheduler.cpython-310.pyc +0 -0
  45. model/__pycache__/scoring.cpython-310.pyc +0 -0
  46. model/__pycache__/train_multi_deep.cpython-310.pyc +0 -0
  47. model/__pycache__/train_multi_deep_selfcond_nostruc.cpython-310.pyc +0 -0
  48. model/__pycache__/util.cpython-310.pyc +0 -0
  49. model/__pycache__/util_module.cpython-310.pyc +0 -0
  50. model/apply_masks.py +196 -0
.gitattributes ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 RosettaCommons
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: PROTEIN GENERATOR
3
+ emoji: 🧪
4
+ thumbnail: http://files.ipd.uw.edu/pub/sequence_diffusion/figs/diffusion_landscape.png
5
+ colorFrom: blue
6
+ colorTo: purple
7
+ sdk: gradio
8
+ sdk_version: 3.24.1
9
+ app_file: app.py
10
+ pinned: false
11
+ ---
12
+
13
+ ![fig](./utils/figs/thumbnail_github.png)
14
+
15
+ ## Code Accessibility
16
+
17
+ To download code and for more details please visit the [github](https://github.com/RosettaCommons/protein_generator)!
app.py ADDED
@@ -0,0 +1,404 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os,sys
2
+
3
+ # install environment goods
4
+ #os.system("pip -q install dgl -f https://data.dgl.ai/wheels/cu113/repo.html")
5
+ os.system('pip install dgl==1.0.2+cu116 -f https://data.dgl.ai/wheels/cu116/repo.html')
6
+ #os.system('pip install gradio')
7
+ os.environ["DGLBACKEND"] = "pytorch"
8
+ #os.system(f'pip install -r ./PROTEIN_GENERATOR/requirements.txt')
9
+ print('Modules installed')
10
+
11
+ os.system('pip install --force gradio==3.28.3')
12
+
13
+ os.environ["DGLBACKEND"] = "pytorch"
14
+
15
+ if not os.path.exists('./SEQDIFF_230205_dssp_hotspots_25mask_EQtasks_mod30.pt'):
16
+ print('Downloading model weights 1')
17
+ os.system('wget http://files.ipd.uw.edu/pub/sequence_diffusion/checkpoints/SEQDIFF_230205_dssp_hotspots_25mask_EQtasks_mod30.pt')
18
+ print('Successfully Downloaded')
19
+
20
+ if not os.path.exists('./SEQDIFF_221219_equalTASKS_nostrSELFCOND_mod30.pt'):
21
+ print('Downloading model weights 2')
22
+ os.system('wget http://files.ipd.uw.edu/pub/sequence_diffusion/checkpoints/SEQDIFF_221219_equalTASKS_nostrSELFCOND_mod30.pt')
23
+ print('Successfully Downloaded')
24
+
25
+ import numpy as np
26
+ import gradio as gr
27
+ import py3Dmol
28
+ from io import StringIO
29
+ import json
30
+ import secrets
31
+ import copy
32
+ import matplotlib.pyplot as plt
33
+ from utils.sampler import HuggingFace_sampler
34
+
35
+ plt.rcParams.update({'font.size': 13})
36
+
37
+ with open('./tmp/args.json','r') as f:
38
+ args = json.load(f)
39
+
40
+ # manually set checkpoint to load
41
+ args['checkpoint'] = None
42
+ args['dump_trb'] = False
43
+ args['dump_args'] = True
44
+ args['save_best_plddt'] = True
45
+ args['T'] = 25
46
+ args['strand_bias'] = 0.0
47
+ args['loop_bias'] = 0.0
48
+ args['helix_bias'] = 0.0
49
+
50
+
51
+
52
+ def protein_diffusion_model(sequence, seq_len, helix_bias, strand_bias, loop_bias,
53
+ secondary_structure, aa_bias, aa_bias_potential,
54
+ #target_charge, target_ph, charge_potential,
55
+ num_steps, noise, hydrophobic_target_score, hydrophobic_potential):
56
+
57
+ dssp_checkpoint = './SEQDIFF_230205_dssp_hotspots_25mask_EQtasks_mod30.pt'
58
+ og_checkpoint = './SEQDIFF_221219_equalTASKS_nostrSELFCOND_mod30.pt'
59
+
60
+ model_args = copy.deepcopy(args)
61
+
62
+ # make sampler
63
+ S = HuggingFace_sampler(args=model_args)
64
+
65
+ # get random prefix
66
+ S.out_prefix = './tmp/'+secrets.token_hex(nbytes=10).upper()
67
+
68
+ # set args
69
+ S.args['checkpoint'] = None
70
+ S.args['dump_trb'] = False
71
+ S.args['dump_args'] = True
72
+ S.args['save_best_plddt'] = True
73
+ S.args['T'] = 20
74
+ S.args['strand_bias'] = 0.0
75
+ S.args['loop_bias'] = 0.0
76
+ S.args['helix_bias'] = 0.0
77
+ S.args['potentials'] = None
78
+ S.args['potential_scale'] = None
79
+ S.args['aa_composition'] = None
80
+
81
+
82
+ # get sequence if entered and make sure all chars are valid
83
+ alt_aa_dict = {'B':['D','N'],'J':['I','L'],'U':['C'],'Z':['E','Q'],'O':['K']}
84
+ if sequence not in ['',None]:
85
+ L = len(sequence)
86
+ aa_seq = []
87
+ for aa in sequence.upper():
88
+ if aa in alt_aa_dict.keys():
89
+ aa_seq.append(np.random.choice(alt_aa_dict[aa]))
90
+ else:
91
+ aa_seq.append(aa)
92
+
93
+ S.args['sequence'] = aa_seq
94
+ else:
95
+ S.args['contigs'] = [f'{seq_len}']
96
+ L = int(seq_len)
97
+
98
+ if secondary_structure in ['',None]:
99
+ secondary_structure = None
100
+ else:
101
+ secondary_structure = ''.join(['E' if x == 'S' else x for x in secondary_structure])
102
+ if L < len(secondary_structure):
103
+ secondary_structure = secondary_structure[:len(sequence)]
104
+ elif L == len(secondary_structure):
105
+ pass
106
+ else:
107
+ dseq = L - len(secondary_structure)
108
+ secondary_structure += secondary_structure[-1]*dseq
109
+
110
+
111
+ # potentials
112
+ potential_list = []
113
+ potential_bias_list = []
114
+
115
+ if aa_bias not in ['',None]:
116
+ potential_list.append('aa_bias')
117
+ S.args['aa_composition'] = aa_bias
118
+ if aa_bias_potential in ['',None]:
119
+ aa_bias_potential = 3
120
+ potential_bias_list.append(str(aa_bias_potential))
121
+ '''
122
+ if target_charge not in ['',None]:
123
+ potential_list.append('charge')
124
+ if charge_potential in ['',None]:
125
+ charge_potential = 1
126
+ potential_bias_list.append(str(charge_potential))
127
+ S.args['target_charge'] = float(target_charge)
128
+ if target_ph in ['',None]:
129
+ target_ph = 7.4
130
+ S.args['target_pH'] = float(target_ph)
131
+ '''
132
+
133
+ if hydrophobic_target_score not in ['',None]:
134
+ potential_list.append('hydrophobic')
135
+ S.args['hydrophobic_score'] = float(hydrophobic_target_score)
136
+ if hydrophobic_potential in ['',None]:
137
+ hydrophobic_potential = 3
138
+ potential_bias_list.append(str(hydrophobic_potential))
139
+
140
+
141
+ if len(potential_list) > 0:
142
+ S.args['potentials'] = ','.join(potential_list)
143
+ S.args['potential_scale'] = ','.join(potential_bias_list)
144
+
145
+
146
+ # normalise secondary_structure bias from range 0-0.3
147
+ S.args['secondary_structure'] = secondary_structure
148
+ S.args['helix_bias'] = helix_bias
149
+ S.args['strand_bias'] = strand_bias
150
+ S.args['loop_bias'] = loop_bias
151
+
152
+ # set T
153
+ if num_steps in ['',None]:
154
+ S.args['T'] = 20
155
+ else:
156
+ S.args['T'] = int(num_steps)
157
+
158
+ # noise
159
+ if 'normal' in noise:
160
+ S.args['sample_distribution'] = noise
161
+ S.args['sample_distribution_gmm_means'] = [0]
162
+ S.args['sample_distribution_gmm_variances'] = [1]
163
+ elif 'gmm2' in noise:
164
+ S.args['sample_distribution'] = noise
165
+ S.args['sample_distribution_gmm_means'] = [-1,1]
166
+ S.args['sample_distribution_gmm_variances'] = [1,1]
167
+ elif 'gmm3' in noise:
168
+ S.args['sample_distribution'] = noise
169
+ S.args['sample_distribution_gmm_means'] = [-1,0,1]
170
+ S.args['sample_distribution_gmm_variances'] = [1,1,1]
171
+
172
+
173
+
174
+ if secondary_structure not in ['',None] or helix_bias+strand_bias+loop_bias > 0:
175
+ S.args['checkpoint'] = dssp_checkpoint
176
+ S.args['d_t1d'] = 29
177
+ print('using dssp checkpoint')
178
+ else:
179
+ S.args['checkpoint'] = og_checkpoint
180
+ S.args['d_t1d'] = 24
181
+ print('using og checkpoint')
182
+
183
+
184
+ for k,v in S.args.items():
185
+ print(f"{k} --> {v}")
186
+
187
+ # init S
188
+ S.model_init()
189
+ S.diffuser_init()
190
+ S.setup()
191
+
192
+ # sampling loop
193
+ plddt_data = []
194
+ for j in range(S.max_t):
195
+ output_seq, output_pdb, plddt = S.take_step_get_outputs(j)
196
+ plddt_data.append(plddt)
197
+ yield output_seq, output_pdb, display_pdb(output_pdb), get_plddt_plot(plddt_data, S.max_t)
198
+
199
+ output_seq, output_pdb, plddt = S.get_outputs()
200
+
201
+ yield output_seq, output_pdb, display_pdb(output_pdb), get_plddt_plot(plddt_data, S.max_t)
202
+
203
+ def get_plddt_plot(plddt_data, max_t):
204
+ x = [i+1 for i in range(len(plddt_data))]
205
+ fig, ax = plt.subplots(figsize=(15,6))
206
+ ax.plot(x,plddt_data,color='#661dbf', linewidth=3,marker='o')
207
+ ax.set_xticks([i+1 for i in range(max_t)])
208
+ ax.set_yticks([(i+1)/10 for i in range(10)])
209
+ ax.set_ylim([0,1])
210
+ ax.set_ylabel('model confidence (plddt)')
211
+ ax.set_xlabel('diffusion steps (t)')
212
+ return fig
213
+
214
+ def display_pdb(path_to_pdb):
215
+ '''
216
+ #function to display pdb in py3dmol
217
+ '''
218
+ pdb = open(path_to_pdb, "r").read()
219
+
220
+ view = py3Dmol.view(width=500, height=500)
221
+ view.addModel(pdb, "pdb")
222
+ view.setStyle({'model': -1}, {"cartoon": {'colorscheme':{'prop':'b','gradient':'roygb','min':0,'max':1}}})#'linear', 'min': 0, 'max': 1, 'colors': ["#ff9ef0","#a903fc",]}}})
223
+ view.zoomTo()
224
+ output = view._make_html().replace("'", '"')
225
+ print(view._make_html())
226
+ x = f"""<!DOCTYPE html><html></center> {output} </center></html>""" # do not use ' in this input
227
+
228
+ return f"""<iframe height="500px" width="100%" name="result" allow="midi; geolocation; microphone; camera;
229
+ display-capture; encrypted-media;" sandbox="allow-modals allow-forms
230
+ allow-scripts allow-same-origin allow-popups
231
+ allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
232
+ allowpaymentrequest="" frameborder="0" srcdoc='{x}'></iframe>"""
233
+
234
+ '''
235
+
236
+ return f"""<iframe style="width: 100%; height:700px" name="result" allow="midi; geolocation; microphone; camera;
237
+ display-capture; encrypted-media;" sandbox="allow-modals allow-forms
238
+ allow-scripts allow-same-origin allow-popups
239
+ allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
240
+ allowpaymentrequest="" frameborder="0" srcdoc='{x}'></iframe>"""
241
+ '''
242
+
243
+ def toggle_seq_input(choice):
244
+ if choice == "protein length":
245
+ return gr.update(visible=True, value=None), gr.update(visible=False, value=None)
246
+ elif choice == "custom sequence":
247
+ return gr.update(visible=False, value=None), gr.update(visible=True, value=None)
248
+
249
+ def toggle_secondary_structure(choice):
250
+ if choice == "sliders":
251
+ return gr.update(visible=True, value=None),gr.update(visible=True, value=None),gr.update(visible=True, value=None),gr.update(visible=False, value=None)
252
+ elif choice == "explicit":
253
+ return gr.update(visible=False, value=None),gr.update(visible=False, value=None),gr.update(visible=False, value=None),gr.update(visible=True, value=None)
254
+
255
+ # Define the Gradio interface
256
+ with gr.Blocks(theme='ParityError/Interstellar') as demo:
257
+
258
+ gr.Markdown(f"""# Protein Generation via Diffusion in Sequence Space""")
259
+
260
+ with gr.Row():
261
+ with gr.Column(min_width=500):
262
+ gr.Markdown(f"""
263
+ ## How does it work?\n
264
+ --- [PREPRINT](https://biorxiv.org/content/10.1101/2023.05.08.539766v1) ---
265
+
266
+ Protein sequence and structure co-generation is a long outstanding problem in the field of protein design. By implementing [ddpm](https://arxiv.org/abs/2006.11239) style diffusion over protein seqeuence space we generate protein sequence and structure pairs. Starting with [RoseTTAFold](https://www.science.org/doi/10.1126/science.abj8754), a protein structure prediction network, we finetuned it to predict sequence and structure given a partially noised sequence. By applying losses to both the predicted sequence and structure the model is forced to generate meaningful pairs. Diffusing in sequence space makes it easy to implement potentials to guide the diffusive process toward particular amino acid composition, net charge, and more! Furthermore, you can sample proteins from a family of sequences or even train a small sequence to function classifier to guide generation toward desired sequences.
267
+ ![fig1](http://files.ipd.uw.edu/pub/sequence_diffusion/figs/diffusion_landscape.png)
268
+
269
+ ## How to use it?\n
270
+ A user can either design a custom input sequence to diffuse from or specify a length below. To scaffold a sequence use the following format where X represent residues to diffuse: XXXXXXXXSCIENCESCIENCEXXXXXXXXXXXXXXXXXXX. You can even design a protein with your name XXXXXXXXXXXXNAMEHEREXXXXXXXXXXXXX!
271
+
272
+ ### Acknowledgements\n
273
+ Thank you to Simon Dürr and the Hugging Face team for setting us up with a community GPU grant!
274
+ """)
275
+
276
+ gr.Markdown("""
277
+ ## Model in Action
278
+ ![gif1](http://files.ipd.uw.edu/pub/sequence_diffusion/figs/seqdiff_anim_720p.gif)
279
+ """)
280
+
281
+ with gr.Row().style(equal_height=False):
282
+ with gr.Column():
283
+ gr.Markdown("""## INPUTS""")
284
+ gr.Markdown("""#### Start Sequence
285
+ Specify the protein length for complete unconditional generation, or scaffold a motif (or your name) using the custom sequence input""")
286
+ seq_opt = gr.Radio(["protein length","custom sequence"], label="How would you like to specify the starting sequence?", value='protein length')
287
+
288
+ sequence = gr.Textbox(label="custom sequence", lines=1, placeholder='AMINO ACIDS: A,C,D,E,F,G,H,I,K,L,M,N,P,Q,R,S,T,V,W,Y\n MASK TOKEN: X', visible=False)
289
+ seq_len = gr.Slider(minimum=5.0, maximum=250.0, label="protein length", value=100, visible=True)
290
+
291
+ seq_opt.change(fn=toggle_seq_input,
292
+ inputs=[seq_opt],
293
+ outputs=[seq_len, sequence],
294
+ queue=False)
295
+
296
+ gr.Markdown("""### Optional Parameters""")
297
+ with gr.Accordion(label='Secondary Structure',open=True):
298
+ gr.Markdown("""Try changing the sliders or inputing explicit secondary structure conditioning for each residue""")
299
+ sec_str_opt = gr.Radio(["sliders","explicit"], label="How would you like to specify secondary structure?", value='sliders')
300
+
301
+ secondary_structure = gr.Textbox(label="secondary structure", lines=1, placeholder='HELIX = H STRAND = S LOOP = L MASK = X(must be the same length as input sequence)', visible=False)
302
+
303
+ with gr.Column():
304
+ helix_bias = gr.Slider(minimum=0.0, maximum=0.05, label="helix bias", visible=True)
305
+ strand_bias = gr.Slider(minimum=0.0, maximum=0.05, label="strand bias", visible=True)
306
+ loop_bias = gr.Slider(minimum=0.0, maximum=0.20, label="loop bias", visible=True)
307
+
308
+ sec_str_opt.change(fn=toggle_secondary_structure,
309
+ inputs=[sec_str_opt],
310
+ outputs=[helix_bias,strand_bias,loop_bias,secondary_structure],
311
+ queue=False)
312
+
313
+ with gr.Accordion(label='Amino Acid Compositional Bias',open=False):
314
+ gr.Markdown("""Bias sequence composition for particular amino acids by specifying the one letter code followed by the fraction to bias. This can be input as a list for example: W0.2,E0.1""")
315
+ with gr.Row():
316
+ aa_bias = gr.Textbox(label="aa bias", lines=1, placeholder='specify one letter AA and fraction to bias, for example W0.1 or M0.1,K0.1' )
317
+ aa_bias_potential = gr.Textbox(label="aa bias scale", lines=1, placeholder='AA Bias potential scale (recomended range 1.0-5.0)')
318
+
319
+ '''
320
+ with gr.Accordion(label='Charge Bias',open=False):
321
+ gr.Markdown("""Bias for a specified net charge at a particular pH using the boxes below""")
322
+ with gr.Row():
323
+ target_charge = gr.Textbox(label="net charge", lines=1, placeholder='net charge to target')
324
+ target_ph = gr.Textbox(label="pH", lines=1, placeholder='pH at which net charge is desired')
325
+ charge_potential = gr.Textbox(label="charge potential scale", lines=1, placeholder='charge potential scale (recomended range 1.0-5.0)')
326
+ '''
327
+
328
+ with gr.Accordion(label='Hydrophobic Bias',open=False):
329
+ gr.Markdown("""Bias for or against hydrophobic composition, to get more soluble proteins, bias away with a negative target score (ex. -5)""")
330
+ with gr.Row():
331
+ hydrophobic_target_score = gr.Textbox(label="hydrophobic score", lines=1, placeholder='hydrophobic score to target (negative score is good for solublility)')
332
+ hydrophobic_potential = gr.Textbox(label="hydrophobic potential scale", lines=1, placeholder='hydrophobic potential scale (recomended range 1.0-2.0)')
333
+
334
+ with gr.Accordion(label='Diffusion Params',open=False):
335
+ gr.Markdown("""Increasing T to more steps can be helpful for harder design challenges, sampling from different distributions can change the sequence and structural composition""")
336
+ with gr.Row():
337
+ num_steps = gr.Textbox(label="T", lines=1, placeholder='number of diffusion steps (25 or less will speed things up)')
338
+ noise = gr.Dropdown(['normal','gmm2 [-1,1]','gmm3 [-1,0,1]'], label='noise type', value='normal')
339
+
340
+ btn = gr.Button("GENERATE")
341
+
342
+ #with gr.Row():
343
+ with gr.Column():
344
+ gr.Markdown("""## OUTPUTS""")
345
+ gr.Markdown("""#### Confidence score for generated structure at each timestep""")
346
+ plddt_plot = gr.Plot(label='plddt at step t')
347
+ gr.Markdown("""#### Output protein sequnece""")
348
+ output_seq = gr.Textbox(label="sequence")
349
+ gr.Markdown("""#### Download PDB file""")
350
+ output_pdb = gr.File(label="PDB file")
351
+ gr.Markdown("""#### Structure viewer""")
352
+ output_viewer = gr.HTML()
353
+
354
+ gr.Markdown("""### Don't know where to get started? Click on an example below to try it out!""")
355
+ gr.Examples(
356
+ [["","125",0.0,0.0,0.2,"","","","20","normal",'',''],
357
+ ["","100",0.0,0.0,0.0,"","W0.2","2","20","normal",'',''],
358
+ ["","100",0.0,0.0,0.0,"XXHHHHHHHHHXXXXXXXHHHHHHHHHXXXXXXXHHHHHHHHXXXXSSSSSSSSSSSXXXXXXXXSSSSSSSSSSSSXXXXXXXSSSSSSSSSXXXXXXX","","","25","normal",'',''],
359
+ ["XXXXXXXXXXXXXXXXXXXXXXXXXIPDXXXXXXXXXXXXXXXXXXXXXXPEPSEQXXXXXXXXXXXXXXXXXXXXXXXXXXIPDXXXXXXXXXXXXXXXXXXX","",0.0,0.0,0.0,"","","","25","normal",'','']],
360
+ inputs=[sequence,
361
+ seq_len,
362
+ helix_bias,
363
+ strand_bias,
364
+ loop_bias,
365
+ secondary_structure,
366
+ aa_bias,
367
+ aa_bias_potential,
368
+ #target_charge,
369
+ #target_ph,
370
+ #charge_potential,
371
+ num_steps,
372
+ noise,
373
+ hydrophobic_target_score,
374
+ hydrophobic_potential],
375
+ outputs=[output_seq,
376
+ output_pdb,
377
+ output_viewer,
378
+ plddt_plot],
379
+ fn=protein_diffusion_model,
380
+ )
381
+ btn.click(protein_diffusion_model,
382
+ [sequence,
383
+ seq_len,
384
+ helix_bias,
385
+ strand_bias,
386
+ loop_bias,
387
+ secondary_structure,
388
+ aa_bias,
389
+ aa_bias_potential,
390
+ #target_charge,
391
+ #target_ph,
392
+ #charge_potential,
393
+ num_steps,
394
+ noise,
395
+ hydrophobic_target_score,
396
+ hydrophobic_potential],
397
+ [output_seq,
398
+ output_pdb,
399
+ output_viewer,
400
+ plddt_plot])
401
+
402
+ demo.queue()
403
+ demo.launch(debug=True)
404
+
examples/aa_weights.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "A": 0,
3
+ "R": 0,
4
+ "N": 0,
5
+ "D": 0,
6
+ "C": 0,
7
+ "Q": 0,
8
+ "E": 0,
9
+ "G": 0,
10
+ "H": 0,
11
+ "I": 0,
12
+ "L": 0,
13
+ "K": 0,
14
+ "M": 0,
15
+ "F": 0,
16
+ "P": 0,
17
+ "S": 0,
18
+ "T": 0,
19
+ "W": 0,
20
+ "Y": 0,
21
+ "V": 0
22
+ }
examples/binder_design.sh ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #SBATCH -J seq_diff
3
+ #SBATCH -p gpu
4
+ #SBATCH --mem=8g
5
+ #SBATCH --gres=gpu:a6000:1
6
+ #SBATCH -o ./out/slurm/slurm_%j.out
7
+
8
+ source activate /software/conda/envs/SE3nv
9
+
10
+ srun python ../inference.py \
11
+ --num_designs 10 \
12
+ --out out/binder_design \
13
+ --pdb pdbs/cd86.pdb \
14
+ --T 25 --save_best_plddt \
15
+ --contigs B1-110,0 25-75 \
16
+ --hotspots B40,B32,B87,B96,B30
examples/loop_design.sh ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #SBATCH -J seq_diff
3
+ #SBATCH -p gpu
4
+ #SBATCH --mem=8g
5
+ #SBATCH --gres=gpu:a6000:1
6
+ #SBATCH -o ./out/slurm/slurm_%j.out
7
+
8
+ source activate /software/conda/envs/SE3nv
9
+
10
+ srun python ../inference.py \
11
+ --num_designs 10 \
12
+ --pdb pdbs/G12D_manual_mut.pdb \
13
+ --out out/ab_loop \
14
+ --contigs A2-176,0 C7-16,0 H2-95,12-15,H111-116,0 L1-45,10-12,L56-107 \
15
+ --T 25 --save_best_plddt --loop_design
examples/motif_scaffolding.sh ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #SBATCH -J seq_diff
3
+ #SBATCH -p gpu
4
+ #SBATCH --mem=8g
5
+ #SBATCH --gres=gpu:a6000:1
6
+ #SBATCH -o ./out/slurm/slurm_%j.out
7
+
8
+ source activate /software/conda/envs/SE3nv
9
+
10
+ srun python ../inference.py \
11
+ --num_designs 10 \
12
+ --out out/design \
13
+ --pdb pdbs/rsv5_5tpn.pdb \
14
+ --contigs 0-25,A163-181,25-30 --T 25 --save_best_plddt
examples/out/design_000000.pdb ADDED
The diff for this file is too large to render. See raw diff
 
examples/out/design_000000.trb ADDED
Binary file (3.51 kB). View file
 
examples/partial_diffusion.sh ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #SBATCH -J seq_diff
3
+ #SBATCH -p gpu
4
+ #SBATCH --mem=8g
5
+ #SBATCH --gres=gpu:a6000:1
6
+ #SBATCH -o ./out/slurm/slurm_%j.out
7
+
8
+ source activate /software/conda/envs/SE3nv
9
+
10
+ srun python ../inference.py \
11
+ --num_designs 10 \
12
+ --pdb out/design_000.pdb \
13
+ --trb out/design_000.trb \
14
+ --out out/partial_diffusion_design \
15
+ --contigs 0 --sampling_temp 0.3 --T 50 --save_best_plddt
examples/pdbs/G12D_manual_mut.pdb ADDED
The diff for this file is too large to render. See raw diff
 
examples/pdbs/cd86.pdb ADDED
The diff for this file is too large to render. See raw diff
 
examples/pdbs/rsv5_5tpn.pdb ADDED
The diff for this file is too large to render. See raw diff
 
examples/secondary_structure.sh ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #SBATCH -J seq_diff
3
+ #SBATCH -p gpu
4
+ #SBATCH --mem=8g
5
+ #SBATCH --gres=gpu:a6000:1
6
+ #SBATCH -o ./out/slurm/slurm_%j.out
7
+
8
+ source activate /software/conda/envs/SE3nv
9
+
10
+ srun python ../inference.py \
11
+ --num_designs 10 \
12
+ --out out/design \
13
+ --contigs 100 \
14
+ --T 25 --save_best_plddt \
15
+ --secondary_structure XXXXXHHHHXXXLLLXXXXXXXXXXHHHHXXXLLLXXXXXXXXXXHHHHXXXLLLXXXXXXXXXXHHHHXXXLLLXXXXXXXXXXHHHHXXXLLLXXXXX
16
+
17
+ # FOR SECONDARY STRUCTURE:
18
+ # X - mask
19
+ # H - helix
20
+ # E - strand
21
+ # L - loop
examples/secondary_structure_bias.sh ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #SBATCH -J seq_diff
3
+ #SBATCH -p gpu
4
+ #SBATCH --mem=8g
5
+ #SBATCH --gres=gpu:a6000:1
6
+ #SBATCH -o ./out/slurm/slurm_%j.out
7
+
8
+ source activate /software/conda/envs/SE3nv
9
+
10
+ srun python ../inference.py \
11
+ --num_designs 10 \
12
+ --out out/design \
13
+ --contigs 100 \
14
+ --T 25 --save_best_plddt \
15
+ --helix_bias 0.01 --strand_bias 0.01 --loop_bias 0.0
examples/secondary_structure_from_pdb.sh ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #SBATCH -J seq_diff
3
+ #SBATCH -p gpu
4
+ #SBATCH --mem=8g
5
+ #SBATCH --gres=gpu:a6000:1
6
+ #SBATCH -o ./out/slurm/slurm_%j.out
7
+
8
+ source activate /software/conda/envs/SE3nv
9
+
10
+ srun python ../inference.py \
11
+ --num_designs 10 \
12
+ --out out/design \
13
+ --contigs 110 \
14
+ --T 25 --save_best_plddt \
15
+ --dssp_pdb ./pdbs/cd86.pdb
16
+
17
+ # FOR SECONDARY STRUCTURE:
18
+ # X - mask
19
+ # H - helix
20
+ # E - strand
21
+ # L - loop
examples/symmetric_design.sh ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #SBATCH -J seq_diff
3
+ #SBATCH -p gpu
4
+ #SBATCH --mem=8g
5
+ #SBATCH --gres=gpu:a6000:1
6
+ #SBATCH -o ./out/slurm/slurm_%j.out
7
+
8
+ source activate /software/conda/envs/SE3nv
9
+
10
+ srun python ../inference.py \
11
+ --num_designs 10 \
12
+ --out out/symmetric_design \
13
+ --contigs 25,0 25,0 25,0 \
14
+ --T 50 \
15
+ --save_best_plddt \
16
+ --symmetry 3
examples/weighted_sequence.sh ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #SBATCH -J seq_diff
3
+ #SBATCH -p gpu
4
+ #SBATCH --mem=8g
5
+ #SBATCH --gres=gpu:a6000:1
6
+ #SBATCH -o ./out/slurm/slurm_%j.out
7
+
8
+ source activate /software/conda/envs/SE3nv
9
+
10
+ srun python ../inference.py \
11
+ --num_designs 10 \
12
+ --out out/design \
13
+ --contigs 100 \
14
+ --T 25 --save_best_plddt \
15
+ --aa_composition W0.2 --potential_scale 1.75
examples/weighted_sequence_json.sh ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #SBATCH -J seq_diff
3
+ #SBATCH -p gpu
4
+ #SBATCH --mem=8g
5
+ #SBATCH --gres=gpu:a6000:1
6
+ #SBATCH -o ./out/slurm/slurm_%j.out
7
+
8
+ source activate /software/conda/envs/SE3nv
9
+
10
+ srun python ../inference.py \
11
+ --num_designs 10 \
12
+ --out out/design \
13
+ --contigs 75 \
14
+ --aa_weights_json aa_weights.json \
15
+ --add_weight_every_n 5 --add_weight_every_n \
16
+ --T 25 --save_best_plddt
model/.ipynb_checkpoints/RoseTTAFoldModel-checkpoint.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from Embeddings import MSA_emb, Extra_emb, Templ_emb, Recycling
4
+ from Track_module import IterativeSimulator
5
+ from AuxiliaryPredictor import DistanceNetwork, MaskedTokenNetwork, ExpResolvedNetwork, LDDTNetwork
6
+ from util import INIT_CRDS
7
+ from opt_einsum import contract as einsum
8
+ from icecream import ic
9
+
10
+ class RoseTTAFoldModule(nn.Module):
11
+ def __init__(self, n_extra_block=4, n_main_block=8, n_ref_block=4,\
12
+ d_msa=256, d_msa_full=64, d_pair=128, d_templ=64,
13
+ n_head_msa=8, n_head_pair=4, n_head_templ=4,
14
+ d_hidden=32, d_hidden_templ=64,
15
+ p_drop=0.15, d_t1d=24, d_t2d=44,
16
+ SE3_param_full={'l0_in_features':32, 'l0_out_features':16, 'num_edge_features':32},
17
+ SE3_param_topk={'l0_in_features':32, 'l0_out_features':16, 'num_edge_features':32},
18
+ ):
19
+ super(RoseTTAFoldModule, self).__init__()
20
+ #
21
+ # Input Embeddings
22
+ d_state = SE3_param_topk['l0_out_features']
23
+ self.latent_emb = MSA_emb(d_msa=d_msa, d_pair=d_pair, d_state=d_state, p_drop=p_drop)
24
+ self.full_emb = Extra_emb(d_msa=d_msa_full, d_init=25, p_drop=p_drop)
25
+ self.templ_emb = Templ_emb(d_pair=d_pair, d_templ=d_templ, d_state=d_state,
26
+ n_head=n_head_templ,
27
+ d_hidden=d_hidden_templ, p_drop=0.25, d_t1d=d_t1d, d_t2d=d_t2d)
28
+ # Update inputs with outputs from previous round
29
+ self.recycle = Recycling(d_msa=d_msa, d_pair=d_pair, d_state=d_state)
30
+ #
31
+ self.simulator = IterativeSimulator(n_extra_block=n_extra_block,
32
+ n_main_block=n_main_block,
33
+ n_ref_block=n_ref_block,
34
+ d_msa=d_msa, d_msa_full=d_msa_full,
35
+ d_pair=d_pair, d_hidden=d_hidden,
36
+ n_head_msa=n_head_msa,
37
+ n_head_pair=n_head_pair,
38
+ SE3_param_full=SE3_param_full,
39
+ SE3_param_topk=SE3_param_topk,
40
+ p_drop=p_drop)
41
+ ##
42
+ self.c6d_pred = DistanceNetwork(d_pair, p_drop=p_drop)
43
+ self.aa_pred = MaskedTokenNetwork(d_msa, p_drop=p_drop)
44
+ self.lddt_pred = LDDTNetwork(d_state)
45
+
46
+ self.exp_pred = ExpResolvedNetwork(d_msa, d_state)
47
+
48
+ def forward(self, msa_latent, msa_full, seq, xyz, idx,
49
+ seq1hot=None, t1d=None, t2d=None, xyz_t=None, alpha_t=None,
50
+ msa_prev=None, pair_prev=None, state_prev=None,
51
+ return_raw=False, return_full=False,
52
+ use_checkpoint=False, return_infer=False):
53
+ B, N, L = msa_latent.shape[:3]
54
+ # Get embeddings
55
+ #ic(seq.shape)
56
+ #ic(msa_latent.shape)
57
+ #ic(seq1hot.shape)
58
+ #ic(idx.shape)
59
+ #ic(xyz.shape)
60
+ #ic(seq1hot.shape)
61
+ #ic(t1d.shape)
62
+ #ic(t2d.shape)
63
+
64
+ idx = idx.long()
65
+ msa_latent, pair, state = self.latent_emb(msa_latent, seq, idx, seq1hot=seq1hot)
66
+
67
+ msa_full = self.full_emb(msa_full, seq, idx, seq1hot=seq1hot)
68
+ #
69
+ # Do recycling
70
+ if msa_prev == None:
71
+ msa_prev = torch.zeros_like(msa_latent[:,0])
72
+ if pair_prev == None:
73
+ pair_prev = torch.zeros_like(pair)
74
+ if state_prev == None:
75
+ state_prev = torch.zeros_like(state)
76
+
77
+ #ic(seq.shape)
78
+ #ic(msa_prev.shape)
79
+ #ic(pair_prev.shape)
80
+ #ic(xyz.shape)
81
+ #ic(state_prev.shape)
82
+
83
+
84
+ msa_recycle, pair_recycle, state_recycle = self.recycle(seq, msa_prev, pair_prev, xyz, state_prev)
85
+ msa_latent[:,0] = msa_latent[:,0] + msa_recycle.reshape(B,L,-1)
86
+ pair = pair + pair_recycle
87
+ state = state + state_recycle
88
+ #
89
+ #ic(t1d.dtype)
90
+ #ic(t2d.dtype)
91
+ #ic(alpha_t.dtype)
92
+ #ic(xyz_t.dtype)
93
+ #ic(pair.dtype)
94
+ #ic(state.dtype)
95
+
96
+
97
+ #import pdb; pdb.set_trace()
98
+
99
+ # add template embedding
100
+ pair, state = self.templ_emb(t1d, t2d, alpha_t, xyz_t, pair, state, use_checkpoint=use_checkpoint)
101
+
102
+ #ic(seq.dtype)
103
+ #ic(msa_latent.dtype)
104
+ #ic(msa_full.dtype)
105
+ #ic(pair.dtype)
106
+ #ic(xyz.dtype)
107
+ #ic(state.dtype)
108
+ #ic(idx.dtype)
109
+
110
+ # Predict coordinates from given inputs
111
+ msa, pair, R, T, alpha_s, state = self.simulator(seq, msa_latent, msa_full.type(torch.float32), pair, xyz[:,:,:3],
112
+ state, idx, use_checkpoint=use_checkpoint)
113
+
114
+ if return_raw:
115
+ # get last structure
116
+ xyz = einsum('bnij,bnaj->bnai', R[-1], xyz[:,:,:3]-xyz[:,:,1].unsqueeze(-2)) + T[-1].unsqueeze(-2)
117
+ return msa[:,0], pair, xyz, state, alpha_s[-1]
118
+
119
+ # predict masked amino acids
120
+ logits_aa = self.aa_pred(msa)
121
+ #
122
+ # predict distogram & orientograms
123
+ logits = self.c6d_pred(pair)
124
+
125
+ # Predict LDDT
126
+ lddt = self.lddt_pred(state)
127
+
128
+ # predict experimentally resolved or not
129
+ logits_exp = self.exp_pred(msa[:,0], state)
130
+
131
+ if return_infer:
132
+ #get last structure
133
+ xyz = einsum('bnij,bnaj->bnai', R[-1], xyz[:,:,:3]-xyz[:,:,1].unsqueeze(-2)) + T[-1].unsqueeze(-2)
134
+ return logits, logits_aa, logits_exp, xyz, lddt, msa[:,0], pair, state, alpha_s[-1]
135
+
136
+
137
+ # get all intermediate bb structures
138
+ xyz = einsum('rbnij,bnaj->rbnai', R, xyz[:,:,:3]-xyz[:,:,1].unsqueeze(-2)) + T.unsqueeze(-2)
139
+
140
+ return logits, logits_aa, logits_exp, xyz, alpha_s, lddt
model/Attention_module.py ADDED
@@ -0,0 +1,411 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import math
5
+ from opt_einsum import contract as einsum
6
+ from util_module import init_lecun_normal
7
+ from icecream import ic
8
+
9
+ class FeedForwardLayer(nn.Module):
10
+ def __init__(self, d_model, r_ff, p_drop=0.1):
11
+ super(FeedForwardLayer, self).__init__()
12
+ self.norm = nn.LayerNorm(d_model)
13
+ self.linear1 = nn.Linear(d_model, d_model*r_ff)
14
+ self.dropout = nn.Dropout(p_drop)
15
+ self.linear2 = nn.Linear(d_model*r_ff, d_model)
16
+
17
+ self.reset_parameter()
18
+
19
+ def reset_parameter(self):
20
+ # initialize linear layer right before ReLu: He initializer (kaiming normal)
21
+ nn.init.kaiming_normal_(self.linear1.weight, nonlinearity='relu')
22
+ nn.init.zeros_(self.linear1.bias)
23
+
24
+ # initialize linear layer right before residual connection: zero initialize
25
+ nn.init.zeros_(self.linear2.weight)
26
+ nn.init.zeros_(self.linear2.bias)
27
+
28
+ def forward(self, src):
29
+ src = self.norm(src)
30
+ src = self.linear2(self.dropout(F.relu_(self.linear1(src))))
31
+ return src
32
+
33
+ class Attention(nn.Module):
34
+ # calculate multi-head attention
35
+ def __init__(self, d_query, d_key, n_head, d_hidden, d_out, p_drop=0.1):
36
+ super(Attention, self).__init__()
37
+ self.h = n_head
38
+ self.dim = d_hidden
39
+ #
40
+ self.to_q = nn.Linear(d_query, n_head*d_hidden, bias=False)
41
+ self.to_k = nn.Linear(d_key, n_head*d_hidden, bias=False)
42
+ self.to_v = nn.Linear(d_key, n_head*d_hidden, bias=False)
43
+ #
44
+ self.to_out = nn.Linear(n_head*d_hidden, d_out)
45
+ self.scaling = 1/math.sqrt(d_hidden)
46
+ #
47
+ # initialize all parameters properly
48
+ self.reset_parameter()
49
+
50
+ def reset_parameter(self):
51
+ # query/key/value projection: Glorot uniform / Xavier uniform
52
+ nn.init.xavier_uniform_(self.to_q.weight)
53
+ nn.init.xavier_uniform_(self.to_k.weight)
54
+ nn.init.xavier_uniform_(self.to_v.weight)
55
+
56
+ # to_out: right before residual connection: zero initialize -- to make it sure residual operation is same to the Identity at the begining
57
+ nn.init.zeros_(self.to_out.weight)
58
+ nn.init.zeros_(self.to_out.bias)
59
+
60
+ def forward(self, query, key, value):
61
+ B, Q = query.shape[:2]
62
+ B, K = key.shape[:2]
63
+ #
64
+ query = self.to_q(query).reshape(B, Q, self.h, self.dim)
65
+ key = self.to_k(key).reshape(B, K, self.h, self.dim)
66
+ value = self.to_v(value).reshape(B, K, self.h, self.dim)
67
+ #
68
+ query = query * self.scaling
69
+ attn = einsum('bqhd,bkhd->bhqk', query, key)
70
+ attn = F.softmax(attn, dim=-1)
71
+ #
72
+ out = einsum('bhqk,bkhd->bqhd', attn, value)
73
+ out = out.reshape(B, Q, self.h*self.dim)
74
+ #
75
+ out = self.to_out(out)
76
+
77
+ return out
78
+
79
+ class AttentionWithBias(nn.Module):
80
+ def __init__(self, d_in=256, d_bias=128, n_head=8, d_hidden=32):
81
+ super(AttentionWithBias, self).__init__()
82
+ self.norm_in = nn.LayerNorm(d_in)
83
+ self.norm_bias = nn.LayerNorm(d_bias)
84
+ #
85
+ self.to_q = nn.Linear(d_in, n_head*d_hidden, bias=False)
86
+ self.to_k = nn.Linear(d_in, n_head*d_hidden, bias=False)
87
+ self.to_v = nn.Linear(d_in, n_head*d_hidden, bias=False)
88
+ self.to_b = nn.Linear(d_bias, n_head, bias=False)
89
+ self.to_g = nn.Linear(d_in, n_head*d_hidden)
90
+ self.to_out = nn.Linear(n_head*d_hidden, d_in)
91
+
92
+ self.scaling = 1/math.sqrt(d_hidden)
93
+ self.h = n_head
94
+ self.dim = d_hidden
95
+
96
+ self.reset_parameter()
97
+
98
+ def reset_parameter(self):
99
+ # query/key/value projection: Glorot uniform / Xavier uniform
100
+ nn.init.xavier_uniform_(self.to_q.weight)
101
+ nn.init.xavier_uniform_(self.to_k.weight)
102
+ nn.init.xavier_uniform_(self.to_v.weight)
103
+
104
+ # bias: normal distribution
105
+ self.to_b = init_lecun_normal(self.to_b)
106
+
107
+ # gating: zero weights, one biases (mostly open gate at the begining)
108
+ nn.init.zeros_(self.to_g.weight)
109
+ nn.init.ones_(self.to_g.bias)
110
+
111
+ # to_out: right before residual connection: zero initialize -- to make it sure residual operation is same to the Identity at the begining
112
+ nn.init.zeros_(self.to_out.weight)
113
+ nn.init.zeros_(self.to_out.bias)
114
+
115
+ def forward(self, x, bias):
116
+ B, L = x.shape[:2]
117
+ #
118
+ x = self.norm_in(x)
119
+ bias = self.norm_bias(bias)
120
+ #
121
+ query = self.to_q(x).reshape(B, L, self.h, self.dim)
122
+ key = self.to_k(x).reshape(B, L, self.h, self.dim)
123
+ value = self.to_v(x).reshape(B, L, self.h, self.dim)
124
+ bias = self.to_b(bias) # (B, L, L, h)
125
+ gate = torch.sigmoid(self.to_g(x))
126
+ #
127
+ key = key * self.scaling
128
+ attn = einsum('bqhd,bkhd->bqkh', query, key)
129
+ attn = attn + bias
130
+ attn = F.softmax(attn, dim=-2)
131
+ #
132
+ out = einsum('bqkh,bkhd->bqhd', attn, value).reshape(B, L, -1)
133
+ out = gate * out
134
+ #
135
+ out = self.to_out(out)
136
+ return out
137
+
138
+ # MSA Attention (row/column) from AlphaFold architecture
139
+ class SequenceWeight(nn.Module):
140
+ def __init__(self, d_msa, n_head, d_hidden, p_drop=0.1):
141
+ super(SequenceWeight, self).__init__()
142
+ self.h = n_head
143
+ self.dim = d_hidden
144
+ self.scale = 1.0 / math.sqrt(self.dim)
145
+
146
+ self.to_query = nn.Linear(d_msa, n_head*d_hidden)
147
+ self.to_key = nn.Linear(d_msa, n_head*d_hidden)
148
+ self.dropout = nn.Dropout(p_drop)
149
+
150
+ self.reset_parameter()
151
+
152
+ def reset_parameter(self):
153
+ # query/key/value projection: Glorot uniform / Xavier uniform
154
+ nn.init.xavier_uniform_(self.to_query.weight)
155
+ nn.init.xavier_uniform_(self.to_key.weight)
156
+
157
+ def forward(self, msa):
158
+ B, N, L = msa.shape[:3]
159
+
160
+ tar_seq = msa[:,0]
161
+
162
+ q = self.to_query(tar_seq).view(B, 1, L, self.h, self.dim)
163
+ k = self.to_key(msa).view(B, N, L, self.h, self.dim)
164
+
165
+ q = q * self.scale
166
+ attn = einsum('bqihd,bkihd->bkihq', q, k)
167
+ attn = F.softmax(attn, dim=1)
168
+ return self.dropout(attn)
169
+
170
+ class MSARowAttentionWithBias(nn.Module):
171
+ def __init__(self, d_msa=256, d_pair=128, n_head=8, d_hidden=32):
172
+ super(MSARowAttentionWithBias, self).__init__()
173
+ self.norm_msa = nn.LayerNorm(d_msa)
174
+ self.norm_pair = nn.LayerNorm(d_pair)
175
+ #
176
+ self.seq_weight = SequenceWeight(d_msa, n_head, d_hidden, p_drop=0.1)
177
+ self.to_q = nn.Linear(d_msa, n_head*d_hidden, bias=False)
178
+ self.to_k = nn.Linear(d_msa, n_head*d_hidden, bias=False)
179
+ self.to_v = nn.Linear(d_msa, n_head*d_hidden, bias=False)
180
+ self.to_b = nn.Linear(d_pair, n_head, bias=False)
181
+ self.to_g = nn.Linear(d_msa, n_head*d_hidden)
182
+ self.to_out = nn.Linear(n_head*d_hidden, d_msa)
183
+
184
+ self.scaling = 1/math.sqrt(d_hidden)
185
+ self.h = n_head
186
+ self.dim = d_hidden
187
+
188
+ self.reset_parameter()
189
+
190
+ def reset_parameter(self):
191
+ # query/key/value projection: Glorot uniform / Xavier uniform
192
+ nn.init.xavier_uniform_(self.to_q.weight)
193
+ nn.init.xavier_uniform_(self.to_k.weight)
194
+ nn.init.xavier_uniform_(self.to_v.weight)
195
+
196
+ # bias: normal distribution
197
+ self.to_b = init_lecun_normal(self.to_b)
198
+
199
+ # gating: zero weights, one biases (mostly open gate at the begining)
200
+ nn.init.zeros_(self.to_g.weight)
201
+ nn.init.ones_(self.to_g.bias)
202
+
203
+ # to_out: right before residual connection: zero initialize -- to make it sure residual operation is same to the Identity at the begining
204
+ nn.init.zeros_(self.to_out.weight)
205
+ nn.init.zeros_(self.to_out.bias)
206
+
207
+ def forward(self, msa, pair): # TODO: make this as tied-attention
208
+ B, N, L = msa.shape[:3]
209
+ #
210
+ msa = self.norm_msa(msa)
211
+ pair = self.norm_pair(pair)
212
+ #
213
+ seq_weight = self.seq_weight(msa) # (B, N, L, h, 1)
214
+ query = self.to_q(msa).reshape(B, N, L, self.h, self.dim)
215
+ key = self.to_k(msa).reshape(B, N, L, self.h, self.dim)
216
+ value = self.to_v(msa).reshape(B, N, L, self.h, self.dim)
217
+ bias = self.to_b(pair) # (B, L, L, h)
218
+ gate = torch.sigmoid(self.to_g(msa))
219
+ #
220
+ query = query * seq_weight.expand(-1, -1, -1, -1, self.dim)
221
+ key = key * self.scaling
222
+ attn = einsum('bsqhd,bskhd->bqkh', query, key)
223
+ attn = attn + bias
224
+ attn = F.softmax(attn, dim=-2)
225
+ #
226
+ out = einsum('bqkh,bskhd->bsqhd', attn, value).reshape(B, N, L, -1)
227
+ out = gate * out
228
+ #
229
+ out = self.to_out(out)
230
+ return out
231
+
232
+ class MSAColAttention(nn.Module):
233
+ def __init__(self, d_msa=256, n_head=8, d_hidden=32):
234
+ super(MSAColAttention, self).__init__()
235
+ self.norm_msa = nn.LayerNorm(d_msa)
236
+ #
237
+ self.to_q = nn.Linear(d_msa, n_head*d_hidden, bias=False)
238
+ self.to_k = nn.Linear(d_msa, n_head*d_hidden, bias=False)
239
+ self.to_v = nn.Linear(d_msa, n_head*d_hidden, bias=False)
240
+ self.to_g = nn.Linear(d_msa, n_head*d_hidden)
241
+ self.to_out = nn.Linear(n_head*d_hidden, d_msa)
242
+
243
+ self.scaling = 1/math.sqrt(d_hidden)
244
+ self.h = n_head
245
+ self.dim = d_hidden
246
+
247
+ self.reset_parameter()
248
+
249
+ def reset_parameter(self):
250
+ # query/key/value projection: Glorot uniform / Xavier uniform
251
+ nn.init.xavier_uniform_(self.to_q.weight)
252
+ nn.init.xavier_uniform_(self.to_k.weight)
253
+ nn.init.xavier_uniform_(self.to_v.weight)
254
+
255
+ # gating: zero weights, one biases (mostly open gate at the begining)
256
+ nn.init.zeros_(self.to_g.weight)
257
+ nn.init.ones_(self.to_g.bias)
258
+
259
+ # to_out: right before residual connection: zero initialize -- to make it sure residual operation is same to the Identity at the begining
260
+ nn.init.zeros_(self.to_out.weight)
261
+ nn.init.zeros_(self.to_out.bias)
262
+
263
+ def forward(self, msa):
264
+ B, N, L = msa.shape[:3]
265
+ #
266
+ msa = self.norm_msa(msa)
267
+ #
268
+ query = self.to_q(msa).reshape(B, N, L, self.h, self.dim)
269
+ key = self.to_k(msa).reshape(B, N, L, self.h, self.dim)
270
+ value = self.to_v(msa).reshape(B, N, L, self.h, self.dim)
271
+ gate = torch.sigmoid(self.to_g(msa))
272
+ #
273
+ query = query * self.scaling
274
+ attn = einsum('bqihd,bkihd->bihqk', query, key)
275
+ attn = F.softmax(attn, dim=-1)
276
+ #
277
+ out = einsum('bihqk,bkihd->bqihd', attn, value).reshape(B, N, L, -1)
278
+ out = gate * out
279
+ #
280
+ out = self.to_out(out)
281
+ return out
282
+
283
+ class MSAColGlobalAttention(nn.Module):
284
+ def __init__(self, d_msa=64, n_head=8, d_hidden=8):
285
+ super(MSAColGlobalAttention, self).__init__()
286
+ self.norm_msa = nn.LayerNorm(d_msa)
287
+ #
288
+ self.to_q = nn.Linear(d_msa, n_head*d_hidden, bias=False)
289
+ self.to_k = nn.Linear(d_msa, d_hidden, bias=False)
290
+ self.to_v = nn.Linear(d_msa, d_hidden, bias=False)
291
+ self.to_g = nn.Linear(d_msa, n_head*d_hidden)
292
+ self.to_out = nn.Linear(n_head*d_hidden, d_msa)
293
+
294
+ self.scaling = 1/math.sqrt(d_hidden)
295
+ self.h = n_head
296
+ self.dim = d_hidden
297
+
298
+ self.reset_parameter()
299
+
300
+ def reset_parameter(self):
301
+ # query/key/value projection: Glorot uniform / Xavier uniform
302
+ nn.init.xavier_uniform_(self.to_q.weight)
303
+ nn.init.xavier_uniform_(self.to_k.weight)
304
+ nn.init.xavier_uniform_(self.to_v.weight)
305
+
306
+ # gating: zero weights, one biases (mostly open gate at the begining)
307
+ nn.init.zeros_(self.to_g.weight)
308
+ nn.init.ones_(self.to_g.bias)
309
+
310
+ # to_out: right before residual connection: zero initialize -- to make it sure residual operation is same to the Identity at the begining
311
+ nn.init.zeros_(self.to_out.weight)
312
+ nn.init.zeros_(self.to_out.bias)
313
+
314
+ def forward(self, msa):
315
+ B, N, L = msa.shape[:3]
316
+ #
317
+ msa = self.norm_msa(msa)
318
+ #
319
+ query = self.to_q(msa).reshape(B, N, L, self.h, self.dim)
320
+ query = query.mean(dim=1) # (B, L, h, dim)
321
+ key = self.to_k(msa) # (B, N, L, dim)
322
+ value = self.to_v(msa) # (B, N, L, dim)
323
+ gate = torch.sigmoid(self.to_g(msa)) # (B, N, L, h*dim)
324
+ #
325
+ query = query * self.scaling
326
+ attn = einsum('bihd,bkid->bihk', query, key) # (B, L, h, N)
327
+ attn = F.softmax(attn, dim=-1)
328
+ #
329
+ out = einsum('bihk,bkid->bihd', attn, value).reshape(B, 1, L, -1) # (B, 1, L, h*dim)
330
+ out = gate * out # (B, N, L, h*dim)
331
+ #
332
+ out = self.to_out(out)
333
+ return out
334
+
335
+ # Instead of triangle attention, use Tied axail attention with bias from coordinates..?
336
+ class BiasedAxialAttention(nn.Module):
337
+ def __init__(self, d_pair, d_bias, n_head, d_hidden, p_drop=0.1, is_row=True):
338
+ super(BiasedAxialAttention, self).__init__()
339
+ #
340
+ self.is_row = is_row
341
+ self.norm_pair = nn.LayerNorm(d_pair)
342
+ self.norm_bias = nn.LayerNorm(d_bias)
343
+
344
+ self.to_q = nn.Linear(d_pair, n_head*d_hidden, bias=False)
345
+ self.to_k = nn.Linear(d_pair, n_head*d_hidden, bias=False)
346
+ self.to_v = nn.Linear(d_pair, n_head*d_hidden, bias=False)
347
+ self.to_b = nn.Linear(d_bias, n_head, bias=False)
348
+ self.to_g = nn.Linear(d_pair, n_head*d_hidden)
349
+ self.to_out = nn.Linear(n_head*d_hidden, d_pair)
350
+
351
+ self.scaling = 1/math.sqrt(d_hidden)
352
+ self.h = n_head
353
+ self.dim = d_hidden
354
+
355
+ # initialize all parameters properly
356
+ self.reset_parameter()
357
+
358
+ def reset_parameter(self):
359
+ # query/key/value projection: Glorot uniform / Xavier uniform
360
+ nn.init.xavier_uniform_(self.to_q.weight)
361
+ nn.init.xavier_uniform_(self.to_k.weight)
362
+ nn.init.xavier_uniform_(self.to_v.weight)
363
+
364
+ # bias: normal distribution
365
+ self.to_b = init_lecun_normal(self.to_b)
366
+
367
+ # gating: zero weights, one biases (mostly open gate at the begining)
368
+ nn.init.zeros_(self.to_g.weight)
369
+ nn.init.ones_(self.to_g.bias)
370
+
371
+ # to_out: right before residual connection: zero initialize -- to make it sure residual operation is same to the Identity at the begining
372
+ nn.init.zeros_(self.to_out.weight)
373
+ nn.init.zeros_(self.to_out.bias)
374
+
375
+ def forward(self, pair, bias, same_chain = None):
376
+ # pair: (B, L, L, d_pair)
377
+ B, L = pair.shape[:2]
378
+
379
+ if self.is_row:
380
+ pair = pair.permute(0,2,1,3)
381
+ bias = bias.permute(0,2,1,3)
382
+
383
+ pair = self.norm_pair(pair)
384
+ bias = self.norm_bias(bias)
385
+
386
+ query = self.to_q(pair).reshape(B, L, L, self.h, self.dim)
387
+ key = self.to_k(pair).reshape(B, L, L, self.h, self.dim)
388
+ value = self.to_v(pair).reshape(B, L, L, self.h, self.dim)
389
+ bias = self.to_b(bias) # (B, L, L, h)
390
+ gate = torch.sigmoid(self.to_g(pair)) # (B, L, L, h*dim)
391
+
392
+ query = query * self.scaling
393
+ key = key / math.sqrt(L) # normalize for tied attention
394
+ attn = einsum('bnihk,bnjhk->bijh', query, key) # tied attention
395
+ attn = attn + bias # apply bias
396
+ attn = F.softmax(attn, dim=-2) # (B, L, L, h)
397
+
398
+ if same_chain is not None:
399
+ ic(same_chain)
400
+ ic(attn)
401
+ ic(attn[~same_chain])
402
+ attn[~same_chain] *= 1.1
403
+
404
+ out = einsum('bijh,bkjhd->bikhd', attn, value).reshape(B, L, L, -1)
405
+ out = gate * out
406
+
407
+ out = self.to_out(out)
408
+ if self.is_row:
409
+ out = out.permute(0,2,1,3)
410
+ return out
411
+
model/AuxiliaryPredictor.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class DistanceNetwork(nn.Module):
5
+ def __init__(self, n_feat, p_drop=0.1):
6
+ super(DistanceNetwork, self).__init__()
7
+ #
8
+ self.proj_symm = nn.Linear(n_feat, 37*2)
9
+ self.proj_asymm = nn.Linear(n_feat, 37+19)
10
+
11
+ self.reset_parameter()
12
+
13
+ def reset_parameter(self):
14
+ # initialize linear layer for final logit prediction
15
+ nn.init.zeros_(self.proj_symm.weight)
16
+ nn.init.zeros_(self.proj_asymm.weight)
17
+ nn.init.zeros_(self.proj_symm.bias)
18
+ nn.init.zeros_(self.proj_asymm.bias)
19
+
20
+ def forward(self, x):
21
+ # input: pair info (B, L, L, C)
22
+
23
+ # predict theta, phi (non-symmetric)
24
+ logits_asymm = self.proj_asymm(x)
25
+ logits_theta = logits_asymm[:,:,:,:37].permute(0,3,1,2)
26
+ logits_phi = logits_asymm[:,:,:,37:].permute(0,3,1,2)
27
+
28
+ # predict dist, omega
29
+ logits_symm = self.proj_symm(x)
30
+ logits_symm = logits_symm + logits_symm.permute(0,2,1,3)
31
+ logits_dist = logits_symm[:,:,:,:37].permute(0,3,1,2)
32
+ logits_omega = logits_symm[:,:,:,37:].permute(0,3,1,2)
33
+
34
+ return logits_dist, logits_omega, logits_theta, logits_phi
35
+
36
+ class MaskedTokenNetwork(nn.Module):
37
+ def __init__(self, n_feat, p_drop=0.1):
38
+ super(MaskedTokenNetwork, self).__init__()
39
+ self.proj = nn.Linear(n_feat, 21)
40
+
41
+ self.reset_parameter()
42
+
43
+ def reset_parameter(self):
44
+ nn.init.zeros_(self.proj.weight)
45
+ nn.init.zeros_(self.proj.bias)
46
+
47
+ def forward(self, x):
48
+ B, N, L = x.shape[:3]
49
+ logits = self.proj(x).permute(0,3,1,2).reshape(B, -1, N*L)
50
+
51
+ return logits
52
+
53
+ class LDDTNetwork(nn.Module):
54
+ def __init__(self, n_feat, n_bin_lddt=50):
55
+ super(LDDTNetwork, self).__init__()
56
+ self.proj = nn.Linear(n_feat, n_bin_lddt)
57
+
58
+ self.reset_parameter()
59
+
60
+ def reset_parameter(self):
61
+ nn.init.zeros_(self.proj.weight)
62
+ nn.init.zeros_(self.proj.bias)
63
+
64
+ def forward(self, x):
65
+ logits = self.proj(x) # (B, L, 50)
66
+
67
+ return logits.permute(0,2,1)
68
+
69
+ class ExpResolvedNetwork(nn.Module):
70
+ def __init__(self, d_msa, d_state, p_drop=0.1):
71
+ super(ExpResolvedNetwork, self).__init__()
72
+ self.norm_msa = nn.LayerNorm(d_msa)
73
+ self.norm_state = nn.LayerNorm(d_state)
74
+ self.proj = nn.Linear(d_msa+d_state, 1)
75
+
76
+ self.reset_parameter()
77
+
78
+ def reset_parameter(self):
79
+ nn.init.zeros_(self.proj.weight)
80
+ nn.init.zeros_(self.proj.bias)
81
+
82
+ def forward(self, seq, state):
83
+ B, L = seq.shape[:2]
84
+
85
+ seq = self.norm_msa(seq)
86
+ state = self.norm_state(state)
87
+ feat = torch.cat((seq, state), dim=-1)
88
+ logits = self.proj(feat)
89
+ return logits.reshape(B, L)
90
+
91
+
92
+
model/Embeddings.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from opt_einsum import contract as einsum
5
+ import torch.utils.checkpoint as checkpoint
6
+ from util import get_tips
7
+ from util_module import Dropout, create_custom_forward, rbf, init_lecun_normal
8
+ from Attention_module import Attention, FeedForwardLayer, AttentionWithBias
9
+ from Track_module import PairStr2Pair
10
+ from icecream import ic
11
+
12
+ # Module contains classes and functions to generate initial embeddings
13
+
14
+ class PositionalEncoding2D(nn.Module):
15
+ # Add relative positional encoding to pair features
16
+ def __init__(self, d_model, minpos=-32, maxpos=32, p_drop=0.1):
17
+ super(PositionalEncoding2D, self).__init__()
18
+ self.minpos = minpos
19
+ self.maxpos = maxpos
20
+ self.nbin = abs(minpos)+maxpos+1
21
+ self.emb = nn.Embedding(self.nbin, d_model)
22
+ self.drop = nn.Dropout(p_drop)
23
+
24
+ def forward(self, x, idx):
25
+ bins = torch.arange(self.minpos, self.maxpos, device=x.device)
26
+ seqsep = idx[:,None,:] - idx[:,:,None] # (B, L, L)
27
+ #
28
+ ib = torch.bucketize(seqsep, bins).long() # (B, L, L)
29
+ emb = self.emb(ib) #(B, L, L, d_model)
30
+ x = x + emb # add relative positional encoding
31
+ return self.drop(x)
32
+
33
+ class MSA_emb(nn.Module):
34
+ # Get initial seed MSA embedding
35
+ def __init__(self, d_msa=256, d_pair=128, d_state=32, d_init=22+22+2+2,
36
+ minpos=-32, maxpos=32, p_drop=0.1):
37
+ super(MSA_emb, self).__init__()
38
+ self.emb = nn.Linear(d_init, d_msa) # embedding for general MSA
39
+ self.emb_q = nn.Embedding(22, d_msa) # embedding for query sequence -- used for MSA embedding
40
+ self.emb_left = nn.Embedding(22, d_pair) # embedding for query sequence -- used for pair embedding
41
+ self.emb_right = nn.Embedding(22, d_pair) # embedding for query sequence -- used for pair embedding
42
+ self.emb_state = nn.Embedding(22, d_state)
43
+ self.drop = nn.Dropout(p_drop)
44
+ self.pos = PositionalEncoding2D(d_pair, minpos=minpos, maxpos=maxpos, p_drop=p_drop)
45
+
46
+ self.reset_parameter()
47
+
48
+ def reset_parameter(self):
49
+ self.emb = init_lecun_normal(self.emb)
50
+ self.emb_q = init_lecun_normal(self.emb_q)
51
+ self.emb_left = init_lecun_normal(self.emb_left)
52
+ self.emb_right = init_lecun_normal(self.emb_right)
53
+ self.emb_state = init_lecun_normal(self.emb_state)
54
+
55
+ nn.init.zeros_(self.emb.bias)
56
+
57
+ def forward(self, msa, seq, idx, seq1hot=None):
58
+ # Inputs:
59
+ # - msa: Input MSA (B, N, L, d_init)
60
+ # - seq: Input Sequence (B, L)
61
+ # - idx: Residue index
62
+ # Outputs:
63
+ # - msa: Initial MSA embedding (B, N, L, d_msa)
64
+ # - pair: Initial Pair embedding (B, L, L, d_pair)
65
+
66
+ N = msa.shape[1] # number of sequenes in MSA
67
+
68
+ # msa embedding
69
+ msa = self.emb(msa) # (B, N, L, d_model) # MSA embedding
70
+ seq = seq.long()
71
+ tmp = self.emb_q(seq).unsqueeze(1) # (B, 1, L, d_model) -- query embedding
72
+ msa = msa + tmp.expand(-1, N, -1, -1) # adding query embedding to MSA
73
+ msa = self.drop(msa)
74
+
75
+ # pair embedding
76
+ if seq1hot is not None:
77
+ left = (seq1hot @ self.emb_left.weight)[:,None] # (B, 1, L, d_pair)
78
+ right = (seq1hot @ self.emb_right.weight)[:,:,None] # (B, L, 1, d_pair)
79
+ else:
80
+ left = self.emb_left(seq)[:,None] # (B, 1, L, d_pair)
81
+ right = self.emb_right(seq)[:,:,None] # (B, L, 1, d_pair)
82
+ #ic(torch.norm(self.emb_left.weight, dim=1))
83
+ #ic(torch.norm(self.emb_right.weight, dim=1))
84
+ pair = left + right # (B, L, L, d_pair)
85
+ pair = self.pos(pair, idx) # add relative position
86
+
87
+ # state embedding
88
+ state = self.drop(self.emb_state(seq))
89
+
90
+ return msa, pair, state
91
+
92
+ class Extra_emb(nn.Module):
93
+ # Get initial seed MSA embedding
94
+ def __init__(self, d_msa=256, d_init=22+1+2, p_drop=0.1):
95
+ super(Extra_emb, self).__init__()
96
+ self.emb = nn.Linear(d_init, d_msa) # embedding for general MSA
97
+ self.emb_q = nn.Embedding(22, d_msa) # embedding for query sequence
98
+ self.drop = nn.Dropout(p_drop)
99
+
100
+ self.reset_parameter()
101
+
102
+ def reset_parameter(self):
103
+ self.emb = init_lecun_normal(self.emb)
104
+ nn.init.zeros_(self.emb.bias)
105
+
106
+ def forward(self, msa, seq, idx, seq1hot=None):
107
+ # Inputs:
108
+ # - msa: Input MSA (B, N, L, d_init)
109
+ # - seq: Input Sequence (B, L)
110
+ # - idx: Residue index
111
+ # Outputs:
112
+ # - msa: Initial MSA embedding (B, N, L, d_msa)
113
+ N = msa.shape[1] # number of sequenes in MSA
114
+ msa = self.emb(msa) # (B, N, L, d_model) # MSA embedding
115
+ if seq1hot is not None:
116
+ seq = (seq1hot @ self.emb_q.weight).unsqueeze(1) # (B, 1, L, d_model) -- query embedding
117
+ else:
118
+ seq = self.emb_q(seq).unsqueeze(1) # (B, 1, L, d_model) -- query embedding
119
+ #ic(torch.norm(self.emb_q.weight, dim=1))
120
+ msa = msa + seq.expand(-1, N, -1, -1) # adding query embedding to MSA
121
+ return self.drop(msa)
122
+
123
+ class TemplatePairStack(nn.Module):
124
+ # process template pairwise features
125
+ # use structure-biased attention
126
+ def __init__(self, n_block=2, d_templ=64, n_head=4, d_hidden=16, p_drop=0.25):
127
+ super(TemplatePairStack, self).__init__()
128
+ self.n_block = n_block
129
+ proc_s = [PairStr2Pair(d_pair=d_templ, n_head=n_head, d_hidden=d_hidden, p_drop=p_drop) for i in range(n_block)]
130
+ self.block = nn.ModuleList(proc_s)
131
+ self.norm = nn.LayerNorm(d_templ)
132
+ def forward(self, templ, rbf_feat, use_checkpoint=False):
133
+ B, T, L = templ.shape[:3]
134
+ templ = templ.reshape(B*T, L, L, -1)
135
+
136
+ for i_block in range(self.n_block):
137
+ if use_checkpoint:
138
+ templ = checkpoint.checkpoint(create_custom_forward(self.block[i_block]), templ, rbf_feat)
139
+ else:
140
+ templ = self.block[i_block](templ, rbf_feat)
141
+ return self.norm(templ).reshape(B, T, L, L, -1)
142
+
143
+ class TemplateTorsionStack(nn.Module):
144
+ def __init__(self, n_block=2, d_templ=64, n_head=4, d_hidden=16, p_drop=0.15):
145
+ super(TemplateTorsionStack, self).__init__()
146
+ self.n_block=n_block
147
+ self.proj_pair = nn.Linear(d_templ+36, d_templ)
148
+ proc_s = [AttentionWithBias(d_in=d_templ, d_bias=d_templ,
149
+ n_head=n_head, d_hidden=d_hidden) for i in range(n_block)]
150
+ self.row_attn = nn.ModuleList(proc_s)
151
+ proc_s = [FeedForwardLayer(d_templ, 4, p_drop=p_drop) for i in range(n_block)]
152
+ self.ff = nn.ModuleList(proc_s)
153
+ self.norm = nn.LayerNorm(d_templ)
154
+
155
+ def reset_parameter(self):
156
+ self.proj_pair = init_lecun_normal(self.proj_pair)
157
+ nn.init.zeros_(self.proj_pair.bias)
158
+
159
+ def forward(self, tors, pair, rbf_feat, use_checkpoint=False):
160
+ B, T, L = tors.shape[:3]
161
+ tors = tors.reshape(B*T, L, -1)
162
+ pair = pair.reshape(B*T, L, L, -1)
163
+ pair = torch.cat((pair, rbf_feat), dim=-1)
164
+ pair = self.proj_pair(pair)
165
+
166
+ for i_block in range(self.n_block):
167
+ if use_checkpoint:
168
+ tors = tors + checkpoint.checkpoint(create_custom_forward(self.row_attn[i_block]), tors, pair)
169
+ else:
170
+ tors = tors + self.row_attn[i_block](tors, pair)
171
+ tors = tors + self.ff[i_block](tors)
172
+ return self.norm(tors).reshape(B, T, L, -1)
173
+
174
+ class Templ_emb(nn.Module):
175
+ # Get template embedding
176
+ # Features are
177
+ # t2d:
178
+ # - 37 distogram bins + 6 orientations (43)
179
+ # - Mask (missing/unaligned) (1)
180
+ # t1d:
181
+ # - tiled AA sequence (20 standard aa + gap)
182
+ # - seq confidence (1)
183
+ # - global time step (1)
184
+ # - struc confidence (1)
185
+ #
186
+ def __init__(self, d_t1d=21+1+1+1, d_t2d=43+1, d_tor=30, d_pair=128, d_state=32,
187
+ n_block=2, d_templ=64,
188
+ n_head=4, d_hidden=16, p_drop=0.25):
189
+ super(Templ_emb, self).__init__()
190
+ # process 2D features
191
+ self.emb = nn.Linear(d_t1d*2+d_t2d, d_templ)
192
+ self.templ_stack = TemplatePairStack(n_block=n_block, d_templ=d_templ, n_head=n_head,
193
+ d_hidden=d_hidden, p_drop=p_drop)
194
+
195
+ self.attn = Attention(d_pair, d_templ, n_head, d_hidden, d_pair, p_drop=p_drop)
196
+
197
+ # process torsion angles
198
+ self.emb_t1d = nn.Linear(d_t1d+d_tor, d_templ)
199
+ self.proj_t1d = nn.Linear(d_templ, d_templ)
200
+ #self.tor_stack = TemplateTorsionStack(n_block=n_block, d_templ=d_templ, n_head=n_head,
201
+ # d_hidden=d_hidden, p_drop=p_drop)
202
+ self.attn_tor = Attention(d_state, d_templ, n_head, d_hidden, d_state, p_drop=p_drop)
203
+
204
+ self.reset_parameter()
205
+
206
+ def reset_parameter(self):
207
+ self.emb = init_lecun_normal(self.emb)
208
+ #nn.init.zeros_(self.emb.weight) #init weights to zero
209
+ nn.init.zeros_(self.emb.bias)
210
+
211
+ nn.init.kaiming_normal_(self.emb_t1d.weight, nonlinearity='relu')
212
+ #nn.init.zeros_(self.emb_t1d.weight)
213
+ nn.init.zeros_(self.emb_t1d.bias)
214
+
215
+ self.proj_t1d = init_lecun_normal(self.proj_t1d)
216
+ nn.init.zeros_(self.proj_t1d.bias)
217
+
218
+ def forward(self, t1d, t2d, alpha_t, xyz_t, pair, state, use_checkpoint=False):
219
+ # Input
220
+ # - t1d: 1D template info (B, T, L, 23) 24 SL
221
+ # - t2d: 2D template info (B, T, L, L, 44)
222
+ B, T, L, _ = t1d.shape
223
+
224
+ # Prepare 2D template features
225
+ left = t1d.unsqueeze(3).expand(-1,-1,-1,L,-1)
226
+ right = t1d.unsqueeze(2).expand(-1,-1,L,-1,-1)
227
+ #
228
+ templ = torch.cat((t2d, left, right), -1) # (B, T, L, L, 88)
229
+
230
+ #ic(templ.shape)
231
+ #ic(templ.dtype)
232
+ #ic(self.emb.weight.dtype)
233
+ templ = self.emb(templ) # Template templures (B, T, L, L, d_templ)
234
+ # process each template features
235
+ xyz_t = xyz_t.reshape(B*T, L, -1, 3)
236
+ rbf_feat = rbf(torch.cdist(xyz_t[:,:,1], xyz_t[:,:,1]))
237
+ templ = self.templ_stack(templ, rbf_feat, use_checkpoint=use_checkpoint) # (B, T, L,L, d_templ)
238
+
239
+ # Prepare 1D template torsion angle features
240
+ t1d = torch.cat((t1d, alpha_t), dim=-1) # (B, T, L, 22+30)
241
+ # process each template features
242
+ t1d = self.proj_t1d(F.relu_(self.emb_t1d(t1d)))
243
+
244
+ # mixing query state features to template state features
245
+ state = state.reshape(B*L, 1, -1)
246
+ t1d = t1d.permute(0,2,1,3).reshape(B*L, T, -1)
247
+ if use_checkpoint:
248
+ out = checkpoint.checkpoint(create_custom_forward(self.attn_tor), state, t1d, t1d)
249
+ out = out.reshape(B, L, -1)
250
+ else:
251
+ out = self.attn_tor(state, t1d, t1d).reshape(B, L, -1)
252
+ state = state.reshape(B, L, -1)
253
+ state = state + out
254
+
255
+ # mixing query pair features to template information (Template pointwise attention)
256
+ pair = pair.reshape(B*L*L, 1, -1)
257
+ templ = templ.permute(0, 2, 3, 1, 4).reshape(B*L*L, T, -1)
258
+ if use_checkpoint:
259
+ out = checkpoint.checkpoint(create_custom_forward(self.attn), pair, templ, templ)
260
+ out = out.reshape(B, L, L, -1)
261
+ else:
262
+ out = self.attn(pair, templ, templ).reshape(B, L, L, -1)
263
+ #
264
+ pair = pair.reshape(B, L, L, -1)
265
+ pair = pair + out
266
+
267
+ return pair, state
268
+
269
+ class Recycling(nn.Module):
270
+ def __init__(self, d_msa=256, d_pair=128, d_state=32):
271
+ super(Recycling, self).__init__()
272
+ self.proj_dist = nn.Linear(36+d_state*2, d_pair)
273
+ self.norm_state = nn.LayerNorm(d_state)
274
+ self.norm_pair = nn.LayerNorm(d_pair)
275
+ self.norm_msa = nn.LayerNorm(d_msa)
276
+
277
+ self.reset_parameter()
278
+
279
+ def reset_parameter(self):
280
+ self.proj_dist = init_lecun_normal(self.proj_dist)
281
+ nn.init.zeros_(self.proj_dist.bias)
282
+
283
+ def forward(self, seq, msa, pair, xyz, state):
284
+ B, L = pair.shape[:2]
285
+ state = self.norm_state(state)
286
+ #
287
+ left = state.unsqueeze(2).expand(-1,-1,L,-1)
288
+ right = state.unsqueeze(1).expand(-1,L,-1,-1)
289
+
290
+ # three anchor atoms
291
+ N = xyz[:,:,0]
292
+ Ca = xyz[:,:,1]
293
+ C = xyz[:,:,2]
294
+
295
+ # recreate Cb given N,Ca,C
296
+ b = Ca - N
297
+ c = C - Ca
298
+ a = torch.cross(b, c, dim=-1)
299
+ Cb = -0.58273431*a + 0.56802827*b - 0.54067466*c + Ca
300
+
301
+ dist = rbf(torch.cdist(Cb, Cb))
302
+ dist = torch.cat((dist, left, right), dim=-1)
303
+ dist = self.proj_dist(dist)
304
+ pair = dist + self.norm_pair(pair)
305
+ msa = self.norm_msa(msa)
306
+ return msa, pair, state
307
+
model/RoseTTAFoldModel.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from Embeddings import MSA_emb, Extra_emb, Templ_emb, Recycling
4
+ from Track_module import IterativeSimulator
5
+ from AuxiliaryPredictor import DistanceNetwork, MaskedTokenNetwork, ExpResolvedNetwork, LDDTNetwork
6
+ from util import INIT_CRDS
7
+ from opt_einsum import contract as einsum
8
+ from icecream import ic
9
+
10
+ class RoseTTAFoldModule(nn.Module):
11
+ def __init__(self, n_extra_block=4, n_main_block=8, n_ref_block=4,\
12
+ d_msa=256, d_msa_full=64, d_pair=128, d_templ=64,
13
+ n_head_msa=8, n_head_pair=4, n_head_templ=4,
14
+ d_hidden=32, d_hidden_templ=64,
15
+ p_drop=0.15, d_t1d=24, d_t2d=44,
16
+ SE3_param_full={'l0_in_features':32, 'l0_out_features':16, 'num_edge_features':32},
17
+ SE3_param_topk={'l0_in_features':32, 'l0_out_features':16, 'num_edge_features':32},
18
+ ):
19
+ super(RoseTTAFoldModule, self).__init__()
20
+ #
21
+ # Input Embeddings
22
+ d_state = SE3_param_topk['l0_out_features']
23
+ self.latent_emb = MSA_emb(d_msa=d_msa, d_pair=d_pair, d_state=d_state, p_drop=p_drop)
24
+ self.full_emb = Extra_emb(d_msa=d_msa_full, d_init=25, p_drop=p_drop)
25
+ self.templ_emb = Templ_emb(d_pair=d_pair, d_templ=d_templ, d_state=d_state,
26
+ n_head=n_head_templ,
27
+ d_hidden=d_hidden_templ, p_drop=0.25, d_t1d=d_t1d, d_t2d=d_t2d)
28
+ # Update inputs with outputs from previous round
29
+ self.recycle = Recycling(d_msa=d_msa, d_pair=d_pair, d_state=d_state)
30
+ #
31
+ self.simulator = IterativeSimulator(n_extra_block=n_extra_block,
32
+ n_main_block=n_main_block,
33
+ n_ref_block=n_ref_block,
34
+ d_msa=d_msa, d_msa_full=d_msa_full,
35
+ d_pair=d_pair, d_hidden=d_hidden,
36
+ n_head_msa=n_head_msa,
37
+ n_head_pair=n_head_pair,
38
+ SE3_param_full=SE3_param_full,
39
+ SE3_param_topk=SE3_param_topk,
40
+ p_drop=p_drop)
41
+ ##
42
+ self.c6d_pred = DistanceNetwork(d_pair, p_drop=p_drop)
43
+ self.aa_pred = MaskedTokenNetwork(d_msa, p_drop=p_drop)
44
+ self.lddt_pred = LDDTNetwork(d_state)
45
+
46
+ self.exp_pred = ExpResolvedNetwork(d_msa, d_state)
47
+
48
+ def forward(self, msa_latent, msa_full, seq, xyz, idx,
49
+ seq1hot=None, t1d=None, t2d=None, xyz_t=None, alpha_t=None,
50
+ msa_prev=None, pair_prev=None, state_prev=None,
51
+ return_raw=False, return_full=False,
52
+ use_checkpoint=False, return_infer=False):
53
+ B, N, L = msa_latent.shape[:3]
54
+ # Get embeddings
55
+ #ic(seq.shape)
56
+ #ic(msa_latent.shape)
57
+ #ic(seq1hot.shape)
58
+ #ic(idx.shape)
59
+ #ic(xyz.shape)
60
+ #ic(seq1hot.shape)
61
+ #ic(t1d.shape)
62
+ #ic(t2d.shape)
63
+
64
+ idx = idx.long()
65
+ msa_latent, pair, state = self.latent_emb(msa_latent, seq, idx, seq1hot=seq1hot)
66
+
67
+ msa_full = self.full_emb(msa_full, seq, idx, seq1hot=seq1hot)
68
+ #
69
+ # Do recycling
70
+ if msa_prev == None:
71
+ msa_prev = torch.zeros_like(msa_latent[:,0])
72
+ if pair_prev == None:
73
+ pair_prev = torch.zeros_like(pair)
74
+ if state_prev == None:
75
+ state_prev = torch.zeros_like(state)
76
+
77
+ #ic(seq.shape)
78
+ #ic(msa_prev.shape)
79
+ #ic(pair_prev.shape)
80
+ #ic(xyz.shape)
81
+ #ic(state_prev.shape)
82
+
83
+
84
+ msa_recycle, pair_recycle, state_recycle = self.recycle(seq, msa_prev, pair_prev, xyz, state_prev)
85
+ msa_latent[:,0] = msa_latent[:,0] + msa_recycle.reshape(B,L,-1)
86
+ pair = pair + pair_recycle
87
+ state = state + state_recycle
88
+ #
89
+ #ic(t1d.dtype)
90
+ #ic(t2d.dtype)
91
+ #ic(alpha_t.dtype)
92
+ #ic(xyz_t.dtype)
93
+ #ic(pair.dtype)
94
+ #ic(state.dtype)
95
+
96
+
97
+ #import pdb; pdb.set_trace()
98
+
99
+ # add template embedding
100
+ pair, state = self.templ_emb(t1d, t2d, alpha_t, xyz_t, pair, state, use_checkpoint=use_checkpoint)
101
+
102
+ #ic(seq.dtype)
103
+ #ic(msa_latent.dtype)
104
+ #ic(msa_full.dtype)
105
+ #ic(pair.dtype)
106
+ #ic(xyz.dtype)
107
+ #ic(state.dtype)
108
+ #ic(idx.dtype)
109
+
110
+ # Predict coordinates from given inputs
111
+ msa, pair, R, T, alpha_s, state = self.simulator(seq, msa_latent, msa_full.type(torch.float32), pair, xyz[:,:,:3],
112
+ state, idx, use_checkpoint=use_checkpoint)
113
+
114
+ if return_raw:
115
+ # get last structure
116
+ xyz = einsum('bnij,bnaj->bnai', R[-1], xyz[:,:,:3]-xyz[:,:,1].unsqueeze(-2)) + T[-1].unsqueeze(-2)
117
+ return msa[:,0], pair, xyz, state, alpha_s[-1]
118
+
119
+ # predict masked amino acids
120
+ logits_aa = self.aa_pred(msa)
121
+ #
122
+ # predict distogram & orientograms
123
+ logits = self.c6d_pred(pair)
124
+
125
+ # Predict LDDT
126
+ lddt = self.lddt_pred(state)
127
+
128
+ # predict experimentally resolved or not
129
+ logits_exp = self.exp_pred(msa[:,0], state)
130
+
131
+ if return_infer:
132
+ #get last structure
133
+ xyz = einsum('bnij,bnaj->bnai', R[-1], xyz[:,:,:3]-xyz[:,:,1].unsqueeze(-2)) + T[-1].unsqueeze(-2)
134
+ return logits, logits_aa, logits_exp, xyz, lddt, msa[:,0], pair, state, alpha_s[-1]
135
+
136
+
137
+ # get all intermediate bb structures
138
+ xyz = einsum('rbnij,bnaj->rbnai', R, xyz[:,:,:3]-xyz[:,:,1].unsqueeze(-2)) + T.unsqueeze(-2)
139
+
140
+ return logits, logits_aa, logits_exp, xyz, alpha_s, lddt
model/SE3_network.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ #from equivariant_attention.modules import get_basis_and_r, GSE3Res, GNormBias
5
+ #from equivariant_attention.modules import GConvSE3, GNormSE3
6
+ #from equivariant_attention.fibers import Fiber
7
+
8
+ from util_module import init_lecun_normal_param
9
+ from se3_transformer.model import SE3Transformer
10
+ from se3_transformer.model.fiber import Fiber
11
+
12
+ class SE3TransformerWrapper(nn.Module):
13
+ """SE(3) equivariant GCN with attention"""
14
+ def __init__(self, num_layers=2, num_channels=32, num_degrees=3, n_heads=4, div=4,
15
+ l0_in_features=32, l0_out_features=32,
16
+ l1_in_features=3, l1_out_features=2,
17
+ num_edge_features=32):
18
+ super().__init__()
19
+ # Build the network
20
+ self.l1_in = l1_in_features
21
+ #
22
+ fiber_edge = Fiber({0: num_edge_features})
23
+ if l1_out_features > 0:
24
+ if l1_in_features > 0:
25
+ fiber_in = Fiber({0: l0_in_features, 1: l1_in_features})
26
+ fiber_hidden = Fiber.create(num_degrees, num_channels)
27
+ fiber_out = Fiber({0: l0_out_features, 1: l1_out_features})
28
+ else:
29
+ fiber_in = Fiber({0: l0_in_features})
30
+ fiber_hidden = Fiber.create(num_degrees, num_channels)
31
+ fiber_out = Fiber({0: l0_out_features, 1: l1_out_features})
32
+ else:
33
+ if l1_in_features > 0:
34
+ fiber_in = Fiber({0: l0_in_features, 1: l1_in_features})
35
+ fiber_hidden = Fiber.create(num_degrees, num_channels)
36
+ fiber_out = Fiber({0: l0_out_features})
37
+ else:
38
+ fiber_in = Fiber({0: l0_in_features})
39
+ fiber_hidden = Fiber.create(num_degrees, num_channels)
40
+ fiber_out = Fiber({0: l0_out_features})
41
+
42
+ self.se3 = SE3Transformer(num_layers=num_layers,
43
+ fiber_in=fiber_in,
44
+ fiber_hidden=fiber_hidden,
45
+ fiber_out = fiber_out,
46
+ num_heads=n_heads,
47
+ channels_div=div,
48
+ fiber_edge=fiber_edge,
49
+ use_layer_norm=True)
50
+ #use_layer_norm=False)
51
+
52
+ self.reset_parameter()
53
+
54
+ def reset_parameter(self):
55
+
56
+ # make sure linear layer before ReLu are initialized with kaiming_normal_
57
+ for n, p in self.se3.named_parameters():
58
+ if "bias" in n:
59
+ nn.init.zeros_(p)
60
+ elif len(p.shape) == 1:
61
+ continue
62
+ else:
63
+ if "radial_func" not in n:
64
+ p = init_lecun_normal_param(p)
65
+ else:
66
+ if "net.6" in n:
67
+ nn.init.zeros_(p)
68
+ else:
69
+ nn.init.kaiming_normal_(p, nonlinearity='relu')
70
+
71
+ # make last layers to be zero-initialized
72
+ #self.se3.graph_modules[-1].to_kernel_self['0'] = init_lecun_normal_param(self.se3.graph_modules[-1].to_kernel_self['0'])
73
+ #self.se3.graph_modules[-1].to_kernel_self['1'] = init_lecun_normal_param(self.se3.graph_modules[-1].to_kernel_self['1'])
74
+ nn.init.zeros_(self.se3.graph_modules[-1].to_kernel_self['0'])
75
+ nn.init.zeros_(self.se3.graph_modules[-1].to_kernel_self['1'])
76
+
77
+ def forward(self, G, type_0_features, type_1_features=None, edge_features=None):
78
+ if self.l1_in > 0:
79
+ node_features = {'0': type_0_features, '1': type_1_features}
80
+ else:
81
+ node_features = {'0': type_0_features}
82
+ edge_features = {'0': edge_features}
83
+ return self.se3(G, node_features, edge_features)
model/Track_module.py ADDED
@@ -0,0 +1,476 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from opt_einsum import contract as einsum
5
+ import torch.utils.checkpoint as checkpoint
6
+ from util import cross_product_matrix
7
+ from util_module import *
8
+ from Attention_module import *
9
+ from SE3_network import SE3TransformerWrapper
10
+ from icecream import ic
11
+
12
+ # Components for three-track blocks
13
+ # 1. MSA -> MSA update (biased attention. bias from pair & structure)
14
+ # 2. Pair -> Pair update (biased attention. bias from structure)
15
+ # 3. MSA -> Pair update (extract coevolution signal)
16
+ # 4. Str -> Str update (node from MSA, edge from Pair)
17
+
18
+ # Update MSA with biased self-attention. bias from Pair & Str
19
+ class MSAPairStr2MSA(nn.Module):
20
+ def __init__(self, d_msa=256, d_pair=128, n_head=8, d_state=16,
21
+ d_hidden=32, p_drop=0.15, use_global_attn=False):
22
+ super(MSAPairStr2MSA, self).__init__()
23
+ self.norm_pair = nn.LayerNorm(d_pair)
24
+ self.proj_pair = nn.Linear(d_pair+36, d_pair)
25
+ self.norm_state = nn.LayerNorm(d_state)
26
+ self.proj_state = nn.Linear(d_state, d_msa)
27
+ self.drop_row = Dropout(broadcast_dim=1, p_drop=p_drop)
28
+ self.row_attn = MSARowAttentionWithBias(d_msa=d_msa, d_pair=d_pair,
29
+ n_head=n_head, d_hidden=d_hidden)
30
+ if use_global_attn:
31
+ self.col_attn = MSAColGlobalAttention(d_msa=d_msa, n_head=n_head, d_hidden=d_hidden)
32
+ else:
33
+ self.col_attn = MSAColAttention(d_msa=d_msa, n_head=n_head, d_hidden=d_hidden)
34
+ self.ff = FeedForwardLayer(d_msa, 4, p_drop=p_drop)
35
+
36
+ # Do proper initialization
37
+ self.reset_parameter()
38
+
39
+ def reset_parameter(self):
40
+ # initialize weights to normal distrib
41
+ self.proj_pair = init_lecun_normal(self.proj_pair)
42
+ self.proj_state = init_lecun_normal(self.proj_state)
43
+
44
+ # initialize bias to zeros
45
+ nn.init.zeros_(self.proj_pair.bias)
46
+ nn.init.zeros_(self.proj_state.bias)
47
+
48
+ def forward(self, msa, pair, rbf_feat, state):
49
+ '''
50
+ Inputs:
51
+ - msa: MSA feature (B, N, L, d_msa)
52
+ - pair: Pair feature (B, L, L, d_pair)
53
+ - rbf_feat: Ca-Ca distance feature calculated from xyz coordinates (B, L, L, 36)
54
+ - xyz: xyz coordinates (B, L, n_atom, 3)
55
+ - state: updated node features after SE(3)-Transformer layer (B, L, d_state)
56
+ Output:
57
+ - msa: Updated MSA feature (B, N, L, d_msa)
58
+ '''
59
+ B, N, L = msa.shape[:3]
60
+
61
+ # prepare input bias feature by combining pair & coordinate info
62
+ pair = self.norm_pair(pair)
63
+ pair = torch.cat((pair, rbf_feat), dim=-1)
64
+ pair = self.proj_pair(pair) # (B, L, L, d_pair)
65
+ #
66
+ # update query sequence feature (first sequence in the MSA) with feedbacks (state) from SE3
67
+ state = self.norm_state(state)
68
+ state = self.proj_state(state).reshape(B, 1, L, -1)
69
+
70
+ msa = msa.index_add(1, torch.tensor([0,], device=state.device), state.type(torch.float32))
71
+ #
72
+ # Apply row/column attention to msa & transform
73
+ msa = msa + self.drop_row(self.row_attn(msa, pair))
74
+ msa = msa + self.col_attn(msa)
75
+ msa = msa + self.ff(msa)
76
+
77
+ return msa
78
+
79
+ class PairStr2Pair(nn.Module):
80
+ def __init__(self, d_pair=128, n_head=4, d_hidden=32, d_rbf=36, p_drop=0.15):
81
+ super(PairStr2Pair, self).__init__()
82
+
83
+ self.emb_rbf = nn.Linear(d_rbf, d_hidden)
84
+ self.proj_rbf = nn.Linear(d_hidden, d_pair)
85
+
86
+ self.drop_row = Dropout(broadcast_dim=1, p_drop=p_drop)
87
+ self.drop_col = Dropout(broadcast_dim=2, p_drop=p_drop)
88
+
89
+ self.row_attn = BiasedAxialAttention(d_pair, d_pair, n_head, d_hidden, p_drop=p_drop, is_row=True)
90
+ self.col_attn = BiasedAxialAttention(d_pair, d_pair, n_head, d_hidden, p_drop=p_drop, is_row=False)
91
+
92
+ self.ff = FeedForwardLayer(d_pair, 2)
93
+
94
+ self.reset_parameter()
95
+
96
+ def reset_parameter(self):
97
+ nn.init.kaiming_normal_(self.emb_rbf.weight, nonlinearity='relu')
98
+ nn.init.zeros_(self.emb_rbf.bias)
99
+
100
+ self.proj_rbf = init_lecun_normal(self.proj_rbf)
101
+ nn.init.zeros_(self.proj_rbf.bias)
102
+
103
+ def forward(self, pair, rbf_feat):
104
+ B, L = pair.shape[:2]
105
+
106
+ rbf_feat = self.proj_rbf(F.relu_(self.emb_rbf(rbf_feat)))
107
+
108
+ pair = pair + self.drop_row(self.row_attn(pair, rbf_feat))
109
+ pair = pair + self.drop_col(self.col_attn(pair, rbf_feat))
110
+ pair = pair + self.ff(pair)
111
+ return pair
112
+
113
+ class MSA2Pair(nn.Module):
114
+ def __init__(self, d_msa=256, d_pair=128, d_hidden=32, p_drop=0.15):
115
+ super(MSA2Pair, self).__init__()
116
+ self.norm = nn.LayerNorm(d_msa)
117
+ self.proj_left = nn.Linear(d_msa, d_hidden)
118
+ self.proj_right = nn.Linear(d_msa, d_hidden)
119
+ self.proj_out = nn.Linear(d_hidden*d_hidden, d_pair)
120
+
121
+ self.reset_parameter()
122
+
123
+ def reset_parameter(self):
124
+ # normal initialization
125
+ self.proj_left = init_lecun_normal(self.proj_left)
126
+ self.proj_right = init_lecun_normal(self.proj_right)
127
+ nn.init.zeros_(self.proj_left.bias)
128
+ nn.init.zeros_(self.proj_right.bias)
129
+
130
+ # zero initialize output
131
+ nn.init.zeros_(self.proj_out.weight)
132
+ nn.init.zeros_(self.proj_out.bias)
133
+
134
+ def forward(self, msa, pair):
135
+ B, N, L = msa.shape[:3]
136
+ msa = self.norm(msa)
137
+ left = self.proj_left(msa)
138
+ right = self.proj_right(msa)
139
+ right = right / float(N)
140
+ out = einsum('bsli,bsmj->blmij', left, right).reshape(B, L, L, -1)
141
+ out = self.proj_out(out)
142
+
143
+ pair = pair + out
144
+
145
+ return pair
146
+
147
+ class SCPred(nn.Module):
148
+ def __init__(self, d_msa=256, d_state=32, d_hidden=128, p_drop=0.15):
149
+ super(SCPred, self).__init__()
150
+ self.norm_s0 = nn.LayerNorm(d_msa)
151
+ self.norm_si = nn.LayerNorm(d_state)
152
+ self.linear_s0 = nn.Linear(d_msa, d_hidden)
153
+ self.linear_si = nn.Linear(d_state, d_hidden)
154
+
155
+ # ResNet layers
156
+ self.linear_1 = nn.Linear(d_hidden, d_hidden)
157
+ self.linear_2 = nn.Linear(d_hidden, d_hidden)
158
+ self.linear_3 = nn.Linear(d_hidden, d_hidden)
159
+ self.linear_4 = nn.Linear(d_hidden, d_hidden)
160
+
161
+ # Final outputs
162
+ self.linear_out = nn.Linear(d_hidden, 20)
163
+
164
+ self.reset_parameter()
165
+
166
+ def reset_parameter(self):
167
+ # normal initialization
168
+ self.linear_s0 = init_lecun_normal(self.linear_s0)
169
+ self.linear_si = init_lecun_normal(self.linear_si)
170
+ self.linear_out = init_lecun_normal(self.linear_out)
171
+ nn.init.zeros_(self.linear_s0.bias)
172
+ nn.init.zeros_(self.linear_si.bias)
173
+ nn.init.zeros_(self.linear_out.bias)
174
+
175
+ # right before relu activation: He initializer (kaiming normal)
176
+ nn.init.kaiming_normal_(self.linear_1.weight, nonlinearity='relu')
177
+ nn.init.zeros_(self.linear_1.bias)
178
+ nn.init.kaiming_normal_(self.linear_3.weight, nonlinearity='relu')
179
+ nn.init.zeros_(self.linear_3.bias)
180
+
181
+ # right before residual connection: zero initialize
182
+ nn.init.zeros_(self.linear_2.weight)
183
+ nn.init.zeros_(self.linear_2.bias)
184
+ nn.init.zeros_(self.linear_4.weight)
185
+ nn.init.zeros_(self.linear_4.bias)
186
+
187
+ def forward(self, seq, state):
188
+ '''
189
+ Predict side-chain torsion angles along with backbone torsions
190
+ Inputs:
191
+ - seq: hidden embeddings corresponding to query sequence (B, L, d_msa)
192
+ - state: state feature (output l0 feature) from previous SE3 layer (B, L, d_state)
193
+ Outputs:
194
+ - si: predicted torsion angles (phi, psi, omega, chi1~4 with cos/sin, Cb bend, Cb twist, CG) (B, L, 10, 2)
195
+ '''
196
+ B, L = seq.shape[:2]
197
+ seq = self.norm_s0(seq)
198
+ state = self.norm_si(state)
199
+ si = self.linear_s0(seq) + self.linear_si(state)
200
+
201
+ si = si + self.linear_2(F.relu_(self.linear_1(F.relu_(si))))
202
+ si = si + self.linear_4(F.relu_(self.linear_3(F.relu_(si))))
203
+
204
+ si = self.linear_out(F.relu_(si))
205
+ return si.view(B, L, 10, 2)
206
+
207
+
208
+ class Str2Str(nn.Module):
209
+ def __init__(self, d_msa=256, d_pair=128, d_state=16,
210
+ SE3_param={'l0_in_features':32, 'l0_out_features':16, 'num_edge_features':32}, p_drop=0.1):
211
+ super(Str2Str, self).__init__()
212
+
213
+ # initial node & pair feature process
214
+ self.norm_msa = nn.LayerNorm(d_msa)
215
+ self.norm_pair = nn.LayerNorm(d_pair)
216
+ self.norm_state = nn.LayerNorm(d_state)
217
+
218
+ self.embed_x = nn.Linear(d_msa+d_state, SE3_param['l0_in_features'])
219
+ self.embed_e1 = nn.Linear(d_pair, SE3_param['num_edge_features'])
220
+ self.embed_e2 = nn.Linear(SE3_param['num_edge_features']+36+1, SE3_param['num_edge_features'])
221
+
222
+ self.norm_node = nn.LayerNorm(SE3_param['l0_in_features'])
223
+ self.norm_edge1 = nn.LayerNorm(SE3_param['num_edge_features'])
224
+ self.norm_edge2 = nn.LayerNorm(SE3_param['num_edge_features'])
225
+
226
+ self.se3 = SE3TransformerWrapper(**SE3_param)
227
+ self.sc_predictor = SCPred(d_msa=d_msa, d_state=SE3_param['l0_out_features'],
228
+ p_drop=p_drop)
229
+
230
+ self.reset_parameter()
231
+
232
+ def reset_parameter(self):
233
+ # initialize weights to normal distribution
234
+ self.embed_x = init_lecun_normal(self.embed_x)
235
+ self.embed_e1 = init_lecun_normal(self.embed_e1)
236
+ self.embed_e2 = init_lecun_normal(self.embed_e2)
237
+
238
+ # initialize bias to zeros
239
+ nn.init.zeros_(self.embed_x.bias)
240
+ nn.init.zeros_(self.embed_e1.bias)
241
+ nn.init.zeros_(self.embed_e2.bias)
242
+
243
+ @torch.cuda.amp.autocast(enabled=False)
244
+ def forward(self, msa, pair, R_in, T_in, xyz, state, idx, top_k=64, eps=1e-5):
245
+ B, N, L = msa.shape[:3]
246
+
247
+ state = state.type(torch.float32)
248
+ mas = msa.type(torch.float32)
249
+ pair = pair.type(torch.float32)
250
+ R_in = R_in.type(torch.float32)
251
+ T_in = T_in.type(torch.float32)
252
+ xyz = xyz.type(torch.float32)
253
+
254
+ #ic(msa.dtype)
255
+ #ic(pair.dtype)
256
+ #ic(R_in.dtype)
257
+ #ic(T_in.dtype)
258
+ #ic(xyz.dtype)
259
+ #ic(state.dtype)
260
+ #ic(idx.dtype)
261
+
262
+
263
+ # process msa & pair features
264
+ node = self.norm_msa(msa[:,0])
265
+ pair = self.norm_pair(pair)
266
+ state = self.norm_state(state)
267
+
268
+ node = torch.cat((node, state), dim=-1)
269
+ node = self.norm_node(self.embed_x(node))
270
+ pair = self.norm_edge1(self.embed_e1(pair))
271
+
272
+ neighbor = get_seqsep(idx)
273
+ rbf_feat = rbf(torch.cdist(xyz[:,:,1], xyz[:,:,1]))
274
+ pair = torch.cat((pair, rbf_feat, neighbor), dim=-1)
275
+ pair = self.norm_edge2(self.embed_e2(pair))
276
+
277
+ # define graph
278
+ if top_k != 0:
279
+ G, edge_feats = make_topk_graph(xyz[:,:,1,:], pair, idx, top_k=top_k)
280
+ else:
281
+ G, edge_feats = make_full_graph(xyz[:,:,1,:], pair, idx, top_k=top_k)
282
+ l1_feats = xyz - xyz[:,:,1,:].unsqueeze(2)
283
+ l1_feats = l1_feats.reshape(B*L, -1, 3)
284
+
285
+ # apply SE(3) Transformer & update coordinates
286
+ shift = self.se3(G, node.reshape(B*L, -1, 1), l1_feats, edge_feats)
287
+
288
+ state = shift['0'].reshape(B, L, -1) # (B, L, C)
289
+
290
+ offset = shift['1'].reshape(B, L, 2, 3)
291
+ delTi = offset[:,:,0,:] / 10.0 # translation
292
+ R = offset[:,:,1,:] / 100.0 # rotation
293
+
294
+ Qnorm = torch.sqrt( 1 + torch.sum(R*R, dim=-1) )
295
+ qA, qB, qC, qD = 1/Qnorm, R[:,:,0]/Qnorm, R[:,:,1]/Qnorm, R[:,:,2]/Qnorm
296
+
297
+ delRi = torch.zeros((B,L,3,3), device=xyz.device)
298
+ delRi[:,:,0,0] = qA*qA+qB*qB-qC*qC-qD*qD
299
+ delRi[:,:,0,1] = 2*qB*qC - 2*qA*qD
300
+ delRi[:,:,0,2] = 2*qB*qD + 2*qA*qC
301
+ delRi[:,:,1,0] = 2*qB*qC + 2*qA*qD
302
+ delRi[:,:,1,1] = qA*qA-qB*qB+qC*qC-qD*qD
303
+ delRi[:,:,1,2] = 2*qC*qD - 2*qA*qB
304
+ delRi[:,:,2,0] = 2*qB*qD - 2*qA*qC
305
+ delRi[:,:,2,1] = 2*qC*qD + 2*qA*qB
306
+ delRi[:,:,2,2] = qA*qA-qB*qB-qC*qC+qD*qD
307
+ #
308
+ ## convert vector to rotation matrix
309
+ #R_angle = torch.norm(R, dim=-1, keepdim=True) # (B, L, 1)
310
+ #cos_angle = torch.cos(R_angle).unsqueeze(2) # (B, L, 1, 1)
311
+ #sin_angle = torch.sin(R_angle).unsqueeze(2) # (B, L, 1, 1)
312
+ #R_vector = R / (R_angle+eps) # (B, L, 3)
313
+
314
+ #delRi = cos_angle*torch.eye(3, device=R.device).reshape(1,1,3,3) \
315
+ # + sin_angle*cross_product_matrix(R_vector) \
316
+ # + (1.0-cos_angle)*einsum('bni,bnj->bnij', R_vector, R_vector)
317
+
318
+ Ri = einsum('bnij,bnjk->bnik', delRi, R_in)
319
+ Ti = delTi + T_in #einsum('bnij,bnj->bni', delRi, T_in) + delTi
320
+
321
+ alpha = self.sc_predictor(msa[:,0], state)
322
+ return Ri, Ti, state, alpha
323
+
324
+ class IterBlock(nn.Module):
325
+ def __init__(self, d_msa=256, d_pair=128,
326
+ n_head_msa=8, n_head_pair=4,
327
+ use_global_attn=False,
328
+ d_hidden=32, d_hidden_msa=None, p_drop=0.15,
329
+ SE3_param={'l0_in_features':32, 'l0_out_features':16, 'num_edge_features':32}):
330
+ super(IterBlock, self).__init__()
331
+ if d_hidden_msa == None:
332
+ d_hidden_msa = d_hidden
333
+
334
+ self.msa2msa = MSAPairStr2MSA(d_msa=d_msa, d_pair=d_pair,
335
+ n_head=n_head_msa,
336
+ d_state=SE3_param['l0_out_features'],
337
+ use_global_attn=use_global_attn,
338
+ d_hidden=d_hidden_msa, p_drop=p_drop)
339
+ self.msa2pair = MSA2Pair(d_msa=d_msa, d_pair=d_pair,
340
+ d_hidden=d_hidden//2, p_drop=p_drop)
341
+ #d_hidden=d_hidden, p_drop=p_drop)
342
+ self.pair2pair = PairStr2Pair(d_pair=d_pair, n_head=n_head_pair,
343
+ d_hidden=d_hidden, p_drop=p_drop)
344
+ self.str2str = Str2Str(d_msa=d_msa, d_pair=d_pair,
345
+ d_state=SE3_param['l0_out_features'],
346
+ SE3_param=SE3_param,
347
+ p_drop=p_drop)
348
+
349
+ def forward(self, msa, pair, R_in, T_in, xyz, state, idx, use_checkpoint=False):
350
+ rbf_feat = rbf(torch.cdist(xyz[:,:,1,:], xyz[:,:,1,:]))
351
+ if use_checkpoint:
352
+ msa = checkpoint.checkpoint(create_custom_forward(self.msa2msa), msa, pair, rbf_feat, state)
353
+ pair = checkpoint.checkpoint(create_custom_forward(self.msa2pair), msa, pair)
354
+ pair = checkpoint.checkpoint(create_custom_forward(self.pair2pair), pair, rbf_feat)
355
+ R, T, state, alpha = checkpoint.checkpoint(create_custom_forward(self.str2str, top_k=0), msa, pair, R_in, T_in, xyz, state, idx)
356
+ else:
357
+ msa = self.msa2msa(msa, pair, rbf_feat, state)
358
+ pair = self.msa2pair(msa, pair)
359
+ pair = self.pair2pair(pair, rbf_feat)
360
+ R, T, state, alpha = self.str2str(msa, pair, R_in, T_in, xyz, state, idx, top_k=0)
361
+
362
+ return msa, pair, R, T, state, alpha
363
+
364
+ class IterativeSimulator(nn.Module):
365
+ def __init__(self, n_extra_block=4, n_main_block=12, n_ref_block=4,
366
+ d_msa=256, d_msa_full=64, d_pair=128, d_hidden=32,
367
+ n_head_msa=8, n_head_pair=4,
368
+ SE3_param_full={'l0_in_features':32, 'l0_out_features':16, 'num_edge_features':32},
369
+ SE3_param_topk={'l0_in_features':32, 'l0_out_features':16, 'num_edge_features':32},
370
+ p_drop=0.15):
371
+ super(IterativeSimulator, self).__init__()
372
+ self.n_extra_block = n_extra_block
373
+ self.n_main_block = n_main_block
374
+ self.n_ref_block = n_ref_block
375
+
376
+ self.proj_state = nn.Linear(SE3_param_topk['l0_out_features'], SE3_param_full['l0_out_features'])
377
+ # Update with extra sequences
378
+ if n_extra_block > 0:
379
+ self.extra_block = nn.ModuleList([IterBlock(d_msa=d_msa_full, d_pair=d_pair,
380
+ n_head_msa=n_head_msa,
381
+ n_head_pair=n_head_pair,
382
+ d_hidden_msa=8,
383
+ d_hidden=d_hidden,
384
+ p_drop=p_drop,
385
+ use_global_attn=True,
386
+ SE3_param=SE3_param_full)
387
+ for i in range(n_extra_block)])
388
+
389
+ # Update with seed sequences
390
+ if n_main_block > 0:
391
+ self.main_block = nn.ModuleList([IterBlock(d_msa=d_msa, d_pair=d_pair,
392
+ n_head_msa=n_head_msa,
393
+ n_head_pair=n_head_pair,
394
+ d_hidden=d_hidden,
395
+ p_drop=p_drop,
396
+ use_global_attn=False,
397
+ SE3_param=SE3_param_full)
398
+ for i in range(n_main_block)])
399
+
400
+ self.proj_state2 = nn.Linear(SE3_param_full['l0_out_features'], SE3_param_topk['l0_out_features'])
401
+ # Final SE(3) refinement
402
+ if n_ref_block > 0:
403
+ self.str_refiner = Str2Str(d_msa=d_msa, d_pair=d_pair,
404
+ d_state=SE3_param_topk['l0_out_features'],
405
+ SE3_param=SE3_param_topk,
406
+ p_drop=p_drop)
407
+
408
+ self.reset_parameter()
409
+ def reset_parameter(self):
410
+ self.proj_state = init_lecun_normal(self.proj_state)
411
+ nn.init.zeros_(self.proj_state.bias)
412
+ self.proj_state2 = init_lecun_normal(self.proj_state2)
413
+ nn.init.zeros_(self.proj_state2.bias)
414
+
415
+ def forward(self, seq, msa, msa_full, pair, xyz_in, state, idx, use_checkpoint=False):
416
+ # input:
417
+ # seq: query sequence (B, L)
418
+ # msa: seed MSA embeddings (B, N, L, d_msa)
419
+ # msa_full: extra MSA embeddings (B, N, L, d_msa_full)
420
+ # pair: initial residue pair embeddings (B, L, L, d_pair)
421
+ # xyz_in: initial BB coordinates (B, L, n_atom, 3)
422
+ # state: initial state features containing mixture of query seq, sidechain, accuracy info (B, L, d_state)
423
+ # idx: residue index
424
+
425
+ B, L = pair.shape[:2]
426
+
427
+ R_in = torch.eye(3, device=xyz_in.device).reshape(1,1,3,3).expand(B, L, -1, -1)
428
+ T_in = xyz_in[:,:,1].clone()
429
+ xyz_in = xyz_in - T_in.unsqueeze(-2)
430
+
431
+ state = self.proj_state(state)
432
+
433
+ R_s = list()
434
+ T_s = list()
435
+ alpha_s = list()
436
+ for i_m in range(self.n_extra_block):
437
+ R_in = R_in.detach() # detach rotation (for stability)
438
+ T_in = T_in.detach()
439
+ # Get current BB structure
440
+ xyz = einsum('bnij,bnaj->bnai', R_in, xyz_in) + T_in.unsqueeze(-2)
441
+
442
+ msa_full, pair, R_in, T_in, state, alpha = self.extra_block[i_m](msa_full, pair,
443
+ R_in, T_in, xyz, state, idx,
444
+ use_checkpoint=use_checkpoint)
445
+ R_s.append(R_in)
446
+ T_s.append(T_in)
447
+ alpha_s.append(alpha)
448
+
449
+ for i_m in range(self.n_main_block):
450
+ R_in = R_in.detach()
451
+ T_in = T_in.detach()
452
+ # Get current BB structure
453
+ xyz = einsum('bnij,bnaj->bnai', R_in, xyz_in) + T_in.unsqueeze(-2)
454
+
455
+ msa, pair, R_in, T_in, state, alpha = self.main_block[i_m](msa, pair,
456
+ R_in, T_in, xyz, state, idx,
457
+ use_checkpoint=use_checkpoint)
458
+ R_s.append(R_in)
459
+ T_s.append(T_in)
460
+ alpha_s.append(alpha)
461
+
462
+ state = self.proj_state2(state)
463
+ for i_m in range(self.n_ref_block):
464
+ R_in = R_in.detach()
465
+ T_in = T_in.detach()
466
+ xyz = einsum('bnij,bnaj->bnai', R_in, xyz_in) + T_in.unsqueeze(-2)
467
+ R_in, T_in, state, alpha = self.str_refiner(msa, pair, R_in, T_in, xyz, state, idx, top_k=64)
468
+ R_s.append(R_in)
469
+ T_s.append(T_in)
470
+ alpha_s.append(alpha)
471
+
472
+ R_s = torch.stack(R_s, dim=0)
473
+ T_s = torch.stack(T_s, dim=0)
474
+ alpha_s = torch.stack(alpha_s, dim=0)
475
+
476
+ return msa, pair, R_s, T_s, alpha_s, state
model/__pycache__/Attention_module.cpython-310.pyc ADDED
Binary file (10.8 kB). View file
 
model/__pycache__/AuxiliaryPredictor.cpython-310.pyc ADDED
Binary file (3.53 kB). View file
 
model/__pycache__/Embeddings.cpython-310.pyc ADDED
Binary file (9.57 kB). View file
 
model/__pycache__/RoseTTAFoldModel.cpython-310.pyc ADDED
Binary file (3.29 kB). View file
 
model/__pycache__/SE3_network.cpython-310.pyc ADDED
Binary file (2.27 kB). View file
 
model/__pycache__/Track_module.cpython-310.pyc ADDED
Binary file (13.7 kB). View file
 
model/__pycache__/ab_tools.cpython-310.pyc ADDED
Binary file (17.5 kB). View file
 
model/__pycache__/apply_masks.cpython-310.pyc ADDED
Binary file (4.18 kB). View file
 
model/__pycache__/arguments.cpython-310.pyc ADDED
Binary file (12.1 kB). View file
 
model/__pycache__/chemical.cpython-310.pyc ADDED
Binary file (19.8 kB). View file
 
model/__pycache__/data_loader.cpython-310.pyc ADDED
Binary file (47.8 kB). View file
 
model/__pycache__/diffusion.cpython-310.pyc ADDED
Binary file (7.34 kB). View file
 
model/__pycache__/kinematics.cpython-310.pyc ADDED
Binary file (8.97 kB). View file
 
model/__pycache__/loss.cpython-310.pyc ADDED
Binary file (17.4 kB). View file
 
model/__pycache__/mask_generator.cpython-310.pyc ADDED
Binary file (3.01 kB). View file
 
model/__pycache__/parsers.cpython-310.pyc ADDED
Binary file (4.74 kB). View file
 
model/__pycache__/scheduler.cpython-310.pyc ADDED
Binary file (4.89 kB). View file
 
model/__pycache__/scoring.cpython-310.pyc ADDED
Binary file (13.9 kB). View file
 
model/__pycache__/train_multi_deep.cpython-310.pyc ADDED
Binary file (25.7 kB). View file
 
model/__pycache__/train_multi_deep_selfcond_nostruc.cpython-310.pyc ADDED
Binary file (37.1 kB). View file
 
model/__pycache__/util.cpython-310.pyc ADDED
Binary file (14.7 kB). View file
 
model/__pycache__/util_module.cpython-310.pyc ADDED
Binary file (9.66 kB). View file
 
model/apply_masks.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys, os
2
+ import torch
3
+ from icecream import ic
4
+ import random
5
+ import numpy as np
6
+ from kinematics import get_init_xyz
7
+ sys.path.append('../')
8
+ from utils.calc_dssp import annotate_sse
9
+
10
+ ic.configureOutput(includeContext=True)
11
+
12
+ def mask_inputs(seq,
13
+ msa_masked,
14
+ msa_full,
15
+ xyz_t,
16
+ t1d,
17
+ mask_msa,
18
+ input_seq_mask=None,
19
+ input_str_mask=None,
20
+ input_floating_mask=None,
21
+ input_t1dconf_mask=None,
22
+ loss_seq_mask=None,
23
+ loss_str_mask=None,
24
+ loss_str_mask_2d=None,
25
+ dssp=False,
26
+ hotspots=False,
27
+ diffuser=None,
28
+ t=None,
29
+ freeze_seq_emb=False,
30
+ mutate_seq=False,
31
+ no_clamp_seq=False,
32
+ norm_input=False,
33
+ contacts=None,
34
+ frac_provide_dssp=0.5,
35
+ dssp_mask_percentage=[0,100],
36
+ frac_provide_contacts=0.5,
37
+ struc_cond=False):
38
+ """
39
+ Parameters:
40
+ seq (torch.tensor, required): (I,L) integer sequence
41
+
42
+ msa_masked (torch.tensor, required): (I,N_short,L,48)
43
+
44
+ msa_full (torch,.tensor, required): (I,N_long,L,25)
45
+
46
+ xyz_t (torch,tensor): (T,L,27,3) template crds BEFORE they go into get_init_xyz
47
+
48
+ t1d (torch.tensor, required): (I,L,22) this is the t1d before tacking on the chi angles
49
+
50
+ str_mask_1D (torch.tensor, required): Shape (L) rank 1 tensor where structure is masked at False positions
51
+
52
+ seq_mask_1D (torch.tensor, required): Shape (L) rank 1 tensor where seq is masked at False positions
53
+ t1d_24: is there an extra dimension to input structure confidence?
54
+
55
+ diffuser: diffuser class
56
+
57
+ t: time step
58
+
59
+ NOTE: in the MSA, the order is 20aa, 1x unknown, 1x mask token. We set the masked region to 22 (masked).
60
+ For the t1d, this has 20aa, 1x unkown, and 1x template conf. Here, we set the masked region to 21 (unknown).
61
+ This, we think, makes sense, as the template in normal RF training does not perfectly correspond to the MSA.
62
+ """
63
+
64
+
65
+
66
+ #ic(input_seq_mask.shape)
67
+ #ic(seq.shape)
68
+ #ic(msa_masked.shape)
69
+ #ic(msa_full.shape)
70
+ #ic(t1d.shape)
71
+ #ic(xyz_t.shape)
72
+ #ic(input_str_mask.shape)
73
+ #ic(mask_msa.shape)
74
+
75
+ ###########
76
+ seq_mask = input_seq_mask
77
+
78
+
79
+ ######################
80
+ ###sequence diffusion###
81
+ ######################
82
+
83
+ str_mask = input_str_mask
84
+
85
+ x_0 = torch.nn.functional.one_hot(seq[0,...],num_classes=22).float()*2-1
86
+ seq_diffused = diffuser.q_sample(x_0,t,mask=seq_mask)
87
+
88
+ seq_tmp=torch.argmax(seq_diffused,axis=-1).to(device=seq.device)
89
+ seq=seq_tmp.repeat(seq.shape[0], 1)
90
+
91
+ ###################
92
+ ###msa diffusion###
93
+ ###################
94
+
95
+ ### msa_masked ###
96
+ #ic(msa_masked.shape)
97
+ B,N,L,_=msa_masked.shape
98
+ msa_masked[:,0,:,:22] = seq_diffused
99
+
100
+ x_0_msa = msa_masked[0,1:,:,:22].float()*2-1
101
+ msa_seq_mask = seq_mask.unsqueeze(0).repeat(N-1, 1)
102
+ msa_diffused = diffuser.q_sample(x_0_msa,torch.tensor([t]),mask=msa_seq_mask)
103
+
104
+ msa_masked[:,1:,:,:22] = torch.clone(msa_diffused)
105
+
106
+ # index 44/45 is insertion/deletion
107
+ # index 43 is the masked token NOTE check this
108
+ # index 42 is the unknown token
109
+ msa_masked[:,0,:,22:44] = seq_diffused
110
+ msa_masked[:,1:,:,22:44] = msa_diffused
111
+
112
+ # insertion/deletion stuff
113
+ msa_masked[:,0,~seq_mask,44:46] = 0
114
+
115
+ ### msa_full ###
116
+ ################
117
+ #make msa_full same size as msa_masked
118
+ #ic(msa_full.shape)
119
+ msa_full = msa_full[:,:msa_masked.shape[1],:,:]
120
+ msa_full[:,0,:,:22] = seq_diffused
121
+ msa_full[:,1:,:,:22] = msa_diffused
122
+
123
+ ### t1d ###
124
+ ###########
125
+ # NOTE: adjusting t1d last dim (confidence) from sequence mask
126
+ t1d = torch.cat((t1d, torch.zeros((t1d.shape[0],t1d.shape[1],1)).float()), -1).to(seq.device)
127
+ t1d[:,:,:21] = seq_diffused[...,:21]
128
+
129
+ #t1d[:,:,21] *= input_t1dconf_mask
130
+ #set diffused conf to 0 and everything else to 1
131
+ t1d[:,~seq_mask,21] = 0.0
132
+ t1d[:,seq_mask,21] = 1.0
133
+
134
+ t1d[:1,:,22] = 1-t/diffuser.num_timesteps
135
+
136
+ #to do add structure confidence metric; need to expand dimensions of chkpt b4
137
+ #if t1d_24: JG - changed to be default
138
+ t1d = torch.cat((t1d, torch.zeros((t1d.shape[0],t1d.shape[1],1)).float()), -1).to(seq.device)
139
+ t1d[:,~str_mask,23] = 0.0
140
+ t1d[:,str_mask,23] = 1.0
141
+
142
+ if dssp:
143
+ print(f'adding dssp {frac_provide_dssp} of time')
144
+ t1d = torch.cat((t1d, torch.zeros((t1d.shape[0],t1d.shape[1],4)).float()), -1).to(seq.device)
145
+ #dssp info
146
+ #mask some percentage of dssp info in range dssp_mask_percentage[0],dssp_mask_percentage[1]
147
+ percentage_mask=random.randint(dssp_mask_percentage[0], dssp_mask_percentage[1])
148
+ dssp=annotate_sse(np.array(xyz_t[0,:,1,:].squeeze()), percentage_mask=percentage_mask)
149
+ #dssp_unmasked = annotate_sse(np.array(xyz_t[0,:,1,:].squeeze()), percentage_mask=0)
150
+ if np.random.rand()>frac_provide_dssp:
151
+ print('masking dssp')
152
+ dssp[...]=0 #replace with mask token
153
+ dssp[:,-1]=1
154
+ t1d[...,24:]=dssp
155
+
156
+ if hotspots:
157
+ print(f"adding hotspots {frac_provide_contacts} of time")
158
+ t1d = torch.cat((t1d, torch.zeros((t1d.shape[0],t1d.shape[1],1)).float()), -1).to(seq.device)
159
+ #mask all contacts some fraction of the time
160
+ if np.random.rand()>frac_provide_contacts:
161
+ print('masking contacts')
162
+ contacts = torch.zeros(L)
163
+ t1d[...,-1] = contacts
164
+
165
+ ### xyz_t ###
166
+ #############
167
+ xyz_t = get_init_xyz(xyz_t[None])
168
+ xyz_t = xyz_t[0]
169
+ #Sequence masking
170
+ xyz_t[:,:,3:,:] = float('nan')
171
+ # Structure masking
172
+ if struc_cond:
173
+ print("non-autoregressive structure conditioning")
174
+ r = diffuser.alphas_cumprod[t]
175
+ xyz_mask = (torch.rand(xyz_t.shape[1]) > r).to(torch.bool).to(seq.device)
176
+ xyz_mask = torch.logical_and(xyz_mask,~str_mask)
177
+ xyz_t[:,xyz_mask,:,:] = float('nan')
178
+ else:
179
+ xyz_t[:,~str_mask,:,:] = float('nan')
180
+
181
+ ### mask_msa ###
182
+ ################
183
+ # NOTE: this is for loss scoring
184
+ mask_msa[:,:,~loss_seq_mask] = False
185
+
186
+ out=dict(
187
+ seq= seq,
188
+ msa_masked= msa_masked,
189
+ msa_full= msa_full,
190
+ xyz_t= xyz_t,
191
+ t1d= t1d,
192
+ mask_msa= mask_msa,
193
+ seq_diffused= seq_diffused
194
+ )
195
+
196
+ return out