File size: 5,084 Bytes
9c83e37
 
 
c1c647f
9c83e37
349e4b9
260e798
349e4b9
9c83e37
349e4b9
 
 
260e798
 
 
 
349e4b9
 
 
 
619ce97
9c83e37
 
349e4b9
9c83e37
 
349e4b9
 
 
 
 
 
 
 
 
 
 
 
260e798
349e4b9
260e798
 
 
 
 
 
 
 
 
 
 
 
c1c647f
 
260e798
79b7bce
260e798
 
79b7bce
260e798
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c1c647f
349e4b9
 
 
9c83e37
 
349e4b9
9c83e37
 
 
c1c647f
260e798
349e4b9
c1c647f
349e4b9
c1c647f
9c83e37
 
c1c647f
349e4b9
 
 
 
f2d0983
0524a3d
 
 
f2d0983
0524a3d
 
 
 
f2d0983
0524a3d
f2d0983
 
349e4b9
f2d0983
3aa054d
349e4b9
 
 
 
c1c647f
349e4b9
 
 
 
 
 
f2d0983
9c83e37
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from functools import lru_cache
from transformers import TapexTokenizer, BartForConditionalGeneration
from deep_translator import GoogleTranslator
from pathlib import Path
import os, json, pandas as pd, torch

# ------------------------
# Config
# ------------------------
HF_MODEL_ID  = os.getenv("HF_MODEL_ID",  "stvnnnnnn/tapex-wikisql-best")
SPLIT        = os.getenv("TABLE_SPLIT",  "validation")   # "validation" ~ "dev"
INDEX        = int(os.getenv("TABLE_INDEX", "10"))
MAX_ROWS     = int(os.getenv("MAX_ROWS",    "12"))

# ------------------------
# App
# ------------------------
app = FastAPI(title="NL→SQL – TAPEX + WikiSQL (API)")
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], allow_credentials=True,
)

class NLQuery(BaseModel):
    nl_query: str

# ------------------------
# Modelo
# ------------------------
tok = TapexTokenizer.from_pretrained(HF_MODEL_ID)
model = BartForConditionalGeneration.from_pretrained(HF_MODEL_ID)
if torch.cuda.is_available():
    model = model.to("cuda")

# ------------------------
# Utilidades de carga robustas
# ------------------------
def _read_json_or_jsonl(p: Path) -> dict:
    """
    Lee un JSON normal (.json) o un JSONL (.jsonl) y devuelve el primer objeto.
    """
    txt = p.read_text(encoding="utf-8").strip()
    if p.suffix.lower() == ".jsonl":
        for line in txt.splitlines():
            s = line.strip()
            if s:
                return json.loads(s)
        raise ValueError(f"{p} está vacío.")
    return json.loads(txt)

@lru_cache(maxsize=32)
def get_table(split: str, index: int, max_rows: int) -> pd.DataFrame:
    """
    1) Intenta cargar ./data/<split>.json o ./data/<split>.jsonl (mapeando 'validation'->'dev').
    2) Si no existe, hace fallback a un ejemplo de WikiSQL (conversión Parquet oficial).
    """
    base_dir = Path(__file__).parent
    data_dir = base_dir / "data"

    # Normalizamos nombre local (para demo usamos 'dev')
    local_name = "dev" if split.lower() in ("validation", "dev") else split.lower()

    # 1) Buscar archivo local
    for candidate in (data_dir / f"{local_name}.json", data_dir / f"{local_name}.jsonl"):
        if candidate.exists():
            js = _read_json_or_jsonl(candidate)
            header = [str(h) for h in js["header"]]
            rows   = js["rows"][:max_rows]
            df = pd.DataFrame(rows, columns=header)
            df.columns = [str(c) for c in df.columns]
            return df

    # 2) Fallback: cargar un ejemplo del dataset WikiSQL (Parquet convertido)
    try:
        from datasets import load_dataset  # import diferido para arrancar más rápido
        ds = load_dataset("Salesforce/wikisql", split="validation", revision="refs/convert/parquet")
        if not (0 <= index < len(ds)):
            index = 0  # seguridad
        ex = ds[index]
        header = [str(h) for h in ex["table"]["header"]]
        rows   = ex["table"]["rows"][:max_rows]
        df = pd.DataFrame(rows, columns=header)
        df.columns = [str(c) for c in df.columns]
        return df
    except Exception as e:
        raise RuntimeError(f"No se pudo obtener una tabla: {e}")

# ------------------------
# Endpoints
# ------------------------
@app.get("/api/health")
def health():
    return {"ok": True, "model": HF_MODEL_ID, "split": SPLIT, "index": INDEX}

@app.get("/api/preview")
def preview():
    try:
        df = get_table(SPLIT, INDEX, MAX_ROWS)
        return {"columns": df.columns.tolist(), "rows": df.head(8).to_dict(orient="records")}
    except Exception as e:
        return {"error": str(e)}

@app.post("/api/nl2sql")
def nl2sql(q: NLQuery):
    try:
        text = (q.nl_query or "").strip()
        if not text:
            raise ValueError("Consulta vacía.")

        # Detectar si parece SQL
        lower = text.lower().strip()
        looks_like_sql = lower.startswith(("select", "with", "insert", "update", "delete", "create", "drop", "alter"))

        # Traducir a inglés si no es SQL
        query_en = text
        if not looks_like_sql:
            try:
                translated = GoogleTranslator(source="auto", target="en").translate(text)
                if translated:
                    query_en = translated
            except Exception:
                query_en = text  # fallback seguro

        # Procesar con TAPEX
        df = get_table(SPLIT, INDEX, MAX_ROWS)
        enc = tok(table=df, query=query_en, return_tensors="pt", truncation=True)
        if torch.cuda.is_available():
            enc = {k: v.to("cuda") for k, v in enc.items()}
        out = model.generate(**enc, max_length=160, num_beams=1)
        sql = tok.batch_decode(out, skip_special_tokens=True)[0]

        return {
            "consulta_original": text,
            "consulta_traducida": query_en,
            "sql_generado": sql
        }

    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))