Upload 4 files
Browse files
README.md
CHANGED
|
@@ -1,3 +1,13 @@
|
|
| 1 |
-
---
|
| 2 |
-
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Text2sql Demo
|
| 3 |
+
emoji: 📊
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: green
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: 5.8.0
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: false
|
| 10 |
+
license: mit
|
| 11 |
+
python_version: 3.10.13
|
| 12 |
+
short_description: 'Text to SQL with RLHF'
|
| 13 |
+
---
|
app.py
ADDED
|
@@ -0,0 +1,569 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
GRADIO DEMO UI - LAZY LOADING EDITION
|
| 3 |
+
NL → SQL → Result Table
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import gradio as gr
|
| 7 |
+
import pandas as pd
|
| 8 |
+
import re
|
| 9 |
+
import time
|
| 10 |
+
import os
|
| 11 |
+
import torch
|
| 12 |
+
import sys
|
| 13 |
+
import json
|
| 14 |
+
import subprocess
|
| 15 |
+
import base64
|
| 16 |
+
import io
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
from typing import Iterator
|
| 19 |
+
|
| 20 |
+
# ==========================================
|
| 21 |
+
# RELATIVE PATH RESOLUTION (GLOBAL)
|
| 22 |
+
# ==========================================
|
| 23 |
+
try:
|
| 24 |
+
PROJECT_ROOT = Path(__file__).resolve().parent
|
| 25 |
+
except NameError:
|
| 26 |
+
PROJECT_ROOT = Path(".").resolve()
|
| 27 |
+
|
| 28 |
+
if (PROJECT_ROOT / "data" / "database").exists():
|
| 29 |
+
DB_ROOT = PROJECT_ROOT / "data" / "database"
|
| 30 |
+
else:
|
| 31 |
+
DB_ROOT = PROJECT_ROOT / "final_databases"
|
| 32 |
+
|
| 33 |
+
def get_db_path(db_id: str) -> str:
|
| 34 |
+
path1 = DB_ROOT / db_id / f"{db_id}.sqlite"
|
| 35 |
+
path2 = DB_ROOT / f"{db_id}.sqlite"
|
| 36 |
+
return str(path1) if path1.exists() else str(path2)
|
| 37 |
+
|
| 38 |
+
# ==========================================
|
| 39 |
+
# 🔥 CUDA MOCK PATCH FOR MAC (MPS) / CPU
|
| 40 |
+
# ==========================================
|
| 41 |
+
if not torch.cuda.is_available():
|
| 42 |
+
class MockCUDAEvent:
|
| 43 |
+
def __init__(self, enable_timing=False, blocking=False, interprocess=False):
|
| 44 |
+
self.t = 0.0
|
| 45 |
+
def record(self, stream=None):
|
| 46 |
+
self.t = time.perf_counter()
|
| 47 |
+
def elapsed_time(self, end_event):
|
| 48 |
+
return (end_event.t - self.t) * 1000.0
|
| 49 |
+
|
| 50 |
+
torch.cuda.Event = MockCUDAEvent
|
| 51 |
+
if not hasattr(torch.cuda, 'synchronize'):
|
| 52 |
+
torch.cuda.synchronize = lambda: None
|
| 53 |
+
|
| 54 |
+
# ==========================================
|
| 55 |
+
# IMPORTS & ENGINE SETUP
|
| 56 |
+
# ==========================================
|
| 57 |
+
from src.quantized_text2sql_engine import QuantizedText2SQLEngine
|
| 58 |
+
from src.schema_encoder import SchemaEncoder
|
| 59 |
+
|
| 60 |
+
DEFAULT_QUANT_ARTIFACT = str(PROJECT_ROOT / "int8_dynamic")
|
| 61 |
+
|
| 62 |
+
_ENGINE_CACHE = {}
|
| 63 |
+
_QUERY_LOG = []
|
| 64 |
+
_PERF_LOG = []
|
| 65 |
+
_SUCCESS_LOG = []
|
| 66 |
+
|
| 67 |
+
_OP_STATS = {
|
| 68 |
+
"SELECT": {"ok": 0, "fail": 0}, "WHERE": {"ok": 0, "fail": 0}, "JOIN": {"ok": 0, "fail": 0},
|
| 69 |
+
"GROUP_BY": {"ok": 0, "fail": 0}, "ORDER_BY": {"ok": 0, "fail": 0}, "HAVING": {"ok": 0, "fail": 0}, "LIMIT": {"ok": 0, "fail": 0},
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
def get_quant_engine(artifact_dir: str, use_constrained: bool = False, exec_workers: int = 8, use_cache: bool = True):
|
| 73 |
+
key = (artifact_dir, bool(use_constrained), int(exec_workers), bool(use_cache))
|
| 74 |
+
if key not in _ENGINE_CACHE:
|
| 75 |
+
try:
|
| 76 |
+
_ENGINE_CACHE[key] = QuantizedText2SQLEngine(artifact_dir, device="cpu", use_constrained=bool(use_constrained), exec_workers=int(exec_workers), use_cache=bool(use_cache))
|
| 77 |
+
except TypeError:
|
| 78 |
+
_ENGINE_CACHE[key] = QuantizedText2SQLEngine(artifact_dir)
|
| 79 |
+
return _ENGINE_CACHE[key]
|
| 80 |
+
|
| 81 |
+
# 🚨 LAZY LOADING: We DO NOT load the model here! We only load the fast Schema Encoder.
|
| 82 |
+
quant_engine = None
|
| 83 |
+
try:
|
| 84 |
+
schema_encoder = SchemaEncoder(DB_ROOT)
|
| 85 |
+
except Exception as e:
|
| 86 |
+
print(f"Warning: SchemaEncoder failed to load: {e}")
|
| 87 |
+
schema_encoder = None
|
| 88 |
+
|
| 89 |
+
SAMPLES = [
|
| 90 |
+
("Show 10 distinct employee first names.", "chinook_1"), ("Which artist has the most albums?", "chinook_1"),
|
| 91 |
+
("List all the tracks that belong to the 'Rock' genre.", "chinook_1"), ("What are the names of all the cities?", "flight_1"),
|
| 92 |
+
("Find the flight number and cost of the cheapest flight.", "flight_1"), ("List the airlines that fly out of New York.", "flight_1"),
|
| 93 |
+
("Which campus was opened between 1935 and 1939?", "csu_1"), ("Count the number of students in each department.", "college_2"),
|
| 94 |
+
("List the names of all clubs.", "club_1"), ("How many members does each club have?", "club_1"),
|
| 95 |
+
("Show the names of all cinemas.", "cinema"), ("Which cinema has the most screens?", "cinema")
|
| 96 |
+
]
|
| 97 |
+
SAMPLE_QUESTIONS = [q[0] for q in SAMPLES]
|
| 98 |
+
|
| 99 |
+
def explain_sql(sql):
|
| 100 |
+
if not sql: return ""
|
| 101 |
+
explanation = "This SQL query retrieves information from the database."
|
| 102 |
+
sql_lower = sql.lower()
|
| 103 |
+
if "join" in sql_lower: explanation += "\n• It combines data from multiple tables using JOIN."
|
| 104 |
+
if "where" in sql_lower: explanation += "\n• It filters rows using a WHERE condition."
|
| 105 |
+
if "group by" in sql_lower: explanation += "\n• It groups results using GROUP BY."
|
| 106 |
+
if "order by" in sql_lower: explanation += "\n• It sorts the results using ORDER BY."
|
| 107 |
+
if "limit" in sql_lower: explanation += "\n• It limits the number of returned rows."
|
| 108 |
+
return explanation
|
| 109 |
+
|
| 110 |
+
def sql_ops(sql: str) -> list[str]:
|
| 111 |
+
s = (sql or "").lower()
|
| 112 |
+
ops = ["SELECT"]
|
| 113 |
+
if " where " in f" {s} ": ops.append("WHERE")
|
| 114 |
+
if " join " in f" {s} ": ops.append("JOIN")
|
| 115 |
+
if " group by " in f" {s} ": ops.append("GROUP_BY")
|
| 116 |
+
if " order by " in f" {s} ": ops.append("ORDER_BY")
|
| 117 |
+
if " having " in f" {s} ": ops.append("HAVING")
|
| 118 |
+
if " limit " in f" {s} ": ops.append("LIMIT")
|
| 119 |
+
return ops
|
| 120 |
+
|
| 121 |
+
def classify_error(sql: str, error_msg: str | None = None, *, timed_out: bool = False):
|
| 122 |
+
s = (sql or "").lower()
|
| 123 |
+
m = (error_msg or "").lower()
|
| 124 |
+
if timed_out or "interrupted" in m or "timeout" in m: return "timeout"
|
| 125 |
+
if not s.strip().startswith(("select", "with")): return "syntax_error"
|
| 126 |
+
if " join " in f" {s} " and " on " not in f" {s} ": return "missing_join"
|
| 127 |
+
if " where " in f" {s} " and not any(op in s for op in ["=", ">", "<", " in ", " like ", " between ", " is null", " is not null"]): return "wrong_where"
|
| 128 |
+
if ("is null" in s or "is not null" in s) and ("no such column" in m or "misuse" in m): return "null_handling"
|
| 129 |
+
if "no such table" in m: return "missing_table"
|
| 130 |
+
if "no such column" in m: return "missing_column"
|
| 131 |
+
if "ambiguous column name" in m: return "ambiguous_column"
|
| 132 |
+
if "datatype mismatch" in m or "type mismatch" in m: return "type_mismatch"
|
| 133 |
+
if "misuse of aggregate" in m or "misuse of aggregate function" in m: return "wrong_aggregation"
|
| 134 |
+
if "syntax error" in m: return "syntax_error"
|
| 135 |
+
if "near" in m and "syntax error" in m: return "syntax_error"
|
| 136 |
+
if "runtime" in m or "constraint failed" in m: return "runtime_error"
|
| 137 |
+
return "other"
|
| 138 |
+
|
| 139 |
+
def get_hint(error_type):
|
| 140 |
+
hints = {
|
| 141 |
+
"missing_join": "Check JOIN conditions between tables.", "wrong_aggregation": "Use proper aggregation like avg(column).",
|
| 142 |
+
"wrong_where": "Check WHERE condition syntax.", "syntax_error": "Ensure SQL starts with SELECT.",
|
| 143 |
+
"missing_table": "Use only tables from the provided schema.", "missing_column": "Use only columns from the provided schema.",
|
| 144 |
+
"ambiguous_column": "Disambiguate by using table.column.", "timeout": "Query took too long; simplify joins.", "other": "Review SQL logic."
|
| 145 |
+
}
|
| 146 |
+
return hints.get(error_type, "Review query.")
|
| 147 |
+
|
| 148 |
+
def is_relevant_to_schema(question, db_id):
|
| 149 |
+
if schema_encoder is None: return True
|
| 150 |
+
try: raw_schema = schema_encoder.structured_schema(db_id).lower()
|
| 151 |
+
except: return True
|
| 152 |
+
schema_words = set(re.findall(r'[a-z0-9_]+', raw_schema))
|
| 153 |
+
q_words = re.findall(r'[a-z0-9_]+', question.lower())
|
| 154 |
+
stop_words = {"show", "list", "all", "what", "is", "the", "how", "many", "count", "find", "get", "me", "a", "an", "of", "in", "for", "from", "with", "which", "are", "there", "give", "tell", "details", "info", "data", "everything"}
|
| 155 |
+
meaningful_q_words = [w for w in q_words if w not in stop_words and not w.isdigit()]
|
| 156 |
+
if not meaningful_q_words: return True
|
| 157 |
+
for word in meaningful_q_words:
|
| 158 |
+
singular_word = word[:-1] if word.endswith('s') else word
|
| 159 |
+
if word in schema_words or singular_word in schema_words: return True
|
| 160 |
+
return False
|
| 161 |
+
|
| 162 |
+
def run_query(method, sample_q, custom_q, db_id):
|
| 163 |
+
global quant_engine
|
| 164 |
+
|
| 165 |
+
# 🚨 LAZY LOADING: We load the heavy AI model ONLY when the button is clicked.
|
| 166 |
+
if quant_engine is None:
|
| 167 |
+
print(f"First request detected! Loading AI model from {DEFAULT_QUANT_ARTIFACT}...", flush=True)
|
| 168 |
+
try:
|
| 169 |
+
quant_engine = get_quant_engine(DEFAULT_QUANT_ARTIFACT, use_constrained=False, exec_workers=8, use_cache=True)
|
| 170 |
+
if quant_engine is None:
|
| 171 |
+
return "-- ❌ ENGINE CRASH", pd.DataFrame(columns=["Error"]), "Failed to load model. Did you move the tokenizer files and add config.json to int8_dynamic/?"
|
| 172 |
+
except Exception as e:
|
| 173 |
+
return f"-- ❌ ENGINE CRASH\n-- {str(e)}", pd.DataFrame(columns=["Error Status"]), f"Critical failure loading model: {e}"
|
| 174 |
+
|
| 175 |
+
def _log(error_type: str, *, question: str, db_id_val: str, sql: str = "", error_msg: str = "") -> None:
|
| 176 |
+
_QUERY_LOG.append({"t": time.time(), "db_id": str(db_id_val), "question": str(question), "sql": str(sql), "error_type": str(error_type), "error_msg": str(error_msg)})
|
| 177 |
+
|
| 178 |
+
def _perf_log(payload: dict) -> None:
|
| 179 |
+
_PERF_LOG.append(payload)
|
| 180 |
+
if len(_PERF_LOG) > 1000: del _PERF_LOG[:200]
|
| 181 |
+
|
| 182 |
+
raw_question = sample_q if method == "💡 Pick a Sample" else custom_q
|
| 183 |
+
|
| 184 |
+
if not raw_question or str(raw_question).strip() == "":
|
| 185 |
+
return "-- No input provided", pd.DataFrame(columns=["Warning"]), "⚠️ Please enter a question."
|
| 186 |
+
if not db_id or str(db_id).strip() == "":
|
| 187 |
+
return "-- No database selected", pd.DataFrame(columns=["Warning"]), "⚠️ Please select a database."
|
| 188 |
+
|
| 189 |
+
typo_corrections = [(r'\bshaw\b', 'show'), (r'\bshw\b', 'show'), (r'\bsho\b', 'show'), (r'\blsit\b', 'list'), (r'\blis\b', 'list'), (r'\bfidn\b', 'find'), (r'\bfnd\b', 'find'), (r'\bgte\b', 'get')]
|
| 190 |
+
question = str(raw_question)
|
| 191 |
+
for bad, good in typo_corrections: question = re.sub(bad, good, question, flags=re.IGNORECASE)
|
| 192 |
+
q_lower = question.strip().lower()
|
| 193 |
+
|
| 194 |
+
if len(q_lower.split()) < 2:
|
| 195 |
+
_log("gibberish", question=question, db_id_val=str(db_id), error_msg="gibberish filtered")
|
| 196 |
+
return "-- Input Blocked", pd.DataFrame(columns=["Warning"]), "⚠️ Please enter a clear, meaningful natural language question (more than one word)."
|
| 197 |
+
|
| 198 |
+
if re.search(r'\b(delete|update|insert|drop|alter|truncate)\b', q_lower):
|
| 199 |
+
_log("blocked_dml", question=question, db_id_val=str(db_id), error_msg="DML blocked")
|
| 200 |
+
return "-- ❌ BLOCKED: Data Modification", pd.DataFrame(columns=["Security Alert"]), "🛑 Security Alert: Modifying or deleting data is strictly prohibited."
|
| 201 |
+
|
| 202 |
+
if not is_relevant_to_schema(question, db_id):
|
| 203 |
+
_log("out_of_domain", question=question, db_id_val=str(db_id), error_msg="out of domain")
|
| 204 |
+
return "-- ❌ BLOCKED: Out of Domain", pd.DataFrame(columns=["Domain Alert"]), f"🛑 Relevance Alert: I don't see anything related to your question in the '{db_id}' schema."
|
| 205 |
+
|
| 206 |
+
start_time = time.time()
|
| 207 |
+
t0 = time.perf_counter()
|
| 208 |
+
ui_warnings = ""
|
| 209 |
+
|
| 210 |
+
try:
|
| 211 |
+
try:
|
| 212 |
+
result = quant_engine.ask(question, str(db_id), num_beams=4, max_new_tokens=120, timeout_s=2.0)
|
| 213 |
+
except TypeError:
|
| 214 |
+
result = quant_engine.ask(question, str(db_id))
|
| 215 |
+
except Exception as e:
|
| 216 |
+
_log("backend_crash", question=question, db_id_val=str(db_id), error_msg=str(e))
|
| 217 |
+
return f"-- ❌ BACKEND CRASH\n-- {str(e)}", pd.DataFrame(columns=["Error Status"]), f"❌ CRITICAL BACKEND CRASH:\n{str(e)}"
|
| 218 |
+
|
| 219 |
+
final_sql = str(result.get("sql", ""))
|
| 220 |
+
model_sql = final_sql
|
| 221 |
+
|
| 222 |
+
num_match = re.search(r'\b(?:show|list|top|limit|get|first|last|sample|of)\s+(?:[a-zA-Z_]+\s+)?(\d+)\b', q_lower)
|
| 223 |
+
if not num_match and q_lower.startswith(("show", "list", "get")):
|
| 224 |
+
num_match = re.search(r'\b(\d+)\b', q_lower)
|
| 225 |
+
|
| 226 |
+
if num_match and final_sql:
|
| 227 |
+
limit_val = num_match.group(1)
|
| 228 |
+
final_sql = re.sub(rf"(?i)\s*(?:where|having|and)?\s*count\s*\(\s*\*\s*\)\s*=\s*{limit_val}", "", final_sql)
|
| 229 |
+
final_sql = re.sub(rf"(?i)\s*(?:where|and)\s+[a-zA-Z0-9_.]+\s*=\s*['\"]?{limit_val}['\"]?", "", final_sql)
|
| 230 |
+
final_sql = re.sub(r"(?i)\s*where\s*$", "", final_sql)
|
| 231 |
+
final_sql = re.sub(r"(?i)\s*where\s+(group by|order by|limit)", r" \1", final_sql)
|
| 232 |
+
|
| 233 |
+
agg_kws = ["most", "top", "highest", "lowest", "count", "many", "group", "frequent", "popular"]
|
| 234 |
+
if not any(k in q_lower for k in agg_kws):
|
| 235 |
+
final_sql = re.sub(r"(?i)\s*group by\s+[a-zA-Z0-9_.]+\s*order by\s+count\(\*\)\s*(?:desc|asc)?", "", final_sql)
|
| 236 |
+
final_sql = re.sub(r"(?i)\s*order by\s+count\(\*\)\s*(?:desc|asc)?", "", final_sql)
|
| 237 |
+
final_sql = re.sub(r"(?i),\s*count\(\*\)", "", final_sql)
|
| 238 |
+
final_sql = re.sub(r"(?i)count\(\*\)\s*,", "", final_sql)
|
| 239 |
+
|
| 240 |
+
if "group by" in final_sql.lower() and not re.search(r'(?i)\b(count|sum|avg|max|min)\b\(', final_sql):
|
| 241 |
+
final_sql = re.sub(r"(?i)\s*group by\s+[a-zA-Z0-9_.]+", "", final_sql)
|
| 242 |
+
|
| 243 |
+
if "limit" not in final_sql.lower():
|
| 244 |
+
final_sql = f"{final_sql.strip().rstrip(';')} LIMIT {limit_val}"
|
| 245 |
+
|
| 246 |
+
# Execution
|
| 247 |
+
from src.sql_validator import validate_sql_schema
|
| 248 |
+
db_path = get_db_path(str(db_id))
|
| 249 |
+
|
| 250 |
+
try: strict_valid, _ = validate_sql_schema(final_sql, db_path)
|
| 251 |
+
except Exception: strict_valid = False
|
| 252 |
+
|
| 253 |
+
error_msg = None
|
| 254 |
+
rows, cols = [], []
|
| 255 |
+
sqlite_success = False
|
| 256 |
+
|
| 257 |
+
try:
|
| 258 |
+
rows, cols = quant_engine._execute_one(final_sql, db_path, timeout_s=2.0)
|
| 259 |
+
sqlite_success = True
|
| 260 |
+
except Exception as e:
|
| 261 |
+
error_msg = str(e)
|
| 262 |
+
sqlite_success = False
|
| 263 |
+
|
| 264 |
+
if not sqlite_success and model_sql and model_sql != final_sql:
|
| 265 |
+
try:
|
| 266 |
+
alt_rows, alt_cols = quant_engine._execute_one(model_sql, db_path, timeout_s=2.0)
|
| 267 |
+
final_sql = model_sql
|
| 268 |
+
rows, cols = alt_rows, alt_cols
|
| 269 |
+
sqlite_success = True
|
| 270 |
+
error_msg = None
|
| 271 |
+
except Exception: pass
|
| 272 |
+
|
| 273 |
+
valid = sqlite_success
|
| 274 |
+
|
| 275 |
+
if error_msg or not valid:
|
| 276 |
+
et = classify_error(final_sql, str(error_msg or ""), timed_out=("interrupted" in str(error_msg or "").lower()))
|
| 277 |
+
_log(et, question=str(question), db_id_val=str(db_id), sql=str(final_sql), error_msg=str(error_msg or "Execution failed"))
|
| 278 |
+
|
| 279 |
+
latency = round(time.time() - start_time, 3)
|
| 280 |
+
t1 = time.perf_counter()
|
| 281 |
+
|
| 282 |
+
engine_stats_after = quant_engine.stats() if hasattr(quant_engine, 'stats') else {}
|
| 283 |
+
|
| 284 |
+
perf = {
|
| 285 |
+
"db_id": str(db_id), "use_constrained_decoding": False, "num_beams": 4,
|
| 286 |
+
"latency_total_ms": round((t1 - t0) * 1000.0, 2), "constraint_ok": bool(strict_valid), "has_error": bool(error_msg),
|
| 287 |
+
"exec_cache_hit_rate": float(engine_stats_after.get("exec_cache_hit_rate", 0.0) or 0.0),
|
| 288 |
+
}
|
| 289 |
+
_perf_log(perf)
|
| 290 |
+
|
| 291 |
+
window = _PERF_LOG[-50:]
|
| 292 |
+
avg_ms = sum(float(x.get("latency_total_ms", 0.0) or 0.0) for x in window) / len(window) if window else 0.0
|
| 293 |
+
constraint_rate = sum(1 for x in window if x.get("constraint_ok")) / len(window) if window else 0.0
|
| 294 |
+
|
| 295 |
+
perf_block = (
|
| 296 |
+
"\n\n---\nPerformance (task impact)\n"
|
| 297 |
+
f"- Total latency (ms): {perf['latency_total_ms']}\n"
|
| 298 |
+
f"- Strict Python Validator OK (Task 3): {perf['constraint_ok']}\n"
|
| 299 |
+
f"- Exec cache hit-rate (Task 1/5): {round(perf['exec_cache_hit_rate'], 3)}\n"
|
| 300 |
+
f"- Rolling avg latency last 50 (ms): {round(avg_ms, 2)}\n"
|
| 301 |
+
f"- Rolling constraint rate last 50: {round(constraint_rate, 3)}\n"
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
if error_msg or not valid:
|
| 305 |
+
display_sql = final_sql if final_sql.strip() else "-- ❌ INVALID SQL"
|
| 306 |
+
explanation = f"{ui_warnings}❌ Error Details:\n\n"
|
| 307 |
+
if error_msg: explanation += f"{error_msg}\n\n"
|
| 308 |
+
|
| 309 |
+
error_type = classify_error(final_sql, str(error_msg or ""))
|
| 310 |
+
explanation += f"Error Type: {error_type}\nHint: {get_hint(error_type)}"
|
| 311 |
+
explanation += perf_block
|
| 312 |
+
ops = sql_ops(final_sql)
|
| 313 |
+
for op in ops:
|
| 314 |
+
if op in _OP_STATS: _OP_STATS[op]["fail"] += 1
|
| 315 |
+
return display_sql, pd.DataFrame(columns=["Execution Notice"]), explanation
|
| 316 |
+
|
| 317 |
+
safe_cols = cols if cols else ["Result"]
|
| 318 |
+
explanation = f"{ui_warnings}✅ Query executed successfully\n\nRows returned: {len(rows)}\nExecution Time: {latency} sec\n\n{explain_sql(final_sql)}{perf_block}"
|
| 319 |
+
|
| 320 |
+
ops = sql_ops(final_sql)
|
| 321 |
+
for op in ops:
|
| 322 |
+
if op in _OP_STATS: _OP_STATS[op]["ok"] += 1
|
| 323 |
+
_SUCCESS_LOG.append({"t": time.time(), "db_id": str(db_id), "question": question, "sql": final_sql, "ops": ops})
|
| 324 |
+
|
| 325 |
+
limit_match = re.search(r'LIMIT\s+(\d+)', final_sql, re.IGNORECASE)
|
| 326 |
+
if limit_match and len(rows) < int(limit_match.group(1)):
|
| 327 |
+
explanation += f"\n\nℹ️ Query allowed up to {int(limit_match.group(1))} rows but only {len(rows)} matched."
|
| 328 |
+
|
| 329 |
+
return final_sql, pd.DataFrame(rows, columns=safe_cols), explanation
|
| 330 |
+
|
| 331 |
+
def task1_benchmark(n_rollouts: int, max_workers: int) -> Iterator[tuple[str, str]]:
|
| 332 |
+
project_root = str(PROJECT_ROOT)
|
| 333 |
+
env = os.environ.copy()
|
| 334 |
+
env["PYTHONPATH"] = project_root + (os.pathsep + env["PYTHONPATH"] if env.get("PYTHONPATH") else "")
|
| 335 |
+
env.setdefault("MPLBACKEND", "Agg")
|
| 336 |
+
env.setdefault("MPLCONFIGDIR", "/tmp/mplconfig")
|
| 337 |
+
try: os.makedirs(env["MPLCONFIGDIR"], exist_ok=True)
|
| 338 |
+
except Exception: pass
|
| 339 |
+
|
| 340 |
+
cmd = [sys.executable, "-u", "scripts/benchmark_parallel_reward.py", "--n", str(int(n_rollouts)), "--max-workers", str(int(max_workers)), "--skip-profile"]
|
| 341 |
+
proc = subprocess.Popen(cmd, cwd=project_root, env=env, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, bufsize=1)
|
| 342 |
+
last_yield = time.perf_counter()
|
| 343 |
+
lines: list[str] = []
|
| 344 |
+
yield "Running Task 1 benchmark...\n", "<i>Running...</i>"
|
| 345 |
+
|
| 346 |
+
assert proc.stdout is not None
|
| 347 |
+
for line in proc.stdout:
|
| 348 |
+
lines.append(line)
|
| 349 |
+
now = time.perf_counter()
|
| 350 |
+
if now - last_yield >= 0.5:
|
| 351 |
+
last_yield = now
|
| 352 |
+
yield "".join(lines[-200:]).strip(), "<i>Running...</i>"
|
| 353 |
+
|
| 354 |
+
proc.wait()
|
| 355 |
+
out = "".join(lines).strip()
|
| 356 |
+
|
| 357 |
+
plot_path = str(PROJECT_ROOT / "results" / "task1_plot.png")
|
| 358 |
+
if os.path.exists(plot_path):
|
| 359 |
+
try:
|
| 360 |
+
b64 = base64.b64encode(Path(plot_path).read_bytes()).decode("ascii")
|
| 361 |
+
yield out, f"<img src='data:image/png;base64,{b64}' style='max-width: 100%; border: 1px solid #e2e8f0; border-radius: 8px;' />"
|
| 362 |
+
return
|
| 363 |
+
except Exception:
|
| 364 |
+
yield out, f"<pre>{plot_path}</pre>"
|
| 365 |
+
return
|
| 366 |
+
|
| 367 |
+
yield out, "<i>No plot generated</i>"
|
| 368 |
+
|
| 369 |
+
def task2_dashboard_structured():
|
| 370 |
+
if not _QUERY_LOG:
|
| 371 |
+
empty_counts = pd.DataFrame(columns=["error_type", "count", "hint"])
|
| 372 |
+
empty_recent = pd.DataFrame(columns=["time", "db_id", "error_type", "question", "error_msg"])
|
| 373 |
+
return empty_counts, empty_recent, gr.update(choices=[], value=None)
|
| 374 |
+
|
| 375 |
+
counts = {}
|
| 376 |
+
for r in _QUERY_LOG[-1000:]:
|
| 377 |
+
k = r.get("error_type") or "other"
|
| 378 |
+
counts[k] = counts.get(k, 0) + 1
|
| 379 |
+
rows = [{"error_type": k, "count": int(v), "hint": get_hint(k)} for k, v in sorted(counts.items(), key=lambda x: (-x[1], x[0]))]
|
| 380 |
+
counts_df = pd.DataFrame(rows)
|
| 381 |
+
|
| 382 |
+
recent = []
|
| 383 |
+
for r in _QUERY_LOG[-100:]:
|
| 384 |
+
ts = r.get("t")
|
| 385 |
+
try: ts_s = time.strftime("%H:%M:%S", time.localtime(float(ts))) if ts else ""
|
| 386 |
+
except Exception: ts_s = ""
|
| 387 |
+
recent.append({"time": ts_s, "db_id": r.get("db_id", ""), "error_type": r.get("error_type", ""), "question": r.get("question", ""), "error_msg": r.get("error_msg", "")})
|
| 388 |
+
recent_df = pd.DataFrame(recent)
|
| 389 |
+
|
| 390 |
+
choices = [str(x["error_type"]) for x in rows]
|
| 391 |
+
default = choices[0] if choices else None
|
| 392 |
+
return counts_df, recent_df, gr.update(choices=choices, value=default)
|
| 393 |
+
|
| 394 |
+
def task2_error_examples(error_type: str) -> str:
|
| 395 |
+
if not error_type: return ""
|
| 396 |
+
hint = get_hint(error_type)
|
| 397 |
+
matches = [r for r in reversed(_QUERY_LOG) if (r.get("error_type") or "") == str(error_type)][:3]
|
| 398 |
+
if not matches: return f"Error type: {error_type}\nHint: {hint}\n\nNo examples yet."
|
| 399 |
+
out = [f"Error type: {error_type}", f"Hint: {hint}", ""]
|
| 400 |
+
for i, r in enumerate(matches, 1):
|
| 401 |
+
out.extend([f"Example {i}", f"DB: {r.get('db_id','')}", f"Q: {r.get('question','')}", f"SQL: {r.get('sql','')}", f"Msg: {r.get('error_msg','')}", ""])
|
| 402 |
+
return "\n".join(out).strip()
|
| 403 |
+
|
| 404 |
+
def _plot_op_stats_html() -> str:
|
| 405 |
+
try:
|
| 406 |
+
import matplotlib.pyplot as plt
|
| 407 |
+
labels = list(_OP_STATS.keys())
|
| 408 |
+
oks = [int(_OP_STATS[k]["ok"]) for k in labels]
|
| 409 |
+
fails = [int(_OP_STATS[k]["fail"]) for k in labels]
|
| 410 |
+
|
| 411 |
+
fig, ax = plt.subplots(figsize=(9, 3.5))
|
| 412 |
+
x = list(range(len(labels)))
|
| 413 |
+
ax.bar(x, oks, label="ok", color="#16a34a")
|
| 414 |
+
ax.bar(x, fails, bottom=oks, label="fail", color="#dc2626")
|
| 415 |
+
ax.set_xticks(x)
|
| 416 |
+
ax.set_xticklabels(labels, rotation=30, ha="right")
|
| 417 |
+
ax.set_title("Success/Failure by SQL operation")
|
| 418 |
+
ax.legend()
|
| 419 |
+
fig.tight_layout()
|
| 420 |
+
|
| 421 |
+
buf = io.BytesIO()
|
| 422 |
+
fig.savefig(buf, format="png", dpi=160)
|
| 423 |
+
plt.close(fig)
|
| 424 |
+
b64 = base64.b64encode(buf.getvalue()).decode("ascii")
|
| 425 |
+
return f"<img src='data:image/png;base64,{b64}' style='max-width: 100%; border: 1px solid #e2e8f0; border-radius: 8px;' />"
|
| 426 |
+
except Exception as e: return f"<pre>Plot error: {e}</pre>"
|
| 427 |
+
|
| 428 |
+
def task2_ops_table():
|
| 429 |
+
rows = []
|
| 430 |
+
for op, d in _OP_STATS.items():
|
| 431 |
+
ok = int(d.get("ok", 0))
|
| 432 |
+
fail = int(d.get("fail", 0))
|
| 433 |
+
total = ok + fail
|
| 434 |
+
rows.append({"op": op, "ok": ok, "fail": fail, "total": total, "success_rate": (ok / total) if total else 0.0})
|
| 435 |
+
return pd.DataFrame(rows), _plot_op_stats_html()
|
| 436 |
+
|
| 437 |
+
def toggle_input_method(method, current_sample):
|
| 438 |
+
if method == "💡 Pick a Sample":
|
| 439 |
+
db = next((db for q, db in SAMPLES if q == current_sample), "chinook_1")
|
| 440 |
+
return (gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(value=db, interactive=False))
|
| 441 |
+
return (gr.update(visible=False), gr.update(visible=True), gr.update(visible=True), gr.update(interactive=True))
|
| 442 |
+
|
| 443 |
+
def load_sample(selected_question):
|
| 444 |
+
if not selected_question: return gr.update()
|
| 445 |
+
return gr.update(value=next((db for q, db in SAMPLES if q == selected_question), "chinook_1"))
|
| 446 |
+
|
| 447 |
+
def clear_inputs():
|
| 448 |
+
return (gr.update(value="💡 Pick a Sample"), gr.update(value=SAMPLE_QUESTIONS[0], visible=True), gr.update(visible=False), gr.update(value="", visible=False), gr.update(value="chinook_1", interactive=False), "", pd.DataFrame(), "")
|
| 449 |
+
|
| 450 |
+
def update_schema(db_id):
|
| 451 |
+
if not db_id or schema_encoder is None: return ""
|
| 452 |
+
try:
|
| 453 |
+
raw_schema = schema_encoder.structured_schema(db_id)
|
| 454 |
+
html_output = "<div style='max-height: 250px; overflow-y: auto; background: #f8fafc; padding: 12px; border-radius: 8px; border: 1px solid #e2e8f0; font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, monospace; font-size: 0.9em; line-height: 1.6;'>"
|
| 455 |
+
for line in raw_schema.strip().split('\n'):
|
| 456 |
+
line = line.strip()
|
| 457 |
+
if not line: continue
|
| 458 |
+
match = re.search(r'^([a-zA-Z0-9_]+)\s*\((.*)\)', line)
|
| 459 |
+
if match: html_output += f"<div style='margin-bottom: 8px;'><strong style='color: #0f172a; font-size: 1.05em; font-weight: 800;'>{match.group(1).upper()}</strong> <span style='color: #64748b;'>( {match.group(2).lower()} )</span></div>"
|
| 460 |
+
else: html_output += f"<div style='color: #475569;'>{line}</div>"
|
| 461 |
+
html_output += "</div>"
|
| 462 |
+
return html_output
|
| 463 |
+
except Exception as e: return f"<div style='color: red;'>Error loading schema: {str(e)}</div>"
|
| 464 |
+
|
| 465 |
+
# =========================
|
| 466 |
+
# UI LAYOUT
|
| 467 |
+
# =========================
|
| 468 |
+
with gr.Blocks(title="Text-to-SQL RLHF") as demo:
|
| 469 |
+
gr.HTML("""
|
| 470 |
+
<div style="text-align: center; background-color: #e0e7ff; padding: 20px; border-radius: 10px; margin-bottom: 20px; border: 1px solid #c7d2fe;">
|
| 471 |
+
<h1 style="color: #3730a3; margin-top: 0; margin-bottom: 10px; font-size: 2.2em;"> Text-to-SQL using RLHF + Execution Reward</h1>
|
| 472 |
+
<p style="color: #4f46e5; font-size: 1.1em; margin: 0;">Convert Natural Language to SQL, strictly validated and safely executed on local SQLite databases.</p>
|
| 473 |
+
</div>
|
| 474 |
+
""")
|
| 475 |
+
|
| 476 |
+
DBS = sorted(["flight_1", "student_assessment", "store_1", "bike_1", "book_2", "chinook_1", "academic", "aircraft", "car_1", "cinema", "club_1", "csu_1", "college_1", "college_2", "company_1", "company_employee", "customer_complaints", "department_store", "employee_hire_evaluation", "museum_visit", "products_for_hire", "restaurant_1", "school_finance", "shop_membership", "small_bank_1", "student_1", "tvshow", "voter_1", "world_1"])
|
| 477 |
+
|
| 478 |
+
with gr.Tabs():
|
| 479 |
+
with gr.Tab("Inference"):
|
| 480 |
+
with gr.Row():
|
| 481 |
+
with gr.Column(scale=1):
|
| 482 |
+
gr.Markdown("### 1. Configuration & Input")
|
| 483 |
+
input_method = gr.Radio(choices=["💡 Pick a Sample", "✍️ Type my own"], value="💡 Pick a Sample", label="How do you want to ask?")
|
| 484 |
+
sample_dropdown = gr.Dropdown(choices=SAMPLE_QUESTIONS, value=SAMPLE_QUESTIONS[0], label="Select a Sample Question", info="The database will be selected automatically.", visible=True)
|
| 485 |
+
type_own_warning = gr.Markdown("**⚠️ Please select a Database first, then type your custom question below:**", visible=False)
|
| 486 |
+
gr.Markdown("---")
|
| 487 |
+
db_id = gr.Dropdown(choices=DBS, value="chinook_1", label="Select Database", interactive=False)
|
| 488 |
+
custom_question = gr.Textbox(label="Ask your Custom Question", placeholder="Type your own question here...", lines=3, visible=False)
|
| 489 |
+
|
| 490 |
+
gr.Markdown("#### 📋 Database Structure")
|
| 491 |
+
gr.HTML("<p style='font-size: 0.85em; color: #64748b; margin-top: -10px; margin-bottom: 5px;'>Use these exact names! Table names are <strong>Dark</strong>, Column names are <span style='color: #94a3b8;'>Light</span>.</p>")
|
| 492 |
+
schema_display = gr.HTML(value=update_schema("chinook_1"))
|
| 493 |
+
|
| 494 |
+
with gr.Row():
|
| 495 |
+
clear_btn = gr.Button("🗑️ Clear", variant="secondary")
|
| 496 |
+
run_btn = gr.Button(" Generate & Run SQL", variant="primary")
|
| 497 |
+
|
| 498 |
+
with gr.Column(scale=2):
|
| 499 |
+
gr.Markdown("### 2. Execution Results")
|
| 500 |
+
final_sql = gr.Code(language="sql", label="Final Executed SQL")
|
| 501 |
+
result_table = gr.Dataframe(label="Query Result Table", interactive=False, wrap=True)
|
| 502 |
+
explanation = gr.Textbox(label="AI Explanation + Execution Details", lines=8)
|
| 503 |
+
|
| 504 |
+
with gr.Tab("Diagnostics"):
|
| 505 |
+
gr.Markdown("## Diagnostics & Telemetry")
|
| 506 |
+
|
| 507 |
+
with gr.Accordion("Task 1: Parallel Reward Benchmark", open=False):
|
| 508 |
+
gr.Markdown("*(Simulates the heavy RLHF training workload by running hundreds of complex SQL queries concurrently to test SQLite multi-threading performance.)*")
|
| 509 |
+
t1_n = gr.Number(value=20, precision=0, label="Rollouts (n)")
|
| 510 |
+
t1_workers = gr.Number(value=10, precision=0, label="Max workers")
|
| 511 |
+
t1_run = gr.Button("Run Task 1 benchmark")
|
| 512 |
+
t1_out = gr.Textbox(label="Output", lines=12)
|
| 513 |
+
t1_plot = gr.HTML(label="Plot (if generated)")
|
| 514 |
+
t1_run.click(fn=task1_benchmark, inputs=[t1_n, t1_workers], outputs=[t1_out, t1_plot])
|
| 515 |
+
|
| 516 |
+
with gr.Accordion("Task 2: Error Dashboard", open=True):
|
| 517 |
+
gr.Markdown("*(Live telemetry tracking the most common SQL failures. Populates automatically when queries fail in the Inference tab.)*")
|
| 518 |
+
t2_refresh = gr.Button("Refresh dashboard")
|
| 519 |
+
t2_counts = gr.Dataframe(label="Error counts", interactive=False, wrap=True)
|
| 520 |
+
t2_recent = gr.Dataframe(label="Recent errors", interactive=False, wrap=True)
|
| 521 |
+
t2_type = gr.Dropdown(choices=[], value=None, label="Select error type")
|
| 522 |
+
t2_examples = gr.Textbox(label="Examples + hint", lines=10)
|
| 523 |
+
|
| 524 |
+
t2_refresh.click(fn=task2_dashboard_structured, inputs=[], outputs=[t2_counts, t2_recent, t2_type])
|
| 525 |
+
t2_type.change(fn=task2_error_examples, inputs=[t2_type], outputs=[t2_examples])
|
| 526 |
+
|
| 527 |
+
with gr.Accordion("Task 2: Clause Telemetry", open=False):
|
| 528 |
+
gr.Markdown("*(Analyzes which specific SQL clauses—SELECT, WHERE, JOIN, etc.—are most prone to errors during natural language generation.)*")
|
| 529 |
+
t2_ops_refresh = gr.Button("Refresh SQL-op stats")
|
| 530 |
+
t2_ops_tbl = gr.Dataframe(label="Success/failure by op", interactive=False, wrap=True)
|
| 531 |
+
t2_ops_plot = gr.HTML(label="Op plot")
|
| 532 |
+
t2_ops_refresh.click(fn=task2_ops_table, inputs=[], outputs=[t2_ops_tbl, t2_ops_plot])
|
| 533 |
+
|
| 534 |
+
# EVENT BINDING: The .then() forces the diagnostic tab to update live in the background!
|
| 535 |
+
input_method.change(fn=toggle_input_method, inputs=[input_method, sample_dropdown], outputs=[sample_dropdown, type_own_warning, custom_question, db_id])
|
| 536 |
+
sample_dropdown.change(fn=load_sample, inputs=[sample_dropdown], outputs=[db_id])
|
| 537 |
+
db_id.change(fn=update_schema, inputs=[db_id], outputs=[schema_display])
|
| 538 |
+
|
| 539 |
+
run_btn.click(
|
| 540 |
+
fn=run_query,
|
| 541 |
+
inputs=[input_method, sample_dropdown, custom_question, db_id],
|
| 542 |
+
outputs=[final_sql, result_table, explanation]
|
| 543 |
+
).then(
|
| 544 |
+
fn=task2_dashboard_structured, inputs=[], outputs=[t2_counts, t2_recent, t2_type]
|
| 545 |
+
).then(
|
| 546 |
+
fn=task2_ops_table, inputs=[], outputs=[t2_ops_tbl, t2_ops_plot]
|
| 547 |
+
)
|
| 548 |
+
|
| 549 |
+
clear_btn.click(fn=clear_inputs, inputs=[], outputs=[input_method, sample_dropdown, type_own_warning, custom_question, db_id, final_sql, result_table, explanation])
|
| 550 |
+
|
| 551 |
+
if __name__ == "__main__":
|
| 552 |
+
server_name = os.environ.get("GRADIO_SERVER_NAME", "0.0.0.0")
|
| 553 |
+
base_port = int(os.environ.get("GRADIO_SERVER_PORT", 7860))
|
| 554 |
+
max_retries = 10
|
| 555 |
+
|
| 556 |
+
for port in range(base_port, base_port + max_retries):
|
| 557 |
+
try:
|
| 558 |
+
print(f"Attempting to start Gradio UI on {server_name}:{port}...", flush=True)
|
| 559 |
+
demo.launch(server_name=server_name, server_port=port)
|
| 560 |
+
break # If successful, exit the loop
|
| 561 |
+
except OSError as e:
|
| 562 |
+
if "Cannot find empty port" in str(e) or "Address already in use" in str(e):
|
| 563 |
+
print(f"⚠️ Port {port} is in use, trying next port...")
|
| 564 |
+
continue
|
| 565 |
+
else:
|
| 566 |
+
# If it's a different OSError, raise it normally
|
| 567 |
+
raise e
|
| 568 |
+
else:
|
| 569 |
+
print(f"❌ Could not find an open port between {base_port} and {base_port + max_retries - 1}.")
|
db.zip
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:bb0e6ef12110c4c9808205cb210a35b5c4412397e15f47ed437e739e161d4213
|
| 3 |
+
size 53803466
|
requirements.txt
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio==5.8.0
|
| 2 |
+
pandas
|
| 3 |
+
sqlparse
|
| 4 |
+
transformers
|
| 5 |
+
torch --index-url https://download.pytorch.org/whl/cpu
|
| 6 |
+
peft
|
| 7 |
+
trl
|
| 8 |
+
sentencepiece
|
| 9 |
+
matplotlib
|
| 10 |
+
huggingface_hub
|