protpardelle / app.py
simonduerr's picture
Update app.py
53da8f8
import gradio as gr
import re
import urllib
import tempfile
from output_helpers import viewer_html, output_html, load_js, get_js
import json
import os
import shlex
import subprocess
from datetime import datetime
from einops import repeat
import torch
from core import data
from core import utils
import models
import sampling
# from draw_samples import draw_and_save_samples, parse_resample_idx_string
print("working directory", os.getcwd())
def draw_and_save_samples(
model,
samples_per_len=8,
lengths=range(50, 512),
save_dir="./",
mode="backbone",
**sampling_kwargs,
):
device = model.device
sample_files = []
if mode == "backbone":
total_sampling_time = 0
for l in lengths:
prot_lens = torch.ones(samples_per_len).long() * l
seq_mask = model.make_seq_mask_for_sampling(prot_lens=prot_lens)
aux = sampling.draw_backbone_samples(
model,
seq_mask=seq_mask,
pdb_save_path=f"{save_dir}/len{format(l, '03d')}_samp",
return_aux=True,
return_sampling_runtime=True,
**sampling_kwargs,
)
total_sampling_time += aux["runtime"]
sample_files+= [f"{save_dir}/len{format(l, '03d')}_samp{i}.pdb" for i in range(samples_per_len)]
return sample_files
elif mode == "allatom":
total_sampling_time = 0
for l in lengths:
prot_lens = torch.ones(samples_per_len).long() * l
seq_mask = model.make_seq_mask_for_sampling(prot_lens=prot_lens)
aux = sampling.draw_allatom_samples(
model,
seq_mask=seq_mask,
pdb_save_path=f"{save_dir}/len{format(l, '03d')}",
return_aux=True,
**sampling_kwargs,
)
total_sampling_time += aux["runtime"]
sample_files+= [f"{save_dir}/len{format(l, '03d')}_samp{i}.pdb" for i in range(samples_per_len)]
return sample_files
def parse_idx_string(idx_str):
spans = idx_str.split(",")
idxs = []
for s in spans:
if "-" in s:
start, stop = s.split("-")
idxs.extend(list(range(int(start), int(stop))))
else:
idxs.append(int(s))
return idxs
def changemode(m):
if (m == "unconditional"):
return gr.update(visible=True), gr.update(visible=False),gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)
else:
return gr.update(visible=False), gr.update(visible=True),gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)
def fileselection(val):
if (val == "upload"):
return gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)
else:
return gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)
def update_structuresel(pdb, radio_val):
pdb_file = tempfile.NamedTemporaryFile(delete=False, suffix=".pdb")
representations = [{
"model": 0,
"chain": "",
"resname": "",
"style": "cartoon",
"color": "whiteCarbon",
"residue_range": "",
"around": 0,
"byres": False,
"visible": False,
}]
if (radio_val == "PDB"):
if (len(pdb) != 4):
return gr.update(open=True),gr.update(), gr.update(value="",visible=False)
else:
urllib.request.urlretrieve(
f"http://files.rcsb.org/download/{pdb.lower()}.pdb1",
pdb_file.name,
)
return gr.update(open=False),gr.update(value=pdb_file.name), gr.update(value=f"""<iframe style="width: 100%; height: 930px" name="result" allow="midi; geolocation; microphone; camera;
display-capture; encrypted-media;" sandbox="allow-modals allow-forms
allow-scripts allow-same-origin allow-popups
allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
allowpaymentrequest="" frameborder="0" srcdoc='{viewer_html(pdb_file.name, representations=representations)}'></iframe>""",visible=True)
elif (radio_val == "AFDB2"):
if (re.match("[OPQ][0-9][A-Z0-9]{3}[0-9]|[A-NR-Z][0-9]([A-Z][A-Z0-9]{2}[0-9]){1,2}",pdb) != None):
urllib.request.urlretrieve(
f"https://alphafold.ebi.ac.uk/files/AF-{pdb}-F1-model_v2.pdb",
pdb_file.name
)
return gr.update(open=False),gr.update(value=pdb_file.name), gr.update(value=f"""<iframe style="width: 100%; height: 930px" name="result" allow="midi; geolocation; microphone; camera;
display-capture; encrypted-media;" sandbox="allow-modals allow-forms
allow-scripts allow-same-origin allow-popups
allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
allowpaymentrequest="" frameborder="0" srcdoc='{viewer_html(pdb_file.name, representations=representations)}'></iframe>""",visible=True)
else:
return gr.update(open=True), gr.update(value="regex not matched",visible=True)
else:
return gr.update(open=False),gr.update(value=f"{pdb.name}"), gr.update(value=f"""<iframe style="width: 100%; height: 930px" name="result" allow="midi; geolocation; microphone; camera;
display-capture; encrypted-media;" sandbox="allow-modals allow-forms
allow-scripts allow-same-origin allow-popups
allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
allowpaymentrequest="" frameborder="0" srcdoc='{viewer_html(pdb.name, representations=representations)}'></iframe>""",visible=True)
from Bio.PDB import PDBParser, cealign
from Bio.PDB.PDBIO import PDBIO
class dotdict(dict):
"""dot.notation access to dictionary attributes"""
__getattr__ = dict.get
__setattr__ = dict.__setitem__
__delattr__ = dict.__delitem__
def protpardelle(path_to_file, m, resample_idx, modeltype, minlen, maxlen, steplen, perlen):
# Set up params, arguments, sampling config
####################
args = {}
args["model_checkpoint"] = "checkpoints" #Path to denoiser model weights and config",
args["mpnnpath"] = "checkpoints/minimpnn_state_dict.pth" #"Path to minimpnn model weights",
args["modeldir"] = None #"Model base directory, ex 'training_logs/other/lemon-shape-51'",
args["modelepoch"] = None #"Model epoch, ex 1000")
args["type"]=modeltype # "Type of model"
if m == "conditional":
args["param"] = None #"Which sampling param to vary"
args["paramval"]=None #"Which param val to use"
args["parampath"]= None # Path to json file with params, either use param/paramval or parampath, not both",
args["perlen"] = int(perlen) #How many samples per sequence length"
args["minlen"] = None #"Minimum sequence length"
args["maxlen"] = None #Maximum sequence length, not inclusive",
args["steplen"] = int(steplen) #"How frequently to select sequence length, for steplen 2, would be 50, 52, 54, etc",
args["num_lens"] = None #"If steplen not provided, how many random lengths to sample at",
args["targetdir"] = "." #"Directory to save results"
args["input_pdb"] = path_to_file # "PDB file to condition on"
args["resample_idxs"] = resample_idx[1:-1] # "Indices from PDB file to resample. Zero-indexed, comma-delimited, can use dashes, eg 0,2-5,7"
else:
args["param"] = None #"Which sampling param to vary"
args["paramval"]=None #"Which param val to use"
args["parampath"]= None # Path to json file with params, either use param/paramval or parampath, not both",
args["perlen"] = int(perlen) #How many samples per sequence length"
args["minlen"] = int(minlen) #"Minimum sequence length"
args["maxlen"] = int(maxlen)+1 #Maximum sequence length
args["steplen"] = int(steplen) #"How frequently to select sequence length, for steplen 2, would be 50, 52, 54, etc",
args["num_lens"] = None #"If steplen not provided, how many random lengths to sample at",
args["targetdir"] = "." #"Directory to save results"
args["resample_idxs"] = None
args = dotdict(args)
is_test_run = False
seed = 0
samples_per_len = args.perlen
min_len = args.minlen
max_len = args.maxlen
len_step_size = args.steplen
device = "cuda:0"
# setting default sampling config
if args.type == "backbone":
sampling_config = sampling.default_backbone_sampling_config()
elif args.type == "allatom":
sampling_config = sampling.default_allatom_sampling_config()
sampling_kwargs = vars(sampling_config)
# Parse conditioning inputs
input_pdb_len = None
if args.input_pdb:
input_feats = utils.load_feats_from_pdb(args.input_pdb, protein_only=True)
input_pdb_len = input_feats["aatype"].shape[0]
if args.resample_idxs:
print(
f"Warning: when sampling conditionally, the input pdb length ({input_pdb_len} residues) is used automatically for the sampling lengths."
)
resample_idxs = parse_idx_string(args.resample_idxs)
else:
resample_idxs = list(range(input_pdb_len))
cond_idxs = [i for i in range(input_pdb_len) if i not in resample_idxs]
to_batch_size = lambda x: repeat(x, "... -> b ...", b=samples_per_len).to(
device
)
# For unconditional model, center coords on whole structure
centered_coords = data.apply_random_se3(
input_feats["atom_positions"],
atom_mask=input_feats["atom_mask"],
translation_scale=0.0,
)
cond_kwargs = {}
cond_kwargs["gt_coords"] = to_batch_size(centered_coords)
cond_kwargs["gt_cond_atom_mask"] = to_batch_size(input_feats["atom_mask"])
cond_kwargs["gt_cond_atom_mask"][:, resample_idxs] = 0
cond_kwargs["gt_aatype"] = to_batch_size(input_feats["aatype"])
cond_kwargs["gt_cond_seq_mask"] = torch.zeros_like(cond_kwargs["gt_aatype"])
cond_kwargs["gt_cond_seq_mask"][:, cond_idxs] = 1
sampling_kwargs.update(cond_kwargs)
print("input_pdb_len", input_pdb_len)
# Determine lengths to sample at
if min_len is not None and max_len is not None:
if len_step_size is not None:
sampling_lengths = range(min_len, max_len, len_step_size)
else:
sampling_lengths = list(
torch.randint(min_len, max_len, size=(args.num_lens,))
)
elif input_pdb_len is not None:
sampling_lengths = [input_pdb_len]
else:
raise Exception("Need to provide a set of protein lengths or an input pdb.")
total_num_samples = len(list(sampling_lengths)) * samples_per_len
model_directory = args.modeldir
epoch = args.modelepoch
base_dir = args.targetdir
date_string = datetime.now().strftime("%y-%m-%d-%H-%M-%S")
if is_test_run:
date_string = f"test-{date_string}"
# Update sampling config with arguments
if args.param:
var_param = args.param
var_value = args.paramval
sampling_kwargs[var_param] = (
None
if var_value == "None"
else int(var_value)
if var_param == "n_steps"
else float(var_value)
)
elif args.parampath:
with open(args.parampath) as f:
var_params = json.loads(f.read())
sampling_kwargs.update(var_params)
# this is only used for the readme, keep s_min and s_max as params instead of struct_noise_schedule
sampling_kwargs_readme = list(sampling_kwargs.items())
print("Base directory:", base_dir)
save_dir = f"{base_dir}/samples/{date_string}"
save_init_dir = f"{base_dir}/samples_inits/{date_string}"
# make dirs if do not exist
if not os.path.exists(save_dir):
subprocess.run(shlex.split(f"mkdir -p {save_dir}"))
if not os.path.exists(save_init_dir):
subprocess.run(shlex.split(f"mkdir -p {save_init_dir}"))
print("Samples saved to:", save_dir)
torch.manual_seed(seed)
# Load model
if args.type == "backbone":
if args.model_checkpoint:
checkpoint = f"{args.model_checkpoint}/backbone_state_dict.pth"
cfg_path = f"{args.model_checkpoint}/backbone_pretrained.yml"
else:
checkpoint = (
f"{model_directory}/checkpoints/epoch{epoch}_training_state.pth"
)
cfg_path = f"{model_directory}/configs/backbone.yml"
config = utils.load_config(cfg_path)
weights = torch.load(checkpoint, map_location=device)["model_state_dict"]
model = models.Protpardelle(config, device=device)
model.load_state_dict(weights)
model.to(device)
model.eval()
model.device = device
elif args.type == "allatom":
if args.model_checkpoint:
checkpoint = f"{args.model_checkpoint}/allatom_state_dict.pth"
cfg_path = f"{args.model_checkpoint}/allatom_pretrained.yml"
else:
checkpoint = (
f"{model_directory}/checkpoints/epoch{epoch}_training_state.pth"
)
cfg_path = f"{model_directory}/configs/allatom.yml"
config = utils.load_config(cfg_path)
weights = torch.load(checkpoint, map_location=device)["model_state_dict"]
model = models.Protpardelle(config, device=device)
model.load_state_dict(weights)
model.load_minimpnn(args.mpnnpath)
model.to(device)
model.eval()
model.device = device
if config.train.home_dir == '':
config.train.home_dir = os.getcwd()
with open(save_dir + "/run_parameters.txt", "w") as f:
f.write(f"Sampling run for {date_string}\n")
f.write(f"Random seed {seed}\n")
f.write(f"Model checkpoint: {checkpoint}\n")
f.write(
f"{samples_per_len} samples per length from {min_len}:{max_len}:{len_step_size}\n"
)
f.write("Sampling params:\n")
for k, v in sampling_kwargs_readme:
f.write(f"{k}\t{v}\n")
print(f"Model loaded from {checkpoint}")
print(f"Beginning sampling for {date_string}...")
# Draw samples
output_files = draw_and_save_samples(
model,
samples_per_len=samples_per_len,
lengths=sampling_lengths,
save_dir=save_dir,
mode=args.type,
**sampling_kwargs,
)
return output_files
def api_predict(pdb_content,m, resample_idx, modeltype, minlen, maxlen, steplen, perlen):
if (m == "conditional"):
tempPDB = tempfile.NamedTemporaryFile(delete=False, suffix=".pdb")
tempPDB.write(pdb_content.encode())
tempPDB.close()
path_to_file = tempPDB.name
else:
path_to_file = None
try:
designs = protpardelle(path_to_file, m, resample_idx, modeltype, minlen, maxlen, steplen, perlen)
except Exception as e:
print(e)
raise gr.Error(e)
# load each design as string
design_str = []
for d in designs:
with open(d, "r") as f:
design_str.append(f.read())
results = list(zip(designs, design_str))
return json.dumps(results)
def predict(pdb_radio, path_to_file,m, resample_idx, modeltype, minlen, maxlen, steplen, perlen):
print("running predict")
try:
designs = protpardelle(path_to_file, m, resample_idx, modeltype, minlen, maxlen, steplen, perlen)
except Exception as e:
print(e)
raise gr.Error(e)
return gr.update(open=True), gr.update(value="something went wrong")
parser = PDBParser()
aligner = cealign.CEAligner()
io=PDBIO()
aligned_designs = []
metrics = []
if (m == "conditional"):
ref = parser.get_structure("ref", path_to_file)
aligner.set_reference(ref)
for d in designs:
design = parser.get_structure("design", d)
aligner.align(design)
metrics.append({"rms": f"{aligner.rms:.1f}", "len": len(list(design[0].get_residues()))})
io.set_structure(design)
io.save(d.replace(".pdb", f"_al.pdb"))
aligned_designs.append(d.replace(".pdb", f"_al.pdb"))
else:
for d in designs:
design = parser.get_structure("design", d)
metrics.append({"len": len(list(design[0].get_residues()))})
aligned_designs = designs
output_view = f"""<iframe style="width: 100%; height: 900px" name="result" allow="midi; geolocation; microphone; camera;
display-capture; encrypted-media;" sandbox="allow-modals allow-forms
allow-scripts allow-same-origin allow-popups
allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
allowpaymentrequest="" frameborder="0" srcdoc='{output_html(path_to_file, aligned_designs, metrics, resample_idx=resample_idx, mode=m)}'></iframe>"""
return gr.update(open=False), gr.update(value=output_view,visible=True)
protpardelleDemo = gr.Blocks()
with protpardelleDemo:
gr.Markdown("# Protpardelle")
gr.Markdown(""" An all-atom protein generative model
Alexander E. Chu, Lucy Cheng, Gina El Nesr, Minkai Xu, Po-Ssu Huang
doi: https://doi.org/10.1101/2023.05.24.542194""")
with gr.Accordion(label="Input options", open=True) as input_accordion:
model = gr.Dropdown(["backbone", "allatom"], value="allatom", label="What to sample?")
m = gr.Radio(['unconditional','conditional'],value="unconditional", label="Choose a Mode")
#unconditional
with gr.Group(visible=True) as uncond:
gr.Markdown("Unconditional Sampling")
# length = gr.Slider(minimum=0, maximum=200, step=1, value=50, label="length")
# param = gr.Dropdown(["length", "param"], value="length", label="Which sampling param to vary?")
# paramval = gr.Dropdown(["nsteps"], label="paramval", info="Which param val to use?")
#conditional
with gr.Group(visible=False) as cond:
with gr.Accordion(label="Structure to condition on", open=True) as input_accordion:
pdb_radio = gr.Radio(['PDB','AF2 EBI DB', 'upload'],value="PDB", label="source of the structure")
pdbcode = gr.Textbox(label="Uniprot code to be retrieved Alphafold2 Database", visible=True)
pdbfile = gr.File(label="PDB File", visible=False)
btn_load = gr.Button("Load PDB")
pdb_radio.change(fileselection, inputs=pdb_radio, outputs=[pdbcode, pdbfile, btn_load])
pdb_html = gr.HTML("", visible=False)
path_to_file = gr.Textbox(label="Path to file", visible=False)
resample_idxs = gr.Textbox(label="Cond Idxs", interactive=False, info="Zero indexed list of indices to condition on, select in sequence viewer above")
btn_load.click(update_structuresel, inputs=[pdbcode, pdb_radio], outputs=[input_accordion,path_to_file,pdb_html])
pdbfile.change(update_structuresel, inputs=[pdbfile,pdb_radio], outputs=[input_accordion,path_to_file,pdb_html])
with gr.Accordion(label="Sizes", open=True) as size_uncond:
with gr.Row():
minlen = gr.Slider(minimum=2, maximum=200,value=50, step=1, label="minlen", info="Minimum sequence length")
maxlen = gr.Slider(minimum=3, maximum=200,value=60, step=1, label="maxlen", info="Maximum sequence length")
steplen = gr.Slider(minimum=1, maximum=50, step=1, value=1, label="steplen", info="How frequently to select sequence length?" )
perlen = gr.Slider(minimum=1, maximum=200, step=1, value=2, label="perlen", info="How many samples per sequence length?")
btn_conditional = gr.Button("Run conditional",visible=False)
btn_unconditional = gr.Button("Run unconditional")
m.change(changemode, inputs=m, outputs=[uncond, cond, btn_unconditional, btn_conditional, size_uncond])
out = gr.HTML("", visible=True)
btn_unconditional.click(predict, inputs=[pdb_radio, path_to_file,m, resample_idxs, model, minlen, maxlen, steplen, perlen], outputs=[input_accordion, out])
btn_conditional.click(fn=None,
inputs=[resample_idxs],
outputs=[resample_idxs],
_js=get_js
) #
out_text = gr.Textbox(label="Output", visible=False)
#hidden button for named api route
pdb_content = gr.Textbox(label="PDB Content", visible=False)
btn_api = gr.Button("Run API",visible=False)
btn_api.click(api_predict, inputs=[pdb_content,m, resample_idxs, model, minlen, maxlen, steplen, perlen], outputs=[out_text], api_name="protpardelle")
resample_idxs.change(predict, inputs=[pdb_radio, path_to_file,m, resample_idxs, model, minlen, maxlen, steplen, perlen], outputs=[input_accordion, out])
protpardelleDemo.load(None, None, None, _js=load_js)
protpardelleDemo.queue()
protpardelleDemo.launch(allowed_paths=['samples'])