metal3d / app.py
simonduerr's picture
Update app.py
1f8abe6
raw
history blame
9.18 kB
import gradio as gr
import urllib
import re
import sys
import warnings
import torch
import torch.nn as nn
import ipywidgets as widgets
from ipywidgets import interact, fixed
from utils.helpers import *
from utils.voxelization import processStructures
from utils.model import Model
import numpy as np
import os
import moleculekit
print(moleculekit.__version__)
def update(inp, file, mode, custom_resids, clustering_threshold):
try:
filepath = file.name
except:
print("using pdbfile")
try:
pdb_file = inp
if (
re.match(
"[OPQ][0-9][A-Z0-9]{3}[0-9]|[A-NR-Z][0-9]([A-Z][A-Z0-9]{2}[0-9]){1,2}",
pdb_file,
).group()
== pdb_file
):
urllib.request.urlretrieve(
f"https://alphafold.ebi.ac.uk/files/AF-{pdb_file}-F1-model_v2.pdb",
f"files/{pdb_file}.pdb",
)
filepath = f"files/{pdb_file}.pdb"
except AttributeError:
if len(inp) == 4:
pdb_file = inp
urllib.request.urlretrieve(
f"http://files.rcsb.org/download/{pdb_file.lower()}.pdb1",
f"files/{pdb_file}.pdb",
)
filepath = f"files/{pdb_file}.pdb"
else:
return "pdb code must be 4 letters or Uniprot code does not match", ""
if mode == "All residues":
ids = get_all_protein_resids(filepath)
elif len(custom_resids)!=0:
ids=get_all_resids_from_list(custom_resids.replace(","," "))
else:
ids = get_all_metalbinding_resids(filepath)
voxels, prot_centers, prot_N, prots = processStructures(filepath, ids)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
voxels.to(device)
print(voxels.shape)
model = Model()
model.to(device)
model.load_state_dict(torch.load("weights/metal_0.5A_v3_d0.2_16Abox.pth", map_location=torch.device('cpu')))
model.eval()
with warnings.catch_warnings():
warnings.filterwarnings("ignore")
output = model(voxels)
print(output.shape)
prot_v = np.vstack(prot_centers)
output_v = output.flatten().cpu().detach().numpy()
bb = get_bb(prot_v)
gridres = 0.5
grid, box_N = create_grid_fromBB(bb, voxelSize=gridres)
probability_values = get_probability_mean(grid, prot_v, output_v)
print(probability_values.shape)
write_cubefile(
bb,
probability_values,
box_N,
outname=f"output/metal_{pdb_file}.cube",
gridres=gridres,
)
message = find_unique_sites(
probability_values,
grid,
writeprobes=True,
probefile=f"output/probes_{pdb_file}.pdb",
threshold=7,
p=clustering_threshold,
)
return message, molecule(
filepath,
f"output/probes_{pdb_file}.pdb",
f"output/metal_{pdb_file}.cube",
)
def test():
x = """<!DOCTYPE html>
<html>
<head>
<meta http-equiv="content-type" content="text/html; charset=UTF-8" />
</head>
<body>
<script src="https://3Dmol.org/build/3Dmol-min.js" async></script> <div style="height: 400px; width: 400px; position: relative;" class="viewer_3Dmoljs" data-pdb="2POR" data-backgroundcolor="0xffffff" data-style="stick" ></div>
</body></html>"""
return f"""<iframe style="width: 100%; height: 480px" name="result" allow="midi; geolocation; microphone; camera;
display-capture; encrypted-media;" sandbox="allow-modals allow-forms
allow-scripts allow-same-origin allow-popups
allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
allowpaymentrequest="" frameborder="0" srcdoc='{x}'></iframe>"""
def read_mol(molpath):
with open(molpath, "r") as fp:
lines = fp.readlines()
mol = ""
for l in lines:
mol += l
return mol
def molecule(pdb, probes, cube):
mol = read_mol(pdb)
probes = read_mol(probes)
cubefile = read_mol(cube)
x = (
"""<!DOCTYPE html>
<html>
<head>
<meta http-equiv="content-type" content="text/html; charset=UTF-8" />
<style>
body{
font-family:sans-serif
}
.mol-container {
width: 100%;
height: 400px;
position: relative;
}
.slider{
width:80%;
margin:0 auto
}
.slidercontainer{
display:flex;
}
.slidercontainer > * + * {
margin-left: 0.5rem;
}
#isovalue{
text-align:right}
</style>
<script src="https://3Dmol.csb.pitt.edu/build/3Dmol-min.js"></script>
<script src="https://cdnjs.cloudflare.com/ajax/libs/rangeslider.js/2.3.3/rangeslider.min.js" integrity="sha512-BUlWdwDeJo24GIubM+z40xcj/pjw7RuULBkxOTc+0L9BaGwZPwiwtbiSVzv31qR7TWx7bs6OPTE5IyfLOorboQ==" crossorigin="anonymous" referrerpolicy="no-referrer"></script>
</head>
<body>
<div class="slidercontainer">
<span>Isovalue </span>
<span id="isovalue">0.5</span>
<input class="slider" type="range" id="rangeslider" min="0" max="1" step="0.025" value=0.5>
</div>
<div id="container" class="mol-container"></div>
<script>
let viewer = null;
let voldata = null;
$(document).ready(function () {
let element = $("#container");
let config = { backgroundColor: "white" };
viewer = $3Dmol.createViewer( element, config );
viewer.ui.initiateUI();
let data = `"""
+ mol
+ """`
viewer.addModel( data, "pdb" );
let cubefile = `"""
+ cubefile
+ """`
voldata = new $3Dmol.VolumeData(cubefile, "cube");
viewer.addIsosurface(voldata, { isoval: 0.7 , color: "blue", alpha: 0.85, smoothness: 1 });
viewer.getModel(0).setStyle({}, {cartoon: {color: "grayCarbon"}});
let probes =`"""
+ probes
+ """`
viewer.addModel(probes, "pdb");
viewer.getModel(1).setStyle({ "resn": "ZN" }, { "sphere": { }});
viewer.getModel(1).setHoverable({}, true,
function (atom, viewer, event, container) {
if (!atom.label) {
atom.label = viewer.addLabel("ZN p=" + atom.pdbline.substring(55, 60), { position: atom, backgroundColor: "mintcream", fontColor: "black" });
}
},
function (atom, viewer) {
if (atom.label) {
viewer.removeLabel(atom.label);
delete atom.label;
}
}
);
viewer.zoomTo();
viewer.render();
viewer.zoom(0.8, 2000);
});
</script>
<script>
$("#rangeslider").rangeslider().on("change", function (el) {
isoval = parseFloat(el.target.value);
$("#isovalue").text(el.target.value)
viewer.addIsosurface(voldata, { isoval: parseFloat(el.target.value), color: "blue", alpha: 0.85, smoothness: 1 });
viewer.render();
});
</script>
</body></html>"""
)
return f"""<iframe style="width: 100%; height: 480px" name="result" allow="midi; geolocation; microphone; camera;
display-capture; encrypted-media;" sandbox="allow-modals allow-forms
allow-scripts allow-same-origin allow-popups
allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
allowpaymentrequest="" frameborder="0" srcdoc='{x}'></iframe>"""
metal3d = gr.Blocks()
with metal3d:
gr.Markdown("# Metal3D")
with gr.Tabs():
with gr.TabItem("Input"):
inp = gr.Textbox( placeholder="PDB Code or Uniprot identifier or upload file below", label="Input molecule"
)
file = gr.File(file_count="single", type="file")
with gr.TabItem("Settings"):
with gr.Row():
mode = gr.Radio(
["All metalbinding residues (ASP, CYS, GLU, HIS)", "All residues"],
label="Residues to use for prediction",
)
custom_resids = gr.Textbox(placeholder="Comma separated list of residues", label="Custom residues")
with gr.Row():
clustering_threshold = gr.Slider(minimum=0.15,maximum=1, value=0.15,step=0.05, label="Clustering threshold")
distance_cutoff = gr.Slider(minimum=1,maximum=10, value=7,step=1, label="Clustering distance cutoff")
btn = gr.Button("Run")
gr.Markdown(
""" <small>Inference using CPU-only, can be quite slow for more than 20 residues. Use Colab notebook for GPU acceleration</small>
"""
)
gr.Markdown("# Output")
out = gr.Textbox(label="status")
mol = gr.HTML()
btn.click(fn=update, inputs=[inp, file, mode, custom_resids, clustering_threshold], outputs=[out, mol])
metal3d.launch()