Simon Duerr commited on
Commit
0605e17
0 Parent(s):

first commit

Browse files
Files changed (5) hide show
  1. README.md +13 -0
  2. app.py +660 -0
  3. packages.txt +1 -0
  4. requirements.txt +5 -0
  5. rosettafold_pymol.py +168 -0
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: RoseTTAfold2
3
+ emoji: 🏢
4
+ colorFrom: indigo
5
+ colorTo: purple
6
+ sdk: gradio
7
+ sdk_version: 3.33.1
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ ---
12
+
13
+
app.py ADDED
@@ -0,0 +1,660 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, time, sys
2
+
3
+
4
+ if not os.path.isfile("RF2_apr23.pt"):
5
+ # send param download into background
6
+ os.system(
7
+ "(apt-get install aria2; aria2c -q -x 16 https://colabfold.steineggerlab.workers.dev/RF2_apr23.pt) &"
8
+ )
9
+
10
+ if not os.path.isdir("RoseTTAFold2"):
11
+ print("install RoseTTAFold2")
12
+ os.system("git clone https://github.com/sokrypton/RoseTTAFold2.git")
13
+ os.system(
14
+ "cd RoseTTAFold2/SE3Transformer; pip -q install --no-cache-dir -r requirements.txt; pip -q install ."
15
+ )
16
+ os.system(
17
+ "wget https://raw.githubusercontent.com/sokrypton/ColabFold/beta/colabfold/mmseqs/api.py"
18
+ )
19
+
20
+ # install hhsuite
21
+ print("install hhsuite")
22
+ os.makedirs("hhsuite", exist_ok=True)
23
+ os.system(
24
+ f"curl -fsSL https://github.com/soedinglab/hh-suite/releases/download/v3.3.0/hhsuite-3.3.0-SSE2-Linux.tar.gz | tar xz -C hhsuite/"
25
+ )
26
+
27
+
28
+ if os.path.isfile(f"RF2_apr23.pt.aria2"):
29
+ print("downloading RoseTTAFold2 params")
30
+ while os.path.isfile(f"RF2_apr23.pt.aria2"):
31
+ time.sleep(5)
32
+
33
+ os.environ["DGLBACKEND"] = "pytorch"
34
+ sys.path.append("RoseTTAFold2/network")
35
+ if "hhsuite" not in os.environ["PATH"]:
36
+ os.environ["PATH"] += ":hhsuite/bin:hhsuite/scripts"
37
+
38
+ import matplotlib.pyplot as plt
39
+ import numpy as np
40
+ from parsers import parse_a3m
41
+ from api import run_mmseqs2
42
+ import py3Dmol
43
+ import torch
44
+ from string import ascii_uppercase, ascii_lowercase
45
+ import hashlib, re, os
46
+ import random
47
+
48
+ from Bio.PDB import *
49
+
50
+
51
+ def get_hash(x):
52
+ return hashlib.sha1(x.encode()).hexdigest()
53
+
54
+
55
+ alphabet_list = list(ascii_uppercase + ascii_lowercase)
56
+ from collections import OrderedDict, Counter
57
+
58
+ import gradio as gr
59
+
60
+ if not "pred" in dir():
61
+ from predict import Predictor
62
+
63
+ print("compile RoseTTAFold2")
64
+ model_params = "RF2_apr23.pt"
65
+ if torch.cuda.is_available():
66
+ pred = Predictor(model_params, torch.device("cuda:0"))
67
+ else:
68
+ print("WARNING: using CPU")
69
+ pred = Predictor(model_params, torch.device("cpu"))
70
+
71
+
72
+ def get_unique_sequences(seq_list):
73
+ unique_seqs = list(OrderedDict.fromkeys(seq_list))
74
+ return unique_seqs
75
+
76
+
77
+ def get_msa(seq, jobname, cov=50, id=90, max_msa=2048, mode="unpaired_paired"):
78
+ assert mode in ["unpaired", "paired", "unpaired_paired"]
79
+ seqs = [seq] if isinstance(seq, str) else seq
80
+
81
+ # collapse homooligomeric sequences
82
+ counts = Counter(seqs)
83
+ u_seqs = list(counts.keys())
84
+ u_nums = list(counts.values())
85
+
86
+ # expand homooligomeric sequences
87
+ first_seq = "/".join(sum([[x] * n for x, n in zip(u_seqs, u_nums)], []))
88
+ msa = [first_seq]
89
+
90
+ path = os.path.join(jobname, "msa")
91
+ os.makedirs(path, exist_ok=True)
92
+ if mode in ["paired", "unpaired_paired"] and len(u_seqs) > 1:
93
+ print("getting paired MSA")
94
+ out_paired = run_mmseqs2(u_seqs, f"{path}/", use_pairing=True)
95
+ headers, sequences = [], []
96
+ for a3m_lines in out_paired:
97
+ n = -1
98
+ for line in a3m_lines.split("\n"):
99
+ if len(line) > 0:
100
+ if line.startswith(">"):
101
+ n += 1
102
+ if len(headers) < (n + 1):
103
+ headers.append([])
104
+ sequences.append([])
105
+ headers[n].append(line)
106
+ else:
107
+ sequences[n].append(line)
108
+ # filter MSA
109
+ with open(f"{path}/paired_in.a3m", "w") as handle:
110
+ for n, sequence in enumerate(sequences):
111
+ handle.write(f">n{n}\n{''.join(sequence)}\n")
112
+ os.system(
113
+ f"hhfilter -i {path}/paired_in.a3m -id {id} -cov {cov} -o {path}/paired_out.a3m"
114
+ )
115
+ with open(f"{path}/paired_out.a3m", "r") as handle:
116
+ for line in handle:
117
+ if line.startswith(">"):
118
+ n = int(line[2:])
119
+ xs = sequences[n]
120
+ # expand homooligomeric sequences
121
+ xs = ["/".join([x] * num) for x, num in zip(xs, u_nums)]
122
+ msa.append("/".join(xs))
123
+
124
+ if len(msa) < max_msa and (
125
+ mode in ["unpaired", "unpaired_paired"] or len(u_seqs) == 1
126
+ ):
127
+ print("getting unpaired MSA")
128
+ out = run_mmseqs2(u_seqs, f"{path}/")
129
+ Ls = [len(seq) for seq in u_seqs]
130
+ sub_idx = []
131
+ sub_msa = []
132
+ sub_msa_num = 0
133
+ for n, a3m_lines in enumerate(out):
134
+ sub_msa.append([])
135
+ with open(f"{path}/in_{n}.a3m", "w") as handle:
136
+ handle.write(a3m_lines)
137
+ # filter
138
+ os.system(
139
+ f"hhfilter -i {path}/in_{n}.a3m -id {id} -cov {cov} -o {path}/out_{n}.a3m"
140
+ )
141
+ with open(f"{path}/out_{n}.a3m", "r") as handle:
142
+ for line in handle:
143
+ if not line.startswith(">"):
144
+ xs = ["-" * l for l in Ls]
145
+ xs[n] = line.rstrip()
146
+ # expand homooligomeric sequences
147
+ xs = ["/".join([x] * num) for x, num in zip(xs, u_nums)]
148
+ sub_msa[-1].append("/".join(xs))
149
+ sub_msa_num += 1
150
+ sub_idx.append(list(range(len(sub_msa[-1]))))
151
+
152
+ while len(msa) < max_msa and sub_msa_num > 0:
153
+ for n in range(len(sub_idx)):
154
+ if len(sub_idx[n]) > 0:
155
+ msa.append(sub_msa[n][sub_idx[n].pop(0)])
156
+ sub_msa_num -= 1
157
+ if len(msa) == max_msa:
158
+ break
159
+
160
+ with open(f"{jobname}/msa.a3m", "w") as handle:
161
+ for n, sequence in enumerate(msa):
162
+ handle.write(f">n{n}\n{sequence}\n")
163
+
164
+
165
+ from Bio.PDB.PDBExceptions import PDBConstructionWarning
166
+ import warnings
167
+ from Bio.PDB import *
168
+ import numpy as np
169
+
170
+
171
+ def add_plddt_to_cif(best_plddts, best_plddt, best_seed, jobname):
172
+ pdb_parser = PDBParser()
173
+ warnings.filterwarnings("ignore", category=PDBConstructionWarning)
174
+ structure = pdb_parser.get_structure(
175
+ "pdb", f"{jobname}/rf2_seed{best_seed}_00_pred.pdb"
176
+ )
177
+ io = MMCIFIO()
178
+ io.set_structure(structure)
179
+ io.save(f"{jobname}/rf2_seed{best_seed}_00_pred.cif")
180
+ plddt_cif = f"""#
181
+ loop_
182
+ _ma_qa_metric.id
183
+ _ma_qa_metric.mode
184
+ _ma_qa_metric.name
185
+ _ma_qa_metric.software_group_id
186
+ _ma_qa_metric.type
187
+ 1 global pLDDT 1 pLDDT
188
+ 2 local pLDDT 1 pLDDT
189
+ #
190
+ _ma_qa_metric_global.metric_id 1
191
+ _ma_qa_metric_global.metric_value {best_plddt:.3f}
192
+ _ma_qa_metric_global.model_id 1
193
+ _ma_qa_metric_global.ordinal_id 1
194
+ #
195
+ loop_
196
+ _ma_qa_metric_local.label_asym_id
197
+ _ma_qa_metric_local.label_comp_id
198
+ _ma_qa_metric_local.label_seq_id
199
+ _ma_qa_metric_local.metric_id
200
+ _ma_qa_metric_local.metric_value
201
+ _ma_qa_metric_local.model_id
202
+ _ma_qa_metric_local.ordinal_id"""
203
+
204
+ for chain in structure[0]:
205
+ for i, residue in enumerate(chain):
206
+ plddt_cif += f"\n{chain.id} {residue.resname} {residue.id[1]} 2 {best_plddts[i]*100:.2f} 1 {residue.id[1]}"
207
+ plddt_cif += "\n#"
208
+ with open(f"{jobname}/rf2_seed{best_seed}_00_pred.cif", "a") as f:
209
+ f.write(plddt_cif)
210
+
211
+
212
+ def predict(
213
+ sequence,
214
+ jobname,
215
+ sym,
216
+ order,
217
+ msa_concat_mode,
218
+ msa_method,
219
+ pair_mode,
220
+ collapse_identical,
221
+ num_recycles,
222
+ use_mlm,
223
+ use_dropout,
224
+ max_msa,
225
+ random_seed,
226
+ num_models,
227
+ mode="web",
228
+ ):
229
+ if not os.path.exists("/home/user/app"): # crude check if on spaces
230
+ if len(sequence) > 600:
231
+ raise gr.Error(
232
+ f"Your sequence is too long ({len(sequence)}). "
233
+ "Please use the full version of RoseTTAfold2 directly from GitHub."
234
+ )
235
+ random_seed = int(random_seed)
236
+ num_models = int(num_models)
237
+ max_msa = int(max_msa)
238
+ num_recycles = int(num_recycles)
239
+ order = int(order)
240
+
241
+ max_extra_msa = max_msa * 8
242
+ sequence = re.sub("[^A-Z:]", "", sequence.replace("/", ":").upper())
243
+ sequence = re.sub(":+", ":", sequence)
244
+ sequence = re.sub("^[:]+", "", sequence)
245
+ sequence = re.sub("[:]+$", "", sequence)
246
+
247
+ if sym in ["X", "C"]:
248
+ copies = int(order)
249
+ elif sym in ["D"]:
250
+ copies = int(order) * 2
251
+ else:
252
+ copies = {"T": 12, "O": 24, "I": 60}[sym]
253
+ order = ""
254
+ symm = sym + str(order)
255
+
256
+ sequences = sequence.replace(":", "/").split("/")
257
+ if collapse_identical:
258
+ u_sequences = get_unique_sequences(sequences)
259
+ else:
260
+ u_sequences = sequences
261
+ sequences = sum([u_sequences] * copies, [])
262
+ lengths = [len(s) for s in sequences]
263
+
264
+ # TODO
265
+ subcrop = 1000 if sum(lengths) > 1400 else -1
266
+
267
+ sequence = "/".join(sequences)
268
+ jobname = jobname + "_" + symm + "_" + get_hash(sequence)[:5]
269
+
270
+ print(f"jobname: {jobname}")
271
+ print(f"lengths: {lengths}")
272
+
273
+ os.makedirs(jobname, exist_ok=True)
274
+ if msa_method == "mmseqs2":
275
+ get_msa(u_sequences, jobname, mode=pair_mode, max_msa=max_extra_msa)
276
+
277
+ elif msa_method == "single_sequence":
278
+ u_sequence = "/".join(u_sequences)
279
+ with open(f"{jobname}/msa.a3m", "w") as a3m:
280
+ a3m.write(f">{jobname}\n{u_sequence}\n")
281
+
282
+ elif msa_method == "custom_a3m":
283
+ print("upload custom a3m")
284
+ # msa_dict = files.upload()
285
+ lines = msa_dict[list(msa_dict.keys())[0]].decode().splitlines()
286
+ a3m_lines = []
287
+ for line in lines:
288
+ line = line.replace("\x00", "")
289
+ if len(line) > 0 and not line.startswith("#"):
290
+ a3m_lines.append(line)
291
+
292
+ with open(f"{jobname}/msa.a3m", "w") as a3m:
293
+ a3m.write("\n".join(a3m_lines))
294
+
295
+ best_plddt = None
296
+ best_seed = None
297
+ for seed in range(int(random_seed), int(random_seed) + int(num_models)):
298
+ torch.manual_seed(seed)
299
+ random.seed(seed)
300
+ np.random.seed(seed)
301
+ npz = f"{jobname}/rf2_seed{seed}_00.npz"
302
+ pred.predict(
303
+ inputs=[f"{jobname}/msa.a3m"],
304
+ out_prefix=f"{jobname}/rf2_seed{seed}",
305
+ symm=symm,
306
+ ffdb=None, # TODO (templates),
307
+ n_recycles=num_recycles,
308
+ msa_mask=0.15 if use_mlm else 0.0,
309
+ msa_concat_mode=msa_concat_mode,
310
+ nseqs=max_msa,
311
+ nseqs_full=max_extra_msa,
312
+ subcrop=subcrop,
313
+ is_training=use_dropout,
314
+ )
315
+ plddt = np.load(npz)["lddt"].mean()
316
+ if best_plddt is None or plddt > best_plddt:
317
+ best_plddt = plddt
318
+ best_plddts = np.load(npz)["lddt"]
319
+ best_seed = seed
320
+
321
+ if mode == "web":
322
+ # Mol* only displays AlphaFold plDDT if they are in a cif.
323
+ pdb_parser = PDBParser()
324
+ mmcif_parser = MMCIFParser()
325
+
326
+ plddt_cif = add_plddt_to_cif(best_plddts, best_plddt, best_seed, jobname)
327
+
328
+ return f"{jobname}/rf2_seed{best_seed}_00_pred.cif"
329
+ else:
330
+ # for api just return a pdb file
331
+ return f"{jobname}/rf2_seed{best_seed}_00_pred.pdb"
332
+
333
+
334
+ def predict_api(
335
+ sequence,
336
+ jobname,
337
+ sym,
338
+ order,
339
+ msa_concat_mode,
340
+ msa_method,
341
+ pair_mode,
342
+ collapse_identical,
343
+ num_recycles,
344
+ use_mlm,
345
+ use_dropout,
346
+ max_msa,
347
+ random_seed,
348
+ num_models,
349
+ ):
350
+ filename = predict(
351
+ sequence,
352
+ jobname,
353
+ sym,
354
+ order,
355
+ msa_concat_mode,
356
+ msa_method,
357
+ pair_mode,
358
+ collapse_identical,
359
+ num_recycles,
360
+ use_mlm,
361
+ use_dropout,
362
+ max_msa,
363
+ random_seed,
364
+ num_models,
365
+ mode="api",
366
+ )
367
+ with open(f"{filename}") as fp:
368
+ return fp.read()
369
+
370
+
371
+ def molecule(input_pdb, public_link):
372
+ print(input_pdb)
373
+ print(public_link + "/file=" + input_pdb)
374
+ link = public_link + "/file=" + input_pdb
375
+ x = (
376
+ """<!DOCTYPE html>
377
+ <html lang="en">
378
+ <head>
379
+ <meta charset="utf-8" />
380
+ <meta name="viewport" content="width=device-width, user-scalable=no, minimum-scale=1.0, maximum-scale=1.0">
381
+ <title>PDBe Molstar - Helper functions</title>
382
+ <!-- Molstar CSS & JS -->
383
+ <link rel="stylesheet" type="text/css" href="https://www.ebi.ac.uk/pdbe/pdb-component-library/css/pdbe-molstar-light-3.1.0.css">
384
+ <script type="text/javascript" src="https://www.ebi.ac.uk/pdbe/pdb-component-library/js/pdbe-molstar-plugin-3.1.0.js"></script>
385
+ <style>
386
+ * {
387
+ margin: 0;
388
+ padding: 0;
389
+ box-sizing: border-box;
390
+ }
391
+ .msp-plugin ::-webkit-scrollbar-thumb {
392
+ background-color: #474748 !important;
393
+ }
394
+ .viewerSection {
395
+ margin: 120px 0 0 0px;
396
+ }
397
+ #myViewer{
398
+ float:left;
399
+ width:100%;
400
+ height: 800px;
401
+ position:relative;
402
+ }
403
+ .btn{
404
+
405
+ font-family: "Open Sans", sans-serif;
406
+ display: inline-block;
407
+ outline: none;
408
+ cursor: pointer;
409
+ font-weight: 600;
410
+ border-radius: 3px;
411
+ padding: 12px 24px;
412
+ border: 0;
413
+ margin:0 10px;
414
+ line-height: 1.15;
415
+ font-size: 16px;
416
+ text-decoration: none;
417
+ }
418
+ .btn-orange{
419
+ background: #ff5000;
420
+ color: #fff;
421
+
422
+ }
423
+ .btn-gray{
424
+ color: #3a4149;
425
+ background: #e7ebee;
426
+
427
+ }
428
+ .btn:hover{
429
+ transition: all .1s ease;
430
+ box-shadow: 0 0 0 0 #fff, 0 0 0 3px #ddd;}
431
+ .text-center{
432
+ display: flex;
433
+ align-items: center;
434
+ justify-content: center;
435
+ padding: 20px 0;
436
+ }
437
+ .flex{
438
+ padding: 10px;
439
+ display: flex;
440
+ align-items: center;
441
+ justify-content: center;
442
+ width:fit-content;
443
+ }
444
+ .flex svg{
445
+ margin-right: 10px;
446
+ width:16px;
447
+ height:16px;
448
+ }
449
+ .flex a{
450
+ margin:0 10px;
451
+ }
452
+
453
+ </style>
454
+ </head>
455
+ <body>
456
+ <div class="text-center">
457
+ <a class="btn btn-orange flex" href=\""""
458
+ + link
459
+ + """\" target="_blank"> <svg fill="none" stroke="currentColor" stroke-width="1.5" viewBox="0 0 24 24" xmlns="http://www.w3.org/2000/svg" aria-hidden="true">
460
+ <path stroke-linecap="round" stroke-linejoin="round" d="M19.5 13.5L12 21m0 0l-7.5-7.5M12 21V3"></path>
461
+ </svg> <span>CIF File</span></a>
462
+ <a class="btn btn-gray flex" href=\""""
463
+ + link.replace(".cif", ".pdb")
464
+ + """\" target="_blank"> <svg fill="none" stroke="currentColor" stroke-width="1.5" viewBox="0 0 24 24" xmlns="http://www.w3.org/2000/svg" aria-hidden="true">
465
+ <path stroke-linecap="round" stroke-linejoin="round" d="M19.5 13.5L12 21m0 0l-7.5-7.5M12 21V3"></path>
466
+ </svg> <span>PDB File</span></a>
467
+
468
+ </div>
469
+ <div class="viewerSection">
470
+ <!-- Molstar container -->
471
+ <div id="myViewer"></div>
472
+
473
+ </div>
474
+ <script>
475
+ //Create plugin instance
476
+ var viewerInstance = new PDBeMolstarPlugin();
477
+
478
+ //Set options (Checkout available options list in the documentation)
479
+ var options = {
480
+ customData: {
481
+ url: \""""
482
+ + link
483
+ + """\",
484
+ format: "cif"
485
+ },
486
+ alphafoldView: true,
487
+ bgColor: {r:255, g:255, b:255},
488
+ //hideCanvasControls: ["selection", "animation", "controlToggle", "controlInfo"]
489
+ }
490
+
491
+ //Get element from HTML/Template to place the viewer
492
+ var viewerContainer = document.getElementById("myViewer");
493
+
494
+ //Call render method to display the 3D view
495
+ viewerInstance.render(viewerContainer, options);
496
+
497
+ </script>
498
+ </body>
499
+ </html>"""
500
+ )
501
+
502
+ return f"""<iframe style="width: 100%; height: 1000px" name="result" allow="midi; geolocation; microphone; camera;
503
+ display-capture; encrypted-media;" sandbox="allow-modals allow-forms
504
+ allow-scripts allow-same-origin allow-popups
505
+ allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
506
+ allowpaymentrequest="" frameborder="0" srcdoc='{x}'></iframe>"""
507
+
508
+
509
+ def predict_web(
510
+ sequence,
511
+ jobname,
512
+ sym,
513
+ order,
514
+ msa_concat_mode,
515
+ msa_method,
516
+ pair_mode,
517
+ collapse_identical,
518
+ num_recycles,
519
+ use_mlm,
520
+ use_dropout,
521
+ max_msa,
522
+ random_seed,
523
+ num_models,
524
+ ):
525
+ if os.path.exists("/home/user/app"):
526
+ public_link = "https://simonduerr-rosettafold2.hf.space/"
527
+ else:
528
+ public_link = "http://localhost:7860"
529
+
530
+ filename = predict(
531
+ sequence,
532
+ jobname,
533
+ sym,
534
+ order,
535
+ msa_concat_mode,
536
+ msa_method,
537
+ pair_mode,
538
+ collapse_identical,
539
+ num_recycles,
540
+ use_mlm,
541
+ use_dropout,
542
+ max_msa,
543
+ random_seed,
544
+ num_models,
545
+ mode="web",
546
+ )
547
+
548
+ return molecule(filename, public_link)
549
+
550
+
551
+ with gr.Blocks() as rosettafold:
552
+ gr.Markdown("# RoseTTAFold2")
553
+ gr.Markdown(
554
+ """If using please cite: [manuscript](https://www.biorxiv.org/content/10.1101/2023.05.24.542179v1)
555
+ <br> Heavily based on [RoseTTAFold2 ColabFold notebook](https://colab.research.google.com/github/sokrypton/ColabFold/blob/main/RoseTTAFold2.ipynb)"""
556
+ )
557
+ with gr.Accordion("How to use in PyMol", open=False):
558
+ gr.Markdown(
559
+ """```os.system('wget https://huggingface.co/spaces/simonduerr/rosettafold2/raw/main/rosettafold_pymol.py')
560
+ run rosettafold_pymol.py
561
+ rosettafold2 sequence, jobname, [sym, order, msa_concat_mode, msa_method, pair_mode, collapse_identical, num_recycles, use_mlm, use_dropout, max_msa, random_seed, num_models]
562
+ color_plddt jobname ```
563
+ """
564
+ )
565
+ sequence = gr.Textbox(
566
+ label="sequence",
567
+ value="PIAQIHILEGRSDEQKETLIREVSEAISRSLDAPLTSVRVIITEMAKGHFGIGGELASK",
568
+ )
569
+ jobname = gr.Textbox(label="jobname", value="test")
570
+
571
+ with gr.Accordion("Additional settings", open=False):
572
+ sym = gr.Textbox(label="sym", value="X")
573
+ order = gr.Slider(label="order", value=1, step=1, minimum=1, maximum=12)
574
+ msa_concat_mode = gr.Dropdown(
575
+ label="msa_concat_mode",
576
+ value="default",
577
+ choices=["diag", "repeat", "default"],
578
+ )
579
+
580
+ msa_method = gr.Dropdown(
581
+ label="msa_method",
582
+ value="single_sequence",
583
+ choices=[
584
+ "mmseqs2",
585
+ "single_sequence",
586
+ ], # dont allow custom a3m for now , "custom_a3m"
587
+ )
588
+ pair_mode = gr.Dropdown(
589
+ label="pair_mode",
590
+ value="unpaired_paired",
591
+ choices=["unpaired_paired", "paired", "unpaired"],
592
+ )
593
+
594
+ num_recycles = gr.Dropdown(
595
+ label="num_recycles", value="6", choices=["0", "1", "3", "6", "12", "24"]
596
+ )
597
+
598
+ use_mlm = gr.Checkbox(label="use_mlm", value=False)
599
+ use_dropout = gr.Checkbox(label="use_dropout", value=False)
600
+ collapse_identical = gr.Checkbox(label="collapse_identical", value=False)
601
+ max_msa = gr.Dropdown(
602
+ choices=["16", "32", "64", "128", "256", "512"],
603
+ value="16",
604
+ label="max_msa",
605
+ )
606
+ random_seed = gr.Textbox(label="random_seed", value=0)
607
+ num_models = gr.Dropdown(
608
+ label="num_models", value="1", choices=["1", "2", "4", "8", "16", "32"]
609
+ )
610
+
611
+ btn = gr.Button("Run", visible=False)
612
+ btn_web = gr.Button("Run")
613
+
614
+ output_plain = gr.HTML()
615
+ output = gr.HTML()
616
+
617
+ btn.click(
618
+ fn=predict_api,
619
+ inputs=[
620
+ sequence,
621
+ jobname,
622
+ sym,
623
+ order,
624
+ msa_concat_mode,
625
+ msa_method,
626
+ pair_mode,
627
+ collapse_identical,
628
+ num_recycles,
629
+ use_mlm,
630
+ use_dropout,
631
+ max_msa,
632
+ random_seed,
633
+ num_models,
634
+ ],
635
+ outputs=output_plain,
636
+ api_name="rosettafold2",
637
+ )
638
+ btn_web.click(
639
+ fn=predict_web,
640
+ inputs=[
641
+ sequence,
642
+ jobname,
643
+ sym,
644
+ order,
645
+ msa_concat_mode,
646
+ msa_method,
647
+ pair_mode,
648
+ collapse_identical,
649
+ num_recycles,
650
+ use_mlm,
651
+ use_dropout,
652
+ max_msa,
653
+ random_seed,
654
+ num_models,
655
+ ],
656
+ outputs=output,
657
+ )
658
+
659
+
660
+ rosettafold.launch(share=True, debug=True)
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ aria2
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ dgl==1.0.2+cu116
2
+ matplotlib
3
+ numpy
4
+ torch
5
+ -f https://data.dgl.ai/wheels/cu116/repo.html
rosettafold_pymol.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pymol import cmd
2
+ import requests
3
+
4
+
5
+ # from gradio_client import Client
6
+
7
+
8
+ def color_plddt(selection="all"):
9
+ """
10
+ AUTHOR
11
+ Jinyuan Sun
12
+ https://github.com/JinyuanSun/PymolFold/tree/main
13
+ MIT License
14
+
15
+ DESCRIPTION
16
+ Colors Predicted Structures by pLDDT
17
+
18
+ USAGE
19
+ color_plddt sele
20
+
21
+ PARAMETERS
22
+
23
+ sele (string)
24
+ The name of the selection/object to color by pLDDT. Default: all
25
+ """
26
+ # Alphafold color scheme for plddt
27
+ cmd.set_color("high_lddt_c", [0, 0.325490196078431, 0.843137254901961])
28
+ cmd.set_color(
29
+ "normal_lddt_c", [0.341176470588235, 0.792156862745098, 0.976470588235294]
30
+ )
31
+ cmd.set_color("medium_lddt_c", [1, 0.858823529411765, 0.070588235294118])
32
+ cmd.set_color("low_lddt_c", [1, 0.494117647058824, 0.270588235294118])
33
+
34
+ # test the scale of predicted_lddt (0~1 or 0~100 ) as b-factors
35
+ cmd.select("test_b_scale", f"b>1 and ({selection})")
36
+ b_scale = cmd.count_atoms("test_b_scale")
37
+
38
+ if b_scale > 0:
39
+ cmd.select("high_lddt", f"({selection}) and (b >90 or b =90)")
40
+ cmd.select("normal_lddt", f"({selection}) and ((b <90 and b >70) or (b =70))")
41
+ cmd.select("medium_lddt", f"({selection}) and ((b <70 and b >50) or (b=50))")
42
+ cmd.select("low_lddt", f"({selection}) and ((b <50 and b >0 ) or (b=0))")
43
+ else:
44
+ cmd.select("high_lddt", f"({selection}) and (b >.90 or b =.90)")
45
+ cmd.select(
46
+ "normal_lddt", f"({selection}) and ((b <.90 and b >.70) or (b =.70))"
47
+ )
48
+ cmd.select("medium_lddt", f"({selection}) and ((b <.70 and b >.50) or (b=.50))")
49
+ cmd.select("low_lddt", f"({selection}) and ((b <.50 and b >0 ) or (b=0))")
50
+
51
+ cmd.delete("test_b_scale")
52
+
53
+ # set color based on plddt values
54
+ cmd.color("high_lddt_c", "high_lddt")
55
+ cmd.color("normal_lddt_c", "normal_lddt")
56
+ cmd.color("medium_lddt_c", "medium_lddt")
57
+ cmd.color("low_lddt_c", "low_lddt")
58
+
59
+ # set background color
60
+ cmd.bg_color("white")
61
+
62
+
63
+ def query_rosettafold2(
64
+ sequence: str,
65
+ jobname: str,
66
+ sym: str = "X",
67
+ order: int = 1,
68
+ msa_concat_mode: str = "diag",
69
+ msa_method: str = "single_sequence",
70
+ pair_mode: str = "unpaired_paired",
71
+ collapse_identical: bool = True,
72
+ num_recycles: int = 0,
73
+ use_mlm: bool = True,
74
+ use_dropout: bool = True,
75
+ max_msa: int = 16,
76
+ random_seed: int = 0,
77
+ num_models: int = 0,
78
+ ):
79
+ """
80
+ AUTHOR
81
+ Simon Duerr
82
+ https://twitter.com/simonduerr
83
+
84
+
85
+ DESCRIPTION
86
+ Predict a structure using rosettafold2
87
+
88
+ USAGE
89
+ rosettafold2 sequence, jobname, [sym, order, msa_concat_mode, msa_method, pair_mode, collapse_identical, num_recycles, use_mlm, use_dropout, max_msa, random_seed, num_models]
90
+
91
+ PARAMETERS
92
+
93
+ sequence: (string)
94
+ one letter amino acid codes that you want to predict
95
+
96
+ jobname: string
97
+ name of the pdbfile that will be outputted
98
+
99
+ sym: string
100
+ symmetry Default: X
101
+
102
+ order:
103
+ Default 1,
104
+
105
+ msa_concat_mode:
106
+ MSA concatenation mode Default:"diag" Options: "diag", "repeat", "default"
107
+
108
+ msa_method:
109
+ MSA method Default:"single_sequence" Options: "mmseqs2", "single_sequence"
110
+
111
+ pair_mode:
112
+ Pair mode Default:"unpaired_paired" Options: "unpaired_paired", "paired", "unpaired"
113
+
114
+ collapse_identical:
115
+ Collapse identical sequences Default:True
116
+
117
+ num_recycles:
118
+ Number of recycles Default:0 Options: 0, 1, 3, 6, 12, 24
119
+
120
+ use_mlm:
121
+ Use MLM Default:True
122
+
123
+ use_dropout:
124
+ Use dropout Default:True
125
+
126
+ max_msa:
127
+ Max MSA Default:16
128
+
129
+ random_seed:
130
+ Random seed Default:0
131
+
132
+ num_models:
133
+ Number of models Default:0
134
+ """
135
+ response = requests.post(
136
+ "http://localhost:7860/run/rosettafold2/",
137
+ json={
138
+ "data": [
139
+ sequence, # str in 'sequence' Textbox component
140
+ jobname, # str in 'jobname' Textbox component
141
+ sym, # str in 'sym' Textbox component
142
+ order, # int | float (numeric value between 1 and 12) in 'order' Slider component
143
+ "diag", # str (Option from: ['diag', 'repeat', 'default']) in 'msa_concat_mode' Dropdown component
144
+ "single_sequence", # str (Option from: ['mmseqs2', 'single_sequence', 'custom_a3m']) in 'msa_method' Dropdown component
145
+ "unpaired_paired", # str (Option from: ['unpaired_paired', 'paired', 'unpaired']) in 'pair_mode' Dropdown component
146
+ True, # bool in 'collapse_identical' Checkbox component
147
+ 0, # int (Option from: ['0', '1', '3', '6', '12', '24']) in 'num_recycles' Dropdown component
148
+ True, # bool in 'use_mlm' Checkbox component
149
+ True, # bool in 'use_dropout' Checkbox component
150
+ 16, # int (Option from: ['16', '32', '64', '128', '256', '512']) in 'max_msa' Dropdown component
151
+ 0, # int in 'random_seed' Textbox component
152
+ 1, # int (Option from: ['1', '2', '4', '8', '16', '32']) in 'num_models' Dropdown component
153
+ ]
154
+ },
155
+ ).json()
156
+ print(response)
157
+ try:
158
+ data = response["data"]
159
+ except KeyError:
160
+ print(response["error"])
161
+ return None
162
+ with open(f"{jobname}.pdb", "w") as out:
163
+ out.writelines(data)
164
+ cmd.load(f"{jobname}.pdb")
165
+
166
+
167
+ cmd.extend("rosettafold2", query_rosettafold2)
168
+ cmd.extend("color_plddt", color_plddt)