stvnnnnnn commited on
Commit
c1c647f
·
verified ·
1 Parent(s): 719fbce

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +128 -38
app.py CHANGED
@@ -1,75 +1,165 @@
 
 
1
  from fastapi import FastAPI, HTTPException
 
2
  from fastapi.middleware.cors import CORSMiddleware
3
  from pydantic import BaseModel
4
- from deep_translator import GoogleTranslator
 
 
 
 
 
5
  from datasets import load_dataset
 
6
  from transformers import TapexTokenizer, BartForConditionalGeneration
7
- import pandas as pd, torch, os
8
 
9
- # === Config ===
 
10
  HF_MODEL_ID = os.getenv("HF_MODEL_ID", "stvnnnnnn/tapex-wikisql-best")
11
  TABLE_SPLIT = os.getenv("TABLE_SPLIT", "validation")
12
  TABLE_INDEX = int(os.getenv("TABLE_INDEX", "10"))
13
- MAX_ROWS = int(os.getenv("MAX_ROWS", "12"))
14
 
15
- torch.set_num_threads(1)
 
 
16
 
17
- app = FastAPI(title="NL→SQL – TAPEX + WikiSQL (HF Space)")
18
 
19
- # CORS: permite que Vercel (o cualquier origen) consuma la API
 
20
  app.add_middleware(
21
  CORSMiddleware,
22
- allow_origins=["*"], allow_credentials=False,
23
- allow_methods=["*"], allow_headers=["*"],
 
 
24
  )
25
 
26
- # Carga modelo/tokenizer con bajo pico de RAM (CPU)
27
- tok = TapexTokenizer.from_pretrained(HF_MODEL_ID)
28
- model = BartForConditionalGeneration.from_pretrained(
29
- HF_MODEL_ID, low_cpu_mem_usage=True
30
- ).to("cpu")
31
 
32
- class NLQuery(BaseModel):
33
- nl_query: str
 
 
 
 
 
 
 
34
 
35
- def get_example(split, index):
36
- # streaming para no cargar todo WikiSQL en RAM
37
- ds = load_dataset("Salesforce/wikisql", split=split, streaming=True)
38
- for i, ex in enumerate(ds):
39
- if i == index:
40
- return ex
41
- raise IndexError("Index fuera de rango")
 
 
42
 
43
- def load_table(split=TABLE_SPLIT, index=TABLE_INDEX, max_rows=MAX_ROWS):
44
- ex = get_example(split, index)
45
  header = [str(h) for h in ex["table"]["header"]]
46
  rows = ex["table"]["rows"][:max_rows]
47
- return pd.DataFrame(rows, columns=header)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
 
 
49
  @app.get("/api/health")
50
  def health():
51
  return {"ok": True, "model": HF_MODEL_ID, "split": TABLE_SPLIT, "index": TABLE_INDEX}
52
 
 
53
  @app.get("/api/preview")
54
  def preview():
55
- df = load_table()
56
- return {"columns": df.columns.tolist(), "rows": df.head(8).to_dict(orient="records")}
 
 
 
 
 
 
 
57
 
58
  @app.post("/api/nl2sql")
59
  def nl2sql(q: NLQuery):
 
 
 
 
 
60
  try:
61
- text = (q.nl_query or "").strip()
62
- if not text:
63
- raise ValueError("Consulta vacía.")
 
 
64
 
65
- is_ascii = all(ord(c) < 128 for c in text)
66
- query_en = text if is_ascii else GoogleTranslator(source="auto", target="en").translate(text)
 
67
 
68
- df = load_table()
69
- enc = tok(table=df, query=query_en, return_tensors="pt", truncation=True)
70
- out = model.generate(**enc, max_length=160, num_beams=1)
71
- sql = tok.batch_decode(out, skip_special_tokens=True)[0]
 
72
 
73
- return {"consulta_original": text, "consulta_traducida": query_en, "sql_generado": sql}
 
74
  except Exception as e:
75
  raise HTTPException(status_code=500, detail=str(e))
 
1
+ # app.py — NL→SQL (TAPEX + WikiSQL) backend for HF Spaces
2
+
3
  from fastapi import FastAPI, HTTPException
4
+ from fastapi.responses import HTMLResponse, JSONResponse
5
  from fastapi.middleware.cors import CORSMiddleware
6
  from pydantic import BaseModel
7
+
8
+ import os
9
+ import torch
10
+ import pandas as pd
11
+ from functools import lru_cache
12
+
13
  from datasets import load_dataset
14
+ from deep_translator import GoogleTranslator
15
  from transformers import TapexTokenizer, BartForConditionalGeneration
 
16
 
17
+
18
+ # --------- Configuración y defaults ----------
19
  HF_MODEL_ID = os.getenv("HF_MODEL_ID", "stvnnnnnn/tapex-wikisql-best")
20
  TABLE_SPLIT = os.getenv("TABLE_SPLIT", "validation")
21
  TABLE_INDEX = int(os.getenv("TABLE_INDEX", "10"))
22
+ MAX_ROWS = int(os.getenv("MAX_ROWS", "100")) # límite prudente en CPU
23
 
24
+ # Asegura caché escribible en Space
25
+ os.environ["HF_HOME"] = "/app/.cache/huggingface"
26
+ os.makedirs(os.environ["HF_HOME"], exist_ok=True)
27
 
 
28
 
29
+ # --------- App & CORS ----------
30
+ app = FastAPI(title="NL→SQL – TAPEX + WikiSQL (HF Space)")
31
  app.add_middleware(
32
  CORSMiddleware,
33
+ allow_origins=["*"], # cambia a tu dominio de Vercel cuando lo tengas
34
+ allow_credentials=True,
35
+ allow_methods=["*"],
36
+ allow_headers=["*"],
37
  )
38
 
 
 
 
 
 
39
 
40
+ # --------- Carga perezosa (lazy) ---------
41
+ @lru_cache(maxsize=1)
42
+ def get_model_and_tokenizer():
43
+ tok = TapexTokenizer.from_pretrained(HF_MODEL_ID)
44
+ # En Spaces free es CPU; no uses device_map="auto" para evitar dependencia de accelerate
45
+ model = BartForConditionalGeneration.from_pretrained(HF_MODEL_ID)
46
+ model.eval()
47
+ return tok, model
48
+
49
 
50
+ @lru_cache(maxsize=32)
51
+ def get_table(split: str, index: int, max_rows: int) -> pd.DataFrame:
52
+ """
53
+ Carga una tabla de WikiSQL y devuelve un DataFrame.
54
+ Evitamos el revision parquet (que a veces falla en Spaces) y usamos el split normal.
55
+ """
56
+ ds = load_dataset("Salesforce/wikisql", split=split)
57
+ if index < 0 or index >= len(ds):
58
+ raise IndexError(f"TABLE_INDEX fuera de rango (0..{len(ds)-1}).")
59
 
60
+ ex = ds[index]
 
61
  header = [str(h) for h in ex["table"]["header"]]
62
  rows = ex["table"]["rows"][:max_rows]
63
+ df = pd.DataFrame(rows, columns=header)
64
+ # Normaliza nombres de columnas a string
65
+ df.columns = [str(c) for c in df.columns]
66
+ return df
67
+
68
+
69
+ # --------- Esquema de petición ----------
70
+ class NLQuery(BaseModel):
71
+ nl_query: str
72
+
73
+
74
+ # --------- Mini UI en / ----------
75
+ INDEX_HTML = """
76
+ <!doctype html>
77
+ <html lang="es">
78
+ <meta charset="utf-8">
79
+ <title>NL→SQL (TAPEX + WikiSQL)</title>
80
+ <style>
81
+ body{font-family:system-ui,-apple-system,Segoe UI,Roboto,Ubuntu,sans-serif;max-width:860px;margin:30px auto;padding:0 16px;color:#eaeaea;background:#0f1115}
82
+ h1{font-size:1.6rem;margin:0 0 8px}
83
+ .card{background:#171923;border:1px solid #232736;border-radius:12px;padding:16px;margin:18px 0}
84
+ input,button,select{font-size:1rem}
85
+ input{width:100%;padding:10px 12px;border-radius:8px;border:1px solid #2a3042;background:#0f1115;color:#eaeaea}
86
+ button{padding:10px 16px;border-radius:8px;border:1px solid #2a3042;background:#1f2433;color:#eaeaea;cursor:pointer}
87
+ pre{white-space:pre-wrap;background:#0f1115;border:1px solid #2a3042;padding:12px;border-radius:8px}
88
+ .row{display:flex;gap:12px;align-items:center}
89
+ </style>
90
+ <h1>🧠 NL → SQL (TAPEX + WikiSQL)</h1>
91
+ <div class="card">
92
+ <p><b>Backend:</b> este Space ofrece endpoints REST; prueba una consulta:</p>
93
+ <div class="row">
94
+ <input id="q" placeholder="Ej.: Muestra los jugadores que son Guards." />
95
+ <button onclick="run()">Generar SQL</button>
96
+ <button onclick="prev()">Ver preview tabla</button>
97
+ </div>
98
+ <p style="font-size:.9rem;opacity:.8">Swagger: <a href="./docs" target="_blank">/docs</a> · Salud: <a href="./api/health" target="_blank">/api/health</a></p>
99
+ <pre id="out">Listo para generar...</pre>
100
+ </div>
101
+ <script>
102
+ async function run(){
103
+ const q = document.getElementById('q').value.trim();
104
+ const r = await fetch('./api/nl2sql', {method:'POST', headers:{'Content-Type':'application/json'}, body: JSON.stringify({nl_query: q}) });
105
+ document.getElementById('out').textContent = JSON.stringify(await r.json(), null, 2);
106
+ }
107
+ async function prev(){
108
+ const r = await fetch('./api/preview');
109
+ document.getElementById('out').textContent = JSON.stringify(await r.json(), null, 2);
110
+ }
111
+ </script>
112
+ </html>
113
+ """
114
+
115
+ @app.get("/", response_class=HTMLResponse)
116
+ def home():
117
+ return INDEX_HTML
118
 
119
+
120
+ # --------- Rutas API ----------
121
  @app.get("/api/health")
122
  def health():
123
  return {"ok": True, "model": HF_MODEL_ID, "split": TABLE_SPLIT, "index": TABLE_INDEX}
124
 
125
+
126
  @app.get("/api/preview")
127
  def preview():
128
+ try:
129
+ df = get_table(TABLE_SPLIT, TABLE_INDEX, MAX_ROWS)
130
+ # Regresa primeras 8 filas para no saturar
131
+ data = df.head(8).to_dict(orient="records")
132
+ return {"columns": list(df.columns), "rows": data}
133
+ except Exception as e:
134
+ # Envía mensaje simple (para no saturar logs)
135
+ return JSONResponse(status_code=500, content={"error": str(e)})
136
+
137
 
138
  @app.post("/api/nl2sql")
139
  def nl2sql(q: NLQuery):
140
+ nl = (q.nl_query or "").strip()
141
+ if not nl:
142
+ raise HTTPException(status_code=400, detail="Consulta vacía.")
143
+
144
+ # Traducción ES→EN si detectamos caracteres no ASCII
145
  try:
146
+ is_ascii = all(ord(c) < 128 for c in nl)
147
+ nl_en = nl if is_ascii else GoogleTranslator(source="auto", target="en").translate(nl)
148
+ except Exception:
149
+ # Si la traducción falla, seguimos con el texto original
150
+ nl_en = nl
151
 
152
+ try:
153
+ df = get_table(TABLE_SPLIT, TABLE_INDEX, MAX_ROWS)
154
+ tok, model = get_model_and_tokenizer()
155
 
156
+ # Tokenización limitada
157
+ enc = tok(table=df, query=nl_en, return_tensors="pt", truncation=True, max_length=512)
158
+
159
+ with torch.inference_mode():
160
+ out = model.generate(**enc, max_length=160, num_beams=1)
161
 
162
+ sql = tok.batch_decode(out, skip_special_tokens=True)[0]
163
+ return {"consulta_original": nl, "consulta_traducida": nl_en, "sql_generado": sql}
164
  except Exception as e:
165
  raise HTTPException(status_code=500, detail=str(e))