nlp2sql-space / app.py
Dub973's picture
Upload app.py
03f92a0 verified
import os, time, json, sqlite3, textwrap, requests, sys
import gradio as gr
# ----------------- CONFIG -----------------
MODEL_ID = "gpt2" # always public; swap later for sqlcoder
API_URL = f"https://api-inference.huggingface.co/models/{MODEL_ID}"
HF_TOKEN = os.getenv("HF_TOKEN")
HEADERS = {"Authorization": f"Bearer {HF_TOKEN}"} if HF_TOKEN else {}
DB_PATH = "company.db"
SCHEMA_FILE = "schema.sql"
# -------------- UTIL: DB ------------------
def create_db_if_needed():
if os.path.exists(DB_PATH):
return
with open(SCHEMA_FILE) as f, sqlite3.connect(DB_PATH) as conn:
conn.executescript(f.read())
# -------------- UTIL: CALL API ------------
def nlp_to_sql(question, schema_ddl):
prompt = textwrap.dedent(f"""
Translate the natural language question to a SQL query.
### Schema
{schema_ddl}
### Question
{question}
### SQL
""")
payload = {"inputs": prompt, "parameters": {"max_new_tokens": 64}}
# ---------- DEBUG PRINTS ----------
print("=" * 60, file=sys.stderr)
print("DEBUG URL:", API_URL, file=sys.stderr)
print("DEBUG Token present?:", bool(HF_TOKEN), file=sys.stderr)
# ----------------------------------
try:
r = requests.post(API_URL, headers=HEADERS, json=payload, timeout=60)
except Exception as e:
return f"[ConnErr] {e}"
# ---------- MORE DEBUG ----------
print("DEBUG Status code:", r.status_code, file=sys.stderr)
print("DEBUG Raw response (first 500 bytes):", r.text[:500], file=sys.stderr)
print("=" * 60, file=sys.stderr)
# ---------------------------------
if r.status_code != 200:
return f"[API {r.status_code}] {r.text[:200]}"
try:
out = r.json()
generated = out[0].get("generated_text", "No generated_text")
except Exception as e:
return f"[JSONErr] {e}"
return generated.split("### SQL")[-1].strip() or "[Empty SQL]"
# -------------- PIPELINE ------------------
def run(query):
t0, trace = time.time(), []
create_db_if_needed()
with open(SCHEMA_FILE) as f:
schema = f.read()
trace.append(("Schema", "loaded"))
sql = nlp_to_sql(query, schema)
trace.append(("LLM", sql))
try:
with sqlite3.connect(DB_PATH) as conn:
cur = conn.execute(sql)
rows = cur.fetchall()
cols = [d[0] for d in cur.description] if cur.description else []
result = {"columns": cols, "rows": rows}
trace.append(("Exec", f"{len(rows)} rows"))
except Exception as e:
result = {"error": str(e)}
trace.append(("Exec error", str(e)))
trace.append(("Time", f"{time.time()-t0:.2f}s"))
return sql, json.dumps(result, indent=2), "\n".join(f"{s}: {m}" for s, m in trace)
# -------------- UI ------------------------
with gr.Blocks(title="Debug NLP→SQL") as demo:
gr.Markdown("### Debugging Hugging Face Inference API calls")
q = gr.Textbox(label="Ask", placeholder="Example: List employees")
with gr.Row():
sql_box = gr.Code(label="Generated SQL / debug output")
res_box = gr.Code(label="Query result")
tbox = gr.Textbox(label="Trace")
btn = gr.Button("Run")
btn.click(run, q, [sql_box, res_box, tbox])
if __name__ == "__main__":
demo.launch()