BarlowDTI / app.py
mschuh's picture
Upload 5 files
4f0db87 verified
raw
history blame
4.28 kB
import gradio as gr
import plotly.graph_objects as go
import numpy as np
import pandas as pd
from model.model import DTIModel
import spaces
dt_str = "14062024_0910"
def make_spider_plot(predictions, model_names, smiles_list):
fig = go.Figure()
for i, (prediction, smiles) in enumerate(zip(predictions, smiles_list)):
fig.add_trace(go.Scatterpolar(
r=prediction,
theta=model_names,
fill='toself',
name=smiles
))
fig.update_layout(
polar=dict(
radialaxis=dict(
visible=True,
range=[0, 1]
)),
showlegend=True
)
return fig
@spaces.GPU
def predict_and_plot(amino_acid_sequence, smiles_input, datasets):
model_ensemble = {}
gbm_model_paths = {
"BindingDB": f"model/xgb_models/xgb_model_BindingDB_{dt_str}_bt_optimized_0.json",
"BioSNAP": f"model/xgb_models/xgb_model_BIOSNAP_full_data_{dt_str}_bt_optimized_0.json",
"DAVIS": f"model/xgb_models/xgb_model_DAVIS_{dt_str}_bt_optimized_0.json",
"BarlowDTI XXL": f"model/xgb_models/{dt_str}_barlowdti_xxl_model.json",
}
for model in datasets:
print(f"Loading model {model}")
model_ensemble[model] = DTIModel(
bt_model_path=f"model/stash/{dt_str}",
gbm_model_path=gbm_model_paths[model],
)
smiles_list = smiles_input.strip().split('\n')
predictions = []
for model in model_ensemble.values():
model_predictions = model.predict(smiles_list, amino_acid_sequence)
predictions.append(model_predictions)
predictions = np.array(predictions).transpose().tolist()
df = pd.DataFrame(predictions, index=smiles_list, columns=datasets).reset_index()
df.columns = ["SMILES"] + datasets
fig = make_spider_plot(predictions, datasets, smiles_list)
return fig, df
dataset_names = [
"BarlowDTI XXL",
"BindingDB",
"BioSNAP",
"DAVIS",
]
title = "Predict Drug-Target Interactions with <span style='font-variant:small-caps;'>BarlowDTI</span>"
description = """
Enter the amino acid sequence and SMILES to get interaction predictions visualized as a spider graph and in a table.
The values can be interpreted as the probability of interaction between the drug and the target (0 = no interaction, 1 = interaction).
Thank you for using <span style='font-variant:small-caps;'>BarlowDTI</span>!
Note: Inference may take longer, you can upgrade to a paid GPU-enabled plan for faster inference.
"""
article = """
This interface lets the scientific community use <span style='font-variant:small-caps;'>BarlowDTI</span><sub>XXL</sub> to predict drug-target interactions.
The model ensemble consists of four models trained on different datasets: our own curated and refined dataset based on
[Golts et. al](https://doi.org/10.48550/arXiv.2401.17174)
in combination with
[BindingDB](https://doi.org/10.1093/nar/gkl999),
[BioSNAP](https://snap.stanford.edu/index.html), and
[DAVIS](https://doi.org/10.1038/nbt.1990).
If you use our approach in your research, please cite our paper:
```
@misc{schuh2024barlowtwinsdeepneural,
title={Barlow Twins Deep Neural Network for Advanced 1D Drug-Target Interaction Prediction},
author={Maximilian G. Schuh and Davide Boldini and Stephan A. Sieber},
year={2024},
eprint={2408.00040},
archivePrefix={arXiv},
primaryClass={q-bio.BM},
url={https://arxiv.org/abs/2408.00040},
}
```
"""
theme = gr.themes.Base(
primary_hue="violet",
font=[gr.themes.GoogleFont('IBM Plex Sans'), 'ui-sans-serif', 'system-ui', 'sans-serif'],
)
iface = gr.Interface(
fn=predict_and_plot,
inputs=[
gr.Textbox(label="Protein Sequence", info="Just one sequence is allowed. Remove FASTA syntax (e.g. >ABC)."),
gr.Textbox(label="Molecule SMILES", info="One per line, multiple allowed."),
gr.CheckboxGroup(choices=dataset_names, label="Select Models for Prediction", value="BarlowDTI XXL")
],
outputs=[
gr.Plot(label="Predictions Visualization"),
gr.DataFrame(label="Predictions DataFrame"),
],
title=title,
description=description,
article=article,
theme=theme
)
iface.launch()