FART / app.py
yzimmermann's picture
Update app.py
8087713 verified
raw
history blame
No virus
1.33 kB
from rdkit import Chem
from rdkit.Chem import Draw
from transformers import pipeline
import gradio as gr
model_checkpoint = "yzimmermann/FART"
classifier = pipeline("text-classification", model=model_checkpoint, return_all_scores=True)
def process_smiles(smiles):
# Validate and canonicalize SMILES
mol = Chem.MolFromSmiles(smiles)
if mol is None:
return "Invalid SMILES", None, "Invalid SMILES"
canonical_smiles = Chem.MolToSmiles(mol)
# Predict using the pipeline
predictions = classifier(canonical_smiles)
# Generate molecule image
img_path = "molecule.png"
img = Draw.MolToImage(mol)
img.save(img_path)
# Convert predictions to a friendly format
prediction_dict = {pred["label"]: pred["score"] for pred in predictions[0]}
return prediction_dict, img_path, canonical_smiles
# Set up the Gradio interface
iface = gr.Interface(
fn=process_smiles,
inputs=gr.inputs.Textbox(label="Input SMILES"),
outputs=[
gr.outputs.Label(num_top_classes=3, label="Classification Probabilities"),
gr.outputs.Image(type="file", label="Molecule Image"),
gr.outputs.Textbox(label="Canonical SMILES")
],
title="FART",
description="Enter a SMILES string to get the taste classification probabilities."
)
iface.launch()