import gradio as gr import urllib import re import sys import warnings import torch import torch.nn as nn import ipywidgets as widgets from ipywidgets import interact, fixed from utils.helpers import * from utils.voxelization import processStructures from utils.model import Model import numpy as np import os import moleculekit print(moleculekit.__version__) def update(inp, file, mode, custom_resids, clustering_threshold, distance_cutoff): try: filepath = file.name except: print("using pdbfile") try: pdb_file = inp if ( re.match( "[OPQ][0-9][A-Z0-9]{3}[0-9]|[A-NR-Z][0-9]([A-Z][A-Z0-9]{2}[0-9]){1,2}", pdb_file, ).group() == pdb_file ): urllib.request.urlretrieve( f"https://alphafold.ebi.ac.uk/files/AF-{pdb_file}-F1-model_v2.pdb", f"files/{pdb_file}.pdb", ) filepath = f"files/{pdb_file}.pdb" except AttributeError: if len(inp) == 4: pdb_file = inp urllib.request.urlretrieve( f"http://files.rcsb.org/download/{pdb_file.lower()}.pdb1", f"files/{pdb_file}.pdb", ) filepath = f"files/{pdb_file}.pdb" else: return "pdb code must be 4 letters or Uniprot code does not match", "" identifier = os.path.basename(filepath) if mode == "All residues": print("using all residues") ids = get_all_protein_resids(filepath) elif len(custom_resids) != 0: print("using listed residues", custom_resids) ids = get_all_resids_from_list(filepath, custom_resids.replace(",", " ")) else: print("using metalbinding") ids = get_all_metalbinding_resids(filepath) print(filepath) print(ids) try: voxels, prot_centers, prot_N, prots = processStructures(filepath, ids) except Exception as e: print(e) return ( "Error", f"""
Something went wrong with the voxelization, reset custom residues and other input fiels and check error message

{e}
""", ) voxels.to(device) with warnings.catch_warnings(): warnings.filterwarnings("ignore") output = model(voxels) print(output.shape) prot_v = np.vstack(prot_centers) output_v = output.flatten().cpu().detach().numpy() bb = get_bb(prot_v) gridres = 0.5 grid, box_N = create_grid_fromBB(bb, voxelSize=gridres) probability_values = get_probability_mean(grid, prot_v, output_v) print(probability_values.shape) write_cubefile( bb, probability_values, box_N, outname=f"output/metal_{identifier}.cube", gridres=gridres, ) message = find_unique_sites( probability_values, grid, writeprobes=True, probefile=f"output/probes_{identifier}.pdb", threshold=distance_cutoff, p=clustering_threshold, ) del voxels torch.cuda.empty_cache() return message, molecule( filepath, f"output/probes_{identifier}.pdb", f"output/metal_{identifier}.cube", ) def read_mol(molpath): with open(molpath, "r") as fp: lines = fp.readlines() mol = "" for l in lines: mol += l return mol def molecule(pdb, probes, cube): mol = read_mol(pdb) probes = read_mol(probes) cubefile = read_mol(cube) x = ( """
Isovalue 0.5
""" ) return f"""""" def set_examples(example): n, code, resids = example return [n, code, resids] device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = Model() model.to(device) model.load_state_dict( torch.load( "weights/metal_0.5A_v3_d0.2_16Abox.pth", map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu"), ) ) model.eval() metal3d = gr.Blocks() with metal3d: gr.Markdown("# Metal3D") gr.Markdown( """ Inference using CPU-only, can be quite slow for more than 20 residues. Use [Colab notebook](https://colab.research.google.com/github/lcbc-epfl/metal-site-prediction/blob/main/Metal3D/ColabMetal.ipynb) for GPU acceleration or use the docker image locally docker run -it -p 7860:7860 --platform=linux/amd64 registry.hf.space/simonduerr-metal3d:latest python app.py """ ) with gr.Tabs(): with gr.TabItem("Input"): inp = gr.Textbox( placeholder="PDB Code or Uniprot identifier or upload file below", label="Input molecule", ) file = gr.File(file_count="single", type="file") with gr.TabItem("Settings"): with gr.Row(): mode = gr.Radio( ["All metalbinding residues (ASP, CYS, GLU, HIS)", "All residues"], label="Residues to use for prediction", ) custom_resids = gr.Textbox( placeholder="Comma separated list of residues", label="Custom residues", ) with gr.Row(): clustering_threshold = gr.Slider( minimum=0.15, maximum=1, value=0.15, step=0.05, label="Clustering threshold", ) distance_cutoff = gr.Slider( minimum=1, maximum=10, value=7, step=0.5, label="Clustering distance cutoff", ) btn = gr.Button("Run") n = gr.Textbox(label="Label", visible=False) examples = gr.Dataset( components=[n, inp, custom_resids], samples=[ ["HCA2", "2CBA", ""], ["Nickel in GB1 dimer", "6F5N", ""], ["Zebrafish palmitoyltransferase ZDHHC15B PDB", "6BMS", ""], [ "Human palmitoyltransferase ZDHHC23 AlphaFold", "Q8IYP9", "280,273,263,260,274,277,274,287", ], ], ) examples.click(fn=set_examples, inputs=examples, outputs=examples.components) gr.Markdown("# Output") out = gr.Textbox(label="status") mol = gr.HTML() btn.click( fn=update, inputs=[inp, file, mode, custom_resids, clustering_threshold, distance_cutoff], outputs=[out, mol], ) metal3d.launch()