gnina-torch / app.py
Rocco Meli
add references
4d2dccd
import os
def load_file(fpath: str) -> str:
"""
Load file content.
Parameters
----------
fpath: str
File path
Returns
-------
str
File content
"""
with open(fpath, "r") as f:
return f.read()
def load_html(html_file: str) -> str:
return load_file(os.path.join("html", html_file))
def load_md(md_file: str) -> str:
return load_file(os.path.join("md", md_file))
def load_protein_from_file(protein_file) -> str:
"""
Parameters
----------
protein_file: _TemporaryFileWrapper
GradIO file object
Returns
-------
str
Protein PDB file content
"""
with open(protein_file.name, "r") as f:
return f.read()
def load_ligand_from_file(ligand_file) -> str:
"""
Load ligand from file.
Parameters
----------
ligand_file: _TemporaryFileWrapper
GradIO file object
Returns
-------
str
Ligand SDF file content
"""
with open(ligand_file.name, "r") as f:
return f.read()
def protein_html_from_file(protein_file) -> str:
"""
Wrap 3Dmol.js code around protein PDB file.
Parameters
----------
protein_file: _TemporaryFileWrapper
GradIO file object
Returns
-------
str
3Dmol.js HTML code for displaying a PDB file
"""
protein = load_protein_from_file(protein_file)
protein_html = load_html("protein.html")
html = protein_html.replace("%%%PDB%%%", protein)
wrapper = load_html("wrapper.html")
return wrapper.replace("%%%HTML%%%", html)
def ligand_html_from_file(ligand_file) -> str:
"""
Wrap 3Dmol.js code around ligand SDF file.
Parameters
----------
ligand_file: _TemporaryFileWrapper
GradIO file object
Returns
-------
str
3Dmol.js HTML code for displaying a SDF file
"""
ligand = load_ligand_from_file(ligand_file)
ligand_html = load_html("ligand.html")
html = ligand_html.replace("%%%SDF%%%", ligand)
wrapper = load_html("wrapper.html")
return wrapper.replace("%%%HTML%%%", html)
def protein_ligand_html_from_file(protein_file, ligand_file):
protein = load_protein_from_file(protein_file)
ligand = load_ligand_from_file(ligand_file)
protein_ligand_html = load_html("pl.html")
html = protein_ligand_html.replace("%%%PDB%%%", protein)
html = html.replace("%%%SDF%%%", ligand)
wrapper = load_html("wrapper.html")
return wrapper.replace("%%%HTML%%%", html)
def predict(protein_file, ligand_file, cnn: str = "default"):
"""
Run gnina-torch on protein-ligand complex.
Parameters
----------
protein_file: _TemporaryFileWrapper
GradIO file object
ligand_file: _TemporaryFileWrapper
GradIO file object
cnn: str
CNN model to use
Returns
-------
dict[str, float]
CNNscore, CNNaffinity, and CNNvariance
"""
import molgrid
from gninatorch import gnina, dataloaders
import torch
import pandas as pd
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
model, ensemble = gnina.setup_gnina_model(cnn, 23.5, 0.5)
model.eval()
model.to(device)
example_provider = molgrid.ExampleProvider(
data_root="",
balanced=False,
shuffle=False,
default_batch_size=1,
iteration_scheme=molgrid.IterationScheme.SmallEpoch,
)
# FIXME: Do this properly... =( [Might require light gnina-torch refactoring]
with open("data.in", "w") as f:
f.write(protein_file.name)
f.write(" ")
f.write(ligand_file.name)
print("Populating example provider... ", end="")
example_provider.populate("data.in")
print("done")
grid_maker = molgrid.GridMaker(resolution=0.5, dimension=23.5)
# TODO: Allow average over different rotations
loader = dataloaders.GriddedExamplesLoader(
example_provider=example_provider,
grid_maker=grid_maker,
random_translation=0.0, # No random translations for inference
random_rotation=False, # No random rotations for inference
grids_only=True,
device=device,
)
print("Loading and gridding data... ", end="")
batch = next(loader)
print("done")
print("Predicting... ", end="")
with torch.no_grad():
log_pose, affinity, affinity_var = model(batch)
print("done")
return pd.DataFrame(
{
"CNNscore": [torch.exp(log_pose[:, -1]).item()],
"CNNaffinity": [affinity.item()],
"CNNvariance": [affinity_var.item()],
}
).round(6)
if __name__ == "__main__":
import gradio as gr
demo = gr.Blocks()
with demo:
gr.Markdown(load_md("intro.md"))
gr.Markdown(load_md("input.md"))
with gr.Row():
with gr.Box():
pfile = gr.File(file_count="single", label="Protein file (PDB)")
gr.Examples(["mols/1cbr_protein.pdb"], inputs=pfile)
pbtn = gr.Button("View Protein")
pbtn.click(fn=protein_html_from_file, inputs=[pfile], outputs=gr.HTML())
with gr.Box():
lfile = gr.File(file_count="single", label="Ligand file (SDF)")
gr.Examples(["mols/1cbr_ligand.sdf"], inputs=lfile)
lbtn = gr.Button("View Ligand")
lbtn.click(fn=ligand_html_from_file, inputs=[lfile], outputs=gr.HTML())
with gr.Box():
with gr.Column():
# TODO: Automatically display complex when both files are uploaded
plbtn = gr.Button("View Protein-Ligand Complex")
plbtn.click(
fn=protein_ligand_html_from_file,
inputs=[pfile, lfile],
outputs=gr.HTML(),
)
gr.Markdown(load_md("scoring.md"))
with gr.Row():
df = gr.Dataframe()
with gr.Column():
dd = gr.Dropdown(
choices=[
"default",
"redock_default2018_ensemble",
"general_default2018_ensemble",
"crossdock_default2018_ensemble",
],
value="default",
label="CNN model",
)
with gr.Row():
btn = gr.Button("Score!")
btn.click(fn=predict, inputs=[pfile, lfile, dd], outputs=df)
gr.Markdown(
load_md("acknowledgements.md"),
)
gr.Markdown(load_md("references.md"))
demo.launch()