metal3d / app.py
simonduerr's picture
fix jquery
c223338
import gradio as gr
import urllib
import re
import sys
import warnings
import torch
import torch.nn as nn
import ipywidgets as widgets
from ipywidgets import interact, fixed
from utils.helpers import *
from utils.voxelization import processStructures
from utils.model import Model
import numpy as np
import os
import moleculekit
print(moleculekit.__version__)
def update(inp, file, mode, custom_resids, clustering_threshold, distance_cutoff):
try:
filepath = file.name
except:
print("using pdbfile")
try:
pdb_file = inp
if (
re.match(
"[OPQ][0-9][A-Z0-9]{3}[0-9]|[A-NR-Z][0-9]([A-Z][A-Z0-9]{2}[0-9]){1,2}",
pdb_file,
).group()
== pdb_file
):
urllib.request.urlretrieve(
f"https://alphafold.ebi.ac.uk/files/AF-{pdb_file}-F1-model_v2.pdb",
f"files/{pdb_file}.pdb",
)
filepath = f"files/{pdb_file}.pdb"
except AttributeError:
if len(inp) == 4:
pdb_file = inp
urllib.request.urlretrieve(
f"http://files.rcsb.org/download/{pdb_file.lower()}.pdb1",
f"files/{pdb_file}.pdb",
)
filepath = f"files/{pdb_file}.pdb"
else:
return "pdb code must be 4 letters or Uniprot code does not match", ""
identifier = os.path.basename(filepath)
if mode == "All residues":
print("using all residues")
ids = get_all_protein_resids(filepath)
elif len(custom_resids) != 0:
print("using listed residues", custom_resids)
ids = get_all_resids_from_list(filepath, custom_resids.replace(",", " "))
else:
print("using metalbinding")
ids = get_all_metalbinding_resids(filepath)
print(filepath)
print(ids)
try:
voxels, prot_centers, prot_N, prots = processStructures(filepath, ids)
except Exception as e:
print(e)
return (
"Error",
f"""<div class="text-center mt-4"> Something went wrong with the voxelization, reset custom residues and other input fiels and check error message <br> <br> <code>{e}</code></div>""",
)
voxels.to(device)
with warnings.catch_warnings():
warnings.filterwarnings("ignore")
output = model(voxels)
print(output.shape)
prot_v = np.vstack(prot_centers)
output_v = output.flatten().cpu().detach().numpy()
bb = get_bb(prot_v)
gridres = 0.5
grid, box_N = create_grid_fromBB(bb, voxelSize=gridres)
probability_values = get_probability_mean(grid, prot_v, output_v)
print(probability_values.shape)
write_cubefile(
bb,
probability_values,
box_N,
outname=f"output/metal_{identifier}.cube",
gridres=gridres,
)
message = find_unique_sites(
probability_values,
grid,
writeprobes=True,
probefile=f"output/probes_{identifier}.pdb",
threshold=distance_cutoff,
p=clustering_threshold,
)
del voxels
torch.cuda.empty_cache()
return message, molecule(
filepath,
f"output/probes_{identifier}.pdb",
f"output/metal_{identifier}.cube",
)
def read_mol(molpath):
with open(molpath, "r") as fp:
lines = fp.readlines()
mol = ""
for l in lines:
mol += l
return mol
def molecule(pdb, probes, cube):
mol = read_mol(pdb)
probes = read_mol(probes)
cubefile = read_mol(cube)
x = (
"""<!DOCTYPE html>
<html>
<head>
<meta http-equiv="content-type" content="text/html; charset=UTF-8" />
<link rel="stylesheet" href="https://unpkg.com/flowbite@1.4.5/dist/flowbite.min.css" />
<style>
body{
font-family:sans-serif
}
.mol-container {
width: 100%;
height: 600px;
position: relative;
}
.slider{
width:80%;
margin:0 auto
}
.slidercontainer{
display:flex;
}
.slidercontainer > * + * {
margin-left: 0.5rem;
}
#isovalue{
text-align:right}
</style>
<script src="https://cdnjs.cloudflare.com/ajax/libs/jquery/3.6.3/jquery.min.js" integrity="sha512-STof4xm1wgkfm7heWqFJVn58Hm3EtS31XFaagaa8VMReCXAkQnJZ+jEy8PCC/iT18dFy95WcExNHFTqLyp72eQ==" crossorigin="anonymous" referrerpolicy="no-referrer"></script>
<script src="https://3Dmol.csb.pitt.edu/build/3Dmol-min.js"></script>
<script src="https://cdnjs.cloudflare.com/ajax/libs/rangeslider.js/2.3.3/rangeslider.min.js" integrity="sha512-BUlWdwDeJo24GIubM+z40xcj/pjw7RuULBkxOTc+0L9BaGwZPwiwtbiSVzv31qR7TWx7bs6OPTE5IyfLOorboQ==" crossorigin="anonymous" referrerpolicy="no-referrer"></script>
</head>
<body>
<div class="slidercontainer my-8">
<span>Isovalue </span>
<span id="isovalue">0.5</span>
<input class="slider text-blue-400" type="range" id="rangeslider" min="0" max="1" step="0.05" value=0.5>
</div>
<div id="container" class="mol-container"></div>
<div class="flex items-center justify-center my-4">
<div class="px-4">
<label for="sidechain" class="relative inline-flex items-center mb-4 cursor-pointer ">
<input id="sidechain"type="checkbox" class="sr-only peer">
<div class="w-11 h-6 bg-gray-200 rounded-full peer peer-focus:ring-4 peer-focus:ring-blue-300 dark:peer-focus:ring-blue-800 dark:bg-gray-700 peer-checked:after:translate-x-full peer-checked:after:border-white after:absolute after:top-0.5 after:left-[2px] after:bg-white after:border-gray-300 after:border after:rounded-full after:h-5 after:w-5 after:transition-all dark:border-gray-600 peer-checked:bg-blue-600"></div>
<span class="ml-3 text-sm font-medium text-gray-900 dark:text-gray-300">Show side chains</span>
</label>
</div>
<div class="px-4">
<label for="pdbmetal" class="relative inline-flex items-center mb-4 cursor-pointer ">
<input id="pdbmetal" type="checkbox" class="sr-only peer">
<div class="w-11 h-6 bg-gray-200 rounded-full peer peer-focus:ring-4 peer-focus:ring-blue-300 dark:peer-focus:ring-blue-800 dark:bg-gray-700 peer-checked:after:translate-x-full peer-checked:after:border-white after:absolute after:top-0.5 after:left-[2px] after:bg-white after:border-gray-300 after:border after:rounded-full after:h-5 after:w-5 after:transition-all dark:border-gray-600 peer-checked:bg-blue-600"></div>
<span class="ml-3 text-sm font-medium text-gray-900 dark:text-gray-300">Show PDB metals</span>
</label>
</div>
<div class="px-4">
<label for="probes" class="relative inline-flex items-center mb-4 cursor-pointer ">
<input id="probes" type="checkbox" class="sr-only peer" checked>
<div class="w-11 h-6 bg-gray-200 rounded-full peer peer-focus:ring-4 peer-focus:ring-blue-300 dark:peer-focus:ring-blue-800 dark:bg-gray-700 peer-checked:after:translate-x-full peer-checked:after:border-white after:absolute after:top-0.5 after:left-[2px] after:bg-white after:border-gray-300 after:border after:rounded-full after:h-5 after:w-5 after:transition-all dark:border-gray-600 peer-checked:bg-blue-600"></div>
<span class="ml-3 text-sm font-medium text-gray-900 dark:text-gray-300">Show Probes</span>
</label>
</div>
</div>
<div class="flex items-center justify-center my-4">
<button type="button" class="text-gray-900 bg-white hover:bg-gray-100 border border-gray-200 focus:ring-4 focus:outline-none focus:ring-gray-100 font-medium rounded-lg text-sm px-5 py-2.5 text-center inline-flex items-center dark:focus:ring-gray-600 dark:bg-gray-800 dark:border-gray-700 dark:text-white dark:hover:bg-gray-700 mr-2 mb-2" id="download">
<svg class="w-6 h-6 mr-2 -ml-1" fill="none" stroke="currentColor" viewBox="0 0 24 24" xmlns="http://www.w3.org/2000/svg"><path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M4 16v1a3 3 0 003 3h10a3 3 0 003-3v-1m-4-4l-4 4m0 0l-4-4m4 4V4"></path></svg>
Download predictions
</button>
</div>
<script>
let viewer = null;
let voldata = null;
let shape = null;
let sidechain = null;
let metal = null;
$(document).ready(function () {
let element = $("#container");
let config = { backgroundColor: "white" };
viewer = $3Dmol.createViewer( element, config );
viewer.ui.initiateUI();
let data = `"""
+ mol
+ """`
viewer.addModel( data, "pdb" );
let cubefile = `"""
+ cubefile
+ """`
voldata = new $3Dmol.VolumeData(cubefile, "cube");
shape = viewer.addIsosurface(voldata, { isoval: 0.5 , color: "blue", alpha: 0.85, smoothness: 1 });
viewer.getModel(0).setStyle({}, {cartoon: {}});
let probes =`"""
+ probes
+ """`
viewer.addModel(probes, "pdb");
viewer.getModel(1).setStyle({ "resn": "ZN" }, { "sphere": { }});
viewer.getModel(1).setHoverable({}, true,
function (atom, viewer, event, container) {
if (!atom.label) {
atom.label = viewer.addLabel("ZN p=" + atom.pdbline.substring(55, 60), { position: atom, backgroundColor: "mintcream", fontColor: "black" });
}
},
function (atom, viewer) {
if (atom.label) {
viewer.removeLabel(atom.label);
delete atom.label;
}
}
);
viewer.zoomTo();
viewer.render();
viewer.zoom(0.8, 2000);
$("#sidechain").change(function () {
if (this.checked) {
BB = ["C", "O", "N"]
viewer.getModel(0).setStyle( {"and": [{resn: ["GLY", "PRO"], invert: true},{atom: BB, invert: true},]},{stick: {hidden:false, colorscheme: "WhiteCarbon", radius: 0.3}, cartoon: {}});
viewer.render()
$("#pdbmetal").prop( "checked", false );
} else {
BB = ["C", "O", "N"]
viewer.getModel(0).setStyle({"and": [{resn: ["GLY", "PRO"], invert: true},{atom: BB, invert: true},]},{stick: {colorscheme: "WhiteCarbon",hidden:true, radius: 0.3}, cartoon: {}});
viewer.render()
$("#pdbmetal").prop( "checked", false );
}
});
$("#pdbmetal").change(function () {
if (this.checked) {
viewer.getModel(0).setStyle({ "resn": ["ZN","MG","NA","FE", "NI","MN","CA", "CU", "CU1"] }, { "sphere": {hidden:false}});
viewer.render()
} else {
viewer.getModel(0).setStyle({ "resn": ["ZN","MG","NA","FE","NI", "MN","CA", "CU", "CU1"] }, { "sphere": {hidden:true}});
viewer.render()
}
});
$("#probes").change(function () {
if (this.checked) {
viewer.getModel(1).setStyle({ "resn": "ZN" }, { "sphere": { }});
viewer.addStyle()
viewer.render()
} else {
viewer.getModel(1).setStyle({});
viewer.render()
}
});
$("#download").click(function () {
download("protein.pdb", data);
download("metaldensity.cube", cubefile);
download("probes.pdb", probes);
})
});
function download(filename, text) {
var element = document.createElement("a");
element.setAttribute("href", "data:text/plain;charset=utf-8," + encodeURIComponent(text));
element.setAttribute("download", filename);
element.style.display = "none";
document.body.appendChild(element);
element.click();
document.body.removeChild(element);
}
</script>
<script>
$("#rangeslider").rangeslider().on("change", function (el) {
isoval = parseFloat(el.target.value);
$("#isovalue").text(el.target.value)
console.log("Change isosurface to "+el.target.value)
viewer.removeShape(shape)
shape=viewer.addIsosurface(voldata, { isoval: parseFloat(el.target.value), color: "blue", alpha: 0.85, smoothness: 1 });
viewer.render();
});
</script>
</body></html>"""
)
return f"""<iframe style="width: 100%; height: 1000px" name="result" allow="midi; geolocation; microphone; camera;
display-capture; encrypted-media;" sandbox="allow-modals allow-forms
allow-scripts allow-same-origin allow-popups
allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
allowpaymentrequest="" frameborder="0" srcdoc='{x}'></iframe>"""
def set_examples(example):
n, code, resids = example
return [n, code, resids]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Model()
model.to(device)
model.load_state_dict(
torch.load(
"weights/metal_0.5A_v3_d0.2_16Abox.pth",
map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
)
)
model.eval()
metal3d = gr.Blocks()
with metal3d:
gr.Markdown("# Metal3D")
gr.Markdown(
""" <em>Inference using CPU-only, can be quite slow for more than 20 residues. Use [Colab notebook](https://colab.research.google.com/github/lcbc-epfl/metal-site-prediction/blob/main/Metal3D/ColabMetal.ipynb) for GPU acceleration or use the docker image locally </em>
<small><code>docker run -it -p 7860:7860 --platform=linux/amd64 registry.hf.space/simonduerr-metal3d:latest python app.py</code></small>
"""
)
with gr.Tabs():
with gr.TabItem("Input"):
inp = gr.Textbox(
placeholder="PDB Code or Uniprot identifier or upload file below",
label="Input molecule",
)
file = gr.File(file_count="single", type="file")
with gr.TabItem("Settings"):
with gr.Row():
mode = gr.Radio(
["All metalbinding residues (ASP, CYS, GLU, HIS)", "All residues"],
label="Residues to use for prediction",
)
custom_resids = gr.Textbox(
placeholder="Comma separated list of residues",
label="Custom residues",
)
with gr.Row():
clustering_threshold = gr.Slider(
minimum=0.15,
maximum=1,
value=0.15,
step=0.05,
label="Clustering threshold",
)
distance_cutoff = gr.Slider(
minimum=1,
maximum=10,
value=7,
step=0.5,
label="Clustering distance cutoff",
)
btn = gr.Button("Run")
n = gr.Textbox(label="Label", visible=False)
examples = gr.Dataset(
components=[n, inp, custom_resids],
samples=[
["HCA2", "2CBA", ""],
["Nickel in GB1 dimer", "6F5N", ""],
["Zebrafish palmitoyltransferase ZDHHC15B PDB", "6BMS", ""],
[
"Human palmitoyltransferase ZDHHC23 AlphaFold",
"Q8IYP9",
"280,273,263,260,274,277,274,287",
],
],
)
examples.click(fn=set_examples, inputs=examples, outputs=examples.components)
gr.Markdown("# Output")
out = gr.Textbox(label="status")
mol = gr.HTML()
btn.click(
fn=update,
inputs=[inp, file, mode, custom_resids, clustering_threshold, distance_cutoff],
outputs=[out, mol],
)
metal3d.launch()