Simon Duerr commited on
Commit
8c639ec
0 Parent(s):
CODEOWNERS ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # Global owner
2
+ * @alexechu
LICENSE ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ BSD 3-Clause License
2
+
3
+ Copyright (c) 2022, Alex Chu
4
+ All rights reserved.
5
+
6
+ Redistribution and use in source and binary forms, with or without
7
+ modification, are permitted provided that the following conditions are met:
8
+
9
+ 1. Redistributions of source code must retain the above copyright notice, this
10
+ list of conditions and the following disclaimer.
11
+
12
+ 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ this list of conditions and the following disclaimer in the documentation
14
+ and/or other materials provided with the distribution.
15
+
16
+ 3. Neither the name of the copyright holder nor the names of its
17
+ contributors may be used to endorse or promote products derived from
18
+ this software without specific prior written permission.
19
+
20
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
ProteinMPNN ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit 8907e6671bfbfc92303b5f79c4b5e6ce47cdef57
app.py ADDED
@@ -0,0 +1,498 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ import re
4
+ import urllib
5
+
6
+ import tempfile
7
+
8
+ from output_helpers import viewer_html, output_html, load_js, get_js
9
+
10
+
11
+
12
+ import json
13
+ import os
14
+ import shlex
15
+ import subprocess
16
+ from datetime import datetime
17
+
18
+ from einops import repeat
19
+ import torch
20
+
21
+ from core import data
22
+ from core import utils
23
+ import models
24
+ import sampling
25
+
26
+ # from draw_samples import draw_and_save_samples, parse_resample_idx_string
27
+
28
+
29
+
30
+ def draw_and_save_samples(
31
+ model,
32
+ samples_per_len=8,
33
+ lengths=range(50, 512),
34
+ save_dir="./",
35
+ mode="backbone",
36
+ **sampling_kwargs,
37
+ ):
38
+ device = model.device
39
+ sample_files = []
40
+ if mode == "backbone":
41
+ total_sampling_time = 0
42
+ for l in lengths:
43
+ prot_lens = torch.ones(samples_per_len).long() * l
44
+ seq_mask = model.make_seq_mask_for_sampling(prot_lens=prot_lens)
45
+ aux = sampling.draw_backbone_samples(
46
+ model,
47
+ seq_mask=seq_mask,
48
+ pdb_save_path=f"{save_dir}/len{format(l, '03d')}_samp",
49
+ return_aux=True,
50
+ return_sampling_runtime=True,
51
+ **sampling_kwargs,
52
+ )
53
+ total_sampling_time += aux["runtime"]
54
+ sample_files+= [f"{save_dir}/len{format(l, '03d')}_samp{i}.pdb" for i in range(samples_per_len)]
55
+ return sample_files
56
+ elif mode == "allatom":
57
+ total_sampling_time = 0
58
+ for l in lengths:
59
+ prot_lens = torch.ones(samples_per_len).long() * l
60
+ seq_mask = model.make_seq_mask_for_sampling(prot_lens=prot_lens)
61
+ aux = sampling.draw_allatom_samples(
62
+ model,
63
+ seq_mask=seq_mask,
64
+ pdb_save_path=f"{save_dir}/len{format(l, '03d')}",
65
+ return_aux=True,
66
+ **sampling_kwargs,
67
+ )
68
+ total_sampling_time += aux["runtime"]
69
+ sample_files+= [f"{save_dir}/len{format(l, '03d')}_samp{i}.pdb" for i in range(samples_per_len)]
70
+ return sample_files
71
+
72
+
73
+ def parse_idx_string(idx_str):
74
+ spans = idx_str.split(",")
75
+ idxs = []
76
+ for s in spans:
77
+ if "-" in s:
78
+ start, stop = s.split("-")
79
+ idxs.extend(list(range(int(start), int(stop))))
80
+ else:
81
+ idxs.append(int(s))
82
+ return idxs
83
+
84
+ def changemode(m):
85
+ if (m == "unconditional"):
86
+ return gr.update(visible=True), gr.update(visible=False),gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)
87
+ else:
88
+ return gr.update(visible=False), gr.update(visible=True),gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)
89
+
90
+ def fileselection(val):
91
+ if (val == "upload"):
92
+ return gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)
93
+ else:
94
+ return gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)
95
+
96
+ def update_structuresel(pdb, radio_val):
97
+ pdb_file = tempfile.NamedTemporaryFile(delete=False, suffix=".pdb")
98
+
99
+
100
+ representations = [{
101
+ "model": 0,
102
+ "chain": "",
103
+ "resname": "",
104
+ "style": "cartoon",
105
+ "color": "whiteCarbon",
106
+ "residue_range": "",
107
+ "around": 0,
108
+ "byres": False,
109
+ "visible": False,
110
+ }]
111
+
112
+
113
+ if (radio_val == "PDB"):
114
+ if (len(pdb) != 4):
115
+ return gr.update(open=True),gr.update(), gr.update(value="",visible=False)
116
+ else:
117
+ urllib.request.urlretrieve(
118
+ f"http://files.rcsb.org/download/{pdb.lower()}.pdb1",
119
+ pdb_file.name,
120
+ )
121
+ return gr.update(open=False),gr.update(value=pdb_file.name), gr.update(value=f"""<iframe style="width: 100%; height: 930px" name="result" allow="midi; geolocation; microphone; camera;
122
+ display-capture; encrypted-media;" sandbox="allow-modals allow-forms
123
+ allow-scripts allow-same-origin allow-popups
124
+ allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
125
+ allowpaymentrequest="" frameborder="0" srcdoc='{viewer_html(pdb_file.name, representations=representations)}'></iframe>""",visible=True)
126
+ elif (radio_val == "AFDB2"):
127
+ if (re.match("[OPQ][0-9][A-Z0-9]{3}[0-9]|[A-NR-Z][0-9]([A-Z][A-Z0-9]{2}[0-9]){1,2}",pdb) != None):
128
+ urllib.request.urlretrieve(
129
+ f"https://alphafold.ebi.ac.uk/files/AF-{pdb}-F1-model_v2.pdb",
130
+ pdb_file.name
131
+ )
132
+ return gr.update(open=False),gr.update(value=pdb_file.name), gr.update(value=f"""<iframe style="width: 100%; height: 930px" name="result" allow="midi; geolocation; microphone; camera;
133
+ display-capture; encrypted-media;" sandbox="allow-modals allow-forms
134
+ allow-scripts allow-same-origin allow-popups
135
+ allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
136
+ allowpaymentrequest="" frameborder="0" srcdoc='{viewer_html(pdb_file.name, representations=representations)}'></iframe>""",visible=True)
137
+ else:
138
+ return gr.update(open=True), gr.update(value="regex not matched",visible=True)
139
+ else:
140
+ return gr.update(open=False),gr.update(value=f"{pdb.name}"), gr.update(value=f"""<iframe style="width: 100%; height: 930px" name="result" allow="midi; geolocation; microphone; camera;
141
+ display-capture; encrypted-media;" sandbox="allow-modals allow-forms
142
+ allow-scripts allow-same-origin allow-popups
143
+ allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
144
+ allowpaymentrequest="" frameborder="0" srcdoc='{viewer_html(pdb.name, representations=representations)}'></iframe>""",visible=True)
145
+
146
+ from Bio.PDB import PDBParser, cealign
147
+ from Bio.PDB.PDBIO import PDBIO
148
+
149
+ class dotdict(dict):
150
+ """dot.notation access to dictionary attributes"""
151
+ __getattr__ = dict.get
152
+ __setattr__ = dict.__setitem__
153
+ __delattr__ = dict.__delitem__
154
+
155
+ def protpardelle(path_to_file, m, resample_idx, modeltype, minlen, maxlen, steplen, perlen):
156
+ # Set up params, arguments, sampling config
157
+ ####################
158
+
159
+ args = {}
160
+ args["model_checkpoint"] = "checkpoints" #Path to denoiser model weights and config",
161
+
162
+ args["mpnnpath"] = "checkpoints/minimpnn_state_dict.pth" #"Path to minimpnn model weights",
163
+
164
+ args["modeldir"] = None #"Model base directory, ex 'training_logs/other/lemon-shape-51'",
165
+
166
+ args["modelepoch"] = None #"Model epoch, ex 1000")
167
+
168
+
169
+ args["type"]=modeltype # "Type of model"
170
+ if m == "conditional":
171
+ args["param"] = None #"Which sampling param to vary"
172
+ args["paramval"]=None #"Which param val to use"
173
+ args["parampath"]= None # Path to json file with params, either use param/paramval or parampath, not both",
174
+ args["perlen"] = int(perlen) #How many samples per sequence length"
175
+ args["minlen"] = None #"Minimum sequence length"
176
+ args["maxlen"] = None #Maximum sequence length, not inclusive",
177
+ args["steplen"] = int(steplen) #"How frequently to select sequence length, for steplen 2, would be 50, 52, 54, etc",
178
+ args["num_lens"] = None #"If steplen not provided, how many random lengths to sample at",
179
+ args["targetdir"] = "." #"Directory to save results"
180
+ args["input_pdb"] = path_to_file # "PDB file to condition on"
181
+ args["resample_idxs"] = resample_idx[1:-1] # "Indices from PDB file to resample. Zero-indexed, comma-delimited, can use dashes, eg 0,2-5,7"
182
+ else:
183
+ args["param"] = "n_steps" #"Which sampling param to vary"
184
+ args["paramval"]="100" #"Which param val to use"
185
+ args["parampath"]= None # Path to json file with params, either use param/paramval or parampath, not both",
186
+ args["perlen"] = int(perlen) #How many samples per sequence length"
187
+ args["minlen"] = int(minlen) #"Minimum sequence length"
188
+ args["maxlen"] = int(maxlen)+1 #Maximum sequence length
189
+ args["steplen"] = int(steplen) #"How frequently to select sequence length, for steplen 2, would be 50, 52, 54, etc",
190
+ args["num_lens"] = None #"If steplen not provided, how many random lengths to sample at",
191
+ args["targetdir"] = "." #"Directory to save results"
192
+ args["resample_idxs"] = None
193
+
194
+ args = dotdict(args)
195
+ is_test_run = False
196
+ seed = 0
197
+ samples_per_len = args.perlen
198
+ min_len = args.minlen
199
+ max_len = args.maxlen
200
+ len_step_size = args.steplen
201
+ device = "cuda:0"
202
+
203
+ # setting default sampling config
204
+ if args.type == "backbone":
205
+ sampling_config = sampling.default_backbone_sampling_config()
206
+ elif args.type == "allatom":
207
+ sampling_config = sampling.default_allatom_sampling_config()
208
+
209
+ sampling_kwargs = vars(sampling_config)
210
+
211
+ # Parse conditioning inputs
212
+ input_pdb_len = None
213
+ if args.input_pdb:
214
+ input_feats = utils.load_feats_from_pdb(args.input_pdb, protein_only=True)
215
+ input_pdb_len = input_feats["aatype"].shape[0]
216
+ if args.resample_idxs:
217
+ print(
218
+ f"Warning: when sampling conditionally, the input pdb length ({input_pdb_len} residues) is used automatically for the sampling lengths."
219
+ )
220
+ resample_idxs = parse_idx_string(args.resample_idxs)
221
+ else:
222
+ resample_idxs = list(range(input_pdb_len))
223
+ cond_idxs = [i for i in range(input_pdb_len) if i not in resample_idxs]
224
+ to_batch_size = lambda x: repeat(x, "... -> b ...", b=samples_per_len).to(
225
+ device
226
+ )
227
+
228
+ # For unconditional model, center coords on whole structure
229
+ centered_coords = data.apply_random_se3(
230
+ input_feats["atom_positions"],
231
+ atom_mask=input_feats["atom_mask"],
232
+ translation_scale=0.0,
233
+ )
234
+ cond_kwargs = {}
235
+ cond_kwargs["gt_coords"] = to_batch_size(centered_coords)
236
+ cond_kwargs["gt_cond_atom_mask"] = to_batch_size(input_feats["atom_mask"])
237
+ cond_kwargs["gt_cond_atom_mask"][:, resample_idxs] = 0
238
+ cond_kwargs["gt_aatype"] = to_batch_size(input_feats["aatype"])
239
+ cond_kwargs["gt_cond_seq_mask"] = torch.zeros_like(cond_kwargs["gt_aatype"])
240
+ cond_kwargs["gt_cond_seq_mask"][:, cond_idxs] = 1
241
+ sampling_kwargs.update(cond_kwargs)
242
+
243
+ print("input_pdb_len", input_pdb_len)
244
+
245
+ # Determine lengths to sample at
246
+ if min_len is not None and max_len is not None:
247
+ if len_step_size is not None:
248
+ sampling_lengths = range(min_len, max_len, len_step_size)
249
+ else:
250
+ sampling_lengths = list(
251
+ torch.randint(min_len, max_len, size=(args.num_lens,))
252
+ )
253
+ elif input_pdb_len is not None:
254
+ sampling_lengths = [input_pdb_len]
255
+ else:
256
+ raise Exception("Need to provide a set of protein lengths or an input pdb.")
257
+
258
+ total_num_samples = len(list(sampling_lengths)) * samples_per_len
259
+
260
+ model_directory = args.modeldir
261
+ epoch = args.modelepoch
262
+ base_dir = args.targetdir
263
+
264
+ date_string = datetime.now().strftime("%y-%m-%d-%H-%M-%S")
265
+ if is_test_run:
266
+ date_string = f"test-{date_string}"
267
+
268
+ # Update sampling config with arguments
269
+ if args.param:
270
+ var_param = args.param
271
+ var_value = args.paramval
272
+ sampling_kwargs[var_param] = (
273
+ None
274
+ if var_value == "None"
275
+ else int(var_value)
276
+ if var_param == "n_steps"
277
+ else float(var_value)
278
+ )
279
+ elif args.parampath:
280
+ with open(args.parampath) as f:
281
+ var_params = json.loads(f.read())
282
+ sampling_kwargs.update(var_params)
283
+
284
+ # this is only used for the readme, keep s_min and s_max as params instead of struct_noise_schedule
285
+ sampling_kwargs_readme = list(sampling_kwargs.items())
286
+
287
+ print("Base directory:", base_dir)
288
+ save_dir = f"{base_dir}/samples/{date_string}"
289
+ save_init_dir = f"{base_dir}/samples_inits/{date_string}"
290
+
291
+ # make dirs if do not exist
292
+ if not os.path.exists(save_dir):
293
+ subprocess.run(shlex.split(f"mkdir -p {save_dir}"))
294
+
295
+ if not os.path.exists(save_init_dir):
296
+ subprocess.run(shlex.split(f"mkdir -p {save_init_dir}"))
297
+
298
+ print("Samples saved to:", save_dir)
299
+ torch.manual_seed(seed)
300
+
301
+ # Load model
302
+ if args.type == "backbone":
303
+ if args.model_checkpoint:
304
+ checkpoint = f"{args.model_checkpoint}/backbone_state_dict.pth"
305
+ cfg_path = f"{args.model_checkpoint}/backbone.yml"
306
+ else:
307
+ checkpoint = (
308
+ f"{model_directory}/checkpoints/epoch{epoch}_training_state.pth"
309
+ )
310
+ cfg_path = f"{model_directory}/configs/backbone.yml"
311
+ cfg = utils.load_config(cfg_path)
312
+ weights = torch.load(checkpoint, map_location=device)["model_state_dict"]
313
+ model = models.Protpardelle(cfg, device=device)
314
+ model.load_state_dict(weights)
315
+ model.to(device)
316
+ model.eval()
317
+ model.device = device
318
+ elif args.type == "allatom":
319
+ if args.model_checkpoint:
320
+ checkpoint = f"{args.model_checkpoint}/allatom_state_dict.pth"
321
+ cfg_path = f"{args.model_checkpoint}/allatom.yml"
322
+ else:
323
+ checkpoint = (
324
+ f"{model_directory}/checkpoints/epoch{epoch}_training_state.pth"
325
+ )
326
+ cfg_path = f"{model_directory}/configs/allatom.yml"
327
+ config = utils.load_config(cfg_path)
328
+ weights = torch.load(checkpoint, map_location=device)["model_state_dict"]
329
+ model = models.Protpardelle(config, device=device)
330
+ model.load_state_dict(weights)
331
+ model.load_minimpnn(args.mpnnpath)
332
+ model.to(device)
333
+ model.eval()
334
+ model.device = device
335
+
336
+ with open(save_dir + "/run_parameters.txt", "w") as f:
337
+ f.write(f"Sampling run for {date_string}\n")
338
+ f.write(f"Random seed {seed}\n")
339
+ f.write(f"Model checkpoint: {checkpoint}\n")
340
+ f.write(
341
+ f"{samples_per_len} samples per length from {min_len}:{max_len}:{len_step_size}\n"
342
+ )
343
+ f.write("Sampling params:\n")
344
+ for k, v in sampling_kwargs_readme:
345
+ f.write(f"{k}\t{v}\n")
346
+
347
+ # Draw samples
348
+ output_files = draw_and_save_samples(
349
+ model,
350
+ samples_per_len=samples_per_len,
351
+ lengths=sampling_lengths,
352
+ save_dir=save_dir,
353
+ mode=args.type,
354
+ **sampling_kwargs,
355
+ )
356
+
357
+ return output_files
358
+
359
+
360
+ def api_predict(pdb_content,m, resample_idx, modeltype, minlen, maxlen, steplen, perlen):
361
+
362
+ if (m == "conditional"):
363
+ tempPDB = tempfile.NamedTemporaryFile(delete=False, suffix=".pdb")
364
+ tempPDB.write(pdb_content.encode())
365
+ tempPDB.close()
366
+
367
+ path_to_file = tempPDB.name
368
+ else:
369
+ path_to_file = None
370
+
371
+ try:
372
+ designs = protpardelle(path_to_file, m, resample_idx, modeltype, minlen, maxlen, steplen, perlen)
373
+ except Exception as e:
374
+ print(e)
375
+
376
+ raise gr.Error(e)
377
+
378
+ # load each design as string
379
+ design_str = []
380
+ for d in designs:
381
+ with open(d, "r") as f:
382
+ design_str.append(f.read())
383
+
384
+ results = list(zip(designs, design_str))
385
+ return json.dumps(results)
386
+
387
+ def predict(pdb_radio, path_to_file,m, resample_idx, modeltype, minlen, maxlen, steplen, perlen):
388
+ print("running predict")
389
+ try:
390
+ designs = protpardelle(path_to_file, m, resample_idx, modeltype, minlen, maxlen, steplen, perlen)
391
+ except Exception as e:
392
+ print(e)
393
+
394
+ raise gr.Error(e)
395
+
396
+ return gr.update(open=True), gr.update(value="something went wrong")
397
+
398
+ parser = PDBParser()
399
+ aligner = cealign.CEAligner()
400
+ io=PDBIO()
401
+ aligned_designs = []
402
+ metrics = []
403
+ if (m == "conditional"):
404
+ ref = parser.get_structure("ref", path_to_file)
405
+ aligner.set_reference(ref)
406
+
407
+ for d in designs:
408
+ design = parser.get_structure("design", d)
409
+ aligner.align(design)
410
+ metrics.append({"rms": f"{aligner.rms:.1f}", "len": len(list(design[0].get_residues()))})
411
+ io.set_structure(design)
412
+ io.save(d.replace(".pdb", f"_al.pdb"))
413
+ aligned_designs.append(d.replace(".pdb", f"_al.pdb"))
414
+ else:
415
+ for d in designs:
416
+ design = parser.get_structure("design", d)
417
+ metrics.append({"len": len(list(design[0].get_residues()))})
418
+ aligned_designs = designs
419
+
420
+ output_view = f"""<iframe style="width: 100%; height: 900px" name="result" allow="midi; geolocation; microphone; camera;
421
+ display-capture; encrypted-media;" sandbox="allow-modals allow-forms
422
+ allow-scripts allow-same-origin allow-popups
423
+ allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
424
+ allowpaymentrequest="" frameborder="0" srcdoc='{output_html(path_to_file, aligned_designs, metrics, resample_idx=resample_idx, mode=m)}'></iframe>"""
425
+
426
+ return gr.update(open=False), gr.update(value=output_view,visible=True)
427
+
428
+
429
+ protpardelleDemo = gr.Blocks()
430
+
431
+ with protpardelleDemo:
432
+ gr.Markdown("# Protpardelle")
433
+ gr.Markdown(""" An all-atom protein generative model
434
+ Alexander E. Chu, Lucy Cheng, Gina El Nesr, Minkai Xu, Po-Ssu Huang
435
+ doi: https://doi.org/10.1101/2023.05.24.542194""")
436
+
437
+ with gr.Accordion(label="Input options", open=True) as input_accordion:
438
+ model = gr.Dropdown(["backbone", "allatom"], value="allatom", label="What to sample?")
439
+
440
+ m = gr.Radio(['unconditional','conditional'],value="unconditional", label="Choose a Mode")
441
+
442
+
443
+ #unconditional
444
+ with gr.Group(visible=True) as uncond:
445
+ gr.Markdown("Unconditional Sampling")
446
+ # length = gr.Slider(minimum=0, maximum=200, step=1, value=50, label="length")
447
+ # param = gr.Dropdown(["length", "param"], value="length", label="Which sampling param to vary?")
448
+ # paramval = gr.Dropdown(["nsteps"], label="paramval", info="Which param val to use?")
449
+
450
+ #conditional
451
+ with gr.Group(visible=False) as cond:
452
+ with gr.Accordion(label="Structure to condition on", open=True) as input_accordion:
453
+ pdb_radio = gr.Radio(['PDB','AF2 EBI DB', 'upload'],value="PDB", label="source of the structure")
454
+ pdbcode = gr.Textbox(label="Uniprot code to be retrieved Alphafold2 Database", visible=True)
455
+ pdbfile = gr.File(label="PDB File", visible=False)
456
+ btn_load = gr.Button("Load PDB")
457
+ pdb_radio.change(fileselection, inputs=pdb_radio, outputs=[pdbcode, pdbfile, btn_load])
458
+
459
+
460
+
461
+ pdb_html = gr.HTML("", visible=False)
462
+
463
+
464
+ path_to_file = gr.Textbox(label="Path to file", visible=False)
465
+ resample_idxs = gr.Textbox(label="Cond Idxs", interactive=False, info="Zero indexed list of indices to condition on, select in sequence viewer above")
466
+ btn_load.click(update_structuresel, inputs=[pdbcode, pdb_radio], outputs=[input_accordion,path_to_file,pdb_html])
467
+ pdbfile.change(update_structuresel, inputs=[pdbfile,pdb_radio], outputs=[input_accordion,path_to_file,pdb_html])
468
+
469
+ with gr.Accordion(label="Sizes", open=True) as size_uncond:
470
+ with gr.Row():
471
+ minlen = gr.Slider(minimum=2, maximum=200,value=50, step=1, label="minlen", info="Minimum sequence length")
472
+ maxlen = gr.Slider(minimum=3, maximum=200,value=60, step=1, label="maxlen", info="Maximum sequence length")
473
+ steplen = gr.Slider(minimum=1, maximum=50, step=1, value=1, label="steplen", info="How frequently to select sequence length?" )
474
+ perlen = gr.Slider(minimum=1, maximum=200, step=1, value=2, label="perlen", info="How many samples per sequence length?")
475
+
476
+
477
+ btn_conditional = gr.Button("Run conditional",visible=False)
478
+ btn_unconditional = gr.Button("Run unconditional")
479
+ m.change(changemode, inputs=m, outputs=[uncond, cond, btn_unconditional, btn_conditional, size_uncond])
480
+ out = gr.HTML("", visible=True)
481
+
482
+ btn_unconditional.click(predict, inputs=[pdb_radio, path_to_file,m, resample_idxs, model, minlen, maxlen, steplen, perlen], outputs=[input_accordion, out])
483
+
484
+ btn_conditional.click(fn=None,
485
+ inputs=[resample_idxs],
486
+ outputs=[resample_idxs],
487
+ _js=get_js
488
+ ) #
489
+ out_text = gr.Textbox(label="Output", visible=False)
490
+ #hidden button for named api route
491
+ pdb_content = gr.Textbox(label="PDB Content", visible=False)
492
+ btn_api = gr.Button("Run API",visible=False)
493
+ btn_api.click(api_predict, inputs=[pdb_content,m, resample_idxs, model, minlen, maxlen, steplen, perlen], outputs=[out_text], api_name="protpardelle")
494
+
495
+ resample_idxs.change(predict, inputs=[pdb_radio, path_to_file,m, resample_idxs, model, minlen, maxlen, steplen, perlen], outputs=[input_accordion, out])
496
+ protpardelleDemo.load(None, None, None, _js=load_js)
497
+ protpardelleDemo.queue()
498
+ protpardelleDemo.launch(allowed_paths=['samples'], share=True)
checkpoints/allatom.yml ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ train:
2
+ home_dir: '/home/duerr/phd/08_Code/protpardelle-final'
3
+ seed: 0
4
+ checkpoint: ['', 0]
5
+ batch_size: 32
6
+ max_epochs: 10000
7
+ eval_freq: 7200 # seconds
8
+ checkpoint_freq: 50
9
+ checkpoints: []
10
+ lr: 0.0001
11
+ warmup_steps: 1000
12
+ decay_steps: 2_000_000
13
+ clip_grad_norm: True
14
+ grad_clip_val: 1.0
15
+ weight_decay: 0.0
16
+ n_eval_samples: 8
17
+ sample_length_range: [50, 512]
18
+ sc_num_seqs: 4
19
+ eval_loss_t: [0.1, 0.3, 0.5, 0.7, 0.9]
20
+ self_cond_train_prob: 0.9
21
+ subsample_eval_set: 0.05
22
+ crop_conditional: False
23
+
24
+ data:
25
+ pdb_path: 'datasets/ingraham_cath_dataset'
26
+ fixed_size: 512
27
+ n_aatype_tokens: 21
28
+ se3_data_augment: True
29
+ sigma_data: 10.0
30
+
31
+ diffusion:
32
+ training:
33
+ function: 'lognormal'
34
+ psigma_mean: -1.0
35
+ psigma_std: 1.5
36
+ sampling:
37
+ function: 'uniform'
38
+ s_min: 0.001
39
+ s_max: 80
40
+
41
+ model:
42
+ task: 'allatom' # 'backbone', 'allatom', 'seqdes', 'codesign'
43
+ pretrained_modules: [] # 'struct_model', 'mpnn_model'
44
+ struct_model_checkpoint: ''
45
+ mpnn_model_checkpoint: ''
46
+ crop_conditional: False
47
+ dummy_fill_masked_atoms: False
48
+ struct_model:
49
+ arch: 'uvit'
50
+ n_atoms: 37
51
+ n_channel: 256
52
+ noise_cond_mult: 4
53
+ uvit:
54
+ patch_size: 1
55
+ n_layers: 6
56
+ n_heads: 8
57
+ dim_head: 32
58
+ n_filt_per_layer: []
59
+ n_blocks_per_layer: 2
60
+ cat_pwd_to_conv: False
61
+ conv_skip_connection: False
62
+ position_embedding_type: 'rotary'
63
+ mpnn_model:
64
+ use_self_conditioning: True
65
+ label_smoothing: 0.1
66
+ n_channel: 128
67
+ n_layers: 3
68
+ n_neighbors: 32
69
+ noise_cond_mult: 4
checkpoints/allatom_state_dict.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c854ce05b3b1b28c45f58ebf6e5cfba5a45b389ea2aa58a6ce25649d90da238f
3
+ size 87550006
checkpoints/backbone.yml ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ train:
2
+ home_dir: '/home/duerr/phd/08_Code/protpardelle-final'
3
+ seed: 0
4
+ checkpoint: ['', 0]
5
+ batch_size: 32
6
+ max_epochs: 10000
7
+ eval_freq: 7200 # seconds
8
+ checkpoint_freq: 50
9
+ checkpoints: []
10
+ lr: 0.0001
11
+ warmup_steps: 1000
12
+ decay_steps: 2_000_000
13
+ clip_grad_norm: True
14
+ grad_clip_val: 1.0
15
+ weight_decay: 0.0
16
+ n_eval_samples: 8
17
+ sample_length_range: [50, 512]
18
+ sc_num_seqs: 4
19
+ eval_loss_t: [0.1, 0.3, 0.5, 0.7, 0.9]
20
+ self_cond_train_prob: 0.9
21
+ subsample_eval_set: 0.05
22
+ crop_conditional: False
23
+
24
+ data:
25
+ pdb_path: 'datasets/ingraham_cath_dataset'
26
+ fixed_size: 384
27
+ n_aatype_tokens: 21
28
+ se3_data_augment: True
29
+ sigma_data: 10.0
30
+
31
+ diffusion:
32
+ training:
33
+ function: 'lognormal'
34
+ psigma_mean: -1.2
35
+ psigma_std: 1.2
36
+ sampling:
37
+ function: 'uniform'
38
+ s_min: 0.001
39
+ s_max: 80
40
+
41
+ model:
42
+ task: 'backbone' # 'backbone', 'allatom', 'seqdes', 'codesign'
43
+ pretrained_modules: [] # 'struct_model', 'mpnn_model'
44
+ struct_model_checkpoint: ''
45
+ mpnn_model_checkpoint: ''
46
+ crop_conditional: False
47
+ dummy_fill_masked_atoms: False
48
+ struct_model:
49
+ arch: 'uvit'
50
+ n_atoms: 37 # keep same shapes, just zero out sidechains
51
+ n_channel: 256
52
+ noise_cond_mult: 4
53
+ uvit:
54
+ patch_size: 1
55
+ n_layers: 6
56
+ n_heads: 8
57
+ dim_head: 32
58
+ n_filt_per_layer: []
59
+ n_blocks_per_layer: 2
60
+ cat_pwd_to_conv: False
61
+ conv_skip_connection: False
62
+ position_embedding_type: 'absolute_residx'
63
+ mpnn_model:
64
+ use_self_conditioning: True
65
+ label_smoothing: 0.1
66
+ n_channel: 128
67
+ n_layers: 3
68
+ n_neighbors: 32
69
+ noise_cond_mult: 4
checkpoints/backbone_state_dict.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2bcbdcca2419beb8f07cc1d43ee4d8c53d7e4ce21b4a144b88218af00ed3b2b9
3
+ size 87548437
checkpoints/minimpnn_state_dict.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:86be202225b3769976ef9bcec75029f4352d670d0107db560eec3d35eeacca9f
3
+ size 100570633
configs/allatom.yml ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ train:
2
+ home_dir: '/home/duerr/phd/08_Code/protpardelle-final'
3
+ seed: 0
4
+ checkpoint: ['', 0]
5
+ batch_size: 32
6
+ max_epochs: 10000
7
+ eval_freq: 7200 # seconds
8
+ checkpoint_freq: 50
9
+ checkpoints: []
10
+ lr: 0.0001
11
+ warmup_steps: 1000
12
+ decay_steps: 2_000_000
13
+ clip_grad_norm: True
14
+ grad_clip_val: 1.0
15
+ weight_decay: 0.0
16
+ n_eval_samples: 8
17
+ sample_length_range: [50, 512]
18
+ sc_num_seqs: 4
19
+ eval_loss_t: [0.1, 0.3, 0.5, 0.7, 0.9]
20
+ self_cond_train_prob: 0.9
21
+ subsample_eval_set: 0.05
22
+ crop_conditional: False
23
+
24
+ data:
25
+ pdb_path: 'datasets/ingraham_cath_dataset'
26
+ fixed_size: 512
27
+ n_aatype_tokens: 21
28
+ se3_data_augment: True
29
+ sigma_data: 10.0
30
+
31
+ diffusion:
32
+ training:
33
+ function: 'lognormal'
34
+ psigma_mean: -1.0
35
+ psigma_std: 1.5
36
+ sampling:
37
+ function: 'uniform'
38
+ s_min: 0.001
39
+ s_max: 80
40
+
41
+ model:
42
+ task: 'allatom' # 'backbone', 'allatom', 'seqdes', 'codesign'
43
+ pretrained_modules: [] # 'struct_model', 'mpnn_model'
44
+ struct_model_checkpoint: ''
45
+ mpnn_model_checkpoint: ''
46
+ crop_conditional: False
47
+ dummy_fill_masked_atoms: False
48
+ struct_model:
49
+ arch: 'uvit'
50
+ n_atoms: 37
51
+ n_channel: 256
52
+ noise_cond_mult: 4
53
+ uvit:
54
+ patch_size: 1
55
+ n_layers: 6
56
+ n_heads: 8
57
+ dim_head: 32
58
+ n_filt_per_layer: []
59
+ n_blocks_per_layer: 2
60
+ cat_pwd_to_conv: False
61
+ conv_skip_connection: False
62
+ position_embedding_type: 'rotary'
63
+ mpnn_model:
64
+ use_self_conditioning: True
65
+ label_smoothing: 0.1
66
+ n_channel: 128
67
+ n_layers: 3
68
+ n_neighbors: 32
69
+ noise_cond_mult: 4
configs/backbone.yml ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ train:
2
+ home_dir: '/scratch/users/alexechu'
3
+ seed: 0
4
+ checkpoint: ['', 0]
5
+ batch_size: 32
6
+ max_epochs: 10000
7
+ eval_freq: 7200 # seconds
8
+ checkpoint_freq: 50
9
+ checkpoints: []
10
+ lr: 0.0001
11
+ warmup_steps: 1000
12
+ decay_steps: 2_000_000
13
+ clip_grad_norm: True
14
+ grad_clip_val: 1.0
15
+ weight_decay: 0.0
16
+ n_eval_samples: 8
17
+ sample_length_range: [50, 512]
18
+ sc_num_seqs: 4
19
+ eval_loss_t: [0.1, 0.3, 0.5, 0.7, 0.9]
20
+ self_cond_train_prob: 0.9
21
+ subsample_eval_set: 0.05
22
+ crop_conditional: False
23
+
24
+ data:
25
+ pdb_path: 'datasets/ingraham_cath_dataset'
26
+ fixed_size: 384
27
+ n_aatype_tokens: 21
28
+ se3_data_augment: True
29
+ sigma_data: 10.0
30
+
31
+ diffusion:
32
+ training:
33
+ function: 'lognormal'
34
+ psigma_mean: -1.2
35
+ psigma_std: 1.2
36
+ sampling:
37
+ function: 'uniform'
38
+ s_min: 0.001
39
+ s_max: 80
40
+
41
+ model:
42
+ task: 'backbone' # 'backbone', 'allatom', 'seqdes', 'codesign'
43
+ pretrained_modules: [] # 'struct_model', 'mpnn_model'
44
+ struct_model_checkpoint: ''
45
+ mpnn_model_checkpoint: ''
46
+ crop_conditional: False
47
+ dummy_fill_masked_atoms: False
48
+ struct_model:
49
+ arch: 'uvit'
50
+ n_atoms: 37 # keep same shapes, just zero out sidechains
51
+ n_channel: 256
52
+ noise_cond_mult: 4
53
+ uvit:
54
+ patch_size: 1
55
+ n_layers: 6
56
+ n_heads: 8
57
+ dim_head: 32
58
+ n_filt_per_layer: []
59
+ n_blocks_per_layer: 2
60
+ cat_pwd_to_conv: False
61
+ conv_skip_connection: False
62
+ position_embedding_type: 'absolute_residx'
63
+ mpnn_model:
64
+ use_self_conditioning: True
65
+ label_smoothing: 0.1
66
+ n_channel: 128
67
+ n_layers: 3
68
+ n_neighbors: 32
69
+ noise_cond_mult: 4
configs/seqdes.yml ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ train:
2
+ home_dir: '/scratch/users/alexechu'
3
+ seed: 0
4
+ checkpoint: ['', 0]
5
+ batch_size: 32
6
+ max_epochs: 10000
7
+ eval_freq: 3600 # seconds
8
+ checkpoint_freq: 20
9
+ checkpoints: []
10
+ lr: 0.0001
11
+ warmup_steps: 1000
12
+ decay_steps: 400_000
13
+ clip_grad_norm: True
14
+ grad_clip_val: 1.0
15
+ weight_decay: 0.0
16
+ n_eval_samples: 8
17
+ sample_length_range: [50, 512]
18
+ sc_num_seqs: 4
19
+ eval_loss_t: [0.1, 0.3, 0.5, 0.7, 0.9]
20
+ self_cond_train_prob: 0.9
21
+ dgram_loss_weight: False
22
+ subsample_eval_set: 0.1
23
+ crop_conditional: False
24
+
25
+ data:
26
+ pdb_path: 'datasets/ingraham_cath_dataset'
27
+ fixed_size: 512
28
+ n_aatype_tokens: 21
29
+ se3_data_augment: True
30
+ sigma_data: 10.0
31
+
32
+ diffusion:
33
+ training:
34
+ function: 'mpnn'
35
+ psigma_mean: -1.2
36
+ psigma_std: 1.2
37
+ time_power: 30.0
38
+ constant_val: 0.02
39
+ sampling:
40
+ function: 'uniform'
41
+ s_min: 0.001
42
+ s_max: 60
43
+
44
+ model:
45
+ task: 'seqdes' # 'backbone', 'allatom', 'seqdes', 'codesign'
46
+ pretrained_modules: ['struct_model'] # 'struct_model', 'mpnn_model'
47
+ struct_model_checkpoint: 'protpardelle/checkpoints/allatom_state_dict.pth'
48
+ mpnn_model_checkpoint: ''
49
+ crop_conditional: False
50
+ dummy_fill_masked_atoms: False
51
+ debug_mpnn: True
52
+ struct_model:
53
+ arch: 'uvit'
54
+ n_channel: 256
55
+ n_atoms: 37
56
+ noise_cond_mult: 4
57
+ uvit:
58
+ patch_size: 1
59
+ n_layers: 6
60
+ n_heads: 8
61
+ dim_head: 32
62
+ n_filt_per_layer: [] # None or [] for vanilla trf
63
+ n_blocks_per_layer: 2
64
+ cat_pwd_to_conv: False
65
+ conv_skip_connection: False # n layers must == 1
66
+ position_embedding_type: 'rotary'
67
+ mpnn_model:
68
+ use_self_conditioning: False
69
+ label_smoothing: 0.0
70
+ n_channel: 128
71
+ n_layers: 3
72
+ n_neighbors: 32
73
+ noise_cond_mult: 4
74
+
core/__init__.py ADDED
File without changes
core/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (169 Bytes). View file
 
core/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (169 Bytes). View file
 
core/__pycache__/data.cpython-38.pyc ADDED
Binary file (6.74 kB). View file
 
core/__pycache__/data.cpython-39.pyc ADDED
Binary file (6.66 kB). View file
 
core/__pycache__/protein.cpython-38.pyc ADDED
Binary file (7.97 kB). View file
 
core/__pycache__/protein.cpython-39.pyc ADDED
Binary file (7.94 kB). View file
 
core/__pycache__/protein_mpnn.cpython-38.pyc ADDED
Binary file (53.5 kB). View file
 
core/__pycache__/protein_mpnn.cpython-39.pyc ADDED
Binary file (53.3 kB). View file
 
core/__pycache__/residue_constants.cpython-38.pyc ADDED
Binary file (21.2 kB). View file
 
core/__pycache__/residue_constants.cpython-39.pyc ADDED
Binary file (24 kB). View file
 
core/__pycache__/utils.cpython-38.pyc ADDED
Binary file (30.3 kB). View file
 
core/__pycache__/utils.cpython-39.pyc ADDED
Binary file (30.1 kB). View file
 
core/data.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ https://github.com/ProteinDesignLab/protpardelle
3
+ License: MIT
4
+ Author: Alex Chu
5
+
6
+ Dataloader from PDB files.
7
+ """
8
+ import copy
9
+ import pickle
10
+ import json
11
+ import numpy as np
12
+ import torch
13
+ from torch.utils import data
14
+
15
+ from core import utils
16
+ from core import protein
17
+ from core import residue_constants
18
+
19
+
20
+ FEATURES_1D = (
21
+ "coords_in",
22
+ "torsions_in",
23
+ "b_factors",
24
+ "atom_positions",
25
+ "aatype",
26
+ "atom_mask",
27
+ "residue_index",
28
+ "chain_index",
29
+ )
30
+ FEATURES_FLOAT = (
31
+ "coords_in",
32
+ "torsions_in",
33
+ "b_factors",
34
+ "atom_positions",
35
+ "atom_mask",
36
+ "seq_mask",
37
+ )
38
+ FEATURES_LONG = ("aatype", "residue_index", "chain_index", "orig_size")
39
+
40
+
41
+ def make_fixed_size_1d(data, fixed_size=128):
42
+ data_len = data.shape[0]
43
+ if data_len >= fixed_size:
44
+ extra_len = data_len - fixed_size
45
+ start_idx = np.random.choice(np.arange(extra_len + 1))
46
+ new_data = data[start_idx : (start_idx + fixed_size)]
47
+ mask = torch.ones(fixed_size)
48
+ if data_len < fixed_size:
49
+ pad_size = fixed_size - data_len
50
+ extra_shape = data.shape[1:]
51
+ new_data = torch.cat([data, torch.zeros(pad_size, *extra_shape)], 0)
52
+ mask = torch.cat([torch.ones(data_len), torch.zeros(pad_size)], 0)
53
+ return new_data, mask
54
+
55
+
56
+ def apply_random_se3(coords_in, atom_mask=None, translation_scale=1.0):
57
+ # unbatched. center on the mean of CA coords
58
+ coords_mean = coords_in[:, 1:2].mean(-3, keepdim=True)
59
+ coords_in -= coords_mean
60
+ random_rot, _ = torch.linalg.qr(torch.randn(3, 3))
61
+ coords_in = coords_in @ random_rot
62
+ random_trans = torch.randn_like(coords_mean) * translation_scale
63
+ coords_in += random_trans
64
+ if atom_mask is not None:
65
+ coords_in = coords_in * atom_mask[..., None]
66
+ return coords_in
67
+
68
+
69
+ def get_masked_coords_array(coords, atom_mask):
70
+ ma_mask = repeat(1 - atom_mask[..., None].cpu().numpy(), "... 1 -> ... 3")
71
+ return np.ma.array(coords.cpu().numpy(), mask=ma_mask)
72
+
73
+
74
+ def make_crop_cond_mask_and_recenter_coords(
75
+ atom_mask,
76
+ atom_coords,
77
+ contiguous_prob=0.05,
78
+ discontiguous_prob=0.9,
79
+ sidechain_only_prob=0.8,
80
+ max_span_len=10,
81
+ max_discontiguous_res=8,
82
+ dist_threshold=8.0,
83
+ recenter_coords=True,
84
+ ):
85
+ b, n, a = atom_mask.shape
86
+ device = atom_mask.device
87
+ seq_mask = atom_mask[..., 1]
88
+ n_res = seq_mask.sum(-1)
89
+ masks = []
90
+
91
+ for i, nr in enumerate(n_res):
92
+ nr = nr.int().item()
93
+ mask = torch.zeros((n, a), device=device)
94
+ conditioning_type = torch.distributions.Categorical(
95
+ torch.tensor(
96
+ [
97
+ contiguous_prob,
98
+ discontiguous_prob,
99
+ 1.0 - contiguous_prob - discontiguous_prob,
100
+ ]
101
+ )
102
+ ).sample()
103
+ conditioning_type = ["contiguous", "discontiguous", "none"][conditioning_type]
104
+
105
+ if conditioning_type == "contiguous":
106
+ span_len = torch.randint(
107
+ 1, min(max_span_len, nr), (1,), device=device
108
+ ).item()
109
+ span_start = torch.randint(0, nr - span_len, (1,), device=device)
110
+ mask[span_start : span_start + span_len, :] = 1
111
+ elif conditioning_type == "discontiguous":
112
+ # Extract CB atoms coordinates for the i-th example
113
+ cb_atoms = atom_coords[i, :, 3]
114
+ # Pairwise distances between CB atoms
115
+ cb_distances = torch.cdist(cb_atoms, cb_atoms)
116
+ close_mask = (
117
+ cb_distances <= dist_threshold
118
+ ) # Mask for selecting close CB atoms
119
+
120
+ random_residue = torch.randint(0, nr, (1,), device=device).squeeze()
121
+ cb_dist_i = cb_distances[random_residue] + 1e3 * (1 - seq_mask[i])
122
+ close_mask = cb_dist_i <= dist_threshold
123
+ n_neighbors = close_mask.sum().int()
124
+
125
+ # pick how many neighbors (up to 10)
126
+ n_sele = torch.randint(
127
+ 2,
128
+ n_neighbors.clamp(min=3, max=max_discontiguous_res + 1),
129
+ (1,),
130
+ device=device,
131
+ )
132
+
133
+ # Select the indices of CB atoms that are close together
134
+ idxs = torch.arange(n, device=device)[close_mask.bool()]
135
+ idxs = idxs[torch.randperm(len(idxs))[:n_sele]]
136
+
137
+ if len(idxs) > 0:
138
+ mask[idxs] = 1
139
+
140
+ if np.random.uniform() < sidechain_only_prob:
141
+ mask[:, :5] = 0
142
+
143
+ masks.append(mask)
144
+
145
+ crop_cond_mask = torch.stack(masks)
146
+ crop_cond_mask = crop_cond_mask * atom_mask
147
+ if recenter_coords:
148
+ motif_masked_array = get_masked_coords_array(atom_coords, crop_cond_mask)
149
+ cond_coords_center = motif_masked_array.mean((1, 2))
150
+ motif_mask = torch.Tensor(1 - cond_coords_center.mask).to(crop_cond_mask)
151
+ means = torch.Tensor(cond_coords_center.data).to(atom_coords) * motif_mask
152
+ coords_out = atom_coords - rearrange(means, "b c -> b 1 1 c")
153
+ else:
154
+ coords_out = atom_coords
155
+ return coords_out, crop_cond_mask
156
+
157
+
158
+ class Dataset(data.Dataset):
159
+ """Loads and processes PDBs into tensors."""
160
+
161
+ def __init__(
162
+ self,
163
+ pdb_path,
164
+ fixed_size,
165
+ mode="train",
166
+ overfit=-1,
167
+ short_epoch=False,
168
+ se3_data_augment=True,
169
+ ):
170
+ self.pdb_path = pdb_path
171
+ self.fixed_size = fixed_size
172
+ self.mode = mode
173
+ self.overfit = overfit
174
+ self.short_epoch = short_epoch
175
+ self.se3_data_augment = se3_data_augment
176
+
177
+ with open(f"{self.pdb_path}/{mode}_pdb_keys.list") as f:
178
+ self.pdb_keys = np.array(f.read().split("\n")[:-1])
179
+
180
+ if overfit > 0:
181
+ n_data = len(self.pdb_keys)
182
+ self.pdb_keys = np.random.choice(
183
+ self.pdb_keys, min(n_data, overfit), replace=False
184
+ ).repeat(n_data // overfit)
185
+
186
+ def __len__(self):
187
+ if self.short_epoch:
188
+ return min(len(self.pdb_keys), 256)
189
+ else:
190
+ return len(self.pdb_keys)
191
+
192
+ def __getitem__(self, idx):
193
+ pdb_key = self.pdb_keys[idx]
194
+ data = self.get_item(pdb_key)
195
+ # For now, replace dataloading errors with a random pdb. 10 tries
196
+ for _ in range(10):
197
+ if data is not None:
198
+ return data
199
+ pdb_key = self.pdb_keys[np.random.randint(len(self.pdb_keys))]
200
+ data = self.get_item(pdb_key)
201
+ raise Exception("Failed to load data example after 10 tries.")
202
+
203
+ def get_item(self, pdb_key):
204
+ example = {}
205
+
206
+ if self.pdb_path.endswith("cath_s40_dataset"): # CATH pdbs
207
+ data_file = f"{self.pdb_path}/dompdb/{pdb_key}"
208
+ elif self.pdb_path.endswith("ingraham_cath_dataset"): # ingraham splits
209
+ data_file = f"{self.pdb_path}/pdb_store/{pdb_key}"
210
+ else:
211
+ raise Exception("Invalid pdb path.")
212
+
213
+ try:
214
+ example = utils.load_feats_from_pdb(data_file)
215
+ coords_in = example["atom_positions"]
216
+ except FileNotFoundError:
217
+ raise Exception(f"File {pdb_key} not found. Check if dataset is corrupted?")
218
+ except RuntimeError:
219
+ return None
220
+
221
+ # Apply data augmentation
222
+ if self.se3_data_augment:
223
+ coords_in = apply_random_se3(coords_in, atom_mask=example["atom_mask"])
224
+
225
+ orig_size = coords_in.shape[0]
226
+ example["coords_in"] = coords_in
227
+ example["orig_size"] = torch.ones(1) * orig_size
228
+
229
+ fixed_size_example = {}
230
+ seq_mask = None
231
+ for k, v in example.items():
232
+ if k in FEATURES_1D:
233
+ fixed_size_example[k], seq_mask = make_fixed_size_1d(
234
+ v, fixed_size=self.fixed_size
235
+ )
236
+ else:
237
+ fixed_size_example[k] = v
238
+ if seq_mask is not None:
239
+ fixed_size_example["seq_mask"] = seq_mask
240
+
241
+ example_out = {}
242
+ for k, v in fixed_size_example.items():
243
+ if k in FEATURES_FLOAT:
244
+ example_out[k] = v.float()
245
+ elif k in FEATURES_LONG:
246
+ example_out[k] = v.long()
247
+
248
+ return example_out
249
+
250
+ def collate(self, example_list):
251
+ out = {}
252
+ for ex in example_list:
253
+ for k, v in ex.items():
254
+ out.setdefault(k, []).append(v)
255
+ return {k: torch.stack(v) for k, v in out.items()}
256
+
257
+ def sample(self, n=1, return_data=True, return_keys=False):
258
+ keys = self.pdb_keys[torch.randperm(self.__len__())[:n].long()]
259
+
260
+ if return_keys and not return_data:
261
+ return keys
262
+
263
+ if n == 1:
264
+ data = self.collate([self.get_item(keys)])
265
+ else:
266
+ data = self.collate([self.get_item(key) for key in keys])
267
+
268
+ if return_data and return_keys:
269
+ return data, keys
270
+ if return_data and not return_keys:
271
+ return data
core/protein.py ADDED
@@ -0,0 +1,341 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 DeepMind Technologies Limited
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Protein data type.
16
+ Adapted from original code by alexechu.
17
+ """
18
+ import dataclasses
19
+ import io
20
+ from typing import Any, Mapping, Optional
21
+ from core import residue_constants
22
+ from Bio.PDB import PDBParser
23
+ import numpy as np
24
+
25
+ FeatureDict = Mapping[str, np.ndarray]
26
+ ModelOutput = Mapping[str, Any] # Is a nested dict.
27
+
28
+ # Complete sequence of chain IDs supported by the PDB format.
29
+ PDB_CHAIN_IDS = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"
30
+ PDB_MAX_CHAINS = len(PDB_CHAIN_IDS) # := 62.
31
+
32
+
33
+ @dataclasses.dataclass(frozen=True)
34
+ class Protein:
35
+ """Protein structure representation."""
36
+
37
+ # Cartesian coordinates of atoms in angstroms. The atom types correspond to
38
+ # residue_constants.atom_types, i.e. the first three are N, CA, CB.
39
+ atom_positions: np.ndarray # [num_res, num_atom_type, 3]
40
+
41
+ # Amino-acid type for each residue represented as an integer between 0 and
42
+ # 20, where 20 is 'X'.
43
+ aatype: np.ndarray # [num_res]
44
+
45
+ # Binary float mask to indicate presence of a particular atom. 1.0 if an atom
46
+ # is present and 0.0 if not. This should be used for loss masking.
47
+ atom_mask: np.ndarray # [num_res, num_atom_type]
48
+
49
+ # Residue index as used in PDB. It is not necessarily continuous or 0-indexed.
50
+ residue_index: np.ndarray # [num_res]
51
+
52
+ # 0-indexed number corresponding to the chain in the protein that this residue
53
+ # belongs to.
54
+ chain_index: np.ndarray # [num_res]
55
+
56
+ # B-factors, or temperature factors, of each residue (in sq. angstroms units),
57
+ # representing the displacement of the residue from its ground truth mean
58
+ # value.
59
+ b_factors: np.ndarray # [num_res, num_atom_type]
60
+
61
+ def __post_init__(self):
62
+ if len(np.unique(self.chain_index)) > PDB_MAX_CHAINS:
63
+ raise ValueError(
64
+ f"Cannot build an instance with more than {PDB_MAX_CHAINS} chains "
65
+ "because these cannot be written to PDB format."
66
+ )
67
+
68
+
69
+ def from_pdb_string(
70
+ pdb_str: str, chain_id: Optional[str] = None, protein_only: bool = False
71
+ ) -> Protein:
72
+ """Takes a PDB string and constructs a Protein object.
73
+
74
+ WARNING: All non-standard residue types will be converted into UNK. All
75
+ non-standard atoms will be ignored.
76
+
77
+ Args:
78
+ pdb_str: The contents of the pdb file
79
+ chain_id: If chain_id is specified (e.g. A), then only that chain
80
+ is parsed. Otherwise all chains are parsed.
81
+
82
+ Returns:
83
+ A new `Protein` parsed from the pdb contents.
84
+ """
85
+ pdb_fh = io.StringIO(pdb_str)
86
+ parser = PDBParser(QUIET=True)
87
+ structure = parser.get_structure("none", pdb_fh)
88
+ models = list(structure.get_models())
89
+ if len(models) != 1:
90
+ raise ValueError(
91
+ f"Only single model PDBs are supported. Found {len(models)} models."
92
+ )
93
+ model = models[0]
94
+
95
+ atom_positions = []
96
+ aatype = []
97
+ atom_mask = []
98
+ residue_index = []
99
+ chain_ids = []
100
+ b_factors = []
101
+
102
+ for chain in model:
103
+ if chain_id is not None and chain.id != chain_id:
104
+ continue
105
+ for res in chain:
106
+ if protein_only and res.id[0] != " ":
107
+ continue
108
+ if res.id[2] != " ":
109
+ pass
110
+ # raise ValueError(
111
+ # f"PDB contains an insertion code at chain {chain.id} and residue "
112
+ # f"index {res.id[1]}. These are not supported."
113
+ # )
114
+ res_shortname = residue_constants.restype_3to1.get(res.resname, "X")
115
+ restype_idx = residue_constants.restype_order.get(
116
+ res_shortname, residue_constants.restype_num
117
+ )
118
+ pos = np.zeros((residue_constants.atom_type_num, 3))
119
+ mask = np.zeros((residue_constants.atom_type_num,))
120
+ res_b_factors = np.zeros((residue_constants.atom_type_num,))
121
+ for atom in res:
122
+ if atom.name not in residue_constants.atom_types:
123
+ continue
124
+ pos[residue_constants.atom_order[atom.name]] = atom.coord
125
+ mask[residue_constants.atom_order[atom.name]] = 1.0
126
+ res_b_factors[residue_constants.atom_order[atom.name]] = atom.bfactor
127
+ if np.sum(mask) < 0.5:
128
+ # If no known atom positions are reported for the residue then skip it.
129
+ continue
130
+ aatype.append(restype_idx)
131
+ atom_positions.append(pos)
132
+ atom_mask.append(mask)
133
+ residue_index.append(res.id[1])
134
+ chain_ids.append(chain.id)
135
+ b_factors.append(res_b_factors)
136
+
137
+ # Chain IDs are usually characters so map these to ints.
138
+ unique_chain_ids = np.unique(chain_ids)
139
+ chain_id_mapping = {cid: n for n, cid in enumerate(unique_chain_ids)}
140
+ chain_index = np.array([chain_id_mapping[cid] for cid in chain_ids])
141
+
142
+ return Protein(
143
+ atom_positions=np.array(atom_positions),
144
+ atom_mask=np.array(atom_mask),
145
+ aatype=np.array(aatype),
146
+ residue_index=np.array(residue_index),
147
+ chain_index=chain_index,
148
+ b_factors=np.array(b_factors),
149
+ )
150
+
151
+
152
+ def _chain_end(atom_index, end_resname, chain_name, residue_index) -> str:
153
+ chain_end = "TER"
154
+ return (
155
+ f"{chain_end:<6}{atom_index:>5} {end_resname:>3} "
156
+ f"{chain_name:>1}{residue_index:>4}"
157
+ )
158
+
159
+
160
+ def are_atoms_bonded(res3name, atom1_name, atom2_name):
161
+ lookup_table = residue_constants.standard_residue_bonds
162
+ for bond in lookup_table[res3name]:
163
+ if bond.atom1_name == atom1_name and bond.atom2_name == atom2_name:
164
+ return True
165
+ elif bond.atom1_name == atom2_name and bond.atom2_name == atom1_name:
166
+ return True
167
+ return False
168
+
169
+
170
+ def to_pdb(prot: Protein, conect=False) -> str:
171
+ """Converts a `Protein` instance to a PDB string.
172
+
173
+ Args:
174
+ prot: The protein to convert to PDB.
175
+
176
+ Returns:
177
+ PDB string.
178
+ """
179
+ restypes = residue_constants.restypes + ["X"]
180
+ res_1to3 = lambda r: residue_constants.restype_1to3.get(restypes[r], "UNK")
181
+ atom_types = residue_constants.atom_types
182
+
183
+ pdb_lines = []
184
+
185
+ atom_mask = prot.atom_mask
186
+ aatype = prot.aatype
187
+ atom_positions = prot.atom_positions
188
+ residue_index = prot.residue_index.astype(np.int32)
189
+ chain_index = prot.chain_index.astype(np.int32)
190
+ b_factors = prot.b_factors
191
+
192
+ if np.any(aatype > residue_constants.restype_num):
193
+ raise ValueError("Invalid aatypes.")
194
+
195
+ # Construct a mapping from chain integer indices to chain ID strings.
196
+ chain_ids = {}
197
+ for i in np.unique(chain_index): # np.unique gives sorted output.
198
+ if i >= PDB_MAX_CHAINS:
199
+ raise ValueError(
200
+ f"The PDB format supports at most {PDB_MAX_CHAINS} chains."
201
+ )
202
+ chain_ids[i] = PDB_CHAIN_IDS[i]
203
+
204
+ pdb_lines.append("MODEL 1")
205
+ atom_index = 1
206
+ last_chain_index = chain_index[0]
207
+ conect_lines = []
208
+ # Add all atom sites.
209
+ for i in range(aatype.shape[0]):
210
+ # Close the previous chain if in a multichain PDB.
211
+ if last_chain_index != chain_index[i]:
212
+ pdb_lines.append(
213
+ _chain_end(
214
+ atom_index,
215
+ res_1to3(aatype[i - 1]),
216
+ chain_ids[chain_index[i - 1]],
217
+ residue_index[i - 1],
218
+ )
219
+ )
220
+ last_chain_index = chain_index[i]
221
+ atom_index += 1 # Atom index increases at the TER symbol.
222
+
223
+ res_name_3 = res_1to3(aatype[i])
224
+ atoms_appended_for_res = []
225
+ for atom_name, pos, mask, b_factor in zip(
226
+ atom_types, atom_positions[i], atom_mask[i], b_factors[i]
227
+ ):
228
+ if mask < 0.5:
229
+ continue
230
+
231
+ record_type = "ATOM"
232
+ name = atom_name if len(atom_name) == 4 else f" {atom_name}"
233
+ alt_loc = ""
234
+ insertion_code = ""
235
+ occupancy = 1.00
236
+ element = atom_name[0] # Protein supports only C, N, O, S, this works.
237
+ charge = ""
238
+ # PDB is a columnar format, every space matters here!
239
+ atom_line = (
240
+ f"{record_type:<6}{atom_index:>5} {name:<4}{alt_loc:>1}"
241
+ f"{res_name_3:>3} {chain_ids[chain_index[i]]:>1}"
242
+ f"{residue_index[i]:>4}{insertion_code:>1} "
243
+ f"{pos[0]:>8.3f}{pos[1]:>8.3f}{pos[2]:>8.3f}"
244
+ f"{occupancy:>6.2f}{b_factor:>6.2f} "
245
+ f"{element:>2}{charge:>2}"
246
+ )
247
+ pdb_lines.append(atom_line)
248
+
249
+ for prev_atom_idx, prev_atom in atoms_appended_for_res:
250
+ if are_atoms_bonded(res_name_3, atom_name, prev_atom):
251
+ conect_line = f"CONECT{prev_atom_idx:5d}{atom_index:5d}\n"
252
+ conect_lines.append(conect_line)
253
+ atoms_appended_for_res.append((atom_index, atom_name))
254
+ if atom_name == "N":
255
+ n_atom_idx = atom_index
256
+ if atom_name == "C":
257
+ c_atom_idx = atom_index
258
+
259
+ atom_index += 1
260
+
261
+ if i > 0:
262
+ conect_line = f"CONECT{prev_c_atom_idx:5d}{n_atom_idx:5d}\n"
263
+ conect_lines.append(conect_line)
264
+ prev_c_atom_idx = c_atom_idx
265
+
266
+ # Close the final chain.
267
+ pdb_lines.append(
268
+ _chain_end(
269
+ atom_index,
270
+ res_1to3(aatype[-1]),
271
+ chain_ids[chain_index[-1]],
272
+ residue_index[-1],
273
+ )
274
+ )
275
+ pdb_lines.append("ENDMDL")
276
+ pdb_lines.append("END")
277
+
278
+ # Pad all lines to 80 characters.
279
+ pdb_lines = [line.ljust(80) for line in pdb_lines]
280
+ pdb_str = "\n".join(pdb_lines) + "\n" # Add terminating newline.
281
+ if conect:
282
+ conect_str = "".join(conect_lines) + "\n"
283
+ return pdb_str, conect_str
284
+ return pdb_str
285
+
286
+
287
+ def ideal_atom_mask(prot: Protein) -> np.ndarray:
288
+ """Computes an ideal atom mask.
289
+
290
+ `Protein.atom_mask` typically is defined according to the atoms that are
291
+ reported in the PDB. This function computes a mask according to heavy atoms
292
+ that should be present in the given sequence of amino acids.
293
+
294
+ Args:
295
+ prot: `Protein` whose fields are `numpy.ndarray` objects.
296
+
297
+ Returns:
298
+ An ideal atom mask.
299
+ """
300
+ return residue_constants.STANDARD_ATOM_MASK[prot.aatype]
301
+
302
+
303
+ def from_prediction(
304
+ features: FeatureDict,
305
+ result: ModelOutput,
306
+ b_factors: Optional[np.ndarray] = None,
307
+ remove_leading_feature_dimension: bool = True,
308
+ ) -> Protein:
309
+ """Assembles a protein from a prediction.
310
+
311
+ Args:
312
+ features: Dictionary holding model inputs.
313
+ result: Dictionary holding model outputs.
314
+ b_factors: (Optional) B-factors to use for the protein.
315
+ remove_leading_feature_dimension: Whether to remove the leading dimension
316
+ of the `features` values.
317
+
318
+ Returns:
319
+ A protein instance.
320
+ """
321
+ fold_output = result["structure_module"]
322
+
323
+ def _maybe_remove_leading_dim(arr: np.ndarray) -> np.ndarray:
324
+ return arr[0] if remove_leading_feature_dimension else arr
325
+
326
+ if "asym_id" in features:
327
+ chain_index = _maybe_remove_leading_dim(features["asym_id"])
328
+ else:
329
+ chain_index = np.zeros_like(_maybe_remove_leading_dim(features["aatype"]))
330
+
331
+ if b_factors is None:
332
+ b_factors = np.zeros_like(fold_output["final_atom_mask"])
333
+
334
+ return Protein(
335
+ aatype=_maybe_remove_leading_dim(features["aatype"]),
336
+ atom_positions=fold_output["final_atom_positions"],
337
+ atom_mask=fold_output["final_atom_mask"],
338
+ residue_index=_maybe_remove_leading_dim(features["residue_index"]) + 1,
339
+ chain_index=chain_index,
340
+ b_factors=b_factors,
341
+ )
core/protein_mpnn.py ADDED
@@ -0,0 +1,1886 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+
3
+ # Copyright (c) 2022 Justas Dauparas
4
+
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+ '''
24
+ Adapted from original code by alexechu.
25
+ '''
26
+ import json, time, os, sys, glob
27
+ import shutil
28
+ import warnings
29
+ import copy
30
+ import random
31
+ import os.path
32
+ import subprocess
33
+ import itertools
34
+
35
+ from einops.layers.torch import Rearrange
36
+ import numpy as np
37
+ import torch
38
+ from torch import optim
39
+ from torch.utils.data import DataLoader
40
+ from torch.utils.data.dataset import random_split, Subset
41
+ import torch.nn as nn
42
+ import torch.nn.functional as F
43
+
44
+
45
+ def get_mpnn_model(model_name='v_48_020', path_to_model_weights='', ca_only=False, backbone_noise=0.0, verbose=False, device=None):
46
+ hidden_dim = 128
47
+ num_layers = 3
48
+ if device is None:
49
+ device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")
50
+
51
+ if path_to_model_weights:
52
+ model_folder_path = path_to_model_weights
53
+ if model_folder_path[-1] != '/':
54
+ model_folder_path = model_folder_path + '/'
55
+ else:
56
+ file_path = os.path.realpath(__file__)
57
+ k = file_path.rfind("/")
58
+ if ca_only:
59
+ model_folder_path = file_path[:k] + '/ca_model_weights/'
60
+ else:
61
+ model_folder_path = file_path[:k] + '/vanilla_model_weights/'
62
+
63
+ checkpoint_path = model_folder_path + f'{model_name}.pt'
64
+ checkpoint = torch.load(checkpoint_path, map_location=device)
65
+ noise_level_print = checkpoint['noise_level']
66
+ model = ProteinMPNN(ca_only=ca_only, num_letters=21, node_features=hidden_dim, edge_features=hidden_dim, hidden_dim=hidden_dim,
67
+ num_encoder_layers=num_layers, num_decoder_layers=num_layers, augment_eps=backbone_noise, k_neighbors=checkpoint['num_edges'])
68
+ model.to(device)
69
+ model.load_state_dict(checkpoint['model_state_dict'])
70
+ model.eval()
71
+
72
+ if verbose:
73
+ print(40*'-')
74
+ print('Model loaded...')
75
+ print('Number of edges:', checkpoint['num_edges'])
76
+ print(f'Training noise level: {noise_level_print}A')
77
+
78
+ return model
79
+
80
+
81
+ def run_proteinmpnn(model=None, pdb_path='', pdb_path_chains='', path_to_model_weights='', model_name='v_48_020', seed=0, ca_only=False, out_folder='', num_seq_per_target=1, batch_size=1, sampling_temps=[0.1], backbone_noise=0.0, max_length=200000, omit_AAs=[], print_all=False,
82
+ chain_id_jsonl='', fixed_positions_jsonl='', pssm_jsonl='', omit_AA_jsonl='', bias_AA_jsonl='', tied_positions_jsonl='', bias_by_res_jsonl='', jsonl_path='',
83
+ pssm_threshold=0.0, pssm_multi=0.0, pssm_log_odds_flag=False, pssm_bias_flag=False, write_output_files=False):
84
+
85
+ if model is None:
86
+ model = get_mpnn_model(model_name=model_name, path_to_model_weights=path_to_model_weights, ca_only=ca_only, backbone_noise=backbone_noise, verbose=print_all)
87
+
88
+ if seed:
89
+ seed=seed
90
+ else:
91
+ seed=int(np.random.randint(0, high=999, size=1, dtype=int)[0])
92
+
93
+ torch.manual_seed(seed)
94
+ random.seed(seed)
95
+ np.random.seed(seed)
96
+
97
+
98
+
99
+ NUM_BATCHES = num_seq_per_target//batch_size
100
+ BATCH_COPIES = batch_size
101
+ temperatures = sampling_temps
102
+ omit_AAs_list = omit_AAs
103
+ alphabet = 'ACDEFGHIKLMNPQRSTVWYX'
104
+ alphabet_dict = dict(zip(alphabet, range(21)))
105
+ omit_AAs_np = np.array([AA in omit_AAs_list for AA in alphabet]).astype(np.float32)
106
+ device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")
107
+ if os.path.isfile(chain_id_jsonl):
108
+ with open(chain_id_jsonl, 'r') as json_file:
109
+ json_list = list(json_file)
110
+ for json_str in json_list:
111
+ chain_id_dict = json.loads(json_str)
112
+ else:
113
+ chain_id_dict = None
114
+ if print_all:
115
+ print(40*'-')
116
+ print('chain_id_jsonl is NOT loaded')
117
+
118
+ if os.path.isfile(fixed_positions_jsonl):
119
+ with open(fixed_positions_jsonl, 'r') as json_file:
120
+ json_list = list(json_file)
121
+ for json_str in json_list:
122
+ fixed_positions_dict = json.loads(json_str)
123
+ else:
124
+ if print_all:
125
+ print(40*'-')
126
+ print('fixed_positions_jsonl is NOT loaded')
127
+ fixed_positions_dict = None
128
+
129
+
130
+ if os.path.isfile(pssm_jsonl):
131
+ with open(pssm_jsonl, 'r') as json_file:
132
+ json_list = list(json_file)
133
+ pssm_dict = {}
134
+ for json_str in json_list:
135
+ pssm_dict.update(json.loads(json_str))
136
+ else:
137
+ if print_all:
138
+ print(40*'-')
139
+ print('pssm_jsonl is NOT loaded')
140
+ pssm_dict = None
141
+
142
+
143
+ if os.path.isfile(omit_AA_jsonl):
144
+ with open(omit_AA_jsonl, 'r') as json_file:
145
+ json_list = list(json_file)
146
+ for json_str in json_list:
147
+ omit_AA_dict = json.loads(json_str)
148
+ else:
149
+ if print_all:
150
+ print(40*'-')
151
+ print('omit_AA_jsonl is NOT loaded')
152
+ omit_AA_dict = None
153
+
154
+
155
+ if os.path.isfile(bias_AA_jsonl):
156
+ with open(bias_AA_jsonl, 'r') as json_file:
157
+ json_list = list(json_file)
158
+ for json_str in json_list:
159
+ bias_AA_dict = json.loads(json_str)
160
+ else:
161
+ if print_all:
162
+ print(40*'-')
163
+ print('bias_AA_jsonl is NOT loaded')
164
+ bias_AA_dict = None
165
+
166
+
167
+ if os.path.isfile(tied_positions_jsonl):
168
+ with open(tied_positions_jsonl, 'r') as json_file:
169
+ json_list = list(json_file)
170
+ for json_str in json_list:
171
+ tied_positions_dict = json.loads(json_str)
172
+ else:
173
+ if print_all:
174
+ print(40*'-')
175
+ print('tied_positions_jsonl is NOT loaded')
176
+ tied_positions_dict = None
177
+
178
+
179
+ if os.path.isfile(bias_by_res_jsonl):
180
+ with open(bias_by_res_jsonl, 'r') as json_file:
181
+ json_list = list(json_file)
182
+
183
+ for json_str in json_list:
184
+ bias_by_res_dict = json.loads(json_str)
185
+ if print_all:
186
+ print('bias by residue dictionary is loaded')
187
+ else:
188
+ if print_all:
189
+ print(40*'-')
190
+ print('bias by residue dictionary is not loaded, or not provided')
191
+ bias_by_res_dict = None
192
+
193
+
194
+ if print_all:
195
+ print(40*'-')
196
+ bias_AAs_np = np.zeros(len(alphabet))
197
+ if bias_AA_dict:
198
+ for n, AA in enumerate(alphabet):
199
+ if AA in list(bias_AA_dict.keys()):
200
+ bias_AAs_np[n] = bias_AA_dict[AA]
201
+
202
+ if pdb_path:
203
+ pdb_dict_list = parse_PDB(pdb_path, ca_only=ca_only)
204
+ dataset_valid = StructureDatasetPDB(pdb_dict_list, truncate=None, max_length=max_length)
205
+ all_chain_list = [item[-1:] for item in list(pdb_dict_list[0]) if item[:9]=='seq_chain'] #['A','B', 'C',...]
206
+ if pdb_path_chains:
207
+ designed_chain_list = [str(item) for item in pdb_path_chains.split()]
208
+ else:
209
+ designed_chain_list = all_chain_list
210
+ fixed_chain_list = [letter for letter in all_chain_list if letter not in designed_chain_list]
211
+ chain_id_dict = {}
212
+ chain_id_dict[pdb_dict_list[0]['name']]= (designed_chain_list, fixed_chain_list)
213
+ else:
214
+ dataset_valid = StructureDataset(jsonl_path, truncate=None, max_length=max_length, verbose=print_all)
215
+
216
+ # Build paths for experiment
217
+ if write_output_files:
218
+ folder_for_outputs = out_folder
219
+ base_folder = folder_for_outputs
220
+ if base_folder[-1] != '/':
221
+ base_folder = base_folder + '/'
222
+ if not os.path.exists(base_folder):
223
+ os.makedirs(base_folder)
224
+ if not os.path.exists(base_folder + 'seqs'):
225
+ os.makedirs(base_folder + 'seqs')
226
+
227
+ # if args.save_score:
228
+ # if not os.path.exists(base_folder + 'scores'):
229
+ # os.makedirs(base_folder + 'scores')
230
+
231
+ # if args.score_only:
232
+ # if not os.path.exists(base_folder + 'score_only'):
233
+ # os.makedirs(base_folder + 'score_only')
234
+
235
+
236
+ # if args.conditional_probs_only:
237
+ # if not os.path.exists(base_folder + 'conditional_probs_only'):
238
+ # os.makedirs(base_folder + 'conditional_probs_only')
239
+
240
+ # if args.unconditional_probs_only:
241
+ # if not os.path.exists(base_folder + 'unconditional_probs_only'):
242
+ # os.makedirs(base_folder + 'unconditional_probs_only')
243
+
244
+ # if args.save_probs:
245
+ # if not os.path.exists(base_folder + 'probs'):
246
+ # os.makedirs(base_folder + 'probs')
247
+
248
+ # Timing
249
+ start_time = time.time()
250
+ total_residues = 0
251
+ protein_list = []
252
+ total_step = 0
253
+ # Validation epoch
254
+ new_mpnn_seqs = []
255
+ with torch.no_grad():
256
+ test_sum, test_weights = 0., 0.
257
+ for ix, protein in enumerate(dataset_valid):
258
+ score_list = []
259
+ global_score_list = []
260
+ all_probs_list = []
261
+ all_log_probs_list = []
262
+ S_sample_list = []
263
+ batch_clones = [copy.deepcopy(protein) for i in range(BATCH_COPIES)]
264
+ X, S, mask, lengths, chain_M, chain_encoding_all, chain_list_list, visible_list_list, masked_list_list, masked_chain_length_list_list, chain_M_pos, omit_AA_mask, residue_idx, dihedral_mask, tied_pos_list_of_lists_list, pssm_coef, pssm_bias, pssm_log_odds_all, bias_by_res_all, tied_beta = tied_featurize(batch_clones, device, chain_id_dict, fixed_positions_dict, omit_AA_dict, tied_positions_dict, pssm_dict, bias_by_res_dict, ca_only=ca_only)
265
+ pssm_log_odds_mask = (pssm_log_odds_all > pssm_threshold).float() #1.0 for true, 0.0 for false
266
+ name_ = batch_clones[0]['name']
267
+ if False:
268
+ pass
269
+ # if args.score_only:
270
+ # loop_c = 0
271
+ # if args.path_to_fasta:
272
+ # fasta_names, fasta_seqs = parse_fasta(args.path_to_fasta, omit=["/"])
273
+ # loop_c = len(fasta_seqs)
274
+ # for fc in range(1+loop_c):
275
+ # if fc == 0:
276
+ # structure_sequence_score_file = base_folder + '/score_only/' + batch_clones[0]['name'] + f'_pdb'
277
+ # else:
278
+ # structure_sequence_score_file = base_folder + '/score_only/' + batch_clones[0]['name'] + f'_fasta_{fc}'
279
+ # native_score_list = []
280
+ # global_native_score_list = []
281
+ # if fc > 0:
282
+ # input_seq_length = len(fasta_seqs[fc-1])
283
+ # S_input = torch.tensor([alphabet_dict[AA] for AA in fasta_seqs[fc-1]], device=device)[None,:].repeat(X.shape[0], 1)
284
+ # S[:,:input_seq_length] = S_input #assumes that S and S_input are alphabetically sorted for masked_chains
285
+ # for j in range(NUM_BATCHES):
286
+ # randn_1 = torch.randn(chain_M.shape, device=X.device)
287
+ # log_probs = model(X, S, mask, chain_M*chain_M_pos, residue_idx, chain_encoding_all, randn_1)
288
+ # mask_for_loss = mask*chain_M*chain_M_pos
289
+ # scores = _scores(S, log_probs, mask_for_loss)
290
+ # native_score = scores.cpu().data.numpy()
291
+ # native_score_list.append(native_score)
292
+ # global_scores = _scores(S, log_probs, mask)
293
+ # global_native_score = global_scores.cpu().data.numpy()
294
+ # global_native_score_list.append(global_native_score)
295
+ # native_score = np.concatenate(native_score_list, 0)
296
+ # global_native_score = np.concatenate(global_native_score_list, 0)
297
+ # ns_mean = native_score.mean()
298
+ # ns_mean_print = np.format_float_positional(np.float32(ns_mean), unique=False, precision=4)
299
+ # ns_std = native_score.std()
300
+ # ns_std_print = np.format_float_positional(np.float32(ns_std), unique=False, precision=4)
301
+
302
+ # global_ns_mean = global_native_score.mean()
303
+ # global_ns_mean_print = np.format_float_positional(np.float32(global_ns_mean), unique=False, precision=4)
304
+ # global_ns_std = global_native_score.std()
305
+ # global_ns_std_print = np.format_float_positional(np.float32(global_ns_std), unique=False, precision=4)
306
+
307
+ # ns_sample_size = native_score.shape[0]
308
+ # seq_str = _S_to_seq(S[0,], chain_M[0,])
309
+ # np.savez(structure_sequence_score_file, score=native_score, global_score=global_native_score, S=S[0,].cpu().numpy(), seq_str=seq_str)
310
+ # if print_all:
311
+ # if fc == 0:
312
+ # print(f'Score for {name_} from PDB, mean: {ns_mean_print}, std: {ns_std_print}, sample size: {ns_sample_size}, global score, mean: {global_ns_mean_print}, std: {global_ns_std_print}, sample size: {ns_sample_size}')
313
+ # else:
314
+ # print(f'Score for {name_}_{fc} from FASTA, mean: {ns_mean_print}, std: {ns_std_print}, sample size: {ns_sample_size}, global score, mean: {global_ns_mean_print}, std: {global_ns_std_print}, sample size: {ns_sample_size}')
315
+ # elif args.conditional_probs_only:
316
+ # if print_all:
317
+ # print(f'Calculating conditional probabilities for {name_}')
318
+ # conditional_probs_only_file = base_folder + '/conditional_probs_only/' + batch_clones[0]['name']
319
+ # log_conditional_probs_list = []
320
+ # for j in range(NUM_BATCHES):
321
+ # randn_1 = torch.randn(chain_M.shape, device=X.device)
322
+ # log_conditional_probs = model.conditional_probs(X, S, mask, chain_M*chain_M_pos, residue_idx, chain_encoding_all, randn_1, args.conditional_probs_only_backbone)
323
+ # log_conditional_probs_list.append(log_conditional_probs.cpu().numpy())
324
+ # concat_log_p = np.concatenate(log_conditional_probs_list, 0) #[B, L, 21]
325
+ # mask_out = (chain_M*chain_M_pos*mask)[0,].cpu().numpy()
326
+ # np.savez(conditional_probs_only_file, log_p=concat_log_p, S=S[0,].cpu().numpy(), mask=mask[0,].cpu().numpy(), design_mask=mask_out)
327
+ # elif args.unconditional_probs_only:
328
+ # if print_all:
329
+ # print(f'Calculating sequence unconditional probabilities for {name_}')
330
+ # unconditional_probs_only_file = base_folder + '/unconditional_probs_only/' + batch_clones[0]['name']
331
+ # log_unconditional_probs_list = []
332
+ # for j in range(NUM_BATCHES):
333
+ # log_unconditional_probs = model.unconditional_probs(X, mask, residue_idx, chain_encoding_all)
334
+ # log_unconditional_probs_list.append(log_unconditional_probs.cpu().numpy())
335
+ # concat_log_p = np.concatenate(log_unconditional_probs_list, 0) #[B, L, 21]
336
+ # mask_out = (chain_M*chain_M_pos*mask)[0,].cpu().numpy()
337
+ # np.savez(unconditional_probs_only_file, log_p=concat_log_p, S=S[0,].cpu().numpy(), mask=mask[0,].cpu().numpy(), design_mask=mask_out)
338
+ else:
339
+ randn_1 = torch.randn(chain_M.shape, device=X.device)
340
+ log_probs = model(X, S, mask, chain_M*chain_M_pos, residue_idx, chain_encoding_all, randn_1)
341
+ mask_for_loss = mask*chain_M*chain_M_pos
342
+ scores = _scores(S, log_probs, mask_for_loss) #score only the redesigned part
343
+ native_score = scores.cpu().data.numpy()
344
+ global_scores = _scores(S, log_probs, mask) #score the whole structure-sequence
345
+ global_native_score = global_scores.cpu().data.numpy()
346
+ # Generate some sequences
347
+ if write_output_files:
348
+ ali_file = base_folder + '/seqs/' + batch_clones[0]['name'] + '.fa'
349
+ score_file = base_folder + '/scores/' + batch_clones[0]['name'] + '.npz'
350
+ probs_file = base_folder + '/probs/' + batch_clones[0]['name'] + '.npz'
351
+ f = open(ali_file, 'w')
352
+ if print_all:
353
+ print(f'Generating sequences for: {name_}')
354
+ t0 = time.time()
355
+ for temp in temperatures:
356
+ for j in range(NUM_BATCHES):
357
+ randn_2 = torch.randn(chain_M.shape, device=X.device)
358
+ if tied_positions_dict == None:
359
+ sample_dict = model.sample(X, randn_2, S, chain_M, chain_encoding_all, residue_idx, mask=mask, temperature=temp, omit_AAs_np=omit_AAs_np, bias_AAs_np=bias_AAs_np, chain_M_pos=chain_M_pos, omit_AA_mask=omit_AA_mask, pssm_coef=pssm_coef, pssm_bias=pssm_bias, pssm_multi=pssm_multi, pssm_log_odds_flag=bool(pssm_log_odds_flag), pssm_log_odds_mask=pssm_log_odds_mask, pssm_bias_flag=bool(pssm_bias_flag), bias_by_res=bias_by_res_all)
360
+ S_sample = sample_dict["S"]
361
+ else:
362
+ sample_dict = model.tied_sample(X, randn_2, S, chain_M, chain_encoding_all, residue_idx, mask=mask, temperature=temp, omit_AAs_np=omit_AAs_np, bias_AAs_np=bias_AAs_np, chain_M_pos=chain_M_pos, omit_AA_mask=omit_AA_mask, pssm_coef=pssm_coef, pssm_bias=pssm_bias, pssm_multi=pssm_multi, pssm_log_odds_flag=bool(pssm_log_odds_flag), pssm_log_odds_mask=pssm_log_odds_mask, pssm_bias_flag=bool(pssm_bias_flag), tied_pos=tied_pos_list_of_lists_list[0], tied_beta=tied_beta, bias_by_res=bias_by_res_all)
363
+ # Compute scores
364
+ S_sample = sample_dict["S"]
365
+ log_probs = model(X, S_sample, mask, chain_M*chain_M_pos, residue_idx, chain_encoding_all, randn_2, use_input_decoding_order=True, decoding_order=sample_dict["decoding_order"])
366
+ mask_for_loss = mask*chain_M*chain_M_pos
367
+ scores = _scores(S_sample, log_probs, mask_for_loss)
368
+ scores = scores.cpu().data.numpy()
369
+
370
+ global_scores = _scores(S_sample, log_probs, mask) #score the whole structure-sequence
371
+ global_scores = global_scores.cpu().data.numpy()
372
+
373
+ all_probs_list.append(sample_dict["probs"].cpu().data.numpy())
374
+ all_log_probs_list.append(log_probs.cpu().data.numpy())
375
+ S_sample_list.append(S_sample.cpu().data.numpy())
376
+ for b_ix in range(BATCH_COPIES):
377
+ masked_chain_length_list = masked_chain_length_list_list[b_ix]
378
+ masked_list = masked_list_list[b_ix]
379
+ seq_recovery_rate = torch.sum(torch.sum(torch.nn.functional.one_hot(S[b_ix], 21)*torch.nn.functional.one_hot(S_sample[b_ix], 21),axis=-1)*mask_for_loss[b_ix])/torch.sum(mask_for_loss[b_ix])
380
+ seq = _S_to_seq(S_sample[b_ix], chain_M[b_ix])
381
+ new_mpnn_seqs.append(seq)
382
+ score = scores[b_ix]
383
+ score_list.append(score)
384
+ global_score = global_scores[b_ix]
385
+ global_score_list.append(global_score)
386
+ native_seq = _S_to_seq(S[b_ix], chain_M[b_ix])
387
+ if b_ix == 0 and j==0 and temp==temperatures[0]:
388
+ start = 0
389
+ end = 0
390
+ list_of_AAs = []
391
+ for mask_l in masked_chain_length_list:
392
+ end += mask_l
393
+ list_of_AAs.append(native_seq[start:end])
394
+ start = end
395
+ native_seq = "".join(list(np.array(list_of_AAs)[np.argsort(masked_list)]))
396
+ l0 = 0
397
+ for mc_length in list(np.array(masked_chain_length_list)[np.argsort(masked_list)])[:-1]:
398
+ l0 += mc_length
399
+ native_seq = native_seq[:l0] + '/' + native_seq[l0:]
400
+ l0 += 1
401
+ sorted_masked_chain_letters = np.argsort(masked_list_list[0])
402
+ print_masked_chains = [masked_list_list[0][i] for i in sorted_masked_chain_letters]
403
+ sorted_visible_chain_letters = np.argsort(visible_list_list[0])
404
+ print_visible_chains = [visible_list_list[0][i] for i in sorted_visible_chain_letters]
405
+ native_score_print = np.format_float_positional(np.float32(native_score.mean()), unique=False, precision=4)
406
+ global_native_score_print = np.format_float_positional(np.float32(global_native_score.mean()), unique=False, precision=4)
407
+ script_dir = os.path.dirname(os.path.realpath(__file__))
408
+ try:
409
+ commit_str = subprocess.check_output(f'git --git-dir {script_dir}/.git rev-parse HEAD', shell=True, stderr=subprocess.DEVNULL).decode().strip()
410
+ except subprocess.CalledProcessError:
411
+ commit_str = 'unknown'
412
+ if ca_only:
413
+ print_model_name = 'CA_model_name'
414
+ else:
415
+ print_model_name = 'model_name'
416
+ if write_output_files:
417
+ f.write('>{}, score={}, global_score={}, fixed_chains={}, designed_chains={}, {}={}, git_hash={}, seed={}\n{}\n'.format(name_, native_score_print, global_native_score_print, print_visible_chains, print_masked_chains, print_model_name, model_name, commit_str, seed, native_seq)) #write the native sequence
418
+ start = 0
419
+ end = 0
420
+ list_of_AAs = []
421
+ for mask_l in masked_chain_length_list:
422
+ end += mask_l
423
+ list_of_AAs.append(seq[start:end])
424
+ start = end
425
+
426
+ seq = "".join(list(np.array(list_of_AAs)[np.argsort(masked_list)]))
427
+ l0 = 0
428
+ for mc_length in list(np.array(masked_chain_length_list)[np.argsort(masked_list)])[:-1]:
429
+ l0 += mc_length
430
+ seq = seq[:l0] + '/' + seq[l0:]
431
+ l0 += 1
432
+ score_print = np.format_float_positional(np.float32(score), unique=False, precision=4)
433
+ global_score_print = np.format_float_positional(np.float32(global_score), unique=False, precision=4)
434
+ seq_rec_print = np.format_float_positional(np.float32(seq_recovery_rate.detach().cpu().numpy()), unique=False, precision=4)
435
+ sample_number = j*BATCH_COPIES+b_ix+1
436
+ if write_output_files:
437
+ f.write('>T={}, sample={}, score={}, global_score={}, seq_recovery={}\n{}\n'.format(temp,sample_number,score_print,global_score_print,seq_rec_print,seq)) #write generated sequence
438
+ # if args.save_score:
439
+ # np.savez(score_file, score=np.array(score_list, np.float32), global_score=np.array(global_score_list, np.float32))
440
+ # if args.save_probs:
441
+ # all_probs_concat = np.concatenate(all_probs_list)
442
+ # all_log_probs_concat = np.concatenate(all_log_probs_list)
443
+ # S_sample_concat = np.concatenate(S_sample_list)
444
+ # np.savez(probs_file, probs=np.array(all_probs_concat, np.float32), log_probs=np.array(all_log_probs_concat, np.float32), S=np.array(S_sample_concat, np.int32), mask=mask_for_loss.cpu().data.numpy(), chain_order=chain_list_list)
445
+ t1 = time.time()
446
+ dt = round(float(t1-t0), 4)
447
+ num_seqs = len(temperatures)*NUM_BATCHES*BATCH_COPIES
448
+ total_length = X.shape[1]
449
+ if print_all:
450
+ print(f'{num_seqs} sequences of length {total_length} generated in {dt} seconds')
451
+ if write_output_files:
452
+ f.close()
453
+
454
+ return new_mpnn_seqs
455
+
456
+
457
+ def parse_fasta(filename,limit=-1, omit=[]):
458
+ header = []
459
+ sequence = []
460
+ lines = open(filename, "r")
461
+ for line in lines:
462
+ line = line.rstrip()
463
+ if line[0] == ">":
464
+ if len(header) == limit:
465
+ break
466
+ header.append(line[1:])
467
+ sequence.append([])
468
+ else:
469
+ if omit:
470
+ line = [item for item in line if item not in omit]
471
+ line = ''.join(line)
472
+ line = ''.join(line)
473
+ sequence[-1].append(line)
474
+ lines.close()
475
+ sequence = [''.join(seq) for seq in sequence]
476
+ return np.array(header), np.array(sequence)
477
+
478
+ def _scores(S, log_probs, mask):
479
+ """ Negative log probabilities """
480
+ criterion = torch.nn.NLLLoss(reduction='none')
481
+ loss = criterion(
482
+ log_probs.contiguous().view(-1,log_probs.size(-1)),
483
+ S.contiguous().view(-1)
484
+ ).view(S.size())
485
+ scores = torch.sum(loss * mask, dim=-1) / torch.sum(mask, dim=-1)
486
+ return scores
487
+
488
+ def _S_to_seq(S, mask):
489
+ alphabet = 'ACDEFGHIKLMNPQRSTVWYX'
490
+ seq = ''.join([alphabet[c] for c, m in zip(S.tolist(), mask.tolist()) if m > 0])
491
+ return seq
492
+
493
+ def parse_PDB_biounits(x, atoms=['N','CA','C'], chain=None):
494
+ '''
495
+ input: x = PDB filename
496
+ atoms = atoms to extract (optional)
497
+ output: (length, atoms, coords=(x,y,z)), sequence
498
+ '''
499
+
500
+ alpha_1 = list("ARNDCQEGHILKMFPSTWYV-")
501
+ states = len(alpha_1)
502
+ alpha_3 = ['ALA','ARG','ASN','ASP','CYS','GLN','GLU','GLY','HIS','ILE',
503
+ 'LEU','LYS','MET','PHE','PRO','SER','THR','TRP','TYR','VAL','GAP']
504
+
505
+ aa_1_N = {a:n for n,a in enumerate(alpha_1)}
506
+ aa_3_N = {a:n for n,a in enumerate(alpha_3)}
507
+ aa_N_1 = {n:a for n,a in enumerate(alpha_1)}
508
+ aa_1_3 = {a:b for a,b in zip(alpha_1,alpha_3)}
509
+ aa_3_1 = {b:a for a,b in zip(alpha_1,alpha_3)}
510
+
511
+ def AA_to_N(x):
512
+ # ["ARND"] -> [[0,1,2,3]]
513
+ x = np.array(x);
514
+ if x.ndim == 0: x = x[None]
515
+ return [[aa_1_N.get(a, states-1) for a in y] for y in x]
516
+
517
+ def N_to_AA(x):
518
+ # [[0,1,2,3]] -> ["ARND"]
519
+ x = np.array(x);
520
+ if x.ndim == 1: x = x[None]
521
+ return ["".join([aa_N_1.get(a,"-") for a in y]) for y in x]
522
+
523
+ xyz,seq,min_resn,max_resn = {},{},1e6,-1e6
524
+ for line in open(x,"rb"):
525
+ line = line.decode("utf-8","ignore").rstrip()
526
+
527
+ if line[:6] == "HETATM" and line[17:17+3] == "MSE":
528
+ line = line.replace("HETATM","ATOM ")
529
+ line = line.replace("MSE","MET")
530
+
531
+ if line[:4] == "ATOM":
532
+ ch = line[21:22]
533
+ if ch == chain or chain is None:
534
+ atom = line[12:12+4].strip()
535
+ resi = line[17:17+3]
536
+ resn = line[22:22+5].strip()
537
+ x,y,z = [float(line[i:(i+8)]) for i in [30,38,46]]
538
+
539
+ if resn[-1].isalpha():
540
+ resa,resn = resn[-1],int(resn[:-1])-1
541
+ else:
542
+ resa,resn = "",int(resn)-1
543
+ # resn = int(resn)
544
+ if resn < min_resn:
545
+ min_resn = resn
546
+ if resn > max_resn:
547
+ max_resn = resn
548
+ if resn not in xyz:
549
+ xyz[resn] = {}
550
+ if resa not in xyz[resn]:
551
+ xyz[resn][resa] = {}
552
+ if resn not in seq:
553
+ seq[resn] = {}
554
+ if resa not in seq[resn]:
555
+ seq[resn][resa] = resi
556
+
557
+ if atom not in xyz[resn][resa]:
558
+ xyz[resn][resa][atom] = np.array([x,y,z])
559
+
560
+ # convert to numpy arrays, fill in missing values
561
+ seq_,xyz_ = [],[]
562
+ try:
563
+ for resn in range(min_resn,max_resn+1):
564
+ if resn in seq:
565
+ for k in sorted(seq[resn]): seq_.append(aa_3_N.get(seq[resn][k],20))
566
+ else: seq_.append(20)
567
+ if resn in xyz:
568
+ for k in sorted(xyz[resn]):
569
+ for atom in atoms:
570
+ if atom in xyz[resn][k]: xyz_.append(xyz[resn][k][atom])
571
+ else: xyz_.append(np.full(3,np.nan))
572
+ else:
573
+ for atom in atoms: xyz_.append(np.full(3,np.nan))
574
+ return np.array(xyz_).reshape(-1,len(atoms),3), N_to_AA(np.array(seq_))
575
+ except TypeError:
576
+ return 'no_chain', 'no_chain'
577
+
578
+ def parse_PDB(path_to_pdb, input_chain_list=None, ca_only=False):
579
+ c=0
580
+ pdb_dict_list = []
581
+ init_alphabet = ['A', 'B', 'C', 'D', 'E', 'F', 'G','H', 'I', 'J','K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T','U', 'V','W','X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g','h', 'i', 'j','k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't','u', 'v','w','x', 'y', 'z']
582
+ extra_alphabet = [str(item) for item in list(np.arange(300))]
583
+ chain_alphabet = init_alphabet + extra_alphabet
584
+
585
+ if input_chain_list:
586
+ chain_alphabet = input_chain_list
587
+
588
+
589
+ biounit_names = [path_to_pdb]
590
+ for biounit in biounit_names:
591
+ my_dict = {}
592
+ s = 0
593
+ concat_seq = ''
594
+ concat_N = []
595
+ concat_CA = []
596
+ concat_C = []
597
+ concat_O = []
598
+ concat_mask = []
599
+ coords_dict = {}
600
+ for letter in chain_alphabet:
601
+ if ca_only:
602
+ sidechain_atoms = ['CA']
603
+ else:
604
+ sidechain_atoms = ['N', 'CA', 'C', 'O']
605
+ xyz, seq = parse_PDB_biounits(biounit, atoms=sidechain_atoms, chain=letter)
606
+ if type(xyz) != str:
607
+ concat_seq += seq[0]
608
+ my_dict['seq_chain_'+letter]=seq[0]
609
+ coords_dict_chain = {}
610
+ if ca_only:
611
+ coords_dict_chain['CA_chain_'+letter]=xyz.tolist()
612
+ else:
613
+ coords_dict_chain['N_chain_' + letter] = xyz[:, 0, :].tolist()
614
+ coords_dict_chain['CA_chain_' + letter] = xyz[:, 1, :].tolist()
615
+ coords_dict_chain['C_chain_' + letter] = xyz[:, 2, :].tolist()
616
+ coords_dict_chain['O_chain_' + letter] = xyz[:, 3, :].tolist()
617
+ my_dict['coords_chain_'+letter]=coords_dict_chain
618
+ s += 1
619
+ fi = biounit.rfind("/")
620
+ my_dict['name']=biounit[(fi+1):-4]
621
+ my_dict['num_of_chains'] = s
622
+ my_dict['seq'] = concat_seq
623
+ if s <= len(chain_alphabet):
624
+ pdb_dict_list.append(my_dict)
625
+ c+=1
626
+ return pdb_dict_list
627
+
628
+
629
+
630
+ def tied_featurize(batch, device, chain_dict, fixed_position_dict=None, omit_AA_dict=None, tied_positions_dict=None, pssm_dict=None, bias_by_res_dict=None, ca_only=False):
631
+ """ Pack and pad batch into torch tensors """
632
+ alphabet = 'ACDEFGHIKLMNPQRSTVWYX'
633
+ B = len(batch)
634
+ lengths = np.array([len(b['seq']) for b in batch], dtype=np.int32) #sum of chain seq lengths
635
+ L_max = max([len(b['seq']) for b in batch])
636
+ if ca_only:
637
+ X = np.zeros([B, L_max, 1, 3])
638
+ else:
639
+ X = np.zeros([B, L_max, 4, 3])
640
+ residue_idx = -100*np.ones([B, L_max], dtype=np.int32)
641
+ chain_M = np.zeros([B, L_max], dtype=np.int32) #1.0 for the bits that need to be predicted
642
+ pssm_coef_all = np.zeros([B, L_max], dtype=np.float32) #1.0 for the bits that need to be predicted
643
+ pssm_bias_all = np.zeros([B, L_max, 21], dtype=np.float32) #1.0 for the bits that need to be predicted
644
+ pssm_log_odds_all = 10000.0*np.ones([B, L_max, 21], dtype=np.float32) #1.0 for the bits that need to be predicted
645
+ chain_M_pos = np.zeros([B, L_max], dtype=np.int32) #1.0 for the bits that need to be predicted
646
+ bias_by_res_all = np.zeros([B, L_max, 21], dtype=np.float32)
647
+ chain_encoding_all = np.zeros([B, L_max], dtype=np.int32) #1.0 for the bits that need to be predicted
648
+ S = np.zeros([B, L_max], dtype=np.int32)
649
+ omit_AA_mask = np.zeros([B, L_max, len(alphabet)], dtype=np.int32)
650
+ # Build the batch
651
+ letter_list_list = []
652
+ visible_list_list = []
653
+ masked_list_list = []
654
+ masked_chain_length_list_list = []
655
+ tied_pos_list_of_lists_list = []
656
+ for i, b in enumerate(batch):
657
+ if chain_dict != None:
658
+ masked_chains, visible_chains = chain_dict[b['name']] #masked_chains a list of chain letters to predict [A, D, F]
659
+ else:
660
+ masked_chains = [item[-1:] for item in list(b) if item[:10]=='seq_chain_']
661
+ visible_chains = []
662
+ masked_chains.sort() #sort masked_chains
663
+ visible_chains.sort() #sort visible_chains
664
+ all_chains = masked_chains + visible_chains
665
+ for i, b in enumerate(batch):
666
+ mask_dict = {}
667
+ a = 0
668
+ x_chain_list = []
669
+ chain_mask_list = []
670
+ chain_seq_list = []
671
+ chain_encoding_list = []
672
+ c = 1
673
+ letter_list = []
674
+ global_idx_start_list = [0]
675
+ visible_list = []
676
+ masked_list = []
677
+ masked_chain_length_list = []
678
+ fixed_position_mask_list = []
679
+ omit_AA_mask_list = []
680
+ pssm_coef_list = []
681
+ pssm_bias_list = []
682
+ pssm_log_odds_list = []
683
+ bias_by_res_list = []
684
+ l0 = 0
685
+ l1 = 0
686
+ for step, letter in enumerate(all_chains):
687
+ if letter in visible_chains:
688
+ letter_list.append(letter)
689
+ visible_list.append(letter)
690
+ chain_seq = b[f'seq_chain_{letter}']
691
+ chain_seq = ''.join([a if a!='-' else 'X' for a in chain_seq])
692
+ chain_length = len(chain_seq)
693
+ global_idx_start_list.append(global_idx_start_list[-1]+chain_length)
694
+ chain_coords = b[f'coords_chain_{letter}'] #this is a dictionary
695
+ chain_mask = np.zeros(chain_length) #0.0 for visible chains
696
+ if ca_only:
697
+ x_chain = np.array(chain_coords[f'CA_chain_{letter}']) #[chain_lenght,1,3] #CA_diff
698
+ if len(x_chain.shape) == 2:
699
+ x_chain = x_chain[:,None,:]
700
+ else:
701
+ x_chain = np.stack([chain_coords[c] for c in [f'N_chain_{letter}', f'CA_chain_{letter}', f'C_chain_{letter}', f'O_chain_{letter}']], 1) #[chain_lenght,4,3]
702
+ x_chain_list.append(x_chain)
703
+ chain_mask_list.append(chain_mask)
704
+ chain_seq_list.append(chain_seq)
705
+ chain_encoding_list.append(c*np.ones(np.array(chain_mask).shape[0]))
706
+ l1 += chain_length
707
+ residue_idx[i, l0:l1] = 100*(c-1)+np.arange(l0, l1)
708
+ l0 += chain_length
709
+ c+=1
710
+ fixed_position_mask = np.ones(chain_length)
711
+ fixed_position_mask_list.append(fixed_position_mask)
712
+ omit_AA_mask_temp = np.zeros([chain_length, len(alphabet)], np.int32)
713
+ omit_AA_mask_list.append(omit_AA_mask_temp)
714
+ pssm_coef = np.zeros(chain_length)
715
+ pssm_bias = np.zeros([chain_length, 21])
716
+ pssm_log_odds = 10000.0*np.ones([chain_length, 21])
717
+ pssm_coef_list.append(pssm_coef)
718
+ pssm_bias_list.append(pssm_bias)
719
+ pssm_log_odds_list.append(pssm_log_odds)
720
+ bias_by_res_list.append(np.zeros([chain_length, 21]))
721
+ if letter in masked_chains:
722
+ masked_list.append(letter)
723
+ letter_list.append(letter)
724
+ chain_seq = b[f'seq_chain_{letter}']
725
+ chain_seq = ''.join([a if a!='-' else 'X' for a in chain_seq])
726
+ chain_length = len(chain_seq)
727
+ global_idx_start_list.append(global_idx_start_list[-1]+chain_length)
728
+ masked_chain_length_list.append(chain_length)
729
+ chain_coords = b[f'coords_chain_{letter}'] #this is a dictionary
730
+ chain_mask = np.ones(chain_length) #1.0 for masked
731
+ if ca_only:
732
+ x_chain = np.array(chain_coords[f'CA_chain_{letter}']) #[chain_lenght,1,3] #CA_diff
733
+ if len(x_chain.shape) == 2:
734
+ x_chain = x_chain[:,None,:]
735
+ else:
736
+ x_chain = np.stack([chain_coords[c] for c in [f'N_chain_{letter}', f'CA_chain_{letter}', f'C_chain_{letter}', f'O_chain_{letter}']], 1) #[chain_lenght,4,3]
737
+ x_chain_list.append(x_chain)
738
+ chain_mask_list.append(chain_mask)
739
+ chain_seq_list.append(chain_seq)
740
+ chain_encoding_list.append(c*np.ones(np.array(chain_mask).shape[0]))
741
+ l1 += chain_length
742
+ residue_idx[i, l0:l1] = 100*(c-1)+np.arange(l0, l1)
743
+ l0 += chain_length
744
+ c+=1
745
+ fixed_position_mask = np.ones(chain_length)
746
+ if fixed_position_dict!=None:
747
+ fixed_pos_list = fixed_position_dict[b['name']][letter]
748
+ if fixed_pos_list:
749
+ fixed_position_mask[np.array(fixed_pos_list)-1] = 0.0
750
+ fixed_position_mask_list.append(fixed_position_mask)
751
+ omit_AA_mask_temp = np.zeros([chain_length, len(alphabet)], np.int32)
752
+ if omit_AA_dict!=None:
753
+ for item in omit_AA_dict[b['name']][letter]:
754
+ idx_AA = np.array(item[0])-1
755
+ AA_idx = np.array([np.argwhere(np.array(list(alphabet))== AA)[0][0] for AA in item[1]]).repeat(idx_AA.shape[0])
756
+ idx_ = np.array([[a, b] for a in idx_AA for b in AA_idx])
757
+ omit_AA_mask_temp[idx_[:,0], idx_[:,1]] = 1
758
+ omit_AA_mask_list.append(omit_AA_mask_temp)
759
+ pssm_coef = np.zeros(chain_length)
760
+ pssm_bias = np.zeros([chain_length, 21])
761
+ pssm_log_odds = 10000.0*np.ones([chain_length, 21])
762
+ if pssm_dict:
763
+ if pssm_dict[b['name']][letter]:
764
+ pssm_coef = pssm_dict[b['name']][letter]['pssm_coef']
765
+ pssm_bias = pssm_dict[b['name']][letter]['pssm_bias']
766
+ pssm_log_odds = pssm_dict[b['name']][letter]['pssm_log_odds']
767
+ pssm_coef_list.append(pssm_coef)
768
+ pssm_bias_list.append(pssm_bias)
769
+ pssm_log_odds_list.append(pssm_log_odds)
770
+ if bias_by_res_dict:
771
+ bias_by_res_list.append(bias_by_res_dict[b['name']][letter])
772
+ else:
773
+ bias_by_res_list.append(np.zeros([chain_length, 21]))
774
+
775
+
776
+ letter_list_np = np.array(letter_list)
777
+ tied_pos_list_of_lists = []
778
+ tied_beta = np.ones(L_max)
779
+ if tied_positions_dict!=None:
780
+ tied_pos_list = tied_positions_dict[b['name']]
781
+ if tied_pos_list:
782
+ set_chains_tied = set(list(itertools.chain(*[list(item) for item in tied_pos_list])))
783
+ for tied_item in tied_pos_list:
784
+ one_list = []
785
+ for k, v in tied_item.items():
786
+ start_idx = global_idx_start_list[np.argwhere(letter_list_np == k)[0][0]]
787
+ if isinstance(v[0], list):
788
+ for v_count in range(len(v[0])):
789
+ one_list.append(start_idx+v[0][v_count]-1)#make 0 to be the first
790
+ tied_beta[start_idx+v[0][v_count]-1] = v[1][v_count]
791
+ else:
792
+ for v_ in v:
793
+ one_list.append(start_idx+v_-1)#make 0 to be the first
794
+ tied_pos_list_of_lists.append(one_list)
795
+ tied_pos_list_of_lists_list.append(tied_pos_list_of_lists)
796
+
797
+
798
+
799
+ x = np.concatenate(x_chain_list,0) #[L, 4, 3]
800
+ all_sequence = "".join(chain_seq_list)
801
+ m = np.concatenate(chain_mask_list,0) #[L,], 1.0 for places that need to be predicted
802
+ chain_encoding = np.concatenate(chain_encoding_list,0)
803
+ m_pos = np.concatenate(fixed_position_mask_list,0) #[L,], 1.0 for places that need to be predicted
804
+
805
+ pssm_coef_ = np.concatenate(pssm_coef_list,0) #[L,], 1.0 for places that need to be predicted
806
+ pssm_bias_ = np.concatenate(pssm_bias_list,0) #[L,], 1.0 for places that need to be predicted
807
+ pssm_log_odds_ = np.concatenate(pssm_log_odds_list,0) #[L,], 1.0 for places that need to be predicted
808
+
809
+ bias_by_res_ = np.concatenate(bias_by_res_list, 0) #[L,21], 0.0 for places where AA frequencies don't need to be tweaked
810
+
811
+ l = len(all_sequence)
812
+ x_pad = np.pad(x, [[0,L_max-l], [0,0], [0,0]], 'constant', constant_values=(np.nan, ))
813
+ X[i,:,:,:] = x_pad
814
+
815
+ m_pad = np.pad(m, [[0,L_max-l]], 'constant', constant_values=(0.0, ))
816
+ m_pos_pad = np.pad(m_pos, [[0,L_max-l]], 'constant', constant_values=(0.0, ))
817
+ omit_AA_mask_pad = np.pad(np.concatenate(omit_AA_mask_list,0), [[0,L_max-l]], 'constant', constant_values=(0.0, ))
818
+ chain_M[i,:] = m_pad
819
+ chain_M_pos[i,:] = m_pos_pad
820
+ omit_AA_mask[i,] = omit_AA_mask_pad
821
+
822
+ chain_encoding_pad = np.pad(chain_encoding, [[0,L_max-l]], 'constant', constant_values=(0.0, ))
823
+ chain_encoding_all[i,:] = chain_encoding_pad
824
+
825
+ pssm_coef_pad = np.pad(pssm_coef_, [[0,L_max-l]], 'constant', constant_values=(0.0, ))
826
+ pssm_bias_pad = np.pad(pssm_bias_, [[0,L_max-l], [0,0]], 'constant', constant_values=(0.0, ))
827
+ pssm_log_odds_pad = np.pad(pssm_log_odds_, [[0,L_max-l], [0,0]], 'constant', constant_values=(0.0, ))
828
+
829
+ pssm_coef_all[i,:] = pssm_coef_pad
830
+ pssm_bias_all[i,:] = pssm_bias_pad
831
+ pssm_log_odds_all[i,:] = pssm_log_odds_pad
832
+
833
+ bias_by_res_pad = np.pad(bias_by_res_, [[0,L_max-l], [0,0]], 'constant', constant_values=(0.0, ))
834
+ bias_by_res_all[i,:] = bias_by_res_pad
835
+
836
+ # Convert to labels
837
+ indices = np.asarray([alphabet.index(a) for a in all_sequence], dtype=np.int32)
838
+ S[i, :l] = indices
839
+ letter_list_list.append(letter_list)
840
+ visible_list_list.append(visible_list)
841
+ masked_list_list.append(masked_list)
842
+ masked_chain_length_list_list.append(masked_chain_length_list)
843
+
844
+
845
+ isnan = np.isnan(X)
846
+ mask = np.isfinite(np.sum(X,(2,3))).astype(np.float32)
847
+ X[isnan] = 0.
848
+
849
+ # Conversion
850
+ pssm_coef_all = torch.from_numpy(pssm_coef_all).to(dtype=torch.float32, device=device)
851
+ pssm_bias_all = torch.from_numpy(pssm_bias_all).to(dtype=torch.float32, device=device)
852
+ pssm_log_odds_all = torch.from_numpy(pssm_log_odds_all).to(dtype=torch.float32, device=device)
853
+
854
+ tied_beta = torch.from_numpy(tied_beta).to(dtype=torch.float32, device=device)
855
+
856
+ jumps = ((residue_idx[:,1:]-residue_idx[:,:-1])==1).astype(np.float32)
857
+ bias_by_res_all = torch.from_numpy(bias_by_res_all).to(dtype=torch.float32, device=device)
858
+ phi_mask = np.pad(jumps, [[0,0],[1,0]])
859
+ psi_mask = np.pad(jumps, [[0,0],[0,1]])
860
+ omega_mask = np.pad(jumps, [[0,0],[0,1]])
861
+ dihedral_mask = np.concatenate([phi_mask[:,:,None], psi_mask[:,:,None], omega_mask[:,:,None]], -1) #[B,L,3]
862
+ dihedral_mask = torch.from_numpy(dihedral_mask).to(dtype=torch.float32, device=device)
863
+ residue_idx = torch.from_numpy(residue_idx).to(dtype=torch.long,device=device)
864
+ S = torch.from_numpy(S).to(dtype=torch.long,device=device)
865
+ X = torch.from_numpy(X).to(dtype=torch.float32, device=device)
866
+ mask = torch.from_numpy(mask).to(dtype=torch.float32, device=device)
867
+ chain_M = torch.from_numpy(chain_M).to(dtype=torch.float32, device=device)
868
+ chain_M_pos = torch.from_numpy(chain_M_pos).to(dtype=torch.float32, device=device)
869
+ omit_AA_mask = torch.from_numpy(omit_AA_mask).to(dtype=torch.float32, device=device)
870
+ chain_encoding_all = torch.from_numpy(chain_encoding_all).to(dtype=torch.long, device=device)
871
+ if ca_only:
872
+ X_out = X[:,:,0]
873
+ else:
874
+ X_out = X
875
+ return X_out, S, mask, lengths, chain_M, chain_encoding_all, letter_list_list, visible_list_list, masked_list_list, masked_chain_length_list_list, chain_M_pos, omit_AA_mask, residue_idx, dihedral_mask, tied_pos_list_of_lists_list, pssm_coef_all, pssm_bias_all, pssm_log_odds_all, bias_by_res_all, tied_beta
876
+
877
+
878
+
879
+ def loss_nll(S, log_probs, mask):
880
+ """ Negative log probabilities """
881
+ criterion = torch.nn.NLLLoss(reduction='none')
882
+ loss = criterion(
883
+ log_probs.contiguous().view(-1, log_probs.size(-1)), S.contiguous().view(-1)
884
+ ).view(S.size())
885
+ loss_av = torch.sum(loss * mask) / torch.sum(mask)
886
+ return loss, loss_av
887
+
888
+
889
+ def loss_smoothed(S, log_probs, mask, weight=0.1):
890
+ """ Negative log probabilities """
891
+ S_onehot = torch.nn.functional.one_hot(S, 21).float()
892
+
893
+ # Label smoothing
894
+ S_onehot = S_onehot + weight / float(S_onehot.size(-1))
895
+ S_onehot = S_onehot / S_onehot.sum(-1, keepdim=True)
896
+
897
+ loss = -(S_onehot * log_probs).sum(-1)
898
+ loss_av = torch.sum(loss * mask) / torch.sum(mask)
899
+ return loss, loss_av
900
+
901
+ class StructureDataset():
902
+ def __init__(self, jsonl_file, verbose=True, truncate=None, max_length=100,
903
+ alphabet='ACDEFGHIKLMNPQRSTVWYX-'):
904
+ alphabet_set = set([a for a in alphabet])
905
+ discard_count = {
906
+ 'bad_chars': 0,
907
+ 'too_long': 0,
908
+ 'bad_seq_length': 0
909
+ }
910
+
911
+ with open(jsonl_file) as f:
912
+ self.data = []
913
+
914
+ lines = f.readlines()
915
+ start = time.time()
916
+ for i, line in enumerate(lines):
917
+ entry = json.loads(line)
918
+ seq = entry['seq']
919
+ name = entry['name']
920
+
921
+ # Convert raw coords to np arrays
922
+ #for key, val in entry['coords'].items():
923
+ # entry['coords'][key] = np.asarray(val)
924
+
925
+ # Check if in alphabet
926
+ bad_chars = set([s for s in seq]).difference(alphabet_set)
927
+ if len(bad_chars) == 0:
928
+ if len(entry['seq']) <= max_length:
929
+ if True:
930
+ self.data.append(entry)
931
+ else:
932
+ discard_count['bad_seq_length'] += 1
933
+ else:
934
+ discard_count['too_long'] += 1
935
+ else:
936
+ if verbose:
937
+ print(name, bad_chars, entry['seq'])
938
+ discard_count['bad_chars'] += 1
939
+
940
+ # Truncate early
941
+ if truncate is not None and len(self.data) == truncate:
942
+ return
943
+
944
+ if verbose and (i + 1) % 1000 == 0:
945
+ elapsed = time.time() - start
946
+ print('{} entries ({} loaded) in {:.1f} s'.format(len(self.data), i+1, elapsed))
947
+ if verbose:
948
+ print('discarded', discard_count)
949
+ def __len__(self):
950
+ return len(self.data)
951
+
952
+ def __getitem__(self, idx):
953
+ return self.data[idx]
954
+
955
+
956
+ class StructureDatasetPDB():
957
+ def __init__(self, pdb_dict_list, verbose=True, truncate=None, max_length=100,
958
+ alphabet='ACDEFGHIKLMNPQRSTVWYX-'):
959
+ alphabet_set = set([a for a in alphabet])
960
+ discard_count = {
961
+ 'bad_chars': 0,
962
+ 'too_long': 0,
963
+ 'bad_seq_length': 0
964
+ }
965
+
966
+ self.data = []
967
+
968
+ start = time.time()
969
+ for i, entry in enumerate(pdb_dict_list):
970
+ seq = entry['seq']
971
+ name = entry['name']
972
+
973
+ bad_chars = set([s for s in seq]).difference(alphabet_set)
974
+ if len(bad_chars) == 0:
975
+ if len(entry['seq']) <= max_length:
976
+ self.data.append(entry)
977
+ else:
978
+ discard_count['too_long'] += 1
979
+ else:
980
+ discard_count['bad_chars'] += 1
981
+
982
+ # Truncate early
983
+ if truncate is not None and len(self.data) == truncate:
984
+ return
985
+
986
+ if verbose and (i + 1) % 1000 == 0:
987
+ elapsed = time.time() - start
988
+
989
+ #print('Discarded', discard_count)
990
+ def __len__(self):
991
+ return len(self.data)
992
+
993
+ def __getitem__(self, idx):
994
+ return self.data[idx]
995
+
996
+
997
+
998
+ class StructureLoader():
999
+ def __init__(self, dataset, batch_size=100, shuffle=True,
1000
+ collate_fn=lambda x:x, drop_last=False):
1001
+ self.dataset = dataset
1002
+ self.size = len(dataset)
1003
+ self.lengths = [len(dataset[i]['seq']) for i in range(self.size)]
1004
+ self.batch_size = batch_size
1005
+ sorted_ix = np.argsort(self.lengths)
1006
+
1007
+ # Cluster into batches of similar sizes
1008
+ clusters, batch = [], []
1009
+ batch_max = 0
1010
+ for ix in sorted_ix:
1011
+ size = self.lengths[ix]
1012
+ if size * (len(batch) + 1) <= self.batch_size:
1013
+ batch.append(ix)
1014
+ batch_max = size
1015
+ else:
1016
+ clusters.append(batch)
1017
+ batch, batch_max = [], 0
1018
+ if len(batch) > 0:
1019
+ clusters.append(batch)
1020
+ self.clusters = clusters
1021
+
1022
+ def __len__(self):
1023
+ return len(self.clusters)
1024
+
1025
+ def __iter__(self):
1026
+ np.random.shuffle(self.clusters)
1027
+ for b_idx in self.clusters:
1028
+ batch = [self.dataset[i] for i in b_idx]
1029
+ yield batch
1030
+
1031
+
1032
+
1033
+ # The following gather functions
1034
+ def gather_edges(edges, neighbor_idx):
1035
+ # Features [B,N,N,C] at Neighbor indices [B,N,K] => Neighbor features [B,N,K,C]
1036
+ neighbors = neighbor_idx.unsqueeze(-1).expand(-1, -1, -1, edges.size(-1))
1037
+ edge_features = torch.gather(edges, 2, neighbors)
1038
+ return edge_features
1039
+
1040
+ def gather_nodes(nodes, neighbor_idx):
1041
+ # Features [B,N,C] at Neighbor indices [B,N,K] => [B,N,K,C]
1042
+ # Flatten and expand indices per batch [B,N,K] => [B,NK] => [B,NK,C]
1043
+ neighbors_flat = neighbor_idx.view((neighbor_idx.shape[0], -1))
1044
+ neighbors_flat = neighbors_flat.unsqueeze(-1).expand(-1, -1, nodes.size(2))
1045
+ # Gather and re-pack
1046
+ neighbor_features = torch.gather(nodes, 1, neighbors_flat)
1047
+ neighbor_features = neighbor_features.view(list(neighbor_idx.shape)[:3] + [-1])
1048
+ return neighbor_features
1049
+
1050
+ def gather_nodes_t(nodes, neighbor_idx):
1051
+ # Features [B,N,C] at Neighbor index [B,K] => Neighbor features[B,K,C]
1052
+ idx_flat = neighbor_idx.unsqueeze(-1).expand(-1, -1, nodes.size(2))
1053
+ neighbor_features = torch.gather(nodes, 1, idx_flat)
1054
+ return neighbor_features
1055
+
1056
+ def cat_neighbors_nodes(h_nodes, h_neighbors, E_idx):
1057
+ h_nodes = gather_nodes(h_nodes, E_idx)
1058
+ h_nn = torch.cat([h_neighbors, h_nodes], -1)
1059
+ return h_nn
1060
+
1061
+
1062
+ class EncLayer(nn.Module):
1063
+ def __init__(self, num_hidden, num_in, dropout=0.1, num_heads=None, scale=30, time_cond_dim=None):
1064
+ super(EncLayer, self).__init__()
1065
+ self.num_hidden = num_hidden
1066
+ self.num_in = num_in
1067
+ self.scale = scale
1068
+ self.dropout1 = nn.Dropout(dropout)
1069
+ self.dropout2 = nn.Dropout(dropout)
1070
+ self.dropout3 = nn.Dropout(dropout)
1071
+ self.norm1 = nn.LayerNorm(num_hidden)
1072
+ self.norm2 = nn.LayerNorm(num_hidden)
1073
+ self.norm3 = nn.LayerNorm(num_hidden)
1074
+
1075
+ if time_cond_dim is not None:
1076
+ self.time_block1 = nn.Sequential(
1077
+ Rearrange('b 1 d -> b 1 1 d'),
1078
+ nn.SiLU(),
1079
+ nn.Linear(time_cond_dim, num_hidden * 2))
1080
+ self.time_block2 = nn.Sequential(
1081
+ Rearrange('b 1 d -> b 1 1 d'),
1082
+ nn.SiLU(),
1083
+ nn.Linear(time_cond_dim, num_hidden * 2))
1084
+
1085
+ self.W1 = nn.Linear(num_hidden + num_in, num_hidden, bias=True)
1086
+ self.W2 = nn.Linear(num_hidden, num_hidden, bias=True)
1087
+ self.W3 = nn.Linear(num_hidden, num_hidden, bias=True)
1088
+ self.W11 = nn.Linear(num_hidden + num_in, num_hidden, bias=True)
1089
+ self.W12 = nn.Linear(num_hidden, num_hidden, bias=True)
1090
+ self.W13 = nn.Linear(num_hidden, num_hidden, bias=True)
1091
+ self.act = torch.nn.GELU()
1092
+ self.dense = PositionWiseFeedForward(num_hidden, num_hidden * 4)
1093
+
1094
+ def forward(self, h_V, h_E, E_idx, mask_V=None, mask_attend=None, time_cond=None):
1095
+ """ Parallel computation of full transformer layer """
1096
+
1097
+ h_EV = cat_neighbors_nodes(h_V, h_E, E_idx)
1098
+ h_V_expand = h_V.unsqueeze(-2).expand(-1,-1,h_EV.size(-2),-1)
1099
+ h_EV = torch.cat([h_V_expand, h_EV], -1)
1100
+
1101
+ h_message = self.act(self.W2(self.act(self.W1(h_EV))))
1102
+ if time_cond is not None:
1103
+ scale, shift = self.time_block1(time_cond).chunk(2, dim=-1)
1104
+ h_message = h_message * (scale + 1) + shift
1105
+ h_message = self.W3(h_message)
1106
+
1107
+ if mask_attend is not None:
1108
+ h_message = mask_attend.unsqueeze(-1) * h_message
1109
+ dh = torch.sum(h_message, -2) / self.scale
1110
+ h_V = self.norm1(h_V + self.dropout1(dh))
1111
+
1112
+ dh = self.dense(h_V)
1113
+ h_V = self.norm2(h_V + self.dropout2(dh))
1114
+ if mask_V is not None:
1115
+ mask_V = mask_V.unsqueeze(-1)
1116
+ h_V = mask_V * h_V
1117
+
1118
+ h_EV = cat_neighbors_nodes(h_V, h_E, E_idx)
1119
+ h_V_expand = h_V.unsqueeze(-2).expand(-1,-1,h_EV.size(-2),-1)
1120
+ h_EV = torch.cat([h_V_expand, h_EV], -1)
1121
+
1122
+ h_message = self.act(self.W12(self.act(self.W11(h_EV))))
1123
+ if time_cond is not None:
1124
+ scale, shift = self.time_block2(time_cond).chunk(2, dim=-1)
1125
+ h_message = h_message * (scale + 1) + shift
1126
+ h_message = self.W13(h_message)
1127
+
1128
+ h_E = self.norm3(h_E + self.dropout3(h_message))
1129
+ return h_V, h_E
1130
+
1131
+
1132
+ class DecLayer(nn.Module):
1133
+ def __init__(self, num_hidden, num_in, dropout=0.1, num_heads=None, scale=30, time_cond_dim=None):
1134
+ super(DecLayer, self).__init__()
1135
+ self.num_hidden = num_hidden
1136
+ self.num_in = num_in
1137
+ self.scale = scale
1138
+ self.dropout1 = nn.Dropout(dropout)
1139
+ self.dropout2 = nn.Dropout(dropout)
1140
+ self.norm1 = nn.LayerNorm(num_hidden)
1141
+ self.norm2 = nn.LayerNorm(num_hidden)
1142
+
1143
+ if time_cond_dim is not None:
1144
+ self.time_block = nn.Sequential(
1145
+ Rearrange('b 1 d -> b 1 1 d'),
1146
+ nn.SiLU(),
1147
+ nn.Linear(time_cond_dim, num_hidden * 2))
1148
+
1149
+ self.W1 = nn.Linear(num_hidden + num_in, num_hidden, bias=True)
1150
+ self.W2 = nn.Linear(num_hidden, num_hidden, bias=True)
1151
+ self.W3 = nn.Linear(num_hidden, num_hidden, bias=True)
1152
+ self.act = torch.nn.GELU()
1153
+ self.dense = PositionWiseFeedForward(num_hidden, num_hidden * 4)
1154
+
1155
+ def forward(self, h_V, h_E, mask_V=None, mask_attend=None, time_cond=None):
1156
+ """ Parallel computation of full transformer layer """
1157
+
1158
+ # Concatenate h_V_i to h_E_ij
1159
+ h_V_expand = h_V.unsqueeze(-2).expand(-1,-1,h_E.size(-2),-1)
1160
+ h_EV = torch.cat([h_V_expand, h_E], -1)
1161
+
1162
+ h_message = self.act(self.W2(self.act(self.W1(h_EV))))
1163
+ if time_cond is not None:
1164
+ scale, shift = self.time_block(time_cond).chunk(2, dim=-1)
1165
+ h_message = h_message * (scale + 1) + shift
1166
+ h_message = self.W3(h_message)
1167
+
1168
+ if mask_attend is not None:
1169
+ h_message = mask_attend.unsqueeze(-1) * h_message
1170
+ dh = torch.sum(h_message, -2) / self.scale
1171
+
1172
+ h_V = self.norm1(h_V + self.dropout1(dh))
1173
+
1174
+ # Position-wise feedforward
1175
+ dh = self.dense(h_V)
1176
+ h_V = self.norm2(h_V + self.dropout2(dh))
1177
+
1178
+ if mask_V is not None:
1179
+ mask_V = mask_V.unsqueeze(-1)
1180
+ h_V = mask_V * h_V
1181
+ return h_V
1182
+
1183
+
1184
+
1185
+ class PositionWiseFeedForward(nn.Module):
1186
+ def __init__(self, num_hidden, num_ff):
1187
+ super(PositionWiseFeedForward, self).__init__()
1188
+ self.W_in = nn.Linear(num_hidden, num_ff, bias=True)
1189
+ self.W_out = nn.Linear(num_ff, num_hidden, bias=True)
1190
+ self.act = torch.nn.GELU()
1191
+ def forward(self, h_V):
1192
+ h = self.act(self.W_in(h_V))
1193
+ h = self.W_out(h)
1194
+ return h
1195
+
1196
+ class PositionalEncodings(nn.Module):
1197
+ def __init__(self, num_embeddings, max_relative_feature=32):
1198
+ super(PositionalEncodings, self).__init__()
1199
+ self.num_embeddings = num_embeddings
1200
+ self.max_relative_feature = max_relative_feature
1201
+ self.linear = nn.Linear(2*max_relative_feature+1+1, num_embeddings)
1202
+
1203
+ def forward(self, offset, mask):
1204
+ d = torch.clip(offset + self.max_relative_feature, 0, 2*self.max_relative_feature)*mask + (1-mask)*(2*self.max_relative_feature+1)
1205
+ d_onehot = torch.nn.functional.one_hot(d, 2*self.max_relative_feature+1+1)
1206
+ E = self.linear(d_onehot.float())
1207
+ return E
1208
+
1209
+
1210
+
1211
+ class CA_ProteinFeatures(nn.Module):
1212
+ def __init__(self, edge_features, node_features, num_positional_embeddings=16,
1213
+ num_rbf=16, top_k=30, augment_eps=0., num_chain_embeddings=16):
1214
+ """ Extract protein features """
1215
+ super(CA_ProteinFeatures, self).__init__()
1216
+ self.edge_features = edge_features
1217
+ self.node_features = node_features
1218
+ self.top_k = top_k
1219
+ self.augment_eps = augment_eps
1220
+ self.num_rbf = num_rbf
1221
+ self.num_positional_embeddings = num_positional_embeddings
1222
+
1223
+ # Positional encoding
1224
+ self.embeddings = PositionalEncodings(num_positional_embeddings)
1225
+ # Normalization and embedding
1226
+ node_in, edge_in = 3, num_positional_embeddings + num_rbf*9 + 7
1227
+ self.node_embedding = nn.Linear(node_in, node_features, bias=False) #NOT USED
1228
+ self.edge_embedding = nn.Linear(edge_in, edge_features, bias=False)
1229
+ self.norm_nodes = nn.LayerNorm(node_features)
1230
+ self.norm_edges = nn.LayerNorm(edge_features)
1231
+
1232
+
1233
+ def _quaternions(self, R):
1234
+ """ Convert a batch of 3D rotations [R] to quaternions [Q]
1235
+ R [...,3,3]
1236
+ Q [...,4]
1237
+ """
1238
+ # Simple Wikipedia version
1239
+ # en.wikipedia.org/wiki/Rotation_matrix#Quaternion
1240
+ # For other options see math.stackexchange.com/questions/2074316/calculating-rotation-axis-from-rotation-matrix
1241
+ diag = torch.diagonal(R, dim1=-2, dim2=-1)
1242
+ Rxx, Ryy, Rzz = diag.unbind(-1)
1243
+ magnitudes = 0.5 * torch.sqrt(torch.abs(1 + torch.stack([
1244
+ Rxx - Ryy - Rzz,
1245
+ - Rxx + Ryy - Rzz,
1246
+ - Rxx - Ryy + Rzz
1247
+ ], -1)))
1248
+ _R = lambda i,j: R[:,:,:,i,j]
1249
+ signs = torch.sign(torch.stack([
1250
+ _R(2,1) - _R(1,2),
1251
+ _R(0,2) - _R(2,0),
1252
+ _R(1,0) - _R(0,1)
1253
+ ], -1))
1254
+ xyz = signs * magnitudes
1255
+ # The relu enforces a non-negative trace
1256
+ w = torch.sqrt(F.relu(1 + diag.sum(-1, keepdim=True))) / 2.
1257
+ Q = torch.cat((xyz, w), -1)
1258
+ Q = F.normalize(Q, dim=-1)
1259
+ return Q
1260
+
1261
+ def _orientations_coarse(self, X, E_idx, eps=1e-6):
1262
+ dX = X[:,1:,:] - X[:,:-1,:]
1263
+ dX_norm = torch.norm(dX,dim=-1)
1264
+ dX_mask = (3.6<dX_norm) & (dX_norm<4.0) #exclude CA-CA jumps
1265
+ dX = dX*dX_mask[:,:,None]
1266
+ U = F.normalize(dX, dim=-1)
1267
+ u_2 = U[:,:-2,:]
1268
+ u_1 = U[:,1:-1,:]
1269
+ u_0 = U[:,2:,:]
1270
+ # Backbone normals
1271
+ n_2 = F.normalize(torch.cross(u_2, u_1), dim=-1)
1272
+ n_1 = F.normalize(torch.cross(u_1, u_0), dim=-1)
1273
+
1274
+ # Bond angle calculation
1275
+ cosA = -(u_1 * u_0).sum(-1)
1276
+ cosA = torch.clamp(cosA, -1+eps, 1-eps)
1277
+ A = torch.acos(cosA)
1278
+ # Angle between normals
1279
+ cosD = (n_2 * n_1).sum(-1)
1280
+ cosD = torch.clamp(cosD, -1+eps, 1-eps)
1281
+ D = torch.sign((u_2 * n_1).sum(-1)) * torch.acos(cosD)
1282
+ # Backbone features
1283
+ AD_features = torch.stack((torch.cos(A), torch.sin(A) * torch.cos(D), torch.sin(A) * torch.sin(D)), 2)
1284
+ AD_features = F.pad(AD_features, (0,0,1,2), 'constant', 0)
1285
+
1286
+ # Build relative orientations
1287
+ o_1 = F.normalize(u_2 - u_1, dim=-1)
1288
+ O = torch.stack((o_1, n_2, torch.cross(o_1, n_2)), 2)
1289
+ O = O.view(list(O.shape[:2]) + [9])
1290
+ O = F.pad(O, (0,0,1,2), 'constant', 0)
1291
+ O_neighbors = gather_nodes(O, E_idx)
1292
+ X_neighbors = gather_nodes(X, E_idx)
1293
+
1294
+ # Re-view as rotation matrices
1295
+ O = O.view(list(O.shape[:2]) + [3,3])
1296
+ O_neighbors = O_neighbors.view(list(O_neighbors.shape[:3]) + [3,3])
1297
+
1298
+ # Rotate into local reference frames
1299
+ dX = X_neighbors - X.unsqueeze(-2)
1300
+ dU = torch.matmul(O.unsqueeze(2), dX.unsqueeze(-1)).squeeze(-1)
1301
+ dU = F.normalize(dU, dim=-1)
1302
+ R = torch.matmul(O.unsqueeze(2).transpose(-1,-2), O_neighbors)
1303
+ Q = self._quaternions(R)
1304
+
1305
+ # Orientation features
1306
+ O_features = torch.cat((dU,Q), dim=-1)
1307
+ return AD_features, O_features
1308
+
1309
+
1310
+
1311
+ def _dist(self, X, mask, eps=1E-6):
1312
+ """ Pairwise euclidean distances """
1313
+ # Convolutional network on NCHW
1314
+ mask_2D = torch.unsqueeze(mask,1) * torch.unsqueeze(mask,2)
1315
+ dX = torch.unsqueeze(X,1) - torch.unsqueeze(X,2)
1316
+ D = mask_2D * torch.sqrt(torch.sum(dX**2, 3) + eps)
1317
+
1318
+ # Identify k nearest neighbors (including self)
1319
+ D_max, _ = torch.max(D, -1, keepdim=True)
1320
+ D_adjust = D + (1. - mask_2D) * D_max
1321
+ D_neighbors, E_idx = torch.topk(D_adjust, np.minimum(self.top_k, X.shape[1]), dim=-1, largest=False)
1322
+ mask_neighbors = gather_edges(mask_2D.unsqueeze(-1), E_idx)
1323
+ return D_neighbors, E_idx, mask_neighbors
1324
+
1325
+ def _rbf(self, D):
1326
+ # Distance radial basis function
1327
+ device = D.device
1328
+ D_min, D_max, D_count = 2., 22., self.num_rbf
1329
+ D_mu = torch.linspace(D_min, D_max, D_count).to(device)
1330
+ D_mu = D_mu.view([1,1,1,-1])
1331
+ D_sigma = (D_max - D_min) / D_count
1332
+ D_expand = torch.unsqueeze(D, -1)
1333
+ RBF = torch.exp(-((D_expand - D_mu) / D_sigma)**2)
1334
+ return RBF
1335
+
1336
+ def _get_rbf(self, A, B, E_idx):
1337
+ D_A_B = torch.sqrt(torch.sum((A[:,:,None,:] - B[:,None,:,:])**2,-1) + 1e-6) #[B, L, L]
1338
+ D_A_B_neighbors = gather_edges(D_A_B[:,:,:,None], E_idx)[:,:,:,0] #[B,L,K]
1339
+ RBF_A_B = self._rbf(D_A_B_neighbors)
1340
+ return RBF_A_B
1341
+
1342
+ def forward(self, Ca, mask, residue_idx, chain_labels):
1343
+ """ Featurize coordinates as an attributed graph """
1344
+ if self.augment_eps > 0:
1345
+ Ca = Ca + self.augment_eps * torch.randn_like(Ca)
1346
+
1347
+ D_neighbors, E_idx, mask_neighbors = self._dist(Ca, mask)
1348
+
1349
+ Ca_0 = torch.zeros(Ca.shape, device=Ca.device)
1350
+ Ca_2 = torch.zeros(Ca.shape, device=Ca.device)
1351
+ Ca_0[:,1:,:] = Ca[:,:-1,:]
1352
+ Ca_1 = Ca
1353
+ Ca_2[:,:-1,:] = Ca[:,1:,:]
1354
+
1355
+ V, O_features = self._orientations_coarse(Ca, E_idx)
1356
+
1357
+ RBF_all = []
1358
+ RBF_all.append(self._rbf(D_neighbors)) #Ca_1-Ca_1
1359
+ RBF_all.append(self._get_rbf(Ca_0, Ca_0, E_idx))
1360
+ RBF_all.append(self._get_rbf(Ca_2, Ca_2, E_idx))
1361
+
1362
+ RBF_all.append(self._get_rbf(Ca_0, Ca_1, E_idx))
1363
+ RBF_all.append(self._get_rbf(Ca_0, Ca_2, E_idx))
1364
+
1365
+ RBF_all.append(self._get_rbf(Ca_1, Ca_0, E_idx))
1366
+ RBF_all.append(self._get_rbf(Ca_1, Ca_2, E_idx))
1367
+
1368
+ RBF_all.append(self._get_rbf(Ca_2, Ca_0, E_idx))
1369
+ RBF_all.append(self._get_rbf(Ca_2, Ca_1, E_idx))
1370
+
1371
+
1372
+ RBF_all = torch.cat(tuple(RBF_all), dim=-1)
1373
+
1374
+
1375
+ offset = residue_idx[:,:,None]-residue_idx[:,None,:]
1376
+ offset = gather_edges(offset[:,:,:,None], E_idx)[:,:,:,0] #[B, L, K]
1377
+
1378
+ d_chains = ((chain_labels[:, :, None] - chain_labels[:,None,:])==0).long()
1379
+ E_chains = gather_edges(d_chains[:,:,:,None], E_idx)[:,:,:,0]
1380
+ E_positional = self.embeddings(offset.long(), E_chains)
1381
+ E = torch.cat((E_positional, RBF_all, O_features), -1)
1382
+
1383
+
1384
+ E = self.edge_embedding(E)
1385
+ E = self.norm_edges(E)
1386
+
1387
+ return E, E_idx
1388
+
1389
+
1390
+ def get_closest_neighbors(X, mask, top_k, eps=1e-6):
1391
+ # X is ca coords (b, n, 3), mask is seq mask
1392
+ mask_2D = torch.unsqueeze(mask,1) * torch.unsqueeze(mask,2)
1393
+ dX = torch.unsqueeze(X,1) - torch.unsqueeze(X,2)
1394
+ D = mask_2D * torch.sqrt(torch.sum(dX**2, 3) + eps)
1395
+ D_max, _ = torch.max(D, -1, keepdim=True)
1396
+ D_adjust = D + (1. - mask_2D) * D_max
1397
+ sampled_top_k = top_k
1398
+ D_neighbors, E_idx = torch.topk(D_adjust, np.minimum(top_k, X.shape[1]), dim=-1, largest=False)
1399
+ return D_neighbors, E_idx
1400
+
1401
+
1402
+ class ProteinFeatures(nn.Module):
1403
+ def __init__(self, edge_features, node_features, num_positional_embeddings=16,
1404
+ num_rbf=16, top_k=30, augment_eps=0., num_chain_embeddings=16):
1405
+ """ Extract protein features """
1406
+ super(ProteinFeatures, self).__init__()
1407
+ self.edge_features = edge_features
1408
+ self.node_features = node_features
1409
+ self.top_k = top_k
1410
+ self.augment_eps = augment_eps
1411
+ self.num_rbf = num_rbf
1412
+ self.num_positional_embeddings = num_positional_embeddings
1413
+
1414
+ self.embeddings = PositionalEncodings(num_positional_embeddings)
1415
+ node_in, edge_in = 6, num_positional_embeddings + num_rbf*25
1416
+ self.edge_embedding = nn.Linear(edge_in, edge_features, bias=False)
1417
+ self.norm_edges = nn.LayerNorm(edge_features)
1418
+
1419
+ def _dist(self, X, mask, eps=1E-6):
1420
+ # mask_2D = torch.unsqueeze(mask,1) * torch.unsqueeze(mask,2)
1421
+ # dX = torch.unsqueeze(X,1) - torch.unsqueeze(X,2)
1422
+ # D = mask_2D * torch.sqrt(torch.sum(dX**2, 3) + eps)
1423
+ # D_max, _ = torch.max(D, -1, keepdim=True)
1424
+ # D_adjust = D + (1. - mask_2D) * D_max
1425
+ # sampled_top_k = self.top_k
1426
+ # D_neighbors, E_idx = torch.topk(D_adjust, np.minimum(self.top_k, X.shape[1]), dim=-1, largest=False)
1427
+ # return D_neighbors, E_idx
1428
+ return get_closest_neighbors(X, mask, self.top_k, eps=eps)
1429
+
1430
+ def _rbf(self, D):
1431
+ device = D.device
1432
+ D_min, D_max, D_count = 2., 22., self.num_rbf
1433
+ D_mu = torch.linspace(D_min, D_max, D_count, device=device)
1434
+ D_mu = D_mu.view([1,1,1,-1])
1435
+ D_sigma = (D_max - D_min) / D_count
1436
+ D_expand = torch.unsqueeze(D, -1)
1437
+ RBF = torch.exp(-((D_expand - D_mu) / D_sigma)**2)
1438
+ return RBF
1439
+
1440
+ def _get_rbf(self, A, B, E_idx):
1441
+ D_A_B = torch.sqrt(torch.sum((A[:,:,None,:] - B[:,None,:,:])**2,-1) + 1e-6) #[B, L, L]
1442
+ D_A_B_neighbors = gather_edges(D_A_B[:,:,:,None], E_idx)[:,:,:,0] #[B,L,K]
1443
+ RBF_A_B = self._rbf(D_A_B_neighbors)
1444
+ return RBF_A_B
1445
+
1446
+ def forward(self, X, mask, residue_idx, chain_labels):
1447
+ if self.augment_eps > 0:
1448
+ X = X + self.augment_eps * torch.randn_like(X)
1449
+
1450
+ b = X[:,:,1,:] - X[:,:,0,:]
1451
+ c = X[:,:,2,:] - X[:,:,1,:]
1452
+ a = torch.cross(b, c, dim=-1)
1453
+ Cb = -0.58273431*a + 0.56802827*b - 0.54067466*c + X[:,:,1,:]
1454
+ Ca = X[:,:,1,:]
1455
+ N = X[:,:,0,:]
1456
+ C = X[:,:,2,:]
1457
+ O = X[:,:,3,:]
1458
+
1459
+ D_neighbors, E_idx = self._dist(Ca, mask)
1460
+
1461
+ RBF_all = []
1462
+ RBF_all.append(self._rbf(D_neighbors)) #Ca-Ca
1463
+ RBF_all.append(self._get_rbf(N, N, E_idx)) #N-N
1464
+ RBF_all.append(self._get_rbf(C, C, E_idx)) #C-C
1465
+ RBF_all.append(self._get_rbf(O, O, E_idx)) #O-O
1466
+ RBF_all.append(self._get_rbf(Cb, Cb, E_idx)) #Cb-Cb
1467
+ RBF_all.append(self._get_rbf(Ca, N, E_idx)) #Ca-N
1468
+ RBF_all.append(self._get_rbf(Ca, C, E_idx)) #Ca-C
1469
+ RBF_all.append(self._get_rbf(Ca, O, E_idx)) #Ca-O
1470
+ RBF_all.append(self._get_rbf(Ca, Cb, E_idx)) #Ca-Cb
1471
+ RBF_all.append(self._get_rbf(N, C, E_idx)) #N-C
1472
+ RBF_all.append(self._get_rbf(N, O, E_idx)) #N-O
1473
+ RBF_all.append(self._get_rbf(N, Cb, E_idx)) #N-Cb
1474
+ RBF_all.append(self._get_rbf(Cb, C, E_idx)) #Cb-C
1475
+ RBF_all.append(self._get_rbf(Cb, O, E_idx)) #Cb-O
1476
+ RBF_all.append(self._get_rbf(O, C, E_idx)) #O-C
1477
+ RBF_all.append(self._get_rbf(N, Ca, E_idx)) #N-Ca
1478
+ RBF_all.append(self._get_rbf(C, Ca, E_idx)) #C-Ca
1479
+ RBF_all.append(self._get_rbf(O, Ca, E_idx)) #O-Ca
1480
+ RBF_all.append(self._get_rbf(Cb, Ca, E_idx)) #Cb-Ca
1481
+ RBF_all.append(self._get_rbf(C, N, E_idx)) #C-N
1482
+ RBF_all.append(self._get_rbf(O, N, E_idx)) #O-N
1483
+ RBF_all.append(self._get_rbf(Cb, N, E_idx)) #Cb-N
1484
+ RBF_all.append(self._get_rbf(C, Cb, E_idx)) #C-Cb
1485
+ RBF_all.append(self._get_rbf(O, Cb, E_idx)) #O-Cb
1486
+ RBF_all.append(self._get_rbf(C, O, E_idx)) #C-O
1487
+ RBF_all = torch.cat(tuple(RBF_all), dim=-1)
1488
+
1489
+ offset = residue_idx[:,:,None]-residue_idx[:,None,:]
1490
+ offset = gather_edges(offset[:,:,:,None], E_idx)[:,:,:,0] #[B, L, K]
1491
+
1492
+ d_chains = ((chain_labels[:, :, None] - chain_labels[:,None,:])==0).long() #find self vs non-self interaction
1493
+ E_chains = gather_edges(d_chains[:,:,:,None], E_idx)[:,:,:,0]
1494
+ E_positional = self.embeddings(offset.long(), E_chains)
1495
+ E = torch.cat((E_positional, RBF_all), -1)
1496
+ E = self.edge_embedding(E)
1497
+ E = self.norm_edges(E)
1498
+ return E, E_idx
1499
+
1500
+
1501
+
1502
+ class ProteinMPNN(nn.Module):
1503
+ def __init__(self, num_letters, node_features, edge_features,
1504
+ hidden_dim, num_encoder_layers=3, num_decoder_layers=3,
1505
+ vocab=21, k_neighbors=64, augment_eps=0.05, dropout=0.1, ca_only=False, time_cond_dim=None, input_S_is_embeddings=False):
1506
+ super(ProteinMPNN, self).__init__()
1507
+
1508
+ # Hyperparameters
1509
+ self.node_features = node_features
1510
+ self.edge_features = edge_features
1511
+ self.hidden_dim = hidden_dim
1512
+
1513
+ # Featurization layers
1514
+ if ca_only:
1515
+ self.features = CA_ProteinFeatures(node_features, edge_features, top_k=k_neighbors, augment_eps=augment_eps)
1516
+ self.W_v = nn.Linear(node_features, hidden_dim, bias=True)
1517
+ else:
1518
+ self.features = ProteinFeatures(node_features, edge_features, top_k=k_neighbors, augment_eps=augment_eps)
1519
+
1520
+ self.W_e = nn.Linear(edge_features, hidden_dim, bias=True)
1521
+ self.input_S_is_embeddings = input_S_is_embeddings
1522
+ if not self.input_S_is_embeddings:
1523
+ self.W_s = nn.Embedding(vocab, hidden_dim)
1524
+
1525
+ if time_cond_dim is not None:
1526
+ self.time_block = nn.Sequential(
1527
+ nn.SiLU(),
1528
+ nn.Linear(time_cond_dim, hidden_dim)
1529
+ )
1530
+
1531
+ # Encoder layers
1532
+ self.encoder_layers = nn.ModuleList([
1533
+ EncLayer(hidden_dim, hidden_dim*2, dropout=dropout, time_cond_dim=time_cond_dim)
1534
+ for _ in range(num_encoder_layers)
1535
+ ])
1536
+
1537
+ # Decoder layers
1538
+ self.decoder_layers = nn.ModuleList([
1539
+ DecLayer(hidden_dim, hidden_dim*3, dropout=dropout, time_cond_dim=time_cond_dim)
1540
+ for _ in range(num_decoder_layers)
1541
+ ])
1542
+ self.W_out = nn.Linear(hidden_dim, num_letters, bias=True)
1543
+
1544
+ for p in self.parameters():
1545
+ if p.dim() > 1:
1546
+ nn.init.xavier_uniform_(p)
1547
+
1548
+ def forward(self, X, S, mask, chain_M, residue_idx, chain_encoding_all, randn, use_input_decoding_order=False, decoding_order=None, causal_mask=True, time_cond=None, return_node_embs=False):
1549
+ """ Graph-conditioned sequence model """
1550
+ device=X.device
1551
+ # Prepare node and edge embeddings
1552
+ E, E_idx = self.features(X, mask, residue_idx, chain_encoding_all)
1553
+ h_V = torch.zeros((E.shape[0], E.shape[1], E.shape[-1]), device=E.device)
1554
+ if time_cond is not None:
1555
+ time_cond_nodes = self.time_block(time_cond)
1556
+ h_V += time_cond_nodes # time_cond is b, 1, c
1557
+ h_E = self.W_e(E)
1558
+
1559
+ # Encoder is unmasked self-attention
1560
+ mask_attend = gather_nodes(mask.unsqueeze(-1), E_idx).squeeze(-1)
1561
+ mask_attend = mask.unsqueeze(-1) * mask_attend
1562
+ for layer in self.encoder_layers:
1563
+ h_V, h_E = layer(h_V, h_E, E_idx, mask, mask_attend, time_cond=time_cond)
1564
+
1565
+ encoder_embs = h_V
1566
+
1567
+ # Concatenate sequence embeddings for autoregressive decoder
1568
+ if self.input_S_is_embeddings:
1569
+ h_S = S
1570
+ else:
1571
+ h_S = self.W_s(S)
1572
+ h_ES = cat_neighbors_nodes(h_S, h_E, E_idx)
1573
+
1574
+ # Build encoder embeddings
1575
+ h_EX_encoder = cat_neighbors_nodes(torch.zeros_like(h_S), h_E, E_idx)
1576
+ h_EXV_encoder = cat_neighbors_nodes(h_V, h_EX_encoder, E_idx)
1577
+
1578
+
1579
+ chain_M = chain_M*mask #update chain_M to include missing regions
1580
+ mask_size = E_idx.shape[1]
1581
+ if causal_mask:
1582
+ if not use_input_decoding_order:
1583
+ decoding_order = torch.argsort((chain_M+0.0001)*(torch.abs(randn))) #[numbers will be smaller for places where chain_M = 0.0 and higher for places where chain_M = 1.0]
1584
+ permutation_matrix_reverse = torch.nn.functional.one_hot(decoding_order, num_classes=mask_size).float()
1585
+ order_mask_backward = torch.einsum('ij, biq, bjp->bqp',(1-torch.triu(torch.ones(mask_size,mask_size, device=device))), permutation_matrix_reverse, permutation_matrix_reverse)
1586
+ else:
1587
+ order_mask_backward = torch.ones(X.shape[0], mask_size, mask_size, device=device)
1588
+ mask_attend = torch.gather(order_mask_backward, 2, E_idx).unsqueeze(-1)
1589
+ mask_1D = mask.view([mask.size(0), mask.size(1), 1, 1])
1590
+ mask_bw = mask_1D * mask_attend
1591
+ mask_fw = mask_1D * (1. - mask_attend)
1592
+
1593
+ h_EXV_encoder_fw = mask_fw * h_EXV_encoder
1594
+ for layer in self.decoder_layers:
1595
+ # Masked positions attend to encoder information, unmasked see.
1596
+ h_ESV = cat_neighbors_nodes(h_V, h_ES, E_idx)
1597
+ h_ESV = mask_bw * h_ESV + h_EXV_encoder_fw
1598
+ h_V = layer(h_V, h_ESV, mask, time_cond=time_cond)
1599
+
1600
+ if return_node_embs:
1601
+ return h_V, encoder_embs
1602
+ else:
1603
+ logits = self.W_out(h_V)
1604
+ log_probs = F.log_softmax(logits, dim=-1)
1605
+ return log_probs
1606
+
1607
+
1608
+ def sample(self, X, randn, S_true, chain_mask, chain_encoding_all, residue_idx, mask=None, temperature=1.0, omit_AAs_np=None, bias_AAs_np=None, chain_M_pos=None, omit_AA_mask=None, pssm_coef=None, pssm_bias=None, pssm_multi=None, pssm_log_odds_flag=None, pssm_log_odds_mask=None, pssm_bias_flag=None, bias_by_res=None):
1609
+ device = X.device
1610
+ # Prepare node and edge embeddings
1611
+ E, E_idx = self.features(X, mask, residue_idx, chain_encoding_all)
1612
+ h_V = torch.zeros((E.shape[0], E.shape[1], E.shape[-1]), device=device)
1613
+ h_E = self.W_e(E)
1614
+
1615
+ # Encoder is unmasked self-attention
1616
+ mask_attend = gather_nodes(mask.unsqueeze(-1), E_idx).squeeze(-1)
1617
+ mask_attend = mask.unsqueeze(-1) * mask_attend
1618
+ for layer in self.encoder_layers:
1619
+ h_V, h_E = layer(h_V, h_E, E_idx, mask, mask_attend)
1620
+
1621
+ # Decoder uses masked self-attention
1622
+ chain_mask = chain_mask*chain_M_pos*mask #update chain_M to include missing regions
1623
+ decoding_order = torch.argsort((chain_mask+0.0001)*(torch.abs(randn))) #[numbers will be smaller for places where chain_M = 0.0 and higher for places where chain_M = 1.0]
1624
+ mask_size = E_idx.shape[1]
1625
+ permutation_matrix_reverse = torch.nn.functional.one_hot(decoding_order, num_classes=mask_size).float()
1626
+ order_mask_backward = torch.einsum('ij, biq, bjp->bqp',(1-torch.triu(torch.ones(mask_size,mask_size, device=device))), permutation_matrix_reverse, permutation_matrix_reverse)
1627
+ mask_attend = torch.gather(order_mask_backward, 2, E_idx).unsqueeze(-1)
1628
+ mask_1D = mask.view([mask.size(0), mask.size(1), 1, 1])
1629
+ mask_bw = mask_1D * mask_attend
1630
+ mask_fw = mask_1D * (1. - mask_attend)
1631
+
1632
+ N_batch, N_nodes = X.size(0), X.size(1)
1633
+ log_probs = torch.zeros((N_batch, N_nodes, 21), device=device)
1634
+ all_probs = torch.zeros((N_batch, N_nodes, 21), device=device, dtype=torch.float32)
1635
+ h_S = torch.zeros_like(h_V, device=device)
1636
+ S = torch.zeros((N_batch, N_nodes), dtype=torch.int64, device=device)
1637
+ h_V_stack = [h_V] + [torch.zeros_like(h_V, device=device) for _ in range(len(self.decoder_layers))]
1638
+ constant = torch.tensor(omit_AAs_np, device=device)
1639
+ constant_bias = torch.tensor(bias_AAs_np, device=device)
1640
+ #chain_mask_combined = chain_mask*chain_M_pos
1641
+ omit_AA_mask_flag = omit_AA_mask != None
1642
+
1643
+
1644
+ h_EX_encoder = cat_neighbors_nodes(torch.zeros_like(h_S), h_E, E_idx)
1645
+ h_EXV_encoder = cat_neighbors_nodes(h_V, h_EX_encoder, E_idx)
1646
+ h_EXV_encoder_fw = mask_fw * h_EXV_encoder
1647
+ for t_ in range(N_nodes):
1648
+ t = decoding_order[:,t_] #[B]
1649
+ chain_mask_gathered = torch.gather(chain_mask, 1, t[:,None]) #[B]
1650
+ mask_gathered = torch.gather(mask, 1, t[:,None]) #[B]
1651
+ bias_by_res_gathered = torch.gather(bias_by_res, 1, t[:,None,None].repeat(1,1,21))[:,0,:] #[B, 21]
1652
+ if (mask_gathered==0).all(): #for padded or missing regions only
1653
+ S_t = torch.gather(S_true, 1, t[:,None])
1654
+ else:
1655
+ # Hidden layers
1656
+ E_idx_t = torch.gather(E_idx, 1, t[:,None,None].repeat(1,1,E_idx.shape[-1]))
1657
+ h_E_t = torch.gather(h_E, 1, t[:,None,None,None].repeat(1,1,h_E.shape[-2], h_E.shape[-1]))
1658
+ h_ES_t = cat_neighbors_nodes(h_S, h_E_t, E_idx_t)
1659
+ h_EXV_encoder_t = torch.gather(h_EXV_encoder_fw, 1, t[:,None,None,None].repeat(1,1,h_EXV_encoder_fw.shape[-2], h_EXV_encoder_fw.shape[-1]))
1660
+ mask_t = torch.gather(mask, 1, t[:,None])
1661
+ for l, layer in enumerate(self.decoder_layers):
1662
+ # Updated relational features for future states
1663
+ h_ESV_decoder_t = cat_neighbors_nodes(h_V_stack[l], h_ES_t, E_idx_t)
1664
+ h_V_t = torch.gather(h_V_stack[l], 1, t[:,None,None].repeat(1,1,h_V_stack[l].shape[-1]))
1665
+ h_ESV_t = torch.gather(mask_bw, 1, t[:,None,None,None].repeat(1,1,mask_bw.shape[-2], mask_bw.shape[-1])) * h_ESV_decoder_t + h_EXV_encoder_t
1666
+ h_V_stack[l+1].scatter_(1, t[:,None,None].repeat(1,1,h_V.shape[-1]), layer(h_V_t, h_ESV_t, mask_V=mask_t))
1667
+ # Sampling step
1668
+ h_V_t = torch.gather(h_V_stack[-1], 1, t[:,None,None].repeat(1,1,h_V_stack[-1].shape[-1]))[:,0]
1669
+ logits = self.W_out(h_V_t) / temperature
1670
+ probs = F.softmax(logits-constant[None,:]*1e8+constant_bias[None,:]/temperature+bias_by_res_gathered/temperature, dim=-1)
1671
+ if pssm_bias_flag:
1672
+ pssm_coef_gathered = torch.gather(pssm_coef, 1, t[:,None])[:,0]
1673
+ pssm_bias_gathered = torch.gather(pssm_bias, 1, t[:,None,None].repeat(1,1,pssm_bias.shape[-1]))[:,0]
1674
+ probs = (1-pssm_multi*pssm_coef_gathered[:,None])*probs + pssm_multi*pssm_coef_gathered[:,None]*pssm_bias_gathered
1675
+ if pssm_log_odds_flag:
1676
+ pssm_log_odds_mask_gathered = torch.gather(pssm_log_odds_mask, 1, t[:,None, None].repeat(1,1,pssm_log_odds_mask.shape[-1]))[:,0] #[B, 21]
1677
+ probs_masked = probs*pssm_log_odds_mask_gathered
1678
+ probs_masked += probs * 0.001
1679
+ probs = probs_masked/torch.sum(probs_masked, dim=-1, keepdim=True) #[B, 21]
1680
+ if omit_AA_mask_flag:
1681
+ omit_AA_mask_gathered = torch.gather(omit_AA_mask, 1, t[:,None, None].repeat(1,1,omit_AA_mask.shape[-1]))[:,0] #[B, 21]
1682
+ probs_masked = probs*(1.0-omit_AA_mask_gathered)
1683
+ probs = probs_masked/torch.sum(probs_masked, dim=-1, keepdim=True) #[B, 21]
1684
+ S_t = torch.multinomial(probs, 1)
1685
+ all_probs.scatter_(1, t[:,None,None].repeat(1,1,21), (chain_mask_gathered[:,:,None,]*probs[:,None,:]).float())
1686
+ S_true_gathered = torch.gather(S_true, 1, t[:,None])
1687
+ S_t = (S_t*chain_mask_gathered+S_true_gathered*(1.0-chain_mask_gathered)).long()
1688
+ temp1 = self.W_s(S_t)
1689
+ h_S.scatter_(1, t[:,None,None].repeat(1,1,temp1.shape[-1]), temp1)
1690
+ S.scatter_(1, t[:,None], S_t)
1691
+ output_dict = {"S": S, "probs": all_probs, "decoding_order": decoding_order}
1692
+ return output_dict
1693
+
1694
+
1695
+ def tied_sample(self, X, randn, S_true, chain_mask, chain_encoding_all, residue_idx, mask=None, temperature=1.0, omit_AAs_np=None, bias_AAs_np=None, chain_M_pos=None, omit_AA_mask=None, pssm_coef=None, pssm_bias=None, pssm_multi=None, pssm_log_odds_flag=None, pssm_log_odds_mask=None, pssm_bias_flag=None, tied_pos=None, tied_beta=None, bias_by_res=None):
1696
+ device = X.device
1697
+ # Prepare node and edge embeddings
1698
+ E, E_idx = self.features(X, mask, residue_idx, chain_encoding_all)
1699
+ h_V = torch.zeros((E.shape[0], E.shape[1], E.shape[-1]), device=device)
1700
+ h_E = self.W_e(E)
1701
+ # Encoder is unmasked self-attention
1702
+ mask_attend = gather_nodes(mask.unsqueeze(-1), E_idx).squeeze(-1)
1703
+ mask_attend = mask.unsqueeze(-1) * mask_attend
1704
+ for layer in self.encoder_layers:
1705
+ h_V, h_E = layer(h_V, h_E, E_idx, mask, mask_attend)
1706
+
1707
+ # Decoder uses masked self-attention
1708
+ chain_mask = chain_mask*chain_M_pos*mask #update chain_M to include missing regions
1709
+ decoding_order = torch.argsort((chain_mask+0.0001)*(torch.abs(randn))) #[numbers will be smaller for places where chain_M = 0.0 and higher for places where chain_M = 1.0]
1710
+
1711
+ new_decoding_order = []
1712
+ for t_dec in list(decoding_order[0,].cpu().data.numpy()):
1713
+ if t_dec not in list(itertools.chain(*new_decoding_order)):
1714
+ list_a = [item for item in tied_pos if t_dec in item]
1715
+ if list_a:
1716
+ new_decoding_order.append(list_a[0])
1717
+ else:
1718
+ new_decoding_order.append([t_dec])
1719
+ decoding_order = torch.tensor(list(itertools.chain(*new_decoding_order)), device=device)[None,].repeat(X.shape[0],1)
1720
+
1721
+ mask_size = E_idx.shape[1]
1722
+ permutation_matrix_reverse = torch.nn.functional.one_hot(decoding_order, num_classes=mask_size).float()
1723
+ order_mask_backward = torch.einsum('ij, biq, bjp->bqp',(1-torch.triu(torch.ones(mask_size,mask_size, device=device))), permutation_matrix_reverse, permutation_matrix_reverse)
1724
+ mask_attend = torch.gather(order_mask_backward, 2, E_idx).unsqueeze(-1)
1725
+ mask_1D = mask.view([mask.size(0), mask.size(1), 1, 1])
1726
+ mask_bw = mask_1D * mask_attend
1727
+ mask_fw = mask_1D * (1. - mask_attend)
1728
+
1729
+ N_batch, N_nodes = X.size(0), X.size(1)
1730
+ log_probs = torch.zeros((N_batch, N_nodes, 21), device=device)
1731
+ all_probs = torch.zeros((N_batch, N_nodes, 21), device=device, dtype=torch.float32)
1732
+ h_S = torch.zeros_like(h_V, device=device)
1733
+ S = torch.zeros((N_batch, N_nodes), dtype=torch.int64, device=device)
1734
+ h_V_stack = [h_V] + [torch.zeros_like(h_V, device=device) for _ in range(len(self.decoder_layers))]
1735
+ constant = torch.tensor(omit_AAs_np, device=device)
1736
+ constant_bias = torch.tensor(bias_AAs_np, device=device)
1737
+ omit_AA_mask_flag = omit_AA_mask != None
1738
+
1739
+ h_EX_encoder = cat_neighbors_nodes(torch.zeros_like(h_S), h_E, E_idx)
1740
+ h_EXV_encoder = cat_neighbors_nodes(h_V, h_EX_encoder, E_idx)
1741
+ h_EXV_encoder_fw = mask_fw * h_EXV_encoder
1742
+ for t_list in new_decoding_order:
1743
+ logits = 0.0
1744
+ logit_list = []
1745
+ done_flag = False
1746
+ for t in t_list:
1747
+ if (mask[:,t]==0).all():
1748
+ S_t = S_true[:,t]
1749
+ for t in t_list:
1750
+ h_S[:,t,:] = self.W_s(S_t)
1751
+ S[:,t] = S_t
1752
+ done_flag = True
1753
+ break
1754
+ else:
1755
+ E_idx_t = E_idx[:,t:t+1,:]
1756
+ h_E_t = h_E[:,t:t+1,:,:]
1757
+ h_ES_t = cat_neighbors_nodes(h_S, h_E_t, E_idx_t)
1758
+ h_EXV_encoder_t = h_EXV_encoder_fw[:,t:t+1,:,:]
1759
+ mask_t = mask[:,t:t+1]
1760
+ for l, layer in enumerate(self.decoder_layers):
1761
+ h_ESV_decoder_t = cat_neighbors_nodes(h_V_stack[l], h_ES_t, E_idx_t)
1762
+ h_V_t = h_V_stack[l][:,t:t+1,:]
1763
+ h_ESV_t = mask_bw[:,t:t+1,:,:] * h_ESV_decoder_t + h_EXV_encoder_t
1764
+ h_V_stack[l+1][:,t,:] = layer(h_V_t, h_ESV_t, mask_V=mask_t).squeeze(1)
1765
+ h_V_t = h_V_stack[-1][:,t,:]
1766
+ logit_list.append((self.W_out(h_V_t) / temperature)/len(t_list))
1767
+ logits += tied_beta[t]*(self.W_out(h_V_t) / temperature)/len(t_list)
1768
+ if done_flag:
1769
+ pass
1770
+ else:
1771
+ bias_by_res_gathered = bias_by_res[:,t,:] #[B, 21]
1772
+ probs = F.softmax(logits-constant[None,:]*1e8+constant_bias[None,:]/temperature+bias_by_res_gathered/temperature, dim=-1)
1773
+ if pssm_bias_flag:
1774
+ pssm_coef_gathered = pssm_coef[:,t]
1775
+ pssm_bias_gathered = pssm_bias[:,t]
1776
+ probs = (1-pssm_multi*pssm_coef_gathered[:,None])*probs + pssm_multi*pssm_coef_gathered[:,None]*pssm_bias_gathered
1777
+ if pssm_log_odds_flag:
1778
+ pssm_log_odds_mask_gathered = pssm_log_odds_mask[:,t]
1779
+ probs_masked = probs*pssm_log_odds_mask_gathered
1780
+ probs_masked += probs * 0.001
1781
+ probs = probs_masked/torch.sum(probs_masked, dim=-1, keepdim=True) #[B, 21]
1782
+ if omit_AA_mask_flag:
1783
+ omit_AA_mask_gathered = omit_AA_mask[:,t]
1784
+ probs_masked = probs*(1.0-omit_AA_mask_gathered)
1785
+ probs = probs_masked/torch.sum(probs_masked, dim=-1, keepdim=True) #[B, 21]
1786
+ S_t_repeat = torch.multinomial(probs, 1).squeeze(-1)
1787
+ S_t_repeat = (chain_mask[:,t]*S_t_repeat + (1-chain_mask[:,t])*S_true[:,t]).long() #hard pick fixed positions
1788
+ for t in t_list:
1789
+ h_S[:,t,:] = self.W_s(S_t_repeat)
1790
+ S[:,t] = S_t_repeat
1791
+ all_probs[:,t,:] = probs.float()
1792
+ output_dict = {"S": S, "probs": all_probs, "decoding_order": decoding_order}
1793
+ return output_dict
1794
+
1795
+
1796
+ def conditional_probs(self, X, S, mask, chain_M, residue_idx, chain_encoding_all, randn, backbone_only=False):
1797
+ """ Graph-conditioned sequence model """
1798
+ device=X.device
1799
+ # Prepare node and edge embeddings
1800
+ E, E_idx = self.features(X, mask, residue_idx, chain_encoding_all)
1801
+ h_V_enc = torch.zeros((E.shape[0], E.shape[1], E.shape[-1]), device=E.device)
1802
+ h_E = self.W_e(E)
1803
+
1804
+ # Encoder is unmasked self-attention
1805
+ mask_attend = gather_nodes(mask.unsqueeze(-1), E_idx).squeeze(-1)
1806
+ mask_attend = mask.unsqueeze(-1) * mask_attend
1807
+ for layer in self.encoder_layers:
1808
+ h_V_enc, h_E = layer(h_V_enc, h_E, E_idx, mask, mask_attend)
1809
+
1810
+ # Concatenate sequence embeddings for autoregressive decoder
1811
+ h_S = self.W_s(S)
1812
+ h_ES = cat_neighbors_nodes(h_S, h_E, E_idx)
1813
+
1814
+ # Build encoder embeddings
1815
+ h_EX_encoder = cat_neighbors_nodes(torch.zeros_like(h_S), h_E, E_idx)
1816
+ h_EXV_encoder = cat_neighbors_nodes(h_V_enc, h_EX_encoder, E_idx)
1817
+
1818
+
1819
+ chain_M = chain_M*mask #update chain_M to include missing regions
1820
+
1821
+ chain_M_np = chain_M.cpu().numpy()
1822
+ idx_to_loop = np.argwhere(chain_M_np[0,:]==1)[:,0]
1823
+ log_conditional_probs = torch.zeros([X.shape[0], chain_M.shape[1], 21], device=device).float()
1824
+
1825
+ for idx in idx_to_loop:
1826
+ h_V = torch.clone(h_V_enc)
1827
+ order_mask = torch.zeros(chain_M.shape[1], device=device).float()
1828
+ if backbone_only:
1829
+ order_mask = torch.ones(chain_M.shape[1], device=device).float()
1830
+ order_mask[idx] = 0.
1831
+ else:
1832
+ order_mask = torch.zeros(chain_M.shape[1], device=device).float()
1833
+ order_mask[idx] = 1.
1834
+ decoding_order = torch.argsort((order_mask[None,]+0.0001)*(torch.abs(randn))) #[numbers will be smaller for places where chain_M = 0.0 and higher for places where chain_M = 1.0]
1835
+ mask_size = E_idx.shape[1]
1836
+ permutation_matrix_reverse = torch.nn.functional.one_hot(decoding_order, num_classes=mask_size).float()
1837
+ order_mask_backward = torch.einsum('ij, biq, bjp->bqp',(1-torch.triu(torch.ones(mask_size,mask_size, device=device))), permutation_matrix_reverse, permutation_matrix_reverse)
1838
+ mask_attend = torch.gather(order_mask_backward, 2, E_idx).unsqueeze(-1)
1839
+ mask_1D = mask.view([mask.size(0), mask.size(1), 1, 1])
1840
+ mask_bw = mask_1D * mask_attend
1841
+ mask_fw = mask_1D * (1. - mask_attend)
1842
+
1843
+ h_EXV_encoder_fw = mask_fw * h_EXV_encoder
1844
+ for layer in self.decoder_layers:
1845
+ # Masked positions attend to encoder information, unmasked see.
1846
+ h_ESV = cat_neighbors_nodes(h_V, h_ES, E_idx)
1847
+ h_ESV = mask_bw * h_ESV + h_EXV_encoder_fw
1848
+ h_V = layer(h_V, h_ESV, mask)
1849
+
1850
+ logits = self.W_out(h_V)
1851
+ log_probs = F.log_softmax(logits, dim=-1)
1852
+ log_conditional_probs[:,idx,:] = log_probs[:,idx,:]
1853
+ return log_conditional_probs
1854
+
1855
+
1856
+ def unconditional_probs(self, X, mask, residue_idx, chain_encoding_all):
1857
+ """ Graph-conditioned sequence model """
1858
+ device=X.device
1859
+ # Prepare node and edge embeddings
1860
+ E, E_idx = self.features(X, mask, residue_idx, chain_encoding_all)
1861
+ h_V = torch.zeros((E.shape[0], E.shape[1], E.shape[-1]), device=E.device)
1862
+ h_E = self.W_e(E)
1863
+
1864
+ # Encoder is unmasked self-attention
1865
+ mask_attend = gather_nodes(mask.unsqueeze(-1), E_idx).squeeze(-1)
1866
+ mask_attend = mask.unsqueeze(-1) * mask_attend
1867
+ for layer in self.encoder_layers:
1868
+ h_V, h_E = layer(h_V, h_E, E_idx, mask, mask_attend)
1869
+
1870
+ # Build encoder embeddings
1871
+ h_EX_encoder = cat_neighbors_nodes(torch.zeros_like(h_V), h_E, E_idx)
1872
+ h_EXV_encoder = cat_neighbors_nodes(h_V, h_EX_encoder, E_idx)
1873
+
1874
+ order_mask_backward = torch.zeros([X.shape[0], X.shape[1], X.shape[1]], device=device)
1875
+ mask_attend = torch.gather(order_mask_backward, 2, E_idx).unsqueeze(-1)
1876
+ mask_1D = mask.view([mask.size(0), mask.size(1), 1, 1])
1877
+ mask_bw = mask_1D * mask_attend
1878
+ mask_fw = mask_1D * (1. - mask_attend)
1879
+
1880
+ h_EXV_encoder_fw = mask_fw * h_EXV_encoder
1881
+ for layer in self.decoder_layers:
1882
+ h_V = layer(h_V, h_EXV_encoder_fw, mask)
1883
+
1884
+ logits = self.W_out(h_V)
1885
+ log_probs = F.log_softmax(logits, dim=-1)
1886
+ return log_probs
core/residue_constants.py ADDED
@@ -0,0 +1,1104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 DeepMind Technologies Limited
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Constants used in AlphaFold.
16
+ Adapted from original code by alexechu.
17
+ """
18
+
19
+ import collections
20
+ import functools
21
+ import os
22
+ from typing import List, Mapping, Tuple
23
+
24
+ import numpy as np
25
+ import tree
26
+
27
+ # Internal import (35fd).
28
+
29
+
30
+ # Distance from one CA to next CA [trans configuration: omega = 180].
31
+ ca_ca = 3.80209737096
32
+
33
+ # Format: The list for each AA type contains chi1, chi2, chi3, chi4 in
34
+ # this order (or a relevant subset from chi1 onwards). ALA and GLY don't have
35
+ # chi angles so their chi angle lists are empty.
36
+ chi_angles_atoms = {
37
+ "ALA": [],
38
+ # Chi5 in arginine is always 0 +- 5 degrees, so ignore it.
39
+ "ARG": [
40
+ ["N", "CA", "CB", "CG"],
41
+ ["CA", "CB", "CG", "CD"],
42
+ ["CB", "CG", "CD", "NE"],
43
+ ["CG", "CD", "NE", "CZ"],
44
+ ],
45
+ "ASN": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "OD1"]],
46
+ "ASP": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "OD1"]],
47
+ "CYS": [["N", "CA", "CB", "SG"]],
48
+ "GLN": [
49
+ ["N", "CA", "CB", "CG"],
50
+ ["CA", "CB", "CG", "CD"],
51
+ ["CB", "CG", "CD", "OE1"],
52
+ ],
53
+ "GLU": [
54
+ ["N", "CA", "CB", "CG"],
55
+ ["CA", "CB", "CG", "CD"],
56
+ ["CB", "CG", "CD", "OE1"],
57
+ ],
58
+ "GLY": [],
59
+ "HIS": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "ND1"]],
60
+ "ILE": [["N", "CA", "CB", "CG1"], ["CA", "CB", "CG1", "CD1"]],
61
+ "LEU": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD1"]],
62
+ "LYS": [
63
+ ["N", "CA", "CB", "CG"],
64
+ ["CA", "CB", "CG", "CD"],
65
+ ["CB", "CG", "CD", "CE"],
66
+ ["CG", "CD", "CE", "NZ"],
67
+ ],
68
+ "MET": [
69
+ ["N", "CA", "CB", "CG"],
70
+ ["CA", "CB", "CG", "SD"],
71
+ ["CB", "CG", "SD", "CE"],
72
+ ],
73
+ "PHE": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD1"]],
74
+ "PRO": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD"]],
75
+ "SER": [["N", "CA", "CB", "OG"]],
76
+ "THR": [["N", "CA", "CB", "OG1"]],
77
+ "TRP": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD1"]],
78
+ "TYR": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD1"]],
79
+ "VAL": [["N", "CA", "CB", "CG1"]],
80
+ }
81
+
82
+ # If chi angles given in fixed-length array, this matrix determines how to mask
83
+ # them for each AA type. The order is as per restype_order (see below).
84
+ chi_angles_mask = [
85
+ [0.0, 0.0, 0.0, 0.0], # ALA
86
+ [1.0, 1.0, 1.0, 1.0], # ARG
87
+ [1.0, 1.0, 0.0, 0.0], # ASN
88
+ [1.0, 1.0, 0.0, 0.0], # ASP
89
+ [1.0, 0.0, 0.0, 0.0], # CYS
90
+ [1.0, 1.0, 1.0, 0.0], # GLN
91
+ [1.0, 1.0, 1.0, 0.0], # GLU
92
+ [0.0, 0.0, 0.0, 0.0], # GLY
93
+ [1.0, 1.0, 0.0, 0.0], # HIS
94
+ [1.0, 1.0, 0.0, 0.0], # ILE
95
+ [1.0, 1.0, 0.0, 0.0], # LEU
96
+ [1.0, 1.0, 1.0, 1.0], # LYS
97
+ [1.0, 1.0, 1.0, 0.0], # MET
98
+ [1.0, 1.0, 0.0, 0.0], # PHE
99
+ [1.0, 1.0, 0.0, 0.0], # PRO
100
+ [1.0, 0.0, 0.0, 0.0], # SER
101
+ [1.0, 0.0, 0.0, 0.0], # THR
102
+ [1.0, 1.0, 0.0, 0.0], # TRP
103
+ [1.0, 1.0, 0.0, 0.0], # TYR
104
+ [1.0, 0.0, 0.0, 0.0], # VAL
105
+ ]
106
+
107
+ # The following chi angles are pi periodic: they can be rotated by a multiple
108
+ # of pi without affecting the structure.
109
+ chi_pi_periodic = [
110
+ [0.0, 0.0, 0.0, 0.0], # ALA
111
+ [0.0, 0.0, 0.0, 0.0], # ARG
112
+ [0.0, 0.0, 0.0, 0.0], # ASN
113
+ [0.0, 1.0, 0.0, 0.0], # ASP
114
+ [0.0, 0.0, 0.0, 0.0], # CYS
115
+ [0.0, 0.0, 0.0, 0.0], # GLN
116
+ [0.0, 0.0, 1.0, 0.0], # GLU
117
+ [0.0, 0.0, 0.0, 0.0], # GLY
118
+ [0.0, 0.0, 0.0, 0.0], # HIS
119
+ [0.0, 0.0, 0.0, 0.0], # ILE
120
+ [0.0, 0.0, 0.0, 0.0], # LEU
121
+ [0.0, 0.0, 0.0, 0.0], # LYS
122
+ [0.0, 0.0, 0.0, 0.0], # MET
123
+ [0.0, 1.0, 0.0, 0.0], # PHE
124
+ [0.0, 0.0, 0.0, 0.0], # PRO
125
+ [0.0, 0.0, 0.0, 0.0], # SER
126
+ [0.0, 0.0, 0.0, 0.0], # THR
127
+ [0.0, 0.0, 0.0, 0.0], # TRP
128
+ [0.0, 1.0, 0.0, 0.0], # TYR
129
+ [0.0, 0.0, 0.0, 0.0], # VAL
130
+ [0.0, 0.0, 0.0, 0.0], # UNK
131
+ ]
132
+
133
+ # Atoms positions relative to the 8 rigid groups, defined by the pre-omega, phi,
134
+ # psi and chi angles:
135
+ # 0: 'backbone group',
136
+ # 1: 'pre-omega-group', (empty)
137
+ # 2: 'phi-group', (currently empty, because it defines only hydrogens)
138
+ # 3: 'psi-group',
139
+ # 4,5,6,7: 'chi1,2,3,4-group'
140
+ # The atom positions are relative to the axis-end-atom of the corresponding
141
+ # rotation axis. The x-axis is in direction of the rotation axis, and the y-axis
142
+ # is defined such that the dihedral-angle-defining atom (the last entry in
143
+ # chi_angles_atoms above) is in the xy-plane (with a positive y-coordinate).
144
+ # format: [atomname, group_idx, rel_position]
145
+ rigid_group_atom_positions = {
146
+ "ALA": [
147
+ ["N", 0, (-0.525, 1.363, 0.000)],
148
+ ["CA", 0, (0.000, 0.000, 0.000)],
149
+ ["C", 0, (1.526, -0.000, -0.000)],
150
+ ["CB", 0, (-0.529, -0.774, -1.205)],
151
+ ["O", 3, (0.627, 1.062, 0.000)],
152
+ ],
153
+ "ARG": [
154
+ ["N", 0, (-0.524, 1.362, -0.000)],
155
+ ["CA", 0, (0.000, 0.000, 0.000)],
156
+ ["C", 0, (1.525, -0.000, -0.000)],
157
+ ["CB", 0, (-0.524, -0.778, -1.209)],
158
+ ["O", 3, (0.626, 1.062, 0.000)],
159
+ ["CG", 4, (0.616, 1.390, -0.000)],
160
+ ["CD", 5, (0.564, 1.414, 0.000)],
161
+ ["NE", 6, (0.539, 1.357, -0.000)],
162
+ ["NH1", 7, (0.206, 2.301, 0.000)],
163
+ ["NH2", 7, (2.078, 0.978, -0.000)],
164
+ ["CZ", 7, (0.758, 1.093, -0.000)],
165
+ ],
166
+ "ASN": [
167
+ ["N", 0, (-0.536, 1.357, 0.000)],
168
+ ["CA", 0, (0.000, 0.000, 0.000)],
169
+ ["C", 0, (1.526, -0.000, -0.000)],
170
+ ["CB", 0, (-0.531, -0.787, -1.200)],
171
+ ["O", 3, (0.625, 1.062, 0.000)],
172
+ ["CG", 4, (0.584, 1.399, 0.000)],
173
+ ["ND2", 5, (0.593, -1.188, 0.001)],
174
+ ["OD1", 5, (0.633, 1.059, 0.000)],
175
+ ],
176
+ "ASP": [
177
+ ["N", 0, (-0.525, 1.362, -0.000)],
178
+ ["CA", 0, (0.000, 0.000, 0.000)],
179
+ ["C", 0, (1.527, 0.000, -0.000)],
180
+ ["CB", 0, (-0.526, -0.778, -1.208)],
181
+ ["O", 3, (0.626, 1.062, -0.000)],
182
+ ["CG", 4, (0.593, 1.398, -0.000)],
183
+ ["OD1", 5, (0.610, 1.091, 0.000)],
184
+ ["OD2", 5, (0.592, -1.101, -0.003)],
185
+ ],
186
+ "CYS": [
187
+ ["N", 0, (-0.522, 1.362, -0.000)],
188
+ ["CA", 0, (0.000, 0.000, 0.000)],
189
+ ["C", 0, (1.524, 0.000, 0.000)],
190
+ ["CB", 0, (-0.519, -0.773, -1.212)],
191
+ ["O", 3, (0.625, 1.062, -0.000)],
192
+ ["SG", 4, (0.728, 1.653, 0.000)],
193
+ ],
194
+ "GLN": [
195
+ ["N", 0, (-0.526, 1.361, -0.000)],
196
+ ["CA", 0, (0.000, 0.000, 0.000)],
197
+ ["C", 0, (1.526, 0.000, 0.000)],
198
+ ["CB", 0, (-0.525, -0.779, -1.207)],
199
+ ["O", 3, (0.626, 1.062, -0.000)],
200
+ ["CG", 4, (0.615, 1.393, 0.000)],
201
+ ["CD", 5, (0.587, 1.399, -0.000)],
202
+ ["NE2", 6, (0.593, -1.189, -0.001)],
203
+ ["OE1", 6, (0.634, 1.060, 0.000)],
204
+ ],
205
+ "GLU": [
206
+ ["N", 0, (-0.528, 1.361, 0.000)],
207
+ ["CA", 0, (0.000, 0.000, 0.000)],
208
+ ["C", 0, (1.526, -0.000, -0.000)],
209
+ ["CB", 0, (-0.526, -0.781, -1.207)],
210
+ ["O", 3, (0.626, 1.062, 0.000)],
211
+ ["CG", 4, (0.615, 1.392, 0.000)],
212
+ ["CD", 5, (0.600, 1.397, 0.000)],
213
+ ["OE1", 6, (0.607, 1.095, -0.000)],
214
+ ["OE2", 6, (0.589, -1.104, -0.001)],
215
+ ],
216
+ "GLY": [
217
+ ["N", 0, (-0.572, 1.337, 0.000)],
218
+ ["CA", 0, (0.000, 0.000, 0.000)],
219
+ ["C", 0, (1.517, -0.000, -0.000)],
220
+ ["O", 3, (0.626, 1.062, -0.000)],
221
+ ],
222
+ "HIS": [
223
+ ["N", 0, (-0.527, 1.360, 0.000)],
224
+ ["CA", 0, (0.000, 0.000, 0.000)],
225
+ ["C", 0, (1.525, 0.000, 0.000)],
226
+ ["CB", 0, (-0.525, -0.778, -1.208)],
227
+ ["O", 3, (0.625, 1.063, 0.000)],
228
+ ["CG", 4, (0.600, 1.370, -0.000)],
229
+ ["CD2", 5, (0.889, -1.021, 0.003)],
230
+ ["ND1", 5, (0.744, 1.160, -0.000)],
231
+ ["CE1", 5, (2.030, 0.851, 0.002)],
232
+ ["NE2", 5, (2.145, -0.466, 0.004)],
233
+ ],
234
+ "ILE": [
235
+ ["N", 0, (-0.493, 1.373, -0.000)],
236
+ ["CA", 0, (0.000, 0.000, 0.000)],
237
+ ["C", 0, (1.527, -0.000, -0.000)],
238
+ ["CB", 0, (-0.536, -0.793, -1.213)],
239
+ ["O", 3, (0.627, 1.062, -0.000)],
240
+ ["CG1", 4, (0.534, 1.437, -0.000)],
241
+ ["CG2", 4, (0.540, -0.785, -1.199)],
242
+ ["CD1", 5, (0.619, 1.391, 0.000)],
243
+ ],
244
+ "LEU": [
245
+ ["N", 0, (-0.520, 1.363, 0.000)],
246
+ ["CA", 0, (0.000, 0.000, 0.000)],
247
+ ["C", 0, (1.525, -0.000, -0.000)],
248
+ ["CB", 0, (-0.522, -0.773, -1.214)],
249
+ ["O", 3, (0.625, 1.063, -0.000)],
250
+ ["CG", 4, (0.678, 1.371, 0.000)],
251
+ ["CD1", 5, (0.530, 1.430, -0.000)],
252
+ ["CD2", 5, (0.535, -0.774, 1.200)],
253
+ ],
254
+ "LYS": [
255
+ ["N", 0, (-0.526, 1.362, -0.000)],
256
+ ["CA", 0, (0.000, 0.000, 0.000)],
257
+ ["C", 0, (1.526, 0.000, 0.000)],
258
+ ["CB", 0, (-0.524, -0.778, -1.208)],
259
+ ["O", 3, (0.626, 1.062, -0.000)],
260
+ ["CG", 4, (0.619, 1.390, 0.000)],
261
+ ["CD", 5, (0.559, 1.417, 0.000)],
262
+ ["CE", 6, (0.560, 1.416, 0.000)],
263
+ ["NZ", 7, (0.554, 1.387, 0.000)],
264
+ ],
265
+ "MET": [
266
+ ["N", 0, (-0.521, 1.364, -0.000)],
267
+ ["CA", 0, (0.000, 0.000, 0.000)],
268
+ ["C", 0, (1.525, 0.000, 0.000)],
269
+ ["CB", 0, (-0.523, -0.776, -1.210)],
270
+ ["O", 3, (0.625, 1.062, -0.000)],
271
+ ["CG", 4, (0.613, 1.391, -0.000)],
272
+ ["SD", 5, (0.703, 1.695, 0.000)],
273
+ ["CE", 6, (0.320, 1.786, -0.000)],
274
+ ],
275
+ "PHE": [
276
+ ["N", 0, (-0.518, 1.363, 0.000)],
277
+ ["CA", 0, (0.000, 0.000, 0.000)],
278
+ ["C", 0, (1.524, 0.000, -0.000)],
279
+ ["CB", 0, (-0.525, -0.776, -1.212)],
280
+ ["O", 3, (0.626, 1.062, -0.000)],
281
+ ["CG", 4, (0.607, 1.377, 0.000)],
282
+ ["CD1", 5, (0.709, 1.195, -0.000)],
283
+ ["CD2", 5, (0.706, -1.196, 0.000)],
284
+ ["CE1", 5, (2.102, 1.198, -0.000)],
285
+ ["CE2", 5, (2.098, -1.201, -0.000)],
286
+ ["CZ", 5, (2.794, -0.003, -0.001)],
287
+ ],
288
+ "PRO": [
289
+ ["N", 0, (-0.566, 1.351, -0.000)],
290
+ ["CA", 0, (0.000, 0.000, 0.000)],
291
+ ["C", 0, (1.527, -0.000, 0.000)],
292
+ ["CB", 0, (-0.546, -0.611, -1.293)],
293
+ ["O", 3, (0.621, 1.066, 0.000)],
294
+ ["CG", 4, (0.382, 1.445, 0.0)],
295
+ # ['CD', 5, (0.427, 1.440, 0.0)],
296
+ ["CD", 5, (0.477, 1.424, 0.0)], # manually made angle 2 degrees larger
297
+ ],
298
+ "SER": [
299
+ ["N", 0, (-0.529, 1.360, -0.000)],
300
+ ["CA", 0, (0.000, 0.000, 0.000)],
301
+ ["C", 0, (1.525, -0.000, -0.000)],
302
+ ["CB", 0, (-0.518, -0.777, -1.211)],
303
+ ["O", 3, (0.626, 1.062, -0.000)],
304
+ ["OG", 4, (0.503, 1.325, 0.000)],
305
+ ],
306
+ "THR": [
307
+ ["N", 0, (-0.517, 1.364, 0.000)],
308
+ ["CA", 0, (0.000, 0.000, 0.000)],
309
+ ["C", 0, (1.526, 0.000, -0.000)],
310
+ ["CB", 0, (-0.516, -0.793, -1.215)],
311
+ ["O", 3, (0.626, 1.062, 0.000)],
312
+ ["CG2", 4, (0.550, -0.718, -1.228)],
313
+ ["OG1", 4, (0.472, 1.353, 0.000)],
314
+ ],
315
+ "TRP": [
316
+ ["N", 0, (-0.521, 1.363, 0.000)],
317
+ ["CA", 0, (0.000, 0.000, 0.000)],
318
+ ["C", 0, (1.525, -0.000, 0.000)],
319
+ ["CB", 0, (-0.523, -0.776, -1.212)],
320
+ ["O", 3, (0.627, 1.062, 0.000)],
321
+ ["CG", 4, (0.609, 1.370, -0.000)],
322
+ ["CD1", 5, (0.824, 1.091, 0.000)],
323
+ ["CD2", 5, (0.854, -1.148, -0.005)],
324
+ ["CE2", 5, (2.186, -0.678, -0.007)],
325
+ ["CE3", 5, (0.622, -2.530, -0.007)],
326
+ ["NE1", 5, (2.140, 0.690, -0.004)],
327
+ ["CH2", 5, (3.028, -2.890, -0.013)],
328
+ ["CZ2", 5, (3.283, -1.543, -0.011)],
329
+ ["CZ3", 5, (1.715, -3.389, -0.011)],
330
+ ],
331
+ "TYR": [
332
+ ["N", 0, (-0.522, 1.362, 0.000)],
333
+ ["CA", 0, (0.000, 0.000, 0.000)],
334
+ ["C", 0, (1.524, -0.000, -0.000)],
335
+ ["CB", 0, (-0.522, -0.776, -1.213)],
336
+ ["O", 3, (0.627, 1.062, -0.000)],
337
+ ["CG", 4, (0.607, 1.382, -0.000)],
338
+ ["CD1", 5, (0.716, 1.195, -0.000)],
339
+ ["CD2", 5, (0.713, -1.194, -0.001)],
340
+ ["CE1", 5, (2.107, 1.200, -0.002)],
341
+ ["CE2", 5, (2.104, -1.201, -0.003)],
342
+ ["OH", 5, (4.168, -0.002, -0.005)],
343
+ ["CZ", 5, (2.791, -0.001, -0.003)],
344
+ ],
345
+ "VAL": [
346
+ ["N", 0, (-0.494, 1.373, -0.000)],
347
+ ["CA", 0, (0.000, 0.000, 0.000)],
348
+ ["C", 0, (1.527, -0.000, -0.000)],
349
+ ["CB", 0, (-0.533, -0.795, -1.213)],
350
+ ["O", 3, (0.627, 1.062, -0.000)],
351
+ ["CG1", 4, (0.540, 1.429, -0.000)],
352
+ ["CG2", 4, (0.533, -0.776, 1.203)],
353
+ ],
354
+ }
355
+
356
+ # A list of atoms (excluding hydrogen) for each AA type. PDB naming convention.
357
+ residue_atoms = {
358
+ "ALA": ["C", "CA", "CB", "N", "O"],
359
+ "ARG": ["C", "CA", "CB", "CG", "CD", "CZ", "N", "NE", "O", "NH1", "NH2"],
360
+ "ASP": ["C", "CA", "CB", "CG", "N", "O", "OD1", "OD2"],
361
+ "ASN": ["C", "CA", "CB", "CG", "N", "ND2", "O", "OD1"],
362
+ "CYS": ["C", "CA", "CB", "N", "O", "SG"],
363
+ "GLU": ["C", "CA", "CB", "CG", "CD", "N", "O", "OE1", "OE2"],
364
+ "GLN": ["C", "CA", "CB", "CG", "CD", "N", "NE2", "O", "OE1"],
365
+ "GLY": ["C", "CA", "N", "O"],
366
+ "HIS": ["C", "CA", "CB", "CG", "CD2", "CE1", "N", "ND1", "NE2", "O"],
367
+ "ILE": ["C", "CA", "CB", "CG1", "CG2", "CD1", "N", "O"],
368
+ "LEU": ["C", "CA", "CB", "CG", "CD1", "CD2", "N", "O"],
369
+ "LYS": ["C", "CA", "CB", "CG", "CD", "CE", "N", "NZ", "O"],
370
+ "MET": ["C", "CA", "CB", "CG", "CE", "N", "O", "SD"],
371
+ "PHE": ["C", "CA", "CB", "CG", "CD1", "CD2", "CE1", "CE2", "CZ", "N", "O"],
372
+ "PRO": ["C", "CA", "CB", "CG", "CD", "N", "O"],
373
+ "SER": ["C", "CA", "CB", "N", "O", "OG"],
374
+ "THR": ["C", "CA", "CB", "CG2", "N", "O", "OG1"],
375
+ "TRP": [
376
+ "C",
377
+ "CA",
378
+ "CB",
379
+ "CG",
380
+ "CD1",
381
+ "CD2",
382
+ "CE2",
383
+ "CE3",
384
+ "CZ2",
385
+ "CZ3",
386
+ "CH2",
387
+ "N",
388
+ "NE1",
389
+ "O",
390
+ ],
391
+ "TYR": ["C", "CA", "CB", "CG", "CD1", "CD2", "CE1", "CE2", "CZ", "N", "O", "OH"],
392
+ "VAL": ["C", "CA", "CB", "CG1", "CG2", "N", "O"],
393
+ }
394
+
395
+ # Naming swaps for ambiguous atom names.
396
+ # Due to symmetries in the amino acids the naming of atoms is ambiguous in
397
+ # 4 of the 20 amino acids.
398
+ # (The LDDT paper lists 7 amino acids as ambiguous, but the naming ambiguities
399
+ # in LEU, VAL and ARG can be resolved by using the 3d constellations of
400
+ # the 'ambiguous' atoms and their neighbours)
401
+ residue_atom_renaming_swaps = {
402
+ "ASP": {"OD1": "OD2"},
403
+ "GLU": {"OE1": "OE2"},
404
+ "PHE": {"CD1": "CD2", "CE1": "CE2"},
405
+ "TYR": {"CD1": "CD2", "CE1": "CE2"},
406
+ }
407
+
408
+ # Van der Waals radii [Angstroem] of the atoms (from Wikipedia)
409
+ van_der_waals_radius = {
410
+ "C": 1.7,
411
+ "N": 1.55,
412
+ "O": 1.52,
413
+ "S": 1.8,
414
+ }
415
+
416
+ Bond = collections.namedtuple("Bond", ["atom1_name", "atom2_name", "length", "stddev"])
417
+ BondAngle = collections.namedtuple(
418
+ "BondAngle", ["atom1_name", "atom2_name", "atom3name", "angle_rad", "stddev"]
419
+ )
420
+
421
+
422
+ @functools.lru_cache(maxsize=None)
423
+ def load_stereo_chemical_props() -> (
424
+ Tuple[
425
+ Mapping[str, List[Bond]],
426
+ Mapping[str, List[Bond]],
427
+ Mapping[str, List[BondAngle]],
428
+ ]
429
+ ):
430
+ """Load stereo_chemical_props.txt into a nice structure.
431
+
432
+ Load literature values for bond lengths and bond angles and translate
433
+ bond angles into the length of the opposite edge of the triangle
434
+ ("residue_virtual_bonds").
435
+
436
+ Returns:
437
+ residue_bonds: Dict that maps resname -> list of Bond tuples.
438
+ residue_virtual_bonds: Dict that maps resname -> list of Bond tuples.
439
+ residue_bond_angles: Dict that maps resname -> list of BondAngle tuples.
440
+ """
441
+ stereo_chemical_props_path = os.path.join(
442
+ os.path.dirname(os.path.abspath(__file__)), "stereo_chemical_props.txt"
443
+ )
444
+ with open(stereo_chemical_props_path, "rt") as f:
445
+ stereo_chemical_props = f.read()
446
+ lines_iter = iter(stereo_chemical_props.splitlines())
447
+ # Load bond lengths.
448
+ residue_bonds = {}
449
+ next(lines_iter) # Skip header line.
450
+ for line in lines_iter:
451
+ if line.strip() == "-":
452
+ break
453
+ bond, resname, length, stddev = line.split()
454
+ atom1, atom2 = bond.split("-")
455
+ if resname not in residue_bonds:
456
+ residue_bonds[resname] = []
457
+ residue_bonds[resname].append(Bond(atom1, atom2, float(length), float(stddev)))
458
+ residue_bonds["UNK"] = []
459
+
460
+ # Load bond angles.
461
+ residue_bond_angles = {}
462
+ next(lines_iter) # Skip empty line.
463
+ next(lines_iter) # Skip header line.
464
+ for line in lines_iter:
465
+ if line.strip() == "-":
466
+ break
467
+ bond, resname, angle_degree, stddev_degree = line.split()
468
+ atom1, atom2, atom3 = bond.split("-")
469
+ if resname not in residue_bond_angles:
470
+ residue_bond_angles[resname] = []
471
+ residue_bond_angles[resname].append(
472
+ BondAngle(
473
+ atom1,
474
+ atom2,
475
+ atom3,
476
+ float(angle_degree) / 180.0 * np.pi,
477
+ float(stddev_degree) / 180.0 * np.pi,
478
+ )
479
+ )
480
+ residue_bond_angles["UNK"] = []
481
+
482
+ def make_bond_key(atom1_name, atom2_name):
483
+ """Unique key to lookup bonds."""
484
+ return "-".join(sorted([atom1_name, atom2_name]))
485
+
486
+ # Translate bond angles into distances ("virtual bonds").
487
+ residue_virtual_bonds = {}
488
+ for resname, bond_angles in residue_bond_angles.items():
489
+ # Create a fast lookup dict for bond lengths.
490
+ bond_cache = {}
491
+ for b in residue_bonds[resname]:
492
+ bond_cache[make_bond_key(b.atom1_name, b.atom2_name)] = b
493
+ residue_virtual_bonds[resname] = []
494
+ for ba in bond_angles:
495
+ bond1 = bond_cache[make_bond_key(ba.atom1_name, ba.atom2_name)]
496
+ bond2 = bond_cache[make_bond_key(ba.atom2_name, ba.atom3name)]
497
+
498
+ # Compute distance between atom1 and atom3 using the law of cosines
499
+ # c^2 = a^2 + b^2 - 2ab*cos(gamma).
500
+ gamma = ba.angle_rad
501
+ length = np.sqrt(
502
+ bond1.length**2
503
+ + bond2.length**2
504
+ - 2 * bond1.length * bond2.length * np.cos(gamma)
505
+ )
506
+
507
+ # Propagation of uncertainty assuming uncorrelated errors.
508
+ dl_outer = 0.5 / length
509
+ dl_dgamma = (2 * bond1.length * bond2.length * np.sin(gamma)) * dl_outer
510
+ dl_db1 = (2 * bond1.length - 2 * bond2.length * np.cos(gamma)) * dl_outer
511
+ dl_db2 = (2 * bond2.length - 2 * bond1.length * np.cos(gamma)) * dl_outer
512
+ stddev = np.sqrt(
513
+ (dl_dgamma * ba.stddev) ** 2
514
+ + (dl_db1 * bond1.stddev) ** 2
515
+ + (dl_db2 * bond2.stddev) ** 2
516
+ )
517
+ residue_virtual_bonds[resname].append(
518
+ Bond(ba.atom1_name, ba.atom3name, length, stddev)
519
+ )
520
+
521
+ return (residue_bonds, residue_virtual_bonds, residue_bond_angles)
522
+
523
+
524
+ # Between-residue bond lengths for general bonds (first element) and for Proline
525
+ # (second element).
526
+ between_res_bond_length_c_n = [1.329, 1.341]
527
+ between_res_bond_length_stddev_c_n = [0.014, 0.016]
528
+
529
+ # Between-residue cos_angles.
530
+ between_res_cos_angles_c_n_ca = [-0.5203, 0.0353] # degrees: 121.352 +- 2.315
531
+ between_res_cos_angles_ca_c_n = [-0.4473, 0.0311] # degrees: 116.568 +- 1.995
532
+
533
+ # This mapping is used when we need to store atom data in a format that requires
534
+ # fixed atom data size for every residue (e.g. a numpy array).
535
+ atom_types = [
536
+ "N",
537
+ "CA",
538
+ "C",
539
+ "CB",
540
+ "O",
541
+ "CG",
542
+ "CG1",
543
+ "CG2",
544
+ "OG",
545
+ "OG1",
546
+ "SG",
547
+ "CD",
548
+ "CD1",
549
+ "CD2",
550
+ "ND1",
551
+ "ND2",
552
+ "OD1",
553
+ "OD2",
554
+ "SD",
555
+ "CE",
556
+ "CE1",
557
+ "CE2",
558
+ "CE3",
559
+ "NE",
560
+ "NE1",
561
+ "NE2",
562
+ "OE1",
563
+ "OE2",
564
+ "CH2",
565
+ "NH1",
566
+ "NH2",
567
+ "OH",
568
+ "CZ",
569
+ "CZ2",
570
+ "CZ3",
571
+ "NZ",
572
+ "OXT",
573
+ ]
574
+ atom_order = {atom_type: i for i, atom_type in enumerate(atom_types)}
575
+ atom_type_num = len(atom_types) # := 37.
576
+
577
+ # A compact atom encoding with 14 columns
578
+ # pylint: disable=line-too-long
579
+ # pylint: disable=bad-whitespace
580
+ restype_name_to_atom14_names = {
581
+ "ALA": ["N", "CA", "C", "O", "CB", "", "", "", "", "", "", "", "", ""],
582
+ "ARG": [
583
+ "N",
584
+ "CA",
585
+ "C",
586
+ "O",
587
+ "CB",
588
+ "CG",
589
+ "CD",
590
+ "NE",
591
+ "CZ",
592
+ "NH1",
593
+ "NH2",
594
+ "",
595
+ "",
596
+ "",
597
+ ],
598
+ "ASN": ["N", "CA", "C", "O", "CB", "CG", "OD1", "ND2", "", "", "", "", "", ""],
599
+ "ASP": ["N", "CA", "C", "O", "CB", "CG", "OD1", "OD2", "", "", "", "", "", ""],
600
+ "CYS": ["N", "CA", "C", "O", "CB", "SG", "", "", "", "", "", "", "", ""],
601
+ "GLN": ["N", "CA", "C", "O", "CB", "CG", "CD", "OE1", "NE2", "", "", "", "", ""],
602
+ "GLU": ["N", "CA", "C", "O", "CB", "CG", "CD", "OE1", "OE2", "", "", "", "", ""],
603
+ "GLY": ["N", "CA", "C", "O", "", "", "", "", "", "", "", "", "", ""],
604
+ "HIS": [
605
+ "N",
606
+ "CA",
607
+ "C",
608
+ "O",
609
+ "CB",
610
+ "CG",
611
+ "ND1",
612
+ "CD2",
613
+ "CE1",
614
+ "NE2",
615
+ "",
616
+ "",
617
+ "",
618
+ "",
619
+ ],
620
+ "ILE": ["N", "CA", "C", "O", "CB", "CG1", "CG2", "CD1", "", "", "", "", "", ""],
621
+ "LEU": ["N", "CA", "C", "O", "CB", "CG", "CD1", "CD2", "", "", "", "", "", ""],
622
+ "LYS": ["N", "CA", "C", "O", "CB", "CG", "CD", "CE", "NZ", "", "", "", "", ""],
623
+ "MET": ["N", "CA", "C", "O", "CB", "CG", "SD", "CE", "", "", "", "", "", ""],
624
+ "PHE": [
625
+ "N",
626
+ "CA",
627
+ "C",
628
+ "O",
629
+ "CB",
630
+ "CG",
631
+ "CD1",
632
+ "CD2",
633
+ "CE1",
634
+ "CE2",
635
+ "CZ",
636
+ "",
637
+ "",
638
+ "",
639
+ ],
640
+ "PRO": ["N", "CA", "C", "O", "CB", "CG", "CD", "", "", "", "", "", "", ""],
641
+ "SER": ["N", "CA", "C", "O", "CB", "OG", "", "", "", "", "", "", "", ""],
642
+ "THR": ["N", "CA", "C", "O", "CB", "OG1", "CG2", "", "", "", "", "", "", ""],
643
+ "TRP": [
644
+ "N",
645
+ "CA",
646
+ "C",
647
+ "O",
648
+ "CB",
649
+ "CG",
650
+ "CD1",
651
+ "CD2",
652
+ "NE1",
653
+ "CE2",
654
+ "CE3",
655
+ "CZ2",
656
+ "CZ3",
657
+ "CH2",
658
+ ],
659
+ "TYR": [
660
+ "N",
661
+ "CA",
662
+ "C",
663
+ "O",
664
+ "CB",
665
+ "CG",
666
+ "CD1",
667
+ "CD2",
668
+ "CE1",
669
+ "CE2",
670
+ "CZ",
671
+ "OH",
672
+ "",
673
+ "",
674
+ ],
675
+ "VAL": ["N", "CA", "C", "O", "CB", "CG1", "CG2", "", "", "", "", "", "", ""],
676
+ "UNK": ["", "", "", "", "", "", "", "", "", "", "", "", "", ""],
677
+ }
678
+ # pylint: enable=line-too-long
679
+ # pylint: enable=bad-whitespace
680
+
681
+
682
+ # This is the standard residue order when coding AA type as a number.
683
+ # Reproduce it by taking 3-letter AA codes and sorting them alphabetically.
684
+ restypes = [
685
+ "A",
686
+ "R",
687
+ "N",
688
+ "D",
689
+ "C",
690
+ "Q",
691
+ "E",
692
+ "G",
693
+ "H",
694
+ "I",
695
+ "L",
696
+ "K",
697
+ "M",
698
+ "F",
699
+ "P",
700
+ "S",
701
+ "T",
702
+ "W",
703
+ "Y",
704
+ "V",
705
+ ]
706
+ restype_order = {restype: i for i, restype in enumerate(restypes)}
707
+ restype_num = len(restypes) # := 20.
708
+ unk_restype_index = restype_num # Catch-all index for unknown restypes.
709
+
710
+ restypes_with_x = restypes + ["X"]
711
+ restype_order_with_x = {restype: i for i, restype in enumerate(restypes_with_x)}
712
+
713
+
714
+ def sequence_to_onehot(
715
+ sequence: str, mapping: Mapping[str, int], map_unknown_to_x: bool = False
716
+ ) -> np.ndarray:
717
+ """Maps the given sequence into a one-hot encoded matrix.
718
+
719
+ Args:
720
+ sequence: An amino acid sequence.
721
+ mapping: A dictionary mapping amino acids to integers.
722
+ map_unknown_to_x: If True, any amino acid that is not in the mapping will be
723
+ mapped to the unknown amino acid 'X'. If the mapping doesn't contain
724
+ amino acid 'X', an error will be thrown. If False, any amino acid not in
725
+ the mapping will throw an error.
726
+
727
+ Returns:
728
+ A numpy array of shape (seq_len, num_unique_aas) with one-hot encoding of
729
+ the sequence.
730
+
731
+ Raises:
732
+ ValueError: If the mapping doesn't contain values from 0 to
733
+ num_unique_aas - 1 without any gaps.
734
+ """
735
+ num_entries = max(mapping.values()) + 1
736
+
737
+ if sorted(set(mapping.values())) != list(range(num_entries)):
738
+ raise ValueError(
739
+ "The mapping must have values from 0 to num_unique_aas-1 "
740
+ "without any gaps. Got: %s" % sorted(mapping.values())
741
+ )
742
+
743
+ one_hot_arr = np.zeros((len(sequence), num_entries), dtype=np.int32)
744
+
745
+ for aa_index, aa_type in enumerate(sequence):
746
+ if map_unknown_to_x:
747
+ if aa_type.isalpha() and aa_type.isupper():
748
+ aa_id = mapping.get(aa_type, mapping["X"])
749
+ else:
750
+ raise ValueError(f"Invalid character in the sequence: {aa_type}")
751
+ else:
752
+ aa_id = mapping[aa_type]
753
+ one_hot_arr[aa_index, aa_id] = 1
754
+
755
+ return one_hot_arr
756
+
757
+
758
+ restype_1to3 = {
759
+ "A": "ALA",
760
+ "R": "ARG",
761
+ "N": "ASN",
762
+ "D": "ASP",
763
+ "C": "CYS",
764
+ "Q": "GLN",
765
+ "E": "GLU",
766
+ "G": "GLY",
767
+ "H": "HIS",
768
+ "I": "ILE",
769
+ "L": "LEU",
770
+ "K": "LYS",
771
+ "M": "MET",
772
+ "F": "PHE",
773
+ "P": "PRO",
774
+ "S": "SER",
775
+ "T": "THR",
776
+ "W": "TRP",
777
+ "Y": "TYR",
778
+ "V": "VAL",
779
+ }
780
+
781
+
782
+ # NB: restype_3to1 differs from Bio.PDB.protein_letters_3to1 by being a simple
783
+ # 1-to-1 mapping of 3 letter names to one letter names. The latter contains
784
+ # many more, and less common, three letter names as keys and maps many of these
785
+ # to the same one letter name (including 'X' and 'U' which we don't use here).
786
+ restype_3to1 = {v: k for k, v in restype_1to3.items()}
787
+
788
+ # Define a restype name for all unknown residues.
789
+ unk_restype = "UNK"
790
+
791
+ resnames = [restype_1to3[r] for r in restypes] + [unk_restype]
792
+ resname_to_idx = {resname: i for i, resname in enumerate(resnames)}
793
+
794
+
795
+ # Define exploded all-atom representation (atom73)
796
+ atom73_names = ['N', 'CA', 'C', 'CB', 'O']
797
+ for aa1 in restypes:
798
+ aa3 = restype_1to3[aa1]
799
+ atom_list = residue_atoms[aa3]
800
+ for atom in atom_types:
801
+ if atom in atom_list and atom not in atom73_names:
802
+ atom73_names.append(f'{aa1}{atom}')
803
+
804
+ atom73_names_to_idx = {a: i for i, a in enumerate(atom73_names)}
805
+
806
+ restype_atom73_mask = np.zeros((22, 73))
807
+ for i, restype in enumerate(restypes):
808
+ for atom_name in atom_types:
809
+ atom73_name = atom_name
810
+ if atom_name not in ['N', 'CA', 'C', 'CB', 'O']:
811
+ atom73_name = restype + atom_name
812
+ if atom73_name in atom73_names_to_idx:
813
+ atom73_idx = atom73_names_to_idx[atom73_name]
814
+ restype_atom73_mask[i, atom73_idx] = 1
815
+ # Remove CB for glycine
816
+ restype_atom73_mask[restype_order["G"], 3] = 0
817
+ # Backbone atoms for unk and mask
818
+ restype_atom73_mask[-2:, [0, 1, 2, 4]] = 1
819
+
820
+
821
+ # The mapping here uses hhblits convention, so that B is mapped to D, J and O
822
+ # are mapped to X, U is mapped to C, and Z is mapped to E. Other than that the
823
+ # remaining 20 amino acids are kept in alphabetical order.
824
+ # There are 2 non-amino acid codes, X (representing any amino acid) and
825
+ # "-" representing a missing amino acid in an alignment. The id for these
826
+ # codes is put at the end (20 and 21) so that they can easily be ignored if
827
+ # desired.
828
+ HHBLITS_AA_TO_ID = {
829
+ "A": 0,
830
+ "B": 2,
831
+ "C": 1,
832
+ "D": 2,
833
+ "E": 3,
834
+ "F": 4,
835
+ "G": 5,
836
+ "H": 6,
837
+ "I": 7,
838
+ "J": 20,
839
+ "K": 8,
840
+ "L": 9,
841
+ "M": 10,
842
+ "N": 11,
843
+ "O": 20,
844
+ "P": 12,
845
+ "Q": 13,
846
+ "R": 14,
847
+ "S": 15,
848
+ "T": 16,
849
+ "U": 1,
850
+ "V": 17,
851
+ "W": 18,
852
+ "X": 20,
853
+ "Y": 19,
854
+ "Z": 3,
855
+ "-": 21,
856
+ }
857
+
858
+ # Partial inversion of HHBLITS_AA_TO_ID.
859
+ ID_TO_HHBLITS_AA = {
860
+ 0: "A",
861
+ 1: "C", # Also U.
862
+ 2: "D", # Also B.
863
+ 3: "E", # Also Z.
864
+ 4: "F",
865
+ 5: "G",
866
+ 6: "H",
867
+ 7: "I",
868
+ 8: "K",
869
+ 9: "L",
870
+ 10: "M",
871
+ 11: "N",
872
+ 12: "P",
873
+ 13: "Q",
874
+ 14: "R",
875
+ 15: "S",
876
+ 16: "T",
877
+ 17: "V",
878
+ 18: "W",
879
+ 19: "Y",
880
+ 20: "X", # Includes J and O.
881
+ 21: "-",
882
+ }
883
+
884
+ restypes_with_x_and_gap = restypes + ["X", "-"]
885
+ MAP_HHBLITS_AATYPE_TO_OUR_AATYPE = tuple(
886
+ restypes_with_x_and_gap.index(ID_TO_HHBLITS_AA[i])
887
+ for i in range(len(restypes_with_x_and_gap))
888
+ )
889
+
890
+
891
+ def _make_standard_atom_mask() -> np.ndarray:
892
+ """Returns [num_res_types, num_atom_types] mask array."""
893
+ # +1 to account for unknown (all 0s).
894
+ mask = np.zeros([restype_num + 1, atom_type_num], dtype=np.int32)
895
+ for restype, restype_letter in enumerate(restypes):
896
+ restype_name = restype_1to3[restype_letter]
897
+ atom_names = residue_atoms[restype_name]
898
+ for atom_name in atom_names:
899
+ atom_type = atom_order[atom_name]
900
+ mask[restype, atom_type] = 1
901
+ return mask
902
+
903
+
904
+ STANDARD_ATOM_MASK = _make_standard_atom_mask()
905
+
906
+
907
+ # A one hot representation for the first and second atoms defining the axis
908
+ # of rotation for each chi-angle in each residue.
909
+ def chi_angle_atom(atom_index: int) -> np.ndarray:
910
+ """Define chi-angle rigid groups via one-hot representations."""
911
+ chi_angles_index = {}
912
+ one_hots = []
913
+
914
+ for k, v in chi_angles_atoms.items():
915
+ indices = [atom_types.index(s[atom_index]) for s in v]
916
+ indices.extend([-1] * (4 - len(indices)))
917
+ chi_angles_index[k] = indices
918
+
919
+ for r in restypes:
920
+ res3 = restype_1to3[r]
921
+ one_hot = np.eye(atom_type_num)[chi_angles_index[res3]]
922
+ one_hots.append(one_hot)
923
+
924
+ one_hots.append(np.zeros([4, atom_type_num])) # Add zeros for residue `X`.
925
+ one_hot = np.stack(one_hots, axis=0)
926
+ one_hot = np.transpose(one_hot, [0, 2, 1])
927
+
928
+ return one_hot
929
+
930
+
931
+ chi_atom_1_one_hot = chi_angle_atom(1)
932
+ chi_atom_2_one_hot = chi_angle_atom(2)
933
+
934
+ # An array like chi_angles_atoms but using indices rather than names.
935
+ chi_angles_atom_indices = [chi_angles_atoms[restype_1to3[r]] for r in restypes]
936
+ chi_angles_atom_indices = tree.map_structure(
937
+ lambda atom_name: atom_order[atom_name], chi_angles_atom_indices
938
+ )
939
+ chi_angles_atom_indices = np.array(
940
+ [
941
+ chi_atoms + ([[0, 0, 0, 0]] * (4 - len(chi_atoms)))
942
+ for chi_atoms in chi_angles_atom_indices
943
+ ]
944
+ )
945
+
946
+ # Mapping from (res_name, atom_name) pairs to the atom's chi group index
947
+ # and atom index within that group.
948
+ chi_groups_for_atom = collections.defaultdict(list)
949
+ for res_name, chi_angle_atoms_for_res in chi_angles_atoms.items():
950
+ for chi_group_i, chi_group in enumerate(chi_angle_atoms_for_res):
951
+ for atom_i, atom in enumerate(chi_group):
952
+ chi_groups_for_atom[(res_name, atom)].append((chi_group_i, atom_i))
953
+ chi_groups_for_atom = dict(chi_groups_for_atom)
954
+
955
+
956
+ def _make_rigid_transformation_4x4(ex, ey, translation):
957
+ """Create a rigid 4x4 transformation matrix from two axes and transl."""
958
+ # Normalize ex.
959
+ ex_normalized = ex / np.linalg.norm(ex)
960
+
961
+ # make ey perpendicular to ex
962
+ ey_normalized = ey - np.dot(ey, ex_normalized) * ex_normalized
963
+ ey_normalized /= np.linalg.norm(ey_normalized)
964
+
965
+ # compute ez as cross product
966
+ eznorm = np.cross(ex_normalized, ey_normalized)
967
+ m = np.stack([ex_normalized, ey_normalized, eznorm, translation]).transpose()
968
+ m = np.concatenate([m, [[0.0, 0.0, 0.0, 1.0]]], axis=0)
969
+ return m
970
+
971
+
972
+ # create an array with (restype, atomtype) --> rigid_group_idx
973
+ # and an array with (restype, atomtype, coord) for the atom positions
974
+ # and compute affine transformation matrices (4,4) from one rigid group to the
975
+ # previous group
976
+ restype_atom37_to_rigid_group = np.zeros([21, 37], dtype=int)
977
+ restype_atom37_mask = np.zeros([21, 37], dtype=np.float32)
978
+ restype_atom37_rigid_group_positions = np.zeros([21, 37, 3], dtype=np.float32)
979
+ restype_atom14_to_rigid_group = np.zeros([21, 14], dtype=int)
980
+ restype_atom14_mask = np.zeros([21, 14], dtype=np.float32)
981
+ restype_atom14_rigid_group_positions = np.zeros([21, 14, 3], dtype=np.float32)
982
+ restype_rigid_group_default_frame = np.zeros([21, 8, 4, 4], dtype=np.float32)
983
+
984
+
985
+ def _make_rigid_group_constants():
986
+ """Fill the arrays above."""
987
+ for restype, restype_letter in enumerate(restypes):
988
+ resname = restype_1to3[restype_letter]
989
+ for atomname, group_idx, atom_position in rigid_group_atom_positions[resname]:
990
+ atomtype = atom_order[atomname]
991
+ restype_atom37_to_rigid_group[restype, atomtype] = group_idx
992
+ restype_atom37_mask[restype, atomtype] = 1
993
+ restype_atom37_rigid_group_positions[restype, atomtype, :] = atom_position
994
+
995
+ atom14idx = restype_name_to_atom14_names[resname].index(atomname)
996
+ restype_atom14_to_rigid_group[restype, atom14idx] = group_idx
997
+ restype_atom14_mask[restype, atom14idx] = 1
998
+ restype_atom14_rigid_group_positions[restype, atom14idx, :] = atom_position
999
+
1000
+ for restype, restype_letter in enumerate(restypes):
1001
+ resname = restype_1to3[restype_letter]
1002
+ atom_positions = {
1003
+ name: np.array(pos) for name, _, pos in rigid_group_atom_positions[resname]
1004
+ }
1005
+
1006
+ # backbone to backbone is the identity transform
1007
+ restype_rigid_group_default_frame[restype, 0, :, :] = np.eye(4)
1008
+
1009
+ # pre-omega-frame to backbone (currently dummy identity matrix)
1010
+ restype_rigid_group_default_frame[restype, 1, :, :] = np.eye(4)
1011
+
1012
+ # phi-frame to backbone
1013
+ mat = _make_rigid_transformation_4x4(
1014
+ ex=atom_positions["N"] - atom_positions["CA"],
1015
+ ey=np.array([1.0, 0.0, 0.0]),
1016
+ translation=atom_positions["N"],
1017
+ )
1018
+ restype_rigid_group_default_frame[restype, 2, :, :] = mat
1019
+
1020
+ # psi-frame to backbone
1021
+ mat = _make_rigid_transformation_4x4(
1022
+ ex=atom_positions["C"] - atom_positions["CA"],
1023
+ ey=atom_positions["CA"] - atom_positions["N"],
1024
+ translation=atom_positions["C"],
1025
+ )
1026
+ restype_rigid_group_default_frame[restype, 3, :, :] = mat
1027
+
1028
+ # chi1-frame to backbone
1029
+ if chi_angles_mask[restype][0]:
1030
+ base_atom_names = chi_angles_atoms[resname][0]
1031
+ base_atom_positions = [atom_positions[name] for name in base_atom_names]
1032
+ mat = _make_rigid_transformation_4x4(
1033
+ ex=base_atom_positions[2] - base_atom_positions[1],
1034
+ ey=base_atom_positions[0] - base_atom_positions[1],
1035
+ translation=base_atom_positions[2],
1036
+ )
1037
+ restype_rigid_group_default_frame[restype, 4, :, :] = mat
1038
+
1039
+ # chi2-frame to chi1-frame
1040
+ # chi3-frame to chi2-frame
1041
+ # chi4-frame to chi3-frame
1042
+ # luckily all rotation axes for the next frame start at (0,0,0) of the
1043
+ # previous frame
1044
+ for chi_idx in range(1, 4):
1045
+ if chi_angles_mask[restype][chi_idx]:
1046
+ axis_end_atom_name = chi_angles_atoms[resname][chi_idx][2]
1047
+ axis_end_atom_position = atom_positions[axis_end_atom_name]
1048
+ mat = _make_rigid_transformation_4x4(
1049
+ ex=axis_end_atom_position,
1050
+ ey=np.array([-1.0, 0.0, 0.0]),
1051
+ translation=axis_end_atom_position,
1052
+ )
1053
+ restype_rigid_group_default_frame[restype, 4 + chi_idx, :, :] = mat
1054
+
1055
+
1056
+ _make_rigid_group_constants()
1057
+
1058
+
1059
+ def make_atom14_dists_bounds(overlap_tolerance=1.5, bond_length_tolerance_factor=15):
1060
+ """compute upper and lower bounds for bonds to assess violations."""
1061
+ restype_atom14_bond_lower_bound = np.zeros([21, 14, 14], np.float32)
1062
+ restype_atom14_bond_upper_bound = np.zeros([21, 14, 14], np.float32)
1063
+ restype_atom14_bond_stddev = np.zeros([21, 14, 14], np.float32)
1064
+ residue_bonds, residue_virtual_bonds, _ = load_stereo_chemical_props()
1065
+ for restype, restype_letter in enumerate(restypes):
1066
+ resname = restype_1to3[restype_letter]
1067
+ atom_list = restype_name_to_atom14_names[resname]
1068
+
1069
+ # create lower and upper bounds for clashes
1070
+ for atom1_idx, atom1_name in enumerate(atom_list):
1071
+ if not atom1_name:
1072
+ continue
1073
+ atom1_radius = van_der_waals_radius[atom1_name[0]]
1074
+ for atom2_idx, atom2_name in enumerate(atom_list):
1075
+ if (not atom2_name) or atom1_idx == atom2_idx:
1076
+ continue
1077
+ atom2_radius = van_der_waals_radius[atom2_name[0]]
1078
+ lower = atom1_radius + atom2_radius - overlap_tolerance
1079
+ upper = 1e10
1080
+ restype_atom14_bond_lower_bound[restype, atom1_idx, atom2_idx] = lower
1081
+ restype_atom14_bond_lower_bound[restype, atom2_idx, atom1_idx] = lower
1082
+ restype_atom14_bond_upper_bound[restype, atom1_idx, atom2_idx] = upper
1083
+ restype_atom14_bond_upper_bound[restype, atom2_idx, atom1_idx] = upper
1084
+
1085
+ # overwrite lower and upper bounds for bonds and angles
1086
+ for b in residue_bonds[resname] + residue_virtual_bonds[resname]:
1087
+ atom1_idx = atom_list.index(b.atom1_name)
1088
+ atom2_idx = atom_list.index(b.atom2_name)
1089
+ lower = b.length - bond_length_tolerance_factor * b.stddev
1090
+ upper = b.length + bond_length_tolerance_factor * b.stddev
1091
+ restype_atom14_bond_lower_bound[restype, atom1_idx, atom2_idx] = lower
1092
+ restype_atom14_bond_lower_bound[restype, atom2_idx, atom1_idx] = lower
1093
+ restype_atom14_bond_upper_bound[restype, atom1_idx, atom2_idx] = upper
1094
+ restype_atom14_bond_upper_bound[restype, atom2_idx, atom1_idx] = upper
1095
+ restype_atom14_bond_stddev[restype, atom1_idx, atom2_idx] = b.stddev
1096
+ restype_atom14_bond_stddev[restype, atom2_idx, atom1_idx] = b.stddev
1097
+ return {
1098
+ "lower_bound": restype_atom14_bond_lower_bound, # shape (21,14,14)
1099
+ "upper_bound": restype_atom14_bond_upper_bound, # shape (21,14,14)
1100
+ "stddev": restype_atom14_bond_stddev, # shape (21,14,14)
1101
+ }
1102
+
1103
+
1104
+ standard_residue_bonds, _, standard_residue_bond_angles = load_stereo_chemical_props()
core/stereo_chemical_props.txt ADDED
@@ -0,0 +1,345 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Bond Residue Mean StdDev
2
+ CA-CB ALA 1.520 0.021
3
+ N-CA ALA 1.459 0.020
4
+ CA-C ALA 1.525 0.026
5
+ C-O ALA 1.229 0.019
6
+ CA-CB ARG 1.535 0.022
7
+ CB-CG ARG 1.521 0.027
8
+ CG-CD ARG 1.515 0.025
9
+ CD-NE ARG 1.460 0.017
10
+ NE-CZ ARG 1.326 0.013
11
+ CZ-NH1 ARG 1.326 0.013
12
+ CZ-NH2 ARG 1.326 0.013
13
+ N-CA ARG 1.459 0.020
14
+ CA-C ARG 1.525 0.026
15
+ C-O ARG 1.229 0.019
16
+ CA-CB ASN 1.527 0.026
17
+ CB-CG ASN 1.506 0.023
18
+ CG-OD1 ASN 1.235 0.022
19
+ CG-ND2 ASN 1.324 0.025
20
+ N-CA ASN 1.459 0.020
21
+ CA-C ASN 1.525 0.026
22
+ C-O ASN 1.229 0.019
23
+ CA-CB ASP 1.535 0.022
24
+ CB-CG ASP 1.513 0.021
25
+ CG-OD1 ASP 1.249 0.023
26
+ CG-OD2 ASP 1.249 0.023
27
+ N-CA ASP 1.459 0.020
28
+ CA-C ASP 1.525 0.026
29
+ C-O ASP 1.229 0.019
30
+ CA-CB CYS 1.526 0.013
31
+ CB-SG CYS 1.812 0.016
32
+ N-CA CYS 1.459 0.020
33
+ CA-C CYS 1.525 0.026
34
+ C-O CYS 1.229 0.019
35
+ CA-CB GLU 1.535 0.022
36
+ CB-CG GLU 1.517 0.019
37
+ CG-CD GLU 1.515 0.015
38
+ CD-OE1 GLU 1.252 0.011
39
+ CD-OE2 GLU 1.252 0.011
40
+ N-CA GLU 1.459 0.020
41
+ CA-C GLU 1.525 0.026
42
+ C-O GLU 1.229 0.019
43
+ CA-CB GLN 1.535 0.022
44
+ CB-CG GLN 1.521 0.027
45
+ CG-CD GLN 1.506 0.023
46
+ CD-OE1 GLN 1.235 0.022
47
+ CD-NE2 GLN 1.324 0.025
48
+ N-CA GLN 1.459 0.020
49
+ CA-C GLN 1.525 0.026
50
+ C-O GLN 1.229 0.019
51
+ N-CA GLY 1.456 0.015
52
+ CA-C GLY 1.514 0.016
53
+ C-O GLY 1.232 0.016
54
+ CA-CB HIS 1.535 0.022
55
+ CB-CG HIS 1.492 0.016
56
+ CG-ND1 HIS 1.369 0.015
57
+ CG-CD2 HIS 1.353 0.017
58
+ ND1-CE1 HIS 1.343 0.025
59
+ CD2-NE2 HIS 1.415 0.021
60
+ CE1-NE2 HIS 1.322 0.023
61
+ N-CA HIS 1.459 0.020
62
+ CA-C HIS 1.525 0.026
63
+ C-O HIS 1.229 0.019
64
+ CA-CB ILE 1.544 0.023
65
+ CB-CG1 ILE 1.536 0.028
66
+ CB-CG2 ILE 1.524 0.031
67
+ CG1-CD1 ILE 1.500 0.069
68
+ N-CA ILE 1.459 0.020
69
+ CA-C ILE 1.525 0.026
70
+ C-O ILE 1.229 0.019
71
+ CA-CB LEU 1.533 0.023
72
+ CB-CG LEU 1.521 0.029
73
+ CG-CD1 LEU 1.514 0.037
74
+ CG-CD2 LEU 1.514 0.037
75
+ N-CA LEU 1.459 0.020
76
+ CA-C LEU 1.525 0.026
77
+ C-O LEU 1.229 0.019
78
+ CA-CB LYS 1.535 0.022
79
+ CB-CG LYS 1.521 0.027
80
+ CG-CD LYS 1.520 0.034
81
+ CD-CE LYS 1.508 0.025
82
+ CE-NZ LYS 1.486 0.025
83
+ N-CA LYS 1.459 0.020
84
+ CA-C LYS 1.525 0.026
85
+ C-O LYS 1.229 0.019
86
+ CA-CB MET 1.535 0.022
87
+ CB-CG MET 1.509 0.032
88
+ CG-SD MET 1.807 0.026
89
+ SD-CE MET 1.774 0.056
90
+ N-CA MET 1.459 0.020
91
+ CA-C MET 1.525 0.026
92
+ C-O MET 1.229 0.019
93
+ CA-CB PHE 1.535 0.022
94
+ CB-CG PHE 1.509 0.017
95
+ CG-CD1 PHE 1.383 0.015
96
+ CG-CD2 PHE 1.383 0.015
97
+ CD1-CE1 PHE 1.388 0.020
98
+ CD2-CE2 PHE 1.388 0.020
99
+ CE1-CZ PHE 1.369 0.019
100
+ CE2-CZ PHE 1.369 0.019
101
+ N-CA PHE 1.459 0.020
102
+ CA-C PHE 1.525 0.026
103
+ C-O PHE 1.229 0.019
104
+ CA-CB PRO 1.531 0.020
105
+ CB-CG PRO 1.495 0.050
106
+ CG-CD PRO 1.502 0.033
107
+ CD-N PRO 1.474 0.014
108
+ N-CA PRO 1.468 0.017
109
+ CA-C PRO 1.524 0.020
110
+ C-O PRO 1.228 0.020
111
+ CA-CB SER 1.525 0.015
112
+ CB-OG SER 1.418 0.013
113
+ N-CA SER 1.459 0.020
114
+ CA-C SER 1.525 0.026
115
+ C-O SER 1.229 0.019
116
+ CA-CB THR 1.529 0.026
117
+ CB-OG1 THR 1.428 0.020
118
+ CB-CG2 THR 1.519 0.033
119
+ N-CA THR 1.459 0.020
120
+ CA-C THR 1.525 0.026
121
+ C-O THR 1.229 0.019
122
+ CA-CB TRP 1.535 0.022
123
+ CB-CG TRP 1.498 0.018
124
+ CG-CD1 TRP 1.363 0.014
125
+ CG-CD2 TRP 1.432 0.017
126
+ CD1-NE1 TRP 1.375 0.017
127
+ NE1-CE2 TRP 1.371 0.013
128
+ CD2-CE2 TRP 1.409 0.012
129
+ CD2-CE3 TRP 1.399 0.015
130
+ CE2-CZ2 TRP 1.393 0.017
131
+ CE3-CZ3 TRP 1.380 0.017
132
+ CZ2-CH2 TRP 1.369 0.019
133
+ CZ3-CH2 TRP 1.396 0.016
134
+ N-CA TRP 1.459 0.020
135
+ CA-C TRP 1.525 0.026
136
+ C-O TRP 1.229 0.019
137
+ CA-CB TYR 1.535 0.022
138
+ CB-CG TYR 1.512 0.015
139
+ CG-CD1 TYR 1.387 0.013
140
+ CG-CD2 TYR 1.387 0.013
141
+ CD1-CE1 TYR 1.389 0.015
142
+ CD2-CE2 TYR 1.389 0.015
143
+ CE1-CZ TYR 1.381 0.013
144
+ CE2-CZ TYR 1.381 0.013
145
+ CZ-OH TYR 1.374 0.017
146
+ N-CA TYR 1.459 0.020
147
+ CA-C TYR 1.525 0.026
148
+ C-O TYR 1.229 0.019
149
+ CA-CB VAL 1.543 0.021
150
+ CB-CG1 VAL 1.524 0.021
151
+ CB-CG2 VAL 1.524 0.021
152
+ N-CA VAL 1.459 0.020
153
+ CA-C VAL 1.525 0.026
154
+ C-O VAL 1.229 0.019
155
+ -
156
+
157
+ Angle Residue Mean StdDev
158
+ N-CA-CB ALA 110.1 1.4
159
+ CB-CA-C ALA 110.1 1.5
160
+ N-CA-C ALA 111.0 2.7
161
+ CA-C-O ALA 120.1 2.1
162
+ N-CA-CB ARG 110.6 1.8
163
+ CB-CA-C ARG 110.4 2.0
164
+ CA-CB-CG ARG 113.4 2.2
165
+ CB-CG-CD ARG 111.6 2.6
166
+ CG-CD-NE ARG 111.8 2.1
167
+ CD-NE-CZ ARG 123.6 1.4
168
+ NE-CZ-NH1 ARG 120.3 0.5
169
+ NE-CZ-NH2 ARG 120.3 0.5
170
+ NH1-CZ-NH2 ARG 119.4 1.1
171
+ N-CA-C ARG 111.0 2.7
172
+ CA-C-O ARG 120.1 2.1
173
+ N-CA-CB ASN 110.6 1.8
174
+ CB-CA-C ASN 110.4 2.0
175
+ CA-CB-CG ASN 113.4 2.2
176
+ CB-CG-ND2 ASN 116.7 2.4
177
+ CB-CG-OD1 ASN 121.6 2.0
178
+ ND2-CG-OD1 ASN 121.9 2.3
179
+ N-CA-C ASN 111.0 2.7
180
+ CA-C-O ASN 120.1 2.1
181
+ N-CA-CB ASP 110.6 1.8
182
+ CB-CA-C ASP 110.4 2.0
183
+ CA-CB-CG ASP 113.4 2.2
184
+ CB-CG-OD1 ASP 118.3 0.9
185
+ CB-CG-OD2 ASP 118.3 0.9
186
+ OD1-CG-OD2 ASP 123.3 1.9
187
+ N-CA-C ASP 111.0 2.7
188
+ CA-C-O ASP 120.1 2.1
189
+ N-CA-CB CYS 110.8 1.5
190
+ CB-CA-C CYS 111.5 1.2
191
+ CA-CB-SG CYS 114.2 1.1
192
+ N-CA-C CYS 111.0 2.7
193
+ CA-C-O CYS 120.1 2.1
194
+ N-CA-CB GLU 110.6 1.8
195
+ CB-CA-C GLU 110.4 2.0
196
+ CA-CB-CG GLU 113.4 2.2
197
+ CB-CG-CD GLU 114.2 2.7
198
+ CG-CD-OE1 GLU 118.3 2.0
199
+ CG-CD-OE2 GLU 118.3 2.0
200
+ OE1-CD-OE2 GLU 123.3 1.2
201
+ N-CA-C GLU 111.0 2.7
202
+ CA-C-O GLU 120.1 2.1
203
+ N-CA-CB GLN 110.6 1.8
204
+ CB-CA-C GLN 110.4 2.0
205
+ CA-CB-CG GLN 113.4 2.2
206
+ CB-CG-CD GLN 111.6 2.6
207
+ CG-CD-OE1 GLN 121.6 2.0
208
+ CG-CD-NE2 GLN 116.7 2.4
209
+ OE1-CD-NE2 GLN 121.9 2.3
210
+ N-CA-C GLN 111.0 2.7
211
+ CA-C-O GLN 120.1 2.1
212
+ N-CA-C GLY 113.1 2.5
213
+ CA-C-O GLY 120.6 1.8
214
+ N-CA-CB HIS 110.6 1.8
215
+ CB-CA-C HIS 110.4 2.0
216
+ CA-CB-CG HIS 113.6 1.7
217
+ CB-CG-ND1 HIS 123.2 2.5
218
+ CB-CG-CD2 HIS 130.8 3.1
219
+ CG-ND1-CE1 HIS 108.2 1.4
220
+ ND1-CE1-NE2 HIS 109.9 2.2
221
+ CE1-NE2-CD2 HIS 106.6 2.5
222
+ NE2-CD2-CG HIS 109.2 1.9
223
+ CD2-CG-ND1 HIS 106.0 1.4
224
+ N-CA-C HIS 111.0 2.7
225
+ CA-C-O HIS 120.1 2.1
226
+ N-CA-CB ILE 110.8 2.3
227
+ CB-CA-C ILE 111.6 2.0
228
+ CA-CB-CG1 ILE 111.0 1.9
229
+ CB-CG1-CD1 ILE 113.9 2.8
230
+ CA-CB-CG2 ILE 110.9 2.0
231
+ CG1-CB-CG2 ILE 111.4 2.2
232
+ N-CA-C ILE 111.0 2.7
233
+ CA-C-O ILE 120.1 2.1
234
+ N-CA-CB LEU 110.4 2.0
235
+ CB-CA-C LEU 110.2 1.9
236
+ CA-CB-CG LEU 115.3 2.3
237
+ CB-CG-CD1 LEU 111.0 1.7
238
+ CB-CG-CD2 LEU 111.0 1.7
239
+ CD1-CG-CD2 LEU 110.5 3.0
240
+ N-CA-C LEU 111.0 2.7
241
+ CA-C-O LEU 120.1 2.1
242
+ N-CA-CB LYS 110.6 1.8
243
+ CB-CA-C LYS 110.4 2.0
244
+ CA-CB-CG LYS 113.4 2.2
245
+ CB-CG-CD LYS 111.6 2.6
246
+ CG-CD-CE LYS 111.9 3.0
247
+ CD-CE-NZ LYS 111.7 2.3
248
+ N-CA-C LYS 111.0 2.7
249
+ CA-C-O LYS 120.1 2.1
250
+ N-CA-CB MET 110.6 1.8
251
+ CB-CA-C MET 110.4 2.0
252
+ CA-CB-CG MET 113.3 1.7
253
+ CB-CG-SD MET 112.4 3.0
254
+ CG-SD-CE MET 100.2 1.6
255
+ N-CA-C MET 111.0 2.7
256
+ CA-C-O MET 120.1 2.1
257
+ N-CA-CB PHE 110.6 1.8
258
+ CB-CA-C PHE 110.4 2.0
259
+ CA-CB-CG PHE 113.9 2.4
260
+ CB-CG-CD1 PHE 120.8 0.7
261
+ CB-CG-CD2 PHE 120.8 0.7
262
+ CD1-CG-CD2 PHE 118.3 1.3
263
+ CG-CD1-CE1 PHE 120.8 1.1
264
+ CG-CD2-CE2 PHE 120.8 1.1
265
+ CD1-CE1-CZ PHE 120.1 1.2
266
+ CD2-CE2-CZ PHE 120.1 1.2
267
+ CE1-CZ-CE2 PHE 120.0 1.8
268
+ N-CA-C PHE 111.0 2.7
269
+ CA-C-O PHE 120.1 2.1
270
+ N-CA-CB PRO 103.3 1.2
271
+ CB-CA-C PRO 111.7 2.1
272
+ CA-CB-CG PRO 104.8 1.9
273
+ CB-CG-CD PRO 106.5 3.9
274
+ CG-CD-N PRO 103.2 1.5
275
+ CA-N-CD PRO 111.7 1.4
276
+ N-CA-C PRO 112.1 2.6
277
+ CA-C-O PRO 120.2 2.4
278
+ N-CA-CB SER 110.5 1.5
279
+ CB-CA-C SER 110.1 1.9
280
+ CA-CB-OG SER 111.2 2.7
281
+ N-CA-C SER 111.0 2.7
282
+ CA-C-O SER 120.1 2.1
283
+ N-CA-CB THR 110.3 1.9
284
+ CB-CA-C THR 111.6 2.7
285
+ CA-CB-OG1 THR 109.0 2.1
286
+ CA-CB-CG2 THR 112.4 1.4
287
+ OG1-CB-CG2 THR 110.0 2.3
288
+ N-CA-C THR 111.0 2.7
289
+ CA-C-O THR 120.1 2.1
290
+ N-CA-CB TRP 110.6 1.8
291
+ CB-CA-C TRP 110.4 2.0
292
+ CA-CB-CG TRP 113.7 1.9
293
+ CB-CG-CD1 TRP 127.0 1.3
294
+ CB-CG-CD2 TRP 126.6 1.3
295
+ CD1-CG-CD2 TRP 106.3 0.8
296
+ CG-CD1-NE1 TRP 110.1 1.0
297
+ CD1-NE1-CE2 TRP 109.0 0.9
298
+ NE1-CE2-CD2 TRP 107.3 1.0
299
+ CE2-CD2-CG TRP 107.3 0.8
300
+ CG-CD2-CE3 TRP 133.9 0.9
301
+ NE1-CE2-CZ2 TRP 130.4 1.1
302
+ CE3-CD2-CE2 TRP 118.7 1.2
303
+ CD2-CE2-CZ2 TRP 122.3 1.2
304
+ CE2-CZ2-CH2 TRP 117.4 1.0
305
+ CZ2-CH2-CZ3 TRP 121.6 1.2
306
+ CH2-CZ3-CE3 TRP 121.2 1.1
307
+ CZ3-CE3-CD2 TRP 118.8 1.3
308
+ N-CA-C TRP 111.0 2.7
309
+ CA-C-O TRP 120.1 2.1
310
+ N-CA-CB TYR 110.6 1.8
311
+ CB-CA-C TYR 110.4 2.0
312
+ CA-CB-CG TYR 113.4 1.9
313
+ CB-CG-CD1 TYR 121.0 0.6
314
+ CB-CG-CD2 TYR 121.0 0.6
315
+ CD1-CG-CD2 TYR 117.9 1.1
316
+ CG-CD1-CE1 TYR 121.3 0.8
317
+ CG-CD2-CE2 TYR 121.3 0.8
318
+ CD1-CE1-CZ TYR 119.8 0.9
319
+ CD2-CE2-CZ TYR 119.8 0.9
320
+ CE1-CZ-CE2 TYR 119.8 1.6
321
+ CE1-CZ-OH TYR 120.1 2.7
322
+ CE2-CZ-OH TYR 120.1 2.7
323
+ N-CA-C TYR 111.0 2.7
324
+ CA-C-O TYR 120.1 2.1
325
+ N-CA-CB VAL 111.5 2.2
326
+ CB-CA-C VAL 111.4 1.9
327
+ CA-CB-CG1 VAL 110.9 1.5
328
+ CA-CB-CG2 VAL 110.9 1.5
329
+ CG1-CB-CG2 VAL 110.9 1.6
330
+ N-CA-C VAL 111.0 2.7
331
+ CA-C-O VAL 120.1 2.1
332
+ -
333
+
334
+ Non-bonded distance Minimum Dist Tolerance
335
+ C-C 3.4 1.5
336
+ C-N 3.25 1.5
337
+ C-S 3.5 1.5
338
+ C-O 3.22 1.5
339
+ N-N 3.1 1.5
340
+ N-S 3.35 1.5
341
+ N-O 3.07 1.5
342
+ O-S 3.32 1.5
343
+ O-O 3.04 1.5
344
+ S-S 2.03 1.0
345
+ -
core/utils.py ADDED
@@ -0,0 +1,1062 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ https://github.com/ProteinDesignLab/protpardelle
3
+ License: MIT
4
+ Author: Alex Chu
5
+
6
+ Various utils for handling protein data.
7
+ """
8
+
9
+ import os
10
+ import shlex
11
+ import subprocess
12
+ import sys
13
+ import torch
14
+ import yaml
15
+ import argparse
16
+
17
+ from einops import rearrange, repeat
18
+ import numpy as np
19
+ import torch
20
+ import torch.nn.functional as F
21
+ import Bio
22
+ from Bio.PDB.DSSP import DSSP
23
+
24
+ from core import protein
25
+ from core import protein_mpnn
26
+ from core import residue_constants
27
+
28
+
29
+ PATH_TO_TMALIGN = "/home/alexechu/essentials_kit/ml_utils/align/TMalign/TMalign"
30
+
31
+
32
+ ################ STRUCTURE/FORMAT UTILS #############################
33
+
34
+
35
+ def aatype_to_seq(aatype, seq_mask=None):
36
+ if seq_mask is None:
37
+ seq_mask = torch.ones_like(aatype)
38
+
39
+ mapping = residue_constants.restypes_with_x
40
+ mapping = mapping + ["<mask>"]
41
+
42
+ unbatched = False
43
+ if len(aatype.shape) == 1:
44
+ unbatched = True
45
+ aatype = [aatype]
46
+ seq_mask = [seq_mask]
47
+
48
+ seqs = []
49
+ for i, ai in enumerate(aatype):
50
+ seq = []
51
+ for j, aa in enumerate(ai):
52
+ if seq_mask[i][j] == 1:
53
+ try:
54
+ seq.append(mapping[aa])
55
+ except IndexError:
56
+ print(aatype[i])
57
+ raise Exception(f"Error in mapping {aa} at {i},{j}")
58
+ seqs.append("".join(seq))
59
+
60
+ if unbatched:
61
+ seqs = seqs[0]
62
+ return seqs
63
+
64
+
65
+ def seq_to_aatype(seq, num_tokens=21):
66
+ if num_tokens == 20:
67
+ mapping = residue_constants.restype_order
68
+ if num_tokens == 21:
69
+ mapping = residue_constants.restype_order_with_x
70
+ if num_tokens == 22:
71
+ mapping = residue_constants.restype_order_with_x
72
+ mapping["<mask>"] = 21
73
+ return torch.Tensor([mapping[aa] for aa in seq]).long()
74
+
75
+
76
+ def batched_seq_to_aatype_and_mask(seqs, max_len=None):
77
+ if max_len is None:
78
+ max_len = max([len(s) for s in seqs])
79
+ aatypes = []
80
+ seq_mask = []
81
+ for s in seqs:
82
+ pad_size = max_len - len(s)
83
+ aatype = seq_to_aatype(s)
84
+ aatypes.append(F.pad(aatype, (0, pad_size)))
85
+ mask = torch.ones_like(aatype).float()
86
+ seq_mask.append(F.pad(mask, (0, pad_size)))
87
+ return torch.stack(aatypes), torch.stack(seq_mask)
88
+
89
+
90
+ def atom37_mask_from_aatype(aatype, seq_mask=None):
91
+ # source_mask is (21,37) originally
92
+ source_mask = torch.Tensor(residue_constants.restype_atom37_mask).to(aatype.device)
93
+ bb_atoms = source_mask[residue_constants.restype_order["G"]][None]
94
+ # Use only the first 20 plus bb atoms for X, mask
95
+ source_mask = torch.cat([source_mask[:-1], bb_atoms, bb_atoms], 0)
96
+ atom_mask = source_mask[aatype]
97
+ if seq_mask is not None:
98
+ atom_mask *= seq_mask[..., None]
99
+ return atom_mask
100
+
101
+
102
+ def atom37_coords_from_atom14(atom14_coords, aatype, return_mask=False):
103
+ # Unbatched
104
+ device = atom14_coords.device
105
+ atom37_coords = torch.zeros((atom14_coords.shape[0], 37, 3)).to(device)
106
+ for i in range(atom14_coords.shape[0]): # per residue
107
+ aa = aatype[i]
108
+ aa_3name = residue_constants.restype_1to3[residue_constants.restypes[aa]]
109
+ atom14_atoms = residue_constants.restype_name_to_atom14_names[aa_3name]
110
+ for j in range(14):
111
+ atom_name = atom14_atoms[j]
112
+ if atom_name != "":
113
+ atom37_idx = residue_constants.atom_order[atom_name]
114
+ atom37_coords[i, atom37_idx, :] = atom14_coords[i, j, :]
115
+
116
+ if return_mask:
117
+ atom37_mask = atom37_mask_from_aatype(aatype)
118
+ return atom37_coords, atom37_mask
119
+ return atom37_coords
120
+
121
+
122
+ def atom73_mask_from_aatype(aatype, seq_mask=None):
123
+ source_mask = torch.Tensor(residue_constants.restype_atom73_mask).to(aatype.device)
124
+ atom_mask = source_mask[aatype]
125
+ if seq_mask is not None:
126
+ atom_mask *= seq_mask[..., None]
127
+ return atom_mask
128
+
129
+
130
+ def atom37_to_atom73(atom37, aatype, return_mask=False):
131
+ # Unbatched
132
+ atom73 = torch.zeros((atom37.shape[0], 73, 3)).to(atom37)
133
+ for i in range(atom37.shape[0]):
134
+ aa = aatype[i]
135
+ aa1 = residue_constants.restypes[aa]
136
+ for j, atom37_name in enumerate(residue_constants.atom_types):
137
+ atom73_name = atom37_name
138
+ if atom37_name not in ["N", "CA", "C", "O", "CB"]:
139
+ atom73_name = aa1 + atom73_name
140
+ if atom73_name in residue_constants.atom73_names_to_idx:
141
+ atom73_idx = residue_constants.atom73_names_to_idx[atom73_name]
142
+ atom73[i, atom73_idx, :] = atom37[i, j, :]
143
+
144
+ if return_mask:
145
+ atom73_mask = atom73_mask_from_aatype(aatype)
146
+ return atom73, atom73_mask
147
+ return atom73
148
+
149
+
150
+ def atom73_to_atom37(atom73, aatype, return_mask=False):
151
+ # Unbatched
152
+ atom37_coords = torch.zeros((atom73.shape[0], 37, 3)).to(atom73)
153
+ for i in range(atom73.shape[0]): # per residue
154
+ aa = aatype[i]
155
+ aa1 = residue_constants.restypes[aa]
156
+ for j, atom_type in enumerate(residue_constants.atom_types):
157
+ atom73_name = atom_type
158
+ if atom73_name not in ["N", "CA", "C", "O", "CB"]:
159
+ atom73_name = aa1 + atom73_name
160
+ if atom73_name in residue_constants.atom73_names_to_idx:
161
+ atom73_idx = residue_constants.atom73_names_to_idx[atom73_name]
162
+ atom37_coords[i, j, :] = atom73[i, atom73_idx, :]
163
+
164
+ if return_mask:
165
+ atom37_mask = atom37_mask_from_aatype(aatype)
166
+ return atom37_coords, atom37_mask
167
+ return atom37_coords
168
+
169
+
170
+ def get_dmap(pdb, atoms=["N", "CA", "C", "O"], batched=True, out="torch", device=None):
171
+ def _dmap_from_coords(coords):
172
+ coords = coords.contiguous()
173
+ dmaps = torch.cdist(coords, coords).unsqueeze(1)
174
+ if out == "numpy":
175
+ return dmaps.detach().cpu().numpy()
176
+ elif out == "torch":
177
+ if device is not None:
178
+ return dmaps.to(device)
179
+ else:
180
+ return dmaps
181
+
182
+ if isinstance(pdb, str): # input is pdb file
183
+ coords = load_coords_from_pdb(pdb, atoms=atoms).view(1, -1, 3)
184
+ return _dmap_from_coords(coords)
185
+ elif len(pdb.shape) == 2: # single set of coords
186
+ if isinstance(pdb, np.ndarray):
187
+ pdb = torch.Tensor(pdb)
188
+ return _dmap_from_coords(pdb.unsqueeze(0))
189
+ elif len(pdb.shape) == 3 and batched:
190
+ return _dmap_from_coords(pdb)
191
+ elif len(pdb.shape) == 3 and not batched:
192
+ return _dmap_from_coords(pdb.view(1, -1, 3))
193
+ elif len(pdb.shape) == 4:
194
+ return _dmap_from_coords(pdb.view(pdb.size(0), -1, 3))
195
+
196
+
197
+ def get_channeled_dmap(coords):
198
+ # coords is b, nres, natom, 3
199
+ coords = coords.permute(0, 2, 1, 3)
200
+ dvecs = coords[..., None, :] - coords[..., None, :, :] # b, natom, nres, nres, 3
201
+ dists = torch.sqrt(dvecs.pow(2).sum(-1) + 1e-8)
202
+ return dists
203
+
204
+
205
+ def fill_in_cbeta_for_atom37(coords):
206
+ b = coords[..., 1, :] - coords[..., 0, :]
207
+ c = coords[..., 2, :] - coords[..., 1, :]
208
+ a = torch.cross(b, c, dim=-1)
209
+ cbeta = -0.58273431 * a + 0.56802827 * b - 0.54067466 * c + coords[..., 1, :]
210
+ new_coords = torch.clone(coords)
211
+ new_coords[..., 3, :] = cbeta
212
+ return new_coords
213
+
214
+
215
+ def get_distogram(coords, n_bins=20, start=2, return_onehot=True, seq_mask=None):
216
+ # coords is b, nres, natom, 3
217
+ # distogram for cb atom (assume 3rd atom)
218
+ coords_with_cb = fill_in_cbeta_for_atom37(coords)
219
+ dists = get_channeled_dmap(coords_with_cb[:, :, 3:4]).squeeze(1)
220
+ bins = torch.arange(start, start + n_bins - 1).to(dists.device)
221
+ dgram = torch.bucketize(dists, bins)
222
+ dgram_oh = F.one_hot(dgram, n_bins)
223
+ if seq_mask is not None:
224
+ mask_2d = seq_mask[:, :, None] * seq_mask[:, None, :]
225
+ dgram = dgram * mask_2d
226
+ dgram_oh = dgram_oh * mask_2d[..., None]
227
+
228
+ if return_onehot:
229
+ return dgram_oh
230
+ return dgram
231
+
232
+
233
+ def get_contacts(coords=None, distogram=None, seq_mask=None):
234
+ if distogram is None:
235
+ distogram = get_distogram(coords)
236
+ contacts = (distogram.argmax(-1) < 6).float()
237
+ if seq_mask is not None:
238
+ contacts *= seq_mask[..., None] * seq_mask[..., None, :]
239
+ return contacts
240
+
241
+
242
+ def dihedral(a, b, c, d):
243
+ # inputs can be (1,3), (n,3), or (bs,n,3)
244
+ b1 = a - b
245
+ b2 = b - c
246
+ b3 = c - d
247
+ n1 = F.normalize(torch.cross(b1, b2), dim=-1)
248
+ n2 = F.normalize(torch.cross(b2, b3), dim=-1)
249
+ m1 = torch.cross(n1, b2 / b2.norm(dim=-1).unsqueeze(-1))
250
+ y = (m1 * n2).sum(dim=-1)
251
+ x = (n1 * n2).sum(dim=-1)
252
+ return torch.atan2(y, x)
253
+
254
+
255
+ def get_torsions_from_coords(
256
+ coords, atoms=["N", "CA", "C", "O"], batched=True, out="torch", device=None
257
+ ):
258
+ """
259
+ Returns a n-dim array of shape (bs, nres, ntors), where ntors is the
260
+ number of torsion angles (e.g. 2 if using phi and psi), with units of radians.
261
+ """
262
+ if isinstance(coords, np.ndarray):
263
+ coords = torch.Tensor(coords)
264
+ if len(coords.shape) == 2:
265
+ coords = coords.unsqueeze(0)
266
+ if len(coords.shape) == 4:
267
+ coords = coords.view(coords.size(0), -1, 3)
268
+ if len(coords.shape) == 3 and not batched:
269
+ coords = coords.view(1, -1, 3)
270
+ if len(coords.shape) == 3:
271
+ bs = coords.size(0)
272
+ if "O" in atoms:
273
+ idxs = [
274
+ i for i in range(coords.size(1)) if i % 4 != 3
275
+ ] # deselect O atoms for N-Ca-C-O coords
276
+ coords = coords[:, idxs, :]
277
+ a, b, c, d = (
278
+ coords[:, :-3, :],
279
+ coords[:, 1:-2, :],
280
+ coords[:, 2:-1, :],
281
+ coords[:, 3:, :],
282
+ )
283
+ torsions = dihedral(
284
+ a, b, c, d
285
+ ) # output order is psi-omega-phi, reorganize to (bs, nres, 3)
286
+ torsions = torsions.view(bs, torsions.size(1) // 3, 3)
287
+ omegaphi = torch.cat(
288
+ (torch.zeros(bs, 1, 2).to(coords.device), torsions[:, :, 1:]), 1
289
+ )
290
+ psi = torch.cat((torsions[:, :, 0], torch.zeros(bs, 1).to(coords.device)), 1)
291
+ torsions = torch.cat(
292
+ (
293
+ omegaphi[:, :, 1].unsqueeze(-1),
294
+ psi.unsqueeze(-1),
295
+ omegaphi[:, :, 0].unsqueeze(-1),
296
+ ),
297
+ -1,
298
+ )
299
+ else:
300
+ raise Exception("input coords not of correct dims")
301
+
302
+ if out == "numpy":
303
+ return torsions.detach().cpu().numpy()
304
+ elif out == "torch":
305
+ if device is not None:
306
+ return torsions.to(device)
307
+ else:
308
+ return torsions
309
+
310
+
311
+ def get_trig_from_torsions(torsions, out="torch", device=None):
312
+ """
313
+ Calculate unit circle projections from coords input.
314
+
315
+ Returns a n-dim array of shape (bs, nres, ntors, 2), where ntors is the
316
+ number of torsion angles (e.g. 2 if using phi and psi), and the last
317
+ dimension is the xy unit-circle coordinates of the corresponding angle.
318
+ """
319
+ if isinstance(torsions, np.ndarray):
320
+ torsions = torch.Tensor(torsions)
321
+ x = torsions.cos()
322
+ y = torsions.sin()
323
+ trig = torch.cat((x.unsqueeze(-1), y.unsqueeze(-1)), -1)
324
+ if out == "numpy":
325
+ return trig.detach().cpu().numpy()
326
+ elif out == "torch":
327
+ if device is not None:
328
+ return trig.to(device)
329
+ else:
330
+ return trig
331
+
332
+
333
+ def get_abego_string_from_torsions(torsions):
334
+ A_bin = (-75, 50)
335
+ G_bin = (-100, 100)
336
+ torsions = torsions * 180.0 / np.pi
337
+ phi, psi = torsions[:, :, 0], torsions[:, :, 1]
338
+ abego_vec = np.zeros((torsions.size(0), torsions.size(1))).astype(str)
339
+ A = (phi <= 0) & (psi <= A_bin[1]) & (psi > A_bin[0])
340
+ B = (phi <= 0) & ((psi > A_bin[1]) | (psi <= A_bin[0]))
341
+ G = (phi > 0) & (psi <= G_bin[1]) & (psi > G_bin[0])
342
+ E = (phi > 0) & ((psi > G_bin[1]) | (psi <= G_bin[0]))
343
+ abego_vec[A] = "A"
344
+ abego_vec[B] = "B"
345
+ abego_vec[G] = "G"
346
+ abego_vec[E] = "E"
347
+ abego_strs = ["".join(v) for v in abego_vec]
348
+ return abego_strs
349
+
350
+
351
+ def get_bond_lengths_from_coords(coords, batched=True, out="torch", device=None):
352
+ """
353
+ Returns array of shape (bs, n_res, 4), where final dim is bond lengths
354
+ in order of N-Ca, Ca-C, C-O, C-N (none for last residue)
355
+ """
356
+ if isinstance(coords, np.ndarray):
357
+ coords = torch.Tensor(coords)
358
+ if len(coords.shape) == 2:
359
+ coords = coords.unsqueeze(0)
360
+ if len(coords.shape) == 3 and not batched:
361
+ coords = coords.view(1, -1, 3)
362
+ if len(coords.shape) == 4:
363
+ coords = coords.view(coords.size(0), -1, 3)
364
+ N = coords[:, ::4, :]
365
+ Ca = coords[:, 1::4, :]
366
+ C = coords[:, 2::4, :]
367
+ O = coords[:, 3::4, :]
368
+ NCa = (Ca - N).norm(dim=-1).unsqueeze(-1)
369
+ CaC = (C - Ca).norm(dim=-1).unsqueeze(-1)
370
+ CO = (O - C).norm(dim=-1).unsqueeze(-1)
371
+ CN = (N[:, 1:] - C[:, :-1]).norm(dim=-1)
372
+ CN = torch.cat([CN, torch.zeros(CN.size(0), 1).to(CN.device)], 1).unsqueeze(-1)
373
+ blengths = torch.cat((NCa, CaC, CO, CN), -1)
374
+ if out == "numpy":
375
+ return blengths.detach().cpu().numpy()
376
+ elif out == "torch":
377
+ if device is not None:
378
+ return blengths.to(device)
379
+ else:
380
+ return blengths
381
+
382
+
383
+ def get_bond_angles_from_coords(coords, batched=True, out="torch", device=None):
384
+ """
385
+ Returns array of shape (bs, n_res, 5), where final dim is bond angles
386
+ in order of N-Ca-C, Ca-C-O, Ca-C-N, O-C-N, C-N-Ca (none for last residue)
387
+ """
388
+
389
+ def _angle(v1, v2):
390
+ cos = (v1 * v2).sum(-1) / (v1.norm(dim=-1) * v2.norm(dim=-1))
391
+ return cos.acos()
392
+
393
+ if isinstance(coords, np.ndarray):
394
+ coords = torch.Tensor(coords)
395
+ if len(coords.shape) == 2:
396
+ coords = coords.unsqueeze(0)
397
+ if len(coords.shape) == 3 and not batched:
398
+ coords = coords.view(1, -1, 3)
399
+ if len(coords.shape) == 4:
400
+ coords = coords.view(coords.size(0), -1, 3)
401
+ N = coords[:, ::4, :]
402
+ Nnext = coords[:, 4::4, :]
403
+ Ca = coords[:, 1::4, :]
404
+ Canext = coords[:, 5::4, :]
405
+ C = coords[:, 2::4, :]
406
+ O = coords[:, 3::4, :]
407
+ CaN = N - Ca
408
+ CaC = C - Ca
409
+ CCa = Ca - C
410
+ CO = O - C
411
+ CNnext = Nnext - C[:, :-1, :]
412
+ NnextC = -1 * CNnext
413
+ NnextCanext = Canext - Nnext
414
+ NCaC = _angle(CaN, CaC).unsqueeze(-1)
415
+ CaCO = _angle(CCa, CO).unsqueeze(-1)
416
+ CaCN = _angle(CCa[:, :-1], CNnext).unsqueeze(-1)
417
+ CaCN = _extend(CaCN)
418
+ OCN = _angle(CO[:, :-1], CNnext).unsqueeze(-1)
419
+ OCN = _extend(OCN)
420
+ CNCa = _angle(NnextC, NnextCanext).unsqueeze(-1)
421
+ # CNCa = torch.cat([CNCa, torch.zeros(CNCa.size(0), 1).to(CNCa.device)], 1).unsqueeze(-1)
422
+ CNCa = _extend(CNCa)
423
+ bangles = torch.cat((NCaC, CaCO, CaCN, OCN, CNCa), -1)
424
+ if out == "numpy":
425
+ return bangles.detach().cpu().numpy()
426
+ elif out == "torch":
427
+ if device is not None:
428
+ return bangles.to(device)
429
+ else:
430
+ return bangles
431
+
432
+
433
+ def get_buried_positions_mask(coords, seq_mask=None, threshold=6.0):
434
+ ca_idx = residue_constants.atom_order["CA"] # typically 1
435
+ cb_idx = residue_constants.atom_order["CB"] # typically 3
436
+ if seq_mask is None:
437
+ seq_mask = torch.ones_like(coords)[..., 0, 0]
438
+ coords = fill_in_cbeta_for_atom37(coords)
439
+
440
+ # get 8 closest neighbors by CB
441
+ neighbor_coords = coords[:, :, cb_idx]
442
+
443
+ ca_neighbor_dists, edge_index = protein_mpnn.get_closest_neighbors(
444
+ neighbor_coords, seq_mask, 9
445
+ )
446
+ edge_index = edge_index[..., 1:].contiguous()
447
+
448
+ # compute avg CB distance
449
+ cb_coords = coords[:, :, cb_idx]
450
+ neighbor_cb = protein_mpnn.gather_nodes(cb_coords, edge_index)
451
+ avg_cb_dist = (neighbor_cb - cb_coords[..., None, :]).pow(2).sum(-1).sqrt().mean(-1)
452
+
453
+ buried_positions_mask = (avg_cb_dist < threshold).float() * seq_mask
454
+ return buried_positions_mask
455
+
456
+
457
+ def get_fullatom_bond_lengths_from_coords(
458
+ coords, aatype, atom_mask=None, return_format="per_aa"
459
+ ):
460
+ # Also return sidechain bond angles. All unbatched. return list of dicts
461
+ def dist(xyz1, xyz2):
462
+ return (xyz1 - xyz2).pow(2).sum().sqrt().detach().cpu().item()
463
+
464
+ assert aatype.max() <= 19
465
+ seq = aatype_to_seq(aatype)
466
+ # residue-wise list of dicts [{'N-CA': a, 'CA-C': b}, {'N-CA': a, 'CA-C': b}]
467
+ all_bond_lens_by_pos = []
468
+ # aa-wise dict of dicts of lists {'A': {'N-CA': [a, b, c], 'CA-C': [a, b, c]}}
469
+ all_bond_lens_by_aa = {aa: {} for aa in residue_constants.restypes}
470
+ for i, res in enumerate(coords):
471
+ aa3 = residue_constants.restype_1to3[seq[i]]
472
+ res_bond_lens = {}
473
+ for j, atom1 in enumerate(residue_constants.atom_types):
474
+ for k, atom2 in enumerate(residue_constants.atom_types):
475
+ if j < k and protein.are_atoms_bonded(aa3, atom1, atom2):
476
+ if atom_mask is None or (
477
+ atom_mask[i, j] > 0.5 and atom_mask[i, k] > 0.5
478
+ ):
479
+ bond_name = f"{atom1}-{atom2}"
480
+ bond_len = dist(res[j], res[k])
481
+ res_bond_lens[bond_name] = bond_len
482
+ all_bond_lens_by_pos.append(res_bond_lens)
483
+ for key, val in res_bond_lens.items():
484
+ all_bond_lens_by_aa[seq[i]].setdefault(key, []).append(val)
485
+
486
+ if return_format == "per_aa":
487
+ return all_bond_lens_by_aa
488
+ elif return_format == "per_position":
489
+ return all_bond_lens_by_pos
490
+
491
+
492
+ def batched_fullatom_bond_lengths_from_coords(
493
+ coords, aatype, atom_mask=None, return_format="per_aa"
494
+ ):
495
+ # Expects trimmed coords (no mask)
496
+ if return_format == "per_position":
497
+ batched_bond_lens = []
498
+ elif return_format == "per_aa":
499
+ batched_bond_lens = {aa: {} for aa in residue_constants.restypes}
500
+ for i, c in enumerate(coords):
501
+ atom_mask_i = None if atom_mask is None else atom_mask[i]
502
+ bond_lens = get_fullatom_bond_lengths_from_coords(
503
+ c, aatype[i], atom_mask=atom_mask_i, return_format=return_format
504
+ )
505
+ if return_format == "per_position":
506
+ batched_bond_lens.extend(bond_lens)
507
+ elif return_format == "per_aa":
508
+ for aa, d in bond_lens.items():
509
+ for bond, lengths in d.items():
510
+ batched_bond_lens[aa].setdefault(bond, []).extend(lengths)
511
+ return batched_bond_lens
512
+
513
+
514
+ def batched_fullatom_bond_angles_from_coords(coords, aatype, return_format="per_aa"):
515
+ # Expects trimmed coords (no mask)
516
+ if return_format == "per_position":
517
+ batched_bond_angles = []
518
+ elif return_format == "per_aa":
519
+ batched_bond_angles = {aa: {} for aa in residue_constants.restypes}
520
+ for i, c in enumerate(coords):
521
+ bond_angles = get_fullatom_bond_angles_from_coords(
522
+ c, aatype[i], return_format=return_format
523
+ )
524
+ if return_format == "per_position":
525
+ batched_bond_angles.extend(bond_angles)
526
+ elif return_format == "per_aa":
527
+ for aa, d in bond_angles.items():
528
+ for bond, lengths in d.items():
529
+ batched_bond_angles[aa].setdefault(bond, []).extend(lengths)
530
+ return batched_bond_angles
531
+
532
+
533
+ def get_chi_angles(coords, aatype, atom_mask=None, seq_mask=None):
534
+ # unbatched
535
+ # return (n, 4) chis in degrees and mask
536
+ chis = []
537
+ chi_mask = []
538
+ atom_order = residue_constants.atom_order
539
+
540
+ seq = aatype_to_seq(aatype, seq_mask=seq_mask)
541
+
542
+ for i, aa1 in enumerate(seq): # per residue
543
+ if seq_mask is not None and seq_mask[i] == 0:
544
+ chis.append([0, 0, 0, 0])
545
+ chi_mask.append([0, 0, 0, 0])
546
+ else:
547
+ chi = []
548
+ mask = []
549
+ chi_atoms = residue_constants.chi_angles_atoms[
550
+ residue_constants.restype_1to3[aa1]
551
+ ]
552
+ for j in range(4): # per chi angle
553
+ if j > len(chi_atoms) - 1:
554
+ chi.append(0)
555
+ mask.append(0)
556
+ elif atom_mask is not None and any(
557
+ [atom_mask[i, atom_order[a]] < 0.5 for a in chi_atoms[j]]
558
+ ):
559
+ chi.append(0)
560
+ mask.append(0)
561
+ else:
562
+ # Four atoms per dihedral
563
+ xyz4 = [coords[i, atom_order[a]] for a in chi_atoms[j]]
564
+ angle = dihedral(*xyz4) * 180 / np.pi
565
+ chi.append(angle)
566
+ mask.append(1)
567
+ chis.append(chi)
568
+ chi_mask.append(mask)
569
+
570
+ chis = torch.Tensor(chis)
571
+ chi_mask = torch.Tensor(chi_mask)
572
+
573
+ return chis, chi_mask
574
+
575
+
576
+ def fill_Os_from_NCaC_coords(
577
+ coords: torch.Tensor, out: str = "torch", device: str = None
578
+ ):
579
+ """Given NCaC coords, add O atom coordinates in.
580
+ (bs, 3n, 3) -> (bs, 4n, 3)
581
+ """
582
+ CO_LEN = 1.231
583
+ if len(coords.shape) == 2:
584
+ coords = coords.unsqueeze(0)
585
+ Cs = coords[:, 2:-1:3, :] # all but last C
586
+ CCa_norm = F.normalize(coords[:, 1:-2:3, :] - Cs, dim=-1) # all but last Ca
587
+ CN_norm = F.normalize(coords[:, 3::3, :] - Cs, dim=-1) # all but first N
588
+ Os = F.normalize(CCa_norm + CN_norm, dim=-1) * -CO_LEN
589
+ Os += Cs
590
+ # TODO place C-term O atom properly
591
+ Os = torch.cat([Os, coords[:, -1, :].view(-1, 1, 3) + 1], 1)
592
+ coords_out = []
593
+ for i in range(Os.size(1)):
594
+ coords_out.append(coords[:, i * 3 : (i + 1) * 3, :])
595
+ coords_out.append(Os[:, i, :].view(-1, 1, 3))
596
+ coords_out = torch.cat(coords_out, 1)
597
+ if out == "numpy":
598
+ return coords_out.detach().cpu().numpy()
599
+ elif out == "torch":
600
+ if device is not None:
601
+ return coords_out.to(device)
602
+ else:
603
+ return coords_out
604
+
605
+
606
+ def _extend(x, axis=1, n=1, prepend=False):
607
+ # Add an extra zeros 'residue' to the end (or beginning, prepend=True) of a Tensor
608
+ # Used to extend torsions when there is no 'psi' for last residue
609
+ shape = list(x.shape)
610
+ shape[axis] = n
611
+ if prepend:
612
+ return torch.cat([torch.zeros(shape).to(x.device), x], axis)
613
+ else:
614
+ return torch.cat([x, torch.zeros(shape).to(x.device)], axis)
615
+
616
+
617
+ def trim_coords(coords, n_res, batched=True):
618
+ if batched: # Return list of tensors
619
+ front = (coords.shape[1] - n_res) // 2
620
+ return [
621
+ coords[i, front[i] : front[i] + n_res[i]] for i in range(coords.shape[0])
622
+ ]
623
+ else:
624
+ if isinstance(n_res, torch.Tensor):
625
+ n_res = n_res.int()
626
+ front_pad = (coords.shape[0] - n_res) // 2
627
+ return coords[front_pad : front_pad + n_res]
628
+
629
+
630
+ def batch_align_on_calpha(x, y):
631
+ aligned_x = []
632
+ for i, xi in enumerate(x):
633
+ xi_calpha = xi[:, 1, :]
634
+ _, (R, t) = kabsch_align(xi_calpha, y[i, :, 1, :])
635
+ xi_ctr = xi - xi_calpha.mean(0, keepdim=True)
636
+ xi_aligned = xi_ctr @ R.t() + t
637
+ aligned_x.append(xi_aligned)
638
+ return torch.stack(aligned_x)
639
+
640
+
641
+ def kabsch_align(p, q):
642
+ if len(p.shape) > 2:
643
+ p = p.reshape(-1, 3)
644
+ if len(q.shape) > 2:
645
+ q = q.reshape(-1, 3)
646
+ p_ctr = p - p.mean(0, keepdim=True)
647
+ t = q.mean(0, keepdim=True)
648
+ q_ctr = q - t
649
+ H = p_ctr.t() @ q_ctr
650
+ U, S, V = torch.svd(H)
651
+ R = V @ U.t()
652
+ I_ = torch.eye(3).to(p)
653
+ I_[-1, -1] = R.det().sign()
654
+ R = V @ I_ @ U.t()
655
+ p_aligned = p_ctr @ R.t() + t
656
+ return p_aligned, (R, t)
657
+
658
+
659
+ def get_dssp_string(pdb):
660
+ try:
661
+ structure = Bio.PDB.PDBParser(QUIET=True).get_structure(pdb[:-3], pdb)
662
+ dssp = DSSP(structure[0], pdb, dssp="mkdssp")
663
+ dssp_string = "".join([dssp[k][2] for k in dssp.keys()])
664
+ return dssp_string
665
+ except Exception as e:
666
+ print(e)
667
+ return None
668
+
669
+
670
+ def pool_dssp_symbols(dssp_string, newchar=None, chars=["-", "T", "S", "C", " "]):
671
+ """Replaces all instances of chars with newchar. DSSP chars are helix=GHI, strand=EB, loop=- TSC"""
672
+ if newchar is None:
673
+ newchar = chars[0]
674
+ string_out = dssp_string
675
+ for c in chars:
676
+ string_out = string_out.replace(c, newchar)
677
+ return string_out
678
+
679
+
680
+ def get_3state_dssp(pdb=None, coords=None):
681
+ if coords is not None:
682
+ pdb = "temp_dssp.pdb"
683
+ write_coords_to_pdb(coords, pdb, batched=False)
684
+ dssp_string = get_dssp_string(pdb)
685
+ if dssp_string is not None:
686
+ dssp_string = pool_dssp_symbols(dssp_string, newchar="L")
687
+ dssp_string = pool_dssp_symbols(dssp_string, chars=["H", "G", "I"])
688
+ dssp_string = pool_dssp_symbols(dssp_string, chars=["E", "B"])
689
+ if coords is not None:
690
+ subprocess.run(shlex.split(f"rm {pdb}"))
691
+ return dssp_string
692
+
693
+
694
+ ############## SAVE/LOAD UTILS #################################
695
+
696
+
697
+ def load_feats_from_pdb(
698
+ pdb, bb_atoms=["N", "CA", "C", "O"], load_atom73=False, **kwargs
699
+ ):
700
+ feats = {}
701
+ with open(pdb, "r") as f:
702
+ pdb_str = f.read()
703
+ protein_obj = protein.from_pdb_string(pdb_str, **kwargs)
704
+ bb_idxs = [residue_constants.atom_order[a] for a in bb_atoms]
705
+ bb_coords = torch.from_numpy(protein_obj.atom_positions[:, bb_idxs])
706
+ feats["bb_coords"] = bb_coords.float()
707
+ for k, v in vars(protein_obj).items():
708
+ feats[k] = torch.Tensor(v)
709
+ feats["aatype"] = feats["aatype"].long()
710
+ if load_atom73:
711
+ feats["atom73_coords"], feats["atom73_mask"] = atom37_to_atom73(
712
+ feats["atom_positions"], feats["aatype"], return_mask=True
713
+ )
714
+ return feats
715
+
716
+
717
+ def load_coords_from_pdb(
718
+ pdb,
719
+ atoms=["N", "CA", "C", "O"],
720
+ method="raw",
721
+ also_bfactors=False,
722
+ normalize_bfactors=True,
723
+ ):
724
+ """Returns array of shape (1, n_res, len(atoms), 3)"""
725
+ coords = []
726
+ bfactors = []
727
+ if method == "raw": # Raw numpy implementation, faster than biopdb
728
+ # Indexing into PDB format, allowing XXXX.XXX
729
+ coords_in_pdb = [slice(30, 38), slice(38, 46), slice(46, 54)]
730
+ # Indexing into PDB format, allowing XXX.XX
731
+ bfactor_in_pdb = slice(60, 66)
732
+
733
+ with open(pdb, "r") as f:
734
+ resi_prev = 1
735
+ counter = 0
736
+ for l in f:
737
+ l_split = l.rstrip("\n").split()
738
+ if len(l_split) > 0 and l_split[0] == "ATOM" and l_split[2] in atoms:
739
+ resi = l_split[5]
740
+ if resi == resi_prev:
741
+ counter += 1
742
+ else:
743
+ counter = 0
744
+ if counter < len(atoms):
745
+ xyz = [
746
+ np.array(l[s].strip()).astype(float) for s in coords_in_pdb
747
+ ]
748
+ coords.append(xyz)
749
+ if also_bfactors:
750
+ bfactor = np.array(l[bfactor_in_pdb].strip()).astype(float)
751
+ bfactors.append(bfactor)
752
+ resi_prev = resi
753
+ coords = torch.Tensor(np.array(coords)).view(1, -1, len(atoms), 3)
754
+ if also_bfactors:
755
+ bfactors = torch.Tensor(np.array(bfactors)).view(1, -1, len(atoms))
756
+ elif method == "biopdb":
757
+ structure = Bio.PDB.PDBParser(QUIET=True).get_structure(pdb[:-3], pdb)
758
+ for model in structure:
759
+ for chain in model:
760
+ for res in chain:
761
+ for atom in atoms:
762
+ try:
763
+ coords.append(np.asarray(res[atom].get_coord()))
764
+ if also_bfactors:
765
+ bfactors.append(np.asarray(res[atom].get_bfactor()))
766
+ except:
767
+ continue
768
+ else:
769
+ raise NotImplementedError(f"Invalid method for reading coords: {method}")
770
+ if also_bfactors:
771
+ if normalize_bfactors: # Normalize over Calphas
772
+ mean_b = bfactors[..., 1].mean()
773
+ std_b = bfactors[..., 1].var().sqrt()
774
+ bfactors = (bfactors - mean_b) / (std_b + 1e-6)
775
+ return coords, bfactors
776
+ return coords
777
+
778
+
779
+ def feats_to_pdb_str(
780
+ atom_positions,
781
+ aatype=None,
782
+ atom_mask=None,
783
+ residue_index=None,
784
+ chain_index=None,
785
+ b_factors=None,
786
+ atom_lines_only=True,
787
+ conect=False,
788
+ **kwargs,
789
+ ):
790
+ # Expects unbatched, cropped inputs. needs at least one of atom_mask, aatype
791
+ # Uses all-GLY aatype if aatype not given: does not infer from atom_mask
792
+ assert aatype is not None or atom_mask is not None
793
+ if atom_mask is None:
794
+ aatype = aatype.cpu()
795
+ atom_mask = atom37_mask_from_aatype(aatype, torch.ones_like(aatype))
796
+ if aatype is None:
797
+ seq_mask = atom_mask[:, residue_constants.atom_order["CA"]].cpu()
798
+ aatype = seq_mask * residue_constants.restype_order["G"]
799
+ if residue_index is None:
800
+ residue_index = torch.arange(aatype.shape[-1])
801
+ if chain_index is None:
802
+ chain_index = torch.ones_like(aatype)
803
+ if b_factors is None:
804
+ b_factors = torch.ones_like(atom_mask)
805
+
806
+ cast = lambda x: np.array(x.detach().cpu()) if isinstance(x, torch.Tensor) else x
807
+ prot = protein.Protein(
808
+ atom_positions=cast(atom_positions),
809
+ atom_mask=cast(atom_mask),
810
+ aatype=cast(aatype),
811
+ residue_index=cast(residue_index),
812
+ chain_index=cast(chain_index),
813
+ b_factors=cast(b_factors),
814
+ )
815
+ pdb_str = protein.to_pdb(prot, conect=conect)
816
+ if conect:
817
+ pdb_str, conect_str = pdb_str
818
+ if atom_lines_only:
819
+ pdb_lines = pdb_str.split("\n")
820
+ atom_lines = [
821
+ l for l in pdb_lines if len(l.split()) > 1 and l.split()[0] == "ATOM"
822
+ ]
823
+ pdb_str = "\n".join(atom_lines) + "\n"
824
+ if conect:
825
+ pdb_str = pdb_str + conect_str
826
+ return pdb_str
827
+
828
+
829
+ def bb_coords_to_pdb_str(coords, atoms=["N", "CA", "C", "O"]):
830
+ def _bb_pdb_line(atom, atomnum, resnum, coords, elem, res="GLY"):
831
+ atm = "ATOM".ljust(6)
832
+ atomnum = str(atomnum).rjust(5)
833
+ atomname = atom.center(4)
834
+ resname = res.ljust(3)
835
+ chain = "A".rjust(1)
836
+ resnum = str(resnum).rjust(4)
837
+ x = str("%8.3f" % (float(coords[0]))).rjust(8)
838
+ y = str("%8.3f" % (float(coords[1]))).rjust(8)
839
+ z = str("%8.3f" % (float(coords[2]))).rjust(8)
840
+ occ = str("%6.2f" % (float(1))).rjust(6)
841
+ temp = str("%6.2f" % (float(20))).ljust(6)
842
+ elname = elem.rjust(12)
843
+ return "%s%s %s %s %s%s %s%s%s%s%s%s\n" % (
844
+ atm,
845
+ atomnum,
846
+ atomname,
847
+ resname,
848
+ chain,
849
+ resnum,
850
+ x,
851
+ y,
852
+ z,
853
+ occ,
854
+ temp,
855
+ elname,
856
+ )
857
+
858
+ n = coords.shape[0]
859
+ na = len(atoms)
860
+ pdb_str = ""
861
+ for j in range(0, n, na):
862
+ for idx, atom in enumerate(atoms):
863
+ pdb_str += _bb_pdb_line(
864
+ atom,
865
+ j + idx + 1,
866
+ (j + na) // na,
867
+ coords[j + idx],
868
+ atom[0],
869
+ )
870
+ return pdb_str
871
+
872
+
873
+ def write_coords_to_pdb(
874
+ coords_in,
875
+ filename,
876
+ batched=True,
877
+ write_to_frames=False,
878
+ conect=False,
879
+ **all_atom_feats,
880
+ ):
881
+ def _write_pdb_string(pdb_str, filename, append=False):
882
+ write_mode = "a" if append else "w"
883
+ with open(filename, write_mode) as f:
884
+ if write_to_frames:
885
+ f.write("MODEL\n")
886
+ f.write(pdb_str)
887
+ if write_to_frames:
888
+ f.write("ENDMDL\n")
889
+
890
+ if not (batched or write_to_frames):
891
+ coords_in = [coords_in]
892
+ filename = [filename]
893
+ all_atom_feats = {k: [v] for k, v in all_atom_feats.items()}
894
+
895
+ n_atoms_in = coords_in[0].shape[-2]
896
+ is_bb_or_ca_pdb = n_atoms_in <= 4
897
+ for i, c in enumerate(coords_in):
898
+ n_res = c.shape[0]
899
+ if isinstance(filename, list):
900
+ fname = filename[i]
901
+ elif write_to_frames or len(coords_in) == 1:
902
+ fname = filename
903
+ else:
904
+ fname = f"{filename[:-4]}_{i}.pdb"
905
+
906
+ if is_bb_or_ca_pdb:
907
+ c_flat = rearrange(c, "n a c -> (n a) c")
908
+ if n_atoms_in == 1:
909
+ atoms = ["CA"]
910
+ if n_atoms_in == 3:
911
+ atoms = ["N", "CA", "C"]
912
+ if n_atoms_in == 4:
913
+ atoms = ["N", "CA", "C", "O"]
914
+ pdb_str = bb_coords_to_pdb_str(c_flat, atoms)
915
+ else:
916
+ feats_i = {k: v[i][:n_res] for k, v in all_atom_feats.items()}
917
+ pdb_str = feats_to_pdb_str(c, conect=conect, **feats_i)
918
+ _write_pdb_string(pdb_str, fname, append=write_to_frames and i > 0)
919
+
920
+
921
+ ###################### LOSSES ###################################
922
+
923
+
924
+ def masked_cross_entropy(logprobs, target, loss_mask):
925
+ # target is onehot
926
+ cel = -(target * logprobs)
927
+ cel = cel * loss_mask[..., None]
928
+ cel = cel.sum((-1, -2)) / loss_mask.sum(-1).clamp(min=1e-6)
929
+ return cel
930
+
931
+
932
+ def masked_mse(x, y, mask, weight=None):
933
+ data_dims = tuple(range(1, len(x.shape)))
934
+ mse = (x - y).pow(2) * mask
935
+ if weight is not None:
936
+ mse = mse * expand(weight, mse)
937
+ mse = mse.sum(data_dims) / mask.sum(data_dims).clamp(min=1e-6)
938
+ return mse
939
+
940
+
941
+ ###################### ALIGN ###################################
942
+
943
+
944
+ def quick_tmalign(
945
+ p, p_sele, q_sele, tmscore_type="avg", differentiable_rmsd=False, rmsd_type="ca"
946
+ ):
947
+ # sota 210712
948
+ write_coords_to_pdb(p_sele[:, 1:2], "temp_p.pdb", atoms=["CA"], batched=False)
949
+ write_coords_to_pdb(q_sele[:, 1:2], "temp_q.pdb", atoms=["CA"], batched=False)
950
+ cmd = f"{PATH_TO_TMALIGN} temp_p.pdb temp_q.pdb -m temp_matrix.txt"
951
+ outputs = subprocess.run(shlex.split(cmd), capture_output=True, text=True)
952
+
953
+ # Get RMSD and TM scores
954
+ tmout = outputs.stdout.split("\n")
955
+ rmsd = float(tmout[16].split()[4][:-1])
956
+ tmscore1 = float(tmout[17].split()[1])
957
+ tmscore2 = float(tmout[18].split()[1])
958
+ if tmscore_type == "avg":
959
+ tmscore = (tmscore1 + tmscore2) / 2
960
+ elif tmscore_type == "1" or tmscore_type == "query":
961
+ tmscore = tmscore1
962
+ elif tmscore_type == "2":
963
+ tmscore = tmscore2
964
+ elif tmscore_type == "both":
965
+ tmscore = (tmscore1, tmscore2)
966
+
967
+ # Get R, t and transform p coords
968
+ m = open("temp_matrix.txt", "r").readlines()[2:5]
969
+ m = [l.strip()[1:].strip() for l in m]
970
+ m = torch.Tensor([[float(i) for i in l.split()] for l in m]).to(p_sele.device)
971
+ R = m[:, 1:].t()
972
+ t = m[:, 0]
973
+ aligned_psele = p_sele @ R + t
974
+ aligned = p @ R + t
975
+
976
+ # Option 2 for rms - MSE of aligned against target coords using TMalign seq alignment. Differentiable
977
+ if differentiable_rmsd:
978
+ pi, qi = 0, 0
979
+ p_idxs, q_idxs = [], []
980
+ for i, c in enumerate(tmout[23]):
981
+ if c in [":", "."]:
982
+ p_idxs.append(pi)
983
+ q_idxs.append(qi)
984
+ if tmout[22][i] != "-":
985
+ pi += 1
986
+ if tmout[24][i] != "-":
987
+ qi += 1
988
+ tmalign_seq_p = p_sele[p_idxs]
989
+ tmalign_seq_q = q_sele[q_idxs]
990
+ if rmsd_type == "ca":
991
+ tmalign_seq_p = tmalign_seq_p[:, 1]
992
+ tmalign_seq_q = tmalign_seq_q[:, 1]
993
+ elif rmsd_type == "bb":
994
+ pass
995
+ rmsd = (tmalign_seq_p - tmalign_seq_q).pow(2).sum(-1).sqrt().mean()
996
+
997
+ # Delete temp files: p.pdb, q.pdb, matrix.txt, tmalign.out
998
+ subprocess.run(shlex.split("rm temp_p.pdb"))
999
+ subprocess.run(shlex.split("rm temp_q.pdb"))
1000
+ subprocess.run(shlex.split("rm temp_matrix.txt"))
1001
+
1002
+ return {"aligned": aligned, "rmsd": rmsd, "tm_score": tmscore, "R": R, "t": t}
1003
+
1004
+
1005
+ ###################### OTHER ###################################
1006
+
1007
+
1008
+ def expand(x, tgt=None, dim=1):
1009
+ if tgt is None:
1010
+ for _ in range(dim):
1011
+ x = x[..., None]
1012
+ else:
1013
+ while len(x.shape) < len(tgt.shape):
1014
+ x = x[..., None]
1015
+ return x
1016
+
1017
+
1018
+ def hookfn(name, verbose=False):
1019
+ def f(grad):
1020
+ if check_nan_inf(grad) > 0:
1021
+ print(name, "grad nan/infs", grad.shape, check_nan_inf(grad), grad)
1022
+ if verbose:
1023
+ print(name, "grad shape", grad.shape, "norm", grad.norm())
1024
+
1025
+ return f
1026
+
1027
+
1028
+ def trigger_nan_check(name, x):
1029
+ if check_nan_inf(x) > 0:
1030
+ print(name, check_nan_inf(x))
1031
+ raise Exception
1032
+
1033
+
1034
+ def check_nan_inf(x):
1035
+ return torch.isinf(x).sum() + torch.isnan(x).sum()
1036
+
1037
+
1038
+ def directory_find(atom, root="."):
1039
+ for path, dirs, files in os.walk(root):
1040
+ if atom in dirs:
1041
+ return os.path.join(path, atom)
1042
+
1043
+
1044
+ def dict2namespace(config):
1045
+ namespace = argparse.Namespace()
1046
+ for key, value in config.items():
1047
+ if isinstance(value, dict):
1048
+ new_value = dict2namespace(value)
1049
+ else:
1050
+ new_value = value
1051
+ setattr(namespace, key, new_value)
1052
+ return namespace
1053
+
1054
+
1055
+ def load_config(path, return_dict=False):
1056
+ with open(path, "r") as f:
1057
+ config_dict = yaml.safe_load(f)
1058
+ config = dict2namespace(config_dict)
1059
+ if return_dict:
1060
+ return config, config_dict
1061
+ else:
1062
+ return config
diffusion.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ https://github.com/ProteinDesignLab/protpardelle
3
+ License: MIT
4
+ Author: Alex Chu
5
+
6
+ Noise and diffusion utils.
7
+ """
8
+ from scipy.stats import norm
9
+ import torch
10
+ from torchtyping import TensorType
11
+
12
+ from core import utils
13
+
14
+
15
+ def noise_schedule(
16
+ time: TensorType[float],
17
+ function: str = "uniform",
18
+ sigma_data: float = 10.0,
19
+ psigma_mean: float = -1.2,
20
+ psigma_std: float = 1.2,
21
+ s_min: float = 0.001,
22
+ s_max: float = 60,
23
+ rho: float = 7.0,
24
+ time_power: float = 4.0,
25
+ constant_val: float = 0.0,
26
+ ):
27
+ def sampling_noise(time):
28
+ # high noise = 1; low noise = 0. opposite of Karras et al. schedule
29
+ term1 = s_max ** (1 / rho)
30
+ term2 = (1 - time) * (s_min ** (1 / rho) - s_max ** (1 / rho))
31
+ noise_level = sigma_data * ((term1 + term2) ** rho)
32
+ return noise_level
33
+
34
+ if function == "lognormal":
35
+ normal_sample = torch.Tensor(norm.ppf(time.cpu())).to(time)
36
+ noise_level = sigma_data * torch.exp(psigma_mean + psigma_std * normal_sample)
37
+ elif function == "uniform":
38
+ noise_level = sampling_noise(time)
39
+ elif function == "mpnn":
40
+ time = time**time_power
41
+ noise_level = sampling_noise(time)
42
+ elif function == "constant":
43
+ noise_level = torch.ones_like(time) * constant_val
44
+ return noise_level
45
+
46
+
47
+ def noise_coords(
48
+ coords: TensorType["b n a x", float],
49
+ noise_level: TensorType["b", float],
50
+ dummy_fill_masked_atoms: bool = False,
51
+ atom_mask: TensorType["b n a"] = None,
52
+ ):
53
+ # Does not apply atom mask after adding noise
54
+ if dummy_fill_masked_atoms:
55
+ assert atom_mask is not None
56
+ dummy_fill_mask = 1 - atom_mask
57
+ dummy_fill_value = coords[..., 1:2, :] # CA
58
+ # dummy_fill_value = utils.fill_in_cbeta_for_atom37(coords)[..., 3:4, :] # CB
59
+ coords = (
60
+ coords * atom_mask[..., None]
61
+ + dummy_fill_value * dummy_fill_mask[..., None]
62
+ )
63
+
64
+ noise = torch.randn_like(coords) * utils.expand(noise_level, coords)
65
+ noisy_coords = coords + noise
66
+ return noisy_coords
draw_samples.py ADDED
@@ -0,0 +1,353 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ https://github.com/ProteinDesignLab/protpardelle
3
+ License: MIT
4
+ Author: Alex Chu
5
+
6
+ Entry point for unconditional or simple conditional sampling.
7
+ """
8
+ import argparse
9
+ from datetime import datetime
10
+ import json
11
+ import os
12
+ import shlex
13
+ import subprocess
14
+ import sys
15
+ import time
16
+
17
+ from einops import repeat
18
+ import torch
19
+
20
+ from core import data
21
+ from core import residue_constants
22
+ from core import utils
23
+ import diffusion
24
+ import models
25
+ import sampling
26
+
27
+
28
+ def draw_and_save_samples(
29
+ model,
30
+ samples_per_len=8,
31
+ lengths=range(50, 512),
32
+ save_dir="./",
33
+ mode="backbone",
34
+ **sampling_kwargs,
35
+ ):
36
+ device = model.device
37
+ if mode == "backbone":
38
+ total_sampling_time = 0
39
+ for l in lengths:
40
+ prot_lens = torch.ones(samples_per_len).long() * l
41
+ seq_mask = model.make_seq_mask_for_sampling(prot_lens=prot_lens)
42
+ aux = sampling.draw_backbone_samples(
43
+ model,
44
+ seq_mask=seq_mask,
45
+ pdb_save_path=f"{save_dir}/len{format(l, '03d')}_samp",
46
+ return_aux=True,
47
+ return_sampling_runtime=True,
48
+ **sampling_kwargs,
49
+ )
50
+ total_sampling_time += aux["runtime"]
51
+ print("Samples drawn for length", l)
52
+ return total_sampling_time
53
+ elif mode == "allatom":
54
+ total_sampling_time = 0
55
+ for l in lengths:
56
+ prot_lens = torch.ones(samples_per_len).long() * l
57
+ seq_mask = model.make_seq_mask_for_sampling(prot_lens=prot_lens)
58
+ aux = sampling.draw_allatom_samples(
59
+ model,
60
+ seq_mask=seq_mask,
61
+ pdb_save_path=f"{save_dir}/len{format(l, '03d')}",
62
+ return_aux=True,
63
+ **sampling_kwargs,
64
+ )
65
+ total_sampling_time += aux["runtime"]
66
+ print("Samples drawn for length", l)
67
+ return total_sampling_time
68
+
69
+
70
+ def parse_idx_string(idx_str):
71
+ spans = idx_str.split(",")
72
+ idxs = []
73
+ for s in spans:
74
+ if "-" in s:
75
+ start, stop = s.split("-")
76
+ idxs.extend(list(range(int(start), int(stop))))
77
+ else:
78
+ idxs.append(int(s))
79
+ return idxs
80
+
81
+
82
+ class Manager(object):
83
+ def __init__(self):
84
+ self.parser = argparse.ArgumentParser(
85
+ formatter_class=argparse.RawTextHelpFormatter
86
+ )
87
+
88
+ self.parser.add_argument(
89
+ "--model_checkpoint",
90
+ type=str,
91
+ default="checkpoints",
92
+ help="Path to denoiser model weights and config",
93
+ )
94
+ self.parser.add_argument(
95
+ "--mpnnpath",
96
+ type=str,
97
+ default="checkpoints/minimpnn_state_dict.pth",
98
+ help="Path to minimpnn model weights",
99
+ )
100
+ self.parser.add_argument(
101
+ "--modeldir",
102
+ type=str,
103
+ help="Model base directory, ex 'training_logs/other/lemon-shape-51'",
104
+ )
105
+ self.parser.add_argument("--modelepoch", type=int, help="Model epoch, ex 1000")
106
+ self.parser.add_argument(
107
+ "--type", type=str, default="allatom", help="Type of model"
108
+ )
109
+ self.parser.add_argument(
110
+ "--param", type=str, default=None, help="Which sampling param to vary"
111
+ )
112
+ self.parser.add_argument(
113
+ "--paramval", type=str, default=None, help="Which param val to use"
114
+ )
115
+ self.parser.add_argument(
116
+ "--parampath",
117
+ type=str,
118
+ default=None,
119
+ help="Path to json file with params, either use param/paramval or parampath, not both",
120
+ )
121
+ self.parser.add_argument(
122
+ "--perlen", type=int, default=2, help="How many samples per sequence length"
123
+ )
124
+ self.parser.add_argument(
125
+ "--minlen", type=int, required=False, help="Minimum sequence length"
126
+ )
127
+ self.parser.add_argument(
128
+ "--maxlen",
129
+ type=int,
130
+ required=False,
131
+ help="Maximum sequence length, not inclusive",
132
+ )
133
+ self.parser.add_argument(
134
+ "--steplen",
135
+ type=int,
136
+ required=False,
137
+ help="How frequently to select sequence length, for steplen 2, would be 50, 52, 54, etc",
138
+ )
139
+ self.parser.add_argument(
140
+ "--num_lens",
141
+ type=int,
142
+ required=False,
143
+ help="If steplen not provided, how many random lengths to sample at",
144
+ )
145
+ self.parser.add_argument(
146
+ "--targetdir", type=str, default=".", help="Directory to save results"
147
+ )
148
+ self.parser.add_argument(
149
+ "--input_pdb", type=str, required=False, help="PDB file to condition on"
150
+ )
151
+ self.parser.add_argument(
152
+ "--resample_idxs",
153
+ type=str,
154
+ required=False,
155
+ help="Indices from PDB file to resample. Zero-indexed, comma-delimited, can use dashes, eg 0,2-5,7",
156
+ )
157
+
158
+ def add_argument(self, *args, **kwargs):
159
+ self.parser.add_argument(*args, **kwargs)
160
+
161
+ def parse_args(self):
162
+ self.args = self.parser.parse_args()
163
+
164
+ return self.args
165
+
166
+
167
+ def main():
168
+ # Set up params, arguments, sampling config
169
+ ####################
170
+ manager = Manager()
171
+ manager.parse_args()
172
+ args = manager.args
173
+ print(args)
174
+ is_test_run = False
175
+ seed = 0
176
+ samples_per_len = args.perlen
177
+ min_len = args.minlen
178
+ max_len = args.maxlen
179
+ len_step_size = args.steplen
180
+ device = "cuda:0"
181
+
182
+ # setting default sampling config
183
+ if args.type == "backbone":
184
+ sampling_config = sampling.default_backbone_sampling_config()
185
+ elif args.type == "allatom":
186
+ sampling_config = sampling.default_allatom_sampling_config()
187
+
188
+ sampling_kwargs = vars(sampling_config)
189
+
190
+ # Parse conditioning inputs
191
+ input_pdb_len = None
192
+ if args.input_pdb:
193
+ input_feats = utils.load_feats_from_pdb(args.input_pdb, protein_only=True)
194
+ input_pdb_len = input_feats["aatype"].shape[0]
195
+ if args.resample_idxs:
196
+ print(
197
+ f"Warning: when sampling conditionally, the input pdb length ({input_pdb_len} residues) is used automatically for the sampling lengths."
198
+ )
199
+ resample_idxs = parse_idx_string(args.resample_idxs)
200
+ else:
201
+ resample_idxs = list(range(input_pdb_len))
202
+ cond_idxs = [i for i in range(input_pdb_len) if i not in resample_idxs]
203
+ to_batch_size = lambda x: repeat(x, "... -> b ...", b=samples_per_len).to(
204
+ device
205
+ )
206
+
207
+ # For unconditional model, center coords on whole structure
208
+ centered_coords = data.apply_random_se3(
209
+ input_feats["atom_positions"],
210
+ atom_mask=input_feats["atom_mask"],
211
+ translation_scale=0.0,
212
+ )
213
+ cond_kwargs = {}
214
+ cond_kwargs["gt_coords"] = to_batch_size(centered_coords)
215
+ cond_kwargs["gt_cond_atom_mask"] = to_batch_size(input_feats["atom_mask"])
216
+ cond_kwargs["gt_cond_atom_mask"][:, resample_idxs] = 0
217
+ cond_kwargs["gt_aatype"] = to_batch_size(input_feats["aatype"])
218
+ cond_kwargs["gt_cond_seq_mask"] = torch.zeros_like(cond_kwargs["gt_aatype"])
219
+ cond_kwargs["gt_cond_seq_mask"][:, cond_idxs] = 1
220
+ sampling_kwargs.update(cond_kwargs)
221
+
222
+ # Determine lengths to sample at
223
+ if min_len is not None and max_len is not None:
224
+ if len_step_size is not None:
225
+ sampling_lengths = range(min_len, max_len, len_step_size)
226
+ else:
227
+ sampling_lengths = list(
228
+ torch.randint(min_len, max_len, size=(args.num_lens,))
229
+ )
230
+ elif input_pdb_len is not None:
231
+ sampling_lengths = [input_pdb_len]
232
+ else:
233
+ raise Exception("Need to provide a set of protein lengths or an input pdb.")
234
+
235
+ total_num_samples = len(list(sampling_lengths)) * samples_per_len
236
+
237
+ model_directory = args.modeldir
238
+ epoch = args.modelepoch
239
+ base_dir = args.targetdir
240
+
241
+ date_string = datetime.now().strftime("%y-%m-%d-%H-%M-%S")
242
+ if is_test_run:
243
+ date_string = f"test-{date_string}"
244
+
245
+ # Update sampling config with arguments
246
+ if args.param:
247
+ var_param = args.param
248
+ var_value = args.paramval
249
+ sampling_kwargs[var_param] = (
250
+ None
251
+ if var_value == "None"
252
+ else int(var_value)
253
+ if var_param == "n_steps"
254
+ else float(var_value)
255
+ )
256
+ elif args.parampath:
257
+ with open(args.parampath) as f:
258
+ var_params = json.loads(f.read())
259
+ sampling_kwargs.update(var_params)
260
+
261
+ # this is only used for the readme, keep s_min and s_max as params instead of struct_noise_schedule
262
+ sampling_kwargs_readme = list(sampling_kwargs.items())
263
+
264
+ print("Base directory:", base_dir)
265
+ save_dir = f"{base_dir}/samples"
266
+ save_init_dir = f"{base_dir}/samples_inits"
267
+
268
+ print("Samples saved to:", save_dir)
269
+ ####################
270
+
271
+ torch.manual_seed(seed)
272
+ if not os.path.exists(save_dir):
273
+ subprocess.run(shlex.split(f"mkdir -p {save_dir}"))
274
+
275
+ if not os.path.exists(save_init_dir):
276
+ subprocess.run(shlex.split(f"mkdir -p {save_init_dir}"))
277
+
278
+ # Load model
279
+ if args.type == "backbone":
280
+ if args.model_checkpoint:
281
+ checkpoint = f"{args.model_checkpoint}/backbone_state_dict.pth"
282
+ cfg_path = f"{args.model_checkpoint}/backbone.yml"
283
+ else:
284
+ checkpoint = (
285
+ f"{model_directory}/checkpoints/epoch{epoch}_training_state.pth"
286
+ )
287
+ cfg_path = f"{model_directory}/configs/backbone.yml"
288
+ cfg = utils.load_config(cfg_path)
289
+ weights = torch.load(checkpoint, map_location=device)["model_state_dict"]
290
+ model = models.Protpardelle(cfg, device=device)
291
+ model.load_state_dict(weights)
292
+ model.to(device)
293
+ model.eval()
294
+ model.device = device
295
+ elif args.type == "allatom":
296
+ if args.model_checkpoint:
297
+ checkpoint = f"{args.model_checkpoint}/allatom_state_dict.pth"
298
+ cfg_path = f"{args.model_checkpoint}/allatom.yml"
299
+ else:
300
+ checkpoint = (
301
+ f"{model_directory}/checkpoints/epoch{epoch}_training_state.pth"
302
+ )
303
+ cfg_path = f"{model_directory}/configs/allatom.yml"
304
+ config = utils.load_config(cfg_path)
305
+ weights = torch.load(checkpoint, map_location=device)["model_state_dict"]
306
+ model = models.Protpardelle(config, device=device)
307
+ model.load_state_dict(weights)
308
+ model.load_minimpnn(args.mpnnpath)
309
+ model.to(device)
310
+ model.eval()
311
+ model.device = device
312
+
313
+ # Sampling
314
+ with open(base_dir + "/readme.txt", "w") as f:
315
+ f.write(f"Sampling run for {date_string}\n")
316
+ f.write(f"Random seed {seed}\n")
317
+ f.write(f"Model checkpoint: {checkpoint}\n")
318
+ f.write(
319
+ f"{samples_per_len} samples per length from {min_len}:{max_len}:{len_step_size}\n"
320
+ )
321
+ f.write("Sampling params:\n")
322
+ for k, v in sampling_kwargs_readme:
323
+ f.write(f"{k}\t{v}\n")
324
+
325
+ print(f"Model loaded from {checkpoint}")
326
+ print(f"Beginning sampling for {date_string}...")
327
+
328
+ # Draw samples
329
+ start_time = time.time()
330
+ sampling_time = draw_and_save_samples(
331
+ model,
332
+ samples_per_len=samples_per_len,
333
+ lengths=sampling_lengths,
334
+ save_dir=save_dir,
335
+ mode=args.type,
336
+ **sampling_kwargs,
337
+ )
338
+ time_elapsed = time.time() - start_time
339
+
340
+ print(f"Sampling concluded after {time_elapsed} seconds.")
341
+ print(f"Of this, {sampling_time} seconds were for actual sampling.")
342
+ print(f"{total_num_samples} total samples were drawn.")
343
+
344
+ with open(base_dir + "/readme.txt", "a") as f:
345
+ f.write(f"Total job time: {time_elapsed} seconds\n")
346
+ f.write(f"Model run time: {sampling_time} seconds\n")
347
+ f.write(f"Total samples drawn: {total_num_samples}\n")
348
+
349
+ return
350
+
351
+
352
+ if __name__ == "__main__":
353
+ main()
evaluation.py ADDED
@@ -0,0 +1,406 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ https://github.com/ProteinDesignLab/protpardelle
3
+ License: MIT
4
+ Author: Alex Chu
5
+
6
+ Utils for computing evaluation metrics.
7
+ """
8
+ import argparse
9
+ import os
10
+ import warnings
11
+ from typing import Tuple
12
+
13
+ from Bio.Align import substitution_matrices
14
+ import numpy as np
15
+ import torch
16
+ from transformers import AutoTokenizer, EsmForProteinFolding
17
+ from torchtyping import TensorType
18
+
19
+ from core import residue_constants
20
+ from core import utils
21
+ from core import protein_mpnn as mpnn
22
+ import modules
23
+ import sampling
24
+
25
+
26
+ def mean(x):
27
+ if len(x) == 0:
28
+ return 0
29
+ return sum(x) / len(x)
30
+
31
+
32
+ def calculate_seq_identity(seq1, seq2, seq_mask=None):
33
+ identity = (seq1 == seq2.to(seq1)).float()
34
+ if seq_mask is not None:
35
+ identity *= seq_mask.to(seq1)
36
+ return identity.sum(-1) / seq_mask.to(seq1).sum(-1).clamp(min=1)
37
+ else:
38
+ return identity.mean(-1)
39
+
40
+
41
+ def design_sequence(coords, model=None, num_seqs=1, disallow_aas=["C"]):
42
+ # Returns list of strs; seqs like 'MKRLLDS', not aatypes
43
+ if model is None:
44
+ model = mpnn.get_mpnn_model()
45
+ if isinstance(coords, str):
46
+ temp_pdb = False
47
+ pdb_fn = coords
48
+ else:
49
+ temp_pdb = True
50
+ pdb_fn = f"tmp{np.random.randint(0, 1e8)}.pdb"
51
+ gly_idx = residue_constants.restype_order["G"]
52
+ gly_aatype = (torch.ones(coords.shape[0]) * gly_idx).long()
53
+ utils.write_coords_to_pdb(coords, pdb_fn, batched=False, aatype=gly_aatype)
54
+
55
+ with torch.no_grad():
56
+ designed_seqs = mpnn.run_proteinmpnn(
57
+ model=model,
58
+ pdb_path=pdb_fn,
59
+ num_seq_per_target=num_seqs,
60
+ omit_AAs=disallow_aas,
61
+ )
62
+
63
+ if temp_pdb:
64
+ os.system("rm " + pdb_fn)
65
+ return designed_seqs
66
+
67
+
68
+ def get_esmfold_model(device=None):
69
+ if device is None:
70
+ device = "cuda" if torch.cuda.is_available() else "cpu"
71
+ model = EsmForProteinFolding.from_pretrained("facebook/esmfold_v1").to(device)
72
+ model.esm = model.esm.half()
73
+ return model
74
+
75
+
76
+ def inference_esmfold(sequence_list, model, tokenizer):
77
+ inputs = tokenizer(
78
+ sequence_list,
79
+ return_tensors="pt",
80
+ padding=True,
81
+ add_special_tokens=False,
82
+ ).to(model.device)
83
+ outputs = model(**inputs)
84
+ # positions is shape (l, b, n, a, c)
85
+ pred_coords = outputs.positions[-1].contiguous()
86
+ plddts = (outputs.plddt[:, :, 1] * inputs.attention_mask).sum(
87
+ -1
88
+ ) / inputs.attention_mask.sum(-1).clamp(min=1e-3)
89
+ return pred_coords, plddts
90
+
91
+
92
+ def predict_structures(sequences, model="esmfold", tokenizer=None, force_unk_to_X=True):
93
+ # Expects seqs like 'MKRLLDS', not aatypes
94
+ # model can be a model, or a string describing which pred model to load
95
+ if isinstance(sequences, str):
96
+ sequences = [sequences]
97
+ if model == "esmfold":
98
+ model = get_esmfold_model()
99
+ device = model.device
100
+ if tokenizer is None:
101
+ tokenizer = AutoTokenizer.from_pretrained("facebook/esmfold_v1")
102
+
103
+ aatype = [utils.seq_to_aatype(seq).to(device) for seq in sequences]
104
+
105
+ with torch.no_grad():
106
+ if isinstance(model, EsmForProteinFolding):
107
+ pred_coords, plddts = inference_esmfold(sequences, model, tokenizer)
108
+
109
+ seq_lens = [len(s) for s in sequences]
110
+ trimmed_coords = [c[: seq_lens[i]] for i, c in enumerate(pred_coords)]
111
+ trimmed_coords_atom37 = [
112
+ utils.atom37_coords_from_atom14(c, aatype[i])
113
+ for i, c in enumerate(trimmed_coords)
114
+ ]
115
+ return trimmed_coords_atom37, plddts
116
+
117
+
118
+ def compute_structure_metric(coords1, coords2, metric="ca_rmsd", atom_mask=None):
119
+ # coords1 tensor[l][a][3]
120
+ def _tmscore(a, b, mask=None):
121
+ length = len(b)
122
+ dists = (a - b).pow(2).sum(-1)
123
+ d0 = 1.24 * ((length - 15) ** (1 / 3)) - 1.8
124
+ term = 1 / (1 + ((dists) / (d0**2)))
125
+ if mask is None:
126
+ return term.mean()
127
+ else:
128
+ term = term * mask
129
+ return term.sum() / mask.sum().clamp(min=1)
130
+
131
+ aligned_coords1_ca, (R, t) = utils.kabsch_align(coords1[:, 1], coords2[:, 1])
132
+ aligned_coords1 = coords1 - coords1[:, 1:2].mean(0, keepdim=True)
133
+ aligned_coords1 = aligned_coords1 @ R.t() + t
134
+ if metric == "ca_rmsd":
135
+ return (aligned_coords1_ca - coords2[:, 1]).pow(2).sum(-1).sqrt().mean()
136
+ elif metric == "tm_score":
137
+ tm = _tmscore(aligned_coords1_ca, coords2[:, 1])
138
+ # TODO: return 1 - tm score for now so sorts work properly
139
+ return 1 - tm
140
+ elif metric == "allatom_tm":
141
+ # Align on Ca, compute allatom TM
142
+ assert atom_mask is not None
143
+ return _tmscore(aligned_coords1, coords2, mask=atom_mask)
144
+ elif metric == "allatom_lddt":
145
+ assert atom_mask is not None
146
+ lddt = modules.lddt(
147
+ coords1.reshape(-1, 3),
148
+ coords2.reshape(-1, 3),
149
+ atom_mask.reshape(-1, 1),
150
+ per_residue=False,
151
+ )
152
+ return lddt
153
+ else:
154
+ raise NotImplementedError
155
+
156
+
157
+ def compute_self_consistency(
158
+ comparison_structures, # can be sampled or ground truth
159
+ sampled_sequences=None,
160
+ mpnn_model=None,
161
+ struct_pred_model=None,
162
+ tokenizer=None,
163
+ num_seqs=1,
164
+ return_aux=False,
165
+ metric="ca_rmsd",
166
+ output_file=None,
167
+ ):
168
+ # Typically used for eval of backbone sampling or sequence design or joint sampling
169
+ # (Maybe MPNN) + Fold + TM/RMSD
170
+ # Expects seqs like 'MKRLLDS', not aatypes
171
+ per_sample_primary_metrics = []
172
+ per_sample_secondary_metrics = []
173
+ per_sample_plddts = []
174
+ per_sample_coords = []
175
+ per_sample_seqs = []
176
+ aux = {}
177
+ for i, coords in enumerate(comparison_structures):
178
+ if sampled_sequences is None:
179
+ seqs_to_predict = design_sequence(
180
+ coords, model=mpnn_model, num_seqs=num_seqs
181
+ )
182
+ else:
183
+ seqs_to_predict = sampled_sequences[i]
184
+ pred_coords, plddts = predict_structures(
185
+ seqs_to_predict, model=struct_pred_model, tokenizer=tokenizer
186
+ )
187
+ primary_metric_name = "tm_score" if metric == "tm_score" else "ca_rmsd"
188
+ secondary_metric_name = "tm_score" if metric == "both" else None
189
+ primary_metrics = [
190
+ compute_structure_metric(coords.to(pred), pred, metric=primary_metric_name)
191
+ for pred in pred_coords
192
+ ]
193
+ if secondary_metric_name:
194
+ secondary_metrics = [
195
+ compute_structure_metric(
196
+ coords.to(pred), pred, metric=secondary_metric_name
197
+ )
198
+ for pred in pred_coords
199
+ ]
200
+ aux.setdefault(secondary_metric_name, []).extend(secondary_metrics)
201
+ else:
202
+ secondary_metrics = primary_metrics
203
+
204
+ aux.setdefault("pred", []).extend(pred_coords)
205
+ seqs_to_predict_arr = seqs_to_predict
206
+ if isinstance(seqs_to_predict_arr, str):
207
+ seqs_to_predict_arr = [seqs_to_predict_arr]
208
+
209
+ aux.setdefault("seqs", []).extend(seqs_to_predict_arr)
210
+ aux.setdefault("plddt", []).extend(plddts)
211
+ aux.setdefault("rmsd", []).extend(primary_metrics)
212
+
213
+ # Report best rmsd design only among MPNN reps
214
+ all_designs = [
215
+ (m, p, t, c, s)
216
+ for m, p, t, c, s in zip(
217
+ primary_metrics,
218
+ plddts,
219
+ secondary_metrics,
220
+ pred_coords,
221
+ seqs_to_predict_arr,
222
+ )
223
+ ]
224
+ best_rmsd_design = min(all_designs, key=lambda x: x[0])
225
+ per_sample_primary_metrics.append(best_rmsd_design[0].detach().cpu())
226
+ per_sample_plddts.append(best_rmsd_design[1].detach().cpu())
227
+ per_sample_secondary_metrics.append(best_rmsd_design[2].detach().cpu())
228
+ per_sample_coords.append(best_rmsd_design[3])
229
+ per_sample_seqs.append(best_rmsd_design[4])
230
+ best_idx = np.argmin(per_sample_primary_metrics)
231
+ metrics = {
232
+ "sc_rmsd_best": per_sample_primary_metrics[best_idx],
233
+ "sc_plddt_best": per_sample_plddts[best_idx],
234
+ "sc_rmsd_mean": mean(per_sample_primary_metrics),
235
+ "sc_plddt_mean": mean(per_sample_plddts),
236
+ }
237
+ if metric == "both":
238
+ metrics["sc_tmscore_best"] = per_sample_secondary_metrics[best_idx]
239
+ metrics["sc_tmscore_mean"] = mean(per_sample_secondary_metrics)
240
+
241
+ if output_file:
242
+ pred_coords = per_sample_coords
243
+ designed_seqs = per_sample_seqs
244
+
245
+ if torch.isnan(pred_coords[best_idx]).sum() == 0:
246
+ designed_seq = utils.seq_to_aatype(designed_seqs[best_idx])
247
+ utils.write_coords_to_pdb(
248
+ pred_coords[best_idx],
249
+ output_file,
250
+ batched=False,
251
+ aatype=designed_seq,
252
+ )
253
+
254
+ if return_aux:
255
+ return metrics, best_idx, aux
256
+ else:
257
+ return metrics, best_idx
258
+
259
+
260
+ def compute_secondary_structure_content(coords_batch):
261
+ dssp_sample = []
262
+ for i, c in enumerate(coords_batch):
263
+ with warnings.catch_warnings():
264
+ warnings.simplefilter("ignore")
265
+ dssp_str = utils.get_3state_dssp(coords=c)
266
+ if dssp_str is None or len(dssp_str) == 0:
267
+ pass
268
+ else:
269
+ dssp_sample.append(dssp_str)
270
+ dssp_sample = "".join(dssp_sample)
271
+ metrics = {}
272
+ metrics["sample_pct_beta"] = mean([c == "E" for c in dssp_sample])
273
+ metrics["sample_pct_alpha"] = mean([c == "H" for c in dssp_sample])
274
+ return metrics
275
+
276
+
277
+ def compute_bond_length_metric(
278
+ cropped_coords_list, cropped_aatypes_list, atom_mask=None
279
+ ):
280
+ bond_length_dict = utils.batched_fullatom_bond_lengths_from_coords(
281
+ cropped_coords_list, cropped_aatypes_list, atom_mask=atom_mask
282
+ )
283
+ all_errors = {}
284
+ for aa1, d in bond_length_dict.items():
285
+ aa3 = residue_constants.restype_1to3[aa1]
286
+ per_bond_errors = []
287
+ for bond, lengths in d.items():
288
+ a1, a2 = bond.split("-")
289
+ ideal_val = None
290
+ for bond in residue_constants.standard_residue_bonds[aa3]:
291
+ if (
292
+ bond.atom1_name == a1
293
+ and bond.atom2_name == a2
294
+ or bond.atom1_name == a2
295
+ and bond.atom2_name == a1
296
+ ):
297
+ ideal_val = bond.length
298
+ break
299
+ error = (np.array(lengths) - ideal_val) ** 2
300
+ per_bond_errors.append(error.mean() ** 0.5)
301
+ if len(per_bond_errors) > 0: # often no Cys
302
+ per_res_errors = np.mean(per_bond_errors)
303
+ all_errors[aa1] = per_res_errors
304
+ return np.mean(list(all_errors.values()))
305
+
306
+
307
+ def evaluate_backbone_generation(
308
+ model,
309
+ n_samples=1,
310
+ mpnn_model=None,
311
+ struct_pred_model=None,
312
+ tokenizer=None,
313
+ sample_length_range=(50, 512),
314
+ ):
315
+ sampling_config = sampling.default_backbone_sampling_config()
316
+ trimmed_coords, seq_mask = sampling.draw_backbone_samples(
317
+ model,
318
+ n_samples=n_samples,
319
+ sample_length_range=sample_length_range,
320
+ **vars(sampling_config),
321
+ )
322
+ sc_metrics, best_idx, aux = compute_self_consistency(
323
+ trimmed_coords,
324
+ mpnn_model=mpnn_model,
325
+ struct_pred_model=struct_pred_model,
326
+ tokenizer=tokenizer,
327
+ return_aux=True,
328
+ )
329
+ dssp_metrics = compute_secondary_structure_content(trimmed_coords)
330
+ all_metrics = {**sc_metrics, **dssp_metrics}
331
+ all_metrics = {f"bb_{k}": v for k, v in all_metrics.items()}
332
+ return all_metrics, (trimmed_coords, seq_mask, best_idx, aux["pred"], aux["seqs"])
333
+
334
+
335
+ def evaluate_allatom_generation(
336
+ model,
337
+ n_samples,
338
+ two_stage_sampling=True,
339
+ struct_pred_model=None,
340
+ tokenizer=None,
341
+ sample_length_range=(50, 512),
342
+ ):
343
+ # Convert allatom model to codesign model by loading miniMPNN
344
+ model.task = "codesign"
345
+ model.load_minimpnn()
346
+ model.eval()
347
+
348
+ sampling_config = sampling.default_allatom_sampling_config()
349
+ ret = sampling.draw_allatom_samples(
350
+ model,
351
+ n_samples=n_samples,
352
+ two_stage_sampling=two_stage_sampling,
353
+ **vars(sampling_config),
354
+ )
355
+ (
356
+ cropped_samp_coords,
357
+ cropped_samp_aatypes,
358
+ samp_atom_mask,
359
+ stage1_coords,
360
+ seq_mask,
361
+ ) = ret
362
+
363
+ # Compute self consistency
364
+ if struct_pred_model is None:
365
+ struct_pred_model = EsmForProteinFolding.from_pretrained(
366
+ "facebook/esmfold_v1"
367
+ ).to(device)
368
+ struct_pred_model.esm = struct_pred_model.esm.half()
369
+ if tokenizer is None:
370
+ tokenizer = AutoTokenizer.from_pretrained("facebook/esmfold_v1")
371
+ designed_seqs = [utils.aatype_to_seq(a) for a in cropped_samp_aatypes]
372
+ sc_metrics, best_idx, sc_aux = compute_self_consistency(
373
+ comparison_structures=cropped_samp_coords,
374
+ sampled_sequences=designed_seqs,
375
+ struct_pred_model=struct_pred_model,
376
+ tokenizer=tokenizer,
377
+ return_aux=True,
378
+ )
379
+ aa_metrics_out = {f"aa_{k}": v for k, v in sc_metrics.items()}
380
+
381
+ # Compute secondary structure content
382
+ cropped_bb_coords = [c[..., [0, 1, 2, 4], :] for c in cropped_samp_coords]
383
+ dssp_metrics = compute_secondary_structure_content(cropped_bb_coords)
384
+ aa_metrics_out = {**aa_metrics_out, **dssp_metrics}
385
+
386
+ # Compute bond length RMSE
387
+ if two_stage_sampling: # compute on original sample
388
+ bond_rmse_coords = stage1_coords
389
+ else:
390
+ bond_rmse_coords = cropped_samp_coords
391
+ bond_rmse = compute_bond_length_metric(
392
+ bond_rmse_coords, cropped_samp_aatypes, samp_atom_mask
393
+ )
394
+ aa_metrics_out["aa_bond_rmse"] = bond_rmse
395
+
396
+ # Convert codesign model back to allatom model and return metrics
397
+ model.task = "allatom"
398
+ model.remove_minimpnn()
399
+ aa_aux_out = (
400
+ cropped_samp_coords,
401
+ cropped_samp_aatypes,
402
+ samp_atom_mask,
403
+ sc_aux["pred"],
404
+ best_idx,
405
+ )
406
+ return aa_metrics_out, aa_aux_out
models.py ADDED
@@ -0,0 +1,778 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ https://github.com/ProteinDesignLab/protpardelle
3
+ License: MIT
4
+ Author: Alex Chu
5
+
6
+ Top-level model definitions.
7
+ Typically these are initialized with config rather than arguments.
8
+ """
9
+ import argparse
10
+ from functools import partial
11
+ import os
12
+ from typing import Callable, List, Optional
13
+
14
+ from einops import rearrange, repeat
15
+ import numpy as np
16
+ import torch
17
+ import torch.nn as nn
18
+ import torch.nn.functional as F
19
+ from torchtyping import TensorType
20
+
21
+ from core import protein_mpnn
22
+ from core import residue_constants
23
+ from core import utils
24
+ import diffusion
25
+ import evaluation
26
+ import modules
27
+
28
+
29
+ class MiniMPNN(nn.Module):
30
+ """Wrapper for ProteinMPNN network to predict sequence from structure."""
31
+
32
+ def __init__(self, config: argparse.Namespace):
33
+ super().__init__()
34
+ self.config = config
35
+ self.model_config = cfg = config.model.mpnn_model
36
+ self.n_tokens = config.data.n_aatype_tokens
37
+ self.seq_emb_dim = cfg.n_channel
38
+ time_cond_dim = cfg.n_channel * cfg.noise_cond_mult
39
+
40
+ self.noise_block = modules.NoiseConditioningBlock(cfg.n_channel, time_cond_dim)
41
+ self.token_embedding = nn.Linear(self.n_tokens, self.seq_emb_dim)
42
+ self.mpnn_net = modules.NoiseConditionalProteinMPNN(
43
+ n_channel=cfg.n_channel,
44
+ n_layers=cfg.n_layers,
45
+ n_neighbors=cfg.n_neighbors,
46
+ time_cond_dim=time_cond_dim,
47
+ vocab_size=config.data.n_aatype_tokens,
48
+ input_S_is_embeddings=True,
49
+ )
50
+ self.proj_out = nn.Linear(cfg.n_channel, self.n_tokens)
51
+
52
+ def forward(
53
+ self,
54
+ denoised_coords: TensorType["b n a x", float],
55
+ coords_noise_level: TensorType["b", float],
56
+ seq_mask: TensorType["b n", float],
57
+ residue_index: TensorType["b n", int],
58
+ seq_self_cond: Optional[TensorType["b n t", float]] = None, # logprobs
59
+ return_embeddings: bool = False,
60
+ ):
61
+ coords_noise_level_scaled = 0.25 * torch.log(coords_noise_level)
62
+ noise_cond = self.noise_block(coords_noise_level_scaled)
63
+
64
+ b, n, _, _ = denoised_coords.shape
65
+ if seq_self_cond is None or not self.model_config.use_self_conditioning:
66
+ seq_emb_in = torch.zeros(b, n, self.seq_emb_dim).to(denoised_coords)
67
+ else:
68
+ seq_emb_in = self.token_embedding(seq_self_cond.exp())
69
+
70
+ node_embs, encoder_embs = self.mpnn_net(
71
+ denoised_coords, seq_emb_in, seq_mask, residue_index, noise_cond
72
+ )
73
+
74
+ logits = self.proj_out(node_embs)
75
+ pred_logprobs = F.log_softmax(logits, -1)
76
+
77
+ if return_embeddings:
78
+ return pred_logprobs, node_embs, encoder_embs
79
+ return pred_logprobs
80
+
81
+
82
+ class CoordinateDenoiser(nn.Module):
83
+ """Wrapper for U-ViT module to denoise structure coordinates."""
84
+
85
+ def __init__(self, config: argparse.Namespace):
86
+ super().__init__()
87
+ self.config = config
88
+
89
+ # Configuration
90
+ self.sigma_data = config.data.sigma_data
91
+ m_cfg = config.model.struct_model
92
+ nc = m_cfg.n_channel
93
+ bb_atoms = ["N", "CA", "C", "O"]
94
+ n_atoms = config.model.struct_model.n_atoms
95
+ self.use_conv = len(m_cfg.uvit.n_filt_per_layer) > 0
96
+ if self.use_conv and n_atoms == 37:
97
+ n_atoms += 1 # make it an even number
98
+ self.n_atoms = n_atoms
99
+ self.bb_idxs = [residue_constants.atom_order[a] for a in bb_atoms]
100
+ n_xyz = 9 if config.model.crop_conditional else 6
101
+ nc_in = n_xyz * n_atoms # xyz + selfcond xyz + maybe cropcond xyz
102
+
103
+ # Neural networks
104
+ n_noise_channel = nc * m_cfg.noise_cond_mult
105
+ self.net = modules.TimeCondUViT(
106
+ seq_len=config.data.fixed_size,
107
+ patch_size=m_cfg.uvit.patch_size,
108
+ dim=nc,
109
+ depth=m_cfg.uvit.n_layers,
110
+ n_filt_per_layer=m_cfg.uvit.n_filt_per_layer,
111
+ heads=m_cfg.uvit.n_heads,
112
+ dim_head=m_cfg.uvit.dim_head,
113
+ conv_skip_connection=m_cfg.uvit.conv_skip_connection,
114
+ n_atoms=n_atoms,
115
+ channels_per_atom=n_xyz,
116
+ time_cond_dim=n_noise_channel,
117
+ position_embedding_type=m_cfg.uvit.position_embedding_type,
118
+ )
119
+ self.noise_block = modules.NoiseConditioningBlock(nc, n_noise_channel)
120
+
121
+ def forward(
122
+ self,
123
+ noisy_coords: TensorType["b n a x", float],
124
+ noise_level: TensorType["b", float],
125
+ seq_mask: TensorType["b n", float],
126
+ residue_index: Optional[TensorType["b n", int]] = None,
127
+ struct_self_cond: Optional[TensorType["b n a x", float]] = None,
128
+ struct_crop_cond: Optional[TensorType["b n a x", float]] = None,
129
+ ):
130
+ # Prep inputs and time conditioning
131
+ actual_var_data = self.sigma_data**2
132
+ var_noisy_coords = noise_level**2 + actual_var_data
133
+ emb = noisy_coords / utils.expand(var_noisy_coords.sqrt(), noisy_coords)
134
+ struct_noise_scaled = 0.25 * torch.log(noise_level)
135
+ noise_cond = self.noise_block(struct_noise_scaled)
136
+
137
+ # Prepare self- and crop-conditioning and concatenate along channels
138
+ if struct_self_cond is None:
139
+ struct_self_cond = torch.zeros_like(noisy_coords)
140
+ if self.config.model.crop_conditional:
141
+ if struct_crop_cond is None:
142
+ struct_crop_cond = torch.zeros_like(noisy_coords)
143
+ else:
144
+ struct_crop_cond = struct_crop_cond / self.sigma_data
145
+ emb = torch.cat([emb, struct_self_cond, struct_crop_cond], -1)
146
+ else:
147
+ emb = torch.cat([emb, struct_self_cond], -1)
148
+
149
+ # Run neural network
150
+ emb = self.net(emb, noise_cond, seq_mask=seq_mask, residue_index=residue_index)
151
+
152
+ # Preconditioning from Karras et al.
153
+ out_scale = noise_level * actual_var_data**0.5 / torch.sqrt(var_noisy_coords)
154
+ skip_scale = actual_var_data / var_noisy_coords
155
+ emb = emb * utils.expand(out_scale, emb)
156
+ skip_info = noisy_coords * utils.expand(skip_scale, noisy_coords)
157
+ denoised_coords_x0 = emb + skip_info
158
+
159
+ # Don't use atom mask; denoise all atoms
160
+ denoised_coords_x0 *= utils.expand(seq_mask, denoised_coords_x0)
161
+ return denoised_coords_x0
162
+
163
+
164
+ class Protpardelle(nn.Module):
165
+ """All-atom protein diffusion-based generative model.
166
+
167
+ This class wraps a structure denoising network and a sequence prediction network
168
+ to do structure/sequence co-design (for all-atom generation), or backbone generation.
169
+
170
+ It can be trained for one of four main tasks. To produce the all-atom (co-design)
171
+ Protpardelle model, we will typically pretrain an 'allatom' model, then use this
172
+ to train a 'seqdes' model. A 'seqdes' model can be trained with either a backbone
173
+ or allatom denoiser. The two can be combined to yield all-atom (co-design) Protpardelle
174
+ without further training.
175
+ 'backbone': train only a backbone coords denoiser.
176
+ 'seqdes': train only a mini-MPNN, using a pretrained coords denoiser.
177
+ 'allatom': train only an allatom coords denoiser (cannot do all-atom generation
178
+ by itself).
179
+ 'codesign': train both an allatom denoiser and mini-MPNN at once.
180
+
181
+ """
182
+
183
+ def __init__(self, config: argparse.Namespace, device: str = "cpu"):
184
+ super().__init__()
185
+ self.config = config
186
+ self.device = device
187
+ self.task = config.model.task
188
+ self.n_tokens = config.data.n_aatype_tokens
189
+
190
+ self.use_mpnn_model = self.task in ["seqdes", "codesign"]
191
+
192
+ # Modules
193
+ self.all_modules = {}
194
+ self.bb_idxs = [0, 1, 2, 4]
195
+ self.n_atoms = 37
196
+ self.struct_model = CoordinateDenoiser(config)
197
+ self.all_modules["struct_model"] = self.struct_model
198
+ self.bb_idxs = self.struct_model.bb_idxs
199
+ self.n_atoms = self.struct_model.n_atoms
200
+
201
+ if self.use_mpnn_model:
202
+ self.mpnn_model = MiniMPNN(config)
203
+ self.all_modules["mpnn_model"] = self.mpnn_model
204
+
205
+ # Load any pretrained modules
206
+ for module_name in self.config.model.pretrained_modules:
207
+ self.load_pretrained_module(module_name)
208
+
209
+ # Diffusion-related
210
+ self.sigma_data = self.struct_model.sigma_data
211
+ self.training_noise_schedule = partial(
212
+ diffusion.noise_schedule,
213
+ sigma_data=self.sigma_data,
214
+ **vars(config.diffusion.training),
215
+ )
216
+ self.sampling_noise_schedule_default = self.make_sampling_noise_schedule()
217
+
218
+ def load_pretrained_module(self, module_name: str, ckpt_path: Optional[str] = None):
219
+ """Load pretrained weights for a given module name."""
220
+ assert module_name in ["struct_model", "mpnn_model"], module_name
221
+
222
+ # Load pretrained checkpoint
223
+ if ckpt_path is None:
224
+ ckpt_path = getattr(self.config.model, f"{module_name}_checkpoint")
225
+ ckpt_path = os.path.join(self.config.train.home_dir, ckpt_path)
226
+ ckpt_dict = torch.load(ckpt_path, map_location=self.device)
227
+ model_state_dict = ckpt_dict["model_state_dict"]
228
+
229
+ # Get only submodule state_dict
230
+ submodule_state_dict = {
231
+ sk[len(module_name) + 1 :]: sv
232
+ for sk, sv in model_state_dict.items()
233
+ if sk.startswith(module_name)
234
+ }
235
+
236
+ # Load into module
237
+ module = dict(self.named_modules())[module_name]
238
+ module.load_state_dict(submodule_state_dict)
239
+
240
+ # Freeze unneeded modules
241
+ if module_name == "struct_model":
242
+ self.struct_model = module
243
+ if self.task == "seqdes":
244
+ for p in module.parameters():
245
+ p.requires_grad = False
246
+ if module_name == "mpnn_model":
247
+ self.mpnn_model = module
248
+ if self.task not in ["codesign", "seqdes"]:
249
+ for p in module.parameters():
250
+ p.requires_grad = False
251
+
252
+ return module
253
+
254
+ def load_minimpnn(self, mpnn_ckpt_path: Optional[str] = None):
255
+ """Convert an allatom model to a codesign model."""
256
+ if mpnn_ckpt_path is None:
257
+ mpnn_ckpt_path = "checkpoints/minimpnn_state_dict.pth"
258
+ self.mpnn_model = MiniMPNN(self.config).to(self.device)
259
+ self.load_pretrained_module("mpnn_model", ckpt_path=mpnn_ckpt_path)
260
+ self.use_mpnn_model = True
261
+ return
262
+
263
+ def remove_minimpnn(self):
264
+ """Revert a codesign model to an allatom model to a codesign model."""
265
+ self.use_mpnn_model = False
266
+ self.mpnn_model = None
267
+ self.all_modules["mpnn_model"] = None
268
+
269
+ def make_sampling_noise_schedule(self, **noise_kwargs):
270
+ """Make the default sampling noise schedule function."""
271
+ noise_schedule_kwargs = vars(self.config.diffusion.sampling)
272
+ if len(noise_kwargs) > 0:
273
+ noise_schedule_kwargs.update(noise_kwargs)
274
+ return partial(diffusion.noise_schedule, **noise_schedule_kwargs)
275
+
276
+ def forward(
277
+ self,
278
+ *,
279
+ noisy_coords: TensorType["b n a x", float],
280
+ noise_level: TensorType["b", float],
281
+ seq_mask: TensorType["b n", float],
282
+ residue_index: TensorType["b n", int],
283
+ struct_self_cond: Optional[TensorType["b n a x", float]] = None,
284
+ struct_crop_cond: Optional[TensorType["b n a x", float]] = None,
285
+ seq_self_cond: Optional[TensorType["b n t", float]] = None, # logprobs
286
+ run_struct_model: bool = True,
287
+ run_mpnn_model: bool = True,
288
+ ):
289
+ """Main forward function for denoising/co-design.
290
+
291
+ Arguments:
292
+ noisy_coords: noisy array of xyz coordinates.
293
+ noise_level: std of noise for each example in the batch.
294
+ seq_mask: mask indicating which indexes contain data.
295
+ residue_index: residue ordering. This is used by proteinMPNN, but currently
296
+ only used by the diffusion model when the 'absolute_residx' or
297
+ 'relative' position_embedding_type is specified.
298
+ struct_self_cond: denoised coordinates from the previous step, scaled
299
+ down by sigma data.
300
+ struct_crop_cond: unnoised coordinates. unscaled (scaled down by sigma
301
+ data inside the denoiser)
302
+ seq_self_cond: mpnn-predicted sequence logprobs from the previous step.
303
+ run_struct_model: flag to optionally not run structure denoiser.
304
+ run_mpnn_model: flag to optionally not run mini-mpnn.
305
+ """
306
+
307
+ # Coordinate denoiser
308
+ denoised_x0 = noisy_coords
309
+ if run_struct_model:
310
+ denoised_x0 = self.struct_model(
311
+ noisy_coords,
312
+ noise_level,
313
+ seq_mask,
314
+ residue_index=residue_index,
315
+ struct_self_cond=struct_self_cond,
316
+ struct_crop_cond=struct_crop_cond,
317
+ )
318
+
319
+ # Mini-MPNN
320
+ aatype_logprobs = None
321
+ if self.use_mpnn_model and run_mpnn_model:
322
+ aatype_logprobs = self.mpnn_model(
323
+ denoised_x0.detach(),
324
+ noise_level,
325
+ seq_mask,
326
+ residue_index,
327
+ seq_self_cond=seq_self_cond,
328
+ return_embeddings=False,
329
+ )
330
+ aatype_logprobs = aatype_logprobs * seq_mask[..., None]
331
+
332
+ # Process outputs
333
+ if aatype_logprobs is None:
334
+ aatype_logprobs = repeat(seq_mask, "b n -> b n t", t=self.n_tokens)
335
+ aatype_logprobs = torch.ones_like(aatype_logprobs)
336
+ aatype_logprobs = F.log_softmax(aatype_logprobs, -1)
337
+ struct_self_cond_out = denoised_x0.detach() / self.sigma_data
338
+ seq_self_cond_out = aatype_logprobs.detach()
339
+
340
+ return denoised_x0, aatype_logprobs, struct_self_cond_out, seq_self_cond_out
341
+
342
+ def make_seq_mask_for_sampling(
343
+ self,
344
+ prot_lens: Optional[TensorType["b", int]] = None,
345
+ n_samples: int = 1,
346
+ min_len: int = 50,
347
+ max_len: Optional[int] = None,
348
+ ):
349
+ """Makes a sequence mask of varying protein lengths (only input required
350
+ to begin sampling).
351
+ """
352
+ if max_len is None:
353
+ max_len = self.config.data.fixed_size
354
+ if prot_lens is None:
355
+ possible_lens = np.arange(min_len, max_len)
356
+ prot_lens = torch.Tensor(np.random.choice(possible_lens, n_samples))
357
+ else:
358
+ n_samples = len(prot_lens)
359
+ max_len = max(prot_lens)
360
+ mask = repeat(torch.arange(max_len), "n -> b n", b=n_samples)
361
+ mask = (mask < prot_lens[:, None]).float().to(self.device)
362
+ return mask
363
+
364
+ def sample(
365
+ self,
366
+ *,
367
+ seq_mask: TensorType["b n", float] = None,
368
+ n_samples: int = 1,
369
+ min_len: int = 50,
370
+ max_len: int = 512,
371
+ residue_index: TensorType["b n", int] = None,
372
+ gt_coords: TensorType["b n a x", float] = None,
373
+ gt_coords_traj: List[TensorType["b n a x", float]] = None,
374
+ gt_cond_atom_mask: TensorType["b n a", float] = None,
375
+ gt_aatype: TensorType["b n", int] = None,
376
+ gt_cond_seq_mask: TensorType["b n", float] = None,
377
+ apply_cond_proportion: float = 1.0,
378
+ n_steps: int = 200,
379
+ step_scale: float = 1.2,
380
+ s_churn: float = 50.0,
381
+ noise_scale: float = 1.0,
382
+ s_t_min: float = 0.01,
383
+ s_t_max: float = 50.0,
384
+ temperature: float = 1.0,
385
+ top_p: float = 1.0,
386
+ disallow_aas: List[int] = [4, 20], # cys, unk
387
+ sidechain_mode: bool = False,
388
+ skip_mpnn_proportion: float = 0.7,
389
+ anneal_seq_resampling_rate: Optional[str] = None, # linear, cosine
390
+ use_fullmpnn: bool = False,
391
+ use_fullmpnn_for_final: bool = True,
392
+ use_reconstruction_guidance: bool = False,
393
+ use_classifier_free_guidance: bool = False, # defaults to replacement guidance if these are all false
394
+ guidance_scale: float = 1.0,
395
+ noise_schedule: Optional[Callable] = None,
396
+ tqdm_pbar: Optional[Callable] = None,
397
+ return_last: bool = True,
398
+ return_aux: bool = False,
399
+ ):
400
+ """Sampling function for backbone or all-atom diffusion. All arguments are optional.
401
+
402
+ Arguments:
403
+ seq_mask: mask defining the number and lengths of proteins to be sampled.
404
+ n_samples: number of samples to draw (if seq_mask not provided).
405
+ min_len: minimum length of proteins to be sampled (if seq_mask not provided).
406
+ max_len: maximum length of proteins to be sampled (if seq_mask not provided).
407
+ residue_index: residue index of proteins to be sampled.
408
+ gt_coords: conditioning information for coords.
409
+ gt_coords_traj: conditioning information for coords specified for each timestep
410
+ (if gt_coords is not provided).
411
+ gt_cond_atom_mask: mask identifying atoms to apply gt_coords.
412
+ gt_aatype: conditioning information for sequence.
413
+ gt_cond_seq_mask: sequence positions to apply gt_aatype.
414
+ apply_cond_proportion: the proportion of timesteps to apply the conditioning.
415
+ e.g. if 0.5, then the first 50% of steps use conditioning, and the last 50%
416
+ are unconditional.
417
+ n_steps: number of denoising steps (ODE discretizations).
418
+ step_scale: scale to apply to the score.
419
+ s_churn: gamma = s_churn / n_steps describes the additional noise to add
420
+ relatively at each denoising step. Use 0.0 for deterministic sampling or
421
+ 0.2 * n_steps as a rough default for stochastic sampling.
422
+ noise_scale: scale to apply to gamma.
423
+ s_t_min: don't apply s_churn below this noise level.
424
+ s_t_max: don't apply s_churn above this noise level.
425
+ temperature: scale to apply to aatype logits.
426
+ top_p: don't tokens which fall outside this proportion of the total probability.
427
+ disallow_aas: don't sample these token indices.
428
+ sidechain_mode: whether to do all-atom sampling (False for backbone-only).
429
+ skip_mpnn_proportion: proportion of timesteps from the start to skip running
430
+ mini-MPNN.
431
+ anneal_seq_resampling_rate: whether and how to decay the probability of
432
+ running mini-MPNN. None, 'linear', or 'cosine'
433
+ use_fullmpnn: use "full" ProteinMPNN at each step.
434
+ use_fullmpnn_for_final: use "full" ProteinMPNN at the final step.
435
+ use_reconstruction_guidance: use reconstruction guidance on the conditioning.
436
+ use_classifier_free_guidance: use classifier-free guidance on the conditioning.
437
+ guidance_scale: weight for reconstruction/classifier-free guidance.
438
+ noise_schedule: specify the noise level timesteps for sampling.
439
+ tqdm_pbar: progress bar in interactive contexts.
440
+ return_last: return only the sampled structure and sequence.
441
+ return_aux: return a dict of everything associated with the sampling run.
442
+ """
443
+
444
+ def ode_step(sigma_in, sigma_next, xt_in, x0_pred, gamma, guidance_in=None):
445
+ if gamma > 0:
446
+ t_hat = sigma_in + gamma * sigma_in
447
+ sigma_delta = torch.sqrt(t_hat**2 - sigma_in**2)
448
+ noisier_x = xt_in + utils.expand(
449
+ sigma_delta, xt_in
450
+ ) * noise_scale * torch.randn_like(xt_in).to(xt_in)
451
+ xt_in = noisier_x * utils.expand(seq_mask, noisier_x)
452
+ sigma_in = t_hat
453
+
454
+ mask = (sigma_in > 0).float()
455
+ score = (xt_in - x0_pred) / utils.expand(sigma_in.clamp(min=1e-6), xt_in)
456
+ score = score * utils.expand(mask, score)
457
+ if use_reconstruction_guidance:
458
+ guidance, guidance_mask = guidance_in
459
+ guidance = guidance * guidance_mask[..., None]
460
+ guidance_std = guidance[guidance_mask.bool()].var().sqrt()
461
+ score_std = score[guidance_mask.bool()].var().sqrt()
462
+ score = score + guidance * guidance_scale
463
+ if use_classifier_free_guidance:
464
+ # guidance_in is the unconditional x0 (x0_pred is the conditional x0)
465
+ # guidance_scale = 1 + w from Ho paper
466
+ # ==0: use only unconditional score; <1: interpolate the scores;
467
+ # ==1: use only conditional score; >1: skew towards conditional score
468
+ uncond_x0 = guidance_in
469
+ uncond_score = (xt_in - uncond_x0) / utils.expand(
470
+ sigma_in.clamp(min=1e-6), xt_in
471
+ )
472
+ uncond_score = uncond_score * utils.expand(mask, uncond_score)
473
+ score = guidance_scale * score + (1 - guidance_scale) * uncond_score
474
+ step = score * step_scale * utils.expand(sigma_next - sigma_in, score)
475
+ new_xt = xt_in + step
476
+ return new_xt
477
+
478
+ def sample_aatype(logprobs):
479
+ # Top-p truncation
480
+ probs = F.softmax(logprobs.clone(), dim=-1)
481
+ sorted_prob, sorted_idxs = torch.sort(probs, descending=True)
482
+ cumsum_prob = torch.cumsum(sorted_prob, dim=-1)
483
+ sorted_indices_to_remove = cumsum_prob > top_p
484
+ sorted_indices_to_remove[..., 0] = 0
485
+ sorted_prob[sorted_indices_to_remove] = 0
486
+ orig_probs = torch.scatter(
487
+ torch.zeros_like(sorted_prob),
488
+ dim=-1,
489
+ index=sorted_idxs,
490
+ src=sorted_prob,
491
+ )
492
+
493
+ # Apply temperature and disallowed AAs and sample
494
+ assert temperature >= 0.0
495
+ scaled_logits = orig_probs.clamp(min=1e-9).log() / (temperature + 1e-4)
496
+ if disallow_aas:
497
+ unwanted_mask = torch.zeros(scaled_logits.shape[-1]).to(scaled_logits)
498
+ unwanted_mask[disallow_aas] = 1
499
+ scaled_logits -= unwanted_mask * 1e10
500
+ orig_probs = F.softmax(scaled_logits, dim=-1)
501
+ categorical = torch.distributions.Categorical(probs=orig_probs)
502
+ samp_aatype = categorical.sample()
503
+ return samp_aatype
504
+
505
+ def design_with_fullmpnn(batched_coords, seq_mask):
506
+ seq_lens = seq_mask.sum(-1).long()
507
+ designed_seqs = [
508
+ evaluation.design_sequence(c[: seq_lens[i]], model=fullmpnn_model)[0]
509
+ for i, c in enumerate(batched_coords)
510
+ ]
511
+ designed_aatypes, _ = utils.batched_seq_to_aatype_and_mask(
512
+ designed_seqs, max_len=seq_mask.shape[-1]
513
+ )
514
+ return designed_aatypes
515
+
516
+ # Initialize masks/features
517
+ if seq_mask is None: # Sample random lengths
518
+ assert gt_aatype is None # Don't condition on aatype without seq_mask
519
+ seq_mask = self.make_seq_mask_for_sampling(
520
+ n_samples=n_samples,
521
+ min_len=min_len,
522
+ max_len=max_len,
523
+ )
524
+ if residue_index is None:
525
+ residue_index = torch.arange(seq_mask.shape[-1])
526
+ residue_index = repeat(residue_index, "n -> b n", b=seq_mask.shape[0])
527
+ residue_index = residue_index.to(seq_mask) * seq_mask
528
+ if use_fullmpnn or use_fullmpnn_for_final:
529
+ fullmpnn_model = protein_mpnn.get_mpnn_model(
530
+ path_to_model_weights=self.config.train.home_dir
531
+ + "/ProteinMPNN/vanilla_model_weights",
532
+ device=self.device,
533
+ )
534
+
535
+ # Initialize noise schedule/parameters
536
+ to_batch_size = lambda x: x * torch.ones(seq_mask.shape[0]).to(self.device)
537
+ s_t_min = s_t_min * self.sigma_data
538
+ s_t_max = s_t_max * self.sigma_data
539
+ if noise_schedule is None:
540
+ noise_schedule = self.sampling_noise_schedule_default
541
+ sigma = noise_schedule(1)
542
+ timesteps = torch.linspace(1, 0, n_steps + 1)
543
+
544
+ # Set up conditioning/guidance information
545
+ crop_cond_coords = None
546
+ if gt_coords is None:
547
+ coords_shape = seq_mask.shape + (self.n_atoms, 3)
548
+ xt = torch.randn(*coords_shape).to(self.device) * sigma
549
+ xt *= utils.expand(seq_mask, xt)
550
+ else:
551
+ assert gt_coords_traj is None
552
+ noise_levels = [to_batch_size(noise_schedule(t)) for t in timesteps]
553
+ gt_coords_traj = [
554
+ diffusion.noise_coords(gt_coords, nl) for nl in noise_levels
555
+ ]
556
+ xt = gt_coords_traj[0]
557
+ if gt_cond_atom_mask is not None:
558
+ crop_cond_coords = gt_coords * gt_cond_atom_mask[..., None]
559
+ gt_atom_mask = None
560
+ if gt_aatype is not None:
561
+ gt_atom_mask = utils.atom37_mask_from_aatype(gt_aatype, seq_mask)
562
+ fake_logits = repeat(seq_mask, "b n -> b n t", t=self.n_tokens)
563
+ s_hat = (sample_aatype(fake_logits) * seq_mask).long()
564
+
565
+ # Initialize superposition for all-atom sampling
566
+ if sidechain_mode:
567
+ b, n = seq_mask.shape[:2]
568
+
569
+ # Latest predicted x0 for sidechain superpositions
570
+ atom73_state_0 = torch.zeros(b, n, 73, 3).to(xt)
571
+
572
+ # Current state xt for sidechain superpositions (denoised to different levels)
573
+ atom73_state_t = torch.randn(b, n, 73, 3).to(xt) * sigma
574
+
575
+ # Noise level of xt
576
+ sigma73_last = torch.ones(b, n, 73).to(xt) * sigma
577
+
578
+ # Seqhat and mask used to choose sidechains for euler step (b, n)
579
+ s_hat = (seq_mask * 7).long()
580
+ mask37 = utils.atom37_mask_from_aatype(s_hat, seq_mask).bool()
581
+ mask73 = utils.atom73_mask_from_aatype(s_hat, seq_mask).bool()
582
+ begin_mpnn_step = int(n_steps * skip_mpnn_proportion)
583
+
584
+ # Prepare to run sampling trajectory
585
+ sigma = to_batch_size(sigma)
586
+ x0 = None
587
+ x0_prev = None
588
+ x_self_cond = None
589
+ s_logprobs = None
590
+ s_self_cond = None
591
+ if tqdm_pbar is None:
592
+ tqdm_pbar = lambda x: x
593
+ torch.set_grad_enabled(False)
594
+
595
+ # *t_traj is the denoising trajectory; *0_traj is the evolution of predicted clean data
596
+ # s0 are aatype probs of shape (b n t); s_hat are discrete aatype of shape (b n)
597
+ xt_traj, x0_traj, st_traj, s0_traj = [], [], [], []
598
+
599
+ # Sampling trajectory
600
+ for i, t in tqdm_pbar(enumerate(iter(timesteps[1:]))):
601
+ # Set up noise levels
602
+ sigma_next = noise_schedule(t)
603
+ if i == n_steps - 1:
604
+ sigma_next *= 0
605
+ gamma = (
606
+ s_churn / n_steps
607
+ if (sigma_next >= s_t_min and sigma_next <= s_t_max)
608
+ else 0.0
609
+ )
610
+ sigma_next = to_batch_size(sigma_next)
611
+
612
+ if sidechain_mode:
613
+ # Fill in noise for masked positions since xt is initialized to zeros at each step
614
+ dummy_fill_noise = torch.randn_like(xt) * utils.expand(sigma, xt)
615
+ zero_atom_mask = utils.atom37_mask_from_aatype(s_hat, seq_mask)
616
+ dummy_fill_mask = 1 - zero_atom_mask[..., None]
617
+ xt = xt * zero_atom_mask[..., None] + dummy_fill_noise * dummy_fill_mask
618
+ else: # backbone only
619
+ bb_seq = (seq_mask * residue_constants.restype_order["G"]).long()
620
+ bb_atom_mask = utils.atom37_mask_from_aatype(bb_seq, seq_mask)
621
+ xt *= bb_atom_mask[..., None]
622
+
623
+ # Enable grad for reconstruction guidance
624
+ if use_reconstruction_guidance:
625
+ torch.set_grad_enabled(True)
626
+ xt.requires_grad = True
627
+
628
+ # Run denoising network
629
+ run_mpnn = not sidechain_mode or i > begin_mpnn_step
630
+ x0, s_logprobs, x_self_cond, s_self_cond = self.forward(
631
+ noisy_coords=xt,
632
+ noise_level=sigma,
633
+ seq_mask=seq_mask,
634
+ residue_index=residue_index,
635
+ struct_self_cond=x_self_cond,
636
+ struct_crop_cond=crop_cond_coords,
637
+ seq_self_cond=s_self_cond,
638
+ run_mpnn_model=run_mpnn,
639
+ )
640
+
641
+ # Compute additional stuff for guidance
642
+ if use_reconstruction_guidance:
643
+ loss = (x0 - gt_coords).pow(2).sum(-1)
644
+ loss = loss * gt_cond_atom_mask
645
+ loss = loss.sum() / gt_cond_atom_mask.sum().clamp(min=1)
646
+ xt.retain_grad()
647
+ loss.backward()
648
+ guidance = xt.grad.clone()
649
+ xt.grad *= 0
650
+ torch.set_grad_enabled(False)
651
+ if use_classifier_free_guidance:
652
+ assert not use_reconstruction_guidance
653
+ uncond_x0, _, _, _ = self.forward(
654
+ noisy_coords=xt,
655
+ noise_level=sigma,
656
+ seq_mask=seq_mask,
657
+ residue_index=residue_index,
658
+ struct_self_cond=x_self_cond,
659
+ seq_self_cond=s_self_cond,
660
+ run_mpnn_model=run_mpnn,
661
+ )
662
+
663
+ # Structure denoising step
664
+ if not sidechain_mode: # backbone
665
+ if sigma[0] > 0:
666
+ xt = ode_step(sigma, sigma_next, xt, x0, gamma)
667
+ else:
668
+ xt = x0
669
+ else: # allatom
670
+ # Write x0 into atom73_state_0 for atoms corresponding to old seqhat
671
+ atom73_state_0[mask73] = x0[mask37]
672
+
673
+ # Determine sequence resampling probability
674
+ if anneal_seq_resampling_rate is not None:
675
+ step_time = 1 - (i - begin_mpnn_step) / max(
676
+ 1, n_steps - begin_mpnn_step
677
+ )
678
+ if anneal_seq_resampling_rate == "linear":
679
+ resampling_rate = step_time
680
+ elif anneal_seq_resampling_rate == "cosine":
681
+ k = 2
682
+ resampling_rate = (
683
+ 1 + np.cos(2 * np.pi * (step_time - 0.5))
684
+ ) / k
685
+ resample_this_step = np.random.uniform() < resampling_rate
686
+
687
+ # Resample sequence or design with full ProteinMPNN
688
+ if i == n_steps - 1 and use_fullmpnn_for_final:
689
+ s_hat = design_with_fullmpnn(x0, seq_mask).to(x0.device)
690
+ elif anneal_seq_resampling_rate is None or resample_this_step:
691
+ if run_mpnn and use_fullmpnn:
692
+ s_hat = design_with_fullmpnn(x0, seq_mask).to(x0.device)
693
+ else:
694
+ s_hat = sample_aatype(s_logprobs)
695
+
696
+ # Overwrite s_hat with any conditioning information
697
+ if (i + 1) / n_steps <= apply_cond_proportion:
698
+ if gt_cond_seq_mask is not None and gt_aatype is not None:
699
+ s_hat = (
700
+ 1 - gt_cond_seq_mask
701
+ ) * s_hat + gt_cond_seq_mask * gt_aatype
702
+ s_hat = s_hat.long()
703
+
704
+ # Set masks for collapsing superposition using new sequence
705
+ mask37 = utils.atom37_mask_from_aatype(s_hat, seq_mask).bool()
706
+ mask73 = utils.atom73_mask_from_aatype(s_hat, seq_mask).bool()
707
+
708
+ # Determine prev noise levels for atoms corresponding to new sequence
709
+ step_sigma_prev = (
710
+ torch.ones(*xt.shape[:-1]).to(xt) * sigma[..., None, None]
711
+ )
712
+ step_sigma_prev[mask37] = sigma73_last[mask73] # b, n, 37
713
+ step_sigma_next = sigma_next[..., None, None] # b, 1, 1
714
+
715
+ # Denoising step on atoms corresponding to new sequence
716
+ b, n = mask37.shape[:2]
717
+ step_xt = torch.zeros(b, n, 37, 3).to(xt)
718
+ step_x0 = torch.zeros(b, n, 37, 3).to(xt)
719
+ step_xt[mask37] = atom73_state_t[mask73]
720
+ step_x0[mask37] = atom73_state_0[mask73]
721
+
722
+ guidance_in = None
723
+ if (i + 1) / n_steps <= apply_cond_proportion:
724
+ if use_reconstruction_guidance:
725
+ guidance_in = (guidance, mask37.float())
726
+ elif use_classifier_free_guidance:
727
+ guidance_in = uncond_x0
728
+
729
+ step_xt = ode_step(
730
+ step_sigma_prev,
731
+ step_sigma_next,
732
+ step_xt,
733
+ step_x0,
734
+ gamma,
735
+ guidance_in=guidance_in,
736
+ )
737
+ xt = step_xt
738
+
739
+ # Write new xt into atom73_state_t for atoms corresponding to new seqhat and update sigma_last
740
+ atom73_state_t[mask73] = step_xt[mask37]
741
+ sigma73_last[mask73] = step_sigma_next[0].item()
742
+
743
+ # Replacement guidance if conditioning information provided
744
+ if (i + 1) / n_steps <= apply_cond_proportion:
745
+ if gt_coords_traj is not None:
746
+ if gt_cond_atom_mask is None:
747
+ xt = gt_coords_traj[i + 1]
748
+ else:
749
+ xt = (1 - gt_cond_atom_mask)[
750
+ ..., None
751
+ ] * xt + gt_cond_atom_mask[..., None] * gt_coords_traj[i + 1]
752
+
753
+ sigma = sigma_next
754
+
755
+ # Logging
756
+ xt_scale = self.sigma_data / utils.expand(
757
+ torch.sqrt(sigma_next**2 + self.sigma_data**2), xt
758
+ )
759
+ scaled_xt = xt * xt_scale
760
+ xt_traj.append(scaled_xt.cpu())
761
+ x0_traj.append(x0.cpu())
762
+ st_traj.append(s_hat.cpu())
763
+ s0_traj.append(s_logprobs.cpu())
764
+
765
+ if return_last:
766
+ return xt, s_hat, seq_mask
767
+ elif return_aux:
768
+ return {
769
+ "x": xt,
770
+ "s": s_hat,
771
+ "seq_mask": seq_mask,
772
+ "xt_traj": xt_traj,
773
+ "x0_traj": x0_traj,
774
+ "st_traj": st_traj,
775
+ "s0_traj": s0_traj,
776
+ }
777
+ else:
778
+ return xt_traj, x0_traj, st_traj, s0_traj, seq_mask
modules.py ADDED
@@ -0,0 +1,696 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ https://github.com/ProteinDesignLab/protpardelle
3
+ License: MIT
4
+ Author: Alex Chu
5
+
6
+ Neural network modules. Many of these are adapted from open source modules.
7
+ """
8
+ from typing import List, Sequence, Optional
9
+
10
+ from einops import rearrange, reduce, repeat
11
+ from einops.layers.torch import Rearrange
12
+ import numpy as np
13
+ from rotary_embedding_torch import RotaryEmbedding
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ from transformers import AutoTokenizer, EsmModel
18
+
19
+ from core import protein_mpnn
20
+ from core import residue_constants
21
+ from core import utils
22
+
23
+
24
+ ########################################
25
+ # Adapted from https://github.com/ermongroup/ddim
26
+
27
+
28
+ def downsample(x):
29
+ return nn.functional.avg_pool2d(x, 2, 2, ceil_mode=True)
30
+
31
+
32
+ def upsample_coords(x, shape):
33
+ new_l, new_w = shape
34
+ return nn.functional.interpolate(x, size=(new_l, new_w), mode="nearest")
35
+
36
+
37
+ ########################################
38
+ # Adapted from https://github.com/aqlaboratory/openfold
39
+
40
+
41
+ def permute_final_dims(tensor: torch.Tensor, inds: List[int]):
42
+ zero_index = -1 * len(inds)
43
+ first_inds = list(range(len(tensor.shape[:zero_index])))
44
+ return tensor.contiguous().permute(first_inds + [zero_index + i for i in inds])
45
+
46
+
47
+ def lddt(
48
+ all_atom_pred_pos: torch.Tensor,
49
+ all_atom_positions: torch.Tensor,
50
+ all_atom_mask: torch.Tensor,
51
+ cutoff: float = 15.0,
52
+ eps: float = 1e-10,
53
+ per_residue: bool = True,
54
+ ) -> torch.Tensor:
55
+ n = all_atom_mask.shape[-2]
56
+ dmat_true = torch.sqrt(
57
+ eps
58
+ + torch.sum(
59
+ (all_atom_positions[..., None, :] - all_atom_positions[..., None, :, :])
60
+ ** 2,
61
+ dim=-1,
62
+ )
63
+ )
64
+
65
+ dmat_pred = torch.sqrt(
66
+ eps
67
+ + torch.sum(
68
+ (all_atom_pred_pos[..., None, :] - all_atom_pred_pos[..., None, :, :]) ** 2,
69
+ dim=-1,
70
+ )
71
+ )
72
+ dists_to_score = (
73
+ (dmat_true < cutoff)
74
+ * all_atom_mask
75
+ * permute_final_dims(all_atom_mask, (1, 0))
76
+ * (1.0 - torch.eye(n, device=all_atom_mask.device))
77
+ )
78
+
79
+ dist_l1 = torch.abs(dmat_true - dmat_pred)
80
+
81
+ score = (
82
+ (dist_l1 < 0.5).type(dist_l1.dtype)
83
+ + (dist_l1 < 1.0).type(dist_l1.dtype)
84
+ + (dist_l1 < 2.0).type(dist_l1.dtype)
85
+ + (dist_l1 < 4.0).type(dist_l1.dtype)
86
+ )
87
+ score = score * 0.25
88
+
89
+ dims = (-1,) if per_residue else (-2, -1)
90
+ norm = 1.0 / (eps + torch.sum(dists_to_score, dim=dims))
91
+ score = norm * (eps + torch.sum(dists_to_score * score, dim=dims))
92
+
93
+ return score
94
+
95
+
96
+ class RelativePositionalEncoding(nn.Module):
97
+ def __init__(self, attn_dim=8, max_rel_idx=32):
98
+ super().__init__()
99
+ self.max_rel_idx = max_rel_idx
100
+ self.n_rel_pos = 2 * self.max_rel_idx + 1
101
+ self.linear = nn.Linear(self.n_rel_pos, attn_dim)
102
+
103
+ def forward(self, residue_index):
104
+ d_ij = residue_index[..., None] - residue_index[..., None, :]
105
+ v_bins = torch.arange(self.n_rel_pos).to(d_ij.device) - self.max_rel_idx
106
+ idxs = (d_ij[..., None] - v_bins[None, None]).abs().argmin(-1)
107
+ p_ij = nn.functional.one_hot(idxs, num_classes=self.n_rel_pos)
108
+ embeddings = self.linear(p_ij.float())
109
+ return embeddings
110
+
111
+
112
+ ########################################
113
+ # Adapted from https://github.com/NVlabs/edm
114
+
115
+
116
+ class Noise_Embedding(nn.Module):
117
+ def __init__(self, num_channels, max_positions=10000, endpoint=False):
118
+ super().__init__()
119
+ self.num_channels = num_channels
120
+ self.max_positions = max_positions
121
+ self.endpoint = endpoint
122
+
123
+ def forward(self, x):
124
+ freqs = torch.arange(
125
+ start=0, end=self.num_channels // 2, dtype=torch.float32, device=x.device
126
+ )
127
+ freqs = freqs / (self.num_channels // 2 - (1 if self.endpoint else 0))
128
+ freqs = (1 / self.max_positions) ** freqs
129
+ x = x.outer(freqs.to(x.dtype))
130
+ x = torch.cat([x.cos(), x.sin()], dim=1)
131
+ return x
132
+
133
+
134
+ ########################################
135
+ # Adapted from github.com/lucidrains
136
+ # https://github.com/lucidrains/denoising-diffusion-pytorch
137
+ # https://github.com/lucidrains/recurrent-interface-network-pytorch
138
+
139
+
140
+ def exists(x):
141
+ return x is not None
142
+
143
+
144
+ def default(val, d):
145
+ if exists(val):
146
+ return val
147
+ return d() if callable(d) else d
148
+
149
+
150
+ def posemb_sincos_1d(patches, temperature=10000, residue_index=None):
151
+ _, n, dim, device, dtype = *patches.shape, patches.device, patches.dtype
152
+
153
+ n = torch.arange(n, device=device) if residue_index is None else residue_index
154
+ assert (dim % 2) == 0, "feature dimension must be multiple of 2 for sincos emb"
155
+ omega = torch.arange(dim // 2, device=device) / (dim // 2 - 1)
156
+ omega = 1.0 / (temperature**omega)
157
+
158
+ n = n[..., None] * omega
159
+ pe = torch.cat((n.sin(), n.cos()), dim=-1)
160
+ return pe.type(dtype)
161
+
162
+
163
+ class LayerNorm(nn.Module):
164
+ def __init__(self, dim):
165
+ super().__init__()
166
+ self.gamma = nn.Parameter(torch.ones(dim))
167
+ self.register_buffer("beta", torch.zeros(dim))
168
+
169
+ def forward(self, x):
170
+ return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)
171
+
172
+
173
+ class NoiseConditioningBlock(nn.Module):
174
+ def __init__(self, n_in_channel, n_out_channel):
175
+ super().__init__()
176
+ self.block = nn.Sequential(
177
+ Noise_Embedding(n_in_channel),
178
+ nn.Linear(n_in_channel, n_out_channel),
179
+ nn.SiLU(),
180
+ nn.Linear(n_out_channel, n_out_channel),
181
+ Rearrange("b d -> b 1 d"),
182
+ )
183
+
184
+ def forward(self, noise_level):
185
+ return self.block(noise_level)
186
+
187
+
188
+ class TimeCondResnetBlock(nn.Module):
189
+ def __init__(
190
+ self, nic, noc, cond_nc, conv_layer=nn.Conv2d, dropout=0.1, n_norm_in_groups=4
191
+ ):
192
+ super().__init__()
193
+ self.block1 = nn.Sequential(
194
+ nn.GroupNorm(num_groups=nic // n_norm_in_groups, num_channels=nic),
195
+ nn.SiLU(),
196
+ conv_layer(nic, noc, 3, 1, 1),
197
+ )
198
+ self.cond_proj = nn.Linear(cond_nc, noc * 2)
199
+ self.mid_norm = nn.GroupNorm(num_groups=noc // 4, num_channels=noc)
200
+ self.dropout = dropout if dropout is None else nn.Dropout(dropout)
201
+ self.block2 = nn.Sequential(
202
+ nn.GroupNorm(num_groups=noc // 4, num_channels=noc),
203
+ nn.SiLU(),
204
+ conv_layer(noc, noc, 3, 1, 1),
205
+ )
206
+ self.mismatch = False
207
+ if nic != noc:
208
+ self.mismatch = True
209
+ self.conv_match = conv_layer(nic, noc, 1, 1, 0)
210
+
211
+ def forward(self, x, time=None):
212
+ h = self.block1(x)
213
+
214
+ if time is not None:
215
+ h = self.mid_norm(h)
216
+ scale, shift = self.cond_proj(time).chunk(2, dim=-1)
217
+ h = (h * (utils.expand(scale, h) + 1)) + utils.expand(shift, h)
218
+
219
+ if self.dropout is not None:
220
+ h = self.dropout(h)
221
+
222
+ h = self.block2(h)
223
+
224
+ if self.mismatch:
225
+ x = self.conv_match(x)
226
+
227
+ return x + h
228
+
229
+
230
+ class TimeCondAttention(nn.Module):
231
+ def __init__(
232
+ self,
233
+ dim,
234
+ dim_context=None,
235
+ heads=4,
236
+ dim_head=32,
237
+ norm=False,
238
+ norm_context=False,
239
+ time_cond_dim=None,
240
+ attn_bias_dim=None,
241
+ rotary_embedding_module=None,
242
+ ):
243
+ super().__init__()
244
+ hidden_dim = dim_head * heads
245
+ dim_context = default(dim_context, dim)
246
+
247
+ self.time_cond = None
248
+
249
+ if exists(time_cond_dim):
250
+ self.time_cond = nn.Sequential(nn.SiLU(), nn.Linear(time_cond_dim, dim * 2))
251
+
252
+ nn.init.zeros_(self.time_cond[-1].weight)
253
+ nn.init.zeros_(self.time_cond[-1].bias)
254
+
255
+ self.scale = dim_head**-0.5
256
+ self.heads = heads
257
+
258
+ self.norm = LayerNorm(dim) if norm else nn.Identity()
259
+ self.norm_context = LayerNorm(dim_context) if norm_context else nn.Identity()
260
+
261
+ self.attn_bias_proj = None
262
+ if attn_bias_dim is not None:
263
+ self.attn_bias_proj = nn.Sequential(
264
+ Rearrange("b a i j -> b i j a"),
265
+ nn.Linear(attn_bias_dim, heads),
266
+ Rearrange("b i j a -> b a i j"),
267
+ )
268
+
269
+ self.to_q = nn.Linear(dim, hidden_dim, bias=False)
270
+ self.to_kv = nn.Linear(dim_context, hidden_dim * 2, bias=False)
271
+ self.to_out = nn.Linear(hidden_dim, dim, bias=False)
272
+ nn.init.zeros_(self.to_out.weight)
273
+
274
+ self.use_rope = False
275
+ if rotary_embedding_module is not None:
276
+ self.use_rope = True
277
+ self.rope = rotary_embedding_module
278
+
279
+ def forward(self, x, context=None, time=None, attn_bias=None, seq_mask=None):
280
+ # attn_bias is b, c, i, j
281
+ h = self.heads
282
+ has_context = exists(context)
283
+
284
+ context = default(context, x)
285
+
286
+ if x.shape[-1] != self.norm.gamma.shape[-1]:
287
+ print(context.shape, x.shape, self.norm.gamma.shape)
288
+
289
+ x = self.norm(x)
290
+
291
+ if exists(time):
292
+ scale, shift = self.time_cond(time).chunk(2, dim=-1)
293
+ x = (x * (scale + 1)) + shift
294
+
295
+ if has_context:
296
+ context = self.norm_context(context)
297
+
298
+ if seq_mask is not None:
299
+ x = x * seq_mask[..., None]
300
+
301
+ qkv = (self.to_q(x), *self.to_kv(context).chunk(2, dim=-1))
302
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), qkv)
303
+
304
+ q = q * self.scale
305
+
306
+ if self.use_rope:
307
+ q = self.rope.rotate_queries_or_keys(q)
308
+ k = self.rope.rotate_queries_or_keys(k)
309
+
310
+ sim = torch.einsum("b h i d, b h j d -> b h i j", q, k)
311
+ if attn_bias is not None:
312
+ if self.attn_bias_proj is not None:
313
+ attn_bias = self.attn_bias_proj(attn_bias)
314
+ sim += attn_bias
315
+ if seq_mask is not None:
316
+ attn_mask = torch.einsum("b i, b j -> b i j", seq_mask, seq_mask)[:, None]
317
+ sim -= (1 - attn_mask) * 1e6
318
+ attn = sim.softmax(dim=-1)
319
+
320
+ out = torch.einsum("b h i j, b h j d -> b h i d", attn, v)
321
+ out = rearrange(out, "b h n d -> b n (h d)")
322
+ out = self.to_out(out)
323
+ if seq_mask is not None:
324
+ out = out * seq_mask[..., None]
325
+ return out
326
+
327
+
328
+ class TimeCondFeedForward(nn.Module):
329
+ def __init__(self, dim, mult=4, dim_out=None, time_cond_dim=None, dropout=0.1):
330
+ super().__init__()
331
+ if dim_out is None:
332
+ dim_out = dim
333
+ self.norm = LayerNorm(dim)
334
+
335
+ self.time_cond = None
336
+ self.dropout = None
337
+ inner_dim = int(dim * mult)
338
+
339
+ if exists(time_cond_dim):
340
+ self.time_cond = nn.Sequential(
341
+ nn.SiLU(),
342
+ nn.Linear(time_cond_dim, inner_dim * 2),
343
+ )
344
+
345
+ nn.init.zeros_(self.time_cond[-1].weight)
346
+ nn.init.zeros_(self.time_cond[-1].bias)
347
+
348
+ self.linear_in = nn.Linear(dim, inner_dim)
349
+ self.nonlinearity = nn.SiLU()
350
+ if dropout is not None:
351
+ self.dropout = nn.Dropout(dropout)
352
+ self.linear_out = nn.Linear(inner_dim, dim_out)
353
+ nn.init.zeros_(self.linear_out.weight)
354
+ nn.init.zeros_(self.linear_out.bias)
355
+
356
+ def forward(self, x, time=None):
357
+ x = self.norm(x)
358
+ x = self.linear_in(x)
359
+ x = self.nonlinearity(x)
360
+
361
+ if exists(time):
362
+ scale, shift = self.time_cond(time).chunk(2, dim=-1)
363
+ x = (x * (scale + 1)) + shift
364
+
365
+ if exists(self.dropout):
366
+ x = self.dropout(x)
367
+
368
+ return self.linear_out(x)
369
+
370
+
371
+ class TimeCondTransformer(nn.Module):
372
+ def __init__(
373
+ self,
374
+ dim,
375
+ depth,
376
+ heads,
377
+ dim_head,
378
+ time_cond_dim,
379
+ attn_bias_dim=None,
380
+ mlp_inner_dim_mult=4,
381
+ position_embedding_type: str = "rotary",
382
+ ):
383
+ super().__init__()
384
+
385
+ self.rope = None
386
+ self.pos_emb_type = position_embedding_type
387
+ if position_embedding_type == "rotary":
388
+ self.rope = RotaryEmbedding(dim=32)
389
+ elif position_embedding_type == "relative":
390
+ self.relpos = nn.Sequential(
391
+ RelativePositionalEncoding(attn_dim=heads),
392
+ Rearrange("b i j d -> b d i j"),
393
+ )
394
+
395
+ self.layers = nn.ModuleList([])
396
+ for _ in range(depth):
397
+ self.layers.append(
398
+ nn.ModuleList(
399
+ [
400
+ TimeCondAttention(
401
+ dim,
402
+ heads=heads,
403
+ dim_head=dim_head,
404
+ norm=True,
405
+ time_cond_dim=time_cond_dim,
406
+ attn_bias_dim=attn_bias_dim,
407
+ rotary_embedding_module=self.rope,
408
+ ),
409
+ TimeCondFeedForward(
410
+ dim, mlp_inner_dim_mult, time_cond_dim=time_cond_dim
411
+ ),
412
+ ]
413
+ )
414
+ )
415
+
416
+ def forward(
417
+ self,
418
+ x,
419
+ time=None,
420
+ attn_bias=None,
421
+ context=None,
422
+ seq_mask=None,
423
+ residue_index=None,
424
+ ):
425
+ if self.pos_emb_type == "absolute":
426
+ pos_emb = posemb_sincos_1d(x)
427
+ x = x + pos_emb
428
+ elif self.pos_emb_type == "absolute_residx":
429
+ assert residue_index is not None
430
+ pos_emb = posemb_sincos_1d(x, residue_index=residue_index)
431
+ x = x + pos_emb
432
+ elif self.pos_emb_type == "relative":
433
+ assert residue_index is not None
434
+ pos_emb = self.relpos(residue_index)
435
+ attn_bias = pos_emb if attn_bias is None else attn_bias + pos_emb
436
+ if seq_mask is not None:
437
+ x = x * seq_mask[..., None]
438
+
439
+ for i, (attn, ff) in enumerate(self.layers):
440
+ x = x + attn(
441
+ x, context=context, time=time, attn_bias=attn_bias, seq_mask=seq_mask
442
+ )
443
+ x = x + ff(x, time=time)
444
+ if seq_mask is not None:
445
+ x = x * seq_mask[..., None]
446
+
447
+ return x
448
+
449
+
450
+ class TimeCondUViT(nn.Module):
451
+ def __init__(
452
+ self,
453
+ *,
454
+ seq_len: int,
455
+ dim: int,
456
+ patch_size: int = 1,
457
+ depth: int = 6,
458
+ heads: int = 8,
459
+ dim_head: int = 32,
460
+ n_filt_per_layer: List[int] = [],
461
+ n_blocks_per_layer: int = 2,
462
+ n_atoms: int = 37,
463
+ channels_per_atom: int = 6,
464
+ attn_bias_dim: int = None,
465
+ time_cond_dim: int = None,
466
+ conv_skip_connection: bool = False,
467
+ position_embedding_type: str = "rotary",
468
+ ):
469
+ super().__init__()
470
+
471
+ # Initialize configuration params
472
+ if time_cond_dim is None:
473
+ time_cond_dim = dim * 4
474
+ self.position_embedding_type = position_embedding_type
475
+ channels = channels_per_atom
476
+ self.n_conv_layers = n_conv_layers = len(n_filt_per_layer)
477
+ if n_conv_layers > 0:
478
+ post_conv_filt = n_filt_per_layer[-1]
479
+ self.conv_skip_connection = conv_skip_connection and n_conv_layers == 1
480
+ transformer_seq_len = seq_len // (2**n_conv_layers)
481
+ assert transformer_seq_len % patch_size == 0
482
+ num_patches = transformer_seq_len // patch_size
483
+ dim_a = post_conv_atom_dim = max(1, n_atoms // (2 ** (n_conv_layers - 1)))
484
+ if n_conv_layers == 0:
485
+ patch_dim = patch_size * n_atoms * channels_per_atom
486
+ patch_dim_out = patch_size * n_atoms * 3
487
+ dim_a = n_atoms
488
+ elif conv_skip_connection and n_conv_layers == 1:
489
+ patch_dim = patch_size * (channels + post_conv_filt) * post_conv_atom_dim
490
+ patch_dim_out = patch_size * post_conv_filt * post_conv_atom_dim
491
+ elif n_conv_layers > 0:
492
+ patch_dim = patch_dim_out = patch_size * post_conv_filt * post_conv_atom_dim
493
+
494
+ # Make downsampling conv
495
+ # Downsamples n-1 times where n is n_conv_layers
496
+ down_conv = []
497
+ block_in = channels
498
+ for i, nf in enumerate(n_filt_per_layer):
499
+ block_out = nf
500
+ layer = []
501
+ for j in range(n_blocks_per_layer):
502
+ n_groups = 2 if i == 0 and j == 0 else 4
503
+ layer.append(
504
+ TimeCondResnetBlock(
505
+ block_in, block_out, time_cond_dim, n_norm_in_groups=n_groups
506
+ )
507
+ )
508
+ block_in = block_out
509
+ down_conv.append(nn.ModuleList(layer))
510
+ self.down_conv = nn.ModuleList(down_conv)
511
+
512
+ # Make transformer
513
+ self.to_patch_embedding = nn.Sequential(
514
+ Rearrange("b c (n p) a -> b n (p c a)", p=patch_size),
515
+ nn.Linear(patch_dim, dim),
516
+ LayerNorm(dim),
517
+ )
518
+ self.transformer = TimeCondTransformer(
519
+ dim,
520
+ depth,
521
+ heads,
522
+ dim_head,
523
+ time_cond_dim,
524
+ attn_bias_dim=attn_bias_dim,
525
+ position_embedding_type=position_embedding_type,
526
+ )
527
+ self.from_patch = nn.Sequential(
528
+ LayerNorm(dim),
529
+ nn.Linear(dim, patch_dim_out),
530
+ Rearrange("b n (p c a) -> b c (n p) a", p=patch_size, a=dim_a),
531
+ )
532
+ nn.init.zeros_(self.from_patch[-2].weight)
533
+ nn.init.zeros_(self.from_patch[-2].bias)
534
+
535
+ # Make upsampling conv
536
+ up_conv = []
537
+ for i, nf in enumerate(reversed(n_filt_per_layer)):
538
+ skip_in = nf
539
+ block_out = nf
540
+ layer = []
541
+ for j in range(n_blocks_per_layer):
542
+ layer.append(
543
+ TimeCondResnetBlock(block_in + skip_in, block_out, time_cond_dim)
544
+ )
545
+ block_in = block_out
546
+ up_conv.append(nn.ModuleList(layer))
547
+ self.up_conv = nn.ModuleList(up_conv)
548
+
549
+ # Conv out
550
+ if n_conv_layers > 0:
551
+ self.conv_out = nn.Sequential(
552
+ nn.GroupNorm(num_groups=block_out // 4, num_channels=block_out),
553
+ nn.SiLU(),
554
+ nn.Conv2d(block_out, channels // 2, 3, 1, 1),
555
+ )
556
+
557
+ def forward(
558
+ self, coords, time_cond, pair_bias=None, seq_mask=None, residue_index=None
559
+ ):
560
+ if self.n_conv_layers > 0: # pad up to even dims
561
+ coords = F.pad(coords, (0, 0, 0, 0, 0, 1, 0, 0))
562
+
563
+ x = rearr_coords = rearrange(coords, "b n a c -> b c n a")
564
+ hiddens = []
565
+ for i, layer in enumerate(self.down_conv):
566
+ for block in layer:
567
+ x = block(x, time=time_cond)
568
+ hiddens.append(x)
569
+ if i != self.n_conv_layers - 1:
570
+ x = downsample(x)
571
+
572
+ if self.conv_skip_connection:
573
+ x = torch.cat([x, rearr_coords], 1)
574
+
575
+ x = self.to_patch_embedding(x)
576
+ # if self.position_embedding_type == 'absolute':
577
+ # pos_emb = posemb_sincos_1d(x)
578
+ # x = x + pos_emb
579
+ if seq_mask is not None and x.shape[1] == seq_mask.shape[1]:
580
+ x *= seq_mask[..., None]
581
+ x = self.transformer(
582
+ x,
583
+ time=time_cond,
584
+ attn_bias=pair_bias,
585
+ seq_mask=seq_mask,
586
+ residue_index=residue_index,
587
+ )
588
+ x = self.from_patch(x)
589
+
590
+ for i, layer in enumerate(self.up_conv):
591
+ for block in layer:
592
+ x = torch.cat([x, hiddens.pop()], 1)
593
+ x = block(x, time=time_cond)
594
+ if i != self.n_conv_layers - 1:
595
+ x = upsample_coords(x, hiddens[-1].shape[2:])
596
+
597
+ if self.n_conv_layers > 0:
598
+ x = self.conv_out(x)
599
+ x = x[..., :-1, :] # drop even-dims padding
600
+
601
+ x = rearrange(x, "b c n a -> b n a c")
602
+ return x
603
+
604
+
605
+ ########################################
606
+
607
+
608
+ class LinearWarmupCosineDecay(torch.optim.lr_scheduler._LRScheduler):
609
+ def __init__(
610
+ self,
611
+ optimizer,
612
+ max_lr,
613
+ warmup_steps=1000,
614
+ decay_steps=int(1e6),
615
+ min_lr=1e-6,
616
+ **kwargs,
617
+ ):
618
+ self.max_lr = max_lr
619
+ self.min_lr = min_lr
620
+ self.warmup_steps = warmup_steps
621
+ self.decay_steps = decay_steps
622
+ self.total_steps = warmup_steps + decay_steps
623
+ super(LinearWarmupCosineDecay, self).__init__(optimizer, **kwargs)
624
+
625
+ def get_lr(self):
626
+ # TODO double check for off-by-one errors
627
+ if self.last_epoch < self.warmup_steps:
628
+ curr_lr = self.last_epoch / self.warmup_steps * self.max_lr
629
+ return [curr_lr for group in self.optimizer.param_groups]
630
+ elif self.last_epoch < self.total_steps:
631
+ time = (self.last_epoch - self.warmup_steps) / self.decay_steps * np.pi
632
+ curr_lr = self.min_lr + (self.max_lr - self.min_lr) * 0.5 * (
633
+ 1 + np.cos(time)
634
+ )
635
+ return [curr_lr for group in self.optimizer.param_groups]
636
+ else:
637
+ return [self.min_lr for group in self.optimizer.param_groups]
638
+
639
+
640
+ class NoiseConditionalProteinMPNN(nn.Module):
641
+ def __init__(
642
+ self,
643
+ n_channel=128,
644
+ n_layers=3,
645
+ n_neighbors=32,
646
+ time_cond_dim=None,
647
+ vocab_size=21,
648
+ input_S_is_embeddings=False,
649
+ ):
650
+ super().__init__()
651
+ self.n_channel = n_channel
652
+ self.n_layers = n_layers
653
+ self.n_neighbors = n_neighbors
654
+ self.time_cond_dim = time_cond_dim
655
+ self.vocab_size = vocab_size
656
+ self.bb_idxs_if_atom37 = [
657
+ residue_constants.atom_order[a] for a in ["N", "CA", "C", "O"]
658
+ ]
659
+
660
+ self.mpnn = protein_mpnn.ProteinMPNN(
661
+ num_letters=vocab_size,
662
+ node_features=n_channel,
663
+ edge_features=n_channel,
664
+ hidden_dim=n_channel,
665
+ num_encoder_layers=n_layers,
666
+ num_decoder_layers=n_layers,
667
+ vocab=vocab_size,
668
+ k_neighbors=n_neighbors,
669
+ augment_eps=0.0,
670
+ dropout=0.1,
671
+ ca_only=False,
672
+ time_cond_dim=time_cond_dim,
673
+ input_S_is_embeddings=input_S_is_embeddings,
674
+ )
675
+
676
+ def forward(
677
+ self, denoised_coords, noisy_aatype, seq_mask, residue_index, time_cond
678
+ ):
679
+ if denoised_coords.shape[-2] == 37:
680
+ denoised_coords = denoised_coords[:, :, self.bb_idxs_if_atom37]
681
+
682
+ node_embs, encoder_embs = self.mpnn(
683
+ X=denoised_coords,
684
+ S=noisy_aatype,
685
+ mask=seq_mask,
686
+ chain_M=seq_mask,
687
+ residue_idx=residue_index,
688
+ chain_encoding_all=seq_mask,
689
+ randn=None,
690
+ use_input_decoding_order=False,
691
+ decoding_order=None,
692
+ causal_mask=False,
693
+ time_cond=time_cond,
694
+ return_node_embs=True,
695
+ )
696
+ return node_embs, encoder_embs
output_helpers.py ADDED
The diff for this file is too large to render. See raw diff
 
package.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ dssp
protpardelle_pymol.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pymol import cmd
2
+
3
+ import os
4
+ import json
5
+ import time
6
+ import threading
7
+
8
+ try:
9
+ from gradio_client import Client
10
+ except ImportError:
11
+ print("gradio_client not installed, trying install:")
12
+ import pip
13
+ pip.main(['install', 'gradio_client'])
14
+ from gradio_client import Client
15
+
16
+
17
+ if os.environ.get("GRADIO_LOCAL") != None:
18
+ public_link = "http://127.0.0.1:7862"
19
+ else:
20
+ public_link = "spacesplaceholder"
21
+
22
+
23
+
24
+
25
+ def thread_protpardelle(input_pdb,
26
+ resample_idxs,
27
+ modeltype,
28
+ mode,
29
+ minlen=50,
30
+ maxlen= 60,
31
+ steplen = 2,
32
+ per_len = 2):
33
+ client = Client(public_link)
34
+
35
+ job = client.submit(
36
+ input_pdb, # str in 'PDB Content' Textbox component
37
+ modeltype, # str in 'Choose a Mode' Radio component
38
+ f'"{resample_idxs}"', # str in 'Resampled Idxs' Textbox component
39
+ mode, # str (Option from: ['backbone', 'allatom'])
40
+ minlen, # int | float (numeric value between 2 and 200) minlen
41
+ maxlen, # int | float (numeric value between 3 and 200) in 'maxlen' Slider component
42
+ steplen, # int | float (numeric value between 1 and 50) in 'steplen' Slider component
43
+ per_len, # int | float (numeric value between 1 and 200) in 'perlen' Slider component
44
+ api_name="/protpardelle"
45
+ )
46
+ #start time
47
+ start = time.time()
48
+
49
+ while (job.done() == False):
50
+ status = job.status()
51
+ elapsed = time.time()-start
52
+ # format as hh:mm:ss
53
+ elapsed = time.strftime("%H:%M:%S", time.gmtime(elapsed))
54
+
55
+ print(f"\r protpardelle running since {elapsed}", end="")
56
+ time.sleep(1)
57
+ results = job.result()
58
+
59
+ # load each result into pymol
60
+ results = json.loads(results)
61
+
62
+ for (name,pdb_content) in results:
63
+ print(name)
64
+ cmd.read_pdbstr(pdb_content, os.path.basename(name))
65
+
66
+
67
+ def query_protpardelle(
68
+ name_of_input: str,
69
+ selection_resample_idxs: str="",
70
+ per_len: int = 2,
71
+ mode: str="allatom",
72
+ ):
73
+ """
74
+ AUTHOR
75
+ Simon Duerr
76
+ https://twitter.com/simonduerr
77
+ DESCRIPTION
78
+ Run Protpardelle
79
+ USAGE
80
+ protpardelle name_of_input, selection_resampled_idx, modeltype, mode, per_len
81
+ PARAMETERS
82
+ name_of_input = string: name of input object
83
+ selection_resampled_idx = string: selection of resampled protein residues
84
+ per_len = int: per_len (default: 2)
85
+ mode = string: mode (default: 'allatom')
86
+ """
87
+ if name_of_input != "":
88
+ input_pdb = cmd.get_pdbstr(name_of_input)
89
+
90
+ all_aa = cmd.index(name_of_input+" and name CA")
91
+ idx = cmd.index(selection_resample_idxs+" and name CA")
92
+
93
+ #map to zero indexed values
94
+ aa_mapping = {aa[1]:i for i,aa in enumerate(all_aa)}
95
+
96
+ idx = ",".join([str(aa_mapping[aa[1]]) for aa in idx])
97
+
98
+ print("resampling", idx , "(zero indexed) from", name_of_input)
99
+
100
+ t = threading.Thread(target=thread_protpardelle,
101
+ args=(input_pdb, idx, "conditional",mode ),
102
+ kwargs={'per_len':per_len},
103
+ daemon=True)
104
+ t.start()
105
+
106
+ def query_protpardelle_uncond(
107
+
108
+ minlen: int = 50,
109
+ maxlen: int = 60,
110
+ steplen: int = 2,
111
+ per_len: int = 2,
112
+ mode: str="allatom",
113
+ ):
114
+ """
115
+ AUTHOR
116
+ Simon Duerr
117
+ https://twitter.com/simonduerr
118
+ DESCRIPTION
119
+ Run Protpardelle
120
+ USAGE
121
+ protpardelle_uncond minlen, maxlen, steplen, per_len,mode
122
+ PARAMETERS
123
+ minlen = int: minlen
124
+ maxlen = int: maxlen
125
+ steplen = int: steplen
126
+ per_len = int: per_len
127
+ mode = string: mode (default: 'allatom')
128
+ """
129
+
130
+ modeltype = "unconditional"
131
+ idx = None
132
+ input_pdb = None
133
+
134
+ t = threading.Thread(target=thread_protpardelle,
135
+ args=(input_pdb, idx, modeltype, mode),
136
+ kwargs={'minlen':minlen, 'maxlen':maxlen, 'steplen':steplen,'per_len':per_len},
137
+ daemon=True)
138
+ t.start()
139
+
140
+
141
+
142
+ def setprotpardellelink(link:str):
143
+ global public_link
144
+ try:
145
+ client = Client(link)
146
+ except:
147
+ print("could not connect to:", public_link)
148
+
149
+ public_link = link
150
+
151
+
152
+
153
+
154
+ cmd.extend("protpardelle_setlink", setprotpardellelink)
155
+
156
+ cmd.extend("protpardelle", query_protpardelle)
157
+
158
+ cmd.extend("protpardelle_uncond", query_protpardelle_uncond)
159
+
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==1.12.1+cu116
2
+ transformers==4.29.1
3
+ einops
4
+ tqdm
5
+ wandb
6
+ rotary-embedding-torch
7
+ biopython
8
+ scipy
9
+ dm-tree
10
+ matplotlib
11
+ seaborn
12
+ black
13
+ ipython
14
+ --extra-index-url https://download.pytorch.org/whl/cu116
sampling.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ https://github.com/ProteinDesignLab/protpardelle
3
+ License: MIT
4
+ Author: Alex Chu
5
+
6
+ Configs and convenience functions for wrapping the model sample() function.
7
+ """
8
+ import argparse
9
+ import time
10
+ from typing import Optional, Tuple
11
+
12
+ import torch
13
+ from torchtyping import TensorType
14
+
15
+ from core import residue_constants
16
+ from core import utils
17
+ import diffusion
18
+
19
+
20
+ def default_backbone_sampling_config():
21
+ config = argparse.Namespace(
22
+ n_steps=500,
23
+ s_churn=200,
24
+ step_scale=1.2,
25
+ sidechain_mode=False,
26
+ noise_schedule=lambda t: diffusion.noise_schedule(t, s_max=80, s_min=0.001),
27
+ )
28
+ return config
29
+
30
+
31
+ def default_allatom_sampling_config():
32
+ noise_schedule = lambda t: diffusion.noise_schedule(t, s_max=80, s_min=0.001)
33
+ stage2 = argparse.Namespace(
34
+ apply_cond_proportion=1.0,
35
+ n_steps=200,
36
+ s_churn=100,
37
+ step_scale=1.2,
38
+ sidechain_mode=True,
39
+ skip_mpnn_proportion=1.0,
40
+ noise_schedule=noise_schedule,
41
+ )
42
+ config = argparse.Namespace(
43
+ n_steps=500,
44
+ s_churn=200,
45
+ step_scale=1.2,
46
+ sidechain_mode=True,
47
+ skip_mpnn_proportion=0.6,
48
+ use_fullmpnn=False,
49
+ use_fullmpnn_for_final=True,
50
+ anneal_seq_resampling_rate="linear",
51
+ noise_schedule=noise_schedule,
52
+ stage_2=stage2,
53
+ )
54
+ return config
55
+
56
+
57
+ def draw_backbone_samples(
58
+ model: torch.nn.Module,
59
+ seq_mask: TensorType["b n", float] = None,
60
+ n_samples: int = None,
61
+ sample_length_range: Tuple[int] = (50, 512),
62
+ pdb_save_path: Optional[str] = None,
63
+ return_aux: bool = False,
64
+ return_sampling_runtime: bool = False,
65
+ **sampling_kwargs,
66
+ ):
67
+ device = model.device
68
+ if seq_mask is None:
69
+ assert n_samples is not None
70
+ seq_mask = model.make_seq_mask_for_sampling(
71
+ n_samples=n_samples,
72
+ min_len=sample_length_range[0],
73
+ max_len=sample_length_range[1],
74
+ )
75
+
76
+ start = time.time()
77
+ aux = model.sample(
78
+ seq_mask=seq_mask, return_last=False, return_aux=True, **sampling_kwargs
79
+ )
80
+ aux["runtime"] = time.time() - start
81
+ seq_lens = seq_mask.sum(-1).long()
82
+ cropped_samp_coords = [
83
+ s[: seq_lens[i], model.bb_idxs] for i, s in enumerate(aux["xt_traj"][-1])
84
+ ]
85
+
86
+ if pdb_save_path is not None:
87
+ gly_aatype = (seq_mask * residue_constants.restype_order["G"]).long()
88
+ trimmed_aatype = [a[: seq_lens[i]] for i, a in enumerate(gly_aatype)]
89
+ atom_mask = utils.atom37_mask_from_aatype(gly_aatype, seq_mask).cpu()
90
+ for i in range(len(cropped_samp_coords)):
91
+ utils.write_coords_to_pdb(
92
+ cropped_samp_coords[i],
93
+ f"{pdb_save_path}{i}.pdb",
94
+ batched=False,
95
+ aatype=trimmed_aatype[i],
96
+ atom_mask=atom_mask[i],
97
+ )
98
+
99
+ if return_aux:
100
+ return aux
101
+ else:
102
+ if return_sampling_runtime:
103
+ return cropped_samp_coords, seq_mask, aux["runtime"]
104
+ else:
105
+ return cropped_samp_coords, seq_mask
106
+
107
+
108
+ def draw_allatom_samples(
109
+ model: torch.nn.Module,
110
+ seq_mask: TensorType["b n", float] = None,
111
+ n_samples: int = None,
112
+ sample_length_range: Tuple[int] = (50, 512),
113
+ two_stage_sampling: bool = True,
114
+ pdb_save_path: Optional[str] = None,
115
+ return_aux: bool = False,
116
+ return_sampling_runtime: bool = False,
117
+ **sampling_kwargs,
118
+ ):
119
+ """Implement the default 2-stage all-atom sampling routine."""
120
+
121
+ def save_allatom_samples(aux, path):
122
+ seq_lens = aux["seq_mask"].sum(-1).long()
123
+ cropped_samp_coords = [
124
+ c[: seq_lens[i]] for i, c in enumerate(aux["xt_traj"][-1])
125
+ ]
126
+ cropped_samp_aatypes = [
127
+ s[: seq_lens[i]] for i, s in enumerate(aux["st_traj"][-1])
128
+ ]
129
+ samp_atom_mask = utils.atom37_mask_from_aatype(
130
+ aux["st_traj"][-1].to(device), seq_mask
131
+ )
132
+ samp_atom_mask = [m[: seq_lens[i]] for i, m in enumerate(samp_atom_mask)]
133
+ for i, c in enumerate(cropped_samp_coords):
134
+ utils.write_coords_to_pdb(
135
+ c,
136
+ f"{path}{i}.pdb",
137
+ batched=False,
138
+ aatype=cropped_samp_aatypes[i],
139
+ atom_mask=samp_atom_mask[i],
140
+ conect=True,
141
+ )
142
+
143
+ device = model.device
144
+ if seq_mask is None:
145
+ assert n_samples is not None
146
+ seq_mask = model.make_seq_mask_for_sampling(
147
+ n_samples=n_samples,
148
+ min_len=sample_length_range[0],
149
+ max_len=sample_length_range[1],
150
+ )
151
+ sampling_runtime = 0.0
152
+
153
+ # Stage 1 sampling
154
+ start = time.time()
155
+ if "stage_2" in sampling_kwargs:
156
+ stage_2_kwargs = vars(sampling_kwargs.pop("stage_2"))
157
+ aux = model.sample(
158
+ seq_mask=seq_mask,
159
+ return_last=False,
160
+ return_aux=True,
161
+ **sampling_kwargs,
162
+ )
163
+ sampling_runtime = time.time() - start
164
+ if pdb_save_path is not None and two_stage_sampling:
165
+ save_allatom_samples(aux, pdb_save_path + "_init")
166
+
167
+ # Stage 2 sampling (sidechain refinement only)
168
+ if two_stage_sampling:
169
+ samp_seq = aux["st_traj"][-1]
170
+ samp_coords = aux["xt_traj"][-1]
171
+ cond_atom_mask = utils.atom37_mask_from_aatype((seq_mask * 7).long(), seq_mask)
172
+ aux = {f"stage1_{k}": v for k, v in aux.items()}
173
+ start = time.time()
174
+ stage2_aux = model.sample(
175
+ gt_cond_atom_mask=cond_atom_mask.to(device), # condition on backbone
176
+ gt_cond_seq_mask=seq_mask.to(device),
177
+ gt_coords=samp_coords.to(device),
178
+ gt_aatype=samp_seq.to(device),
179
+ seq_mask=seq_mask,
180
+ return_last=False,
181
+ return_aux=True,
182
+ **stage_2_kwargs,
183
+ )
184
+ sampling_runtime += time.time() - start
185
+ aux = {**aux, **stage2_aux}
186
+ if pdb_save_path is not None:
187
+ save_allatom_samples(aux, pdb_save_path + "_samp")
188
+ aux["runtime"] = sampling_runtime
189
+
190
+ # Process outputs, crop to correct length
191
+ if return_aux:
192
+ return aux
193
+ else:
194
+ xt_traj = aux["xt_traj"]
195
+ st_traj = aux["st_traj"]
196
+ seq_mask = aux["seq_mask"]
197
+ seq_lens = seq_mask.sum(-1).long()
198
+ cropped_samp_coords = [c[: seq_lens[i]] for i, c in enumerate(xt_traj[-1])]
199
+ cropped_samp_aatypes = [s[: seq_lens[i]] for i, s in enumerate(st_traj[-1])]
200
+ samp_atom_mask = utils.atom37_mask_from_aatype(st_traj[-1].to(device), seq_mask)
201
+ samp_atom_mask = [m[: seq_lens[i]] for i, m in enumerate(samp_atom_mask)]
202
+ orig_xt_traj = aux["stage1_xt_traj"]
203
+ stage1_coords = [c[: seq_lens[i]] for i, c in enumerate(orig_xt_traj[-1])]
204
+ ret = (
205
+ cropped_samp_coords,
206
+ cropped_samp_aatypes,
207
+ samp_atom_mask,
208
+ stage1_coords,
209
+ seq_mask,
210
+ )
211
+ if return_sampling_runtime:
212
+ ret = ret + (sampling_runtime,)
213
+ return ret