tjhalanigrid commited on
Commit
f2ede4f
·
verified ·
1 Parent(s): 87987b5

Upload 4 files

Browse files
Files changed (4) hide show
  1. README.md +13 -3
  2. app.py +569 -0
  3. db.zip +3 -0
  4. requirements.txt +10 -0
README.md CHANGED
@@ -1,3 +1,13 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Text2sql Demo
3
+ emoji: 📊
4
+ colorFrom: blue
5
+ colorTo: green
6
+ sdk: gradio
7
+ sdk_version: 5.8.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ python_version: 3.10.13
12
+ short_description: 'Text to SQL with RLHF'
13
+ ---
app.py ADDED
@@ -0,0 +1,569 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ GRADIO DEMO UI - LAZY LOADING EDITION
3
+ NL → SQL → Result Table
4
+ """
5
+
6
+ import gradio as gr
7
+ import pandas as pd
8
+ import re
9
+ import time
10
+ import os
11
+ import torch
12
+ import sys
13
+ import json
14
+ import subprocess
15
+ import base64
16
+ import io
17
+ from pathlib import Path
18
+ from typing import Iterator
19
+
20
+ # ==========================================
21
+ # RELATIVE PATH RESOLUTION (GLOBAL)
22
+ # ==========================================
23
+ try:
24
+ PROJECT_ROOT = Path(__file__).resolve().parent
25
+ except NameError:
26
+ PROJECT_ROOT = Path(".").resolve()
27
+
28
+ if (PROJECT_ROOT / "data" / "database").exists():
29
+ DB_ROOT = PROJECT_ROOT / "data" / "database"
30
+ else:
31
+ DB_ROOT = PROJECT_ROOT / "final_databases"
32
+
33
+ def get_db_path(db_id: str) -> str:
34
+ path1 = DB_ROOT / db_id / f"{db_id}.sqlite"
35
+ path2 = DB_ROOT / f"{db_id}.sqlite"
36
+ return str(path1) if path1.exists() else str(path2)
37
+
38
+ # ==========================================
39
+ # 🔥 CUDA MOCK PATCH FOR MAC (MPS) / CPU
40
+ # ==========================================
41
+ if not torch.cuda.is_available():
42
+ class MockCUDAEvent:
43
+ def __init__(self, enable_timing=False, blocking=False, interprocess=False):
44
+ self.t = 0.0
45
+ def record(self, stream=None):
46
+ self.t = time.perf_counter()
47
+ def elapsed_time(self, end_event):
48
+ return (end_event.t - self.t) * 1000.0
49
+
50
+ torch.cuda.Event = MockCUDAEvent
51
+ if not hasattr(torch.cuda, 'synchronize'):
52
+ torch.cuda.synchronize = lambda: None
53
+
54
+ # ==========================================
55
+ # IMPORTS & ENGINE SETUP
56
+ # ==========================================
57
+ from src.quantized_text2sql_engine import QuantizedText2SQLEngine
58
+ from src.schema_encoder import SchemaEncoder
59
+
60
+ DEFAULT_QUANT_ARTIFACT = str(PROJECT_ROOT / "int8_dynamic")
61
+
62
+ _ENGINE_CACHE = {}
63
+ _QUERY_LOG = []
64
+ _PERF_LOG = []
65
+ _SUCCESS_LOG = []
66
+
67
+ _OP_STATS = {
68
+ "SELECT": {"ok": 0, "fail": 0}, "WHERE": {"ok": 0, "fail": 0}, "JOIN": {"ok": 0, "fail": 0},
69
+ "GROUP_BY": {"ok": 0, "fail": 0}, "ORDER_BY": {"ok": 0, "fail": 0}, "HAVING": {"ok": 0, "fail": 0}, "LIMIT": {"ok": 0, "fail": 0},
70
+ }
71
+
72
+ def get_quant_engine(artifact_dir: str, use_constrained: bool = False, exec_workers: int = 8, use_cache: bool = True):
73
+ key = (artifact_dir, bool(use_constrained), int(exec_workers), bool(use_cache))
74
+ if key not in _ENGINE_CACHE:
75
+ try:
76
+ _ENGINE_CACHE[key] = QuantizedText2SQLEngine(artifact_dir, device="cpu", use_constrained=bool(use_constrained), exec_workers=int(exec_workers), use_cache=bool(use_cache))
77
+ except TypeError:
78
+ _ENGINE_CACHE[key] = QuantizedText2SQLEngine(artifact_dir)
79
+ return _ENGINE_CACHE[key]
80
+
81
+ # 🚨 LAZY LOADING: We DO NOT load the model here! We only load the fast Schema Encoder.
82
+ quant_engine = None
83
+ try:
84
+ schema_encoder = SchemaEncoder(DB_ROOT)
85
+ except Exception as e:
86
+ print(f"Warning: SchemaEncoder failed to load: {e}")
87
+ schema_encoder = None
88
+
89
+ SAMPLES = [
90
+ ("Show 10 distinct employee first names.", "chinook_1"), ("Which artist has the most albums?", "chinook_1"),
91
+ ("List all the tracks that belong to the 'Rock' genre.", "chinook_1"), ("What are the names of all the cities?", "flight_1"),
92
+ ("Find the flight number and cost of the cheapest flight.", "flight_1"), ("List the airlines that fly out of New York.", "flight_1"),
93
+ ("Which campus was opened between 1935 and 1939?", "csu_1"), ("Count the number of students in each department.", "college_2"),
94
+ ("List the names of all clubs.", "club_1"), ("How many members does each club have?", "club_1"),
95
+ ("Show the names of all cinemas.", "cinema"), ("Which cinema has the most screens?", "cinema")
96
+ ]
97
+ SAMPLE_QUESTIONS = [q[0] for q in SAMPLES]
98
+
99
+ def explain_sql(sql):
100
+ if not sql: return ""
101
+ explanation = "This SQL query retrieves information from the database."
102
+ sql_lower = sql.lower()
103
+ if "join" in sql_lower: explanation += "\n• It combines data from multiple tables using JOIN."
104
+ if "where" in sql_lower: explanation += "\n• It filters rows using a WHERE condition."
105
+ if "group by" in sql_lower: explanation += "\n• It groups results using GROUP BY."
106
+ if "order by" in sql_lower: explanation += "\n• It sorts the results using ORDER BY."
107
+ if "limit" in sql_lower: explanation += "\n• It limits the number of returned rows."
108
+ return explanation
109
+
110
+ def sql_ops(sql: str) -> list[str]:
111
+ s = (sql or "").lower()
112
+ ops = ["SELECT"]
113
+ if " where " in f" {s} ": ops.append("WHERE")
114
+ if " join " in f" {s} ": ops.append("JOIN")
115
+ if " group by " in f" {s} ": ops.append("GROUP_BY")
116
+ if " order by " in f" {s} ": ops.append("ORDER_BY")
117
+ if " having " in f" {s} ": ops.append("HAVING")
118
+ if " limit " in f" {s} ": ops.append("LIMIT")
119
+ return ops
120
+
121
+ def classify_error(sql: str, error_msg: str | None = None, *, timed_out: bool = False):
122
+ s = (sql or "").lower()
123
+ m = (error_msg or "").lower()
124
+ if timed_out or "interrupted" in m or "timeout" in m: return "timeout"
125
+ if not s.strip().startswith(("select", "with")): return "syntax_error"
126
+ if " join " in f" {s} " and " on " not in f" {s} ": return "missing_join"
127
+ if " where " in f" {s} " and not any(op in s for op in ["=", ">", "<", " in ", " like ", " between ", " is null", " is not null"]): return "wrong_where"
128
+ if ("is null" in s or "is not null" in s) and ("no such column" in m or "misuse" in m): return "null_handling"
129
+ if "no such table" in m: return "missing_table"
130
+ if "no such column" in m: return "missing_column"
131
+ if "ambiguous column name" in m: return "ambiguous_column"
132
+ if "datatype mismatch" in m or "type mismatch" in m: return "type_mismatch"
133
+ if "misuse of aggregate" in m or "misuse of aggregate function" in m: return "wrong_aggregation"
134
+ if "syntax error" in m: return "syntax_error"
135
+ if "near" in m and "syntax error" in m: return "syntax_error"
136
+ if "runtime" in m or "constraint failed" in m: return "runtime_error"
137
+ return "other"
138
+
139
+ def get_hint(error_type):
140
+ hints = {
141
+ "missing_join": "Check JOIN conditions between tables.", "wrong_aggregation": "Use proper aggregation like avg(column).",
142
+ "wrong_where": "Check WHERE condition syntax.", "syntax_error": "Ensure SQL starts with SELECT.",
143
+ "missing_table": "Use only tables from the provided schema.", "missing_column": "Use only columns from the provided schema.",
144
+ "ambiguous_column": "Disambiguate by using table.column.", "timeout": "Query took too long; simplify joins.", "other": "Review SQL logic."
145
+ }
146
+ return hints.get(error_type, "Review query.")
147
+
148
+ def is_relevant_to_schema(question, db_id):
149
+ if schema_encoder is None: return True
150
+ try: raw_schema = schema_encoder.structured_schema(db_id).lower()
151
+ except: return True
152
+ schema_words = set(re.findall(r'[a-z0-9_]+', raw_schema))
153
+ q_words = re.findall(r'[a-z0-9_]+', question.lower())
154
+ stop_words = {"show", "list", "all", "what", "is", "the", "how", "many", "count", "find", "get", "me", "a", "an", "of", "in", "for", "from", "with", "which", "are", "there", "give", "tell", "details", "info", "data", "everything"}
155
+ meaningful_q_words = [w for w in q_words if w not in stop_words and not w.isdigit()]
156
+ if not meaningful_q_words: return True
157
+ for word in meaningful_q_words:
158
+ singular_word = word[:-1] if word.endswith('s') else word
159
+ if word in schema_words or singular_word in schema_words: return True
160
+ return False
161
+
162
+ def run_query(method, sample_q, custom_q, db_id):
163
+ global quant_engine
164
+
165
+ # 🚨 LAZY LOADING: We load the heavy AI model ONLY when the button is clicked.
166
+ if quant_engine is None:
167
+ print(f"First request detected! Loading AI model from {DEFAULT_QUANT_ARTIFACT}...", flush=True)
168
+ try:
169
+ quant_engine = get_quant_engine(DEFAULT_QUANT_ARTIFACT, use_constrained=False, exec_workers=8, use_cache=True)
170
+ if quant_engine is None:
171
+ return "-- ❌ ENGINE CRASH", pd.DataFrame(columns=["Error"]), "Failed to load model. Did you move the tokenizer files and add config.json to int8_dynamic/?"
172
+ except Exception as e:
173
+ return f"-- ❌ ENGINE CRASH\n-- {str(e)}", pd.DataFrame(columns=["Error Status"]), f"Critical failure loading model: {e}"
174
+
175
+ def _log(error_type: str, *, question: str, db_id_val: str, sql: str = "", error_msg: str = "") -> None:
176
+ _QUERY_LOG.append({"t": time.time(), "db_id": str(db_id_val), "question": str(question), "sql": str(sql), "error_type": str(error_type), "error_msg": str(error_msg)})
177
+
178
+ def _perf_log(payload: dict) -> None:
179
+ _PERF_LOG.append(payload)
180
+ if len(_PERF_LOG) > 1000: del _PERF_LOG[:200]
181
+
182
+ raw_question = sample_q if method == "💡 Pick a Sample" else custom_q
183
+
184
+ if not raw_question or str(raw_question).strip() == "":
185
+ return "-- No input provided", pd.DataFrame(columns=["Warning"]), "⚠️ Please enter a question."
186
+ if not db_id or str(db_id).strip() == "":
187
+ return "-- No database selected", pd.DataFrame(columns=["Warning"]), "⚠️ Please select a database."
188
+
189
+ typo_corrections = [(r'\bshaw\b', 'show'), (r'\bshw\b', 'show'), (r'\bsho\b', 'show'), (r'\blsit\b', 'list'), (r'\blis\b', 'list'), (r'\bfidn\b', 'find'), (r'\bfnd\b', 'find'), (r'\bgte\b', 'get')]
190
+ question = str(raw_question)
191
+ for bad, good in typo_corrections: question = re.sub(bad, good, question, flags=re.IGNORECASE)
192
+ q_lower = question.strip().lower()
193
+
194
+ if len(q_lower.split()) < 2:
195
+ _log("gibberish", question=question, db_id_val=str(db_id), error_msg="gibberish filtered")
196
+ return "-- Input Blocked", pd.DataFrame(columns=["Warning"]), "⚠️ Please enter a clear, meaningful natural language question (more than one word)."
197
+
198
+ if re.search(r'\b(delete|update|insert|drop|alter|truncate)\b', q_lower):
199
+ _log("blocked_dml", question=question, db_id_val=str(db_id), error_msg="DML blocked")
200
+ return "-- ❌ BLOCKED: Data Modification", pd.DataFrame(columns=["Security Alert"]), "🛑 Security Alert: Modifying or deleting data is strictly prohibited."
201
+
202
+ if not is_relevant_to_schema(question, db_id):
203
+ _log("out_of_domain", question=question, db_id_val=str(db_id), error_msg="out of domain")
204
+ return "-- ❌ BLOCKED: Out of Domain", pd.DataFrame(columns=["Domain Alert"]), f"🛑 Relevance Alert: I don't see anything related to your question in the '{db_id}' schema."
205
+
206
+ start_time = time.time()
207
+ t0 = time.perf_counter()
208
+ ui_warnings = ""
209
+
210
+ try:
211
+ try:
212
+ result = quant_engine.ask(question, str(db_id), num_beams=4, max_new_tokens=120, timeout_s=2.0)
213
+ except TypeError:
214
+ result = quant_engine.ask(question, str(db_id))
215
+ except Exception as e:
216
+ _log("backend_crash", question=question, db_id_val=str(db_id), error_msg=str(e))
217
+ return f"-- ❌ BACKEND CRASH\n-- {str(e)}", pd.DataFrame(columns=["Error Status"]), f"❌ CRITICAL BACKEND CRASH:\n{str(e)}"
218
+
219
+ final_sql = str(result.get("sql", ""))
220
+ model_sql = final_sql
221
+
222
+ num_match = re.search(r'\b(?:show|list|top|limit|get|first|last|sample|of)\s+(?:[a-zA-Z_]+\s+)?(\d+)\b', q_lower)
223
+ if not num_match and q_lower.startswith(("show", "list", "get")):
224
+ num_match = re.search(r'\b(\d+)\b', q_lower)
225
+
226
+ if num_match and final_sql:
227
+ limit_val = num_match.group(1)
228
+ final_sql = re.sub(rf"(?i)\s*(?:where|having|and)?\s*count\s*\(\s*\*\s*\)\s*=\s*{limit_val}", "", final_sql)
229
+ final_sql = re.sub(rf"(?i)\s*(?:where|and)\s+[a-zA-Z0-9_.]+\s*=\s*['\"]?{limit_val}['\"]?", "", final_sql)
230
+ final_sql = re.sub(r"(?i)\s*where\s*$", "", final_sql)
231
+ final_sql = re.sub(r"(?i)\s*where\s+(group by|order by|limit)", r" \1", final_sql)
232
+
233
+ agg_kws = ["most", "top", "highest", "lowest", "count", "many", "group", "frequent", "popular"]
234
+ if not any(k in q_lower for k in agg_kws):
235
+ final_sql = re.sub(r"(?i)\s*group by\s+[a-zA-Z0-9_.]+\s*order by\s+count\(\*\)\s*(?:desc|asc)?", "", final_sql)
236
+ final_sql = re.sub(r"(?i)\s*order by\s+count\(\*\)\s*(?:desc|asc)?", "", final_sql)
237
+ final_sql = re.sub(r"(?i),\s*count\(\*\)", "", final_sql)
238
+ final_sql = re.sub(r"(?i)count\(\*\)\s*,", "", final_sql)
239
+
240
+ if "group by" in final_sql.lower() and not re.search(r'(?i)\b(count|sum|avg|max|min)\b\(', final_sql):
241
+ final_sql = re.sub(r"(?i)\s*group by\s+[a-zA-Z0-9_.]+", "", final_sql)
242
+
243
+ if "limit" not in final_sql.lower():
244
+ final_sql = f"{final_sql.strip().rstrip(';')} LIMIT {limit_val}"
245
+
246
+ # Execution
247
+ from src.sql_validator import validate_sql_schema
248
+ db_path = get_db_path(str(db_id))
249
+
250
+ try: strict_valid, _ = validate_sql_schema(final_sql, db_path)
251
+ except Exception: strict_valid = False
252
+
253
+ error_msg = None
254
+ rows, cols = [], []
255
+ sqlite_success = False
256
+
257
+ try:
258
+ rows, cols = quant_engine._execute_one(final_sql, db_path, timeout_s=2.0)
259
+ sqlite_success = True
260
+ except Exception as e:
261
+ error_msg = str(e)
262
+ sqlite_success = False
263
+
264
+ if not sqlite_success and model_sql and model_sql != final_sql:
265
+ try:
266
+ alt_rows, alt_cols = quant_engine._execute_one(model_sql, db_path, timeout_s=2.0)
267
+ final_sql = model_sql
268
+ rows, cols = alt_rows, alt_cols
269
+ sqlite_success = True
270
+ error_msg = None
271
+ except Exception: pass
272
+
273
+ valid = sqlite_success
274
+
275
+ if error_msg or not valid:
276
+ et = classify_error(final_sql, str(error_msg or ""), timed_out=("interrupted" in str(error_msg or "").lower()))
277
+ _log(et, question=str(question), db_id_val=str(db_id), sql=str(final_sql), error_msg=str(error_msg or "Execution failed"))
278
+
279
+ latency = round(time.time() - start_time, 3)
280
+ t1 = time.perf_counter()
281
+
282
+ engine_stats_after = quant_engine.stats() if hasattr(quant_engine, 'stats') else {}
283
+
284
+ perf = {
285
+ "db_id": str(db_id), "use_constrained_decoding": False, "num_beams": 4,
286
+ "latency_total_ms": round((t1 - t0) * 1000.0, 2), "constraint_ok": bool(strict_valid), "has_error": bool(error_msg),
287
+ "exec_cache_hit_rate": float(engine_stats_after.get("exec_cache_hit_rate", 0.0) or 0.0),
288
+ }
289
+ _perf_log(perf)
290
+
291
+ window = _PERF_LOG[-50:]
292
+ avg_ms = sum(float(x.get("latency_total_ms", 0.0) or 0.0) for x in window) / len(window) if window else 0.0
293
+ constraint_rate = sum(1 for x in window if x.get("constraint_ok")) / len(window) if window else 0.0
294
+
295
+ perf_block = (
296
+ "\n\n---\nPerformance (task impact)\n"
297
+ f"- Total latency (ms): {perf['latency_total_ms']}\n"
298
+ f"- Strict Python Validator OK (Task 3): {perf['constraint_ok']}\n"
299
+ f"- Exec cache hit-rate (Task 1/5): {round(perf['exec_cache_hit_rate'], 3)}\n"
300
+ f"- Rolling avg latency last 50 (ms): {round(avg_ms, 2)}\n"
301
+ f"- Rolling constraint rate last 50: {round(constraint_rate, 3)}\n"
302
+ )
303
+
304
+ if error_msg or not valid:
305
+ display_sql = final_sql if final_sql.strip() else "-- ❌ INVALID SQL"
306
+ explanation = f"{ui_warnings}❌ Error Details:\n\n"
307
+ if error_msg: explanation += f"{error_msg}\n\n"
308
+
309
+ error_type = classify_error(final_sql, str(error_msg or ""))
310
+ explanation += f"Error Type: {error_type}\nHint: {get_hint(error_type)}"
311
+ explanation += perf_block
312
+ ops = sql_ops(final_sql)
313
+ for op in ops:
314
+ if op in _OP_STATS: _OP_STATS[op]["fail"] += 1
315
+ return display_sql, pd.DataFrame(columns=["Execution Notice"]), explanation
316
+
317
+ safe_cols = cols if cols else ["Result"]
318
+ explanation = f"{ui_warnings}✅ Query executed successfully\n\nRows returned: {len(rows)}\nExecution Time: {latency} sec\n\n{explain_sql(final_sql)}{perf_block}"
319
+
320
+ ops = sql_ops(final_sql)
321
+ for op in ops:
322
+ if op in _OP_STATS: _OP_STATS[op]["ok"] += 1
323
+ _SUCCESS_LOG.append({"t": time.time(), "db_id": str(db_id), "question": question, "sql": final_sql, "ops": ops})
324
+
325
+ limit_match = re.search(r'LIMIT\s+(\d+)', final_sql, re.IGNORECASE)
326
+ if limit_match and len(rows) < int(limit_match.group(1)):
327
+ explanation += f"\n\nℹ️ Query allowed up to {int(limit_match.group(1))} rows but only {len(rows)} matched."
328
+
329
+ return final_sql, pd.DataFrame(rows, columns=safe_cols), explanation
330
+
331
+ def task1_benchmark(n_rollouts: int, max_workers: int) -> Iterator[tuple[str, str]]:
332
+ project_root = str(PROJECT_ROOT)
333
+ env = os.environ.copy()
334
+ env["PYTHONPATH"] = project_root + (os.pathsep + env["PYTHONPATH"] if env.get("PYTHONPATH") else "")
335
+ env.setdefault("MPLBACKEND", "Agg")
336
+ env.setdefault("MPLCONFIGDIR", "/tmp/mplconfig")
337
+ try: os.makedirs(env["MPLCONFIGDIR"], exist_ok=True)
338
+ except Exception: pass
339
+
340
+ cmd = [sys.executable, "-u", "scripts/benchmark_parallel_reward.py", "--n", str(int(n_rollouts)), "--max-workers", str(int(max_workers)), "--skip-profile"]
341
+ proc = subprocess.Popen(cmd, cwd=project_root, env=env, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, bufsize=1)
342
+ last_yield = time.perf_counter()
343
+ lines: list[str] = []
344
+ yield "Running Task 1 benchmark...\n", "<i>Running...</i>"
345
+
346
+ assert proc.stdout is not None
347
+ for line in proc.stdout:
348
+ lines.append(line)
349
+ now = time.perf_counter()
350
+ if now - last_yield >= 0.5:
351
+ last_yield = now
352
+ yield "".join(lines[-200:]).strip(), "<i>Running...</i>"
353
+
354
+ proc.wait()
355
+ out = "".join(lines).strip()
356
+
357
+ plot_path = str(PROJECT_ROOT / "results" / "task1_plot.png")
358
+ if os.path.exists(plot_path):
359
+ try:
360
+ b64 = base64.b64encode(Path(plot_path).read_bytes()).decode("ascii")
361
+ yield out, f"<img src='data:image/png;base64,{b64}' style='max-width: 100%; border: 1px solid #e2e8f0; border-radius: 8px;' />"
362
+ return
363
+ except Exception:
364
+ yield out, f"<pre>{plot_path}</pre>"
365
+ return
366
+
367
+ yield out, "<i>No plot generated</i>"
368
+
369
+ def task2_dashboard_structured():
370
+ if not _QUERY_LOG:
371
+ empty_counts = pd.DataFrame(columns=["error_type", "count", "hint"])
372
+ empty_recent = pd.DataFrame(columns=["time", "db_id", "error_type", "question", "error_msg"])
373
+ return empty_counts, empty_recent, gr.update(choices=[], value=None)
374
+
375
+ counts = {}
376
+ for r in _QUERY_LOG[-1000:]:
377
+ k = r.get("error_type") or "other"
378
+ counts[k] = counts.get(k, 0) + 1
379
+ rows = [{"error_type": k, "count": int(v), "hint": get_hint(k)} for k, v in sorted(counts.items(), key=lambda x: (-x[1], x[0]))]
380
+ counts_df = pd.DataFrame(rows)
381
+
382
+ recent = []
383
+ for r in _QUERY_LOG[-100:]:
384
+ ts = r.get("t")
385
+ try: ts_s = time.strftime("%H:%M:%S", time.localtime(float(ts))) if ts else ""
386
+ except Exception: ts_s = ""
387
+ recent.append({"time": ts_s, "db_id": r.get("db_id", ""), "error_type": r.get("error_type", ""), "question": r.get("question", ""), "error_msg": r.get("error_msg", "")})
388
+ recent_df = pd.DataFrame(recent)
389
+
390
+ choices = [str(x["error_type"]) for x in rows]
391
+ default = choices[0] if choices else None
392
+ return counts_df, recent_df, gr.update(choices=choices, value=default)
393
+
394
+ def task2_error_examples(error_type: str) -> str:
395
+ if not error_type: return ""
396
+ hint = get_hint(error_type)
397
+ matches = [r for r in reversed(_QUERY_LOG) if (r.get("error_type") or "") == str(error_type)][:3]
398
+ if not matches: return f"Error type: {error_type}\nHint: {hint}\n\nNo examples yet."
399
+ out = [f"Error type: {error_type}", f"Hint: {hint}", ""]
400
+ for i, r in enumerate(matches, 1):
401
+ out.extend([f"Example {i}", f"DB: {r.get('db_id','')}", f"Q: {r.get('question','')}", f"SQL: {r.get('sql','')}", f"Msg: {r.get('error_msg','')}", ""])
402
+ return "\n".join(out).strip()
403
+
404
+ def _plot_op_stats_html() -> str:
405
+ try:
406
+ import matplotlib.pyplot as plt
407
+ labels = list(_OP_STATS.keys())
408
+ oks = [int(_OP_STATS[k]["ok"]) for k in labels]
409
+ fails = [int(_OP_STATS[k]["fail"]) for k in labels]
410
+
411
+ fig, ax = plt.subplots(figsize=(9, 3.5))
412
+ x = list(range(len(labels)))
413
+ ax.bar(x, oks, label="ok", color="#16a34a")
414
+ ax.bar(x, fails, bottom=oks, label="fail", color="#dc2626")
415
+ ax.set_xticks(x)
416
+ ax.set_xticklabels(labels, rotation=30, ha="right")
417
+ ax.set_title("Success/Failure by SQL operation")
418
+ ax.legend()
419
+ fig.tight_layout()
420
+
421
+ buf = io.BytesIO()
422
+ fig.savefig(buf, format="png", dpi=160)
423
+ plt.close(fig)
424
+ b64 = base64.b64encode(buf.getvalue()).decode("ascii")
425
+ return f"<img src='data:image/png;base64,{b64}' style='max-width: 100%; border: 1px solid #e2e8f0; border-radius: 8px;' />"
426
+ except Exception as e: return f"<pre>Plot error: {e}</pre>"
427
+
428
+ def task2_ops_table():
429
+ rows = []
430
+ for op, d in _OP_STATS.items():
431
+ ok = int(d.get("ok", 0))
432
+ fail = int(d.get("fail", 0))
433
+ total = ok + fail
434
+ rows.append({"op": op, "ok": ok, "fail": fail, "total": total, "success_rate": (ok / total) if total else 0.0})
435
+ return pd.DataFrame(rows), _plot_op_stats_html()
436
+
437
+ def toggle_input_method(method, current_sample):
438
+ if method == "💡 Pick a Sample":
439
+ db = next((db for q, db in SAMPLES if q == current_sample), "chinook_1")
440
+ return (gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(value=db, interactive=False))
441
+ return (gr.update(visible=False), gr.update(visible=True), gr.update(visible=True), gr.update(interactive=True))
442
+
443
+ def load_sample(selected_question):
444
+ if not selected_question: return gr.update()
445
+ return gr.update(value=next((db for q, db in SAMPLES if q == selected_question), "chinook_1"))
446
+
447
+ def clear_inputs():
448
+ return (gr.update(value="💡 Pick a Sample"), gr.update(value=SAMPLE_QUESTIONS[0], visible=True), gr.update(visible=False), gr.update(value="", visible=False), gr.update(value="chinook_1", interactive=False), "", pd.DataFrame(), "")
449
+
450
+ def update_schema(db_id):
451
+ if not db_id or schema_encoder is None: return ""
452
+ try:
453
+ raw_schema = schema_encoder.structured_schema(db_id)
454
+ html_output = "<div style='max-height: 250px; overflow-y: auto; background: #f8fafc; padding: 12px; border-radius: 8px; border: 1px solid #e2e8f0; font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, monospace; font-size: 0.9em; line-height: 1.6;'>"
455
+ for line in raw_schema.strip().split('\n'):
456
+ line = line.strip()
457
+ if not line: continue
458
+ match = re.search(r'^([a-zA-Z0-9_]+)\s*\((.*)\)', line)
459
+ if match: html_output += f"<div style='margin-bottom: 8px;'><strong style='color: #0f172a; font-size: 1.05em; font-weight: 800;'>{match.group(1).upper()}</strong> <span style='color: #64748b;'>( {match.group(2).lower()} )</span></div>"
460
+ else: html_output += f"<div style='color: #475569;'>{line}</div>"
461
+ html_output += "</div>"
462
+ return html_output
463
+ except Exception as e: return f"<div style='color: red;'>Error loading schema: {str(e)}</div>"
464
+
465
+ # =========================
466
+ # UI LAYOUT
467
+ # =========================
468
+ with gr.Blocks(title="Text-to-SQL RLHF") as demo:
469
+ gr.HTML("""
470
+ <div style="text-align: center; background-color: #e0e7ff; padding: 20px; border-radius: 10px; margin-bottom: 20px; border: 1px solid #c7d2fe;">
471
+ <h1 style="color: #3730a3; margin-top: 0; margin-bottom: 10px; font-size: 2.2em;"> Text-to-SQL using RLHF + Execution Reward</h1>
472
+ <p style="color: #4f46e5; font-size: 1.1em; margin: 0;">Convert Natural Language to SQL, strictly validated and safely executed on local SQLite databases.</p>
473
+ </div>
474
+ """)
475
+
476
+ DBS = sorted(["flight_1", "student_assessment", "store_1", "bike_1", "book_2", "chinook_1", "academic", "aircraft", "car_1", "cinema", "club_1", "csu_1", "college_1", "college_2", "company_1", "company_employee", "customer_complaints", "department_store", "employee_hire_evaluation", "museum_visit", "products_for_hire", "restaurant_1", "school_finance", "shop_membership", "small_bank_1", "student_1", "tvshow", "voter_1", "world_1"])
477
+
478
+ with gr.Tabs():
479
+ with gr.Tab("Inference"):
480
+ with gr.Row():
481
+ with gr.Column(scale=1):
482
+ gr.Markdown("### 1. Configuration & Input")
483
+ input_method = gr.Radio(choices=["💡 Pick a Sample", "✍️ Type my own"], value="💡 Pick a Sample", label="How do you want to ask?")
484
+ sample_dropdown = gr.Dropdown(choices=SAMPLE_QUESTIONS, value=SAMPLE_QUESTIONS[0], label="Select a Sample Question", info="The database will be selected automatically.", visible=True)
485
+ type_own_warning = gr.Markdown("**⚠️ Please select a Database first, then type your custom question below:**", visible=False)
486
+ gr.Markdown("---")
487
+ db_id = gr.Dropdown(choices=DBS, value="chinook_1", label="Select Database", interactive=False)
488
+ custom_question = gr.Textbox(label="Ask your Custom Question", placeholder="Type your own question here...", lines=3, visible=False)
489
+
490
+ gr.Markdown("#### 📋 Database Structure")
491
+ gr.HTML("<p style='font-size: 0.85em; color: #64748b; margin-top: -10px; margin-bottom: 5px;'>Use these exact names! Table names are <strong>Dark</strong>, Column names are <span style='color: #94a3b8;'>Light</span>.</p>")
492
+ schema_display = gr.HTML(value=update_schema("chinook_1"))
493
+
494
+ with gr.Row():
495
+ clear_btn = gr.Button("🗑️ Clear", variant="secondary")
496
+ run_btn = gr.Button(" Generate & Run SQL", variant="primary")
497
+
498
+ with gr.Column(scale=2):
499
+ gr.Markdown("### 2. Execution Results")
500
+ final_sql = gr.Code(language="sql", label="Final Executed SQL")
501
+ result_table = gr.Dataframe(label="Query Result Table", interactive=False, wrap=True)
502
+ explanation = gr.Textbox(label="AI Explanation + Execution Details", lines=8)
503
+
504
+ with gr.Tab("Diagnostics"):
505
+ gr.Markdown("## Diagnostics & Telemetry")
506
+
507
+ with gr.Accordion("Task 1: Parallel Reward Benchmark", open=False):
508
+ gr.Markdown("*(Simulates the heavy RLHF training workload by running hundreds of complex SQL queries concurrently to test SQLite multi-threading performance.)*")
509
+ t1_n = gr.Number(value=20, precision=0, label="Rollouts (n)")
510
+ t1_workers = gr.Number(value=10, precision=0, label="Max workers")
511
+ t1_run = gr.Button("Run Task 1 benchmark")
512
+ t1_out = gr.Textbox(label="Output", lines=12)
513
+ t1_plot = gr.HTML(label="Plot (if generated)")
514
+ t1_run.click(fn=task1_benchmark, inputs=[t1_n, t1_workers], outputs=[t1_out, t1_plot])
515
+
516
+ with gr.Accordion("Task 2: Error Dashboard", open=True):
517
+ gr.Markdown("*(Live telemetry tracking the most common SQL failures. Populates automatically when queries fail in the Inference tab.)*")
518
+ t2_refresh = gr.Button("Refresh dashboard")
519
+ t2_counts = gr.Dataframe(label="Error counts", interactive=False, wrap=True)
520
+ t2_recent = gr.Dataframe(label="Recent errors", interactive=False, wrap=True)
521
+ t2_type = gr.Dropdown(choices=[], value=None, label="Select error type")
522
+ t2_examples = gr.Textbox(label="Examples + hint", lines=10)
523
+
524
+ t2_refresh.click(fn=task2_dashboard_structured, inputs=[], outputs=[t2_counts, t2_recent, t2_type])
525
+ t2_type.change(fn=task2_error_examples, inputs=[t2_type], outputs=[t2_examples])
526
+
527
+ with gr.Accordion("Task 2: Clause Telemetry", open=False):
528
+ gr.Markdown("*(Analyzes which specific SQL clauses—SELECT, WHERE, JOIN, etc.—are most prone to errors during natural language generation.)*")
529
+ t2_ops_refresh = gr.Button("Refresh SQL-op stats")
530
+ t2_ops_tbl = gr.Dataframe(label="Success/failure by op", interactive=False, wrap=True)
531
+ t2_ops_plot = gr.HTML(label="Op plot")
532
+ t2_ops_refresh.click(fn=task2_ops_table, inputs=[], outputs=[t2_ops_tbl, t2_ops_plot])
533
+
534
+ # EVENT BINDING: The .then() forces the diagnostic tab to update live in the background!
535
+ input_method.change(fn=toggle_input_method, inputs=[input_method, sample_dropdown], outputs=[sample_dropdown, type_own_warning, custom_question, db_id])
536
+ sample_dropdown.change(fn=load_sample, inputs=[sample_dropdown], outputs=[db_id])
537
+ db_id.change(fn=update_schema, inputs=[db_id], outputs=[schema_display])
538
+
539
+ run_btn.click(
540
+ fn=run_query,
541
+ inputs=[input_method, sample_dropdown, custom_question, db_id],
542
+ outputs=[final_sql, result_table, explanation]
543
+ ).then(
544
+ fn=task2_dashboard_structured, inputs=[], outputs=[t2_counts, t2_recent, t2_type]
545
+ ).then(
546
+ fn=task2_ops_table, inputs=[], outputs=[t2_ops_tbl, t2_ops_plot]
547
+ )
548
+
549
+ clear_btn.click(fn=clear_inputs, inputs=[], outputs=[input_method, sample_dropdown, type_own_warning, custom_question, db_id, final_sql, result_table, explanation])
550
+
551
+ if __name__ == "__main__":
552
+ server_name = os.environ.get("GRADIO_SERVER_NAME", "0.0.0.0")
553
+ base_port = int(os.environ.get("GRADIO_SERVER_PORT", 7860))
554
+ max_retries = 10
555
+
556
+ for port in range(base_port, base_port + max_retries):
557
+ try:
558
+ print(f"Attempting to start Gradio UI on {server_name}:{port}...", flush=True)
559
+ demo.launch(server_name=server_name, server_port=port)
560
+ break # If successful, exit the loop
561
+ except OSError as e:
562
+ if "Cannot find empty port" in str(e) or "Address already in use" in str(e):
563
+ print(f"⚠️ Port {port} is in use, trying next port...")
564
+ continue
565
+ else:
566
+ # If it's a different OSError, raise it normally
567
+ raise e
568
+ else:
569
+ print(f"❌ Could not find an open port between {base_port} and {base_port + max_retries - 1}.")
db.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bb0e6ef12110c4c9808205cb210a35b5c4412397e15f47ed437e739e161d4213
3
+ size 53803466
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio==5.8.0
2
+ pandas
3
+ sqlparse
4
+ transformers
5
+ torch --index-url https://download.pytorch.org/whl/cpu
6
+ peft
7
+ trl
8
+ sentencepiece
9
+ matplotlib
10
+ huggingface_hub