rosettafold2 / rosettafold_pymol.py
simonduerr's picture
Update rosettafold_pymol.py
2cd9afa
from pymol import cmd
import requests
import os
import time
# from gradio_client import Client
def color_plddt(selection="all"):
"""
AUTHOR
Jinyuan Sun
https://github.com/JinyuanSun/PymolFold/tree/main
MIT License
DESCRIPTION
Colors Predicted Structures by pLDDT
USAGE
color_plddt sele
PARAMETERS
sele (string)
The name of the selection/object to color by pLDDT. Default: all
"""
# Alphafold color scheme for plddt
cmd.set_color("high_lddt_c", [0, 0.325490196078431, 0.843137254901961])
cmd.set_color(
"normal_lddt_c", [0.341176470588235, 0.792156862745098, 0.976470588235294]
)
cmd.set_color("medium_lddt_c", [1, 0.858823529411765, 0.070588235294118])
cmd.set_color("low_lddt_c", [1, 0.494117647058824, 0.270588235294118])
# test the scale of predicted_lddt (0~1 or 0~100 ) as b-factors
cmd.select("test_b_scale", f"b>1 and ({selection})")
b_scale = cmd.count_atoms("test_b_scale")
if b_scale > 0:
cmd.select("high_lddt", f"({selection}) and (b >90 or b =90)")
cmd.select("normal_lddt", f"({selection}) and ((b <90 and b >70) or (b =70))")
cmd.select("medium_lddt", f"({selection}) and ((b <70 and b >50) or (b=50))")
cmd.select("low_lddt", f"({selection}) and ((b <50 and b >0 ) or (b=0))")
else:
cmd.select("high_lddt", f"({selection}) and (b >.90 or b =.90)")
cmd.select(
"normal_lddt", f"({selection}) and ((b <.90 and b >.70) or (b =.70))"
)
cmd.select("medium_lddt", f"({selection}) and ((b <.70 and b >.50) or (b=.50))")
cmd.select("low_lddt", f"({selection}) and ((b <.50 and b >0 ) or (b=0))")
cmd.delete("test_b_scale")
# set color based on plddt values
cmd.color("high_lddt_c", "high_lddt")
cmd.color("normal_lddt_c", "normal_lddt")
cmd.color("medium_lddt_c", "medium_lddt")
cmd.color("low_lddt_c", "low_lddt")
# set background color
cmd.bg_color("white")
def query_rosettafold2(
sequence: str,
jobname: str,
sym: str = "X",
order: int = 1,
msa_concat_mode: str = "diag",
msa_method: str = "single_sequence",
pair_mode: str = "unpaired_paired",
collapse_identical: bool = False,
num_recycles: int = 1,
use_mlm: bool = False,
use_dropout: bool = False,
max_msa: int = 16,
random_seed: int = 0,
num_models: int = 1,
):
"""
AUTHOR
Simon Duerr
https://twitter.com/simonduerr
DESCRIPTION
Predict a structure using rosettafold2
USAGE
rosettafold2 sequence, jobname, [sym, order, msa_concat_mode, msa_method, pair_mode, collapse_identical, num_recycles, use_mlm, use_dropout, max_msa, random_seed, num_models]
PARAMETERS
sequence: (string)
one letter amino acid codes that you want to predict
jobname: string
name of the pdbfile that will be outputted
sym: string
symmetry Default: X
order:
Default 1,
msa_concat_mode:
MSA concatenation mode Default:"diag" Options: "diag", "repeat", "default"
msa_method:
MSA method Default:"single_sequence" Options: "mmseqs2", "single_sequence"
pair_mode:
Pair mode Default:"unpaired_paired" Options: "unpaired_paired", "paired", "unpaired"
collapse_identical:
Collapse identical sequences Default:True
num_recycles:
Number of recycles Default:6 Options: 0, 1, 3, 6, 12, 24
use_mlm:
Use MLM Default:False
use_dropout:
Use dropout Default:False
max_msa:
Max MSA Default:16
random_seed:
Random seed Default:0
num_models:
Number of models Default:0
"""
# check if server is running on localhost
try:
requests.get("http://localhost:7860")
url = "http://localhost:7860"
except requests.exceptions.ConnectionError:
url = "https://simonduerr-rosettafold2.hf.space"
print(f'querying {url}')
response = requests.post(
url + "/run/rosettafold2/",
json={
"data": [
sequence, # str in 'sequence' Textbox component
jobname, # str in 'jobname' Textbox component
sym, # str in 'sym' Textbox component
order, # int | float (numeric value between 1 and 12) in 'order' Slider component
msa_concat_mode, # str (Option from: ['diag', 'repeat', 'default']) in 'msa_concat_mode' Dropdown component
msa_method, # str (Option from: ['mmseqs2', 'single_sequence', 'custom_a3m']) in 'msa_method' Dropdown component
pair_mode, # str (Option from: ['unpaired_paired', 'paired', 'unpaired']) in 'pair_mode' Dropdown component
collapse_identical, # bool in 'collapse_identical' Checkbox component
num_recycles, # int (Option from: ['0', '1', '3', '6', '12', '24']) in 'num_recycles' Dropdown component
use_mlm, # bool in 'use_mlm' Checkbox component
use_dropout, # bool in 'use_dropout' Checkbox component
max_msa, # int (Option from: ['16', '32', '64', '128', '256', '512']) in 'max_msa' Dropdown component
random_seed, # int in 'random_seed' Textbox component
num_models, # int (Option from: ['1', '2', '4', '8', '16', '32']) in 'num_models' Dropdown component
]
},
).json()
try:
data = response["data"]
except KeyError:
# print(response["error"])
return None
with open(f"{jobname}.pdb", "w") as out:
out.writelines(data)
cmd.load(f"{jobname}.pdb")
cmd.extend("rosettafold2", query_rosettafold2)
cmd.extend("color_plddt", color_plddt)