yzimmermann commited on
Commit
8087713
1 Parent(s): 6e8f3ad

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -1
app.py CHANGED
@@ -1,3 +1,42 @@
 
 
 
1
  import gradio as gr
2
 
3
- gr.load("models/yzimmermann/FART").launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from rdkit import Chem
2
+ from rdkit.Chem import Draw
3
+ from transformers import pipeline
4
  import gradio as gr
5
 
6
+ model_checkpoint = "yzimmermann/FART"
7
+ classifier = pipeline("text-classification", model=model_checkpoint, return_all_scores=True)
8
+
9
+ def process_smiles(smiles):
10
+ # Validate and canonicalize SMILES
11
+ mol = Chem.MolFromSmiles(smiles)
12
+ if mol is None:
13
+ return "Invalid SMILES", None, "Invalid SMILES"
14
+ canonical_smiles = Chem.MolToSmiles(mol)
15
+
16
+ # Predict using the pipeline
17
+ predictions = classifier(canonical_smiles)
18
+
19
+ # Generate molecule image
20
+ img_path = "molecule.png"
21
+ img = Draw.MolToImage(mol)
22
+ img.save(img_path)
23
+
24
+ # Convert predictions to a friendly format
25
+ prediction_dict = {pred["label"]: pred["score"] for pred in predictions[0]}
26
+
27
+ return prediction_dict, img_path, canonical_smiles
28
+
29
+ # Set up the Gradio interface
30
+ iface = gr.Interface(
31
+ fn=process_smiles,
32
+ inputs=gr.inputs.Textbox(label="Input SMILES"),
33
+ outputs=[
34
+ gr.outputs.Label(num_top_classes=3, label="Classification Probabilities"),
35
+ gr.outputs.Image(type="file", label="Molecule Image"),
36
+ gr.outputs.Textbox(label="Canonical SMILES")
37
+ ],
38
+ title="FART",
39
+ description="Enter a SMILES string to get the taste classification probabilities."
40
+ )
41
+
42
+ iface.launch()