Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import plotly.graph_objects as go | |
import numpy as np | |
import pandas as pd | |
from model.model import DTIModel | |
import spaces | |
dt_str = "14062024_0910" | |
def make_spider_plot(predictions, model_names, smiles_list): | |
fig = go.Figure() | |
for i, (prediction, smiles) in enumerate(zip(predictions, smiles_list)): | |
fig.add_trace(go.Scatterpolar( | |
r=prediction, | |
theta=model_names, | |
fill='toself', | |
name=smiles | |
)) | |
fig.update_layout( | |
polar=dict( | |
radialaxis=dict( | |
visible=True, | |
range=[0, 1] | |
)), | |
showlegend=True | |
) | |
return fig | |
def predict_and_plot(amino_acid_sequence, smiles_input, datasets): | |
model_ensemble = {} | |
gbm_model_paths = { | |
"BindingDB": f"model/xgb_models/xgb_model_BindingDB_{dt_str}_bt_optimized_0.json", | |
"BioSNAP": f"model/xgb_models/xgb_model_BIOSNAP_full_data_{dt_str}_bt_optimized_0.json", | |
"DAVIS": f"model/xgb_models/xgb_model_DAVIS_{dt_str}_bt_optimized_0.json", | |
"BarlowDTI XXL": f"model/xgb_models/{dt_str}_barlowdti_xxl_model.json", | |
} | |
for model in datasets: | |
print(f"Loading model {model}") | |
model_ensemble[model] = DTIModel( | |
bt_model_path=f"model/stash/{dt_str}", | |
gbm_model_path=gbm_model_paths[model], | |
) | |
smiles_list = smiles_input.strip().split('\n') | |
predictions = [] | |
for model in model_ensemble.values(): | |
model_predictions = model.predict(smiles_list, amino_acid_sequence) | |
predictions.append(model_predictions) | |
predictions = np.array(predictions).transpose().tolist() | |
df = pd.DataFrame(predictions, index=smiles_list, columns=datasets).reset_index() | |
df.columns = ["SMILES"] + datasets | |
fig = make_spider_plot(predictions, datasets, smiles_list) | |
return fig, df | |
dataset_names = [ | |
"BarlowDTI XXL", | |
"BindingDB", | |
"BioSNAP", | |
"DAVIS", | |
] | |
title = "Predict Drug-Target Interactions with <span style='font-variant:small-caps;'>BarlowDTI</span>" | |
description = """ | |
Enter the amino acid sequence and SMILES to get interaction predictions visualized as a spider graph and in a table. | |
The values can be interpreted as the probability of interaction between the drug and the target (0 = no interaction, 1 = interaction). | |
Thank you for using <span style='font-variant:small-caps;'>BarlowDTI</span>! | |
*Note: Thanks to ZeroGPU, you can run this model on a GPU for free.* | |
""" | |
article = """ | |
This interface lets the scientific community use <span style='font-variant:small-caps;'>BarlowDTI</span><sub>XXL</sub> to predict drug-target interactions. | |
The model ensemble consists of four models trained on different datasets: our own curated and refined dataset based on | |
[Golts et. al](https://doi.org/10.48550/arXiv.2401.17174) | |
in combination with | |
[BindingDB](https://doi.org/10.1093/nar/gkl999), | |
[BioSNAP](https://snap.stanford.edu/index.html), and | |
[DAVIS](https://doi.org/10.1038/nbt.1990). | |
If you use our model in your research, please cite our paper: | |
``` | |
@misc{schuh2024barlowtwinsdeepneural, | |
title={Barlow Twins Deep Neural Network for Advanced 1D Drug-Target Interaction Prediction}, | |
author={Maximilian G. Schuh and Davide Boldini and Annkathrin I. Bohne and Stephan A. Sieber}, | |
year={2024}, | |
eprint={2408.00040}, | |
archivePrefix={arXiv}, | |
primaryClass={q-bio.BM}, | |
url={https://arxiv.org/abs/2408.00040}, | |
} | |
``` | |
""" | |
theme = gr.themes.Base( | |
primary_hue="violet", | |
font=[gr.themes.GoogleFont('IBM Plex Sans'), 'ui-sans-serif', 'system-ui', 'sans-serif'], | |
) | |
iface = gr.Interface( | |
fn=predict_and_plot, | |
inputs=[ | |
gr.Textbox(label="Protein Sequence", info="Just one sequence is allowed. Remove FASTA syntax (e.g. > ABC).", placeholder="MRSWSTVMLAVLATAATVFGHDADPEMKMTTPQIIMRWGYPAMIYDVTTEDGYILELHRI"), | |
gr.Textbox(label="Molecule SMILES", info="One per line, multiple allowed.", placeholder="C1CSSC1CCCCC(=O)O\nCC1=CC(=C(C=C1)C(=O)O)O"), | |
gr.CheckboxGroup(choices=dataset_names, label="Select Models for Prediction", value="BarlowDTI XXL") | |
], | |
outputs=[ | |
gr.Plot(label="Predictions Visualization"), | |
gr.DataFrame(label="Predictions DataFrame"), | |
], | |
title=title, | |
description=description, | |
article=article, | |
theme=theme | |
) | |
iface.launch() | |