import gradio as gr import plotly.graph_objects as go import numpy as np import pandas as pd from model.model import DTIModel import spaces dt_str = "14062024_0910" def make_spider_plot(predictions, model_names, smiles_list): fig = go.Figure() for i, (prediction, smiles) in enumerate(zip(predictions, smiles_list)): fig.add_trace(go.Scatterpolar( r=prediction, theta=model_names, fill='toself', name=smiles )) fig.update_layout( polar=dict( radialaxis=dict( visible=True, range=[0, 1] )), showlegend=True ) return fig @spaces.GPU def predict_and_plot(amino_acid_sequence, smiles_input, datasets): model_ensemble = {} gbm_model_paths = { "BindingDB": f"model/xgb_models/xgb_model_BindingDB_{dt_str}_bt_optimized_0.json", "BioSNAP": f"model/xgb_models/xgb_model_BIOSNAP_full_data_{dt_str}_bt_optimized_0.json", "DAVIS": f"model/xgb_models/xgb_model_DAVIS_{dt_str}_bt_optimized_0.json", "BarlowDTI XXL": f"model/xgb_models/{dt_str}_barlowdti_xxl_model.json", } for model in datasets: print(f"Loading model {model}") model_ensemble[model] = DTIModel( bt_model_path=f"model/stash/{dt_str}", gbm_model_path=gbm_model_paths[model], ) smiles_list = smiles_input.strip().split('\n') predictions = [] for model in model_ensemble.values(): model_predictions = model.predict(smiles_list, amino_acid_sequence) predictions.append(model_predictions) predictions = np.array(predictions).transpose().tolist() df = pd.DataFrame(predictions, index=smiles_list, columns=datasets).reset_index() df.columns = ["SMILES"] + datasets fig = make_spider_plot(predictions, datasets, smiles_list) return fig, df dataset_names = [ "BarlowDTI XXL", "BindingDB", "BioSNAP", "DAVIS", ] title = "Predict Drug-Target Interactions with BarlowDTI" description = """ Enter the amino acid sequence and SMILES to get interaction predictions visualized as a spider graph and in a table. The values can be interpreted as the probability of interaction between the drug and the target (0 = no interaction, 1 = interaction). Thank you for using BarlowDTI! Note: Inference may take longer, you can upgrade to a paid GPU-enabled plan for faster inference. """ article = """ This interface lets the scientific community use BarlowDTIXXL to predict drug-target interactions. The model ensemble consists of four models trained on different datasets: our own curated and refined dataset based on [Golts et. al](https://doi.org/10.48550/arXiv.2401.17174) in combination with [BindingDB](https://doi.org/10.1093/nar/gkl999), [BioSNAP](https://snap.stanford.edu/index.html), and [DAVIS](https://doi.org/10.1038/nbt.1990). If you use our approach in your research, please cite our paper: ``` @misc{schuh2024barlowtwinsdeepneural, title={Barlow Twins Deep Neural Network for Advanced 1D Drug-Target Interaction Prediction}, author={Maximilian G. Schuh and Davide Boldini and Stephan A. Sieber}, year={2024}, eprint={2408.00040}, archivePrefix={arXiv}, primaryClass={q-bio.BM}, url={https://arxiv.org/abs/2408.00040}, } ``` """ theme = gr.themes.Base( primary_hue="violet", font=[gr.themes.GoogleFont('IBM Plex Sans'), 'ui-sans-serif', 'system-ui', 'sans-serif'], ) iface = gr.Interface( fn=predict_and_plot, inputs=[ gr.Textbox(label="Protein Sequence", info="Just one sequence is allowed. Remove FASTA syntax (e.g. >ABC)."), gr.Textbox(label="Molecule SMILES", info="One per line, multiple allowed."), gr.CheckboxGroup(choices=dataset_names, label="Select Models for Prediction", value="BarlowDTI XXL") ], outputs=[ gr.Plot(label="Predictions Visualization"), gr.DataFrame(label="Predictions DataFrame"), ], title=title, description=description, article=article, theme=theme ) iface.launch()