ribesstefano's picture
Update app.py
13490f0 verified
import protac_degradation_predictor as pdp
from typing import Dict, List, Literal
import difflib
import json
import torch
import numpy as np
from rdkit import Chem
import gradio as gr
import pandas as pd
def gradio_app(
protac_smiles: str | List[str],
e3_ligase: str | List[str],
target_uniprot: str | List[str],
cell_line: str | List[str],
use_models_from_cv: bool = False,
) -> Dict[str, np.ndarray]:
""" Wrapper for the Gradio interface.
Args:
protac_smiles (str | List[str]): PROTAC SMILES string or list of strings
e3_ligase (str | List[str]): E3 ligase string or list of strings
target_uniprot (str | List[str]): Uniprot ID string or list of strings
cell_line (str | List[str]): Cell line string or list of strings
use_models_from_cv (bool): Whether to use models trained during cross-validation
Returns:
Dict[str, np.ndarray]: Dictionary of mean and majority vote predictions
"""
device = 'cuda' if torch.cuda.is_available() else 'cpu'
avail_uniprots = pdp.avail_uniprots()
avail_cells = pdp.avail_cell_lines()
# Check if input Uniprot ID and cell line are among the available ones. If
# not, suggest the 3 closest matches.
if target_uniprot not in avail_uniprots:
suggestions = difflib.get_close_matches(target_uniprot, avail_uniprots, n=3, cutoff=0.5)
suggestion_text = "Did you mean:" + ", ".join(suggestions) + "?" if suggestions else "No close matches found."
raise gr.Error(f"Invalid Uniprot ID. {suggestion_text}", duration=None)
if cell_line not in avail_cells:
suggestions = difflib.get_close_matches(cell_line, avail_cells, n=3, cutoff=0.5)
suggestion_text = "Did you mean:" + ", ".join(suggestions) + "?" if suggestions else "No close matches found."
raise gr.Error(f"Invalid Cell Line. {suggestion_text}", duration=None)
prediction = pdp.get_protac_active_proba(
protac_smiles,
e3_ligase,
target_uniprot,
cell_line,
device=device,
use_models_from_cv=use_models_from_cv,
)
mean_pred = {"Active": float(prediction['mean']), "Inactive": 1. - float(prediction['mean'])}
majvote_pred = "Active" if prediction['majority_vote'] else "Inactive"
# Create a DataFrame for the raw predictions suitable for a bar plot
raw_preds_df = pd.DataFrame({
'Model': [f"Model n.{i}" for i in range(len(prediction["preds"]))],
'Active Probability': [float(p[0]) for p in prediction["preds"]],
})
# Write JSON to a file in tmp directory
json_data = {
"protac_smiles": protac_smiles,
"e3_ligase": e3_ligase,
"target_uniprot": target_uniprot,
"cell_line": cell_line,
"mean_prediction": mean_pred,
"majority_vote_prediction": majvote_pred,
"model_predictions": raw_preds_df.to_dict(orient="records")
}
json_content = json.dumps(json_data, indent=4)
json_filename = "/tmp/predictions.json"
with open(json_filename, "w") as f:
f.write(json_content)
return mean_pred, majvote_pred, raw_preds_df, json_filename
description = r"""A machine learning-based tool for predicting PROTAC protein degradation activity. This is a GUI app of the [PROTAC-Degradation-Predictor Github repository](https://github.com/ribesstefano/PROTAC-Degradation-Predictor/).
After having input the PROTAC SMILES string and its biological context, the app will predict its activity. A PROTAC is defined active when:
$$D_{max} \ge 60\\% \ \ \mathrm{and} \ \ pDC_{50} \ge 6$$
If you find this tool useful, please cite the following paper:
```
@article{Ribes_2024,
title={Modeling PROTAC degradation activity with machine learning},
volume={6},
ISSN={2667-3185},
url={http://dx.doi.org/10.1016/j.ailsci.2024.100104},
DOI={10.1016/j.ailsci.2024.100104},
journal={Artificial Intelligence in the Life Sciences},
publisher={Elsevier BV},
author={Ribes, Stefano and Nittinger, Eva and Tyrchan, Christian and Mercado, Rocío},
year={2024},
month=dec, pages={100104}
}
```
"""
demo = gr.Interface(
fn=gradio_app,
inputs=[
gr.Textbox(placeholder="PROTAC SMILES", label="PROTAC SMILES"),
gr.Dropdown(pdp.avail_e3_ligases(), label="E3 ligase"),
gr.Textbox(placeholder="E.g., Q92769", label="Target Uniprot"),
gr.Textbox(placeholder="E.g., HeLa", label="Cell line"),
gr.Checkbox(label="Use models trained during cross-validation"),
],
outputs=[
gr.Label(label="Average probability (confidence)"),
gr.Label(label="Majority vote prediction"),
gr.BarPlot(
x="Model",
y="Active Probability",
vertical=False,
y_lim=[0, 1],
tooltip="Active Probability",
title="Models' activity probability prediction",
label="Models' activity probability prediction",
show_label=False,
),
gr.DownloadButton(label="Download as JSON", size="sm"),
],
examples=[
[
"Cc1ncsc1-c1ccc(CNC(=O)[C@@H]2C[C@@H](O)CN2C(=O)[C@@H](NC(=O)COCCCCCCCCCOCC(=O)Nc2ccc(C(=O)Nc3ccc(F)cc3N)cc2)C(C)(C)C)cc1",
"VHL",
"Q92769",
"HeLa",
],
],
title="PROTAC Degradation Predictor",
submit_btn="Predict Activity",
description=description,
)
demo.launch()