stvnnnnnn commited on
Commit
260e798
·
verified ·
1 Parent(s): 5fe84ff

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -78
app.py CHANGED
@@ -2,19 +2,18 @@ from fastapi import FastAPI, HTTPException
2
  from fastapi.middleware.cors import CORSMiddleware
3
  from pydantic import BaseModel
4
  from functools import lru_cache
5
- from huggingface_hub import hf_hub_download
6
  from transformers import TapexTokenizer, BartForConditionalGeneration
7
  from deep_translator import GoogleTranslator
 
8
  import os, json, pandas as pd, torch
9
 
10
  # ------------------------
11
  # Config
12
  # ------------------------
13
- HF_MODEL_ID = os.getenv("HF_MODEL_ID", "stvnnnnnn/tapex-wikisql-best")
14
- WIKISQL_REPO = os.getenv("WIKISQL_REPO", "Salesforce/wikisql") # dataset oficial
15
- SPLIT = os.getenv("TABLE_SPLIT", "validation") # "validation" == dev en WikiSQL
16
- INDEX = int(os.getenv("TABLE_INDEX", "10"))
17
- MAX_ROWS = int(os.getenv("MAX_ROWS", "12"))
18
 
19
  # ------------------------
20
  # App
@@ -37,80 +36,57 @@ if torch.cuda.is_available():
37
  model = model.to("cuda")
38
 
39
  # ------------------------
40
- # Util: carga WikiSQL (JSONL)
41
  # ------------------------
42
- def _read_jsonl(path):
43
- with open(path, "r", encoding="utf-8") as f:
44
- for line in f:
45
- if line.strip():
46
- yield json.loads(line)
47
-
48
- def _download_file(filename: str) -> str:
49
- # descarga desde el dataset hug
50
- return hf_hub_download(repo_id=WIKISQL_REPO, filename=filename, repo_type="dataset")
 
 
 
51
 
52
  @lru_cache(maxsize=32)
53
- def get_table_from_wikisql(split: str, index: int, max_rows: int) -> pd.DataFrame:
54
  """
55
- Carga la tabla de WikiSQL sin scripts, usando directamente los JSONL del repo:
56
- - dev.jsonl (validation = 'dev' en terminología original)
57
- - dev.tables.jsonl
58
- Si cambias split a 'train' o 'test', intenta los nombres equivalentes.
59
  """
60
- # Mapeo simple: validation->dev, train->train, test->test
61
- split_map = {"validation": "dev", "dev": "dev", "train": "train", "test": "test"}
62
- base = split_map.get(split.lower(), "dev")
63
-
64
- # Posibles nombres de archivo en el repo (algunos mirrors usan variantes)
65
- qa_candidates = [f"data/{base}.jsonl", f"data/{base}.json", f"{base}.jsonl"]
66
- tbl_candidates = [f"data/{base}.tables.jsonl", f"{base}.tables.jsonl"]
67
-
68
- qa_path = None
69
- tbl_path = None
70
-
71
- # Descarga QA
72
- for cand in qa_candidates:
73
- try:
74
- qa_path = _download_file(cand)
75
- break
76
- except Exception:
77
- continue
78
- if qa_path is None:
79
- raise RuntimeError(f"No se encontró el archivo QA para split={split}. Intentos: {qa_candidates}")
80
-
81
- # Descarga tablas
82
- for cand in tbl_candidates:
83
- try:
84
- tbl_path = _download_file(cand)
85
- break
86
- except Exception:
87
- continue
88
- if tbl_path is None:
89
- raise RuntimeError(f"No se encontró el archivo de tablas para split={split}. Intentos: {tbl_candidates}")
90
-
91
- # Leemos la pregunta N (para tomar su table_id) — si no necesitas la pregunta, puedes omitir esto
92
- qa_list = list(_read_jsonl(qa_path))
93
- if not (0 <= index < len(qa_list)):
94
- raise IndexError(f"index={index} fuera de rango (0..{len(qa_list)-1}) para split={split}")
95
- table_id = qa_list[index].get("table_id") or qa_list[index].get("table", {}).get("id")
96
- if table_id is None:
97
- raise RuntimeError("No se pudo extraer 'table_id' del registro de QA.")
98
-
99
- # Buscamos esa tabla en dev.tables.jsonl
100
- header, rows = None, None
101
- for obj in _read_jsonl(tbl_path):
102
- if obj.get("id") == table_id:
103
- header = [str(h) for h in obj["header"]]
104
- rows = obj["rows"]
105
- break
106
- if header is None or rows is None:
107
- raise RuntimeError(f"No se encontró la tabla con id={table_id} en {os.path.basename(tbl_path)}")
108
-
109
- # recortamos filas
110
- rows = rows[:max_rows]
111
- df = pd.DataFrame(rows, columns=header)
112
- df.columns = [str(c) for c in df.columns]
113
- return df
114
 
115
  # ------------------------
116
  # Endpoints
@@ -122,7 +98,7 @@ def health():
122
  @app.get("/api/preview")
123
  def preview():
124
  try:
125
- df = get_table_from_wikisql(SPLIT, INDEX, MAX_ROWS)
126
  return {"columns": df.columns.tolist(), "rows": df.head(8).to_dict(orient="records")}
127
  except Exception as e:
128
  return {"error": str(e)}
@@ -134,14 +110,15 @@ def nl2sql(q: NLQuery):
134
  if not text:
135
  raise ValueError("Consulta vacía.")
136
 
137
- # Traducción ES->EN si detectamos acentos u otros
138
  is_ascii = all(ord(c) < 128 for c in text)
139
  query_en = text if is_ascii else GoogleTranslator(source="auto", target="en").translate(text)
140
 
141
- df = get_table_from_wikisql(SPLIT, INDEX, MAX_ROWS)
142
  enc = tok(table=df, query=query_en, return_tensors="pt", truncation=True)
143
  if torch.cuda.is_available():
144
  enc = {k: v.to("cuda") for k, v in enc.items()}
 
145
  out = model.generate(**enc, max_length=160, num_beams=1)
146
  sql = tok.batch_decode(out, skip_special_tokens=True)[0]
147
 
 
2
  from fastapi.middleware.cors import CORSMiddleware
3
  from pydantic import BaseModel
4
  from functools import lru_cache
 
5
  from transformers import TapexTokenizer, BartForConditionalGeneration
6
  from deep_translator import GoogleTranslator
7
+ from pathlib import Path
8
  import os, json, pandas as pd, torch
9
 
10
  # ------------------------
11
  # Config
12
  # ------------------------
13
+ HF_MODEL_ID = os.getenv("HF_MODEL_ID", "stvnnnnnn/tapex-wikisql-best")
14
+ SPLIT = os.getenv("TABLE_SPLIT", "validation") # "validation" ~ "dev"
15
+ INDEX = int(os.getenv("TABLE_INDEX", "10"))
16
+ MAX_ROWS = int(os.getenv("MAX_ROWS", "12"))
 
17
 
18
  # ------------------------
19
  # App
 
36
  model = model.to("cuda")
37
 
38
  # ------------------------
39
+ # Utilidades de carga robustas
40
  # ------------------------
41
+ def _read_json_or_jsonl(p: Path) -> dict:
42
+ """
43
+ Lee un JSON normal (.json) o un JSONL (.jsonl) y devuelve el primer objeto.
44
+ """
45
+ txt = p.read_text(encoding="utf-8").strip()
46
+ if p.suffix.lower() == ".jsonl":
47
+ for line in txt.splitlines():
48
+ s = line.strip()
49
+ if s:
50
+ return json.loads(s)
51
+ raise ValueError(f"{p} está vacío.")
52
+ return json.loads(txt)
53
 
54
  @lru_cache(maxsize=32)
55
+ def get_table(split: str, index: int, max_rows: int) -> pd.DataFrame:
56
  """
57
+ 1) Intenta cargar ./data/<split>.json o ./data/<split>.jsonl (mapeando 'validation'->'dev').
58
+ 2) Si no existe, hace fallback a un ejemplo de WikiSQL (conversión Parquet oficial).
 
 
59
  """
60
+ base_dir = Path(__file__).parent
61
+ data_dir = base_dir / "data"
62
+
63
+ # Normalizamos nombre local (para demo usamos 'dev')
64
+ local_name = "dev" if split.lower() in ("validation", "dev") else split.lower()
65
+
66
+ # 1) Buscar archivo local
67
+ for candidate in (data_dir / f"{local_name}.json", data_dir / f"{local_name}.jsonl"):
68
+ if candidate.exists():
69
+ js = _read_json_or_jsonl(candidate)
70
+ header = [str(h) for h in js["header"]]
71
+ rows = js["rows"][:max_rows]
72
+ df = pd.DataFrame(rows, columns=header)
73
+ df.columns = [str(c) for c in df.columns]
74
+ return df
75
+
76
+ # 2) Fallback: cargar un ejemplo del dataset WikiSQL (Parquet convertido)
77
+ try:
78
+ from datasets import load_dataset # import diferido para arrancar más rápido
79
+ ds = load_dataset("Salesforce/wikisql", split="validation", revision="refs/convert/parquet")
80
+ if not (0 <= index < len(ds)):
81
+ index = 0 # seguridad
82
+ ex = ds[index]
83
+ header = [str(h) for h in ex["table"]["header"]]
84
+ rows = ex["table"]["rows"][:max_rows]
85
+ df = pd.DataFrame(rows, columns=header)
86
+ df.columns = [str(c) for c in df.columns]
87
+ return df
88
+ except Exception as e:
89
+ raise RuntimeError(f"No se pudo obtener una tabla: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
  # ------------------------
92
  # Endpoints
 
98
  @app.get("/api/preview")
99
  def preview():
100
  try:
101
+ df = get_table(SPLIT, INDEX, MAX_ROWS)
102
  return {"columns": df.columns.tolist(), "rows": df.head(8).to_dict(orient="records")}
103
  except Exception as e:
104
  return {"error": str(e)}
 
110
  if not text:
111
  raise ValueError("Consulta vacía.")
112
 
113
+ # Traducción ES->EN si detectamos caracteres no ASCII
114
  is_ascii = all(ord(c) < 128 for c in text)
115
  query_en = text if is_ascii else GoogleTranslator(source="auto", target="en").translate(text)
116
 
117
+ df = get_table(SPLIT, INDEX, MAX_ROWS)
118
  enc = tok(table=df, query=query_en, return_tensors="pt", truncation=True)
119
  if torch.cuda.is_available():
120
  enc = {k: v.to("cuda") for k, v in enc.items()}
121
+
122
  out = model.generate(**enc, max_length=160, num_beams=1)
123
  sql = tok.batch_decode(out, skip_special_tokens=True)[0]
124