Spaces:
Running
Running
File size: 5,084 Bytes
9c83e37 c1c647f 9c83e37 349e4b9 260e798 349e4b9 9c83e37 349e4b9 260e798 349e4b9 619ce97 9c83e37 349e4b9 9c83e37 349e4b9 260e798 349e4b9 260e798 c1c647f 260e798 79b7bce 260e798 79b7bce 260e798 c1c647f 349e4b9 9c83e37 349e4b9 9c83e37 c1c647f 260e798 349e4b9 c1c647f 349e4b9 c1c647f 9c83e37 c1c647f 349e4b9 f2d0983 0524a3d f2d0983 0524a3d f2d0983 0524a3d f2d0983 349e4b9 f2d0983 3aa054d 349e4b9 c1c647f 349e4b9 f2d0983 9c83e37 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from functools import lru_cache
from transformers import TapexTokenizer, BartForConditionalGeneration
from deep_translator import GoogleTranslator
from pathlib import Path
import os, json, pandas as pd, torch
# ------------------------
# Config
# ------------------------
HF_MODEL_ID = os.getenv("HF_MODEL_ID", "stvnnnnnn/tapex-wikisql-best")
SPLIT = os.getenv("TABLE_SPLIT", "validation") # "validation" ~ "dev"
INDEX = int(os.getenv("TABLE_INDEX", "10"))
MAX_ROWS = int(os.getenv("MAX_ROWS", "12"))
# ------------------------
# App
# ------------------------
app = FastAPI(title="NL→SQL – TAPEX + WikiSQL (API)")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], allow_credentials=True,
)
class NLQuery(BaseModel):
nl_query: str
# ------------------------
# Modelo
# ------------------------
tok = TapexTokenizer.from_pretrained(HF_MODEL_ID)
model = BartForConditionalGeneration.from_pretrained(HF_MODEL_ID)
if torch.cuda.is_available():
model = model.to("cuda")
# ------------------------
# Utilidades de carga robustas
# ------------------------
def _read_json_or_jsonl(p: Path) -> dict:
"""
Lee un JSON normal (.json) o un JSONL (.jsonl) y devuelve el primer objeto.
"""
txt = p.read_text(encoding="utf-8").strip()
if p.suffix.lower() == ".jsonl":
for line in txt.splitlines():
s = line.strip()
if s:
return json.loads(s)
raise ValueError(f"{p} está vacío.")
return json.loads(txt)
@lru_cache(maxsize=32)
def get_table(split: str, index: int, max_rows: int) -> pd.DataFrame:
"""
1) Intenta cargar ./data/<split>.json o ./data/<split>.jsonl (mapeando 'validation'->'dev').
2) Si no existe, hace fallback a un ejemplo de WikiSQL (conversión Parquet oficial).
"""
base_dir = Path(__file__).parent
data_dir = base_dir / "data"
# Normalizamos nombre local (para demo usamos 'dev')
local_name = "dev" if split.lower() in ("validation", "dev") else split.lower()
# 1) Buscar archivo local
for candidate in (data_dir / f"{local_name}.json", data_dir / f"{local_name}.jsonl"):
if candidate.exists():
js = _read_json_or_jsonl(candidate)
header = [str(h) for h in js["header"]]
rows = js["rows"][:max_rows]
df = pd.DataFrame(rows, columns=header)
df.columns = [str(c) for c in df.columns]
return df
# 2) Fallback: cargar un ejemplo del dataset WikiSQL (Parquet convertido)
try:
from datasets import load_dataset # import diferido para arrancar más rápido
ds = load_dataset("Salesforce/wikisql", split="validation", revision="refs/convert/parquet")
if not (0 <= index < len(ds)):
index = 0 # seguridad
ex = ds[index]
header = [str(h) for h in ex["table"]["header"]]
rows = ex["table"]["rows"][:max_rows]
df = pd.DataFrame(rows, columns=header)
df.columns = [str(c) for c in df.columns]
return df
except Exception as e:
raise RuntimeError(f"No se pudo obtener una tabla: {e}")
# ------------------------
# Endpoints
# ------------------------
@app.get("/api/health")
def health():
return {"ok": True, "model": HF_MODEL_ID, "split": SPLIT, "index": INDEX}
@app.get("/api/preview")
def preview():
try:
df = get_table(SPLIT, INDEX, MAX_ROWS)
return {"columns": df.columns.tolist(), "rows": df.head(8).to_dict(orient="records")}
except Exception as e:
return {"error": str(e)}
@app.post("/api/nl2sql")
def nl2sql(q: NLQuery):
try:
text = (q.nl_query or "").strip()
if not text:
raise ValueError("Consulta vacía.")
# Detectar si parece SQL
lower = text.lower().strip()
looks_like_sql = lower.startswith(("select", "with", "insert", "update", "delete", "create", "drop", "alter"))
# Traducir a inglés si no es SQL
query_en = text
if not looks_like_sql:
try:
translated = GoogleTranslator(source="auto", target="en").translate(text)
if translated:
query_en = translated
except Exception:
query_en = text # fallback seguro
# Procesar con TAPEX
df = get_table(SPLIT, INDEX, MAX_ROWS)
enc = tok(table=df, query=query_en, return_tensors="pt", truncation=True)
if torch.cuda.is_available():
enc = {k: v.to("cuda") for k, v in enc.items()}
out = model.generate(**enc, max_length=160, num_beams=1)
sql = tok.batch_decode(out, skip_special_tokens=True)[0]
return {
"consulta_original": text,
"consulta_traducida": query_en,
"sql_generado": sql
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) |