FART / app.py
yzimmermann's picture
Update app.py
147ef4b verified
raw
history blame contribute delete
No virus
1.37 kB
from rdkit import Chem
from rdkit.Chem import Draw
from transformers import pipeline
import gradio as gr
model_checkpoint = "FartLabs/FART_Chemberta_PubChem10M"
classifier = pipeline("text-classification", model=model_checkpoint, top_k=None)
def process_smiles(smiles):
# Validate and canonicalize SMILES
mol = Chem.MolFromSmiles(smiles)
if mol is None:
return "Invalid SMILES", None, "Invalid SMILES"
canonical_smiles = Chem.MolToSmiles(mol)
# Predict using the pipeline
predictions = classifier(canonical_smiles)
# Generate molecule image
img_path = "molecule.png"
img = Draw.MolToImage(mol)
img.save(img_path)
# Convert predictions to a friendly format
prediction_dict = {pred["label"]: pred["score"] for pred in predictions[0]}
return prediction_dict, img_path, canonical_smiles
iface = gr.Interface(
fn=process_smiles,
inputs=gr.Textbox(label="Input SMILES", value="O1[C@H](CO)[C@@H](O)[C@H](O)[C@@H](O)[C@H]1O[C@@]2(O[C@@H]([C@@H](O)[C@@H]2O)CO)CO"),
outputs=[
gr.Label(num_top_classes=3, label="Classification Probabilities"),
gr.Image(type="filepath", label="Molecule Image"),
gr.Textbox(label="Canonical SMILES")
],
title="FART",
description="Enter a SMILES string to get the taste classification probabilities."
)
iface.launch()