STOUTII / app.py
unruffle's picture
Upload 3 files
fcedb1b verified
# -*- coding: utf-8 -*-
import os
import gradio as gr
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import List
from rdkit import Chem
from rdkit.Chem import MolToSmiles
from STOUT import translate_forward
app = FastAPI(title="SMILES → IUPAC (STOUT-V2)")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
def canonicalize_smiles(s: str) -> str:
mol = Chem.MolFromSmiles(s)
if mol is None:
raise ValueError(f"非法SMILES,RDKit无法解析:{s!r}")
return MolToSmiles(mol, canonical=True)
class SMILESItem(BaseModel):
smiles: str
canonicalize: bool = True
class BatchRequest(BaseModel):
inputs: List[SMILESItem]
@app.get("/healthz")
def healthz():
return {"ok": True}
@app.post("/api/smiles2iupac")
def api_smiles2iupac(req: SMILESItem):
try:
s = (req.smiles or "").strip()
if not s:
return {"success": False, "error": "输入为空"}
s_norm = canonicalize_smiles(s) if req.canonicalize else s
name = translate_forward(s_norm)
return {
"success": True,
"input": s,
"smiles_processed": s_norm,
"iupac": name,
}
except Exception as e:
return {"success": False, "error": str(e)}
@app.post("/api/smiles2iupac/batch")
def api_smiles2iupac_batch(req: BatchRequest):
out = []
for item in req.inputs:
try:
s = (item.smiles or "").strip()
if not s:
out.append({"success": False, "error": "输入为空"})
continue
s_norm = canonicalize_smiles(s) if item.canonicalize else s
name = translate_forward(s_norm)
out.append({
"success": True,
"input": s,
"smiles_processed": s_norm,
"iupac": name,
})
except Exception as e:
out.append({"success": False, "input": item.smiles, "error": str(e)})
return out
def gradio_fn(s: str, canonicalize: bool):
if not (s or "").strip():
return "", "输入为空"
try:
s_norm = canonicalize_smiles(s) if canonicalize else s
name = translate_forward(s_norm)
return name, f"输入SMILES: {s}\n规范化SMILES: {s_norm}"
except Exception as e:
return "", f"Error: {e}"
demo = gr.Interface(
fn=gradio_fn,
inputs=[
gr.Textbox(label="输入SMILES", placeholder="支持任意合法SMILES写法"),
gr.Checkbox(label="自动RDKit规范化", value=True),
],
outputs=[
gr.Textbox(label="IUPAC名称", interactive=True),
gr.Textbox(label="调试信息"),
],
title="SMILES → IUPAC (STOUT-V2)",
description="Model: Kohulan/STOUT-V2 | 基于TensorFlow的翻译模型",
)
app = gr.mount_gradio_app(app, demo, path="/")