File size: 1,330 Bytes
8087713
 
 
5b5221b
 
8087713
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
from rdkit import Chem
from rdkit.Chem import Draw
from transformers import pipeline
import gradio as gr

model_checkpoint = "yzimmermann/FART"
classifier = pipeline("text-classification", model=model_checkpoint, return_all_scores=True)

def process_smiles(smiles):
    # Validate and canonicalize SMILES
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        return "Invalid SMILES", None, "Invalid SMILES"
    canonical_smiles = Chem.MolToSmiles(mol)
    
    # Predict using the pipeline
    predictions = classifier(canonical_smiles)
    
    # Generate molecule image
    img_path = "molecule.png"
    img = Draw.MolToImage(mol)
    img.save(img_path)
    
    # Convert predictions to a friendly format
    prediction_dict = {pred["label"]: pred["score"] for pred in predictions[0]}
    
    return prediction_dict, img_path, canonical_smiles

# Set up the Gradio interface
iface = gr.Interface(
    fn=process_smiles,
    inputs=gr.inputs.Textbox(label="Input SMILES"),
    outputs=[
        gr.outputs.Label(num_top_classes=3, label="Classification Probabilities"),
        gr.outputs.Image(type="file", label="Molecule Image"),
        gr.outputs.Textbox(label="Canonical SMILES")
    ],
    title="FART",
    description="Enter a SMILES string to get the taste classification probabilities."
)

iface.launch()