tjhalanigrid commited on
Commit
cf17729
·
1 Parent(s): b70f6fd

Added full project

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +13 -3
  2. app.py +569 -0
  3. db.zip +3 -0
  4. int8_dynamic/meta.json +7 -0
  5. int8_dynamic/model.pt +3 -0
  6. int8_dynamic/tokenizer/merges.txt +0 -0
  7. int8_dynamic/tokenizer/special_tokens_map.json +753 -0
  8. int8_dynamic/tokenizer/tokenizer.json +0 -0
  9. int8_dynamic/tokenizer/tokenizer_config.json +959 -0
  10. int8_dynamic/tokenizer/vocab.json +0 -0
  11. requirements.txt +10 -0
  12. scripts/__pycache__/benchmark_parallel_reward.cpython-310.pyc +0 -0
  13. scripts/__pycache__/benchmark_parallel_reward.cpython-313.pyc +0 -0
  14. scripts/__pycache__/benchmark_quantization.cpython-310.pyc +0 -0
  15. scripts/__pycache__/benchmark_rollout_generation.cpython-310.pyc +0 -0
  16. scripts/__pycache__/quantize_export.cpython-310.pyc +0 -0
  17. scripts/__pycache__/quantized_infer_harness.cpython-310.pyc +0 -0
  18. scripts/benchmark_parallel_reward.py +202 -0
  19. scripts/benchmark_quantization.py +108 -0
  20. scripts/benchmark_rollout_generation.py +66 -0
  21. scripts/error_dashboard.py +99 -0
  22. scripts/evaluate.py +170 -0
  23. scripts/plot_task2.py +58 -0
  24. scripts/plot_task3.py +15 -0
  25. scripts/plot_task3_plotly.py +103 -0
  26. scripts/quantize_export.py +86 -0
  27. scripts/quantized_infer_harness.py +46 -0
  28. src/__pycache__/execution_reward.cpython-310.pyc +0 -0
  29. src/__pycache__/quantization_utils.cpython-310.pyc +0 -0
  30. src/__pycache__/quantized_text2sql_engine.cpython-310.pyc +0 -0
  31. src/__pycache__/schema_encoder.cpython-310.pyc +0 -0
  32. src/__pycache__/schema_utils.cpython-310.pyc +0 -0
  33. src/__pycache__/sql_validator.cpython-310.pyc +0 -0
  34. src/__pycache__/text2sql_engine.cpython-310.pyc +0 -0
  35. src/ask.py +93 -0
  36. src/component_analysis.py +229 -0
  37. src/constrained_decoding.py +1058 -0
  38. src/constrained_decoding_sample.py +516 -0
  39. src/convert_to_hf_dataset.py +8 -0
  40. src/eval_baseline_codet5.py +112 -0
  41. src/eval_both_metrics.py +144 -0
  42. src/eval_rl_fixed.py +756 -0
  43. src/eval_rl_t5.py +279 -0
  44. src/eval_single_model.py +218 -0
  45. src/evaluate_model_codet5.py +392 -0
  46. src/evaluate_model_t5_small_sft.py +179 -0
  47. src/evaluate_rl_bart.py +138 -0
  48. src/evaluate_sft_bart.py +190 -0
  49. src/evaluate_without_constraied.py +503 -0
  50. src/execution_reward copy.py +831 -0
README.md CHANGED
@@ -1,3 +1,13 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Text2sql Demo
3
+ emoji: 📊
4
+ colorFrom: blue
5
+ colorTo: green
6
+ sdk: gradio
7
+ sdk_version: 5.8.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ python_version: 3.10.13
12
+ short_description: 'Text to SQL with RLHF'
13
+ ---
app.py ADDED
@@ -0,0 +1,569 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ GRADIO DEMO UI - LAZY LOADING EDITION
3
+ NL → SQL → Result Table
4
+ """
5
+
6
+ import gradio as gr
7
+ import pandas as pd
8
+ import re
9
+ import time
10
+ import os
11
+ import torch
12
+ import sys
13
+ import json
14
+ import subprocess
15
+ import base64
16
+ import io
17
+ from pathlib import Path
18
+ from typing import Iterator
19
+
20
+ # ==========================================
21
+ # RELATIVE PATH RESOLUTION (GLOBAL)
22
+ # ==========================================
23
+ try:
24
+ PROJECT_ROOT = Path(__file__).resolve().parent
25
+ except NameError:
26
+ PROJECT_ROOT = Path(".").resolve()
27
+
28
+ if (PROJECT_ROOT / "data" / "database").exists():
29
+ DB_ROOT = PROJECT_ROOT / "data" / "database"
30
+ else:
31
+ DB_ROOT = PROJECT_ROOT / "final_databases"
32
+
33
+ def get_db_path(db_id: str) -> str:
34
+ path1 = DB_ROOT / db_id / f"{db_id}.sqlite"
35
+ path2 = DB_ROOT / f"{db_id}.sqlite"
36
+ return str(path1) if path1.exists() else str(path2)
37
+
38
+ # ==========================================
39
+ # 🔥 CUDA MOCK PATCH FOR MAC (MPS) / CPU
40
+ # ==========================================
41
+ if not torch.cuda.is_available():
42
+ class MockCUDAEvent:
43
+ def __init__(self, enable_timing=False, blocking=False, interprocess=False):
44
+ self.t = 0.0
45
+ def record(self, stream=None):
46
+ self.t = time.perf_counter()
47
+ def elapsed_time(self, end_event):
48
+ return (end_event.t - self.t) * 1000.0
49
+
50
+ torch.cuda.Event = MockCUDAEvent
51
+ if not hasattr(torch.cuda, 'synchronize'):
52
+ torch.cuda.synchronize = lambda: None
53
+
54
+ # ==========================================
55
+ # IMPORTS & ENGINE SETUP
56
+ # ==========================================
57
+ from src.quantized_text2sql_engine import QuantizedText2SQLEngine
58
+ from src.schema_encoder import SchemaEncoder
59
+
60
+ DEFAULT_QUANT_ARTIFACT = str(PROJECT_ROOT / "int8_dynamic")
61
+
62
+ _ENGINE_CACHE = {}
63
+ _QUERY_LOG = []
64
+ _PERF_LOG = []
65
+ _SUCCESS_LOG = []
66
+
67
+ _OP_STATS = {
68
+ "SELECT": {"ok": 0, "fail": 0}, "WHERE": {"ok": 0, "fail": 0}, "JOIN": {"ok": 0, "fail": 0},
69
+ "GROUP_BY": {"ok": 0, "fail": 0}, "ORDER_BY": {"ok": 0, "fail": 0}, "HAVING": {"ok": 0, "fail": 0}, "LIMIT": {"ok": 0, "fail": 0},
70
+ }
71
+
72
+ def get_quant_engine(artifact_dir: str, use_constrained: bool = False, exec_workers: int = 8, use_cache: bool = True):
73
+ key = (artifact_dir, bool(use_constrained), int(exec_workers), bool(use_cache))
74
+ if key not in _ENGINE_CACHE:
75
+ try:
76
+ _ENGINE_CACHE[key] = QuantizedText2SQLEngine(artifact_dir, device="cpu", use_constrained=bool(use_constrained), exec_workers=int(exec_workers), use_cache=bool(use_cache))
77
+ except TypeError:
78
+ _ENGINE_CACHE[key] = QuantizedText2SQLEngine(artifact_dir)
79
+ return _ENGINE_CACHE[key]
80
+
81
+ # 🚨 LAZY LOADING: We DO NOT load the model here! We only load the fast Schema Encoder.
82
+ quant_engine = None
83
+ try:
84
+ schema_encoder = SchemaEncoder(DB_ROOT)
85
+ except Exception as e:
86
+ print(f"Warning: SchemaEncoder failed to load: {e}")
87
+ schema_encoder = None
88
+
89
+ SAMPLES = [
90
+ ("Show 10 distinct employee first names.", "chinook_1"), ("Which artist has the most albums?", "chinook_1"),
91
+ ("List all the tracks that belong to the 'Rock' genre.", "chinook_1"), ("What are the names of all the cities?", "flight_1"),
92
+ ("Find the flight number and cost of the cheapest flight.", "flight_1"), ("List the airlines that fly out of New York.", "flight_1"),
93
+ ("Which campus was opened between 1935 and 1939?", "csu_1"), ("Count the number of students in each department.", "college_2"),
94
+ ("List the names of all clubs.", "club_1"), ("How many members does each club have?", "club_1"),
95
+ ("Show the names of all cinemas.", "cinema"), ("Which cinema has the most screens?", "cinema")
96
+ ]
97
+ SAMPLE_QUESTIONS = [q[0] for q in SAMPLES]
98
+
99
+ def explain_sql(sql):
100
+ if not sql: return ""
101
+ explanation = "This SQL query retrieves information from the database."
102
+ sql_lower = sql.lower()
103
+ if "join" in sql_lower: explanation += "\n• It combines data from multiple tables using JOIN."
104
+ if "where" in sql_lower: explanation += "\n• It filters rows using a WHERE condition."
105
+ if "group by" in sql_lower: explanation += "\n• It groups results using GROUP BY."
106
+ if "order by" in sql_lower: explanation += "\n• It sorts the results using ORDER BY."
107
+ if "limit" in sql_lower: explanation += "\n• It limits the number of returned rows."
108
+ return explanation
109
+
110
+ def sql_ops(sql: str) -> list[str]:
111
+ s = (sql or "").lower()
112
+ ops = ["SELECT"]
113
+ if " where " in f" {s} ": ops.append("WHERE")
114
+ if " join " in f" {s} ": ops.append("JOIN")
115
+ if " group by " in f" {s} ": ops.append("GROUP_BY")
116
+ if " order by " in f" {s} ": ops.append("ORDER_BY")
117
+ if " having " in f" {s} ": ops.append("HAVING")
118
+ if " limit " in f" {s} ": ops.append("LIMIT")
119
+ return ops
120
+
121
+ def classify_error(sql: str, error_msg: str | None = None, *, timed_out: bool = False):
122
+ s = (sql or "").lower()
123
+ m = (error_msg or "").lower()
124
+ if timed_out or "interrupted" in m or "timeout" in m: return "timeout"
125
+ if not s.strip().startswith(("select", "with")): return "syntax_error"
126
+ if " join " in f" {s} " and " on " not in f" {s} ": return "missing_join"
127
+ if " where " in f" {s} " and not any(op in s for op in ["=", ">", "<", " in ", " like ", " between ", " is null", " is not null"]): return "wrong_where"
128
+ if ("is null" in s or "is not null" in s) and ("no such column" in m or "misuse" in m): return "null_handling"
129
+ if "no such table" in m: return "missing_table"
130
+ if "no such column" in m: return "missing_column"
131
+ if "ambiguous column name" in m: return "ambiguous_column"
132
+ if "datatype mismatch" in m or "type mismatch" in m: return "type_mismatch"
133
+ if "misuse of aggregate" in m or "misuse of aggregate function" in m: return "wrong_aggregation"
134
+ if "syntax error" in m: return "syntax_error"
135
+ if "near" in m and "syntax error" in m: return "syntax_error"
136
+ if "runtime" in m or "constraint failed" in m: return "runtime_error"
137
+ return "other"
138
+
139
+ def get_hint(error_type):
140
+ hints = {
141
+ "missing_join": "Check JOIN conditions between tables.", "wrong_aggregation": "Use proper aggregation like avg(column).",
142
+ "wrong_where": "Check WHERE condition syntax.", "syntax_error": "Ensure SQL starts with SELECT.",
143
+ "missing_table": "Use only tables from the provided schema.", "missing_column": "Use only columns from the provided schema.",
144
+ "ambiguous_column": "Disambiguate by using table.column.", "timeout": "Query took too long; simplify joins.", "other": "Review SQL logic."
145
+ }
146
+ return hints.get(error_type, "Review query.")
147
+
148
+ def is_relevant_to_schema(question, db_id):
149
+ if schema_encoder is None: return True
150
+ try: raw_schema = schema_encoder.structured_schema(db_id).lower()
151
+ except: return True
152
+ schema_words = set(re.findall(r'[a-z0-9_]+', raw_schema))
153
+ q_words = re.findall(r'[a-z0-9_]+', question.lower())
154
+ stop_words = {"show", "list", "all", "what", "is", "the", "how", "many", "count", "find", "get", "me", "a", "an", "of", "in", "for", "from", "with", "which", "are", "there", "give", "tell", "details", "info", "data", "everything"}
155
+ meaningful_q_words = [w for w in q_words if w not in stop_words and not w.isdigit()]
156
+ if not meaningful_q_words: return True
157
+ for word in meaningful_q_words:
158
+ singular_word = word[:-1] if word.endswith('s') else word
159
+ if word in schema_words or singular_word in schema_words: return True
160
+ return False
161
+
162
+ def run_query(method, sample_q, custom_q, db_id):
163
+ global quant_engine
164
+
165
+ # 🚨 LAZY LOADING: We load the heavy AI model ONLY when the button is clicked.
166
+ if quant_engine is None:
167
+ print(f"First request detected! Loading AI model from {DEFAULT_QUANT_ARTIFACT}...", flush=True)
168
+ try:
169
+ quant_engine = get_quant_engine(DEFAULT_QUANT_ARTIFACT, use_constrained=False, exec_workers=8, use_cache=True)
170
+ if quant_engine is None:
171
+ return "-- ❌ ENGINE CRASH", pd.DataFrame(columns=["Error"]), "Failed to load model. Did you move the tokenizer files and add config.json to int8_dynamic/?"
172
+ except Exception as e:
173
+ return f"-- ❌ ENGINE CRASH\n-- {str(e)}", pd.DataFrame(columns=["Error Status"]), f"Critical failure loading model: {e}"
174
+
175
+ def _log(error_type: str, *, question: str, db_id_val: str, sql: str = "", error_msg: str = "") -> None:
176
+ _QUERY_LOG.append({"t": time.time(), "db_id": str(db_id_val), "question": str(question), "sql": str(sql), "error_type": str(error_type), "error_msg": str(error_msg)})
177
+
178
+ def _perf_log(payload: dict) -> None:
179
+ _PERF_LOG.append(payload)
180
+ if len(_PERF_LOG) > 1000: del _PERF_LOG[:200]
181
+
182
+ raw_question = sample_q if method == "💡 Pick a Sample" else custom_q
183
+
184
+ if not raw_question or str(raw_question).strip() == "":
185
+ return "-- No input provided", pd.DataFrame(columns=["Warning"]), "⚠️ Please enter a question."
186
+ if not db_id or str(db_id).strip() == "":
187
+ return "-- No database selected", pd.DataFrame(columns=["Warning"]), "⚠️ Please select a database."
188
+
189
+ typo_corrections = [(r'\bshaw\b', 'show'), (r'\bshw\b', 'show'), (r'\bsho\b', 'show'), (r'\blsit\b', 'list'), (r'\blis\b', 'list'), (r'\bfidn\b', 'find'), (r'\bfnd\b', 'find'), (r'\bgte\b', 'get')]
190
+ question = str(raw_question)
191
+ for bad, good in typo_corrections: question = re.sub(bad, good, question, flags=re.IGNORECASE)
192
+ q_lower = question.strip().lower()
193
+
194
+ if len(q_lower.split()) < 2:
195
+ _log("gibberish", question=question, db_id_val=str(db_id), error_msg="gibberish filtered")
196
+ return "-- Input Blocked", pd.DataFrame(columns=["Warning"]), "⚠️ Please enter a clear, meaningful natural language question (more than one word)."
197
+
198
+ if re.search(r'\b(delete|update|insert|drop|alter|truncate)\b', q_lower):
199
+ _log("blocked_dml", question=question, db_id_val=str(db_id), error_msg="DML blocked")
200
+ return "-- ❌ BLOCKED: Data Modification", pd.DataFrame(columns=["Security Alert"]), "🛑 Security Alert: Modifying or deleting data is strictly prohibited."
201
+
202
+ if not is_relevant_to_schema(question, db_id):
203
+ _log("out_of_domain", question=question, db_id_val=str(db_id), error_msg="out of domain")
204
+ return "-- ❌ BLOCKED: Out of Domain", pd.DataFrame(columns=["Domain Alert"]), f"🛑 Relevance Alert: I don't see anything related to your question in the '{db_id}' schema."
205
+
206
+ start_time = time.time()
207
+ t0 = time.perf_counter()
208
+ ui_warnings = ""
209
+
210
+ try:
211
+ try:
212
+ result = quant_engine.ask(question, str(db_id), num_beams=4, max_new_tokens=120, timeout_s=2.0)
213
+ except TypeError:
214
+ result = quant_engine.ask(question, str(db_id))
215
+ except Exception as e:
216
+ _log("backend_crash", question=question, db_id_val=str(db_id), error_msg=str(e))
217
+ return f"-- ❌ BACKEND CRASH\n-- {str(e)}", pd.DataFrame(columns=["Error Status"]), f"❌ CRITICAL BACKEND CRASH:\n{str(e)}"
218
+
219
+ final_sql = str(result.get("sql", ""))
220
+ model_sql = final_sql
221
+
222
+ num_match = re.search(r'\b(?:show|list|top|limit|get|first|last|sample|of)\s+(?:[a-zA-Z_]+\s+)?(\d+)\b', q_lower)
223
+ if not num_match and q_lower.startswith(("show", "list", "get")):
224
+ num_match = re.search(r'\b(\d+)\b', q_lower)
225
+
226
+ if num_match and final_sql:
227
+ limit_val = num_match.group(1)
228
+ final_sql = re.sub(rf"(?i)\s*(?:where|having|and)?\s*count\s*\(\s*\*\s*\)\s*=\s*{limit_val}", "", final_sql)
229
+ final_sql = re.sub(rf"(?i)\s*(?:where|and)\s+[a-zA-Z0-9_.]+\s*=\s*['\"]?{limit_val}['\"]?", "", final_sql)
230
+ final_sql = re.sub(r"(?i)\s*where\s*$", "", final_sql)
231
+ final_sql = re.sub(r"(?i)\s*where\s+(group by|order by|limit)", r" \1", final_sql)
232
+
233
+ agg_kws = ["most", "top", "highest", "lowest", "count", "many", "group", "frequent", "popular"]
234
+ if not any(k in q_lower for k in agg_kws):
235
+ final_sql = re.sub(r"(?i)\s*group by\s+[a-zA-Z0-9_.]+\s*order by\s+count\(\*\)\s*(?:desc|asc)?", "", final_sql)
236
+ final_sql = re.sub(r"(?i)\s*order by\s+count\(\*\)\s*(?:desc|asc)?", "", final_sql)
237
+ final_sql = re.sub(r"(?i),\s*count\(\*\)", "", final_sql)
238
+ final_sql = re.sub(r"(?i)count\(\*\)\s*,", "", final_sql)
239
+
240
+ if "group by" in final_sql.lower() and not re.search(r'(?i)\b(count|sum|avg|max|min)\b\(', final_sql):
241
+ final_sql = re.sub(r"(?i)\s*group by\s+[a-zA-Z0-9_.]+", "", final_sql)
242
+
243
+ if "limit" not in final_sql.lower():
244
+ final_sql = f"{final_sql.strip().rstrip(';')} LIMIT {limit_val}"
245
+
246
+ # Execution
247
+ from src.sql_validator import validate_sql_schema
248
+ db_path = get_db_path(str(db_id))
249
+
250
+ try: strict_valid, _ = validate_sql_schema(final_sql, db_path)
251
+ except Exception: strict_valid = False
252
+
253
+ error_msg = None
254
+ rows, cols = [], []
255
+ sqlite_success = False
256
+
257
+ try:
258
+ rows, cols = quant_engine._execute_one(final_sql, db_path, timeout_s=2.0)
259
+ sqlite_success = True
260
+ except Exception as e:
261
+ error_msg = str(e)
262
+ sqlite_success = False
263
+
264
+ if not sqlite_success and model_sql and model_sql != final_sql:
265
+ try:
266
+ alt_rows, alt_cols = quant_engine._execute_one(model_sql, db_path, timeout_s=2.0)
267
+ final_sql = model_sql
268
+ rows, cols = alt_rows, alt_cols
269
+ sqlite_success = True
270
+ error_msg = None
271
+ except Exception: pass
272
+
273
+ valid = sqlite_success
274
+
275
+ if error_msg or not valid:
276
+ et = classify_error(final_sql, str(error_msg or ""), timed_out=("interrupted" in str(error_msg or "").lower()))
277
+ _log(et, question=str(question), db_id_val=str(db_id), sql=str(final_sql), error_msg=str(error_msg or "Execution failed"))
278
+
279
+ latency = round(time.time() - start_time, 3)
280
+ t1 = time.perf_counter()
281
+
282
+ engine_stats_after = quant_engine.stats() if hasattr(quant_engine, 'stats') else {}
283
+
284
+ perf = {
285
+ "db_id": str(db_id), "use_constrained_decoding": False, "num_beams": 4,
286
+ "latency_total_ms": round((t1 - t0) * 1000.0, 2), "constraint_ok": bool(strict_valid), "has_error": bool(error_msg),
287
+ "exec_cache_hit_rate": float(engine_stats_after.get("exec_cache_hit_rate", 0.0) or 0.0),
288
+ }
289
+ _perf_log(perf)
290
+
291
+ window = _PERF_LOG[-50:]
292
+ avg_ms = sum(float(x.get("latency_total_ms", 0.0) or 0.0) for x in window) / len(window) if window else 0.0
293
+ constraint_rate = sum(1 for x in window if x.get("constraint_ok")) / len(window) if window else 0.0
294
+
295
+ perf_block = (
296
+ "\n\n---\nPerformance (task impact)\n"
297
+ f"- Total latency (ms): {perf['latency_total_ms']}\n"
298
+ f"- Strict Python Validator OK (Task 3): {perf['constraint_ok']}\n"
299
+ f"- Exec cache hit-rate (Task 1/5): {round(perf['exec_cache_hit_rate'], 3)}\n"
300
+ f"- Rolling avg latency last 50 (ms): {round(avg_ms, 2)}\n"
301
+ f"- Rolling constraint rate last 50: {round(constraint_rate, 3)}\n"
302
+ )
303
+
304
+ if error_msg or not valid:
305
+ display_sql = final_sql if final_sql.strip() else "-- ❌ INVALID SQL"
306
+ explanation = f"{ui_warnings}❌ Error Details:\n\n"
307
+ if error_msg: explanation += f"{error_msg}\n\n"
308
+
309
+ error_type = classify_error(final_sql, str(error_msg or ""))
310
+ explanation += f"Error Type: {error_type}\nHint: {get_hint(error_type)}"
311
+ explanation += perf_block
312
+ ops = sql_ops(final_sql)
313
+ for op in ops:
314
+ if op in _OP_STATS: _OP_STATS[op]["fail"] += 1
315
+ return display_sql, pd.DataFrame(columns=["Execution Notice"]), explanation
316
+
317
+ safe_cols = cols if cols else ["Result"]
318
+ explanation = f"{ui_warnings}✅ Query executed successfully\n\nRows returned: {len(rows)}\nExecution Time: {latency} sec\n\n{explain_sql(final_sql)}{perf_block}"
319
+
320
+ ops = sql_ops(final_sql)
321
+ for op in ops:
322
+ if op in _OP_STATS: _OP_STATS[op]["ok"] += 1
323
+ _SUCCESS_LOG.append({"t": time.time(), "db_id": str(db_id), "question": question, "sql": final_sql, "ops": ops})
324
+
325
+ limit_match = re.search(r'LIMIT\s+(\d+)', final_sql, re.IGNORECASE)
326
+ if limit_match and len(rows) < int(limit_match.group(1)):
327
+ explanation += f"\n\nℹ️ Query allowed up to {int(limit_match.group(1))} rows but only {len(rows)} matched."
328
+
329
+ return final_sql, pd.DataFrame(rows, columns=safe_cols), explanation
330
+
331
+ def task1_benchmark(n_rollouts: int, max_workers: int) -> Iterator[tuple[str, str]]:
332
+ project_root = str(PROJECT_ROOT)
333
+ env = os.environ.copy()
334
+ env["PYTHONPATH"] = project_root + (os.pathsep + env["PYTHONPATH"] if env.get("PYTHONPATH") else "")
335
+ env.setdefault("MPLBACKEND", "Agg")
336
+ env.setdefault("MPLCONFIGDIR", "/tmp/mplconfig")
337
+ try: os.makedirs(env["MPLCONFIGDIR"], exist_ok=True)
338
+ except Exception: pass
339
+
340
+ cmd = [sys.executable, "-u", "scripts/benchmark_parallel_reward.py", "--n", str(int(n_rollouts)), "--max-workers", str(int(max_workers)), "--skip-profile"]
341
+ proc = subprocess.Popen(cmd, cwd=project_root, env=env, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, bufsize=1)
342
+ last_yield = time.perf_counter()
343
+ lines: list[str] = []
344
+ yield "Running Task 1 benchmark...\n", "<i>Running...</i>"
345
+
346
+ assert proc.stdout is not None
347
+ for line in proc.stdout:
348
+ lines.append(line)
349
+ now = time.perf_counter()
350
+ if now - last_yield >= 0.5:
351
+ last_yield = now
352
+ yield "".join(lines[-200:]).strip(), "<i>Running...</i>"
353
+
354
+ proc.wait()
355
+ out = "".join(lines).strip()
356
+
357
+ plot_path = str(PROJECT_ROOT / "results" / "task1_plot.png")
358
+ if os.path.exists(plot_path):
359
+ try:
360
+ b64 = base64.b64encode(Path(plot_path).read_bytes()).decode("ascii")
361
+ yield out, f"<img src='data:image/png;base64,{b64}' style='max-width: 100%; border: 1px solid #e2e8f0; border-radius: 8px;' />"
362
+ return
363
+ except Exception:
364
+ yield out, f"<pre>{plot_path}</pre>"
365
+ return
366
+
367
+ yield out, "<i>No plot generated</i>"
368
+
369
+ def task2_dashboard_structured():
370
+ if not _QUERY_LOG:
371
+ empty_counts = pd.DataFrame(columns=["error_type", "count", "hint"])
372
+ empty_recent = pd.DataFrame(columns=["time", "db_id", "error_type", "question", "error_msg"])
373
+ return empty_counts, empty_recent, gr.update(choices=[], value=None)
374
+
375
+ counts = {}
376
+ for r in _QUERY_LOG[-1000:]:
377
+ k = r.get("error_type") or "other"
378
+ counts[k] = counts.get(k, 0) + 1
379
+ rows = [{"error_type": k, "count": int(v), "hint": get_hint(k)} for k, v in sorted(counts.items(), key=lambda x: (-x[1], x[0]))]
380
+ counts_df = pd.DataFrame(rows)
381
+
382
+ recent = []
383
+ for r in _QUERY_LOG[-100:]:
384
+ ts = r.get("t")
385
+ try: ts_s = time.strftime("%H:%M:%S", time.localtime(float(ts))) if ts else ""
386
+ except Exception: ts_s = ""
387
+ recent.append({"time": ts_s, "db_id": r.get("db_id", ""), "error_type": r.get("error_type", ""), "question": r.get("question", ""), "error_msg": r.get("error_msg", "")})
388
+ recent_df = pd.DataFrame(recent)
389
+
390
+ choices = [str(x["error_type"]) for x in rows]
391
+ default = choices[0] if choices else None
392
+ return counts_df, recent_df, gr.update(choices=choices, value=default)
393
+
394
+ def task2_error_examples(error_type: str) -> str:
395
+ if not error_type: return ""
396
+ hint = get_hint(error_type)
397
+ matches = [r for r in reversed(_QUERY_LOG) if (r.get("error_type") or "") == str(error_type)][:3]
398
+ if not matches: return f"Error type: {error_type}\nHint: {hint}\n\nNo examples yet."
399
+ out = [f"Error type: {error_type}", f"Hint: {hint}", ""]
400
+ for i, r in enumerate(matches, 1):
401
+ out.extend([f"Example {i}", f"DB: {r.get('db_id','')}", f"Q: {r.get('question','')}", f"SQL: {r.get('sql','')}", f"Msg: {r.get('error_msg','')}", ""])
402
+ return "\n".join(out).strip()
403
+
404
+ def _plot_op_stats_html() -> str:
405
+ try:
406
+ import matplotlib.pyplot as plt
407
+ labels = list(_OP_STATS.keys())
408
+ oks = [int(_OP_STATS[k]["ok"]) for k in labels]
409
+ fails = [int(_OP_STATS[k]["fail"]) for k in labels]
410
+
411
+ fig, ax = plt.subplots(figsize=(9, 3.5))
412
+ x = list(range(len(labels)))
413
+ ax.bar(x, oks, label="ok", color="#16a34a")
414
+ ax.bar(x, fails, bottom=oks, label="fail", color="#dc2626")
415
+ ax.set_xticks(x)
416
+ ax.set_xticklabels(labels, rotation=30, ha="right")
417
+ ax.set_title("Success/Failure by SQL operation")
418
+ ax.legend()
419
+ fig.tight_layout()
420
+
421
+ buf = io.BytesIO()
422
+ fig.savefig(buf, format="png", dpi=160)
423
+ plt.close(fig)
424
+ b64 = base64.b64encode(buf.getvalue()).decode("ascii")
425
+ return f"<img src='data:image/png;base64,{b64}' style='max-width: 100%; border: 1px solid #e2e8f0; border-radius: 8px;' />"
426
+ except Exception as e: return f"<pre>Plot error: {e}</pre>"
427
+
428
+ def task2_ops_table():
429
+ rows = []
430
+ for op, d in _OP_STATS.items():
431
+ ok = int(d.get("ok", 0))
432
+ fail = int(d.get("fail", 0))
433
+ total = ok + fail
434
+ rows.append({"op": op, "ok": ok, "fail": fail, "total": total, "success_rate": (ok / total) if total else 0.0})
435
+ return pd.DataFrame(rows), _plot_op_stats_html()
436
+
437
+ def toggle_input_method(method, current_sample):
438
+ if method == "💡 Pick a Sample":
439
+ db = next((db for q, db in SAMPLES if q == current_sample), "chinook_1")
440
+ return (gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(value=db, interactive=False))
441
+ return (gr.update(visible=False), gr.update(visible=True), gr.update(visible=True), gr.update(interactive=True))
442
+
443
+ def load_sample(selected_question):
444
+ if not selected_question: return gr.update()
445
+ return gr.update(value=next((db for q, db in SAMPLES if q == selected_question), "chinook_1"))
446
+
447
+ def clear_inputs():
448
+ return (gr.update(value="💡 Pick a Sample"), gr.update(value=SAMPLE_QUESTIONS[0], visible=True), gr.update(visible=False), gr.update(value="", visible=False), gr.update(value="chinook_1", interactive=False), "", pd.DataFrame(), "")
449
+
450
+ def update_schema(db_id):
451
+ if not db_id or schema_encoder is None: return ""
452
+ try:
453
+ raw_schema = schema_encoder.structured_schema(db_id)
454
+ html_output = "<div style='max-height: 250px; overflow-y: auto; background: #f8fafc; padding: 12px; border-radius: 8px; border: 1px solid #e2e8f0; font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, monospace; font-size: 0.9em; line-height: 1.6;'>"
455
+ for line in raw_schema.strip().split('\n'):
456
+ line = line.strip()
457
+ if not line: continue
458
+ match = re.search(r'^([a-zA-Z0-9_]+)\s*\((.*)\)', line)
459
+ if match: html_output += f"<div style='margin-bottom: 8px;'><strong style='color: #0f172a; font-size: 1.05em; font-weight: 800;'>{match.group(1).upper()}</strong> <span style='color: #64748b;'>( {match.group(2).lower()} )</span></div>"
460
+ else: html_output += f"<div style='color: #475569;'>{line}</div>"
461
+ html_output += "</div>"
462
+ return html_output
463
+ except Exception as e: return f"<div style='color: red;'>Error loading schema: {str(e)}</div>"
464
+
465
+ # =========================
466
+ # UI LAYOUT
467
+ # =========================
468
+ with gr.Blocks(title="Text-to-SQL RLHF") as demo:
469
+ gr.HTML("""
470
+ <div style="text-align: center; background-color: #e0e7ff; padding: 20px; border-radius: 10px; margin-bottom: 20px; border: 1px solid #c7d2fe;">
471
+ <h1 style="color: #3730a3; margin-top: 0; margin-bottom: 10px; font-size: 2.2em;"> Text-to-SQL using RLHF + Execution Reward</h1>
472
+ <p style="color: #4f46e5; font-size: 1.1em; margin: 0;">Convert Natural Language to SQL, strictly validated and safely executed on local SQLite databases.</p>
473
+ </div>
474
+ """)
475
+
476
+ DBS = sorted(["flight_1", "student_assessment", "store_1", "bike_1", "book_2", "chinook_1", "academic", "aircraft", "car_1", "cinema", "club_1", "csu_1", "college_1", "college_2", "company_1", "company_employee", "customer_complaints", "department_store", "employee_hire_evaluation", "museum_visit", "products_for_hire", "restaurant_1", "school_finance", "shop_membership", "small_bank_1", "student_1", "tvshow", "voter_1", "world_1"])
477
+
478
+ with gr.Tabs():
479
+ with gr.Tab("Inference"):
480
+ with gr.Row():
481
+ with gr.Column(scale=1):
482
+ gr.Markdown("### 1. Configuration & Input")
483
+ input_method = gr.Radio(choices=["💡 Pick a Sample", "✍️ Type my own"], value="💡 Pick a Sample", label="How do you want to ask?")
484
+ sample_dropdown = gr.Dropdown(choices=SAMPLE_QUESTIONS, value=SAMPLE_QUESTIONS[0], label="Select a Sample Question", info="The database will be selected automatically.", visible=True)
485
+ type_own_warning = gr.Markdown("**⚠️ Please select a Database first, then type your custom question below:**", visible=False)
486
+ gr.Markdown("---")
487
+ db_id = gr.Dropdown(choices=DBS, value="chinook_1", label="Select Database", interactive=False)
488
+ custom_question = gr.Textbox(label="Ask your Custom Question", placeholder="Type your own question here...", lines=3, visible=False)
489
+
490
+ gr.Markdown("#### 📋 Database Structure")
491
+ gr.HTML("<p style='font-size: 0.85em; color: #64748b; margin-top: -10px; margin-bottom: 5px;'>Use these exact names! Table names are <strong>Dark</strong>, Column names are <span style='color: #94a3b8;'>Light</span>.</p>")
492
+ schema_display = gr.HTML(value=update_schema("chinook_1"))
493
+
494
+ with gr.Row():
495
+ clear_btn = gr.Button("🗑️ Clear", variant="secondary")
496
+ run_btn = gr.Button(" Generate & Run SQL", variant="primary")
497
+
498
+ with gr.Column(scale=2):
499
+ gr.Markdown("### 2. Execution Results")
500
+ final_sql = gr.Code(language="sql", label="Final Executed SQL")
501
+ result_table = gr.Dataframe(label="Query Result Table", interactive=False, wrap=True)
502
+ explanation = gr.Textbox(label="AI Explanation + Execution Details", lines=8)
503
+
504
+ with gr.Tab("Diagnostics"):
505
+ gr.Markdown("## Diagnostics & Telemetry")
506
+
507
+ with gr.Accordion("Task 1: Parallel Reward Benchmark", open=False):
508
+ gr.Markdown("*(Simulates the heavy RLHF training workload by running hundreds of complex SQL queries concurrently to test SQLite multi-threading performance.)*")
509
+ t1_n = gr.Number(value=20, precision=0, label="Rollouts (n)")
510
+ t1_workers = gr.Number(value=10, precision=0, label="Max workers")
511
+ t1_run = gr.Button("Run Task 1 benchmark")
512
+ t1_out = gr.Textbox(label="Output", lines=12)
513
+ t1_plot = gr.HTML(label="Plot (if generated)")
514
+ t1_run.click(fn=task1_benchmark, inputs=[t1_n, t1_workers], outputs=[t1_out, t1_plot])
515
+
516
+ with gr.Accordion("Task 2: Error Dashboard", open=True):
517
+ gr.Markdown("*(Live telemetry tracking the most common SQL failures. Populates automatically when queries fail in the Inference tab.)*")
518
+ t2_refresh = gr.Button("Refresh dashboard")
519
+ t2_counts = gr.Dataframe(label="Error counts", interactive=False, wrap=True)
520
+ t2_recent = gr.Dataframe(label="Recent errors", interactive=False, wrap=True)
521
+ t2_type = gr.Dropdown(choices=[], value=None, label="Select error type")
522
+ t2_examples = gr.Textbox(label="Examples + hint", lines=10)
523
+
524
+ t2_refresh.click(fn=task2_dashboard_structured, inputs=[], outputs=[t2_counts, t2_recent, t2_type])
525
+ t2_type.change(fn=task2_error_examples, inputs=[t2_type], outputs=[t2_examples])
526
+
527
+ with gr.Accordion("Task 2: Clause Telemetry", open=False):
528
+ gr.Markdown("*(Analyzes which specific SQL clauses—SELECT, WHERE, JOIN, etc.—are most prone to errors during natural language generation.)*")
529
+ t2_ops_refresh = gr.Button("Refresh SQL-op stats")
530
+ t2_ops_tbl = gr.Dataframe(label="Success/failure by op", interactive=False, wrap=True)
531
+ t2_ops_plot = gr.HTML(label="Op plot")
532
+ t2_ops_refresh.click(fn=task2_ops_table, inputs=[], outputs=[t2_ops_tbl, t2_ops_plot])
533
+
534
+ # EVENT BINDING: The .then() forces the diagnostic tab to update live in the background!
535
+ input_method.change(fn=toggle_input_method, inputs=[input_method, sample_dropdown], outputs=[sample_dropdown, type_own_warning, custom_question, db_id])
536
+ sample_dropdown.change(fn=load_sample, inputs=[sample_dropdown], outputs=[db_id])
537
+ db_id.change(fn=update_schema, inputs=[db_id], outputs=[schema_display])
538
+
539
+ run_btn.click(
540
+ fn=run_query,
541
+ inputs=[input_method, sample_dropdown, custom_question, db_id],
542
+ outputs=[final_sql, result_table, explanation]
543
+ ).then(
544
+ fn=task2_dashboard_structured, inputs=[], outputs=[t2_counts, t2_recent, t2_type]
545
+ ).then(
546
+ fn=task2_ops_table, inputs=[], outputs=[t2_ops_tbl, t2_ops_plot]
547
+ )
548
+
549
+ clear_btn.click(fn=clear_inputs, inputs=[], outputs=[input_method, sample_dropdown, type_own_warning, custom_question, db_id, final_sql, result_table, explanation])
550
+
551
+ if __name__ == "__main__":
552
+ server_name = os.environ.get("GRADIO_SERVER_NAME", "0.0.0.0")
553
+ base_port = int(os.environ.get("GRADIO_SERVER_PORT", 7860))
554
+ max_retries = 10
555
+
556
+ for port in range(base_port, base_port + max_retries):
557
+ try:
558
+ print(f"Attempting to start Gradio UI on {server_name}:{port}...", flush=True)
559
+ demo.launch(server_name=server_name, server_port=port)
560
+ break # If successful, exit the loop
561
+ except OSError as e:
562
+ if "Cannot find empty port" in str(e) or "Address already in use" in str(e):
563
+ print(f"⚠️ Port {port} is in use, trying next port...")
564
+ continue
565
+ else:
566
+ # If it's a different OSError, raise it normally
567
+ raise e
568
+ else:
569
+ print(f"❌ Could not find an open port between {base_port} and {base_port + max_retries - 1}.")
db.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bb0e6ef12110c4c9808205cb210a35b5c4412397e15f47ed437e739e161d4213
3
+ size 53803466
int8_dynamic/meta.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "mode": "int8_dynamic",
3
+ "base_model": "Salesforce/codet5-base",
4
+ "adapter_path": "checkpoints/best_rlhf_model_2",
5
+ "created_at_s": 1774418718.320342,
6
+ "estimated_model_bytes": 98804736
7
+ }
int8_dynamic/model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f398e044cd49fc84553b746d26ad79beb1dd565d90cf8f6f5e50d27f48d08228
3
+ size 322871519
int8_dynamic/tokenizer/merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
int8_dynamic/tokenizer/special_tokens_map.json ADDED
@@ -0,0 +1,753 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ {
4
+ "content": "<extra_id_99>",
5
+ "lstrip": true,
6
+ "normalized": true,
7
+ "rstrip": false,
8
+ "single_word": false
9
+ },
10
+ {
11
+ "content": "<extra_id_98>",
12
+ "lstrip": true,
13
+ "normalized": true,
14
+ "rstrip": false,
15
+ "single_word": false
16
+ },
17
+ {
18
+ "content": "<extra_id_97>",
19
+ "lstrip": true,
20
+ "normalized": true,
21
+ "rstrip": false,
22
+ "single_word": false
23
+ },
24
+ {
25
+ "content": "<extra_id_96>",
26
+ "lstrip": true,
27
+ "normalized": true,
28
+ "rstrip": false,
29
+ "single_word": false
30
+ },
31
+ {
32
+ "content": "<extra_id_95>",
33
+ "lstrip": true,
34
+ "normalized": true,
35
+ "rstrip": false,
36
+ "single_word": false
37
+ },
38
+ {
39
+ "content": "<extra_id_94>",
40
+ "lstrip": true,
41
+ "normalized": true,
42
+ "rstrip": false,
43
+ "single_word": false
44
+ },
45
+ {
46
+ "content": "<extra_id_93>",
47
+ "lstrip": true,
48
+ "normalized": true,
49
+ "rstrip": false,
50
+ "single_word": false
51
+ },
52
+ {
53
+ "content": "<extra_id_92>",
54
+ "lstrip": true,
55
+ "normalized": true,
56
+ "rstrip": false,
57
+ "single_word": false
58
+ },
59
+ {
60
+ "content": "<extra_id_91>",
61
+ "lstrip": true,
62
+ "normalized": true,
63
+ "rstrip": false,
64
+ "single_word": false
65
+ },
66
+ {
67
+ "content": "<extra_id_90>",
68
+ "lstrip": true,
69
+ "normalized": true,
70
+ "rstrip": false,
71
+ "single_word": false
72
+ },
73
+ {
74
+ "content": "<extra_id_89>",
75
+ "lstrip": true,
76
+ "normalized": true,
77
+ "rstrip": false,
78
+ "single_word": false
79
+ },
80
+ {
81
+ "content": "<extra_id_88>",
82
+ "lstrip": true,
83
+ "normalized": true,
84
+ "rstrip": false,
85
+ "single_word": false
86
+ },
87
+ {
88
+ "content": "<extra_id_87>",
89
+ "lstrip": true,
90
+ "normalized": true,
91
+ "rstrip": false,
92
+ "single_word": false
93
+ },
94
+ {
95
+ "content": "<extra_id_86>",
96
+ "lstrip": true,
97
+ "normalized": true,
98
+ "rstrip": false,
99
+ "single_word": false
100
+ },
101
+ {
102
+ "content": "<extra_id_85>",
103
+ "lstrip": true,
104
+ "normalized": true,
105
+ "rstrip": false,
106
+ "single_word": false
107
+ },
108
+ {
109
+ "content": "<extra_id_84>",
110
+ "lstrip": true,
111
+ "normalized": true,
112
+ "rstrip": false,
113
+ "single_word": false
114
+ },
115
+ {
116
+ "content": "<extra_id_83>",
117
+ "lstrip": true,
118
+ "normalized": true,
119
+ "rstrip": false,
120
+ "single_word": false
121
+ },
122
+ {
123
+ "content": "<extra_id_82>",
124
+ "lstrip": true,
125
+ "normalized": true,
126
+ "rstrip": false,
127
+ "single_word": false
128
+ },
129
+ {
130
+ "content": "<extra_id_81>",
131
+ "lstrip": true,
132
+ "normalized": true,
133
+ "rstrip": false,
134
+ "single_word": false
135
+ },
136
+ {
137
+ "content": "<extra_id_80>",
138
+ "lstrip": true,
139
+ "normalized": true,
140
+ "rstrip": false,
141
+ "single_word": false
142
+ },
143
+ {
144
+ "content": "<extra_id_79>",
145
+ "lstrip": true,
146
+ "normalized": true,
147
+ "rstrip": false,
148
+ "single_word": false
149
+ },
150
+ {
151
+ "content": "<extra_id_78>",
152
+ "lstrip": true,
153
+ "normalized": true,
154
+ "rstrip": false,
155
+ "single_word": false
156
+ },
157
+ {
158
+ "content": "<extra_id_77>",
159
+ "lstrip": true,
160
+ "normalized": true,
161
+ "rstrip": false,
162
+ "single_word": false
163
+ },
164
+ {
165
+ "content": "<extra_id_76>",
166
+ "lstrip": true,
167
+ "normalized": true,
168
+ "rstrip": false,
169
+ "single_word": false
170
+ },
171
+ {
172
+ "content": "<extra_id_75>",
173
+ "lstrip": true,
174
+ "normalized": true,
175
+ "rstrip": false,
176
+ "single_word": false
177
+ },
178
+ {
179
+ "content": "<extra_id_74>",
180
+ "lstrip": true,
181
+ "normalized": true,
182
+ "rstrip": false,
183
+ "single_word": false
184
+ },
185
+ {
186
+ "content": "<extra_id_73>",
187
+ "lstrip": true,
188
+ "normalized": true,
189
+ "rstrip": false,
190
+ "single_word": false
191
+ },
192
+ {
193
+ "content": "<extra_id_72>",
194
+ "lstrip": true,
195
+ "normalized": true,
196
+ "rstrip": false,
197
+ "single_word": false
198
+ },
199
+ {
200
+ "content": "<extra_id_71>",
201
+ "lstrip": true,
202
+ "normalized": true,
203
+ "rstrip": false,
204
+ "single_word": false
205
+ },
206
+ {
207
+ "content": "<extra_id_70>",
208
+ "lstrip": true,
209
+ "normalized": true,
210
+ "rstrip": false,
211
+ "single_word": false
212
+ },
213
+ {
214
+ "content": "<extra_id_69>",
215
+ "lstrip": true,
216
+ "normalized": true,
217
+ "rstrip": false,
218
+ "single_word": false
219
+ },
220
+ {
221
+ "content": "<extra_id_68>",
222
+ "lstrip": true,
223
+ "normalized": true,
224
+ "rstrip": false,
225
+ "single_word": false
226
+ },
227
+ {
228
+ "content": "<extra_id_67>",
229
+ "lstrip": true,
230
+ "normalized": true,
231
+ "rstrip": false,
232
+ "single_word": false
233
+ },
234
+ {
235
+ "content": "<extra_id_66>",
236
+ "lstrip": true,
237
+ "normalized": true,
238
+ "rstrip": false,
239
+ "single_word": false
240
+ },
241
+ {
242
+ "content": "<extra_id_65>",
243
+ "lstrip": true,
244
+ "normalized": true,
245
+ "rstrip": false,
246
+ "single_word": false
247
+ },
248
+ {
249
+ "content": "<extra_id_64>",
250
+ "lstrip": true,
251
+ "normalized": true,
252
+ "rstrip": false,
253
+ "single_word": false
254
+ },
255
+ {
256
+ "content": "<extra_id_63>",
257
+ "lstrip": true,
258
+ "normalized": true,
259
+ "rstrip": false,
260
+ "single_word": false
261
+ },
262
+ {
263
+ "content": "<extra_id_62>",
264
+ "lstrip": true,
265
+ "normalized": true,
266
+ "rstrip": false,
267
+ "single_word": false
268
+ },
269
+ {
270
+ "content": "<extra_id_61>",
271
+ "lstrip": true,
272
+ "normalized": true,
273
+ "rstrip": false,
274
+ "single_word": false
275
+ },
276
+ {
277
+ "content": "<extra_id_60>",
278
+ "lstrip": true,
279
+ "normalized": true,
280
+ "rstrip": false,
281
+ "single_word": false
282
+ },
283
+ {
284
+ "content": "<extra_id_59>",
285
+ "lstrip": true,
286
+ "normalized": true,
287
+ "rstrip": false,
288
+ "single_word": false
289
+ },
290
+ {
291
+ "content": "<extra_id_58>",
292
+ "lstrip": true,
293
+ "normalized": true,
294
+ "rstrip": false,
295
+ "single_word": false
296
+ },
297
+ {
298
+ "content": "<extra_id_57>",
299
+ "lstrip": true,
300
+ "normalized": true,
301
+ "rstrip": false,
302
+ "single_word": false
303
+ },
304
+ {
305
+ "content": "<extra_id_56>",
306
+ "lstrip": true,
307
+ "normalized": true,
308
+ "rstrip": false,
309
+ "single_word": false
310
+ },
311
+ {
312
+ "content": "<extra_id_55>",
313
+ "lstrip": true,
314
+ "normalized": true,
315
+ "rstrip": false,
316
+ "single_word": false
317
+ },
318
+ {
319
+ "content": "<extra_id_54>",
320
+ "lstrip": true,
321
+ "normalized": true,
322
+ "rstrip": false,
323
+ "single_word": false
324
+ },
325
+ {
326
+ "content": "<extra_id_53>",
327
+ "lstrip": true,
328
+ "normalized": true,
329
+ "rstrip": false,
330
+ "single_word": false
331
+ },
332
+ {
333
+ "content": "<extra_id_52>",
334
+ "lstrip": true,
335
+ "normalized": true,
336
+ "rstrip": false,
337
+ "single_word": false
338
+ },
339
+ {
340
+ "content": "<extra_id_51>",
341
+ "lstrip": true,
342
+ "normalized": true,
343
+ "rstrip": false,
344
+ "single_word": false
345
+ },
346
+ {
347
+ "content": "<extra_id_50>",
348
+ "lstrip": true,
349
+ "normalized": true,
350
+ "rstrip": false,
351
+ "single_word": false
352
+ },
353
+ {
354
+ "content": "<extra_id_49>",
355
+ "lstrip": true,
356
+ "normalized": true,
357
+ "rstrip": false,
358
+ "single_word": false
359
+ },
360
+ {
361
+ "content": "<extra_id_48>",
362
+ "lstrip": true,
363
+ "normalized": true,
364
+ "rstrip": false,
365
+ "single_word": false
366
+ },
367
+ {
368
+ "content": "<extra_id_47>",
369
+ "lstrip": true,
370
+ "normalized": true,
371
+ "rstrip": false,
372
+ "single_word": false
373
+ },
374
+ {
375
+ "content": "<extra_id_46>",
376
+ "lstrip": true,
377
+ "normalized": true,
378
+ "rstrip": false,
379
+ "single_word": false
380
+ },
381
+ {
382
+ "content": "<extra_id_45>",
383
+ "lstrip": true,
384
+ "normalized": true,
385
+ "rstrip": false,
386
+ "single_word": false
387
+ },
388
+ {
389
+ "content": "<extra_id_44>",
390
+ "lstrip": true,
391
+ "normalized": true,
392
+ "rstrip": false,
393
+ "single_word": false
394
+ },
395
+ {
396
+ "content": "<extra_id_43>",
397
+ "lstrip": true,
398
+ "normalized": true,
399
+ "rstrip": false,
400
+ "single_word": false
401
+ },
402
+ {
403
+ "content": "<extra_id_42>",
404
+ "lstrip": true,
405
+ "normalized": true,
406
+ "rstrip": false,
407
+ "single_word": false
408
+ },
409
+ {
410
+ "content": "<extra_id_41>",
411
+ "lstrip": true,
412
+ "normalized": true,
413
+ "rstrip": false,
414
+ "single_word": false
415
+ },
416
+ {
417
+ "content": "<extra_id_40>",
418
+ "lstrip": true,
419
+ "normalized": true,
420
+ "rstrip": false,
421
+ "single_word": false
422
+ },
423
+ {
424
+ "content": "<extra_id_39>",
425
+ "lstrip": true,
426
+ "normalized": true,
427
+ "rstrip": false,
428
+ "single_word": false
429
+ },
430
+ {
431
+ "content": "<extra_id_38>",
432
+ "lstrip": true,
433
+ "normalized": true,
434
+ "rstrip": false,
435
+ "single_word": false
436
+ },
437
+ {
438
+ "content": "<extra_id_37>",
439
+ "lstrip": true,
440
+ "normalized": true,
441
+ "rstrip": false,
442
+ "single_word": false
443
+ },
444
+ {
445
+ "content": "<extra_id_36>",
446
+ "lstrip": true,
447
+ "normalized": true,
448
+ "rstrip": false,
449
+ "single_word": false
450
+ },
451
+ {
452
+ "content": "<extra_id_35>",
453
+ "lstrip": true,
454
+ "normalized": true,
455
+ "rstrip": false,
456
+ "single_word": false
457
+ },
458
+ {
459
+ "content": "<extra_id_34>",
460
+ "lstrip": true,
461
+ "normalized": true,
462
+ "rstrip": false,
463
+ "single_word": false
464
+ },
465
+ {
466
+ "content": "<extra_id_33>",
467
+ "lstrip": true,
468
+ "normalized": true,
469
+ "rstrip": false,
470
+ "single_word": false
471
+ },
472
+ {
473
+ "content": "<extra_id_32>",
474
+ "lstrip": true,
475
+ "normalized": true,
476
+ "rstrip": false,
477
+ "single_word": false
478
+ },
479
+ {
480
+ "content": "<extra_id_31>",
481
+ "lstrip": true,
482
+ "normalized": true,
483
+ "rstrip": false,
484
+ "single_word": false
485
+ },
486
+ {
487
+ "content": "<extra_id_30>",
488
+ "lstrip": true,
489
+ "normalized": true,
490
+ "rstrip": false,
491
+ "single_word": false
492
+ },
493
+ {
494
+ "content": "<extra_id_29>",
495
+ "lstrip": true,
496
+ "normalized": true,
497
+ "rstrip": false,
498
+ "single_word": false
499
+ },
500
+ {
501
+ "content": "<extra_id_28>",
502
+ "lstrip": true,
503
+ "normalized": true,
504
+ "rstrip": false,
505
+ "single_word": false
506
+ },
507
+ {
508
+ "content": "<extra_id_27>",
509
+ "lstrip": true,
510
+ "normalized": true,
511
+ "rstrip": false,
512
+ "single_word": false
513
+ },
514
+ {
515
+ "content": "<extra_id_26>",
516
+ "lstrip": true,
517
+ "normalized": true,
518
+ "rstrip": false,
519
+ "single_word": false
520
+ },
521
+ {
522
+ "content": "<extra_id_25>",
523
+ "lstrip": true,
524
+ "normalized": true,
525
+ "rstrip": false,
526
+ "single_word": false
527
+ },
528
+ {
529
+ "content": "<extra_id_24>",
530
+ "lstrip": true,
531
+ "normalized": true,
532
+ "rstrip": false,
533
+ "single_word": false
534
+ },
535
+ {
536
+ "content": "<extra_id_23>",
537
+ "lstrip": true,
538
+ "normalized": true,
539
+ "rstrip": false,
540
+ "single_word": false
541
+ },
542
+ {
543
+ "content": "<extra_id_22>",
544
+ "lstrip": true,
545
+ "normalized": true,
546
+ "rstrip": false,
547
+ "single_word": false
548
+ },
549
+ {
550
+ "content": "<extra_id_21>",
551
+ "lstrip": true,
552
+ "normalized": true,
553
+ "rstrip": false,
554
+ "single_word": false
555
+ },
556
+ {
557
+ "content": "<extra_id_20>",
558
+ "lstrip": true,
559
+ "normalized": true,
560
+ "rstrip": false,
561
+ "single_word": false
562
+ },
563
+ {
564
+ "content": "<extra_id_19>",
565
+ "lstrip": true,
566
+ "normalized": true,
567
+ "rstrip": false,
568
+ "single_word": false
569
+ },
570
+ {
571
+ "content": "<extra_id_18>",
572
+ "lstrip": true,
573
+ "normalized": true,
574
+ "rstrip": false,
575
+ "single_word": false
576
+ },
577
+ {
578
+ "content": "<extra_id_17>",
579
+ "lstrip": true,
580
+ "normalized": true,
581
+ "rstrip": false,
582
+ "single_word": false
583
+ },
584
+ {
585
+ "content": "<extra_id_16>",
586
+ "lstrip": true,
587
+ "normalized": true,
588
+ "rstrip": false,
589
+ "single_word": false
590
+ },
591
+ {
592
+ "content": "<extra_id_15>",
593
+ "lstrip": true,
594
+ "normalized": true,
595
+ "rstrip": false,
596
+ "single_word": false
597
+ },
598
+ {
599
+ "content": "<extra_id_14>",
600
+ "lstrip": true,
601
+ "normalized": true,
602
+ "rstrip": false,
603
+ "single_word": false
604
+ },
605
+ {
606
+ "content": "<extra_id_13>",
607
+ "lstrip": true,
608
+ "normalized": true,
609
+ "rstrip": false,
610
+ "single_word": false
611
+ },
612
+ {
613
+ "content": "<extra_id_12>",
614
+ "lstrip": true,
615
+ "normalized": true,
616
+ "rstrip": false,
617
+ "single_word": false
618
+ },
619
+ {
620
+ "content": "<extra_id_11>",
621
+ "lstrip": true,
622
+ "normalized": true,
623
+ "rstrip": false,
624
+ "single_word": false
625
+ },
626
+ {
627
+ "content": "<extra_id_10>",
628
+ "lstrip": true,
629
+ "normalized": true,
630
+ "rstrip": false,
631
+ "single_word": false
632
+ },
633
+ {
634
+ "content": "<extra_id_9>",
635
+ "lstrip": true,
636
+ "normalized": true,
637
+ "rstrip": false,
638
+ "single_word": false
639
+ },
640
+ {
641
+ "content": "<extra_id_8>",
642
+ "lstrip": true,
643
+ "normalized": true,
644
+ "rstrip": false,
645
+ "single_word": false
646
+ },
647
+ {
648
+ "content": "<extra_id_7>",
649
+ "lstrip": true,
650
+ "normalized": true,
651
+ "rstrip": false,
652
+ "single_word": false
653
+ },
654
+ {
655
+ "content": "<extra_id_6>",
656
+ "lstrip": true,
657
+ "normalized": true,
658
+ "rstrip": false,
659
+ "single_word": false
660
+ },
661
+ {
662
+ "content": "<extra_id_5>",
663
+ "lstrip": true,
664
+ "normalized": true,
665
+ "rstrip": false,
666
+ "single_word": false
667
+ },
668
+ {
669
+ "content": "<extra_id_4>",
670
+ "lstrip": true,
671
+ "normalized": true,
672
+ "rstrip": false,
673
+ "single_word": false
674
+ },
675
+ {
676
+ "content": "<extra_id_3>",
677
+ "lstrip": true,
678
+ "normalized": true,
679
+ "rstrip": false,
680
+ "single_word": false
681
+ },
682
+ {
683
+ "content": "<extra_id_2>",
684
+ "lstrip": true,
685
+ "normalized": true,
686
+ "rstrip": false,
687
+ "single_word": false
688
+ },
689
+ {
690
+ "content": "<extra_id_1>",
691
+ "lstrip": true,
692
+ "normalized": true,
693
+ "rstrip": false,
694
+ "single_word": false
695
+ },
696
+ {
697
+ "content": "<extra_id_0>",
698
+ "lstrip": true,
699
+ "normalized": true,
700
+ "rstrip": false,
701
+ "single_word": false
702
+ }
703
+ ],
704
+ "bos_token": {
705
+ "content": "<s>",
706
+ "lstrip": false,
707
+ "normalized": true,
708
+ "rstrip": false,
709
+ "single_word": false
710
+ },
711
+ "cls_token": {
712
+ "content": "<s>",
713
+ "lstrip": false,
714
+ "normalized": true,
715
+ "rstrip": false,
716
+ "single_word": false
717
+ },
718
+ "eos_token": {
719
+ "content": "</s>",
720
+ "lstrip": false,
721
+ "normalized": true,
722
+ "rstrip": false,
723
+ "single_word": false
724
+ },
725
+ "mask_token": {
726
+ "content": "<mask>",
727
+ "lstrip": true,
728
+ "normalized": true,
729
+ "rstrip": false,
730
+ "single_word": false
731
+ },
732
+ "pad_token": {
733
+ "content": "<pad>",
734
+ "lstrip": false,
735
+ "normalized": true,
736
+ "rstrip": false,
737
+ "single_word": false
738
+ },
739
+ "sep_token": {
740
+ "content": "</s>",
741
+ "lstrip": false,
742
+ "normalized": true,
743
+ "rstrip": false,
744
+ "single_word": false
745
+ },
746
+ "unk_token": {
747
+ "content": "<unk>",
748
+ "lstrip": false,
749
+ "normalized": true,
750
+ "rstrip": false,
751
+ "single_word": false
752
+ }
753
+ }
int8_dynamic/tokenizer/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
int8_dynamic/tokenizer/tokenizer_config.json ADDED
@@ -0,0 +1,959 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "added_tokens_decoder": {
4
+ "0": {
5
+ "content": "<pad>",
6
+ "lstrip": false,
7
+ "normalized": true,
8
+ "rstrip": false,
9
+ "single_word": false,
10
+ "special": true
11
+ },
12
+ "1": {
13
+ "content": "<s>",
14
+ "lstrip": false,
15
+ "normalized": true,
16
+ "rstrip": false,
17
+ "single_word": false,
18
+ "special": true
19
+ },
20
+ "2": {
21
+ "content": "</s>",
22
+ "lstrip": false,
23
+ "normalized": true,
24
+ "rstrip": false,
25
+ "single_word": false,
26
+ "special": true
27
+ },
28
+ "3": {
29
+ "content": "<unk>",
30
+ "lstrip": false,
31
+ "normalized": true,
32
+ "rstrip": false,
33
+ "single_word": false,
34
+ "special": true
35
+ },
36
+ "4": {
37
+ "content": "<mask>",
38
+ "lstrip": true,
39
+ "normalized": true,
40
+ "rstrip": false,
41
+ "single_word": false,
42
+ "special": true
43
+ },
44
+ "32000": {
45
+ "content": "<extra_id_99>",
46
+ "lstrip": true,
47
+ "normalized": true,
48
+ "rstrip": false,
49
+ "single_word": false,
50
+ "special": true
51
+ },
52
+ "32001": {
53
+ "content": "<extra_id_98>",
54
+ "lstrip": true,
55
+ "normalized": true,
56
+ "rstrip": false,
57
+ "single_word": false,
58
+ "special": true
59
+ },
60
+ "32002": {
61
+ "content": "<extra_id_97>",
62
+ "lstrip": true,
63
+ "normalized": true,
64
+ "rstrip": false,
65
+ "single_word": false,
66
+ "special": true
67
+ },
68
+ "32003": {
69
+ "content": "<extra_id_96>",
70
+ "lstrip": true,
71
+ "normalized": true,
72
+ "rstrip": false,
73
+ "single_word": false,
74
+ "special": true
75
+ },
76
+ "32004": {
77
+ "content": "<extra_id_95>",
78
+ "lstrip": true,
79
+ "normalized": true,
80
+ "rstrip": false,
81
+ "single_word": false,
82
+ "special": true
83
+ },
84
+ "32005": {
85
+ "content": "<extra_id_94>",
86
+ "lstrip": true,
87
+ "normalized": true,
88
+ "rstrip": false,
89
+ "single_word": false,
90
+ "special": true
91
+ },
92
+ "32006": {
93
+ "content": "<extra_id_93>",
94
+ "lstrip": true,
95
+ "normalized": true,
96
+ "rstrip": false,
97
+ "single_word": false,
98
+ "special": true
99
+ },
100
+ "32007": {
101
+ "content": "<extra_id_92>",
102
+ "lstrip": true,
103
+ "normalized": true,
104
+ "rstrip": false,
105
+ "single_word": false,
106
+ "special": true
107
+ },
108
+ "32008": {
109
+ "content": "<extra_id_91>",
110
+ "lstrip": true,
111
+ "normalized": true,
112
+ "rstrip": false,
113
+ "single_word": false,
114
+ "special": true
115
+ },
116
+ "32009": {
117
+ "content": "<extra_id_90>",
118
+ "lstrip": true,
119
+ "normalized": true,
120
+ "rstrip": false,
121
+ "single_word": false,
122
+ "special": true
123
+ },
124
+ "32010": {
125
+ "content": "<extra_id_89>",
126
+ "lstrip": true,
127
+ "normalized": true,
128
+ "rstrip": false,
129
+ "single_word": false,
130
+ "special": true
131
+ },
132
+ "32011": {
133
+ "content": "<extra_id_88>",
134
+ "lstrip": true,
135
+ "normalized": true,
136
+ "rstrip": false,
137
+ "single_word": false,
138
+ "special": true
139
+ },
140
+ "32012": {
141
+ "content": "<extra_id_87>",
142
+ "lstrip": true,
143
+ "normalized": true,
144
+ "rstrip": false,
145
+ "single_word": false,
146
+ "special": true
147
+ },
148
+ "32013": {
149
+ "content": "<extra_id_86>",
150
+ "lstrip": true,
151
+ "normalized": true,
152
+ "rstrip": false,
153
+ "single_word": false,
154
+ "special": true
155
+ },
156
+ "32014": {
157
+ "content": "<extra_id_85>",
158
+ "lstrip": true,
159
+ "normalized": true,
160
+ "rstrip": false,
161
+ "single_word": false,
162
+ "special": true
163
+ },
164
+ "32015": {
165
+ "content": "<extra_id_84>",
166
+ "lstrip": true,
167
+ "normalized": true,
168
+ "rstrip": false,
169
+ "single_word": false,
170
+ "special": true
171
+ },
172
+ "32016": {
173
+ "content": "<extra_id_83>",
174
+ "lstrip": true,
175
+ "normalized": true,
176
+ "rstrip": false,
177
+ "single_word": false,
178
+ "special": true
179
+ },
180
+ "32017": {
181
+ "content": "<extra_id_82>",
182
+ "lstrip": true,
183
+ "normalized": true,
184
+ "rstrip": false,
185
+ "single_word": false,
186
+ "special": true
187
+ },
188
+ "32018": {
189
+ "content": "<extra_id_81>",
190
+ "lstrip": true,
191
+ "normalized": true,
192
+ "rstrip": false,
193
+ "single_word": false,
194
+ "special": true
195
+ },
196
+ "32019": {
197
+ "content": "<extra_id_80>",
198
+ "lstrip": true,
199
+ "normalized": true,
200
+ "rstrip": false,
201
+ "single_word": false,
202
+ "special": true
203
+ },
204
+ "32020": {
205
+ "content": "<extra_id_79>",
206
+ "lstrip": true,
207
+ "normalized": true,
208
+ "rstrip": false,
209
+ "single_word": false,
210
+ "special": true
211
+ },
212
+ "32021": {
213
+ "content": "<extra_id_78>",
214
+ "lstrip": true,
215
+ "normalized": true,
216
+ "rstrip": false,
217
+ "single_word": false,
218
+ "special": true
219
+ },
220
+ "32022": {
221
+ "content": "<extra_id_77>",
222
+ "lstrip": true,
223
+ "normalized": true,
224
+ "rstrip": false,
225
+ "single_word": false,
226
+ "special": true
227
+ },
228
+ "32023": {
229
+ "content": "<extra_id_76>",
230
+ "lstrip": true,
231
+ "normalized": true,
232
+ "rstrip": false,
233
+ "single_word": false,
234
+ "special": true
235
+ },
236
+ "32024": {
237
+ "content": "<extra_id_75>",
238
+ "lstrip": true,
239
+ "normalized": true,
240
+ "rstrip": false,
241
+ "single_word": false,
242
+ "special": true
243
+ },
244
+ "32025": {
245
+ "content": "<extra_id_74>",
246
+ "lstrip": true,
247
+ "normalized": true,
248
+ "rstrip": false,
249
+ "single_word": false,
250
+ "special": true
251
+ },
252
+ "32026": {
253
+ "content": "<extra_id_73>",
254
+ "lstrip": true,
255
+ "normalized": true,
256
+ "rstrip": false,
257
+ "single_word": false,
258
+ "special": true
259
+ },
260
+ "32027": {
261
+ "content": "<extra_id_72>",
262
+ "lstrip": true,
263
+ "normalized": true,
264
+ "rstrip": false,
265
+ "single_word": false,
266
+ "special": true
267
+ },
268
+ "32028": {
269
+ "content": "<extra_id_71>",
270
+ "lstrip": true,
271
+ "normalized": true,
272
+ "rstrip": false,
273
+ "single_word": false,
274
+ "special": true
275
+ },
276
+ "32029": {
277
+ "content": "<extra_id_70>",
278
+ "lstrip": true,
279
+ "normalized": true,
280
+ "rstrip": false,
281
+ "single_word": false,
282
+ "special": true
283
+ },
284
+ "32030": {
285
+ "content": "<extra_id_69>",
286
+ "lstrip": true,
287
+ "normalized": true,
288
+ "rstrip": false,
289
+ "single_word": false,
290
+ "special": true
291
+ },
292
+ "32031": {
293
+ "content": "<extra_id_68>",
294
+ "lstrip": true,
295
+ "normalized": true,
296
+ "rstrip": false,
297
+ "single_word": false,
298
+ "special": true
299
+ },
300
+ "32032": {
301
+ "content": "<extra_id_67>",
302
+ "lstrip": true,
303
+ "normalized": true,
304
+ "rstrip": false,
305
+ "single_word": false,
306
+ "special": true
307
+ },
308
+ "32033": {
309
+ "content": "<extra_id_66>",
310
+ "lstrip": true,
311
+ "normalized": true,
312
+ "rstrip": false,
313
+ "single_word": false,
314
+ "special": true
315
+ },
316
+ "32034": {
317
+ "content": "<extra_id_65>",
318
+ "lstrip": true,
319
+ "normalized": true,
320
+ "rstrip": false,
321
+ "single_word": false,
322
+ "special": true
323
+ },
324
+ "32035": {
325
+ "content": "<extra_id_64>",
326
+ "lstrip": true,
327
+ "normalized": true,
328
+ "rstrip": false,
329
+ "single_word": false,
330
+ "special": true
331
+ },
332
+ "32036": {
333
+ "content": "<extra_id_63>",
334
+ "lstrip": true,
335
+ "normalized": true,
336
+ "rstrip": false,
337
+ "single_word": false,
338
+ "special": true
339
+ },
340
+ "32037": {
341
+ "content": "<extra_id_62>",
342
+ "lstrip": true,
343
+ "normalized": true,
344
+ "rstrip": false,
345
+ "single_word": false,
346
+ "special": true
347
+ },
348
+ "32038": {
349
+ "content": "<extra_id_61>",
350
+ "lstrip": true,
351
+ "normalized": true,
352
+ "rstrip": false,
353
+ "single_word": false,
354
+ "special": true
355
+ },
356
+ "32039": {
357
+ "content": "<extra_id_60>",
358
+ "lstrip": true,
359
+ "normalized": true,
360
+ "rstrip": false,
361
+ "single_word": false,
362
+ "special": true
363
+ },
364
+ "32040": {
365
+ "content": "<extra_id_59>",
366
+ "lstrip": true,
367
+ "normalized": true,
368
+ "rstrip": false,
369
+ "single_word": false,
370
+ "special": true
371
+ },
372
+ "32041": {
373
+ "content": "<extra_id_58>",
374
+ "lstrip": true,
375
+ "normalized": true,
376
+ "rstrip": false,
377
+ "single_word": false,
378
+ "special": true
379
+ },
380
+ "32042": {
381
+ "content": "<extra_id_57>",
382
+ "lstrip": true,
383
+ "normalized": true,
384
+ "rstrip": false,
385
+ "single_word": false,
386
+ "special": true
387
+ },
388
+ "32043": {
389
+ "content": "<extra_id_56>",
390
+ "lstrip": true,
391
+ "normalized": true,
392
+ "rstrip": false,
393
+ "single_word": false,
394
+ "special": true
395
+ },
396
+ "32044": {
397
+ "content": "<extra_id_55>",
398
+ "lstrip": true,
399
+ "normalized": true,
400
+ "rstrip": false,
401
+ "single_word": false,
402
+ "special": true
403
+ },
404
+ "32045": {
405
+ "content": "<extra_id_54>",
406
+ "lstrip": true,
407
+ "normalized": true,
408
+ "rstrip": false,
409
+ "single_word": false,
410
+ "special": true
411
+ },
412
+ "32046": {
413
+ "content": "<extra_id_53>",
414
+ "lstrip": true,
415
+ "normalized": true,
416
+ "rstrip": false,
417
+ "single_word": false,
418
+ "special": true
419
+ },
420
+ "32047": {
421
+ "content": "<extra_id_52>",
422
+ "lstrip": true,
423
+ "normalized": true,
424
+ "rstrip": false,
425
+ "single_word": false,
426
+ "special": true
427
+ },
428
+ "32048": {
429
+ "content": "<extra_id_51>",
430
+ "lstrip": true,
431
+ "normalized": true,
432
+ "rstrip": false,
433
+ "single_word": false,
434
+ "special": true
435
+ },
436
+ "32049": {
437
+ "content": "<extra_id_50>",
438
+ "lstrip": true,
439
+ "normalized": true,
440
+ "rstrip": false,
441
+ "single_word": false,
442
+ "special": true
443
+ },
444
+ "32050": {
445
+ "content": "<extra_id_49>",
446
+ "lstrip": true,
447
+ "normalized": true,
448
+ "rstrip": false,
449
+ "single_word": false,
450
+ "special": true
451
+ },
452
+ "32051": {
453
+ "content": "<extra_id_48>",
454
+ "lstrip": true,
455
+ "normalized": true,
456
+ "rstrip": false,
457
+ "single_word": false,
458
+ "special": true
459
+ },
460
+ "32052": {
461
+ "content": "<extra_id_47>",
462
+ "lstrip": true,
463
+ "normalized": true,
464
+ "rstrip": false,
465
+ "single_word": false,
466
+ "special": true
467
+ },
468
+ "32053": {
469
+ "content": "<extra_id_46>",
470
+ "lstrip": true,
471
+ "normalized": true,
472
+ "rstrip": false,
473
+ "single_word": false,
474
+ "special": true
475
+ },
476
+ "32054": {
477
+ "content": "<extra_id_45>",
478
+ "lstrip": true,
479
+ "normalized": true,
480
+ "rstrip": false,
481
+ "single_word": false,
482
+ "special": true
483
+ },
484
+ "32055": {
485
+ "content": "<extra_id_44>",
486
+ "lstrip": true,
487
+ "normalized": true,
488
+ "rstrip": false,
489
+ "single_word": false,
490
+ "special": true
491
+ },
492
+ "32056": {
493
+ "content": "<extra_id_43>",
494
+ "lstrip": true,
495
+ "normalized": true,
496
+ "rstrip": false,
497
+ "single_word": false,
498
+ "special": true
499
+ },
500
+ "32057": {
501
+ "content": "<extra_id_42>",
502
+ "lstrip": true,
503
+ "normalized": true,
504
+ "rstrip": false,
505
+ "single_word": false,
506
+ "special": true
507
+ },
508
+ "32058": {
509
+ "content": "<extra_id_41>",
510
+ "lstrip": true,
511
+ "normalized": true,
512
+ "rstrip": false,
513
+ "single_word": false,
514
+ "special": true
515
+ },
516
+ "32059": {
517
+ "content": "<extra_id_40>",
518
+ "lstrip": true,
519
+ "normalized": true,
520
+ "rstrip": false,
521
+ "single_word": false,
522
+ "special": true
523
+ },
524
+ "32060": {
525
+ "content": "<extra_id_39>",
526
+ "lstrip": true,
527
+ "normalized": true,
528
+ "rstrip": false,
529
+ "single_word": false,
530
+ "special": true
531
+ },
532
+ "32061": {
533
+ "content": "<extra_id_38>",
534
+ "lstrip": true,
535
+ "normalized": true,
536
+ "rstrip": false,
537
+ "single_word": false,
538
+ "special": true
539
+ },
540
+ "32062": {
541
+ "content": "<extra_id_37>",
542
+ "lstrip": true,
543
+ "normalized": true,
544
+ "rstrip": false,
545
+ "single_word": false,
546
+ "special": true
547
+ },
548
+ "32063": {
549
+ "content": "<extra_id_36>",
550
+ "lstrip": true,
551
+ "normalized": true,
552
+ "rstrip": false,
553
+ "single_word": false,
554
+ "special": true
555
+ },
556
+ "32064": {
557
+ "content": "<extra_id_35>",
558
+ "lstrip": true,
559
+ "normalized": true,
560
+ "rstrip": false,
561
+ "single_word": false,
562
+ "special": true
563
+ },
564
+ "32065": {
565
+ "content": "<extra_id_34>",
566
+ "lstrip": true,
567
+ "normalized": true,
568
+ "rstrip": false,
569
+ "single_word": false,
570
+ "special": true
571
+ },
572
+ "32066": {
573
+ "content": "<extra_id_33>",
574
+ "lstrip": true,
575
+ "normalized": true,
576
+ "rstrip": false,
577
+ "single_word": false,
578
+ "special": true
579
+ },
580
+ "32067": {
581
+ "content": "<extra_id_32>",
582
+ "lstrip": true,
583
+ "normalized": true,
584
+ "rstrip": false,
585
+ "single_word": false,
586
+ "special": true
587
+ },
588
+ "32068": {
589
+ "content": "<extra_id_31>",
590
+ "lstrip": true,
591
+ "normalized": true,
592
+ "rstrip": false,
593
+ "single_word": false,
594
+ "special": true
595
+ },
596
+ "32069": {
597
+ "content": "<extra_id_30>",
598
+ "lstrip": true,
599
+ "normalized": true,
600
+ "rstrip": false,
601
+ "single_word": false,
602
+ "special": true
603
+ },
604
+ "32070": {
605
+ "content": "<extra_id_29>",
606
+ "lstrip": true,
607
+ "normalized": true,
608
+ "rstrip": false,
609
+ "single_word": false,
610
+ "special": true
611
+ },
612
+ "32071": {
613
+ "content": "<extra_id_28>",
614
+ "lstrip": true,
615
+ "normalized": true,
616
+ "rstrip": false,
617
+ "single_word": false,
618
+ "special": true
619
+ },
620
+ "32072": {
621
+ "content": "<extra_id_27>",
622
+ "lstrip": true,
623
+ "normalized": true,
624
+ "rstrip": false,
625
+ "single_word": false,
626
+ "special": true
627
+ },
628
+ "32073": {
629
+ "content": "<extra_id_26>",
630
+ "lstrip": true,
631
+ "normalized": true,
632
+ "rstrip": false,
633
+ "single_word": false,
634
+ "special": true
635
+ },
636
+ "32074": {
637
+ "content": "<extra_id_25>",
638
+ "lstrip": true,
639
+ "normalized": true,
640
+ "rstrip": false,
641
+ "single_word": false,
642
+ "special": true
643
+ },
644
+ "32075": {
645
+ "content": "<extra_id_24>",
646
+ "lstrip": true,
647
+ "normalized": true,
648
+ "rstrip": false,
649
+ "single_word": false,
650
+ "special": true
651
+ },
652
+ "32076": {
653
+ "content": "<extra_id_23>",
654
+ "lstrip": true,
655
+ "normalized": true,
656
+ "rstrip": false,
657
+ "single_word": false,
658
+ "special": true
659
+ },
660
+ "32077": {
661
+ "content": "<extra_id_22>",
662
+ "lstrip": true,
663
+ "normalized": true,
664
+ "rstrip": false,
665
+ "single_word": false,
666
+ "special": true
667
+ },
668
+ "32078": {
669
+ "content": "<extra_id_21>",
670
+ "lstrip": true,
671
+ "normalized": true,
672
+ "rstrip": false,
673
+ "single_word": false,
674
+ "special": true
675
+ },
676
+ "32079": {
677
+ "content": "<extra_id_20>",
678
+ "lstrip": true,
679
+ "normalized": true,
680
+ "rstrip": false,
681
+ "single_word": false,
682
+ "special": true
683
+ },
684
+ "32080": {
685
+ "content": "<extra_id_19>",
686
+ "lstrip": true,
687
+ "normalized": true,
688
+ "rstrip": false,
689
+ "single_word": false,
690
+ "special": true
691
+ },
692
+ "32081": {
693
+ "content": "<extra_id_18>",
694
+ "lstrip": true,
695
+ "normalized": true,
696
+ "rstrip": false,
697
+ "single_word": false,
698
+ "special": true
699
+ },
700
+ "32082": {
701
+ "content": "<extra_id_17>",
702
+ "lstrip": true,
703
+ "normalized": true,
704
+ "rstrip": false,
705
+ "single_word": false,
706
+ "special": true
707
+ },
708
+ "32083": {
709
+ "content": "<extra_id_16>",
710
+ "lstrip": true,
711
+ "normalized": true,
712
+ "rstrip": false,
713
+ "single_word": false,
714
+ "special": true
715
+ },
716
+ "32084": {
717
+ "content": "<extra_id_15>",
718
+ "lstrip": true,
719
+ "normalized": true,
720
+ "rstrip": false,
721
+ "single_word": false,
722
+ "special": true
723
+ },
724
+ "32085": {
725
+ "content": "<extra_id_14>",
726
+ "lstrip": true,
727
+ "normalized": true,
728
+ "rstrip": false,
729
+ "single_word": false,
730
+ "special": true
731
+ },
732
+ "32086": {
733
+ "content": "<extra_id_13>",
734
+ "lstrip": true,
735
+ "normalized": true,
736
+ "rstrip": false,
737
+ "single_word": false,
738
+ "special": true
739
+ },
740
+ "32087": {
741
+ "content": "<extra_id_12>",
742
+ "lstrip": true,
743
+ "normalized": true,
744
+ "rstrip": false,
745
+ "single_word": false,
746
+ "special": true
747
+ },
748
+ "32088": {
749
+ "content": "<extra_id_11>",
750
+ "lstrip": true,
751
+ "normalized": true,
752
+ "rstrip": false,
753
+ "single_word": false,
754
+ "special": true
755
+ },
756
+ "32089": {
757
+ "content": "<extra_id_10>",
758
+ "lstrip": true,
759
+ "normalized": true,
760
+ "rstrip": false,
761
+ "single_word": false,
762
+ "special": true
763
+ },
764
+ "32090": {
765
+ "content": "<extra_id_9>",
766
+ "lstrip": true,
767
+ "normalized": true,
768
+ "rstrip": false,
769
+ "single_word": false,
770
+ "special": true
771
+ },
772
+ "32091": {
773
+ "content": "<extra_id_8>",
774
+ "lstrip": true,
775
+ "normalized": true,
776
+ "rstrip": false,
777
+ "single_word": false,
778
+ "special": true
779
+ },
780
+ "32092": {
781
+ "content": "<extra_id_7>",
782
+ "lstrip": true,
783
+ "normalized": true,
784
+ "rstrip": false,
785
+ "single_word": false,
786
+ "special": true
787
+ },
788
+ "32093": {
789
+ "content": "<extra_id_6>",
790
+ "lstrip": true,
791
+ "normalized": true,
792
+ "rstrip": false,
793
+ "single_word": false,
794
+ "special": true
795
+ },
796
+ "32094": {
797
+ "content": "<extra_id_5>",
798
+ "lstrip": true,
799
+ "normalized": true,
800
+ "rstrip": false,
801
+ "single_word": false,
802
+ "special": true
803
+ },
804
+ "32095": {
805
+ "content": "<extra_id_4>",
806
+ "lstrip": true,
807
+ "normalized": true,
808
+ "rstrip": false,
809
+ "single_word": false,
810
+ "special": true
811
+ },
812
+ "32096": {
813
+ "content": "<extra_id_3>",
814
+ "lstrip": true,
815
+ "normalized": true,
816
+ "rstrip": false,
817
+ "single_word": false,
818
+ "special": true
819
+ },
820
+ "32097": {
821
+ "content": "<extra_id_2>",
822
+ "lstrip": true,
823
+ "normalized": true,
824
+ "rstrip": false,
825
+ "single_word": false,
826
+ "special": true
827
+ },
828
+ "32098": {
829
+ "content": "<extra_id_1>",
830
+ "lstrip": true,
831
+ "normalized": true,
832
+ "rstrip": false,
833
+ "single_word": false,
834
+ "special": true
835
+ },
836
+ "32099": {
837
+ "content": "<extra_id_0>",
838
+ "lstrip": true,
839
+ "normalized": true,
840
+ "rstrip": false,
841
+ "single_word": false,
842
+ "special": true
843
+ }
844
+ },
845
+ "additional_special_tokens": [
846
+ "<extra_id_99>",
847
+ "<extra_id_98>",
848
+ "<extra_id_97>",
849
+ "<extra_id_96>",
850
+ "<extra_id_95>",
851
+ "<extra_id_94>",
852
+ "<extra_id_93>",
853
+ "<extra_id_92>",
854
+ "<extra_id_91>",
855
+ "<extra_id_90>",
856
+ "<extra_id_89>",
857
+ "<extra_id_88>",
858
+ "<extra_id_87>",
859
+ "<extra_id_86>",
860
+ "<extra_id_85>",
861
+ "<extra_id_84>",
862
+ "<extra_id_83>",
863
+ "<extra_id_82>",
864
+ "<extra_id_81>",
865
+ "<extra_id_80>",
866
+ "<extra_id_79>",
867
+ "<extra_id_78>",
868
+ "<extra_id_77>",
869
+ "<extra_id_76>",
870
+ "<extra_id_75>",
871
+ "<extra_id_74>",
872
+ "<extra_id_73>",
873
+ "<extra_id_72>",
874
+ "<extra_id_71>",
875
+ "<extra_id_70>",
876
+ "<extra_id_69>",
877
+ "<extra_id_68>",
878
+ "<extra_id_67>",
879
+ "<extra_id_66>",
880
+ "<extra_id_65>",
881
+ "<extra_id_64>",
882
+ "<extra_id_63>",
883
+ "<extra_id_62>",
884
+ "<extra_id_61>",
885
+ "<extra_id_60>",
886
+ "<extra_id_59>",
887
+ "<extra_id_58>",
888
+ "<extra_id_57>",
889
+ "<extra_id_56>",
890
+ "<extra_id_55>",
891
+ "<extra_id_54>",
892
+ "<extra_id_53>",
893
+ "<extra_id_52>",
894
+ "<extra_id_51>",
895
+ "<extra_id_50>",
896
+ "<extra_id_49>",
897
+ "<extra_id_48>",
898
+ "<extra_id_47>",
899
+ "<extra_id_46>",
900
+ "<extra_id_45>",
901
+ "<extra_id_44>",
902
+ "<extra_id_43>",
903
+ "<extra_id_42>",
904
+ "<extra_id_41>",
905
+ "<extra_id_40>",
906
+ "<extra_id_39>",
907
+ "<extra_id_38>",
908
+ "<extra_id_37>",
909
+ "<extra_id_36>",
910
+ "<extra_id_35>",
911
+ "<extra_id_34>",
912
+ "<extra_id_33>",
913
+ "<extra_id_32>",
914
+ "<extra_id_31>",
915
+ "<extra_id_30>",
916
+ "<extra_id_29>",
917
+ "<extra_id_28>",
918
+ "<extra_id_27>",
919
+ "<extra_id_26>",
920
+ "<extra_id_25>",
921
+ "<extra_id_24>",
922
+ "<extra_id_23>",
923
+ "<extra_id_22>",
924
+ "<extra_id_21>",
925
+ "<extra_id_20>",
926
+ "<extra_id_19>",
927
+ "<extra_id_18>",
928
+ "<extra_id_17>",
929
+ "<extra_id_16>",
930
+ "<extra_id_15>",
931
+ "<extra_id_14>",
932
+ "<extra_id_13>",
933
+ "<extra_id_12>",
934
+ "<extra_id_11>",
935
+ "<extra_id_10>",
936
+ "<extra_id_9>",
937
+ "<extra_id_8>",
938
+ "<extra_id_7>",
939
+ "<extra_id_6>",
940
+ "<extra_id_5>",
941
+ "<extra_id_4>",
942
+ "<extra_id_3>",
943
+ "<extra_id_2>",
944
+ "<extra_id_1>",
945
+ "<extra_id_0>"
946
+ ],
947
+ "bos_token": "<s>",
948
+ "clean_up_tokenization_spaces": true,
949
+ "cls_token": "<s>",
950
+ "eos_token": "</s>",
951
+ "errors": "replace",
952
+ "mask_token": "<mask>",
953
+ "model_max_length": 512,
954
+ "pad_token": "<pad>",
955
+ "sep_token": "</s>",
956
+ "tokenizer_class": "RobertaTokenizer",
957
+ "trim_offsets": true,
958
+ "unk_token": "<unk>"
959
+ }
int8_dynamic/tokenizer/vocab.json ADDED
The diff for this file is too large to render. See raw diff
 
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio==5.8.0
2
+ pandas
3
+ sqlparse
4
+ transformers
5
+ torch --index-url https://download.pytorch.org/whl/cpu
6
+ peft
7
+ trl
8
+ sentencepiece
9
+ matplotlib
10
+ huggingface_hub
scripts/__pycache__/benchmark_parallel_reward.cpython-310.pyc ADDED
Binary file (6.31 kB). View file
 
scripts/__pycache__/benchmark_parallel_reward.cpython-313.pyc ADDED
Binary file (10.3 kB). View file
 
scripts/__pycache__/benchmark_quantization.cpython-310.pyc ADDED
Binary file (3.79 kB). View file
 
scripts/__pycache__/benchmark_rollout_generation.cpython-310.pyc ADDED
Binary file (2.75 kB). View file
 
scripts/__pycache__/quantize_export.cpython-310.pyc ADDED
Binary file (2.05 kB). View file
 
scripts/__pycache__/quantized_infer_harness.cpython-310.pyc ADDED
Binary file (1.62 kB). View file
 
scripts/benchmark_parallel_reward.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ # Ensure headless-safe matplotlib + writable cache when called from Gradio/subprocess.
3
+ os.environ.setdefault("MPLBACKEND", "Agg")
4
+ os.environ.setdefault("MPLCONFIGDIR", os.environ.get("MPLCONFIGDIR", "/tmp/mplconfig"))
5
+ import time
6
+ import json
7
+ import argparse
8
+ import matplotlib.pyplot as plt
9
+ import numpy as np
10
+ import sys
11
+ from pathlib import Path
12
+
13
+ # ==========================================
14
+ # RELATIVE PATH RESOLUTION
15
+ # ==========================================
16
+ PROJECT_ROOT = Path(__file__).resolve().parent.parent
17
+ sys.path.append(str(PROJECT_ROOT))
18
+
19
+ # Dynamically resolve where the databases are kept
20
+ if (PROJECT_ROOT / "data" / "database").exists() and list((PROJECT_ROOT / "data" / "database").rglob("*.sqlite")):
21
+ DB_ROOT = PROJECT_ROOT / "data" / "database"
22
+ else:
23
+ DB_ROOT = PROJECT_ROOT / "final_databases"
24
+
25
+ from src.execution_reward import (
26
+ execution_reward_batch_sequential,
27
+ execution_reward_batch_parallel,
28
+ execution_reward_batch_parallel_by_db,
29
+ execution_reward_timed,
30
+ set_use_cache,
31
+ set_use_schema_validation,
32
+ clear_result_cache
33
+ )
34
+
35
+ def generate_mock_rollouts(num_rollouts: int = 100, heavy_n: int = 500_000):
36
+ """Generates heavy queries across multiple databases to properly test true concurrency."""
37
+ print(f"\nGenerating {num_rollouts} heavy rollouts to simulate RLHF query workload...", flush=True)
38
+
39
+ # Smart search for real databases
40
+ real_dbs = [str(p) for p in DB_ROOT.rglob("*.sqlite")]
41
+
42
+ if real_dbs:
43
+ print(f"Found {len(real_dbs)} real SQLite databases in {DB_ROOT}. Distributing workload...", flush=True)
44
+ else:
45
+ print(f"❌ CRITICAL ERROR: No real databases found in {DB_ROOT}. Cannot run benchmark.", flush=True)
46
+ sys.exit(1)
47
+
48
+ rollouts = []
49
+ for i in range(num_rollouts):
50
+ db_path = real_dbs[i % len(real_dbs)]
51
+
52
+ # Heavy deterministic CPU-ish query (may be cut off by the 2s timeout depending on machine).
53
+ heavy_sql = f"""
54
+ WITH RECURSIVE cnt(x) AS (
55
+ SELECT 1
56
+ UNION ALL
57
+ SELECT x+1 FROM cnt WHERE x < {heavy_n + (i % 10_000)}
58
+ )
59
+ SELECT sum(x) FROM cnt;
60
+ """
61
+ clean_sql = heavy_sql.replace("\n", " ").strip()
62
+ rollouts.append((clean_sql, db_path, clean_sql))
63
+ if num_rollouts >= 500 and (i + 1) % 250 == 0:
64
+ print(f" generated {i + 1}/{num_rollouts}...", flush=True)
65
+
66
+ return rollouts
67
+
68
+ def profile_bottlenecks(rollouts, sample_size: int = 20, print_every: int = 5):
69
+ """Profiles CPU usage to identify time spent in parsing, planning, and execution."""
70
+ print("\n" + "="*65)
71
+ print(" 🔍 CPU PROFILING: IDENTIFYING BOTTLENECKS (100 Rollouts)")
72
+ print("="*65)
73
+
74
+ clear_result_cache()
75
+ set_use_cache(False) # Disable cache to force real work
76
+ set_use_schema_validation(False) # CTE-heavy benchmark queries may fail schema validation
77
+
78
+ total_parse = 0.0
79
+ total_plan = 0.0
80
+ total_exec = 0.0
81
+
82
+ # Profile a small subset by default so the script prints quickly.
83
+ sample_size = min(int(sample_size), len(rollouts))
84
+ sample_rollouts = rollouts[:sample_size]
85
+
86
+ for i, (pred, db, gold) in enumerate(sample_rollouts, 1):
87
+ _, timings = execution_reward_timed(pred, db, gold, measure_plan=True)
88
+ total_parse += timings['parse_s']
89
+ total_plan += timings['plan_s']
90
+ total_exec += timings['exec_s']
91
+ if print_every and (i % int(print_every) == 0 or i == sample_size):
92
+ print(f" profiled {i}/{sample_size}...", flush=True)
93
+
94
+ total_time = total_parse + total_plan + total_exec
95
+ if total_time == 0: total_time = 0.0001 # Prevent div by zero
96
+
97
+ print(f"{'Phase':<15} | {'Avg Time (ms)':<15} | {'% of Total CPU':<15}")
98
+ print("-" * 65)
99
+ print(f"{'Regex Parsing':<15} | {(total_parse/sample_size)*1000:<15.2f} | {(total_parse/total_time)*100:<14.1f}%")
100
+ print(f"{'Query Planning':<15} | {(total_plan/sample_size)*1000:<15.2f} | {(total_plan/total_time)*100:<14.1f}%")
101
+ print(f"{'DB Execution':<15} | {(total_exec/sample_size)*1000:<15.2f} | {(total_exec/total_time)*100:<14.1f}%")
102
+ print("="*65 + "\n")
103
+
104
+ def run_benchmark_for_setting(rollouts, use_cache: bool, max_workers: int):
105
+ set_use_cache(use_cache)
106
+ set_use_schema_validation(False) # benchmark focuses on execution speed
107
+
108
+ # Sequential
109
+ clear_result_cache()
110
+ start_time = time.perf_counter()
111
+ execution_reward_batch_sequential(rollouts)
112
+ sequential_s = time.perf_counter() - start_time
113
+
114
+ # Parallel
115
+ clear_result_cache()
116
+ start_time = time.perf_counter()
117
+ # 1 thread per DB (recommended)
118
+ execution_reward_batch_parallel_by_db(rollouts, max_workers=max_workers)
119
+ parallel_s = time.perf_counter() - start_time
120
+
121
+ speedup = sequential_s / parallel_s if parallel_s > 0 else 0
122
+
123
+ return {
124
+ "sequential_s": sequential_s,
125
+ "parallel_s": parallel_s,
126
+ "speedup": speedup
127
+ }
128
+
129
+ def print_comparison_table(results):
130
+ print("="*65)
131
+ print(f"{'Setting':<16} | {'Sequential (s)':<14} | {'Parallel (s)':<14} | {'Speedup':<10}")
132
+ print("-" * 65)
133
+ for setting, key in [("With Cache", "with_cache"), ("Without Cache", "without_cache")]:
134
+ seq = results[key]['sequential_s']
135
+ par = results[key]['parallel_s']
136
+ spd = results[key]['speedup']
137
+ print(f"{setting:<16} | {seq:<14.4f} | {par:<14.4f} | {spd:<9.2f}x")
138
+ print("="*65 + "\n")
139
+
140
+ def plot_results(results, output_path: str):
141
+ labels = ['With Cache', 'Without Cache']
142
+ seq_times = [results['with_cache']['sequential_s'], results['without_cache']['sequential_s']]
143
+ par_times = [results['with_cache']['parallel_s'], results['without_cache']['parallel_s']]
144
+
145
+ x = np.arange(len(labels))
146
+ width = 0.35
147
+
148
+ fig, ax = plt.subplots(figsize=(8, 6))
149
+ ax.bar(x - width/2, seq_times, width, label='Sequential', color='#4C72B0')
150
+ ax.bar(x + width/2, par_times, width, label='Parallel', color='#DD8452')
151
+
152
+ ax.set_ylabel('Execution Time (seconds)')
153
+ ax.set_title('Text2SQL Reward Execution: Sequential vs Parallel')
154
+ ax.set_xticks(x)
155
+ ax.set_xticklabels(labels)
156
+ ax.legend()
157
+
158
+ for container in ax.containers:
159
+ ax.bar_label(container, fmt='%.2f', padding=3)
160
+
161
+ fig.tight_layout()
162
+ plt.savefig(output_path, dpi=300)
163
+ plt.close()
164
+
165
+ def main():
166
+ parser = argparse.ArgumentParser(description="Benchmark SQL Execution Reward")
167
+ parser.add_argument("--n", type=int, default=1000, help="Number of rollouts to benchmark")
168
+ parser.add_argument("--max-workers", type=int, default=20, help="Max workers for parallel execution")
169
+ parser.add_argument("--heavy-n", type=int, default=200_000, help="Recursive CTE upper bound (controls heaviness)")
170
+ parser.add_argument("--skip-profile", action="store_true", help="Skip the CPU profiling section for faster startup")
171
+ parser.add_argument("--profile-n", type=int, default=20, help="Number of rollouts to use for CPU profiling")
172
+ args = parser.parse_args()
173
+
174
+ os.makedirs(str(PROJECT_ROOT / "results"), exist_ok=True)
175
+
176
+ rollouts = generate_mock_rollouts(args.n, heavy_n=args.heavy_n)
177
+
178
+ if not args.skip_profile:
179
+ profile_bottlenecks(rollouts, sample_size=args.profile_n)
180
+
181
+ print("Starting Main Scalability Benchmarks...")
182
+
183
+ print("Running Experiment A: Cache ENABLED...")
184
+ results_with_cache = run_benchmark_for_setting(rollouts, use_cache=True, max_workers=args.max_workers)
185
+
186
+ print("Running Experiment B: Cache DISABLED...")
187
+ results_without_cache = run_benchmark_for_setting(rollouts, use_cache=False, max_workers=args.max_workers)
188
+
189
+ final_results = {
190
+ "with_cache": results_with_cache,
191
+ "without_cache": results_without_cache
192
+ }
193
+
194
+ json_path = str(PROJECT_ROOT / "results" / "task1_results.json")
195
+ with open(json_path, 'w') as f:
196
+ json.dump(final_results, f, indent=4)
197
+
198
+ print_comparison_table(final_results)
199
+ plot_results(final_results, str(PROJECT_ROOT / "results" / "task1_plot.png"))
200
+
201
+ if __name__ == "__main__":
202
+ main()
scripts/benchmark_quantization.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import json
5
+ import os
6
+ import time
7
+ from pathlib import Path
8
+ from typing import Dict, List, Tuple
9
+
10
+ import numpy as np
11
+ import torch
12
+
13
+ from src.execution_reward import execution_reward
14
+ from src.prompting import encode_prompt
15
+ from src.quantization_utils import load_fp32_model, load_quant_artifact
16
+
17
+
18
+ def _load_dev_items(root: Path, n: int, seed: int = 42) -> List[dict]:
19
+ data = json.loads((root / "data" / "dev.json").read_text())
20
+ if n >= len(data):
21
+ return data
22
+ rng = np.random.default_rng(seed)
23
+ idxs = rng.choice(len(data), size=n, replace=False)
24
+ return [data[int(i)] for i in idxs]
25
+
26
+
27
+ def _bench_variant(name: str, tok, model, items: List[dict], device: str) -> Dict[str, float]:
28
+ latencies: List[float] = []
29
+ ex = 0
30
+
31
+ # Warmup (1 item)
32
+ if items:
33
+ it = items[0]
34
+ _ = encode_prompt(tok, it["question"], it["db_id"], device=device, max_input_tokens=512).unsqueeze(0)
35
+
36
+ for it in items:
37
+ db_id = it["db_id"]
38
+ q = it["question"]
39
+ gold = it["query"]
40
+ db_path = str(Path("data") / "database" / db_id / f"{db_id}.sqlite")
41
+
42
+ input_ids = encode_prompt(tok, q, db_id, device=device, max_input_tokens=512).unsqueeze(0)
43
+ t0 = time.perf_counter()
44
+ out = model.generate(input_ids=input_ids, max_new_tokens=120, num_beams=8, repetition_penalty=1.2)
45
+ dt = time.perf_counter() - t0
46
+ latencies.append(dt)
47
+
48
+ pred = tok.decode(out[0], skip_special_tokens=True).strip()
49
+ r = execution_reward(pred, db_path, gold)
50
+ if float(r) >= 1.0:
51
+ ex += 1
52
+
53
+ p50 = float(np.percentile(latencies, 50)) if latencies else 0.0
54
+ p90 = float(np.percentile(latencies, 90)) if latencies else 0.0
55
+ mean = float(np.mean(latencies)) if latencies else 0.0
56
+ return {
57
+ "n": float(len(items)),
58
+ "ex": float(ex / max(len(items), 1)),
59
+ "lat_mean_s": mean,
60
+ "lat_p50_s": p50,
61
+ "lat_p90_s": p90,
62
+ }
63
+
64
+
65
+ def main() -> None:
66
+ p = argparse.ArgumentParser(description="Benchmark fp32 vs quantized artifacts (CPU-focused).")
67
+ p.add_argument("--base_model", default=os.environ.get("BASE_MODEL", "Salesforce/codet5-base"))
68
+ p.add_argument("--adapter", default="", help="Optional adapter for fp32 baseline.")
69
+ p.add_argument("--artifact_int8", default="", help="Artifact dir exported by scripts/quantize_export.py")
70
+ p.add_argument("--artifact_int8_decoder", default="", help="Artifact dir for decoder-only int8")
71
+ p.add_argument("--num_samples", type=int, default=100)
72
+ p.add_argument("--seed", type=int, default=42)
73
+ p.add_argument("--out", default="results/task5_quant_bench.json")
74
+ p.add_argument("--local_only", action="store_true")
75
+ args = p.parse_args()
76
+
77
+ device = "cpu"
78
+ root = Path(".")
79
+ items = _load_dev_items(root, args.num_samples, args.seed)
80
+
81
+ report: Dict[str, Dict[str, float]] = {}
82
+
83
+ tok, fp32 = load_fp32_model(
84
+ args.base_model,
85
+ adapter_path=args.adapter.strip() or None,
86
+ device=device,
87
+ local_only=args.local_only,
88
+ )
89
+ report["fp32"] = _bench_variant("fp32", tok, fp32, items, device)
90
+
91
+ if args.artifact_int8:
92
+ tok8, m8, _meta = load_quant_artifact(args.artifact_int8, device=device, local_only=True)
93
+ report["int8_dynamic"] = _bench_variant("int8_dynamic", tok8, m8, items, device)
94
+
95
+ if args.artifact_int8_decoder:
96
+ tokd, md, _meta = load_quant_artifact(args.artifact_int8_decoder, device=device, local_only=True)
97
+ report["int8_decoder_dynamic"] = _bench_variant("int8_decoder_dynamic", tokd, md, items, device)
98
+
99
+ out_path = Path(args.out)
100
+ out_path.parent.mkdir(parents=True, exist_ok=True)
101
+ out_path.write_text(json.dumps(report, indent=2))
102
+ print(json.dumps(report, indent=2))
103
+
104
+
105
+ if __name__ == "__main__":
106
+ torch.set_grad_enabled(False)
107
+ main()
108
+
scripts/benchmark_rollout_generation.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import json
5
+ import os
6
+ import time
7
+ from pathlib import Path
8
+ from typing import List
9
+
10
+ import numpy as np
11
+ import torch
12
+
13
+ from src.prompting import encode_prompt
14
+ from src.quantization_utils import load_fp32_model, load_quant_artifact
15
+
16
+
17
+ def _load_items(root: Path, n: int, seed: int = 42) -> List[dict]:
18
+ data = json.loads((root / "data" / "dev.json").read_text())
19
+ if n >= len(data):
20
+ return data
21
+ rng = np.random.default_rng(seed)
22
+ idxs = rng.choice(len(data), size=n, replace=False)
23
+ return [data[int(i)] for i in idxs]
24
+
25
+
26
+ def _bench_generate(tok, model, items: List[dict], device: str) -> float:
27
+ t0 = time.perf_counter()
28
+ for it in items:
29
+ input_ids = encode_prompt(tok, it["question"], it["db_id"], device=device, max_input_tokens=512).unsqueeze(0)
30
+ _ = model.generate(input_ids=input_ids, max_new_tokens=64, num_beams=4)
31
+ return time.perf_counter() - t0
32
+
33
+
34
+ def main() -> None:
35
+ p = argparse.ArgumentParser(description="Benchmark rollout generation latency for RL loops.")
36
+ p.add_argument("--base_model", default=os.environ.get("BASE_MODEL", "Salesforce/codet5-base"))
37
+ p.add_argument("--adapter", default="")
38
+ p.add_argument("--artifact", default="", help="Quantized artifact dir (optional).")
39
+ p.add_argument("--num_rollouts", type=int, default=128)
40
+ p.add_argument("--seed", type=int, default=42)
41
+ p.add_argument("--local_only", action="store_true")
42
+ args = p.parse_args()
43
+
44
+ device = "cpu"
45
+ root = Path(".")
46
+ items = _load_items(root, args.num_rollouts, args.seed)
47
+
48
+ tok, fp32 = load_fp32_model(
49
+ args.base_model,
50
+ adapter_path=args.adapter.strip() or None,
51
+ device=device,
52
+ local_only=args.local_only,
53
+ )
54
+ t_fp32 = _bench_generate(tok, fp32, items, device)
55
+ print(f"fp32: {t_fp32:.2f}s for {len(items)} rollouts ({len(items)/max(t_fp32,1e-9):.2f} rollouts/s)")
56
+
57
+ if args.artifact:
58
+ tokq, mq, meta = load_quant_artifact(args.artifact, device=device, local_only=True)
59
+ t_q = _bench_generate(tokq, mq, items, device)
60
+ mode = meta.get("mode", "quant")
61
+ print(f"{mode}: {t_q:.2f}s for {len(items)} rollouts ({len(items)/max(t_q,1e-9):.2f} rollouts/s)")
62
+
63
+
64
+ if __name__ == "__main__":
65
+ torch.set_grad_enabled(False)
66
+ main()
scripts/error_dashboard.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import json
3
+ from collections import Counter
4
+
5
+ # ==============================
6
+ # LOAD LOGS
7
+ # ==============================
8
+ with open("results/error_logs.json") as f:
9
+ logs = json.load(f)
10
+
11
+ total_errors = len(logs)
12
+
13
+ # ==============================
14
+ # ERROR DISTRIBUTION
15
+ # ==============================
16
+ error_counts = Counter([e["error_type"] for e in logs])
17
+
18
+ print("\n" + "="*50)
19
+ print("📊 TEXT-to-SQL ERROR DASHBOARD")
20
+ print("="*50)
21
+
22
+ print(f"\n🔢 Total Errors Logged: {total_errors}")
23
+
24
+ print("\n📊 ERROR DISTRIBUTION:")
25
+ print("-"*30)
26
+ for k, v in error_counts.items():
27
+ percent = (v / total_errors) * 100
28
+ print(f"{k:<20} : {v:>4} ({percent:.1f}%)")
29
+
30
+ # ==============================
31
+ # TOP ERROR
32
+ # ==============================
33
+ top_error = error_counts.most_common(1)[0]
34
+
35
+ print("\n🔥 MOST COMMON ERROR:")
36
+ print("-"*30)
37
+ print(f"{top_error[0]} ({top_error[1]} times)")
38
+
39
+ # ==============================
40
+ # SQL OPERATION ANALYSIS
41
+ # ==============================
42
+ join_count = 0
43
+ where_count = 0
44
+ group_count = 0
45
+ order_count = 0
46
+
47
+ for e in logs:
48
+ sql = e["sql"].lower()
49
+
50
+ if "join" in sql:
51
+ join_count += 1
52
+ if "where" in sql:
53
+ where_count += 1
54
+ if "group by" in sql:
55
+ group_count += 1
56
+ if "order by" in sql:
57
+ order_count += 1
58
+
59
+ print("\n🧠 SQL OPERATION ANALYSIS:")
60
+ print("-"*30)
61
+ print(f"JOIN used in : {join_count} queries")
62
+ print(f"WHERE used in : {where_count} queries")
63
+ print(f"GROUP BY used in : {group_count} queries")
64
+ print(f"ORDER BY used in : {order_count} queries")
65
+
66
+ # ==============================
67
+ # SAMPLE ERRORS
68
+ # ==============================
69
+ print("\n🧪 SAMPLE ERROR CASES:")
70
+ print("-"*50)
71
+
72
+ for i, e in enumerate(logs[:3], 1):
73
+ print(f"\nCase {i}:")
74
+ print(f"Q : {e['question']}")
75
+ print(f"SQL : {e['sql']}")
76
+ print(f"Type: {e['error_type']}")
77
+
78
+ # ==============================
79
+ # FINAL INSIGHT
80
+ # ==============================
81
+ print("\n📌 FINAL INSIGHT:")
82
+ print("-"*30)
83
+
84
+ if top_error[0] == "wrong_column":
85
+ print("⚠️ Model struggles with column selection (schema understanding issue).")
86
+
87
+ elif top_error[0] == "wrong_table":
88
+ print("⚠️ Model struggles with correct table mapping.")
89
+
90
+ elif top_error[0] == "syntax_error":
91
+ print("⚠️ Model generates invalid SQL syntax.")
92
+
93
+ else:
94
+ print("⚠️ Mixed errors — needs general improvement.")
95
+
96
+ print("\n" + "="*50)
97
+ print("✅ DASHBOARD COMPLETE")
98
+ print("="*50)
99
+
scripts/evaluate.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import sqlite3
5
+ from contextlib import closing
6
+ from typing import Dict, List
7
+
8
+ import torch
9
+ from datasets import load_dataset
10
+ from peft import PeftModel
11
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
12
+ from trl import AutoModelForSeq2SeqLMWithValueHead
13
+
14
+ import sys
15
+
16
+ PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
17
+ sys.path.append(PROJECT_ROOT)
18
+ from src.execution_reward import execution_reward # noqa: E402
19
+
20
+
21
+ BASE_MODEL = os.environ.get("BASE_MODEL", "t5-small")
22
+ DB_ROOT = os.path.join(PROJECT_ROOT, "data", "database")
23
+
24
+ # Prefer RL best model if present; otherwise fall back.
25
+ RL_DIR = os.path.join(PROJECT_ROOT, "outputs", "rlhf_text2sql", "best_model")
26
+ if not os.path.isdir(RL_DIR):
27
+ RL_DIR = os.path.join(PROJECT_ROOT, "outputs", "rlhf_text2sql")
28
+
29
+ SPLIT = "train[:100]" # quick sanity check
30
+ MAX_NEW_TOKENS = 128
31
+
32
+ PREFIX = "translate English to SQL:"
33
+ MAX_SCHEMA_CHARS = 1500
34
+ MAX_INPUT_TOKENS = 512
35
+
36
+
37
+ os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
38
+ device = "mps" if torch.backends.mps.is_available() else "cpu"
39
+ print("Using device:", device)
40
+
41
+
42
+ def get_db_path(db_id: str) -> str:
43
+ return os.path.join(DB_ROOT, db_id, f"{db_id}.sqlite")
44
+
45
+
46
+ _SCHEMA_CACHE: Dict[str, str] = {}
47
+
48
+
49
+ def get_db_schema_text(db_path: str) -> str:
50
+ if db_path in _SCHEMA_CACHE:
51
+ return _SCHEMA_CACHE[db_path]
52
+ schema_text = ""
53
+ try:
54
+ with closing(sqlite3.connect(db_path)) as conn:
55
+ cur = conn.cursor()
56
+ tables = cur.execute(
57
+ "SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%';"
58
+ ).fetchall()
59
+ for (tname,) in tables:
60
+ cols = cur.execute(f'PRAGMA table_info(\"{tname}\")').fetchall()
61
+ col_names = [c[1] for c in cols if c and isinstance(c[1], str)]
62
+ schema_text += f"{tname}({', '.join(col_names)}) "
63
+ except Exception:
64
+ schema_text = ""
65
+ if len(schema_text) > MAX_SCHEMA_CHARS:
66
+ schema_text = schema_text[:MAX_SCHEMA_CHARS]
67
+ _SCHEMA_CACHE[db_path] = schema_text
68
+ return schema_text
69
+
70
+
71
+ def encode_prompt(tokenizer, question: str, schema: str) -> torch.Tensor:
72
+ schema = (schema or "")[:MAX_SCHEMA_CHARS]
73
+ prefix_schema = f"{PREFIX}\n\nSchema:\n"
74
+ mid = "\n\nQuestion:\n"
75
+ suffix = f"{question}\n\nSQL:"
76
+
77
+ prefix_ids = tokenizer.encode(prefix_schema, add_special_tokens=False)
78
+ schema_ids = tokenizer.encode(schema, add_special_tokens=False)
79
+ mid_ids = tokenizer.encode(mid, add_special_tokens=False)
80
+ suffix_ids = tokenizer.encode(suffix, add_special_tokens=False)
81
+
82
+ eos_id = tokenizer.eos_token_id
83
+ max_without_eos = MAX_INPUT_TOKENS - (1 if eos_id is not None else 0)
84
+
85
+ fixed_len = len(prefix_ids) + len(mid_ids) + len(suffix_ids)
86
+ if fixed_len > max_without_eos:
87
+ keep = max(0, max_without_eos - (len(prefix_ids) + len(mid_ids)))
88
+ suffix_ids = suffix_ids[:keep]
89
+ fixed_len = len(prefix_ids) + len(mid_ids) + len(suffix_ids)
90
+
91
+ remaining_for_schema = max_without_eos - fixed_len
92
+ if remaining_for_schema < 0:
93
+ remaining_for_schema = 0
94
+ schema_ids = schema_ids[:remaining_for_schema]
95
+
96
+ ids = (prefix_ids + schema_ids + mid_ids + suffix_ids)[:max_without_eos]
97
+ if eos_id is not None:
98
+ ids = ids + [eos_id]
99
+
100
+ return torch.tensor(ids, dtype=torch.long).to(device)
101
+
102
+
103
+ def load_model_and_tokenizer():
104
+ # Try loading the PPO-saved value-head model directly.
105
+ try:
106
+ tok = AutoTokenizer.from_pretrained(RL_DIR)
107
+ mdl = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(RL_DIR).to(device)
108
+ return tok, mdl
109
+ except Exception:
110
+ pass
111
+
112
+ # Fallback: treat RL_DIR as a LoRA adapter directory.
113
+ tok = AutoTokenizer.from_pretrained(BASE_MODEL)
114
+ if tok.pad_token_id is None:
115
+ tok.pad_token = tok.eos_token
116
+ base = AutoModelForSeq2SeqLM.from_pretrained(BASE_MODEL).to(device)
117
+ try:
118
+ base = PeftModel.from_pretrained(base, RL_DIR)
119
+ except Exception:
120
+ # Final fallback: use SFT adapter (if RL adapter not found)
121
+ sft_dir = os.path.join(PROJECT_ROOT, "checkpoints", "sft_adapter")
122
+ base = PeftModel.from_pretrained(base, sft_dir)
123
+ return tok, base
124
+
125
+
126
+ def main() -> None:
127
+ tokenizer, model = load_model_and_tokenizer()
128
+ model.eval()
129
+
130
+ ds = load_dataset("spider", split=SPLIT)
131
+
132
+ correct = 0
133
+ valid = 0
134
+
135
+ for i, ex in enumerate(ds, start=1):
136
+ question = ex["question"]
137
+ gold_sql = ex["query"]
138
+ db_id = ex["db_id"]
139
+ db_path = get_db_path(db_id)
140
+ schema = get_db_schema_text(db_path)
141
+
142
+ inp = encode_prompt(tokenizer, question, schema)
143
+ with torch.no_grad():
144
+ out = model.generate(
145
+ input_ids=inp.unsqueeze(0),
146
+ max_new_tokens=MAX_NEW_TOKENS,
147
+ do_sample=False,
148
+ num_beams=1,
149
+ pad_token_id=tokenizer.pad_token_id,
150
+ eos_token_id=tokenizer.eos_token_id,
151
+ )
152
+ pred_sql = tokenizer.decode(out[0], skip_special_tokens=True)
153
+ r = execution_reward(pred_sql, db_path, gold_sql)
154
+ if r > -1.0:
155
+ valid += 1
156
+ if r >= 1.0:
157
+ correct += 1
158
+
159
+ if i % 25 == 0:
160
+ print(f"Evaluated {i}/{len(ds)}")
161
+
162
+ n = len(ds)
163
+ print("\nRESULTS")
164
+ print(f"examples: {n}")
165
+ print(f"execution_accuracy: {correct/n:.3f}")
166
+ print(f"valid_sql_rate: {valid/n:.3f}")
167
+
168
+
169
+ if __name__ == "__main__":
170
+ main()
scripts/plot_task2.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import seaborn as sns
3
+
4
+ # ==========================================
5
+ # 1. EXTRACTED DATA FROM TERMINAL
6
+ # ==========================================
7
+ # Error Distribution Data
8
+ error_types = ['wrong_column', 'wrong_table', 'ambiguous_column', 'other']
9
+ error_counts = [61, 11, 4, 1]
10
+
11
+ # SQL Operation Analysis Data
12
+ sql_ops = ['WHERE', 'JOIN', 'ORDER BY', 'GROUP BY']
13
+ op_counts = [55, 36, 20, 14]
14
+
15
+ # ==========================================
16
+ # 2. SET UP THE DASHBOARD LAYOUT
17
+ # ==========================================
18
+ # Use a clean, modern aesthetic
19
+ sns.set_theme(style="whitegrid")
20
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))
21
+
22
+ # ==========================================
23
+ # 3. PLOT 1: ERROR DISTRIBUTION (Horizontal Bar)
24
+ # ==========================================
25
+ sns.barplot(x=error_counts, y=error_types, ax=ax1, palette="flare")
26
+ ax1.set_title('Primary Cause of Failure (Total: 77 Errors)', fontsize=14, pad=15, fontweight='bold')
27
+ ax1.set_xlabel('Number of Queries')
28
+ ax1.set_ylabel('')
29
+
30
+ # Add actual numbers next to the bars
31
+ for i, v in enumerate(error_counts):
32
+ ax1.text(v + 1.5, i, f"{v}", color='#333333', va='center', fontweight='bold')
33
+
34
+ # ==========================================
35
+ # 4. PLOT 2: SQL OPERATIONS (Vertical Bar)
36
+ # ==========================================
37
+ sns.barplot(x=sql_ops, y=op_counts, ax=ax2, palette="crest")
38
+ ax2.set_title('Clauses Present in Failed Queries', fontsize=14, pad=15, fontweight='bold')
39
+ ax2.set_ylabel('Frequency')
40
+ ax2.set_xlabel('')
41
+
42
+ # Add actual numbers on top of the bars
43
+ for i, v in enumerate(op_counts):
44
+ ax2.text(i, v + 1, str(v), color='#333333', ha='center', fontweight='bold')
45
+
46
+ # ==========================================
47
+ # 5. RENDER AND SAVE
48
+ # ==========================================
49
+ plt.suptitle('Text-to-SQL Error Diagnostic Dashboard', fontsize=18, fontweight='heavy', y=1.05)
50
+ sns.despine(left=True, bottom=True) # Removes clunky borders
51
+ plt.tight_layout()
52
+
53
+ # Save the plot as a high-res image for your report!
54
+ plt.savefig('error_diagnostic_plot.png', dpi=300, bbox_inches='tight')
55
+ print("✅ Plot successfully saved as 'error_diagnostic_plot.png'")
56
+
57
+ # Display the plot
58
+ plt.show()
scripts/plot_task3.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+
3
+ labels = ["Without", "With"]
4
+ constraint = [0, 88]
5
+
6
+ plt.figure()
7
+ plt.bar(labels, constraint)
8
+
9
+ plt.title("Constraint Satisfaction (Task 3)")
10
+ plt.ylabel("Percentage")
11
+
12
+ plt.savefig("task3_constraint.png")
13
+ plt.show()
14
+
15
+
scripts/plot_task3_plotly.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import plotly.graph_objects as go
2
+ from plotly.subplots import make_subplots
3
+
4
+ # ==========================================
5
+ # 1. YOUR DATA
6
+ # ==========================================
7
+ models = ['FP32 (Base)', 'INT8 Dynamic', 'INT8 Decoder-Only']
8
+
9
+ # Accuracy (multiplied by 100 for percentage)
10
+ accuracy = [36.0, 36.0, 38.0]
11
+
12
+ # Latency metrics
13
+ lat_mean = [3.11, 1.65, 1.66]
14
+ lat_p50 = [2.94, 1.54, 1.56]
15
+ lat_p90 = [4.64, 2.44, 2.48]
16
+
17
+ # ==========================================
18
+ # 2. SET UP THE SIDE-BY-SIDE LAYOUT
19
+ # ==========================================
20
+ fig = make_subplots(
21
+ rows=1, cols=2,
22
+ subplot_titles=(
23
+ "<b>Model Accuracy (Execution)</b>",
24
+ "<b>Inference Latency Profile</b>"
25
+ ),
26
+ horizontal_spacing=0.1
27
+ )
28
+
29
+ # ==========================================
30
+ # 3. LEFT CHART: ACCURACY
31
+ # ==========================================
32
+ fig.add_trace(go.Bar(
33
+ x=models,
34
+ y=accuracy,
35
+ name="Execution Accuracy",
36
+ marker_color=['#94a3b8', '#38bdf8', '#10b981'], # Gray, Blue, Green
37
+ text=[f"{val:.1f}%" for val in accuracy],
38
+ textposition='auto',
39
+ textfont=dict(size=14, color='white', family="Arial Black"),
40
+ showlegend=False
41
+ ), row=1, col=1)
42
+
43
+ # ==========================================
44
+ # 4. RIGHT CHART: LATENCY PROFILE
45
+ # ==========================================
46
+ # P50 Latency
47
+ fig.add_trace(go.Bar(
48
+ x=models, y=lat_p50,
49
+ name="Median (P50)",
50
+ marker_color="#ece80a" # Light Blue
51
+ ), row=1, col=2)
52
+
53
+ # Mean Latency
54
+ fig.add_trace(go.Bar(
55
+ x=models, y=lat_mean,
56
+ name="Mean Latency",
57
+ marker_color="#3b4da9" # Standard Blue
58
+ ), row=1, col=2)
59
+
60
+ # P90 Latency
61
+ fig.add_trace(go.Bar(
62
+ x=models, y=lat_p90,
63
+ name="90th Percentile (P90)",
64
+ marker_color="#d974e2" # Dark Blue
65
+ ), row=1, col=2)
66
+
67
+ # ==========================================
68
+ # 5. APPLY ULTRA-MODERN STYLING
69
+ # ==========================================
70
+ fig.update_layout(
71
+ title=dict(
72
+ text="<b>Task 5: FP32 vs. INT8 Quantization Performance</b>",
73
+ font=dict(size=22, color='#1e293b'),
74
+ x=0.5
75
+ ),
76
+ plot_bgcolor='white',
77
+ paper_bgcolor='white',
78
+ barmode='group',
79
+ legend=dict(
80
+ orientation="h",
81
+ yanchor="bottom", y=1.05,
82
+ xanchor="center", x=0.8,
83
+ bgcolor='rgba(255,255,255,0.8)'
84
+ ),
85
+ font=dict(family="-apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Helvetica, Arial, sans-serif"),
86
+ margin=dict(t=120, b=60, l=60, r=40)
87
+ )
88
+
89
+ # Style Left Axes
90
+ fig.update_yaxes(title_text="<b>Accuracy (%)</b>", range=[0, 45], gridcolor='#f1f5f9', row=1, col=1)
91
+ fig.update_xaxes(tickfont=dict(weight='bold'), row=1, col=1)
92
+
93
+ # Style Right Axes
94
+ fig.update_yaxes(title_text="<b>Seconds per Query</b>", gridcolor='#f1f5f9', row=1, col=2)
95
+ fig.update_xaxes(tickfont=dict(weight='bold'), row=1, col=2)
96
+
97
+ # ==========================================
98
+ # 6. RENDER AND SAVE
99
+ # ==========================================
100
+ html_file = "task5_quantization_dashboard.html"
101
+ fig.write_html(html_file)
102
+ print(f"✅ Interactive Plotly Dashboard saved to: {html_file}")
103
+ fig.show()
scripts/quantize_export.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import os
5
+ from pathlib import Path
6
+
7
+ import torch
8
+
9
+ from src.quantization_utils import (
10
+ load_bnb_quantized_model,
11
+ load_fp32_model,
12
+ quantize_dynamic_int8,
13
+ quantize_dynamic_int8_decoder_only,
14
+ save_quant_artifact,
15
+ )
16
+
17
+
18
+ def main() -> None:
19
+ p = argparse.ArgumentParser(description="Export quantized Seq2Seq model artifacts for CPU inference.")
20
+ p.add_argument("--base_model", default=os.environ.get("BASE_MODEL", "Salesforce/codet5-base"))
21
+ p.add_argument("--adapter", default="", help="Optional LoRA adapter directory.")
22
+ p.add_argument("--out_dir", required=True, help="Output directory for artifact.")
23
+ p.add_argument(
24
+ "--mode",
25
+ required=True,
26
+ choices=["fp32", "int8_dynamic", "int8_decoder_dynamic", "int8_bnb", "int4_bnb"],
27
+ )
28
+ p.add_argument("--device", default="cpu", help="cpu|cuda (bnb requires cuda)")
29
+ p.add_argument("--local_only", action="store_true", help="Do not hit network; use HF cache only.")
30
+ args = p.parse_args()
31
+
32
+ adapter = args.adapter.strip() or None
33
+ out_dir = Path(args.out_dir)
34
+
35
+ if args.mode == "fp32":
36
+ tok, model = load_fp32_model(args.base_model, adapter_path=adapter, device=args.device, local_only=args.local_only)
37
+ save_quant_artifact(out_dir, mode="fp32", base_model=args.base_model, adapter_path=adapter, tokenizer=tok, model=model)
38
+ return
39
+
40
+ if args.mode == "int8_dynamic":
41
+ tok, model = load_fp32_model(args.base_model, adapter_path=adapter, device="cpu", local_only=args.local_only)
42
+ model = quantize_dynamic_int8(model)
43
+ save_quant_artifact(out_dir, mode="int8_dynamic", base_model=args.base_model, adapter_path=adapter, tokenizer=tok, model=model)
44
+ return
45
+
46
+ if args.mode == "int8_decoder_dynamic":
47
+ tok, model = load_fp32_model(args.base_model, adapter_path=adapter, device="cpu", local_only=args.local_only)
48
+ model = quantize_dynamic_int8_decoder_only(model)
49
+ save_quant_artifact(
50
+ out_dir,
51
+ mode="int8_decoder_dynamic",
52
+ base_model=args.base_model,
53
+ adapter_path=adapter,
54
+ tokenizer=tok,
55
+ model=model,
56
+ )
57
+ return
58
+
59
+ if args.mode == "int8_bnb":
60
+ tok, model = load_bnb_quantized_model(
61
+ args.base_model,
62
+ adapter_path=adapter,
63
+ device=args.device,
64
+ local_only=args.local_only,
65
+ load_in_8bit=True,
66
+ )
67
+ # Note: saving bnb quantized weights in a portable way is non-trivial; we still save state_dict for reference.
68
+ save_quant_artifact(out_dir, mode="int8_bnb", base_model=args.base_model, adapter_path=adapter, tokenizer=tok, model=model)
69
+ return
70
+
71
+ if args.mode == "int4_bnb":
72
+ tok, model = load_bnb_quantized_model(
73
+ args.base_model,
74
+ adapter_path=adapter,
75
+ device=args.device,
76
+ local_only=args.local_only,
77
+ load_in_4bit=True,
78
+ )
79
+ save_quant_artifact(out_dir, mode="int4_bnb", base_model=args.base_model, adapter_path=adapter, tokenizer=tok, model=model)
80
+ return
81
+
82
+
83
+ if __name__ == "__main__":
84
+ torch.set_grad_enabled(False)
85
+ main()
86
+
scripts/quantized_infer_harness.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import json
5
+ import time
6
+ from pathlib import Path
7
+
8
+ from src.quantized_text2sql_engine import QuantizedText2SQLEngine
9
+
10
+
11
+ def main() -> None:
12
+ p = argparse.ArgumentParser(description="Production-style inference harness for quantized artifacts.")
13
+ p.add_argument("--artifact", required=True, help="Quant artifact dir from scripts/quantize_export.py")
14
+ p.add_argument("--num_samples", type=int, default=128)
15
+ p.add_argument("--out", default="results/task5_quant_infer.json")
16
+ args = p.parse_args()
17
+
18
+ root = Path(".")
19
+ dev = json.loads((root / "data" / "dev.json").read_text())
20
+ dev = dev[: args.num_samples]
21
+
22
+ engine = QuantizedText2SQLEngine(args.artifact, device="cpu")
23
+ pairs = [(x["question"], x["db_id"]) for x in dev]
24
+
25
+ t0 = time.perf_counter()
26
+ results = engine.ask_batch_execute(pairs)
27
+ dt = time.perf_counter() - t0
28
+
29
+ out = {
30
+ "n": len(results),
31
+ "seconds": dt,
32
+ "qps": len(results) / max(dt, 1e-9),
33
+ "artifact": args.artifact,
34
+ "meta": engine.meta,
35
+ "results": results[:10], # sample
36
+ }
37
+
38
+ out_path = Path(args.out)
39
+ out_path.parent.mkdir(parents=True, exist_ok=True)
40
+ out_path.write_text(json.dumps(out, indent=2))
41
+ print(json.dumps(out, indent=2))
42
+
43
+
44
+ if __name__ == "__main__":
45
+ main()
46
+
src/__pycache__/execution_reward.cpython-310.pyc ADDED
Binary file (11.2 kB). View file
 
src/__pycache__/quantization_utils.cpython-310.pyc ADDED
Binary file (6.23 kB). View file
 
src/__pycache__/quantized_text2sql_engine.cpython-310.pyc ADDED
Binary file (8.95 kB). View file
 
src/__pycache__/schema_encoder.cpython-310.pyc ADDED
Binary file (1.75 kB). View file
 
src/__pycache__/schema_utils.cpython-310.pyc ADDED
Binary file (3.64 kB). View file
 
src/__pycache__/sql_validator.cpython-310.pyc ADDED
Binary file (5.29 kB). View file
 
src/__pycache__/text2sql_engine.cpython-310.pyc ADDED
Binary file (8.34 kB). View file
 
src/ask.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ TERMINAL CHAT WITH DATABASE
3
+ Run:
4
+ python src/ask.py chinook_1
5
+ """
6
+
7
+ import sys
8
+ from text2sql_engine import get_engine
9
+
10
+
11
+ # -------------------------------
12
+ # Pretty table printer
13
+ # -------------------------------
14
+ def print_table(cols, rows, limit=20):
15
+ if not rows or not cols:
16
+ print("No results\n")
17
+ return
18
+
19
+ cols = [str(c) for c in cols]
20
+
21
+ widths = [max(len(c), 12) for c in cols]
22
+
23
+ for r in rows[:limit]:
24
+ for i, val in enumerate(r):
25
+ widths[i] = max(widths[i], len(str(val)))
26
+
27
+ header = " | ".join(cols[i].ljust(widths[i]) for i in range(len(cols)))
28
+ print("\n" + header)
29
+ print("-" * len(header))
30
+
31
+ for r in rows[:limit]:
32
+ print(" | ".join(str(r[i]).ljust(widths[i]) for i in range(len(cols))))
33
+
34
+ if len(rows) > limit:
35
+ print(f"\n... showing first {limit} rows of {len(rows)}")
36
+
37
+ print()
38
+
39
+
40
+ # -------------------------------
41
+ # Main loop
42
+ # -------------------------------
43
+ def main():
44
+ if len(sys.argv) < 2:
45
+ print("Usage: python src/ask.py <db_id>")
46
+ return
47
+
48
+ db_id = sys.argv[1].strip()
49
+
50
+ print("Loading model... (first time takes 20-40s)")
51
+ engine = get_engine()
52
+
53
+ print(f"\nConnected to database: {db_id}")
54
+ print("Type 'exit' to quit\n")
55
+
56
+ while True:
57
+ try:
58
+ q = input("Ask> ").strip()
59
+
60
+ if not q:
61
+ continue
62
+
63
+ if q.lower() in ["exit", "quit"]:
64
+ break
65
+
66
+ result = engine.ask(q, db_id)
67
+
68
+ if result is None:
69
+ print("Model returned no output\n")
70
+ continue
71
+
72
+ print("\nGenerated SQL:")
73
+ print(result.get("sql", "<no sql>"))
74
+
75
+ if result.get("error"):
76
+ print("\nSQL Error:")
77
+ print(result["error"])
78
+ else:
79
+ print_table(
80
+ result.get("columns", []),
81
+ result.get("rows", []),
82
+ )
83
+
84
+ except KeyboardInterrupt:
85
+ break
86
+ except Exception as e:
87
+ print("\nRuntime error:", e, "\n")
88
+
89
+ print("\nBye!")
90
+
91
+
92
+ if __name__ == "__main__":
93
+ main()
src/component_analysis.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import sqlite3
3
+ import torch
4
+ import re
5
+ import matplotlib.pyplot as plt
6
+ import numpy as np
7
+ from pathlib import Path
8
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
9
+ from peft import PeftModel
10
+
11
+ PROJECT_ROOT = Path(__file__).resolve().parents[1]
12
+ DB_ROOT = PROJECT_ROOT / "data" / "database"
13
+
14
+ # -------------------------------
15
+ # Extract SQL components
16
+ # -------------------------------
17
+ def extract_components(sql):
18
+ sql = sql.lower()
19
+ return {
20
+ "select": "select" in sql,
21
+ "where": "where" in sql,
22
+ "group": "group by" in sql,
23
+ "order": "order by" in sql,
24
+ "and_or": (" and " in sql) or (" or " in sql),
25
+ "join": "join" in sql
26
+ }
27
+
28
+ # -------------------------------
29
+ # Fallback Difficulty Estimator
30
+ # -------------------------------
31
+ def estimate_difficulty(sql):
32
+ """Fallback if 'difficulty' is missing from the JSON."""
33
+ sql = sql.lower()
34
+ joins = sql.count("join")
35
+ conditions = sql.count("and") + sql.count("or")
36
+
37
+ if "intersect" in sql or "except" in sql or "union" in sql or joins > 2:
38
+ return "extra"
39
+ elif joins == 2 or ("group by" in sql and conditions > 0):
40
+ return "hard"
41
+ elif joins == 1 or "group by" in sql or "order by" in sql:
42
+ return "medium"
43
+ else:
44
+ return "easy"
45
+
46
+ # -------------------------------
47
+ # Load schema
48
+ # -------------------------------
49
+ def load_schema(db_path):
50
+ conn = sqlite3.connect(db_path)
51
+ conn.text_factory = lambda b: b.decode(errors='ignore')
52
+ cursor = conn.cursor()
53
+
54
+ tables = cursor.execute(
55
+ "SELECT name FROM sqlite_master WHERE type='table';"
56
+ ).fetchall()
57
+
58
+ schema = ""
59
+ for (table,) in tables:
60
+ cols = cursor.execute(f"PRAGMA table_info({table});").fetchall()
61
+ col_names = [c[1] for c in cols]
62
+ schema += f"{table}({', '.join(col_names)})\n"
63
+
64
+ conn.close()
65
+ return schema
66
+
67
+ # -------------------------------
68
+ # Prompt
69
+ # -------------------------------
70
+ def build_prompt(question, schema):
71
+ return f"""Database Schema:
72
+ {schema}
73
+
74
+ Translate English to SQL:
75
+ {question}
76
+ SQL:
77
+ """
78
+
79
+ # -------------------------------
80
+ # Main
81
+ # -------------------------------
82
+ def main():
83
+ adapter = "checkpoints/rl_step_1800"
84
+ base_model = "Salesforce/codet5-base"
85
+
86
+ device = "mps" if torch.backends.mps.is_available() else ("cuda" if torch.cuda.is_available() else "cpu")
87
+
88
+ print("Loading tokenizer and models...")
89
+ tokenizer = AutoTokenizer.from_pretrained(adapter)
90
+ base = AutoModelForSeq2SeqLM.from_pretrained(base_model).to(device)
91
+ model = PeftModel.from_pretrained(base, adapter).to(device)
92
+ model = model.merge_and_unload()
93
+ model.eval()
94
+
95
+ dev_json = PROJECT_ROOT / "data" / "dev.json"
96
+
97
+ with open(dev_json) as f:
98
+ dev = json.load(f)[:1000] # Adjust number to test more/less
99
+
100
+ components_list = ["select", "where", "group", "order", "and_or", "join"]
101
+ difficulties_list = ["easy", "medium", "hard", "extra"]
102
+
103
+ # Nested dictionary for components
104
+ stats = {
105
+ comp: {diff: {"correct": 0, "total": 0} for diff in difficulties_list}
106
+ for comp in components_list
107
+ }
108
+
109
+ # 🚀 NEW: Trackers for OVERALL accuracy by difficulty
110
+ overall_correct = {diff: 0 for diff in difficulties_list}
111
+ overall_total = {diff: 0 for diff in difficulties_list}
112
+
113
+ print(f"\nRunning grouped evaluation on {len(dev)} examples...\n")
114
+
115
+ for i, ex in enumerate(dev, 1):
116
+ question = ex["question"]
117
+ gold_sql = ex["query"]
118
+ db_id = ex["db_id"]
119
+
120
+ # Determine difficulty
121
+ difficulty = ex.get("difficulty", estimate_difficulty(gold_sql))
122
+ if difficulty not in difficulties_list:
123
+ difficulty = "medium"
124
+
125
+ db_path = DB_ROOT / db_id / f"{db_id}.sqlite"
126
+ schema = load_schema(db_path)
127
+ prompt = build_prompt(question, schema)
128
+
129
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
130
+
131
+ with torch.no_grad():
132
+ outputs = model.generate(
133
+ **inputs,
134
+ max_new_tokens=1000,
135
+ num_beams=4,
136
+ do_sample=False
137
+ )
138
+
139
+ pred_sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
140
+ if "SQL:" in pred_sql:
141
+ pred_sql = pred_sql.split("SQL:")[-1]
142
+
143
+ # --- 1. Update Overall Accuracy Trackers ---
144
+ overall_total[difficulty] += 1
145
+ # Simple string match for quick overall accuracy
146
+ if pred_sql.strip().lower() == gold_sql.strip().lower():
147
+ overall_correct[difficulty] += 1
148
+
149
+ # --- 2. Update Component Stats ---
150
+ pred_comp = extract_components(pred_sql)
151
+ gold_comp = extract_components(gold_sql)
152
+
153
+ for comp in components_list:
154
+ if gold_comp[comp]: # If the gold SQL required this component
155
+ stats[comp][difficulty]["total"] += 1
156
+ if pred_comp[comp]: # If the model successfully generated it
157
+ stats[comp][difficulty]["correct"] += 1
158
+
159
+ if i % 20 == 0:
160
+ print(f"Processed {i}/{len(dev)}")
161
+
162
+ # -------------------------------
163
+ # Plotting (Grouped Bar Chart)
164
+ # -------------------------------
165
+ x = np.arange(len(components_list))
166
+ width = 0.2
167
+
168
+ def get_acc(diff):
169
+ return [
170
+ (stats[comp][diff]["correct"] / stats[comp][diff]["total"] * 100) if stats[comp][diff]["total"] > 0 else 0
171
+ for comp in components_list
172
+ ]
173
+
174
+ acc_easy = get_acc("easy")
175
+ acc_medium = get_acc("medium")
176
+ acc_hard = get_acc("hard")
177
+ acc_extra = get_acc("extra")
178
+
179
+ fig, ax = plt.subplots(figsize=(14, 7))
180
+
181
+ bars1 = ax.bar(x - 1.5 * width, acc_easy, width, label='Easy', color='#2ecc71')
182
+ bars2 = ax.bar(x - 0.5 * width, acc_medium, width, label='Medium', color='#f1c40f')
183
+ bars3 = ax.bar(x + 0.5 * width, acc_hard, width, label='Hard', color='#e67e22')
184
+ bars4 = ax.bar(x + 1.5 * width, acc_extra, width, label='Extra', color='#e74c3c')
185
+
186
+ ax.set_ylabel('Accuracy (%)', fontsize=12)
187
+ ax.set_title('SQL Component Match Accuracy by Difficulty Level', fontsize=14, fontweight='bold')
188
+ ax.set_xticks(x)
189
+ ax.set_xticklabels([c.upper() for c in components_list], fontsize=11)
190
+ ax.legend(title="Query Difficulty")
191
+ ax.set_ylim(0, 115)
192
+
193
+ def autolabel(rects):
194
+ for rect in rects:
195
+ height = rect.get_height()
196
+ if height > 0:
197
+ ax.annotate(f'{int(height)}%',
198
+ xy=(rect.get_x() + rect.get_width() / 2, height),
199
+ xytext=(0, 3),
200
+ textcoords="offset points",
201
+ ha='center', va='bottom', fontsize=8, rotation=90)
202
+
203
+ autolabel(bars1)
204
+ autolabel(bars2)
205
+ autolabel(bars3)
206
+ autolabel(bars4)
207
+
208
+ ax.yaxis.grid(True, linestyle='--', alpha=0.7)
209
+ plt.tight_layout()
210
+ plt.savefig("component_by_difficulty_plot.png", dpi=300)
211
+
212
+ # -------------------------------
213
+ # 🚀 Terminal Printout
214
+ # -------------------------------
215
+ print("\n✅ Saved merged plot -> component_by_difficulty_plot.png")
216
+
217
+ print("\n========================================")
218
+ print("🏆 OVERALL AVERAGE ACCURACY BY DIFFICULTY")
219
+ print("========================================")
220
+ for diff in difficulties_list:
221
+ if overall_total[diff] > 0:
222
+ avg = round((overall_correct[diff] / overall_total[diff]) * 100, 2)
223
+ print(f"{diff.capitalize():<8}: {avg:>5}% ({overall_correct[diff]}/{overall_total[diff]} queries)")
224
+ else:
225
+ print(f"{diff.capitalize():<8}: N/A (0 queries)")
226
+ print("========================================\n")
227
+
228
+ if __name__ == "__main__":
229
+ main()
src/constrained_decoding.py ADDED
@@ -0,0 +1,1058 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # from __future__ import annotations
2
+
3
+ # import re
4
+ # import threading
5
+ # from dataclasses import dataclass
6
+ # from typing import Dict, Iterable, List, Optional, Sequence, Set
7
+
8
+ # import torch
9
+ # from transformers.generation.logits_process import LogitsProcessor
10
+
11
+ # from schema_constraints import ConstraintGraph, build_constraint_graph
12
+
13
+
14
+ # def _infer_expected_identifier(prefix_text: str) -> Optional[str]:
15
+ # s = re.sub(r"\s+", " ", prefix_text.lower())
16
+ # last_from = s.rfind(" from ")
17
+ # last_join = s.rfind(" join ")
18
+ # last_select = s.rfind(" select ")
19
+ # last_where = s.rfind(" where ")
20
+ # last_on = s.rfind(" on ")
21
+ # last_group = s.rfind(" group by ")
22
+ # last_order = s.rfind(" order by ")
23
+ # last_having = s.rfind(" having ")
24
+
25
+ # last_table_kw = max(last_from, last_join)
26
+ # last_col_kw = max(last_select, last_where, last_on, last_group, last_order, last_having)
27
+
28
+ # if last_table_kw < 0 and last_col_kw < 0:
29
+ # return None
30
+ # if last_table_kw > last_col_kw:
31
+ # return "table"
32
+ # if last_col_kw > last_table_kw:
33
+ # return "column"
34
+ # return None
35
+
36
+
37
+ # class _TrieNode:
38
+ # __slots__ = ("children", "terminal")
39
+
40
+ # def __init__(self) -> None:
41
+ # self.children: Dict[int, _TrieNode] = {}
42
+ # self.terminal: bool = False
43
+
44
+ # def insert(self, token_ids: Sequence[int]) -> None:
45
+ # node: _TrieNode = self
46
+ # for tid in token_ids:
47
+ # tid_i = int(tid)
48
+ # nxt = node.children.get(tid_i)
49
+ # if nxt is None:
50
+ # nxt = _TrieNode()
51
+ # node.children[tid_i] = nxt
52
+ # node = nxt
53
+ # node.terminal = True
54
+
55
+ # def walk(self, prefix: Sequence[int]) -> Optional["_TrieNode"]:
56
+ # node: _TrieNode = self
57
+ # for tid in prefix:
58
+ # node = node.children.get(int(tid)) # type: ignore[assignment]
59
+ # if node is None:
60
+ # return None
61
+ # return node
62
+
63
+
64
+ # def _encode_identifier(tokenizer, name: str) -> List[int]:
65
+ # # Leading space encourages word-start markers (e.g. "Ġ" in RoBERTa BPE).
66
+ # return tokenizer.encode(" " + name, add_special_tokens=False)
67
+
68
+
69
+ # def _build_trie(tokenizer, names: Iterable[str]) -> _TrieNode:
70
+ # trie = _TrieNode()
71
+ # for n in names:
72
+ # if not n:
73
+ # continue
74
+ # try:
75
+ # ids = _encode_identifier(tokenizer, n)
76
+ # except Exception:
77
+ # continue
78
+ # if ids:
79
+ # trie.insert(ids)
80
+ # return trie
81
+
82
+
83
+ # def _allow_always_token_ids(tokenizer) -> torch.Tensor:
84
+ # # Allow common delimiters so the model can end an identifier.
85
+ # toks = [",", ")", "(", "\n", ".", ";"]
86
+ # ids: Set[int] = set()
87
+ # for t in toks:
88
+ # try:
89
+ # for tid in tokenizer.encode(t, add_special_tokens=False):
90
+ # ids.add(int(tid))
91
+ # except Exception:
92
+ # continue
93
+ # return torch.tensor(sorted(ids), dtype=torch.long)
94
+
95
+
96
+ # @dataclass
97
+ # class _PerDbTokenSets:
98
+ # fp: str
99
+ # table_trie: _TrieNode
100
+ # column_trie: _TrieNode
101
+ # allow_always: torch.Tensor
102
+
103
+
104
+ # _DB_TOKENSET_LOCK = threading.Lock()
105
+ # _DB_TOKENSETS: Dict[str, _PerDbTokenSets] = {}
106
+
107
+
108
+ # def _per_db_tokensets(tokenizer, graph: ConstraintGraph) -> _PerDbTokenSets:
109
+ # with _DB_TOKENSET_LOCK:
110
+ # cached = _DB_TOKENSETS.get(graph.db_path)
111
+ # if cached is not None and cached.fp == graph.fingerprint:
112
+ # return cached
113
+
114
+ # out = _PerDbTokenSets(
115
+ # fp=graph.fingerprint,
116
+ # table_trie=_build_trie(tokenizer, graph.tables),
117
+ # column_trie=_build_trie(tokenizer, graph.all_columns),
118
+ # allow_always=_allow_always_token_ids(tokenizer),
119
+ # )
120
+ # with _DB_TOKENSET_LOCK:
121
+ # _DB_TOKENSETS[graph.db_path] = out
122
+ # return out
123
+
124
+
125
+ # class BatchSchemaConstrainedLogitsProcessor(LogitsProcessor):
126
+ # """
127
+ # Schema-aware constrained decoding per item in the generation batch.
128
+ # Uses a tokenizer-based trie so multi-token identifiers can be constrained.
129
+ # """
130
+
131
+ # def __init__(self, tokenizer, db_paths: Sequence[str], *, max_prefix_tokens: int = 48):
132
+ # self.tokenizer = tokenizer
133
+ # self.db_paths = list(db_paths)
134
+ # self.max_prefix_tokens = int(max_prefix_tokens)
135
+
136
+ # self._graphs = [build_constraint_graph(p) for p in self.db_paths]
137
+ # self._token_sets = [_per_db_tokensets(tokenizer, g) for g in self._graphs]
138
+
139
+ # def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
140
+ # if input_ids.dim() != 2 or scores.dim() != 2:
141
+ # return scores
142
+
143
+ # batch = input_ids.size(0)
144
+ # if batch != len(self._graphs):
145
+ # return scores
146
+
147
+ # for i in range(batch):
148
+ # tail_ids = input_ids[i, -self.max_prefix_tokens :].tolist()
149
+ # prefix_text = self.tokenizer.decode(tail_ids, skip_special_tokens=True)
150
+ # expected = _infer_expected_identifier(prefix_text)
151
+ # if expected is None:
152
+ # continue
153
+
154
+ # if expected == "table":
155
+ # m = re.search(r"(?:from|join)\s+([A-Za-z_][A-Za-z0-9_]*)$", prefix_text, flags=re.I)
156
+ # partial = m.group(1) if m else None
157
+ # if partial is None and not re.search(r"(?:from|join)\s*$", prefix_text, flags=re.I):
158
+ # continue
159
+ # trie = self._token_sets[i].table_trie
160
+ # else:
161
+ # m = re.search(
162
+ # r"(?:select|where|on|group by|order by|having)\s+([A-Za-z_][A-Za-z0-9_]*(?:\.[A-Za-z_][A-Za-z0-9_]*)?)$",
163
+ # prefix_text,
164
+ # flags=re.I,
165
+ # )
166
+ # partial = m.group(1) if m else None
167
+ # if partial is None and not re.search(
168
+ # r"(?:select|where|on|group by|order by|having)\s*$", prefix_text, flags=re.I
169
+ # ):
170
+ # continue
171
+ # trie = self._token_sets[i].column_trie
172
+
173
+ # if not partial:
174
+ # prefix_token_ids: List[int] = []
175
+ # else:
176
+ # try:
177
+ # prefix_token_ids = _encode_identifier(self.tokenizer, partial)
178
+ # except Exception:
179
+ # continue
180
+
181
+ # node = trie.walk(prefix_token_ids)
182
+ # if node is None or node.terminal:
183
+ # continue
184
+
185
+ # allowed_next = sorted(node.children.keys())
186
+ # if not allowed_next:
187
+ # continue
188
+
189
+ # allowed_next_t = torch.tensor(allowed_next, dtype=torch.long, device=scores.device)
190
+ # allow_always = self._token_sets[i].allow_always.to(scores.device)
191
+ # keep = torch.cat([allowed_next_t, allow_always]) if allow_always.numel() else allowed_next_t
192
+
193
+ # kept_scores = scores[i, keep].clone()
194
+ # scores[i, :] = -float("inf")
195
+ # scores[i, keep] = kept_scores
196
+
197
+ # return scores
198
+
199
+
200
+ # # Backwards-compatible names used elsewhere in the repo.
201
+ # class SchemaConstraintGraph:
202
+ # def __init__(self, db_path: str):
203
+ # self._graph = build_constraint_graph(db_path)
204
+ # self.tables = sorted(self._graph.tables)
205
+ # self.columns = sorted(self._graph.all_columns)
206
+
207
+
208
+ # class SchemaConstrainedLogitsProcessor(LogitsProcessor):
209
+ # def __init__(self, tokenizer, schema_graph: SchemaConstraintGraph):
210
+ # self._proc = BatchSchemaConstrainedLogitsProcessor(tokenizer, [schema_graph._graph.db_path])
211
+
212
+ # def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
213
+ # return self._proc(input_ids, scores)
214
+
215
+
216
+
217
+
218
+ # from __future__ import annotations
219
+
220
+ # import re
221
+ # import threading
222
+ # from dataclasses import dataclass
223
+ # from typing import Dict, Iterable, List, Optional, Sequence, Set
224
+
225
+ # import torch
226
+ # from transformers.generation.logits_process import LogitsProcessor
227
+
228
+ # from schema_constraints import ConstraintGraph, build_constraint_graph
229
+
230
+
231
+ # # =========================================================
232
+ # # 🔍 IDENTIFIER TYPE DETECTION
233
+ # # =========================================================
234
+ # def _infer_expected_identifier(prefix_text: str) -> Optional[str]:
235
+ # s = re.sub(r"\s+", " ", prefix_text.lower())
236
+
237
+ # last_from = s.rfind(" from ")
238
+ # last_join = s.rfind(" join ")
239
+ # last_select = s.rfind(" select ")
240
+ # last_where = s.rfind(" where ")
241
+ # last_on = s.rfind(" on ")
242
+ # last_group = s.rfind(" group by ")
243
+ # last_order = s.rfind(" order by ")
244
+ # last_having = s.rfind(" having ")
245
+
246
+ # last_table_kw = max(last_from, last_join)
247
+ # last_col_kw = max(last_select, last_where, last_on, last_group, last_order, last_having)
248
+
249
+ # if last_table_kw < 0 and last_col_kw < 0:
250
+ # return None
251
+ # if last_table_kw > last_col_kw:
252
+ # return "table"
253
+ # if last_col_kw > last_table_kw:
254
+ # return "column"
255
+ # return None
256
+
257
+
258
+ # # =========================================================
259
+ # # 🌳 TRIE STRUCTURE
260
+ # # =========================================================
261
+ # class _TrieNode:
262
+ # __slots__ = ("children", "terminal")
263
+
264
+ # def __init__(self) -> None:
265
+ # self.children: Dict[int, _TrieNode] = {}
266
+ # self.terminal: bool = False
267
+
268
+ # def insert(self, token_ids: Sequence[int]) -> None:
269
+ # node = self
270
+ # for tid in token_ids:
271
+ # tid = int(tid)
272
+ # if tid not in node.children:
273
+ # node.children[tid] = _TrieNode()
274
+ # node = node.children[tid]
275
+ # node.terminal = True
276
+
277
+ # def walk(self, prefix: Sequence[int]) -> Optional["_TrieNode"]:
278
+ # node = self
279
+ # for tid in prefix:
280
+ # node = node.children.get(int(tid))
281
+ # if node is None:
282
+ # return None
283
+ # return node
284
+
285
+
286
+ # # =========================================================
287
+ # # 🔤 TOKEN ENCODING
288
+ # # =========================================================
289
+ # def _encode_identifier(tokenizer, name: str) -> List[int]:
290
+ # return tokenizer.encode(" " + name, add_special_tokens=False)
291
+
292
+
293
+ # def _build_trie(tokenizer, names: Iterable[str]) -> _TrieNode:
294
+ # trie = _TrieNode()
295
+ # for name in names:
296
+ # try:
297
+ # ids = _encode_identifier(tokenizer, name)
298
+ # if ids:
299
+ # trie.insert(ids)
300
+ # except Exception:
301
+ # continue
302
+ # return trie
303
+
304
+
305
+ # def _allow_always_token_ids(tokenizer) -> torch.Tensor:
306
+ # tokens = [",", ")", "(", ".", ";", "\n"]
307
+ # ids: Set[int] = set()
308
+
309
+ # for t in tokens:
310
+ # try:
311
+ # ids.update(tokenizer.encode(t, add_special_tokens=False))
312
+ # except:
313
+ # pass
314
+
315
+ # return torch.tensor(sorted(ids), dtype=torch.long)
316
+
317
+
318
+ # # =========================================================
319
+ # # 📦 PER-DB CACHE
320
+ # # =========================================================
321
+ # @dataclass
322
+ # class _PerDbTokenSets:
323
+ # fp: str
324
+ # table_trie: _TrieNode
325
+ # column_trie: _TrieNode
326
+ # allow_always: torch.Tensor
327
+
328
+
329
+ # _DB_CACHE: Dict[str, _PerDbTokenSets] = {}
330
+ # _DB_LOCK = threading.Lock()
331
+
332
+
333
+ # def _per_db_tokensets(tokenizer, graph: ConstraintGraph) -> _PerDbTokenSets:
334
+ # with _DB_LOCK:
335
+ # cached = _DB_CACHE.get(graph.db_path)
336
+ # if cached and cached.fp == graph.fingerprint:
337
+ # return cached
338
+
339
+ # obj = _PerDbTokenSets(
340
+ # fp=graph.fingerprint,
341
+ # table_trie=_build_trie(tokenizer, graph.tables),
342
+ # column_trie=_build_trie(tokenizer, graph.all_columns),
343
+ # allow_always=_allow_always_token_ids(tokenizer),
344
+ # )
345
+
346
+ # with _DB_LOCK:
347
+ # _DB_CACHE[graph.db_path] = obj
348
+
349
+ # return obj
350
+
351
+
352
+ # # =========================================================
353
+ # # 🚀 MAIN LOGITS PROCESSOR
354
+ # # =========================================================
355
+ # class BatchSchemaConstrainedLogitsProcessor(LogitsProcessor):
356
+ # def __init__(self, tokenizer, db_paths: Sequence[str], max_prefix_tokens: int = 48):
357
+ # self.tokenizer = tokenizer
358
+ # self.db_paths = list(db_paths)
359
+ # self.max_prefix_tokens = max_prefix_tokens
360
+
361
+ # self._graphs = [build_constraint_graph(p) for p in db_paths]
362
+ # self._token_sets = [_per_db_tokensets(tokenizer, g) for g in self._graphs]
363
+
364
+ # # 📊 Metrics (IMPORTANT FOR REPORT)
365
+ # self.total_steps = 0
366
+ # self.constrained_steps = 0
367
+
368
+ # def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
369
+ # batch = input_ids.size(0)
370
+
371
+ # for i in range(batch):
372
+ # self.total_steps += 1
373
+
374
+ # tail_ids = input_ids[i, -self.max_prefix_tokens:].tolist()
375
+ # prefix_text = self.tokenizer.decode(tail_ids, skip_special_tokens=True)
376
+
377
+ # expected = _infer_expected_identifier(prefix_text)
378
+ # if expected is None:
379
+ # continue
380
+
381
+ # self.constrained_steps += 1
382
+
383
+ # # =========================
384
+ # # SELECT TRIE
385
+ # # =========================
386
+ # if expected == "table":
387
+ # trie = self._token_sets[i].table_trie
388
+ # else:
389
+ # trie = self._token_sets[i].column_trie
390
+
391
+ # # =========================
392
+ # # PARTIAL TOKEN MATCH
393
+ # # =========================
394
+ # match = re.search(r"([A-Za-z_][A-Za-z0-9_]*)$", prefix_text)
395
+ # partial = match.group(1) if match else ""
396
+
397
+ # try:
398
+ # prefix_ids = _encode_identifier(self.tokenizer, partial) if partial else []
399
+ # except:
400
+ # continue
401
+
402
+ # node = trie.walk(prefix_ids)
403
+ # if node is None or node.terminal:
404
+ # continue
405
+
406
+ # allowed_next = list(node.children.keys())
407
+ # if not allowed_next:
408
+ # continue
409
+
410
+ # allowed_next = torch.tensor(allowed_next, device=scores.device)
411
+ # allow_always = self._token_sets[i].allow_always.to(scores.device)
412
+
413
+ # keep = torch.cat([allowed_next, allow_always])
414
+
415
+ # kept_scores = scores[i, keep].clone()
416
+ # scores[i, :] = -float("inf")
417
+ # scores[i, keep] = kept_scores
418
+
419
+ # return scores
420
+
421
+ # # =========================================================
422
+ # # 📊 METRICS FOR REPORT
423
+ # # =========================================================
424
+ # def get_constraint_stats(self):
425
+ # if self.total_steps == 0:
426
+ # return 0
427
+ # return self.constrained_steps / self.total_steps
428
+
429
+
430
+ # # =========================================================
431
+ # # 🔁 BACKWARD COMPATIBILITY
432
+ # # =========================================================
433
+ # class SchemaConstraintGraph:
434
+ # def __init__(self, db_path: str):
435
+ # self._graph = build_constraint_graph(db_path)
436
+ # self.tables = sorted(self._graph.tables)
437
+ # self.columns = sorted(self._graph.all_columns)
438
+
439
+
440
+ # class SchemaConstrainedLogitsProcessor(LogitsProcessor):
441
+ # def __init__(self, tokenizer, schema_graph: SchemaConstraintGraph):
442
+ # self.proc = BatchSchemaConstrainedLogitsProcessor(
443
+ # tokenizer, [schema_graph._graph.db_path]
444
+ # )
445
+
446
+ # def __call__(self, input_ids, scores):
447
+ # return self.proc(input_ids, scores)
448
+
449
+
450
+
451
+
452
+
453
+
454
+ # from __future__ import annotations
455
+
456
+ # import re
457
+ # import threading
458
+ # from dataclasses import dataclass
459
+ # from typing import Dict, Iterable, List, Optional, Sequence, Set
460
+
461
+ # import torch
462
+ # from transformers.generation.logits_process import LogitsProcessor
463
+
464
+ # from schema_constraints import ConstraintGraph, build_constraint_graph
465
+
466
+
467
+ # def _infer_expected_identifier(prefix_text: str) -> Optional[str]:
468
+ # s = re.sub(r"\s+", " ", prefix_text.lower())
469
+ # last_from = s.rfind(" from ")
470
+ # last_join = s.rfind(" join ")
471
+ # last_select = s.rfind(" select ")
472
+ # last_where = s.rfind(" where ")
473
+ # last_on = s.rfind(" on ")
474
+ # last_group = s.rfind(" group by ")
475
+ # last_order = s.rfind(" order by ")
476
+ # last_having = s.rfind(" having ")
477
+
478
+ # last_table_kw = max(last_from, last_join)
479
+ # last_col_kw = max(last_select, last_where, last_on, last_group, last_order, last_having)
480
+
481
+ # if last_table_kw < 0 and last_col_kw < 0:
482
+ # return None
483
+ # if last_table_kw > last_col_kw:
484
+ # return "table"
485
+ # if last_col_kw > last_table_kw:
486
+ # return "column"
487
+ # return None
488
+
489
+
490
+ # class _TrieNode:
491
+ # __slots__ = ("children", "terminal")
492
+
493
+ # def __init__(self) -> None:
494
+ # self.children: Dict[int, _TrieNode] = {}
495
+ # self.terminal: bool = False
496
+
497
+ # def insert(self, token_ids: Sequence[int]) -> None:
498
+ # node: _TrieNode = self
499
+ # for tid in token_ids:
500
+ # tid_i = int(tid)
501
+ # nxt = node.children.get(tid_i)
502
+ # if nxt is None:
503
+ # nxt = _TrieNode()
504
+ # node.children[tid_i] = nxt
505
+ # node = nxt
506
+ # node.terminal = True
507
+
508
+ # def walk(self, prefix: Sequence[int]) -> Optional["_TrieNode"]:
509
+ # node: _TrieNode = self
510
+ # for tid in prefix:
511
+ # node = node.children.get(int(tid)) # type: ignore[assignment]
512
+ # if node is None:
513
+ # return None
514
+ # return node
515
+
516
+
517
+ # def _encode_identifier(tokenizer, name: str) -> List[int]:
518
+ # # Leading space encourages word-start markers (e.g. "Ġ" in RoBERTa BPE).
519
+ # return tokenizer.encode(" " + name, add_special_tokens=False)
520
+
521
+
522
+ # def _build_trie(tokenizer, names: Iterable[str]) -> _TrieNode:
523
+ # trie = _TrieNode()
524
+ # for n in names:
525
+ # if not n:
526
+ # continue
527
+ # try:
528
+ # ids = _encode_identifier(tokenizer, n)
529
+ # except Exception:
530
+ # continue
531
+ # if ids:
532
+ # trie.insert(ids)
533
+ # return trie
534
+
535
+
536
+ # def _allow_always_token_ids(tokenizer) -> torch.Tensor:
537
+ # # Allow common delimiters so the model can end an identifier.
538
+ # toks = [",", ")", "(", "\n", ".", ";"]
539
+ # ids: Set[int] = set()
540
+ # for t in toks:
541
+ # try:
542
+ # for tid in tokenizer.encode(t, add_special_tokens=False):
543
+ # ids.add(int(tid))
544
+ # except Exception:
545
+ # continue
546
+ # return torch.tensor(sorted(ids), dtype=torch.long)
547
+
548
+
549
+ # @dataclass
550
+ # class _PerDbTokenSets:
551
+ # fp: str
552
+ # table_trie: _TrieNode
553
+ # column_trie: _TrieNode
554
+ # allow_always: torch.Tensor
555
+
556
+
557
+ # _DB_TOKENSET_LOCK = threading.Lock()
558
+ # _DB_TOKENSETS: Dict[str, _PerDbTokenSets] = {}
559
+
560
+
561
+ # def _per_db_tokensets(tokenizer, graph: ConstraintGraph) -> _PerDbTokenSets:
562
+ # with _DB_TOKENSET_LOCK:
563
+ # cached = _DB_TOKENSETS.get(graph.db_path)
564
+ # if cached is not None and cached.fp == graph.fingerprint:
565
+ # return cached
566
+
567
+ # out = _PerDbTokenSets(
568
+ # fp=graph.fingerprint,
569
+ # table_trie=_build_trie(tokenizer, graph.tables),
570
+ # column_trie=_build_trie(tokenizer, graph.all_columns),
571
+ # allow_always=_allow_always_token_ids(tokenizer),
572
+ # )
573
+ # with _DB_TOKENSET_LOCK:
574
+ # _DB_TOKENSETS[graph.db_path] = out
575
+ # return out
576
+
577
+
578
+ # class BatchSchemaConstrainedLogitsProcessor(LogitsProcessor):
579
+ # """
580
+ # Schema-aware constrained decoding per item in the generation batch.
581
+ # Uses a tokenizer-based trie so multi-token identifiers can be constrained.
582
+ # """
583
+
584
+ # def __init__(self, tokenizer, db_paths: Sequence[str], *, max_prefix_tokens: int = 48):
585
+ # self.tokenizer = tokenizer
586
+ # self.db_paths = list(db_paths)
587
+ # self.max_prefix_tokens = int(max_prefix_tokens)
588
+
589
+ # self._graphs = [build_constraint_graph(p) for p in self.db_paths]
590
+ # self._token_sets = [_per_db_tokensets(tokenizer, g) for g in self._graphs]
591
+
592
+ # def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
593
+ # if input_ids.dim() != 2 or scores.dim() != 2:
594
+ # return scores
595
+
596
+ # batch = input_ids.size(0)
597
+ # if batch != len(self._graphs):
598
+ # return scores
599
+
600
+ # for i in range(batch):
601
+ # tail_ids = input_ids[i, -self.max_prefix_tokens :].tolist()
602
+ # prefix_text = self.tokenizer.decode(tail_ids, skip_special_tokens=True)
603
+ # expected = _infer_expected_identifier(prefix_text)
604
+ # if expected is None:
605
+ # continue
606
+
607
+ # if expected == "table":
608
+ # m = re.search(r"(?:from|join)\s+([A-Za-z_][A-Za-z0-9_]*)$", prefix_text, flags=re.I)
609
+ # partial = m.group(1) if m else None
610
+ # if partial is None and not re.search(r"(?:from|join)\s*$", prefix_text, flags=re.I):
611
+ # continue
612
+ # trie = self._token_sets[i].table_trie
613
+ # else:
614
+ # m = re.search(
615
+ # r"(?:select|where|on|group by|order by|having)\s+([A-Za-z_][A-Za-z0-9_]*(?:\.[A-Za-z_][A-Za-z0-9_]*)?)$",
616
+ # prefix_text,
617
+ # flags=re.I,
618
+ # )
619
+ # partial = m.group(1) if m else None
620
+ # if partial is None and not re.search(
621
+ # r"(?:select|where|on|group by|order by|having)\s*$", prefix_text, flags=re.I
622
+ # ):
623
+ # continue
624
+ # trie = self._token_sets[i].column_trie
625
+
626
+ # if not partial:
627
+ # prefix_token_ids: List[int] = []
628
+ # else:
629
+ # try:
630
+ # prefix_token_ids = _encode_identifier(self.tokenizer, partial)
631
+ # except Exception:
632
+ # continue
633
+
634
+ # node = trie.walk(prefix_token_ids)
635
+ # if node is None or node.terminal:
636
+ # continue
637
+
638
+ # allowed_next = sorted(node.children.keys())
639
+ # if not allowed_next:
640
+ # continue
641
+
642
+ # allowed_next_t = torch.tensor(allowed_next, dtype=torch.long, device=scores.device)
643
+ # allow_always = self._token_sets[i].allow_always.to(scores.device)
644
+ # keep = torch.cat([allowed_next_t, allow_always]) if allow_always.numel() else allowed_next_t
645
+
646
+ # kept_scores = scores[i, keep].clone()
647
+ # scores[i, :] = -float("inf")
648
+ # scores[i, keep] = kept_scores
649
+
650
+ # return scores
651
+
652
+
653
+ # # Backwards-compatible names used elsewhere in the repo.
654
+ # class SchemaConstraintGraph:
655
+ # def __init__(self, db_path: str):
656
+ # self._graph = build_constraint_graph(db_path)
657
+ # self.tables = sorted(self._graph.tables)
658
+ # self.columns = sorted(self._graph.all_columns)
659
+
660
+
661
+ # class SchemaConstrainedLogitsProcessor(LogitsProcessor):
662
+ # def __init__(self, tokenizer, schema_graph: SchemaConstraintGraph):
663
+ # self._proc = BatchSchemaConstrainedLogitsProcessor(tokenizer, [schema_graph._graph.db_path])
664
+
665
+ # def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
666
+ # return self._proc(input_ids, scores)
667
+
668
+
669
+
670
+
671
+ # from __future__ import annotations
672
+
673
+ # import re
674
+ # import threading
675
+ # from dataclasses import dataclass
676
+ # from typing import Dict, Iterable, List, Optional, Sequence, Set
677
+
678
+ # import torch
679
+ # from transformers.generation.logits_process import LogitsProcessor
680
+
681
+ # from schema_constraints import ConstraintGraph, build_constraint_graph
682
+
683
+
684
+ # # =========================================================
685
+ # # 🔍 IDENTIFIER TYPE DETECTION
686
+ # # =========================================================
687
+ # def _infer_expected_identifier(prefix_text: str) -> Optional[str]:
688
+ # s = re.sub(r"\s+", " ", prefix_text.lower())
689
+
690
+ # last_from = s.rfind(" from ")
691
+ # last_join = s.rfind(" join ")
692
+ # last_select = s.rfind(" select ")
693
+ # last_where = s.rfind(" where ")
694
+ # last_on = s.rfind(" on ")
695
+ # last_group = s.rfind(" group by ")
696
+ # last_order = s.rfind(" order by ")
697
+ # last_having = s.rfind(" having ")
698
+
699
+ # last_table_kw = max(last_from, last_join)
700
+ # last_col_kw = max(last_select, last_where, last_on, last_group, last_order, last_having)
701
+
702
+ # if last_table_kw < 0 and last_col_kw < 0:
703
+ # return None
704
+ # if last_table_kw > last_col_kw:
705
+ # return "table"
706
+ # if last_col_kw > last_table_kw:
707
+ # return "column"
708
+ # return None
709
+
710
+
711
+ # # =========================================================
712
+ # # 🌳 TRIE STRUCTURE
713
+ # # =========================================================
714
+ # class _TrieNode:
715
+ # __slots__ = ("children", "terminal")
716
+
717
+ # def __init__(self) -> None:
718
+ # self.children: Dict[int, _TrieNode] = {}
719
+ # self.terminal: bool = False
720
+
721
+ # def insert(self, token_ids: Sequence[int]) -> None:
722
+ # node = self
723
+ # for tid in token_ids:
724
+ # tid = int(tid)
725
+ # if tid not in node.children:
726
+ # node.children[tid] = _TrieNode()
727
+ # node = node.children[tid]
728
+ # node.terminal = True
729
+
730
+ # def walk(self, prefix: Sequence[int]) -> Optional["_TrieNode"]:
731
+ # node = self
732
+ # for tid in prefix:
733
+ # node = node.children.get(int(tid))
734
+ # if node is None:
735
+ # return None
736
+ # return node
737
+
738
+
739
+ # # =========================================================
740
+ # # 🔤 TOKEN ENCODING
741
+ # # =========================================================
742
+ # def _encode_identifier(tokenizer, name: str) -> List[int]:
743
+ # return tokenizer.encode(" " + name, add_special_tokens=False)
744
+
745
+
746
+ # def _build_trie(tokenizer, names: Iterable[str]) -> _TrieNode:
747
+ # trie = _TrieNode()
748
+ # for name in names:
749
+ # try:
750
+ # ids = _encode_identifier(tokenizer, name)
751
+ # if ids:
752
+ # trie.insert(ids)
753
+ # except Exception:
754
+ # continue
755
+ # return trie
756
+
757
+
758
+ # def _allow_always_token_ids(tokenizer) -> torch.Tensor:
759
+ # tokens = [",", ")", "(", ".", ";", "\n"]
760
+ # ids: Set[int] = set()
761
+
762
+ # for t in tokens:
763
+ # try:
764
+ # ids.update(tokenizer.encode(t, add_special_tokens=False))
765
+ # except:
766
+ # pass
767
+
768
+ # return torch.tensor(sorted(ids), dtype=torch.long)
769
+
770
+
771
+ # # =========================================================
772
+ # # 📦 PER-DB CACHE
773
+ # # =========================================================
774
+ # @dataclass
775
+ # class _PerDbTokenSets:
776
+ # fp: str
777
+ # table_trie: _TrieNode
778
+ # column_trie: _TrieNode
779
+ # allow_always: torch.Tensor
780
+
781
+
782
+ # _DB_CACHE: Dict[str, _PerDbTokenSets] = {}
783
+ # _DB_LOCK = threading.Lock()
784
+
785
+
786
+ # def _per_db_tokensets(tokenizer, graph: ConstraintGraph) -> _PerDbTokenSets:
787
+ # with _DB_LOCK:
788
+ # cached = _DB_CACHE.get(graph.db_path)
789
+ # if cached and cached.fp == graph.fingerprint:
790
+ # return cached
791
+
792
+ # obj = _PerDbTokenSets(
793
+ # fp=graph.fingerprint,
794
+ # table_trie=_build_trie(tokenizer, graph.tables),
795
+ # column_trie=_build_trie(tokenizer, graph.all_columns),
796
+ # allow_always=_allow_always_token_ids(tokenizer),
797
+ # )
798
+
799
+ # with _DB_LOCK:
800
+ # _DB_CACHE[graph.db_path] = obj
801
+
802
+ # return obj
803
+
804
+
805
+ # # =========================================================
806
+ # # 🚀 MAIN LOGITS PROCESSOR
807
+ # # =========================================================
808
+ # class BatchSchemaConstrainedLogitsProcessor(LogitsProcessor):
809
+ # def __init__(self, tokenizer, db_paths: Sequence[str], max_prefix_tokens: int = 48):
810
+ # self.tokenizer = tokenizer
811
+ # self.db_paths = list(db_paths)
812
+ # self.max_prefix_tokens = max_prefix_tokens
813
+
814
+ # self._graphs = [build_constraint_graph(p) for p in db_paths]
815
+ # self._token_sets = [_per_db_tokensets(tokenizer, g) for g in self._graphs]
816
+
817
+ # # 📊 Metrics (IMPORTANT FOR REPORT)
818
+ # self.total_steps = 0
819
+ # self.constrained_steps = 0
820
+
821
+ # def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
822
+ # batch = input_ids.size(0)
823
+
824
+ # for i in range(batch):
825
+ # self.total_steps += 1
826
+
827
+ # tail_ids = input_ids[i, -self.max_prefix_tokens:].tolist()
828
+ # prefix_text = self.tokenizer.decode(tail_ids, skip_special_tokens=True)
829
+
830
+ # expected = _infer_expected_identifier(prefix_text)
831
+ # if expected is None:
832
+ # continue
833
+
834
+ # self.constrained_steps += 1
835
+
836
+ # # =========================
837
+ # # SELECT TRIE
838
+ # # =========================
839
+ # if expected == "table":
840
+ # trie = self._token_sets[i].table_trie
841
+ # else:
842
+ # trie = self._token_sets[i].column_trie
843
+
844
+ # # =========================
845
+ # # PARTIAL TOKEN MATCH
846
+ # # =========================
847
+ # match = re.search(r"([A-Za-z_][A-Za-z0-9_]*)$", prefix_text)
848
+ # partial = match.group(1) if match else ""
849
+
850
+ # try:
851
+ # prefix_ids = _encode_identifier(self.tokenizer, partial) if partial else []
852
+ # except:
853
+ # continue
854
+
855
+ # node = trie.walk(prefix_ids)
856
+ # if node is None or node.terminal:
857
+ # continue
858
+
859
+ # allowed_next = list(node.children.keys())
860
+ # if not allowed_next:
861
+ # continue
862
+
863
+ # allowed_next = torch.tensor(allowed_next, device=scores.device)
864
+ # allow_always = self._token_sets[i].allow_always.to(scores.device)
865
+
866
+ # keep = torch.cat([allowed_next, allow_always])
867
+
868
+ # kept_scores = scores[i, keep].clone()
869
+ # scores[i, :] = -float("inf")
870
+ # scores[i, keep] = kept_scores
871
+
872
+ # return scores
873
+
874
+ # # =========================================================
875
+ # # 📊 METRICS FOR REPORT
876
+ # # =========================================================
877
+ # def get_constraint_stats(self):
878
+ # if self.total_steps == 0:
879
+ # return 0
880
+ # return self.constrained_steps / self.total_steps
881
+
882
+
883
+ # # =========================================================
884
+ # # 🔁 BACKWARD COMPATIBILITY
885
+ # # =========================================================
886
+ # class SchemaConstraintGraph:
887
+ # def __init__(self, db_path: str):
888
+ # self._graph = build_constraint_graph(db_path)
889
+ # self.tables = sorted(self._graph.tables)
890
+ # self.columns = sorted(self._graph.all_columns)
891
+
892
+
893
+ # class SchemaConstrainedLogitsProcessor(LogitsProcessor):
894
+ # def __init__(self, tokenizer, schema_graph: SchemaConstraintGraph):
895
+ # self.proc = BatchSchemaConstrainedLogitsProcessor(
896
+ # tokenizer, [schema_graph._graph.db_path]
897
+ # )
898
+
899
+ # def __call__(self, input_ids, scores):
900
+ # return self.proc(input_ids, scores)
901
+
902
+
903
+
904
+
905
+
906
+
907
+
908
+
909
+ # ********* after task 3
910
+
911
+ import re
912
+ import threading
913
+ from functools import lru_cache
914
+
915
+ import torch
916
+ from transformers import LogitsProcessor
917
+
918
+ from src.schema_utils import get_constraint_graph
919
+
920
+
921
+ _TOKEN_CACHE_LOCK = threading.Lock()
922
+ _TOKEN_ID_CACHE = {} # (id(tokenizer), db_path) -> (allowed_ids_tensor, always_allow_ids_tensor)
923
+
924
+
925
+ def _encode_variants(tokenizer, text: str) -> list[int]:
926
+ ids: list[int] = []
927
+ for variant in (text, " " + text):
928
+ try:
929
+ ids.extend(tokenizer.encode(variant, add_special_tokens=False))
930
+ except Exception:
931
+ continue
932
+ # de-dup while keeping order
933
+ seen = set()
934
+ out = []
935
+ for i in ids:
936
+ if int(i) not in seen:
937
+ seen.add(int(i))
938
+ out.append(int(i))
939
+ return out
940
+
941
+
942
+ def _always_allow_ids(tokenizer) -> list[int]:
943
+ """
944
+ Tokens we should never block, otherwise decoding can get stuck or generate garbage:
945
+ - EOS/PAD
946
+ - punctuation/operators needed for SQL formatting
947
+ - digits/quotes
948
+ """
949
+ ids: list[int] = []
950
+ for special in [getattr(tokenizer, "eos_token_id", None), getattr(tokenizer, "pad_token_id", None)]:
951
+ if special is not None:
952
+ ids.append(int(special))
953
+
954
+ # Common SQL punctuation/operators
955
+ pieces = [
956
+ " ", "\n", "\t",
957
+ ",", ".", "(", ")", ";",
958
+ "=", "!=", "<>", "<", ">", "<=", ">=",
959
+ "*", "+", "-", "/", "%",
960
+ "'", '"',
961
+ ]
962
+ for p in pieces:
963
+ ids.extend(_encode_variants(tokenizer, p))
964
+
965
+ # digits
966
+ for d in "0123456789":
967
+ ids.extend(_encode_variants(tokenizer, d))
968
+
969
+ seen = set()
970
+ out = []
971
+ for i in ids:
972
+ if int(i) not in seen:
973
+ seen.add(int(i))
974
+ out.append(int(i))
975
+ return out
976
+
977
+
978
+ def _infer_expected_identifier_tail(tail_text: str):
979
+ """
980
+ Returns ("table"|"column", partial_or_empty) if the tail looks like it's currently
981
+ emitting a table/column identifier. Otherwise returns None.
982
+ """
983
+ t = re.sub(r"\s+", " ", (tail_text or "")).lower()
984
+
985
+ m = re.search(r"(?:from|join)\s+([a-z_][a-z0-9_]*)?$", t)
986
+ if m:
987
+ partial = m.group(1) or ""
988
+ # ensure we are actually after keyword (not elsewhere)
989
+ if re.search(r"(?:from|join)\s*$", t) or partial:
990
+ return "table", partial
991
+
992
+ m = re.search(
993
+ r"(?:select|where|on|group by|order by|having)\s+([a-z_][a-z0-9_]*(?:\.[a-z_][a-z0-9_]*)?)?$",
994
+ t,
995
+ )
996
+ if m:
997
+ partial = m.group(1) or ""
998
+ if re.search(r"(?:select|where|on|group by|order by|having)\s*$", t) or partial:
999
+ return "column", partial
1000
+
1001
+ return None
1002
+
1003
+
1004
+ class SchemaConstrainedLogitsProcessor(LogitsProcessor):
1005
+ def __init__(self, tokenizer, db_path):
1006
+ self.tokenizer = tokenizer
1007
+
1008
+ graph = get_constraint_graph(db_path)
1009
+
1010
+ key = (id(tokenizer), str(db_path))
1011
+ with _TOKEN_CACHE_LOCK:
1012
+ cached = _TOKEN_ID_CACHE.get(key)
1013
+ if cached is None:
1014
+ allowed_tokens = set(graph.get("tables", set())) | set(graph.get("columns", set()))
1015
+
1016
+ sql_keywords = {
1017
+ "select", "from", "where", "join", "on",
1018
+ "group", "by", "order", "limit", "having",
1019
+ "and", "or", "desc", "asc",
1020
+ "count", "avg", "min", "max", "sum",
1021
+ "distinct", "as", "in", "like", "between",
1022
+ "is", "null",
1023
+ }
1024
+ allowed_tokens |= sql_keywords
1025
+
1026
+ allowed_ids: list[int] = []
1027
+ for tok in sorted(allowed_tokens):
1028
+ allowed_ids.extend(_encode_variants(tokenizer, tok))
1029
+ always_ids = _always_allow_ids(tokenizer)
1030
+
1031
+ allowed_ids_t = torch.tensor(sorted(set(allowed_ids)), dtype=torch.long)
1032
+ always_ids_t = torch.tensor(sorted(set(always_ids)), dtype=torch.long)
1033
+ cached = (allowed_ids_t, always_ids_t)
1034
+ with _TOKEN_CACHE_LOCK:
1035
+ _TOKEN_ID_CACHE[key] = cached
1036
+
1037
+ self._allowed_ids_t, self._always_ids_t = cached
1038
+
1039
+ def __call__(self, input_ids, scores):
1040
+ # Decode only a tail window for speed (beam search calls this a lot).
1041
+ try:
1042
+ tail_ids = input_ids[0][-128:]
1043
+ except Exception:
1044
+ tail_ids = input_ids[0]
1045
+ tail = self.tokenizer.decode(tail_ids, skip_special_tokens=True)
1046
+
1047
+ inferred = _infer_expected_identifier_tail(tail)
1048
+ if inferred is None:
1049
+ return scores
1050
+
1051
+ keep = torch.cat([self._allowed_ids_t.to(scores.device), self._always_ids_t.to(scores.device)])
1052
+ if keep.numel() == 0:
1053
+ return scores
1054
+
1055
+ kept_scores = scores[:, keep].clone()
1056
+ scores[:] = -float("inf")
1057
+ scores[:, keep] = kept_scores
1058
+ return scores
src/constrained_decoding_sample.py ADDED
@@ -0,0 +1,516 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # from __future__ import annotations
2
+
3
+ # import re
4
+ # import threading
5
+ # from dataclasses import dataclass
6
+ # from typing import Dict, Iterable, List, Optional, Sequence, Set
7
+
8
+ # import torch
9
+ # from transformers.generation.logits_process import LogitsProcessor
10
+
11
+ # from schema_constraints import ConstraintGraph, build_constraint_graph
12
+
13
+
14
+ # def _infer_expected_identifier(prefix_text: str) -> Optional[str]:
15
+ # s = re.sub(r"\s+", " ", prefix_text.lower())
16
+ # last_from = s.rfind(" from ")
17
+ # last_join = s.rfind(" join ")
18
+ # last_select = s.rfind(" select ")
19
+ # last_where = s.rfind(" where ")
20
+ # last_on = s.rfind(" on ")
21
+ # last_group = s.rfind(" group by ")
22
+ # last_order = s.rfind(" order by ")
23
+ # last_having = s.rfind(" having ")
24
+
25
+ # last_table_kw = max(last_from, last_join)
26
+ # last_col_kw = max(last_select, last_where, last_on, last_group, last_order, last_having)
27
+
28
+ # if last_table_kw < 0 and last_col_kw < 0:
29
+ # return None
30
+ # if last_table_kw > last_col_kw:
31
+ # return "table"
32
+ # if last_col_kw > last_table_kw:
33
+ # return "column"
34
+ # return None
35
+
36
+
37
+ # class _TrieNode:
38
+ # __slots__ = ("children", "terminal")
39
+
40
+ # def __init__(self) -> None:
41
+ # self.children: Dict[int, _TrieNode] = {}
42
+ # self.terminal: bool = False
43
+
44
+ # def insert(self, token_ids: Sequence[int]) -> None:
45
+ # node: _TrieNode = self
46
+ # for tid in token_ids:
47
+ # tid_i = int(tid)
48
+ # nxt = node.children.get(tid_i)
49
+ # if nxt is None:
50
+ # nxt = _TrieNode()
51
+ # node.children[tid_i] = nxt
52
+ # node = nxt
53
+ # node.terminal = True
54
+
55
+ # def walk(self, prefix: Sequence[int]) -> Optional["_TrieNode"]:
56
+ # node: _TrieNode = self
57
+ # for tid in prefix:
58
+ # node = node.children.get(int(tid)) # type: ignore[assignment]
59
+ # if node is None:
60
+ # return None
61
+ # return node
62
+
63
+
64
+ # def _encode_identifier(tokenizer, name: str) -> List[int]:
65
+ # # Leading space encourages word-start markers (e.g. "Ġ" in RoBERTa BPE).
66
+ # return tokenizer.encode(" " + name, add_special_tokens=False)
67
+
68
+
69
+ # def _build_trie(tokenizer, names: Iterable[str]) -> _TrieNode:
70
+ # trie = _TrieNode()
71
+ # for n in names:
72
+ # if not n:
73
+ # continue
74
+ # try:
75
+ # ids = _encode_identifier(tokenizer, n)
76
+ # except Exception:
77
+ # continue
78
+ # if ids:
79
+ # trie.insert(ids)
80
+ # return trie
81
+
82
+
83
+ # def _allow_always_token_ids(tokenizer) -> torch.Tensor:
84
+ # # Allow common delimiters so the model can end an identifier.
85
+ # toks = [",", ")", "(", "\n", ".", ";"]
86
+ # ids: Set[int] = set()
87
+ # for t in toks:
88
+ # try:
89
+ # for tid in tokenizer.encode(t, add_special_tokens=False):
90
+ # ids.add(int(tid))
91
+ # except Exception:
92
+ # continue
93
+ # return torch.tensor(sorted(ids), dtype=torch.long)
94
+
95
+
96
+ # @dataclass
97
+ # class _PerDbTokenSets:
98
+ # fp: str
99
+ # table_trie: _TrieNode
100
+ # column_trie: _TrieNode
101
+ # allow_always: torch.Tensor
102
+
103
+
104
+ # _DB_TOKENSET_LOCK = threading.Lock()
105
+ # _DB_TOKENSETS: Dict[str, _PerDbTokenSets] = {}
106
+
107
+
108
+ # def _per_db_tokensets(tokenizer, graph: ConstraintGraph) -> _PerDbTokenSets:
109
+ # with _DB_TOKENSET_LOCK:
110
+ # cached = _DB_TOKENSETS.get(graph.db_path)
111
+ # if cached is not None and cached.fp == graph.fingerprint:
112
+ # return cached
113
+
114
+ # out = _PerDbTokenSets(
115
+ # fp=graph.fingerprint,
116
+ # table_trie=_build_trie(tokenizer, graph.tables),
117
+ # column_trie=_build_trie(tokenizer, graph.all_columns),
118
+ # allow_always=_allow_always_token_ids(tokenizer),
119
+ # )
120
+ # with _DB_TOKENSET_LOCK:
121
+ # _DB_TOKENSETS[graph.db_path] = out
122
+ # return out
123
+
124
+
125
+ # class BatchSchemaConstrainedLogitsProcessor(LogitsProcessor):
126
+ # """
127
+ # Schema-aware constrained decoding per item in the generation batch.
128
+ # Uses a tokenizer-based trie so multi-token identifiers can be constrained.
129
+ # """
130
+
131
+ # def __init__(self, tokenizer, db_paths: Sequence[str], *, max_prefix_tokens: int = 48):
132
+ # self.tokenizer = tokenizer
133
+ # self.db_paths = list(db_paths)
134
+ # self.max_prefix_tokens = int(max_prefix_tokens)
135
+
136
+ # self._graphs = [build_constraint_graph(p) for p in self.db_paths]
137
+ # self._token_sets = [_per_db_tokensets(tokenizer, g) for g in self._graphs]
138
+
139
+ # def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
140
+ # if input_ids.dim() != 2 or scores.dim() != 2:
141
+ # return scores
142
+
143
+ # batch = input_ids.size(0)
144
+ # if batch != len(self._graphs):
145
+ # return scores
146
+
147
+ # for i in range(batch):
148
+ # tail_ids = input_ids[i, -self.max_prefix_tokens :].tolist()
149
+ # prefix_text = self.tokenizer.decode(tail_ids, skip_special_tokens=True)
150
+ # expected = _infer_expected_identifier(prefix_text)
151
+ # if expected is None:
152
+ # continue
153
+
154
+ # if expected == "table":
155
+ # m = re.search(r"(?:from|join)\s+([A-Za-z_][A-Za-z0-9_]*)$", prefix_text, flags=re.I)
156
+ # partial = m.group(1) if m else None
157
+ # if partial is None and not re.search(r"(?:from|join)\s*$", prefix_text, flags=re.I):
158
+ # continue
159
+ # trie = self._token_sets[i].table_trie
160
+ # else:
161
+ # m = re.search(
162
+ # r"(?:select|where|on|group by|order by|having)\s+([A-Za-z_][A-Za-z0-9_]*(?:\.[A-Za-z_][A-Za-z0-9_]*)?)$",
163
+ # prefix_text,
164
+ # flags=re.I,
165
+ # )
166
+ # partial = m.group(1) if m else None
167
+ # if partial is None and not re.search(
168
+ # r"(?:select|where|on|group by|order by|having)\s*$", prefix_text, flags=re.I
169
+ # ):
170
+ # continue
171
+ # trie = self._token_sets[i].column_trie
172
+
173
+ # if not partial:
174
+ # prefix_token_ids: List[int] = []
175
+ # else:
176
+ # try:
177
+ # prefix_token_ids = _encode_identifier(self.tokenizer, partial)
178
+ # except Exception:
179
+ # continue
180
+
181
+ # node = trie.walk(prefix_token_ids)
182
+ # if node is None or node.terminal:
183
+ # continue
184
+
185
+ # allowed_next = sorted(node.children.keys())
186
+ # if not allowed_next:
187
+ # continue
188
+
189
+ # allowed_next_t = torch.tensor(allowed_next, dtype=torch.long, device=scores.device)
190
+ # allow_always = self._token_sets[i].allow_always.to(scores.device)
191
+ # keep = torch.cat([allowed_next_t, allow_always]) if allow_always.numel() else allowed_next_t
192
+
193
+ # kept_scores = scores[i, keep].clone()
194
+ # scores[i, :] = -float("inf")
195
+ # scores[i, keep] = kept_scores
196
+
197
+ # return scores
198
+
199
+
200
+ # # Backwards-compatible names used elsewhere in the repo.
201
+ # class SchemaConstraintGraph:
202
+ # def __init__(self, db_path: str):
203
+ # self._graph = build_constraint_graph(db_path)
204
+ # self.tables = sorted(self._graph.tables)
205
+ # self.columns = sorted(self._graph.all_columns)
206
+
207
+
208
+ # class SchemaConstrainedLogitsProcessor(LogitsProcessor):
209
+ # def __init__(self, tokenizer, schema_graph: SchemaConstraintGraph):
210
+ # self._proc = BatchSchemaConstrainedLogitsProcessor(tokenizer, [schema_graph._graph.db_path])
211
+
212
+ # def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
213
+ # return self._proc(input_ids, scores)
214
+
215
+
216
+
217
+
218
+ # from __future__ import annotations
219
+
220
+ # import re
221
+ # import threading
222
+ # from dataclasses import dataclass
223
+ # from typing import Dict, Iterable, List, Optional, Sequence, Set
224
+
225
+ # import torch
226
+ # from transformers.generation.logits_process import LogitsProcessor
227
+
228
+ # from schema_constraints import ConstraintGraph, build_constraint_graph
229
+
230
+
231
+ # # =========================================================
232
+ # # 🔍 IDENTIFIER TYPE DETECTION
233
+ # # =========================================================
234
+ # def _infer_expected_identifier(prefix_text: str) -> Optional[str]:
235
+ # s = re.sub(r"\s+", " ", prefix_text.lower())
236
+
237
+ # last_from = s.rfind(" from ")
238
+ # last_join = s.rfind(" join ")
239
+ # last_select = s.rfind(" select ")
240
+ # last_where = s.rfind(" where ")
241
+ # last_on = s.rfind(" on ")
242
+ # last_group = s.rfind(" group by ")
243
+ # last_order = s.rfind(" order by ")
244
+ # last_having = s.rfind(" having ")
245
+
246
+ # last_table_kw = max(last_from, last_join)
247
+ # last_col_kw = max(last_select, last_where, last_on, last_group, last_order, last_having)
248
+
249
+ # if last_table_kw < 0 and last_col_kw < 0:
250
+ # return None
251
+ # if last_table_kw > last_col_kw:
252
+ # return "table"
253
+ # if last_col_kw > last_table_kw:
254
+ # return "column"
255
+ # return None
256
+
257
+
258
+ # # =========================================================
259
+ # # 🌳 TRIE STRUCTURE
260
+ # # =========================================================
261
+ # class _TrieNode:
262
+ # __slots__ = ("children", "terminal")
263
+
264
+ # def __init__(self) -> None:
265
+ # self.children: Dict[int, _TrieNode] = {}
266
+ # self.terminal: bool = False
267
+
268
+ # def insert(self, token_ids: Sequence[int]) -> None:
269
+ # node = self
270
+ # for tid in token_ids:
271
+ # tid = int(tid)
272
+ # if tid not in node.children:
273
+ # node.children[tid] = _TrieNode()
274
+ # node = node.children[tid]
275
+ # node.terminal = True
276
+
277
+ # def walk(self, prefix: Sequence[int]) -> Optional["_TrieNode"]:
278
+ # node = self
279
+ # for tid in prefix:
280
+ # node = node.children.get(int(tid))
281
+ # if node is None:
282
+ # return None
283
+ # return node
284
+
285
+
286
+ # # =========================================================
287
+ # # 🔤 TOKEN ENCODING
288
+ # # =========================================================
289
+ # def _encode_identifier(tokenizer, name: str) -> List[int]:
290
+ # return tokenizer.encode(" " + name, add_special_tokens=False)
291
+
292
+
293
+ # def _build_trie(tokenizer, names: Iterable[str]) -> _TrieNode:
294
+ # trie = _TrieNode()
295
+ # for name in names:
296
+ # try:
297
+ # ids = _encode_identifier(tokenizer, name)
298
+ # if ids:
299
+ # trie.insert(ids)
300
+ # except Exception:
301
+ # continue
302
+ # return trie
303
+
304
+
305
+ # def _allow_always_token_ids(tokenizer) -> torch.Tensor:
306
+ # tokens = [",", ")", "(", ".", ";", "\n"]
307
+ # ids: Set[int] = set()
308
+
309
+ # for t in tokens:
310
+ # try:
311
+ # ids.update(tokenizer.encode(t, add_special_tokens=False))
312
+ # except:
313
+ # pass
314
+
315
+ # return torch.tensor(sorted(ids), dtype=torch.long)
316
+
317
+
318
+ # # =========================================================
319
+ # # 📦 PER-DB CACHE
320
+ # # =========================================================
321
+ # @dataclass
322
+ # class _PerDbTokenSets:
323
+ # fp: str
324
+ # table_trie: _TrieNode
325
+ # column_trie: _TrieNode
326
+ # allow_always: torch.Tensor
327
+
328
+
329
+ # _DB_CACHE: Dict[str, _PerDbTokenSets] = {}
330
+ # _DB_LOCK = threading.Lock()
331
+
332
+
333
+ # def _per_db_tokensets(tokenizer, graph: ConstraintGraph) -> _PerDbTokenSets:
334
+ # with _DB_LOCK:
335
+ # cached = _DB_CACHE.get(graph.db_path)
336
+ # if cached and cached.fp == graph.fingerprint:
337
+ # return cached
338
+
339
+ # obj = _PerDbTokenSets(
340
+ # fp=graph.fingerprint,
341
+ # table_trie=_build_trie(tokenizer, graph.tables),
342
+ # column_trie=_build_trie(tokenizer, graph.all_columns),
343
+ # allow_always=_allow_always_token_ids(tokenizer),
344
+ # )
345
+
346
+ # with _DB_LOCK:
347
+ # _DB_CACHE[graph.db_path] = obj
348
+
349
+ # return obj
350
+
351
+
352
+ # # =========================================================
353
+ # # 🚀 MAIN LOGITS PROCESSOR
354
+ # # =========================================================
355
+ # class BatchSchemaConstrainedLogitsProcessor(LogitsProcessor):
356
+ # def __init__(self, tokenizer, db_paths: Sequence[str], max_prefix_tokens: int = 48):
357
+ # self.tokenizer = tokenizer
358
+ # self.db_paths = list(db_paths)
359
+ # self.max_prefix_tokens = max_prefix_tokens
360
+
361
+ # self._graphs = [build_constraint_graph(p) for p in db_paths]
362
+ # self._token_sets = [_per_db_tokensets(tokenizer, g) for g in self._graphs]
363
+
364
+ # # 📊 Metrics (IMPORTANT FOR REPORT)
365
+ # self.total_steps = 0
366
+ # self.constrained_steps = 0
367
+
368
+ # def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
369
+ # batch = input_ids.size(0)
370
+
371
+ # for i in range(batch):
372
+ # self.total_steps += 1
373
+
374
+ # tail_ids = input_ids[i, -self.max_prefix_tokens:].tolist()
375
+ # prefix_text = self.tokenizer.decode(tail_ids, skip_special_tokens=True)
376
+
377
+ # expected = _infer_expected_identifier(prefix_text)
378
+ # if expected is None:
379
+ # continue
380
+
381
+ # self.constrained_steps += 1
382
+
383
+ # # =========================
384
+ # # SELECT TRIE
385
+ # # =========================
386
+ # if expected == "table":
387
+ # trie = self._token_sets[i].table_trie
388
+ # else:
389
+ # trie = self._token_sets[i].column_trie
390
+
391
+ # # =========================
392
+ # # PARTIAL TOKEN MATCH
393
+ # # =========================
394
+ # match = re.search(r"([A-Za-z_][A-Za-z0-9_]*)$", prefix_text)
395
+ # partial = match.group(1) if match else ""
396
+
397
+ # try:
398
+ # prefix_ids = _encode_identifier(self.tokenizer, partial) if partial else []
399
+ # except:
400
+ # continue
401
+
402
+ # node = trie.walk(prefix_ids)
403
+ # if node is None or node.terminal:
404
+ # continue
405
+
406
+ # allowed_next = list(node.children.keys())
407
+ # if not allowed_next:
408
+ # continue
409
+
410
+ # allowed_next = torch.tensor(allowed_next, device=scores.device)
411
+ # allow_always = self._token_sets[i].allow_always.to(scores.device)
412
+
413
+ # keep = torch.cat([allowed_next, allow_always])
414
+
415
+ # kept_scores = scores[i, keep].clone()
416
+ # scores[i, :] = -float("inf")
417
+ # scores[i, keep] = kept_scores
418
+
419
+ # return scores
420
+
421
+ # # =========================================================
422
+ # # 📊 METRICS FOR REPORT
423
+ # # =========================================================
424
+ # def get_constraint_stats(self):
425
+ # if self.total_steps == 0:
426
+ # return 0
427
+ # return self.constrained_steps / self.total_steps
428
+
429
+
430
+ # # =========================================================
431
+ # # 🔁 BACKWARD COMPATIBILITY
432
+ # # =========================================================
433
+ # class SchemaConstraintGraph:
434
+ # def __init__(self, db_path: str):
435
+ # self._graph = build_constraint_graph(db_path)
436
+ # self.tables = sorted(self._graph.tables)
437
+ # self.columns = sorted(self._graph.all_columns)
438
+
439
+
440
+ # class SchemaConstrainedLogitsProcessor(LogitsProcessor):
441
+ # def __init__(self, tokenizer, schema_graph: SchemaConstraintGraph):
442
+ # self.proc = BatchSchemaConstrainedLogitsProcessor(
443
+ # tokenizer, [schema_graph._graph.db_path]
444
+ # )
445
+
446
+ # def __call__(self, input_ids, scores):
447
+ # return self.proc(input_ids, scores)
448
+
449
+
450
+
451
+
452
+
453
+
454
+
455
+
456
+ # ********* after task 3
457
+
458
+ import re
459
+ import torch
460
+ from transformers import LogitsProcessor
461
+ from src.schema_utils import get_constraint_graph
462
+
463
+
464
+ def _infer_expected_identifier(prefix_text: str):
465
+ s = prefix_text.lower()
466
+
467
+ if " from " in s or " join " in s:
468
+ return "table"
469
+ if any(k in s for k in ["select", "where", "on", "group by", "order by"]):
470
+ return "column"
471
+
472
+ return None
473
+
474
+
475
+ class SchemaConstrainedLogitsProcessor(LogitsProcessor):
476
+ def __init__(self, tokenizer, db_path):
477
+ self.tokenizer = tokenizer
478
+
479
+ graph = get_constraint_graph(db_path)
480
+
481
+ self.allowed_tokens = set(graph["tables"]) | set(graph["columns"])
482
+
483
+ self.sql_keywords = {
484
+ "select", "from", "where", "join", "on",
485
+ "group", "by", "order", "limit",
486
+ "and", "or", "desc", "asc",
487
+ "count", "avg", "min", "max", "sum", "*"
488
+ }
489
+
490
+ self.allowed_tokens |= self.sql_keywords
491
+
492
+ self.allowed_token_ids = set()
493
+ for token in self.allowed_tokens:
494
+ ids = tokenizer.encode(token, add_special_tokens=False)
495
+ for i in ids:
496
+ self.allowed_token_ids.add(i)
497
+
498
+ def __call__(self, input_ids, scores):
499
+
500
+ prefix = self.tokenizer.decode(input_ids[0], skip_special_tokens=True)
501
+
502
+ # 🔥 SOFT CONSTRAINT (FIX)
503
+ if len(prefix.strip()) < 10:
504
+ return scores
505
+
506
+ expected = _infer_expected_identifier(prefix)
507
+
508
+ if expected not in ["table", "column"]:
509
+ return scores
510
+
511
+ mask = torch.full_like(scores, float("-inf"))
512
+
513
+ for token_id in self.allowed_token_ids:
514
+ mask[:, token_id] = scores[:, token_id]
515
+
516
+ return mask
src/convert_to_hf_dataset.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from datasets import Dataset
2
+ import pandas as pd
3
+
4
+ df = pd.read_csv("../data/processed/train.csv")
5
+ ds = Dataset.from_pandas(df)
6
+ ds.save_to_disk("../data/processed/train")
7
+ print("DONE")
8
+
src/eval_baseline_codet5.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import sqlite3
3
+ from pathlib import Path
4
+ import torch
5
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
6
+
7
+ # ---------------- PROMPT (same style as training) ----------------
8
+ def build_prompt(question, schema):
9
+ return f"""translate English to SQL:
10
+
11
+ Schema:
12
+ {schema}
13
+
14
+ Question:
15
+ {question}
16
+
17
+ SQL:"""
18
+
19
+ # ---------------- LOAD SCHEMA ----------------
20
+ def load_schema(db_path):
21
+ conn = sqlite3.connect(db_path)
22
+ cursor = conn.cursor()
23
+
24
+ tables = cursor.execute(
25
+ "SELECT name FROM sqlite_master WHERE type='table';"
26
+ ).fetchall()
27
+
28
+ schema = ""
29
+ for (table,) in tables:
30
+ cols = cursor.execute(f"PRAGMA table_info({table});").fetchall()
31
+ col_names = [c[1] for c in cols]
32
+ schema += f"{table}({', '.join(col_names)})\n"
33
+
34
+ conn.close()
35
+ return schema
36
+
37
+ # ---------------- EXECUTION MATCH ----------------
38
+ def execution_match(pred_sql, gold_sql, db_path):
39
+ try:
40
+ conn = sqlite3.connect(db_path)
41
+ cur = conn.cursor()
42
+
43
+ cur.execute(pred_sql)
44
+ pred = cur.fetchall()
45
+
46
+ cur.execute(gold_sql)
47
+ gold = cur.fetchall()
48
+
49
+ conn.close()
50
+ return pred == gold
51
+
52
+ except Exception:
53
+ return False
54
+
55
+ # ---------------- MAIN ----------------
56
+ def main():
57
+ project_root = Path(__file__).resolve().parents[1]
58
+
59
+ dev_json = project_root / "data" / "dev.json"
60
+ db_root = project_root / "data" / "database"
61
+
62
+ device = "mps" if torch.backends.mps.is_available() else "cpu"
63
+
64
+ print("Loading BASE CodeT5...")
65
+ tokenizer = AutoTokenizer.from_pretrained("Salesforce/codet5-base")
66
+ model = AutoModelForSeq2SeqLM.from_pretrained("Salesforce/codet5-base").to(device)
67
+ model.eval()
68
+
69
+ with open(dev_json) as f:
70
+ dev = json.load(f)[:100]
71
+
72
+ correct = 0
73
+
74
+ print(f"\nEvaluating {len(dev)} samples...\n")
75
+
76
+ for i, ex in enumerate(dev, 1):
77
+ question = ex["question"]
78
+ db_id = ex["db_id"]
79
+ gold_sql = ex["query"]
80
+
81
+ db_path = db_root / db_id / f"{db_id}.sqlite"
82
+ schema = load_schema(db_path)
83
+
84
+ prompt = build_prompt(question, schema)
85
+
86
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512).to(device)
87
+
88
+ with torch.no_grad():
89
+ outputs = model.generate(
90
+ **inputs,
91
+ max_new_tokens=80,
92
+ num_beams=4,
93
+ do_sample=False
94
+ )
95
+
96
+ pred_sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
97
+
98
+ if "SQL:" in pred_sql:
99
+ pred_sql = pred_sql.split("SQL:")[-1].strip()
100
+
101
+ if execution_match(pred_sql, gold_sql, db_path):
102
+ correct += 1
103
+
104
+ if i % 10 == 0:
105
+ print(f"{i}/100 | Accuracy: {correct/i:.3f}")
106
+
107
+ print("\n=============================")
108
+ print(f"BASE MODEL ACCURACY: {correct}% / 100 = {correct}%")
109
+ print("=============================")
110
+
111
+ if __name__ == "__main__":
112
+ main()
src/eval_both_metrics.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import sqlite3
3
+ import torch
4
+ import re
5
+ import time
6
+ import argparse
7
+ from pathlib import Path
8
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
9
+ from peft import PeftModel
10
+
11
+ PROJECT_ROOT = Path(__file__).resolve().parents[1]
12
+ DB_ROOT = PROJECT_ROOT / "data" / "database"
13
+
14
+ # -------------------------------
15
+ # 1. NORMALIZATION FOR EXACT MATCH
16
+ # -------------------------------
17
+ def normalize_sql(sql):
18
+ """Cleans SQL to make Exact Match grading fair (ignores spacing/cases)."""
19
+ sql = sql.replace('"', "'") # Standardize quotes
20
+ sql = re.sub(r"\s+", " ", sql) # Remove extra spaces/newlines
21
+ sql = sql.strip().lower() # Lowercase everything
22
+ sql = sql.rstrip(";") # Remove trailing semicolons
23
+ return sql
24
+
25
+ # -------------------------------
26
+ # 2. EXECUTION ACCURACY CHECK
27
+ # -------------------------------
28
+ def check_execution(pred_sql, gold_sql, db_path):
29
+ """Runs both queries and checks if the output rows/columns match."""
30
+ try:
31
+ conn = sqlite3.connect(db_path)
32
+ # Handle bad characters in Spider DBs
33
+ conn.text_factory = lambda b: b.decode(errors='ignore')
34
+
35
+ # 5-second timeout
36
+ start_time = time.monotonic()
37
+ def timeout_handler():
38
+ return 1 if (time.monotonic() - start_time) > 5.0 else 0
39
+ conn.set_progress_handler(timeout_handler, 10000)
40
+
41
+ cursor = conn.cursor()
42
+
43
+ # Get Predicted Result
44
+ cursor.execute(pred_sql)
45
+ pred_res = cursor.fetchall()
46
+
47
+ # Get Gold Result
48
+ cursor.execute(gold_sql)
49
+ gold_res = cursor.fetchall()
50
+
51
+ conn.close()
52
+ return pred_res == gold_res
53
+ except Exception:
54
+ return False
55
+
56
+ # -------------------------------
57
+ # 3. LOAD SCHEMA
58
+ # -------------------------------
59
+ def load_schema(db_path):
60
+ conn = sqlite3.connect(db_path)
61
+ conn.text_factory = lambda b: b.decode(errors='ignore')
62
+ cursor = conn.cursor()
63
+ tables = cursor.execute("SELECT name FROM sqlite_master WHERE type='table';").fetchall()
64
+ schema = ""
65
+ for (table,) in tables:
66
+ cols = cursor.execute(f"PRAGMA table_info({table});").fetchall()
67
+ col_names = [c[1] for c in cols]
68
+ schema += f"{table}({', '.join(col_names)})\n"
69
+ conn.close()
70
+ return schema
71
+
72
+ # -------------------------------
73
+ # 4. MAIN PIPELINE
74
+ # -------------------------------
75
+ def main():
76
+ parser = argparse.ArgumentParser()
77
+ parser.add_argument("--adapter", type=str, required=True, help="Path to your SFT or RLHF checkpoint")
78
+ parser.add_argument("--num_samples", type=int, default=1034, help="How many samples to evaluate")
79
+ args = parser.parse_args()
80
+
81
+ device = "mps" if torch.backends.mps.is_available() else ("cuda" if torch.cuda.is_available() else "cpu")
82
+ base_model = "Salesforce/codet5-base"
83
+
84
+ print(f"\n🚀 Loading Model from: {args.adapter}")
85
+ tokenizer = AutoTokenizer.from_pretrained(args.adapter)
86
+ base = AutoModelForSeq2SeqLM.from_pretrained(base_model).to(device)
87
+ model = PeftModel.from_pretrained(base, args.adapter).to(device)
88
+ model = model.merge_and_unload()
89
+ model.eval()
90
+
91
+ dev_json = PROJECT_ROOT / "data" / "dev.json"
92
+ with open(dev_json) as f:
93
+ dev = json.load(f)[:args.num_samples]
94
+
95
+ em_correct = 0
96
+ ex_correct = 0
97
+ total = len(dev)
98
+
99
+ print(f"\n📊 Evaluating {total} queries for BOTH Exact Match and Execution Accuracy...\n")
100
+
101
+ for i, ex in enumerate(dev, 1):
102
+ question = ex["question"]
103
+ gold_sql = ex["query"]
104
+ db_id = ex["db_id"]
105
+ db_path = DB_ROOT / db_id / f"{db_id}.sqlite"
106
+
107
+ # Generate SQL
108
+ schema = load_schema(db_path)
109
+ prompt = f"Database Schema:\n{schema}\nTranslate English to SQL:\n{question}\nSQL:\n"
110
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
111
+
112
+ with torch.no_grad():
113
+ outputs = model.generate(**inputs, max_new_tokens=100, num_beams=4, do_sample=False)
114
+
115
+ pred_sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
116
+ if "SQL:" in pred_sql:
117
+ pred_sql = pred_sql.split("SQL:")[-1].strip()
118
+
119
+ # --- METRIC 1: EXACT MATCH ---
120
+ is_em = (normalize_sql(pred_sql) == normalize_sql(gold_sql))
121
+ if is_em:
122
+ em_correct += 1
123
+
124
+ # --- METRIC 2: EXECUTION ACCURACY ---
125
+ is_ex = check_execution(pred_sql, gold_sql, db_path)
126
+ if is_ex:
127
+ ex_correct += 1
128
+
129
+ if i % 50 == 0 or i == total:
130
+ print(f"Progress: {i}/{total} | Current EM: {(em_correct/i)*100:.2f}% | Current EX: {(ex_correct/i)*100:.2f}%")
131
+
132
+ # Final Results
133
+ final_em = (em_correct / total) * 100
134
+ final_ex = (ex_correct / total) * 100
135
+
136
+ print("\n==========================================")
137
+ print(f"🎯 FINAL RESULTS FOR: {args.adapter}")
138
+ print("==========================================")
139
+ print(f"Exact Match (EM) Accuracy : {final_em:.2f}%")
140
+ print(f"Execution (EX) Accuracy : {final_ex:.2f}%")
141
+ print("==========================================\n")
142
+
143
+ if __name__ == "__main__":
144
+ main()
src/eval_rl_fixed.py ADDED
@@ -0,0 +1,756 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import json
2
+ # import subprocess
3
+ # import sys
4
+ # import argparse
5
+ # import random
6
+ # import sqlite3
7
+ # import time
8
+ # import re
9
+ # import os
10
+ # from pathlib import Path
11
+
12
+ # import torch
13
+ # from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
14
+ # from peft import PeftModel
15
+
16
+ # from prompting import encode_prompt
17
+
18
+ # # -------------------------------
19
+ # # NORMALIZATION
20
+ # # -------------------------------
21
+ # def normalize_sql(sql):
22
+ # sql = sql.replace('"', "'")
23
+ # sql = re.sub(r"\s+", " ", sql)
24
+ # return sql.strip().lower().rstrip(";")
25
+
26
+
27
+ # # -------------------------------
28
+ # # 🔥 SAFE RESULT NORMALIZATION (FIX)
29
+ # # -------------------------------
30
+ # def normalize_result(res):
31
+ # try:
32
+ # return sorted([str(r) for r in res])
33
+ # except:
34
+ # return []
35
+
36
+
37
+ # # -------------------------------
38
+ # # EXECUTION CHECK (FIXED)
39
+ # # -------------------------------
40
+ # def check_execution(pred_sql, gold_sql, db_path):
41
+ # try:
42
+ # conn = sqlite3.connect(db_path)
43
+ # conn.text_factory = lambda b: b.decode(errors='ignore')
44
+
45
+ # start_time = time.monotonic()
46
+
47
+ # def timeout_handler():
48
+ # return 1 if (time.monotonic() - start_time) > 2.0 else 0
49
+
50
+ # conn.set_progress_handler(timeout_handler, 10000)
51
+
52
+ # cursor = conn.cursor()
53
+
54
+ # cursor.execute(pred_sql)
55
+ # pred_res = cursor.fetchall()
56
+
57
+ # cursor.execute(gold_sql)
58
+ # gold_res = cursor.fetchall()
59
+
60
+ # conn.close()
61
+
62
+ # # 🔥 FIXED COMPARISON
63
+ # return normalize_result(pred_res) == normalize_result(gold_res)
64
+
65
+ # except Exception:
66
+ # return False
67
+
68
+
69
+ # # -------------------------------
70
+ # # SPIDER PARSER
71
+ # # -------------------------------
72
+ # def _parse_spider_accuracy(stdout: str, metric_type: str):
73
+ # for line in stdout.splitlines():
74
+ # if metric_type == "exec" and line.strip().startswith("execution"):
75
+ # try:
76
+ # return float(line.split()[-1])
77
+ # except:
78
+ # pass
79
+ # elif metric_type == "match" and line.strip().startswith("exact"):
80
+ # try:
81
+ # return float(line.split()[-1])
82
+ # except:
83
+ # pass
84
+ # return None
85
+
86
+
87
+ # # -------------------------------
88
+ # # MAIN
89
+ # # -------------------------------
90
+ # def main():
91
+ # parser = argparse.ArgumentParser()
92
+ # parser.add_argument("--adapter", type=str, required=True)
93
+ # parser.add_argument("--num_samples", type=int, default=700)
94
+ # parser.add_argument("--shuffle_dev", action="store_true")
95
+ # parser.add_argument("--shuffle_seed", type=int, default=42)
96
+ # args = parser.parse_args()
97
+
98
+ # project_root = Path(__file__).resolve().parents[1]
99
+ # adapter_dir = project_root / args.adapter
100
+
101
+ # db_root = project_root / "data" / "database"
102
+ # table_json = project_root / "data" / "tables.json"
103
+ # dev_json = project_root / "data" / "dev.json"
104
+
105
+ # pred_path = project_root / "temp_predictions.txt"
106
+ # temp_gold_path = project_root / "temp_gold.sql"
107
+
108
+ # if not adapter_dir.exists():
109
+ # raise FileNotFoundError(f"Missing adapter dir: {adapter_dir}")
110
+
111
+ # device = "mps" if torch.backends.mps.is_available() else (
112
+ # "cuda" if torch.cuda.is_available() else "cpu"
113
+ # )
114
+ # print(f"Using device: {device}")
115
+
116
+ # BASE_MODEL = "Salesforce/codet5-base"
117
+ # tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
118
+
119
+ # if tokenizer.pad_token is None:
120
+ # tokenizer.pad_token = tokenizer.eos_token
121
+
122
+ # print(f"\n📦 Loading Model: {args.adapter}")
123
+
124
+ # base = AutoModelForSeq2SeqLM.from_pretrained(BASE_MODEL).to(device)
125
+
126
+ # adapter_for_peft = os.path.relpath(adapter_dir, project_root)
127
+
128
+ # model = PeftModel.from_pretrained(
129
+ # base,
130
+ # adapter_for_peft,
131
+ # local_files_only=True
132
+ # ).to(device)
133
+
134
+ # model = model.merge_and_unload()
135
+ # model.eval()
136
+
137
+ # # -------------------------------
138
+ # # LOAD DATA
139
+ # # -------------------------------
140
+ # with dev_json.open() as f:
141
+ # dev = json.load(f)
142
+
143
+ # if args.shuffle_dev:
144
+ # rng = random.Random(args.shuffle_seed)
145
+ # rng.shuffle(dev)
146
+
147
+ # dev = dev[: args.num_samples]
148
+ # total = len(dev)
149
+
150
+ # gen_kwargs = dict(
151
+ # max_new_tokens=160,
152
+ # num_beams=8,
153
+ # length_penalty=0.8,
154
+ # do_sample=False,
155
+ # early_stopping=True,
156
+ # pad_token_id=tokenizer.pad_token_id,
157
+ # eos_token_id=tokenizer.eos_token_id,
158
+ # )
159
+
160
+ # print(f"\n🚀 Evaluating {total} samples...\n")
161
+
162
+ # em_correct = 0
163
+ # ex_correct = 0
164
+
165
+ # with pred_path.open("w") as out_pred, temp_gold_path.open("w") as out_gold, torch.no_grad():
166
+ # for i, ex in enumerate(dev, start=1):
167
+
168
+ # db_id = ex["db_id"]
169
+ # question = ex["question"]
170
+ # gold_query = ex["query"]
171
+ # db_path = db_root / db_id / f"{db_id}.sqlite"
172
+
173
+ # # -------------------------------
174
+ # # GENERATE SQL
175
+ # # -------------------------------
176
+ # input_ids = encode_prompt(
177
+ # tokenizer,
178
+ # question,
179
+ # db_id,
180
+ # device=device,
181
+ # max_input_tokens=512
182
+ # )
183
+
184
+ # input_ids = input_ids.unsqueeze(0).to(device)
185
+ # attention_mask = (input_ids != tokenizer.pad_token_id).long().to(device)
186
+
187
+ # outputs = model.generate(
188
+ # input_ids=input_ids,
189
+ # attention_mask=attention_mask,
190
+ # **gen_kwargs
191
+ # )
192
+
193
+ # pred_sql = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
194
+
195
+ # # -------------------------------
196
+ # # SAVE FOR SPIDER EVAL
197
+ # # -------------------------------
198
+ # out_pred.write(f"{pred_sql}\n")
199
+ # out_gold.write(f"{gold_query}\t{db_id}\n")
200
+
201
+ # # -------------------------------
202
+ # # LIVE METRICS
203
+ # # -------------------------------
204
+ # if normalize_sql(pred_sql) == normalize_sql(gold_query):
205
+ # em_correct += 1
206
+
207
+ # if check_execution(pred_sql, gold_query, db_path):
208
+ # ex_correct += 1
209
+
210
+ # if i % 20 == 0 or i == total:
211
+ # print(
212
+ # f"Progress: {i}/{total} | "
213
+ # f"EM: {(em_correct/i)*100:.2f}% | "
214
+ # f"EX: {(ex_correct/i)*100:.2f}%"
215
+ # )
216
+
217
+ # print("\n🚀 Running Official Spider Evaluation...\n")
218
+
219
+ # eval_script = project_root / "spider_eval" / "evaluation.py"
220
+
221
+ # # EXACT MATCH
222
+ # cmd_match = [
223
+ # sys.executable, str(eval_script),
224
+ # "--gold", str(temp_gold_path),
225
+ # "--pred", str(pred_path),
226
+ # "--etype", "match",
227
+ # "--db", str(db_root),
228
+ # "--table", str(table_json),
229
+ # ]
230
+
231
+ # proc_match = subprocess.run(cmd_match, capture_output=True, text=True)
232
+ # exact_acc = _parse_spider_accuracy(proc_match.stdout, "match")
233
+
234
+ # # EXECUTION
235
+ # cmd_exec = [
236
+ # sys.executable, str(eval_script),
237
+ # "--gold", str(temp_gold_path),
238
+ # "--pred", str(pred_path),
239
+ # "--etype", "exec",
240
+ # "--db", str(db_root),
241
+ # "--table", str(table_json),
242
+ # ]
243
+
244
+ # proc_exec = subprocess.run(cmd_exec, capture_output=True, text=True)
245
+ # exec_acc = _parse_spider_accuracy(proc_exec.stdout, "exec")
246
+
247
+ # print("==========================================")
248
+ # print(f"🎯 OFFICIAL SPIDER RESULTS FOR: {args.adapter}")
249
+ # print("==========================================")
250
+
251
+ # print(f"Exact Match Accuracy : {exact_acc*100:.2f}%" if exact_acc else "EM parsing failed")
252
+ # print(f"Execution Accuracy : {exec_acc*100:.2f}%" if exec_acc else "EX parsing failed")
253
+
254
+ # print("==========================================\n")
255
+
256
+
257
+ # if __name__ == "__main__":
258
+ # main()
259
+
260
+
261
+
262
+
263
+
264
+
265
+ # import json
266
+ # import sqlite3
267
+ # import re
268
+ # import time
269
+ # import sys
270
+ # import argparse
271
+ # from pathlib import Path
272
+
273
+ # # ==========================================
274
+ # # PATH SETUP
275
+ # # ==========================================
276
+ # PROJECT_ROOT = Path(__file__).resolve().parents[1]
277
+ # if str(PROJECT_ROOT) not in sys.path:
278
+ # sys.path.insert(0, str(PROJECT_ROOT))
279
+
280
+ # from src.text2sql_engine import get_engine
281
+ # from src.sql_validator import validate_sql_schema
282
+
283
+ # # ==========================================
284
+ # # CONFIG
285
+ # # ==========================================
286
+ # DATA_PATH = PROJECT_ROOT / "data" / "dev.json"
287
+ # DB_ROOT = PROJECT_ROOT / "data" / "database"
288
+
289
+ # # ==========================================
290
+ # # NORMALIZATION
291
+ # # ==========================================
292
+ # def normalize_sql(sql):
293
+ # if not isinstance(sql, str):
294
+ # return ""
295
+ # sql = sql.replace('"', "'")
296
+ # sql = re.sub(r"\s+", " ", sql)
297
+ # return sql.strip().lower().rstrip(";")
298
+
299
+ # def normalize_result(res):
300
+ # try:
301
+ # return sorted([tuple(map(str, r)) for r in res])
302
+ # except:
303
+ # return []
304
+
305
+ # # ==========================================
306
+ # # EXECUTION
307
+ # # ==========================================
308
+ # def execute_sql(db_path, sql):
309
+ # try:
310
+ # conn = sqlite3.connect(db_path)
311
+
312
+ # start = time.time()
313
+ # def timeout():
314
+ # return 1 if (time.time() - start) > 2 else 0
315
+
316
+ # conn.set_progress_handler(timeout, 10000)
317
+
318
+ # cur = conn.cursor()
319
+ # cur.execute(sql)
320
+ # res = cur.fetchall()
321
+
322
+ # conn.close()
323
+ # return res
324
+
325
+ # except Exception:
326
+ # return None
327
+
328
+ # # ==========================================
329
+ # # EVALUATION
330
+ # # ==========================================
331
+ # def evaluate(engine, data, is_constrained=False, debug=False):
332
+
333
+ # attempted = 0
334
+ # total = 0
335
+ # exact_match = 0
336
+ # execution_match = 0
337
+ # constraint_ok = 0
338
+
339
+ # skipped_missing_db = 0
340
+ # skipped_exception = 0
341
+ # skipped_no_sql = 0
342
+
343
+ # total_time = 0
344
+
345
+ # for i, item in enumerate(data, 1):
346
+
347
+ # question = item.get("question", "")
348
+ # gold_sql = item.get("query", "")
349
+ # db_id = item.get("db_id", "")
350
+
351
+ # db_path = DB_ROOT / db_id / f"{db_id}.sqlite"
352
+
353
+ # if not db_path.exists():
354
+ # skipped_missing_db += 1
355
+ # continue
356
+
357
+ # try:
358
+ # start = time.time()
359
+ # result = engine.ask(question, db_id)
360
+ # total_time += (time.time() - start)
361
+ # except Exception:
362
+ # skipped_exception += 1
363
+ # continue
364
+
365
+ # if not isinstance(result, dict):
366
+ # continue
367
+
368
+ # pred_sql = result.get("sql", "")
369
+
370
+ # # DEBUG
371
+ # if debug:
372
+ # print(f"\nQ: {question}")
373
+ # print(f"PRED: {pred_sql}")
374
+ # print(f"GOLD: {gold_sql}")
375
+
376
+ # if not pred_sql:
377
+ # skipped_no_sql += 1
378
+ # continue
379
+
380
+ # attempted += 1
381
+ # total += 1
382
+
383
+ # # CONSTRAINT CHECK
384
+ # if is_constrained:
385
+ # try:
386
+ # is_valid, _ = validate_sql_schema(pred_sql, str(db_path))
387
+ # if is_valid:
388
+ # constraint_ok += 1
389
+ # except:
390
+ # pass
391
+
392
+ # # EXACT MATCH
393
+ # if normalize_sql(pred_sql) == normalize_sql(gold_sql):
394
+ # exact_match += 1
395
+
396
+ # # EXECUTION MATCH
397
+ # pred_res = execute_sql(str(db_path), pred_sql)
398
+ # gold_res = execute_sql(str(db_path), gold_sql)
399
+
400
+ # if pred_res is not None and gold_res is not None:
401
+ # if normalize_result(pred_res) == normalize_result(gold_res):
402
+ # execution_match += 1
403
+
404
+ # # PROGRESS
405
+ # if i % 10 == 0:
406
+ # print(
407
+ # f"[{i}/{len(data)}] "
408
+ # f"EM: {exact_match/max(total,1):.3f} | "
409
+ # f"EX: {execution_match/max(total,1):.3f} | "
410
+ # f"Constraint: {(constraint_ok/max(total,1)) if is_constrained else 0:.3f}"
411
+ # )
412
+
413
+ # avg_latency = total_time / max(attempted, 1)
414
+
415
+ # return {
416
+ # "exact_match": exact_match / total if total > 0 else 0,
417
+ # "execution_accuracy": execution_match / total if total > 0 else 0,
418
+ # "constraint_rate": (constraint_ok / total if (is_constrained and total > 0) else 0),
419
+ # "avg_latency": avg_latency,
420
+ # "total": total,
421
+ # "attempted": attempted,
422
+ # "skipped_missing_db": skipped_missing_db,
423
+ # "skipped_exception": skipped_exception,
424
+ # "skipped_no_sql": skipped_no_sql,
425
+ # }
426
+
427
+ # # ==========================================
428
+ # # MAIN
429
+ # # ==========================================
430
+ # if __name__ == "__main__":
431
+
432
+ # ap = argparse.ArgumentParser()
433
+ # ap.add_argument("--num-samples", type=int, default=100)
434
+ # ap.add_argument("--adapter", type=str, default="checkpoints/best_rlhf_model")
435
+ # ap.add_argument("--debug", action="store_true")
436
+ # args = ap.parse_args()
437
+
438
+ # print(f"\n📥 Loading dataset from {DATA_PATH}...")
439
+
440
+ # with open(str(DATA_PATH)) as f:
441
+ # data = json.load(f)[: args.num_samples]
442
+
443
+ # # ==========================================
444
+ # # 🔴 BASE MODEL
445
+ # # ==========================================
446
+ # print("\n🚀 Running BASE MODEL...\n")
447
+
448
+ # engine_base = get_engine(
449
+ # adapter_path="checkpoints/sft_adapter_codet5" , # 🔥 change this
450
+ # use_lora=True,
451
+ # use_constrained=False
452
+ # )
453
+
454
+ # res_base = evaluate(engine_base, data, is_constrained=False, debug=args.debug)
455
+
456
+ # # ==========================================
457
+ # # 🟡 RLHF (NO CONSTRAINT)
458
+ # # ==========================================
459
+ # print("\n🚀 Running RLHF (NO CONSTRAINT)...\n")
460
+
461
+ # engine_rlhf = get_engine(
462
+ # adapter_path="checkpoints/best_rlhf_model",
463
+ # use_lora=True,
464
+ # use_constrained=False
465
+ # )
466
+
467
+ # res_rlhf = evaluate(engine_rlhf, data, is_constrained=False, debug=args.debug)
468
+
469
+ # # ==========================================
470
+ # # 🟢 RLHF + CONSTRAINT
471
+ # # ==========================================
472
+ # print("\n🚀 Running RLHF + CONSTRAINED...\n")
473
+
474
+ # engine_const = get_engine(
475
+ # adapter_path="checkpoints/best_rlhf_model_2",
476
+ # use_lora=True,
477
+ # use_constrained=True
478
+ # )
479
+
480
+ # res_const = evaluate(engine_const, data, is_constrained=True, debug=args.debug)
481
+
482
+ # # ==========================================
483
+ # # FINAL RESULTS
484
+ # # ==========================================
485
+ # print("\n==========================================")
486
+ # print("🎯 FINAL RESULTS (3-WAY COMPARISON)")
487
+ # print("==========================================")
488
+
489
+ # print(f"Base Model → EM: {res_base['exact_match']*100:.2f}% | "
490
+ # f"EX: {res_base['execution_accuracy']*100:.2f}%")
491
+
492
+ # print(f"RLHF → EM: {res_rlhf['exact_match']*100:.2f}% | "
493
+ # f"EX: {res_rlhf['execution_accuracy']*100:.2f}%")
494
+
495
+ # print(f"RLHF + Constrain → EM: {res_const['exact_match']*100:.2f}% | "
496
+ # f"EX: {res_const['execution_accuracy']*100:.2f}% | "
497
+ # f"Constraint: {res_const['constraint_rate']*100:.2f}%")
498
+
499
+ # print("==========================================\n")
500
+
501
+
502
+ import json
503
+ import argparse
504
+ import sqlite3
505
+ import time
506
+ import re
507
+ import os
508
+ from pathlib import Path
509
+
510
+ import torch
511
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
512
+ from peft import PeftModel
513
+
514
+ # Import handling
515
+ try:
516
+ from prompting import encode_prompt
517
+ from src.sql_validator import validate_sql_schema
518
+ except ImportError:
519
+ import sys
520
+ sys.path.append(str(Path(__file__).resolve().parents[1]))
521
+ from src.prompting import encode_prompt
522
+ from src.sql_validator import validate_sql_schema
523
+
524
+ # =========================================================
525
+ # ERROR LOGGING
526
+ # =========================================================
527
+ ERROR_LOG_FILE = "results/error_logs.json"
528
+
529
+ def classify_error(sql, error_msg=""):
530
+ sql = sql.lower()
531
+ error_msg = str(error_msg).lower()
532
+
533
+ if "no such column" in error_msg:
534
+ return "wrong_column"
535
+ if "no such table" in error_msg:
536
+ return "wrong_table"
537
+ if "syntax error" in error_msg:
538
+ return "syntax_error"
539
+ if "ambiguous column" in error_msg:
540
+ return "ambiguous_column"
541
+ if "join" in sql and " on " not in sql:
542
+ return "missing_join"
543
+
544
+ return "other"
545
+
546
+ def log_error(question, sql, error, error_type):
547
+ os.makedirs(os.path.dirname(ERROR_LOG_FILE), exist_ok=True)
548
+
549
+ entry = {
550
+ "question": question,
551
+ "sql": sql,
552
+ "error": str(error),
553
+ "error_type": error_type,
554
+ "timestamp": time.time()
555
+ }
556
+
557
+ logs = []
558
+ if os.path.exists(ERROR_LOG_FILE):
559
+ try:
560
+ with open(ERROR_LOG_FILE, "r") as f:
561
+ content = f.read().strip()
562
+ if content:
563
+ logs = json.loads(content)
564
+ except:
565
+ logs = []
566
+
567
+ logs.append(entry)
568
+
569
+ with open(ERROR_LOG_FILE, "w") as f:
570
+ json.dump(logs, f, indent=2)
571
+
572
+ # =========================================================
573
+ # 🔥 FINAL FIX_SQL (BALANCED VERSION)
574
+ # =========================================================
575
+ def fix_sql(sql):
576
+ if not sql:
577
+ return "SELECT 1"
578
+
579
+ s = str(sql).strip()
580
+
581
+ # Extract SQL only
582
+ match = re.search(r"(?i)(select|with)[\s\S]*", s)
583
+ if match:
584
+ s = match.group(0)
585
+
586
+ s = s.split(";")[0].strip()
587
+
588
+ # NULL fixes
589
+ s = re.sub(r'(?i)=\s*null', 'IS NULL', s)
590
+ s = re.sub(r'(?i)!=\s*null', 'IS NOT NULL', s)
591
+
592
+ # Fix commas
593
+ s = re.sub(r',\s*,+', ',', s)
594
+ s = re.sub(r'(?i),\s*from', ' FROM', s)
595
+
596
+ # 🔥 LIGHT COLUMN SAFETY (main improvement)
597
+ if "select" in s.lower():
598
+ if len(re.findall(r'\w+\.\w+', s)) > 3:
599
+ s = re.sub(r'(?i)select\s+.*?\s+from', 'SELECT * FROM', s)
600
+
601
+ # 🔥 JOIN fix
602
+ if "join" in s.lower() and " on " not in s.lower():
603
+ s = re.sub(r'join\s+(\w+)', r'JOIN \1 ON 1=1', s, flags=re.I)
604
+
605
+ # Ensure valid SQL
606
+ if not s.lower().startswith(("select", "with")):
607
+ return "SELECT 1"
608
+
609
+ return s.strip()
610
+
611
+ # =========================================================
612
+ # NORMALIZATION
613
+ # =========================================================
614
+ def normalize_sql(sql):
615
+ if not sql:
616
+ return ""
617
+ return re.sub(r"\s+", " ", str(sql)).strip().lower()
618
+
619
+ def normalize_result(res):
620
+ if not res:
621
+ return []
622
+ try:
623
+ normalized = [tuple(sorted(str(x) for x in row)) for row in res]
624
+ return sorted(normalized)
625
+ except:
626
+ return sorted([str(r) for r in res])
627
+
628
+ # =========================================================
629
+ # EXECUTION HELPERS
630
+ # =========================================================
631
+ def is_executable(sql, db_path):
632
+ try:
633
+ conn = sqlite3.connect(db_path)
634
+ cur = conn.cursor()
635
+ cur.execute(sql)
636
+ conn.close()
637
+ return True
638
+ except:
639
+ return False
640
+
641
+ def check_execution(pred_sql, gold_sql, db_path, question):
642
+ try:
643
+ conn = sqlite3.connect(db_path)
644
+ conn.text_factory = lambda b: b.decode(errors='ignore')
645
+ cur = conn.cursor()
646
+
647
+ cur.execute(gold_sql)
648
+ gold_res = cur.fetchall()
649
+
650
+ cur.execute(pred_sql)
651
+ pred_res = cur.fetchall()
652
+
653
+ conn.close()
654
+
655
+ return normalize_result(pred_res) == normalize_result(gold_res)
656
+
657
+ except Exception as e:
658
+ error_type = classify_error(pred_sql, str(e))
659
+ log_error(question, pred_sql, str(e), error_type)
660
+ return False
661
+
662
+ # =========================================================
663
+ # MAIN
664
+ # =========================================================
665
+ def main():
666
+ parser = argparse.ArgumentParser()
667
+ parser.add_argument("--adapter", type=str, required=True)
668
+ parser.add_argument("--num_samples", type=int, default=700)
669
+ args = parser.parse_args()
670
+
671
+ project_root = Path(__file__).resolve().parent
672
+ if project_root.name in ["scripts", "src"]:
673
+ project_root = project_root.parent
674
+
675
+ db_root = project_root / "data" / "database"
676
+ dev_json = project_root / "data" / "dev.json"
677
+
678
+ device = "mps" if torch.backends.mps.is_available() else "cpu"
679
+
680
+ print(f"Loading model on {device}...")
681
+
682
+ tokenizer = AutoTokenizer.from_pretrained("Salesforce/codet5-base")
683
+ base_model = AutoModelForSeq2SeqLM.from_pretrained("Salesforce/codet5-base").to(device)
684
+
685
+ model = PeftModel.from_pretrained(base_model, args.adapter).to(device)
686
+ model = model.merge_and_unload()
687
+ model.eval()
688
+
689
+ with open(dev_json, "r") as f:
690
+ dev_data = json.load(f)[:args.num_samples]
691
+
692
+ em_correct = 0
693
+ ex_correct = 0
694
+ constraint_ok = 0
695
+
696
+ print(f"\n🚀 Evaluating {len(dev_data)} samples...\n")
697
+
698
+ for i, ex in enumerate(dev_data, 1):
699
+ db_id = ex["db_id"]
700
+ question = ex["question"]
701
+ gold_query = ex["query"]
702
+
703
+ db_path = db_root / db_id / f"{db_id}.sqlite"
704
+
705
+ input_tensor = encode_prompt(tokenizer, question, db_id, device=device).unsqueeze(0)
706
+
707
+ with torch.no_grad():
708
+ outputs = model.generate(
709
+ input_ids=input_tensor,
710
+ max_new_tokens=128,
711
+ num_beams=8,
712
+ num_return_sequences=8
713
+ )
714
+
715
+ best_sql = ""
716
+
717
+ # 🔥 EXECUTION-GUIDED SELECTION
718
+ for out in outputs:
719
+ raw_pred = tokenizer.decode(out, skip_special_tokens=True)
720
+ candidate_sql = fix_sql(raw_pred)
721
+
722
+ if is_executable(candidate_sql, str(db_path)):
723
+ best_sql = candidate_sql
724
+ break
725
+
726
+ if not best_sql:
727
+ best_sql = fix_sql(tokenizer.decode(outputs[0], skip_special_tokens=True))
728
+
729
+ try:
730
+ is_valid, _ = validate_sql_schema(best_sql, str(db_path))
731
+ except:
732
+ is_valid = False
733
+
734
+ if is_valid:
735
+ constraint_ok += 1
736
+
737
+ if normalize_sql(best_sql) == normalize_sql(gold_query):
738
+ em_correct += 1
739
+
740
+ if check_execution(best_sql, gold_query, str(db_path), question):
741
+ ex_correct += 1
742
+
743
+ if i % 50 == 0:
744
+ print(f"{i}/{len(dev_data)} done")
745
+
746
+ print("\n========================================")
747
+ print("🎯 FINAL EVALUATION RESULTS")
748
+ print("========================================")
749
+ print(f"Exact Match (EM): {(em_correct/len(dev_data))*100:.2f}%")
750
+ print(f"Execution Acc (EX): {(ex_correct/len(dev_data))*100:.2f}%")
751
+ print(f"Constraint Rate: {(constraint_ok/len(dev_data))*100:.2f}%")
752
+ print("========================================")
753
+ print(f"Errors logged to: {ERROR_LOG_FILE}")
754
+
755
+ if __name__ == "__main__":
756
+ main()
src/eval_rl_t5.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import sys
2
+ # import os
3
+ # sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
4
+ # import json
5
+
6
+ # import subprocess
7
+
8
+ # import argparse
9
+ # from pathlib import Path
10
+
11
+ # import torch
12
+ # from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
13
+ # from peft import PeftModel
14
+
15
+ # # IMPORTANT: must match training prompt format
16
+ # from prompting import build_prompt
17
+ # from schema_utils import get_schema as get_db_schema
18
+
19
+
20
+ # def _parse_exec_accuracy(stdout: str):
21
+ # for line in stdout.splitlines():
22
+ # if line.strip().startswith("execution"):
23
+ # parts = line.split()
24
+ # try:
25
+ # return float(parts[-1])
26
+ # except Exception:
27
+ # return None
28
+ # return None
29
+
30
+
31
+ # def main():
32
+ # parser = argparse.ArgumentParser()
33
+ # parser.add_argument("--adapter", type=str, default="checkpoints/best_rlhf_model")
34
+ # parser.add_argument("--num_samples", type=int, default=200)
35
+ # args = parser.parse_args()
36
+
37
+ # project_root = Path(__file__).resolve().parents[1]
38
+ # adapter_dir = project_root / args.adapter
39
+
40
+ # if not adapter_dir.exists():
41
+ # raise FileNotFoundError(f"Adapter not found: {adapter_dir}")
42
+
43
+ # db_root = project_root / "data" / "database"
44
+ # table_json = project_root / "data" / "tables.json"
45
+ # dev_json = project_root / "data" / "dev.json"
46
+ # gold_sql = project_root / "data" / "dev_gold.sql"
47
+ # pred_path = project_root / "predictions_rl.txt"
48
+
49
+ # device = "mps" if torch.backends.mps.is_available() else "cpu"
50
+
51
+ # # ---- LOAD MODEL (CodeT5 + LoRA) ----
52
+ # base_model = "Salesforce/codet5-base"
53
+
54
+ # tokenizer = AutoTokenizer.from_pretrained(str(adapter_dir))
55
+ # base = AutoModelForSeq2SeqLM.from_pretrained(base_model).to(device)
56
+ # model = PeftModel.from_pretrained(base, str(adapter_dir)).to(device)
57
+
58
+ # # merge LoRA for faster inference
59
+ # model = model.merge_and_unload()
60
+ # model.eval()
61
+ # model.config.use_cache = True
62
+
63
+ # if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None:
64
+ # tokenizer.pad_token = tokenizer.eos_token
65
+
66
+ # # ---- LOAD DATA ----
67
+ # with dev_json.open() as f:
68
+ # dev = json.load(f)
69
+
70
+ # dev = dev[: args.num_samples]
71
+
72
+ # gen_kwargs = dict(
73
+ # max_new_tokens=120,
74
+ # do_sample=False,
75
+ # num_beams=1,
76
+ # pad_token_id=tokenizer.pad_token_id,
77
+ # eos_token_id=tokenizer.eos_token_id,
78
+ # )
79
+
80
+ # print(f"Generating {len(dev)} predictions...")
81
+
82
+ # with pred_path.open("w") as out_f, torch.no_grad():
83
+ # for i, ex in enumerate(dev, start=1):
84
+ # db_id = ex["db_id"]
85
+ # question = ex["question"]
86
+
87
+ # db_path = db_root / db_id / f"{db_id}.sqlite"
88
+ # schema = get_db_schema(str(db_path))
89
+ # prompt = build_prompt(question, schema, use_schema=True)
90
+
91
+ # inputs = tokenizer(
92
+ # prompt,
93
+ # return_tensors="pt",
94
+ # truncation=True,
95
+ # max_length=512
96
+ # ).to(device)
97
+
98
+ # out = model.generate(**inputs, **gen_kwargs)
99
+ # pred_sql = tokenizer.decode(out[0], skip_special_tokens=True).strip()
100
+
101
+ # out_f.write(f"{pred_sql}\t{db_id}\n")
102
+
103
+ # if i % 20 == 0 or i == len(dev):
104
+ # print(f"{i}/{len(dev)} done")
105
+
106
+ # # ---- SPIDER OFFICIAL EVAL ----
107
+ # eval_script = project_root / "spider_eval" / "evaluation.py"
108
+
109
+ # cmd = [
110
+ # sys.executable,
111
+ # str(eval_script),
112
+ # "--gold",
113
+ # str(gold_sql),
114
+ # "--pred",
115
+ # str(pred_path),
116
+ # "--etype",
117
+ # "exec",
118
+ # "--db",
119
+ # str(db_root),
120
+ # "--table",
121
+ # str(table_json),
122
+ # ]
123
+
124
+ # print("\nRunning Spider execution evaluation...\n")
125
+ # proc = subprocess.run(cmd, capture_output=True, text=True)
126
+
127
+ # if proc.returncode != 0:
128
+ # print(proc.stdout)
129
+ # print(proc.stderr)
130
+ # sys.exit(proc.returncode)
131
+
132
+ # print(proc.stdout)
133
+
134
+ # acc = _parse_exec_accuracy(proc.stdout)
135
+ # if acc is not None:
136
+ # print(f"\nFINAL EXECUTION ACCURACY: {acc*100:.2f}%")
137
+ # else:
138
+ # print("Could not parse execution accuracy")
139
+
140
+
141
+ # if __name__ == "__main__":
142
+ # main()
143
+
144
+
145
+ import json
146
+ import sqlite3
147
+ import argparse
148
+ import time
149
+ from pathlib import Path
150
+ import torch
151
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
152
+ from peft import PeftModel
153
+
154
+ # ---------------- PROMPT (FIXED TO PERFECTLY MATCH RLHF TRAINING) ----------------
155
+ def build_prompt(question, schema):
156
+ return f"translate English to SQL:\n\nSchema:\n{schema}\n\nQuestion:\n{question}\n\nSQL:"
157
+
158
+ # ---------------- LOAD SCHEMA (FIXED TO MATCH TRAINING FORMAT) ----------------
159
+ def load_schema(db_path):
160
+ conn = sqlite3.connect(db_path)
161
+ cursor = conn.cursor()
162
+
163
+ tables = cursor.execute(
164
+ "SELECT name FROM sqlite_master WHERE type='table';"
165
+ ).fetchall()
166
+
167
+ schema = ""
168
+ for (table,) in tables:
169
+ cols = cursor.execute(f"PRAGMA table_info({table});").fetchall()
170
+ col_names = [c[1] for c in cols]
171
+ # Space-separated, not newline-separated, just like the RLHF script
172
+ schema += f"{table}({', '.join(col_names)}) "
173
+
174
+ conn.close()
175
+ return schema.strip()
176
+
177
+
178
+ # ---------------- EXECUTION CHECK WITH TIMEOUT ----------------
179
+ def execution_match(pred_sql, gold_sql, db_path):
180
+ try:
181
+ conn = sqlite3.connect(db_path)
182
+
183
+ # --- 5-SECOND TIMEOUT SO THE SCRIPT DOESN'T HANG ---
184
+ start_time = time.monotonic()
185
+ def timeout_handler():
186
+ return 1 if (time.monotonic() - start_time) > 5.0 else 0
187
+ conn.set_progress_handler(timeout_handler, 10000)
188
+
189
+ cur = conn.cursor()
190
+
191
+ cur.execute(pred_sql)
192
+ pred = cur.fetchall()
193
+
194
+ cur.execute(gold_sql)
195
+ gold = cur.fetchall()
196
+
197
+ conn.close()
198
+ return pred == gold
199
+
200
+ except Exception:
201
+ return False
202
+
203
+
204
+ # ---------------- MAIN ----------------
205
+ def main():
206
+ parser = argparse.ArgumentParser()
207
+ # 🎯 Set the default directly to your best RLHF model!
208
+ parser.add_argument("--adapter", type=str, default="checkpoints/rlhf_t5_best")
209
+ parser.add_argument("--num_samples", type=int, default=1000)
210
+ args = parser.parse_args()
211
+
212
+ project_root = Path(__file__).resolve().parents[1]
213
+
214
+ # Resolve adapter path safely
215
+ adapter_path = project_root / args.adapter
216
+
217
+ dev_json = project_root / "data" / "dev.json"
218
+ db_root = project_root / "data" / "database"
219
+
220
+ # 🎯 Added CUDA support
221
+ device = "mps" if torch.backends.mps.is_available() else ("cuda" if torch.cuda.is_available() else "cpu")
222
+
223
+ # load model
224
+ base_model = "t5-small"
225
+ print(f"Loading Base: {base_model}")
226
+ print(f"Loading Adapter: {adapter_path}")
227
+
228
+ tokenizer = AutoTokenizer.from_pretrained(str(adapter_path))
229
+ base = AutoModelForSeq2SeqLM.from_pretrained(base_model).to(device)
230
+ model = PeftModel.from_pretrained(base, str(adapter_path)).to(device)
231
+ model = model.merge_and_unload()
232
+
233
+ with open(dev_json) as f:
234
+ dev = json.load(f)[: args.num_samples]
235
+
236
+ correct = 0
237
+
238
+ print(f"Evaluating {len(dev)} examples...\n")
239
+
240
+ for i, ex in enumerate(dev, 1):
241
+ question = ex["question"]
242
+ db_id = ex["db_id"]
243
+ gold_sql = ex["query"]
244
+
245
+ db_path = db_root / db_id / f"{db_id}.sqlite"
246
+ schema = load_schema(db_path)
247
+
248
+ prompt = build_prompt(question, schema)
249
+
250
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
251
+
252
+ with torch.no_grad():
253
+ outputs = model.generate(
254
+ **inputs,
255
+ max_new_tokens=80,
256
+ do_sample=False,
257
+ num_beams=4,
258
+ )
259
+
260
+ pred_sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
261
+
262
+ if "SQL:" in pred_sql:
263
+ pred_sql = pred_sql.split("SQL:")[-1].strip()
264
+
265
+ match = execution_match(pred_sql, gold_sql, db_path)
266
+
267
+ if match:
268
+ correct += 1
269
+
270
+ if i % 10 == 0:
271
+ print(f"{i}/{len(dev)} | Acc: {correct/i:.3f}")
272
+
273
+ print("\n=============================")
274
+ print(f"FINAL EXECUTION ACCURACY: {correct/len(dev)*100:.2f}%")
275
+ print("=============================")
276
+
277
+
278
+ if __name__ == "__main__":
279
+ main()
src/eval_single_model.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import subprocess
3
+ import sys
4
+ import argparse
5
+ import random
6
+ import sqlite3
7
+ import time
8
+ import re
9
+ import matplotlib.pyplot as plt
10
+ import numpy as np
11
+ from pathlib import Path
12
+
13
+ import torch
14
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
15
+ from peft import PeftModel
16
+
17
+ # Assuming you have a prompting.py that has encode_prompt
18
+ from prompting import encode_prompt
19
+
20
+ # -------------------------------
21
+ # LIVE CHECK HELPERS
22
+ # -------------------------------
23
+ def normalize_sql(sql):
24
+ sql = sql.replace('"', "'")
25
+ sql = re.sub(r"\s+", " ", sql)
26
+ return sql.strip().lower().rstrip(";")
27
+
28
+ def check_execution(pred_sql, gold_sql, db_path):
29
+ try:
30
+ conn = sqlite3.connect(db_path)
31
+ conn.text_factory = lambda b: b.decode(errors='ignore')
32
+
33
+ start_time = time.monotonic()
34
+ def timeout_handler():
35
+ return 1 if (time.monotonic() - start_time) > 2.0 else 0
36
+ conn.set_progress_handler(timeout_handler, 10000)
37
+
38
+ cursor = conn.cursor()
39
+ cursor.execute(pred_sql)
40
+ pred_res = cursor.fetchall()
41
+
42
+ cursor.execute(gold_sql)
43
+ gold_res = cursor.fetchall()
44
+ conn.close()
45
+
46
+ return sorted(pred_res) == sorted(gold_res)
47
+ except Exception:
48
+ return False
49
+
50
+ # -------------------------------
51
+ # SPIDER PARSER
52
+ # -------------------------------
53
+ def _parse_spider_accuracy(stdout: str, metric_type: str) -> float | None:
54
+ for line in stdout.splitlines():
55
+ if metric_type == "exec" and line.strip().startswith("execution"):
56
+ try: return float(line.split()[-1])
57
+ except: pass
58
+ elif metric_type == "match" and line.strip().startswith("exact"):
59
+ try: return float(line.split()[-1])
60
+ except: pass
61
+ return None
62
+
63
+ # -------------------------------
64
+ # MAIN
65
+ # -------------------------------
66
+ def main():
67
+ parser = argparse.ArgumentParser()
68
+ parser.add_argument("--adapter", type=str, required=True, help="Path to your checkpoint")
69
+ parser.add_argument("--base_model", type=str, required=True, help="E.g., facebook/bart-base, t5-small")
70
+ parser.add_argument("--model_name", type=str, required=True, help="Name for the plot label (e.g., 'BART RLHF')")
71
+ parser.add_argument("--num_samples", type=int, default=700)
72
+ args = parser.parse_args()
73
+
74
+ project_root = Path(__file__).resolve().parents[1]
75
+ adapter_dir = project_root / args.adapter
76
+
77
+ db_root = project_root / "data" / "database"
78
+ table_json = project_root / "data" / "tables.json"
79
+ dev_json = project_root / "data" / "dev.json"
80
+
81
+ pred_path = project_root / "temp_predictions.txt"
82
+ temp_gold_path = project_root / "temp_gold.sql"
83
+
84
+ # NEW: Plot directory setup
85
+ plot_dir = project_root / "comparison_plots"
86
+ plot_dir.mkdir(parents=True, exist_ok=True)
87
+ results_json_path = plot_dir / "all_metrics.json"
88
+
89
+ if not adapter_dir.exists():
90
+ raise FileNotFoundError(f"Missing adapter dir: {adapter_dir}")
91
+
92
+ device = "mps" if torch.backends.mps.is_available() else ("cuda" if torch.cuda.is_available() else "cpu")
93
+ print(f"Loading Base Model: {args.base_model} on {device}...")
94
+
95
+ tokenizer = AutoTokenizer.from_pretrained(args.base_model)
96
+ if tokenizer.pad_token is None:
97
+ tokenizer.pad_token = tokenizer.eos_token
98
+
99
+ base = AutoModelForSeq2SeqLM.from_pretrained(args.base_model).to(device)
100
+ model = PeftModel.from_pretrained(base, str(adapter_dir)).to(device)
101
+ model = model.merge_and_unload()
102
+ model.eval()
103
+
104
+ with dev_json.open() as f:
105
+ dev = json.load(f)[: args.num_samples]
106
+ total = len(dev)
107
+
108
+ gen_kwargs = dict(
109
+ max_new_tokens=160,
110
+ num_beams=4,
111
+ do_sample=False,
112
+ early_stopping=True,
113
+ pad_token_id=tokenizer.pad_token_id,
114
+ eos_token_id=tokenizer.eos_token_id,
115
+ )
116
+
117
+ print(f"\n🚀 Generating and live-tracking {total} samples...\n")
118
+
119
+ em_correct = 0
120
+ ex_correct = 0
121
+
122
+ with pred_path.open("w") as out_pred, temp_gold_path.open("w") as out_gold, torch.no_grad():
123
+ for i, ex in enumerate(dev, start=1):
124
+ db_id = ex["db_id"]
125
+ question = ex["question"]
126
+ gold_query = ex["query"]
127
+ db_path = db_root / db_id / f"{db_id}.sqlite"
128
+
129
+ # Generate
130
+ input_ids = encode_prompt(tokenizer, question, db_id, device=device, max_input_tokens=512)
131
+ input_ids = input_ids.unsqueeze(0).to(device)
132
+ attention_mask = (input_ids != tokenizer.pad_token_id).long().to(device)
133
+
134
+ outputs = model.generate(input_ids=input_ids, attention_mask=attention_mask, **gen_kwargs)
135
+ pred_sql = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
136
+
137
+ out_pred.write(f"{pred_sql}\n")
138
+ out_gold.write(f"{gold_query}\t{db_id}\n")
139
+
140
+ # --- PRINT FIRST 3 EXAMPLES ---
141
+ if i <= 3:
142
+ print(f"--- 🔍 Example {i} ---")
143
+ print(f"Q : {question}")
144
+ print(f"Gold: {gold_query}")
145
+ print(f"Pred: {pred_sql}")
146
+ print("-" * 25)
147
+
148
+ # --- LIVE TRACKING CHECKS ---
149
+ if normalize_sql(pred_sql) == normalize_sql(gold_query):
150
+ em_correct += 1
151
+ if check_execution(pred_sql, gold_query, db_path):
152
+ ex_correct += 1
153
+
154
+ if i % 50 == 0 or i == total:
155
+ print(f"Progress: {i}/{total} | Current EM: {(em_correct/i)*100:.2f}% | Current EX: {(ex_correct/i)*100:.2f}%")
156
+
157
+ print("\nRunning Official Spider Evaluations...")
158
+ eval_script = project_root / "spider_eval" / "evaluation.py"
159
+
160
+ proc_match = subprocess.run([sys.executable, str(eval_script), "--gold", str(temp_gold_path), "--pred", str(pred_path), "--etype", "match", "--db", str(db_root), "--table", str(table_json)], capture_output=True, text=True)
161
+ exact_acc = _parse_spider_accuracy(proc_match.stdout, "match")
162
+
163
+ proc_exec = subprocess.run([sys.executable, str(eval_script), "--gold", str(temp_gold_path), "--pred", str(pred_path), "--etype", "exec", "--db", str(db_root), "--table", str(table_json)], capture_output=True, text=True)
164
+ exec_acc = _parse_spider_accuracy(proc_exec.stdout, "exec")
165
+
166
+ print("\n==========================================")
167
+ print(f"🎯 RESULTS FOR: {args.model_name}")
168
+ print("==========================================")
169
+ exact_val = exact_acc * 100 if exact_acc else 0
170
+ exec_val = exec_acc * 100 if exec_acc else 0
171
+ print(f"Exact Match : {exact_val:.2f}%")
172
+ print(f"Execution : {exec_val:.2f}%")
173
+ print("==========================================\n")
174
+
175
+ # -------------------------------
176
+ # SAVE JSON & GENERATE PLOT
177
+ # -------------------------------
178
+ if results_json_path.exists():
179
+ with open(results_json_path, 'r') as f:
180
+ all_results = json.load(f)
181
+ else:
182
+ all_results = {}
183
+
184
+ all_results[args.model_name] = {"EM": exact_val, "EX": exec_val}
185
+
186
+ with open(results_json_path, 'w') as f:
187
+ json.dump(all_results, f, indent=4)
188
+
189
+ labels = list(all_results.keys())
190
+ em_vals = [all_results[k]["EM"] for k in labels]
191
+ ex_vals = [all_results[k]["EX"] for k in labels]
192
+
193
+ x = np.arange(len(labels))
194
+ width = 0.35
195
+
196
+ plt.figure(figsize=(max(8, len(labels) * 1.5), 6))
197
+ plt.bar(x - width/2, em_vals, width, label='Exact Match', color='#3498db')
198
+ plt.bar(x + width/2, ex_vals, width, label='Execution', color='#2ecc71')
199
+
200
+ plt.ylabel('Accuracy (%)', fontweight='bold')
201
+ plt.title('Model Comparison: Exact Match vs Execution Accuracy', fontweight='bold', fontsize=14)
202
+ plt.xticks(x, labels, rotation=45, ha="right")
203
+ plt.legend()
204
+ plt.ylim(0, max(max(em_vals, default=0), max(ex_vals, default=0)) + 15)
205
+ plt.grid(axis='y', linestyle='--', alpha=0.7)
206
+
207
+ # Attach labels to bars
208
+ for i in range(len(labels)):
209
+ plt.text(x[i] - width/2, em_vals[i] + 1, f"{em_vals[i]:.1f}%", ha='center', fontsize=9)
210
+ plt.text(x[i] + width/2, ex_vals[i] + 1, f"{ex_vals[i]:.1f}%", ha='center', fontsize=9)
211
+
212
+ plt.tight_layout()
213
+ plot_path = plot_dir / "accuracy_comparison.png"
214
+ plt.savefig(plot_path, dpi=300)
215
+ print(f"📈 Updated comparison plot saved to: {plot_path}")
216
+
217
+ if __name__ == "__main__":
218
+ main()
src/evaluate_model_codet5.py ADDED
@@ -0,0 +1,392 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # from __future__ import annotations
2
+
3
+ # import json
4
+ # import subprocess
5
+ # import sys
6
+ # import argparse
7
+ # import sqlite3
8
+ # import random
9
+ # from pathlib import Path
10
+
11
+ # import torch
12
+ # from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
13
+ # from peft import PeftModel
14
+
15
+ # from prompting import encode_prompt
16
+
17
+
18
+ # def _parse_exec_accuracy(stdout: str) -> float | None:
19
+ # for line in stdout.splitlines():
20
+ # if line.strip().startswith("execution"):
21
+ # try:
22
+ # return float(line.split()[-1])
23
+ # except:
24
+ # return None
25
+ # return None
26
+
27
+
28
+ # def main():
29
+
30
+ # # ---------------- ARGUMENTS ----------------
31
+ # parser = argparse.ArgumentParser()
32
+ # parser.add_argument("--adapter", type=str, default="checkpoints/sft_adapter_codet5")
33
+ # parser.add_argument("--num_samples", type=int, default=1000)
34
+ # parser.add_argument("--shuffle_dev", action="store_true")
35
+ # parser.add_argument("--shuffle_seed", type=int, default=42)
36
+ # parser.add_argument("--accuracy_log", type=str, default="")
37
+ # args = parser.parse_args()
38
+
39
+ # project_root = Path(__file__).resolve().parents[1]
40
+ # adapter_dir = project_root / args.adapter
41
+
42
+ # db_root = project_root / "data" / "database"
43
+ # table_json = project_root / "data" / "tables.json"
44
+ # dev_json = project_root / "data" / "dev.json"
45
+ # gold_sql = project_root / "data" / "dev_gold.sql"
46
+ # pred_path = project_root / "predictions.txt"
47
+
48
+ # if not adapter_dir.exists():
49
+ # raise FileNotFoundError(f"Missing adapter dir: {adapter_dir}")
50
+
51
+ # # ---------------- DEVICE ----------------
52
+ # device = "mps" if torch.backends.mps.is_available() else (
53
+ # "cuda" if torch.cuda.is_available() else "cpu"
54
+ # )
55
+ # print("Using device:", device)
56
+
57
+ # # ---------------- LOAD MODEL ----------------
58
+ # BASE_MODEL = "Salesforce/codet5-base"
59
+
60
+ # tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
61
+
62
+ # if tokenizer.pad_token is None:
63
+ # tokenizer.pad_token = tokenizer.eos_token
64
+
65
+ # base = AutoModelForSeq2SeqLM.from_pretrained(BASE_MODEL).to(device)
66
+ # model = PeftModel.from_pretrained(base, str(adapter_dir)).to(device)
67
+
68
+ # model = model.merge_and_unload()
69
+ # model.eval()
70
+
71
+ # # ---------------- LOAD DATA ----------------
72
+ # with dev_json.open() as f:
73
+ # dev = json.load(f)
74
+
75
+ # if args.shuffle_dev:
76
+ # rng = random.Random(args.shuffle_seed)
77
+ # rng.shuffle(dev)
78
+
79
+ # dev = dev[: args.num_samples]
80
+
81
+ # # ---------------- GENERATION CONFIG ----------------
82
+ # gen_kwargs = dict(
83
+ # max_new_tokens=160,
84
+ # num_beams=4,
85
+ # do_sample=False,
86
+ # early_stopping=True,
87
+ # pad_token_id=tokenizer.pad_token_id,
88
+ # eos_token_id=tokenizer.eos_token_id,
89
+ # )
90
+
91
+ # print("Generating predictions...\n")
92
+
93
+ # correct = 0
94
+ # total = len(dev)
95
+ # accuracy_log_fh = None
96
+
97
+ # if args.accuracy_log:
98
+ # accuracy_log_path = (project_root / args.accuracy_log).resolve()
99
+ # accuracy_log_path.parent.mkdir(parents=True, exist_ok=True)
100
+ # accuracy_log_fh = accuracy_log_path.open("w")
101
+ # print(f"Writing running accuracy log to: {accuracy_log_path}")
102
+
103
+ # with pred_path.open("w") as out_f, torch.no_grad():
104
+
105
+ # for i, ex in enumerate(dev, start=1):
106
+
107
+ # db_id = ex["db_id"]
108
+ # question = ex["question"]
109
+ # gold_query = ex["query"]
110
+
111
+ # input_ids = encode_prompt(
112
+ # tokenizer,
113
+ # question,
114
+ # db_id,
115
+ # device=device,
116
+ # max_input_tokens=512,
117
+ # )
118
+
119
+ # input_ids = input_ids.unsqueeze(0).to(device)
120
+ # attention_mask = (input_ids != tokenizer.pad_token_id).long().to(device)
121
+
122
+ # outputs = model.generate(
123
+ # input_ids=input_ids,
124
+ # attention_mask=attention_mask,
125
+ # **gen_kwargs
126
+ # )
127
+
128
+ # pred_sql = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
129
+ # out_f.write(f"{pred_sql}\t{db_id}\n")
130
+
131
+ # # ---------------- LIVE EXECUTION CHECK ----------------
132
+ # try:
133
+ # db_path = db_root / db_id / f"{db_id}.sqlite"
134
+
135
+ # conn = sqlite3.connect(db_path)
136
+ # cursor = conn.cursor()
137
+
138
+ # cursor.execute(pred_sql)
139
+ # pred_rows = cursor.fetchall()
140
+
141
+ # cursor.execute(gold_query)
142
+ # gold_rows = cursor.fetchall()
143
+
144
+ # conn.close()
145
+
146
+ # if sorted(pred_rows) == sorted(gold_rows):
147
+ # correct += 1
148
+
149
+ # except Exception:
150
+ # pass # execution failed
151
+
152
+ # # 🔥 PRINT EVERY 10
153
+ # if i % 10 == 0 or i == total:
154
+ # current_acc = correct / i
155
+ # line = f"{i}/{total} | Acc: {current_acc:.3f}"
156
+ # print(line)
157
+ # if accuracy_log_fh is not None:
158
+ # accuracy_log_fh.write(line + "\n")
159
+
160
+ # if accuracy_log_fh is not None:
161
+ # accuracy_log_fh.close()
162
+
163
+ # print("\nGeneration finished.\n")
164
+
165
+ # # ---------------- OFFICIAL SPIDER EVAL ----------------
166
+ # eval_script = project_root / "spider_eval" / "evaluation.py"
167
+
168
+ # cmd = [
169
+ # sys.executable,
170
+ # str(eval_script),
171
+ # "--gold", str(gold_sql),
172
+ # "--pred", str(pred_path),
173
+ # "--etype", "exec",
174
+ # "--db", str(db_root),
175
+ # "--table", str(table_json),
176
+ # ]
177
+
178
+ # print("Running Spider evaluation...")
179
+ # proc = subprocess.run(cmd, capture_output=True, text=True)
180
+
181
+ # print(proc.stdout)
182
+
183
+ # exec_acc = _parse_exec_accuracy(proc.stdout)
184
+ # if exec_acc is not None:
185
+ # print(f"\n🎯 Official Execution Accuracy: {exec_acc*100:.2f}%")
186
+ # else:
187
+ # print("Could not parse accuracy.")
188
+
189
+
190
+ # if __name__ == "__main__":
191
+ # main()
192
+
193
+ import json
194
+ import subprocess
195
+ import sys
196
+ import argparse
197
+ import random
198
+ import sqlite3
199
+ import time
200
+ import re
201
+ from pathlib import Path
202
+
203
+ import torch
204
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
205
+ from peft import PeftModel
206
+
207
+ # Assuming you have a prompting.py that has encode_prompt
208
+ from prompting import encode_prompt
209
+
210
+ # -------------------------------
211
+ # LIVE CHECK HELPERS
212
+ # -------------------------------
213
+ def normalize_sql(sql):
214
+ """Basic normalization for the live progress bar."""
215
+ sql = sql.replace('"', "'")
216
+ sql = re.sub(r"\s+", " ", sql)
217
+ return sql.strip().lower().rstrip(";")
218
+
219
+ def check_execution(pred_sql, gold_sql, db_path):
220
+ """Basic execution check for the live progress bar."""
221
+ try:
222
+ conn = sqlite3.connect(db_path)
223
+ conn.text_factory = lambda b: b.decode(errors='ignore')
224
+
225
+ # 2-second timeout so the live tracker doesn't freeze forever
226
+ start_time = time.monotonic()
227
+ def timeout_handler():
228
+ return 1 if (time.monotonic() - start_time) > 2.0 else 0
229
+ conn.set_progress_handler(timeout_handler, 10000)
230
+
231
+ cursor = conn.cursor()
232
+ cursor.execute(pred_sql)
233
+ pred_res = cursor.fetchall()
234
+
235
+ cursor.execute(gold_sql)
236
+ gold_res = cursor.fetchall()
237
+ conn.close()
238
+
239
+ # Simple sorted check for the live tracker
240
+ return sorted(pred_res) == sorted(gold_res)
241
+ except Exception:
242
+ return False
243
+
244
+ # -------------------------------
245
+ # SPIDER PARSER
246
+ # -------------------------------
247
+ def _parse_spider_accuracy(stdout: str, metric_type: str) -> float | None:
248
+ for line in stdout.splitlines():
249
+ if metric_type == "exec" and line.strip().startswith("execution"):
250
+ try: return float(line.split()[-1])
251
+ except: pass
252
+ elif metric_type == "match" and line.strip().startswith("exact"):
253
+ try: return float(line.split()[-1])
254
+ except: pass
255
+ return None
256
+
257
+ # -------------------------------
258
+ # MAIN
259
+ # -------------------------------
260
+ def main():
261
+ parser = argparse.ArgumentParser()
262
+ parser.add_argument("--adapter", type=str, required=True, help="Path to your SFT or RLHF checkpoint")
263
+ parser.add_argument("--num_samples", type=int, default=1034, help="Number of samples to evaluate")
264
+ parser.add_argument("--shuffle_dev", action="store_true")
265
+ parser.add_argument("--shuffle_seed", type=int, default=42)
266
+ args = parser.parse_args()
267
+
268
+ project_root = Path(__file__).resolve().parents[1]
269
+ adapter_dir = project_root / args.adapter
270
+
271
+ db_root = project_root / "data" / "database"
272
+ table_json = project_root / "data" / "tables.json"
273
+ dev_json = project_root / "data" / "dev.json"
274
+
275
+ pred_path = project_root / "temp_predictions.txt"
276
+ temp_gold_path = project_root / "temp_gold.sql"
277
+
278
+ if not adapter_dir.exists():
279
+ raise FileNotFoundError(f"Missing adapter dir: {adapter_dir}")
280
+
281
+ device = "mps" if torch.backends.mps.is_available() else ("cuda" if torch.cuda.is_available() else "cpu")
282
+ print(f"Using device: {device}")
283
+
284
+ BASE_MODEL = "Salesforce/codet5-base"
285
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
286
+ if tokenizer.pad_token is None:
287
+ tokenizer.pad_token = tokenizer.eos_token
288
+
289
+ print(f"Loading Model: {args.adapter}...")
290
+ base = AutoModelForSeq2SeqLM.from_pretrained(BASE_MODEL).to(device)
291
+ model = PeftModel.from_pretrained(base, str(adapter_dir)).to(device)
292
+ model = model.merge_and_unload()
293
+ model.eval()
294
+
295
+ with dev_json.open() as f:
296
+ dev = json.load(f)
297
+
298
+ if args.shuffle_dev:
299
+ rng = random.Random(args.shuffle_seed)
300
+ rng.shuffle(dev)
301
+
302
+ dev = dev[: args.num_samples]
303
+ total = len(dev)
304
+
305
+ gen_kwargs = dict(
306
+ max_new_tokens=160,
307
+ num_beams=4,
308
+ do_sample=False,
309
+ early_stopping=True,
310
+ pad_token_id=tokenizer.pad_token_id,
311
+ eos_token_id=tokenizer.eos_token_id,
312
+ )
313
+
314
+ print(f"\n🚀 Generating and live-tracking {total} samples...\n")
315
+
316
+ em_correct = 0
317
+ ex_correct = 0
318
+
319
+ with pred_path.open("w") as out_pred, temp_gold_path.open("w") as out_gold, torch.no_grad():
320
+ for i, ex in enumerate(dev, start=1):
321
+ db_id = ex["db_id"]
322
+ question = ex["question"]
323
+ gold_query = ex["query"]
324
+ db_path = db_root / db_id / f"{db_id}.sqlite"
325
+
326
+ # Generate
327
+ input_ids = encode_prompt(tokenizer, question, db_id, device=device, max_input_tokens=512)
328
+ input_ids = input_ids.unsqueeze(0).to(device)
329
+ attention_mask = (input_ids != tokenizer.pad_token_id).long().to(device)
330
+
331
+ outputs = model.generate(input_ids=input_ids, attention_mask=attention_mask, **gen_kwargs)
332
+ pred_sql = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
333
+
334
+ # Write to files for official spider eval later
335
+ out_pred.write(f"{pred_sql}\n")
336
+ out_gold.write(f"{gold_query}\t{db_id}\n")
337
+
338
+ # --- LIVE TRACKING CHECKS ---
339
+ if normalize_sql(pred_sql) == normalize_sql(gold_query):
340
+ em_correct += 1
341
+ if check_execution(pred_sql, gold_query, db_path):
342
+ ex_correct += 1
343
+
344
+ # Print progress every 50 loops
345
+ if i % 50 == 0 or i == total:
346
+ print(f"Progress: {i}/{total} | Current EM: {(em_correct/i)*100:.2f}% | Current EX: {(ex_correct/i)*100:.2f}%")
347
+
348
+ print("\nGeneration finished. Running Official Spider Evaluations for final numbers...\n")
349
+
350
+ eval_script = project_root / "spider_eval" / "evaluation.py"
351
+
352
+ # 1. RUN EXACT MATCH EVAL
353
+ cmd_match = [
354
+ sys.executable, str(eval_script),
355
+ "--gold", str(temp_gold_path),
356
+ "--pred", str(pred_path),
357
+ "--etype", "match",
358
+ "--db", str(db_root),
359
+ "--table", str(table_json),
360
+ ]
361
+ proc_match = subprocess.run(cmd_match, capture_output=True, text=True)
362
+ exact_acc = _parse_spider_accuracy(proc_match.stdout, "match")
363
+
364
+ # 2. RUN EXECUTION EVAL
365
+ cmd_exec = [
366
+ sys.executable, str(eval_script),
367
+ "--gold", str(temp_gold_path),
368
+ "--pred", str(pred_path),
369
+ "--etype", "exec",
370
+ "--db", str(db_root),
371
+ "--table", str(table_json),
372
+ ]
373
+ proc_exec = subprocess.run(cmd_exec, capture_output=True, text=True)
374
+ exec_acc = _parse_spider_accuracy(proc_exec.stdout, "exec")
375
+
376
+ print("==========================================")
377
+ print(f"🎯 OFFICIAL SPIDER RESULTS FOR: {args.adapter}")
378
+ print("==========================================")
379
+
380
+ if exact_acc is not None:
381
+ print(f"Exact Set Match Accuracy : {exact_acc*100:.2f}%")
382
+ else:
383
+ print("Exact Set Match Accuracy : Could not parse output")
384
+
385
+ if exec_acc is not None:
386
+ print(f"Execution Accuracy : {exec_acc*100:.2f}%")
387
+ else:
388
+ print("Execution Accuracy : Could not parse output")
389
+ print("==========================================\n")
390
+
391
+ if __name__ == "__main__":
392
+ main()
src/evaluate_model_t5_small_sft.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import subprocess
5
+ import sys
6
+ import argparse
7
+ import re
8
+ import sqlite3
9
+ from pathlib import Path
10
+
11
+ import torch
12
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
13
+ from peft import PeftModel
14
+ from prompting import encode_prompt
15
+
16
+
17
+ # ---------------- PARSE ACC ----------------
18
+ def _parse_exec_accuracy(stdout: str) -> float | None:
19
+ for line in stdout.splitlines():
20
+ if line.strip().startswith("execution"):
21
+ try:
22
+ return float(line.split()[-1])
23
+ except:
24
+ return None
25
+ return None
26
+
27
+
28
+ # ---------------- CLEAN SQL ----------------
29
+ def clean_prediction(pred_sql: str) -> str:
30
+ pred_sql = pred_sql.strip()
31
+
32
+ if "SQL:" in pred_sql:
33
+ pred_sql = pred_sql.split("SQL:")[-1]
34
+
35
+ pred_sql = pred_sql.replace('"', "'")
36
+ pred_sql = re.sub(r"\s+", " ", pred_sql).strip()
37
+
38
+ if not pred_sql.endswith(";"):
39
+ pred_sql += ";"
40
+
41
+ return pred_sql
42
+
43
+
44
+ def main():
45
+
46
+ parser = argparse.ArgumentParser()
47
+ parser.add_argument("--adapter", type=str, default="checkpoints/sft_t5")
48
+ parser.add_argument("--num_samples", type=int, default=1000)
49
+ args = parser.parse_args()
50
+
51
+ project_root = Path(__file__).resolve().parents[1]
52
+ adapter_dir = project_root / args.adapter
53
+
54
+ db_root = project_root / "data/database"
55
+ table_json = project_root / "data/tables.json"
56
+ dev_json = project_root / "data/dev.json"
57
+ gold_sql = project_root / "data/dev_gold.sql"
58
+ pred_path = project_root / "pred.sql"
59
+
60
+ if not adapter_dir.exists():
61
+ raise FileNotFoundError(f"Missing adapter dir: {adapter_dir}")
62
+
63
+ # ---------------- DEVICE ----------------
64
+ device = "mps" if torch.backends.mps.is_available() else (
65
+ "cuda" if torch.cuda.is_available() else "cpu"
66
+ )
67
+ print("Using device:", device)
68
+
69
+ # ---------------- LOAD MODEL ----------------
70
+ BASE_MODEL = "t5-small"
71
+
72
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
73
+ base = AutoModelForSeq2SeqLM.from_pretrained(BASE_MODEL).to(device)
74
+
75
+ model = PeftModel.from_pretrained(base, str(adapter_dir)).to(device)
76
+ model = model.merge_and_unload()
77
+ model.eval()
78
+
79
+ if tokenizer.pad_token_id is None:
80
+ tokenizer.pad_token = tokenizer.eos_token
81
+
82
+ # ---------------- LOAD DATA ----------------
83
+ with dev_json.open() as f:
84
+ dev = json.load(f)[: args.num_samples]
85
+
86
+ print("Generating predictions...\n")
87
+
88
+ correct = 0
89
+ total = len(dev)
90
+
91
+ # ---------------- GENERATE + LIVE EXEC ----------------
92
+ with pred_path.open("w") as out_f, torch.no_grad():
93
+
94
+ for i, ex in enumerate(dev, start=1):
95
+
96
+ db_id = ex["db_id"]
97
+ question = ex["question"]
98
+ gold_query = ex["query"]
99
+
100
+ prompt_ids = encode_prompt(
101
+ tokenizer,
102
+ question,
103
+ db_id,
104
+ device=device,
105
+ max_input_tokens=512,
106
+ )
107
+
108
+ input_ids = prompt_ids.unsqueeze(0).to(device)
109
+ attention_mask = (input_ids != tokenizer.pad_token_id).long().to(device)
110
+
111
+ outputs = model.generate(
112
+ input_ids=input_ids,
113
+ attention_mask=attention_mask,
114
+ max_new_tokens=160,
115
+ num_beams=4,
116
+ do_sample=False,
117
+ early_stopping=True,
118
+ )
119
+
120
+ pred_sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
121
+ pred_sql = clean_prediction(pred_sql)
122
+
123
+ out_f.write(pred_sql + "\n")
124
+
125
+ # -------- LIVE EXECUTION CHECK --------
126
+ try:
127
+ db_path = db_root / db_id / f"{db_id}.sqlite"
128
+
129
+ conn = sqlite3.connect(db_path)
130
+ cursor = conn.cursor()
131
+
132
+ cursor.execute(pred_sql)
133
+ pred_rows = cursor.fetchall()
134
+
135
+ cursor.execute(gold_query)
136
+ gold_rows = cursor.fetchall()
137
+
138
+ conn.close()
139
+
140
+ if sorted(pred_rows) == sorted(gold_rows):
141
+ correct += 1
142
+
143
+ except Exception:
144
+ pass # execution failed
145
+
146
+ # 🔥 PRINT EVERY 10
147
+ if i % 10 == 0 or i == total:
148
+ current_acc = correct / i
149
+ print(f"{i}/{total} | Acc: {current_acc:.3f}")
150
+
151
+ print("\nGeneration finished.\n")
152
+
153
+ # ---------------- SPIDER EVAL ----------------
154
+ eval_script = project_root / "spider_eval/evaluation.py"
155
+
156
+ cmd = [
157
+ sys.executable,
158
+ str(eval_script),
159
+ "--gold", str(gold_sql),
160
+ "--pred", str(pred_path),
161
+ "--etype", "exec",
162
+ "--db", str(db_root),
163
+ "--table", str(table_json),
164
+ ]
165
+
166
+ print("Running Spider evaluation...")
167
+ proc = subprocess.run(cmd, capture_output=True, text=True)
168
+
169
+ print(proc.stdout)
170
+
171
+ exec_acc = _parse_exec_accuracy(proc.stdout)
172
+ if exec_acc is not None:
173
+ print(f"\n🎯 Official Execution Accuracy: {exec_acc*100:.2f}%")
174
+ else:
175
+ print("Could not parse accuracy.")
176
+
177
+
178
+ if __name__ == "__main__":
179
+ main()
src/evaluate_rl_bart.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import json
3
+ import sqlite3
4
+ import argparse
5
+ import time
6
+ from pathlib import Path
7
+ import torch
8
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
9
+ from peft import PeftModel
10
+
11
+ # ---------------- PROMPT (IDENTICAL TO TRAINING) ----------------
12
+ def build_prompt(question, schema):
13
+ return f"""
14
+ Database Schema:
15
+ {schema}
16
+
17
+ Translate English to SQL:
18
+ {question}
19
+ SQL:
20
+ """
21
+
22
+ # ---------------- LOAD SCHEMA ----------------
23
+ def load_schema(db_path):
24
+ conn = sqlite3.connect(db_path)
25
+ cursor = conn.cursor()
26
+
27
+ tables = cursor.execute(
28
+ "SELECT name FROM sqlite_master WHERE type='table';"
29
+ ).fetchall()
30
+
31
+ schema = ""
32
+ for (table,) in tables:
33
+ cols = cursor.execute(f"PRAGMA table_info({table});").fetchall()
34
+ col_names = [c[1] for c in cols]
35
+ schema += f"{table}({', '.join(col_names)})\n"
36
+
37
+ conn.close()
38
+ return schema
39
+
40
+
41
+ # ---------------- EXECUTION CHECK WITH TIMEOUT ----------------
42
+ def execution_match(pred_sql, gold_sql, db_path):
43
+ try:
44
+ conn = sqlite3.connect(db_path)
45
+
46
+ # --- 5-SECOND TIMEOUT SO EVALUATION DOESN'T FREEZE ---
47
+ start_time = time.monotonic()
48
+ def timeout_handler():
49
+ return 1 if (time.monotonic() - start_time) > 5.0 else 0
50
+ conn.set_progress_handler(timeout_handler, 10000)
51
+
52
+ cur = conn.cursor()
53
+
54
+ cur.execute(pred_sql)
55
+ pred = cur.fetchall()
56
+
57
+ cur.execute(gold_sql)
58
+ gold = cur.fetchall()
59
+
60
+ conn.close()
61
+ return pred == gold
62
+
63
+ except Exception:
64
+ return False
65
+
66
+
67
+ # ---------------- MAIN ----------------
68
+ def main():
69
+ parser = argparse.ArgumentParser()
70
+ parser.add_argument("--adapter", type=str, required=True)
71
+ parser.add_argument("--num_samples", type=int, default=1034)
72
+ args = parser.parse_args()
73
+
74
+ project_root = Path(__file__).resolve().parents[1]
75
+
76
+ dev_json = project_root / "data" / "dev.json"
77
+ db_root = project_root / "data" / "database"
78
+
79
+ # 🎯 Added CUDA support for Nvidia GPUs
80
+ device = "mps" if torch.backends.mps.is_available() else ("cuda" if torch.cuda.is_available() else "cpu")
81
+
82
+ # load model
83
+ base_model = "facebook/bart-base"
84
+ print(f"Loading Base: {base_model}")
85
+ print(f"Loading Adapter: {args.adapter}")
86
+
87
+ tokenizer = AutoTokenizer.from_pretrained(args.adapter)
88
+ base = AutoModelForSeq2SeqLM.from_pretrained(base_model).to(device)
89
+ model = PeftModel.from_pretrained(base, args.adapter).to(device)
90
+ model = model.merge_and_unload()
91
+
92
+ with open(dev_json) as f:
93
+ dev = json.load(f)[: args.num_samples]
94
+
95
+ correct = 0
96
+
97
+ print(f"Evaluating {len(dev)} examples...\n")
98
+
99
+ for i, ex in enumerate(dev, 1):
100
+ question = ex["question"]
101
+ db_id = ex["db_id"]
102
+ gold_sql = ex["query"]
103
+
104
+ db_path = db_root / db_id / f"{db_id}.sqlite"
105
+ schema = load_schema(db_path)
106
+
107
+ prompt = build_prompt(question, schema)
108
+
109
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
110
+
111
+ with torch.no_grad():
112
+ outputs = model.generate(
113
+ **inputs,
114
+ max_new_tokens=80,
115
+ do_sample=False,
116
+ num_beams=4,
117
+ )
118
+
119
+ pred_sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
120
+
121
+ if "SQL:" in pred_sql:
122
+ pred_sql = pred_sql.split("SQL:")[-1].strip()
123
+
124
+ match = execution_match(pred_sql, gold_sql, db_path)
125
+
126
+ if match:
127
+ correct += 1
128
+
129
+ if i % 10 == 0:
130
+ print(f"{i}/{len(dev)} | Acc: {correct/i:.3f}")
131
+
132
+ print("\n=============================")
133
+ print(f"FINAL EXECUTION ACCURACY: {correct/len(dev)*100:.2f}%")
134
+ print("=============================")
135
+
136
+
137
+ if __name__ == "__main__":
138
+ main()
src/evaluate_sft_bart.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import subprocess
5
+ import sys
6
+ import argparse
7
+ import re
8
+ import sqlite3
9
+ from pathlib import Path
10
+
11
+ import torch
12
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
13
+ from peft import PeftModel
14
+ from prompting import encode_prompt
15
+
16
+
17
+ # ---------------- SQL CLEAN ----------------
18
+ def extract_sql(text: str) -> str:
19
+ text = text.strip()
20
+
21
+ if "SQL:" in text:
22
+ text = text.split("SQL:")[-1]
23
+
24
+ match = re.search(r"(SELECT .*?)(?:$)", text, re.IGNORECASE | re.DOTALL)
25
+ if match:
26
+ text = match.group(1)
27
+
28
+ text = text.replace('"', "'")
29
+ text = re.sub(r"\s+", " ", text).strip()
30
+
31
+ if not text.endswith(";"):
32
+ text += ";"
33
+
34
+ return text
35
+
36
+
37
+ # ---------------- ROBUST ACC PARSER ----------------
38
+ def parse_exec_accuracy(stdout: str):
39
+ for line in stdout.splitlines():
40
+ if "execution" in line.lower():
41
+ numbers = re.findall(r"\d+\.\d+", line)
42
+ if numbers:
43
+ return float(numbers[-1])
44
+ return None
45
+
46
+
47
+ def main():
48
+
49
+ parser = argparse.ArgumentParser()
50
+ parser.add_argument("--adapter", type=str, default="checkpoints/sft_best_bart_2")
51
+ parser.add_argument("--num_samples", type=int, default=1000)
52
+ args = parser.parse_args()
53
+
54
+ project_root = Path(__file__).resolve().parents[1]
55
+ adapter_dir = project_root / args.adapter
56
+
57
+ if not adapter_dir.exists():
58
+ raise FileNotFoundError(f"Adapter not found: {adapter_dir}")
59
+
60
+ db_root = project_root / "data/database"
61
+ table_json = project_root / "data/tables.json"
62
+ dev_json = project_root / "data/dev.json"
63
+ gold_sql_file = project_root / "data/dev_gold.sql"
64
+ pred_sql_file = project_root / "pred.sql"
65
+
66
+ device = "mps" if torch.backends.mps.is_available() else (
67
+ "cuda" if torch.cuda.is_available() else "cpu"
68
+ )
69
+ print("Using device:", device)
70
+
71
+ # -------- LOAD MODEL --------
72
+ print("Loading tokenizer...")
73
+ tokenizer = AutoTokenizer.from_pretrained(adapter_dir)
74
+
75
+ BASE_MODEL = "facebook/bart-base"
76
+ print(f"Loading base model {BASE_MODEL}...")
77
+ base_model = AutoModelForSeq2SeqLM.from_pretrained(BASE_MODEL).to(device)
78
+
79
+ print("Loading LoRA adapter...")
80
+ model = PeftModel.from_pretrained(base_model, adapter_dir).to(device)
81
+ model = model.merge_and_unload()
82
+ model.eval()
83
+
84
+ if tokenizer.pad_token_id is None:
85
+ tokenizer.pad_token = tokenizer.eos_token
86
+
87
+ # -------- LOAD DATA --------
88
+ with open(dev_json) as f:
89
+ dev = json.load(f)[: args.num_samples]
90
+
91
+ print("Generating SQL predictions...\n")
92
+
93
+ correct = 0
94
+ total = len(dev)
95
+
96
+ with open(pred_sql_file, "w") as f, torch.no_grad():
97
+
98
+ for i, ex in enumerate(dev, 1):
99
+
100
+ question = ex["question"]
101
+ db_id = ex["db_id"]
102
+ gold_query = ex["query"]
103
+
104
+ prompt_ids = encode_prompt(
105
+ tokenizer,
106
+ question,
107
+ db_id,
108
+ device=device,
109
+ max_input_tokens=512,
110
+ )
111
+
112
+ input_ids = prompt_ids.unsqueeze(0).to(device)
113
+ attention_mask = (input_ids != tokenizer.pad_token_id).long().to(device)
114
+
115
+ outputs = model.generate(
116
+ input_ids=input_ids,
117
+ attention_mask=attention_mask,
118
+ max_new_tokens=160,
119
+ num_beams=4,
120
+ do_sample=False,
121
+ )
122
+
123
+ pred = tokenizer.decode(outputs[0], skip_special_tokens=True)
124
+ pred_sql = extract_sql(pred)
125
+
126
+ f.write(f"{pred_sql}\t{db_id}\n")
127
+
128
+ # -------- LIVE EXECUTION CHECK --------
129
+ try:
130
+ db_path = db_root / db_id / f"{db_id}.sqlite"
131
+
132
+ conn = sqlite3.connect(db_path)
133
+ cursor = conn.cursor()
134
+
135
+ cursor.execute(pred_sql)
136
+ pred_rows = cursor.fetchall()
137
+
138
+ cursor.execute(gold_query)
139
+ gold_rows = cursor.fetchall()
140
+
141
+ conn.close()
142
+
143
+ # order insensitive comparison
144
+ if sorted(pred_rows) == sorted(gold_rows):
145
+ correct += 1
146
+
147
+ except Exception:
148
+ pass # execution failed
149
+
150
+ if i % 10 == 0 or i == total:
151
+ current_acc = correct / i
152
+ print(f"{i}/{total} | Acc: {current_acc:.3f}")
153
+
154
+ print("\nGeneration finished.\n")
155
+
156
+ # -------- RUN OFFICIAL SPIDER EVAL --------
157
+ eval_script = project_root / "spider_eval/evaluation.py"
158
+ if (project_root / "spider_eval/evaluation_bart.py").exists():
159
+ eval_script = project_root / "spider_eval/evaluation_bart.py"
160
+
161
+ cmd = [
162
+ sys.executable,
163
+ str(eval_script),
164
+ "--gold", str(gold_sql_file),
165
+ "--pred", str(pred_sql_file),
166
+ "--etype", "exec",
167
+ "--db", str(db_root),
168
+ "--table", str(table_json),
169
+ ]
170
+
171
+ print(f"\nRunning Spider evaluation using {eval_script.name}...")
172
+ proc = subprocess.run(cmd, capture_output=True, text=True, errors="ignore")
173
+
174
+ if proc.returncode != 0:
175
+ print("\nSpider evaluation crashed.")
176
+ print(proc.stderr)
177
+ return
178
+
179
+ print("\n--- Spider Eval Output ---")
180
+ print("\n".join(proc.stdout.splitlines()[-20:]))
181
+
182
+ acc = parse_exec_accuracy(proc.stdout)
183
+ if acc is not None:
184
+ print(f"\n🎯 Official Execution Accuracy: {acc*100:.2f}%")
185
+ else:
186
+ print("\nCould not parse official accuracy.")
187
+
188
+
189
+ if __name__ == "__main__":
190
+ main()
src/evaluate_without_constraied.py ADDED
@@ -0,0 +1,503 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # *********** code till task 3 ************
3
+
4
+ # import json
5
+ # import subprocess
6
+ # import sys
7
+ # import argparse
8
+ # import random
9
+ # import sqlite3
10
+ # import time
11
+ # import re
12
+ # import os
13
+ # from pathlib import Path
14
+
15
+ # import torch
16
+ # from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
17
+ # from peft import PeftModel
18
+
19
+ # from prompting import encode_prompt
20
+
21
+ # # -------------------------------
22
+ # # NORMALIZATION
23
+ # # -------------------------------
24
+ # def normalize_sql(sql):
25
+ # sql = sql.replace('"', "'")
26
+ # sql = re.sub(r"\s+", " ", sql)
27
+ # return sql.strip().lower().rstrip(";")
28
+
29
+
30
+ # # -------------------------------
31
+ # # 🔥 SAFE RESULT NORMALIZATION (FIX)
32
+ # # -------------------------------
33
+ # def normalize_result(res):
34
+ # try:
35
+ # return sorted([str(r) for r in res])
36
+ # except:
37
+ # return []
38
+
39
+
40
+ # # -------------------------------
41
+ # # EXECUTION CHECK (FIXED)
42
+ # # -------------------------------
43
+ # def check_execution(pred_sql, gold_sql, db_path):
44
+ # try:
45
+ # conn = sqlite3.connect(db_path)
46
+ # conn.text_factory = lambda b: b.decode(errors='ignore')
47
+
48
+ # start_time = time.monotonic()
49
+
50
+ # def timeout_handler():
51
+ # return 1 if (time.monotonic() - start_time) > 2.0 else 0
52
+
53
+ # conn.set_progress_handler(timeout_handler, 10000)
54
+
55
+ # cursor = conn.cursor()
56
+
57
+ # cursor.execute(pred_sql)
58
+ # pred_res = cursor.fetchall()
59
+
60
+ # cursor.execute(gold_sql)
61
+ # gold_res = cursor.fetchall()
62
+
63
+ # conn.close()
64
+
65
+ # # 🔥 FIXED COMPARISON
66
+ # return normalize_result(pred_res) == normalize_result(gold_res)
67
+
68
+ # except Exception:
69
+ # return False
70
+
71
+
72
+ # # -------------------------------
73
+ # # SPIDER PARSER
74
+ # # -------------------------------
75
+ # def _parse_spider_accuracy(stdout: str, metric_type: str):
76
+ # for line in stdout.splitlines():
77
+ # if metric_type == "exec" and line.strip().startswith("execution"):
78
+ # try:
79
+ # return float(line.split()[-1])
80
+ # except:
81
+ # pass
82
+ # elif metric_type == "match" and line.strip().startswith("exact"):
83
+ # try:
84
+ # return float(line.split()[-1])
85
+ # except:
86
+ # pass
87
+ # return None
88
+
89
+
90
+ # # -------------------------------
91
+ # # MAIN
92
+ # # -------------------------------
93
+ # def main():
94
+ # parser = argparse.ArgumentParser()
95
+ # parser.add_argument("--adapter", type=str, required=True)
96
+ # parser.add_argument("--num_samples", type=int, default= 500)
97
+ # parser.add_argument("--shuffle_dev", action="store_true")
98
+ # parser.add_argument("--shuffle_seed", type=int, default=42)
99
+ # args = parser.parse_args()
100
+
101
+ # project_root = Path(__file__).resolve().parents[1]
102
+ # adapter_dir = project_root / args.adapter
103
+
104
+ # db_root = project_root / "data" / "database"
105
+ # table_json = project_root / "data" / "tables.json"
106
+ # dev_json = project_root / "data" / "dev.json"
107
+
108
+ # pred_path = project_root / "temp_predictions.txt"
109
+ # temp_gold_path = project_root / "temp_gold.sql"
110
+
111
+ # if not adapter_dir.exists():
112
+ # raise FileNotFoundError(f"Missing adapter dir: {adapter_dir}")
113
+
114
+ # device = "mps" if torch.backends.mps.is_available() else (
115
+ # "cuda" if torch.cuda.is_available() else "cpu"
116
+ # )
117
+ # print(f"Using device: {device}")
118
+
119
+ # BASE_MODEL = "Salesforce/codet5-base"
120
+ # tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
121
+
122
+ # if tokenizer.pad_token is None:
123
+ # tokenizer.pad_token = tokenizer.eos_token
124
+
125
+ # print(f"\n📦 Loading Model: {args.adapter}")
126
+
127
+ # base = AutoModelForSeq2SeqLM.from_pretrained(BASE_MODEL).to(device)
128
+
129
+ # adapter_for_peft = os.path.relpath(adapter_dir, project_root)
130
+
131
+ # model = PeftModel.from_pretrained(
132
+ # base,
133
+ # adapter_for_peft,
134
+ # local_files_only=True
135
+ # ).to(device)
136
+
137
+ # model = model.merge_and_unload()
138
+ # model.eval()
139
+
140
+ # # -------------------------------
141
+ # # LOAD DATA
142
+ # # -------------------------------
143
+ # with dev_json.open() as f:
144
+ # dev = json.load(f)
145
+
146
+ # if args.shuffle_dev:
147
+ # rng = random.Random(args.shuffle_seed)
148
+ # rng.shuffle(dev)
149
+
150
+ # dev = dev[: args.num_samples]
151
+ # total = len(dev)
152
+
153
+ # gen_kwargs = dict(
154
+ # max_new_tokens=160,
155
+ # num_beams=8,
156
+ # length_penalty=0.8,
157
+ # do_sample=False,
158
+ # early_stopping=True,
159
+ # pad_token_id=tokenizer.pad_token_id,
160
+ # eos_token_id=tokenizer.eos_token_id,
161
+ # )
162
+
163
+ # print(f"\n🚀 Evaluating {total} samples...\n")
164
+
165
+ # em_correct = 0
166
+ # ex_correct = 0
167
+
168
+ # with pred_path.open("w") as out_pred, temp_gold_path.open("w") as out_gold, torch.no_grad():
169
+ # for i, ex in enumerate(dev, start=1):
170
+
171
+ # db_id = ex["db_id"]
172
+ # question = ex["question"]
173
+ # gold_query = ex["query"]
174
+ # db_path = db_root / db_id / f"{db_id}.sqlite"
175
+
176
+ # # -------------------------------
177
+ # # GENERATE SQL
178
+ # # -------------------------------
179
+ # input_ids = encode_prompt(
180
+ # tokenizer,
181
+ # question,
182
+ # db_id,
183
+ # device=device,
184
+ # max_input_tokens=512
185
+ # )
186
+
187
+ # input_ids = input_ids.unsqueeze(0).to(device)
188
+ # attention_mask = (input_ids != tokenizer.pad_token_id).long().to(device)
189
+
190
+ # outputs = model.generate(
191
+ # input_ids=input_ids,
192
+ # attention_mask=attention_mask,
193
+ # **gen_kwargs
194
+ # )
195
+
196
+ # pred_sql = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
197
+
198
+ # # -------------------------------
199
+ # # SAVE FOR SPIDER EVAL
200
+ # # -------------------------------
201
+ # out_pred.write(f"{pred_sql}\n")
202
+ # out_gold.write(f"{gold_query}\t{db_id}\n")
203
+
204
+ # # -------------------------------
205
+ # # LIVE METRICS
206
+ # # -------------------------------
207
+ # if normalize_sql(pred_sql) == normalize_sql(gold_query):
208
+ # em_correct += 1
209
+
210
+ # if check_execution(pred_sql, gold_query, db_path):
211
+ # ex_correct += 1
212
+
213
+ # if i % 20 == 0 or i == total:
214
+ # print(
215
+ # f"Progress: {i}/{total} | "
216
+ # f"EM: {(em_correct/i)*100:.2f}% | "
217
+ # f"EX: {(ex_correct/i)*100:.2f}%"
218
+ # )
219
+
220
+ # print("\n🚀 Running Official Spider Evaluation...\n")
221
+
222
+ # eval_script = project_root / "spider_eval" / "evaluation.py"
223
+
224
+ # # EXACT MATCH
225
+ # cmd_match = [
226
+ # sys.executable, str(eval_script),
227
+ # "--gold", str(temp_gold_path),
228
+ # "--pred", str(pred_path),
229
+ # "--etype", "match",
230
+ # "--db", str(db_root),
231
+ # "--table", str(table_json),
232
+ # ]
233
+
234
+ # proc_match = subprocess.run(cmd_match, capture_output=True, text=True)
235
+ # exact_acc = _parse_spider_accuracy(proc_match.stdout, "match")
236
+
237
+ # # EXECUTION
238
+ # cmd_exec = [
239
+ # sys.executable, str(eval_script),
240
+ # "--gold", str(temp_gold_path),
241
+ # "--pred", str(pred_path),
242
+ # "--etype", "exec",
243
+ # "--db", str(db_root),
244
+ # "--table", str(table_json),
245
+ # ]
246
+
247
+ # proc_exec = subprocess.run(cmd_exec, capture_output=True, text=True)
248
+ # exec_acc = _parse_spider_accuracy(proc_exec.stdout, "exec")
249
+
250
+ # print("==========================================")
251
+ # print(f"🎯 OFFICIAL SPIDER RESULTS FOR: {args.adapter}")
252
+ # print("==========================================")
253
+
254
+ # print(f"Exact Match Accuracy : {exact_acc*100:.2f}%" if exact_acc else "EM parsing failed")
255
+ # print(f"Execution Accuracy : {exec_acc*100:.2f}%" if exec_acc else "EX parsing failed")
256
+
257
+ # print("==========================================\n")
258
+
259
+
260
+ # if __name__ == "__main__":
261
+ # main()
262
+
263
+
264
+
265
+
266
+ # *********** for task 2 ****************************************
267
+ import json
268
+ import argparse
269
+ import random
270
+ import sqlite3
271
+ import re
272
+ import os
273
+ from pathlib import Path
274
+ from collections import defaultdict
275
+
276
+ import torch
277
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
278
+ from peft import PeftModel
279
+
280
+ from prompting import encode_prompt
281
+
282
+ # -------------------------------
283
+ # NORMALIZATION
284
+ # -------------------------------
285
+ def normalize_sql(sql):
286
+ sql = sql.replace('"', "'")
287
+ sql = re.sub(r"\s+", " ", sql)
288
+ return sql.strip().lower().rstrip(";")
289
+
290
+ def normalize_result(res):
291
+ try:
292
+ return sorted([str(r) for r in res])
293
+ except:
294
+ return []
295
+
296
+ # -------------------------------
297
+ # STEP 1: EXECUTION
298
+ # -------------------------------
299
+ def execute_with_error(sql, db_path):
300
+ try:
301
+ conn = sqlite3.connect(db_path)
302
+ cur = conn.cursor()
303
+ cur.execute(sql)
304
+ res = cur.fetchall()
305
+ conn.close()
306
+ return res, None
307
+ except Exception as e:
308
+ return None, str(e)
309
+
310
+ # -------------------------------
311
+ # STEP 2: ERROR CLASSIFICATION
312
+ # -------------------------------
313
+ def classify_error(sql, error_msg):
314
+ if error_msg is None:
315
+ return "correct"
316
+
317
+ err = error_msg.lower()
318
+ sql_l = sql.lower()
319
+
320
+ if "syntax" in err:
321
+ return "syntax_error"
322
+ if "no such table" in err:
323
+ return "wrong_table"
324
+ if "no such column" in err:
325
+ return "wrong_column"
326
+ if "ambiguous" in err:
327
+ return "missing_join"
328
+ if "datatype mismatch" in err:
329
+ return "type_error"
330
+ if "where" not in sql_l and any(x in sql_l for x in ["=", ">", "<"]):
331
+ return "missing_where"
332
+
333
+ return "other"
334
+
335
+ # -------------------------------
336
+ # STEP 4: HINTS
337
+ # -------------------------------
338
+ def generate_hint(error_type):
339
+ hints = {
340
+ "missing_join": "Try using JOIN between related tables.",
341
+ "wrong_column": "Check column names in schema.",
342
+ "missing_where": "Add WHERE condition.",
343
+ "syntax_error": "Fix SQL syntax.",
344
+ "wrong_table": "Verify table names.",
345
+ "type_error": "Check data types.",
346
+ "other": "Review SQL logic."
347
+ }
348
+ return hints.get(error_type, "")
349
+
350
+ # -------------------------------
351
+ # STEP 2 EXTRA: LIGHT ATTRIBUTION
352
+ # -------------------------------
353
+ def extract_keywords(question):
354
+ return [w for w in re.findall(r"\w+", question.lower()) if len(w) > 3]
355
+
356
+ # -------------------------------
357
+ # MAIN
358
+ # -------------------------------
359
+ def main():
360
+ parser = argparse.ArgumentParser()
361
+ parser.add_argument("--adapter", type=str, required=True)
362
+ parser.add_argument("--num_samples", type=int, default=200)
363
+ args = parser.parse_args()
364
+
365
+ project_root = Path(__file__).resolve().parents[1]
366
+ db_root = project_root / "data" / "database"
367
+ dev_json = project_root / "data" / "dev.json"
368
+
369
+ device = "mps" if torch.backends.mps.is_available() else "cpu"
370
+
371
+ tokenizer = AutoTokenizer.from_pretrained("Salesforce/codet5-base")
372
+ base = AutoModelForSeq2SeqLM.from_pretrained("Salesforce/codet5-base").to(device)
373
+
374
+ model = PeftModel.from_pretrained(
375
+ base,
376
+ os.path.relpath(project_root / args.adapter, project_root),
377
+ local_files_only=True
378
+ ).to(device)
379
+
380
+ model = model.merge_and_unload()
381
+ model.eval()
382
+
383
+ with open(dev_json) as f:
384
+ dev = json.load(f)
385
+
386
+ dev = dev[:args.num_samples]
387
+
388
+ # STORAGE
389
+ error_counter = defaultdict(int)
390
+ error_examples = defaultdict(list)
391
+ success_examples = []
392
+ hint_examples = defaultdict(list)
393
+ operation_counter = defaultdict(int)
394
+ attribution_map = defaultdict(list)
395
+
396
+ em, ex = 0, 0
397
+
398
+ print(f"\n🚀 Evaluating {len(dev)} samples...\n")
399
+
400
+ for i, sample in enumerate(dev, 1):
401
+
402
+ db_id = sample["db_id"]
403
+ q = sample["question"]
404
+ gold = sample["query"]
405
+ db_path = db_root / db_id / f"{db_id}.sqlite"
406
+
407
+ input_ids = encode_prompt(tokenizer, q, db_id, device=device).unsqueeze(0)
408
+
409
+ out = model.generate(input_ids=input_ids, max_new_tokens=120, num_beams=8)
410
+ pred = tokenizer.decode(out[0], skip_special_tokens=True).strip()
411
+
412
+ # operation analysis
413
+ s = pred.lower()
414
+ if "select" in s: operation_counter["SELECT"] += 1
415
+ if "where" in s: operation_counter["WHERE"] += 1
416
+ if "join" in s: operation_counter["JOIN"] += 1
417
+ if "group by" in s: operation_counter["GROUP_BY"] += 1
418
+ if "order by" in s: operation_counter["ORDER_BY"] += 1
419
+
420
+ pred_res, err = execute_with_error(pred, db_path)
421
+ gold_res, _ = execute_with_error(gold, db_path)
422
+
423
+ error_type = classify_error(pred, err)
424
+ error_counter[error_type] += 1
425
+
426
+ # attribution
427
+ if err:
428
+ attribution_map[error_type].append(extract_keywords(q))
429
+
430
+ # examples
431
+ if len(error_examples[error_type]) < 3:
432
+ error_examples[error_type].append(pred)
433
+
434
+ # hints
435
+ if error_type != "correct":
436
+ hint = generate_hint(error_type)
437
+ if len(hint_examples[error_type]) < 3:
438
+ hint_examples[error_type].append((pred, hint))
439
+
440
+ # metrics
441
+ if normalize_sql(pred) == normalize_sql(gold):
442
+ em += 1
443
+
444
+ if pred_res and gold_res and normalize_result(pred_res) == normalize_result(gold_res):
445
+ ex += 1
446
+ if len(success_examples) < 5:
447
+ success_examples.append(pred)
448
+
449
+ if i % 20 == 0:
450
+ print(f"[{i}] EM: {em/i:.2f} | EX: {ex/i:.2f}")
451
+
452
+ # -------------------------------
453
+ # OUTPUT
454
+ # -------------------------------
455
+ print("\n🎯 FINAL RESULTS")
456
+ print(f"EM: {em/len(dev)*100:.2f}%")
457
+ print(f"EX: {ex/len(dev)*100:.2f}%")
458
+
459
+ print("\n🔥 ERROR SUMMARY")
460
+ for k, v in error_counter.items():
461
+ print(k, ":", v)
462
+
463
+ print("\n🔥 ERROR EXAMPLES")
464
+ for k in error_examples:
465
+ print("\n", k)
466
+ for e in error_examples[k]:
467
+ print(" ", e)
468
+
469
+ print("\n🔥 HINTS")
470
+ for k in hint_examples:
471
+ print("\n", k)
472
+ for sql, h in hint_examples[k]:
473
+ print(" ", sql)
474
+ print(" →", h)
475
+
476
+ print("\n🔥 ATTRIBUTION (KEYWORDS)")
477
+ for k in attribution_map:
478
+ print(k, ":", attribution_map[k][:3])
479
+
480
+ print("\n🔥 SQL OPERATIONS")
481
+ for k, v in operation_counter.items():
482
+ print(k, ":", v)
483
+
484
+ # -------------------------------
485
+ # ADVERSARIAL
486
+ # -------------------------------
487
+ print("\n🔥 ADVERSARIAL TESTS")
488
+
489
+ adv = [
490
+ "Find most expensive product",
491
+ "Top 3 students by marks",
492
+ "Average salary per department"
493
+ ]
494
+
495
+ for q in adv:
496
+ inp = encode_prompt(tokenizer, q, dev[0]["db_id"], device=device).unsqueeze(0)
497
+ out = model.generate(input_ids=inp, max_new_tokens=120)
498
+ print("\nQ:", q)
499
+ print("SQL:", tokenizer.decode(out[0], skip_special_tokens=True))
500
+
501
+
502
+ if __name__ == "__main__":
503
+ main()
src/execution_reward copy.py ADDED
@@ -0,0 +1,831 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ # from __future__ import annotations
4
+
5
+ # import hashlib
6
+ # import os
7
+ # import queue
8
+ # import re
9
+ # import sqlite3
10
+ # import threading
11
+ # import time
12
+ # from concurrent.futures import ThreadPoolExecutor, as_completed
13
+ # from dataclasses import dataclass
14
+ # from typing import Dict, List, Optional, Sequence, Set, Tuple, Union
15
+
16
+ # # --- CACHE CONTROL ---
17
+ # USE_CACHE = True
18
+ # _REWARD_CACHE: Dict[str, float] = {}
19
+
20
+ # def set_use_cache(enabled: bool):
21
+ # """Dynamically toggle the reward cache for benchmarks."""
22
+ # global USE_CACHE
23
+ # USE_CACHE = enabled
24
+
25
+ # def _normalize_sql(sql: str) -> str:
26
+ # if not isinstance(sql, str):
27
+ # return ""
28
+ # s = sql.strip()
29
+ # if s.startswith("```"):
30
+ # s = re.sub(r"^```[a-zA-Z0-9_+-]*\n?", "", s).strip()
31
+ # s = re.sub(r"\n?```$", "", s).strip()
32
+ # if s.lower().startswith("sql:"):
33
+ # s = s[4:].strip()
34
+ # if ";" in s:
35
+ # s = s.split(";", 1)[0].strip()
36
+ # return s
37
+
38
+ # def _connect_readonly(db_path: str) -> sqlite3.Connection:
39
+ # uri = f"file:{os.path.abspath(db_path)}?mode=ro"
40
+ # conn = sqlite3.connect(uri, uri=True, check_same_thread=False)
41
+ # conn.execute("PRAGMA query_only = ON;")
42
+ # conn.execute("PRAGMA foreign_keys = ON;")
43
+ # return conn
44
+
45
+ # DEFAULT_QUERY_TIMEOUT_S = 2.0
46
+
47
+ # def _with_timeout(conn: sqlite3.Connection, timeout_s: float = DEFAULT_QUERY_TIMEOUT_S) -> None:
48
+ # start = time.monotonic()
49
+ # def _handler() -> int:
50
+ # return 1 if (time.monotonic() - start) > timeout_s else 0
51
+ # conn.set_progress_handler(_handler, 10_000)
52
+
53
+ # def _list_tables(conn: sqlite3.Connection) -> List[str]:
54
+ # try:
55
+ # cur = conn.execute("SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%';")
56
+ # return [r[0] for r in cur.fetchall() if r and isinstance(r[0], str)]
57
+ # except sqlite3.Error:
58
+ # return []
59
+
60
+ # def _contains_table_name(sql: str, table_names: Sequence[str]) -> bool:
61
+ # s = sql.lower()
62
+ # for t in table_names:
63
+ # tl = t.lower()
64
+ # if not tl:
65
+ # continue
66
+ # if re.search(rf"\b{re.escape(tl)}\b", s):
67
+ # return True
68
+ # return False
69
+
70
+ # def _explain_query_plan(conn: sqlite3.Connection, sql: str) -> bool:
71
+ # try:
72
+ # _with_timeout(conn, timeout_s=DEFAULT_QUERY_TIMEOUT_S)
73
+ # conn.execute(f"EXPLAIN QUERY PLAN {sql}")
74
+ # return True
75
+ # except sqlite3.Error:
76
+ # return False
77
+
78
+ # def _execute(conn: sqlite3.Connection, sql: str, max_rows: int = 1000) -> Tuple[bool, List[Tuple], Optional[str]]:
79
+ # try:
80
+ # _with_timeout(conn, timeout_s=DEFAULT_QUERY_TIMEOUT_S)
81
+ # cur = conn.execute(sql)
82
+ # rows = cur.fetchmany(max_rows)
83
+ # norm_rows = [tuple(r) for r in rows]
84
+ # return True, norm_rows, None
85
+ # except sqlite3.Error as e:
86
+ # return False, [], str(e)
87
+
88
+ # _SQL_KEYWORDS_TO_IGNORE = {
89
+ # "select", "from", "where", "join", "inner", "left", "right", "full", "outer",
90
+ # "on", "group", "by", "order", "limit", "having", "distinct", "union", "intersect",
91
+ # "except", "as", "and", "or", "not", "in", "is", "null", "like", "between", "case",
92
+ # "when", "then", "else", "end", "asc", "desc"
93
+ # }
94
+
95
+ # _SQL_FUNCTIONS_TO_IGNORE = {
96
+ # "count", "avg", "min", "max", "sum", "lower", "upper", "substr", "coalesce",
97
+ # "round", "date", "datetime", "strftime"
98
+ # }
99
+
100
+ # # --- LIGHTWEIGHT PARSING ---
101
+ # def is_valid_select(sql: str):
102
+ # sql = sql.strip().lower()
103
+ # return sql.startswith("select") or sql.startswith("with")
104
+
105
+ # def extract_tables(sql: str) -> List[str]:
106
+ # sql = sql.lower()
107
+ # if "join" not in sql:
108
+ # tables = re.findall(r'from\s+(\w+)', sql)
109
+ # return list(set(tables))
110
+
111
+ # tables = re.findall(r'from\s+([a-zA-Z_][a-zA-Z0-9_]*)', sql)
112
+ # joins = re.findall(r'join\s+([a-zA-Z_][a-zA-Z0-9_]*)', sql)
113
+ # return list(set(tables + joins))
114
+
115
+ # def extract_columns(sql: str) -> List[str]:
116
+ # sql = sql.lower()
117
+ # match = re.search(r'select\s+(.*?)\s+from', sql)
118
+ # if not match:
119
+ # return []
120
+ # cols = match.group(1)
121
+ # if cols.strip() == "*":
122
+ # return ["*"]
123
+ # return [c.strip() for c in cols.split(",")]
124
+
125
+ # def _get_db_tables_and_columns(conn: sqlite3.Connection) -> Tuple[Set[str], Set[str]]:
126
+ # tables = set()
127
+ # columns = set()
128
+ # for t in _list_tables(conn):
129
+ # tl = t.lower()
130
+ # if not tl:
131
+ # continue
132
+ # tables.add(tl)
133
+ # try:
134
+ # cur = conn.execute(f'PRAGMA table_info("{t}")')
135
+ # for row in cur.fetchall():
136
+ # if row and isinstance(row[1], str):
137
+ # columns.add(row[1].lower())
138
+ # except sqlite3.Error:
139
+ # continue
140
+ # return tables, columns
141
+
142
+ # def _safe_results_equal(a: List[Tuple], b: List[Tuple]) -> bool:
143
+ # return a == b
144
+
145
+ # @dataclass
146
+ # class RewardDebugStats:
147
+ # total: int = 0
148
+ # parsed_ok: int = 0
149
+ # table_match: int = 0
150
+ # column_match: int = 0
151
+ # executed_ok: int = 0
152
+ # exact_match: int = 0
153
+
154
+ # _DEBUG = RewardDebugStats()
155
+
156
+ # def reset_debug_metrics() -> None:
157
+ # global _DEBUG
158
+ # _DEBUG = RewardDebugStats()
159
+
160
+ # def get_debug_metrics() -> dict:
161
+ # denom = max(_DEBUG.total, 1)
162
+ # return {
163
+ # "valid_sql_rate": _DEBUG.parsed_ok / denom,
164
+ # "table_match_rate": _DEBUG.table_match / denom,
165
+ # "column_match_rate": _DEBUG.column_match / denom,
166
+ # "execution_accuracy": _DEBUG.exact_match / denom,
167
+ # }
168
+
169
+ # EXECUTION_ERROR = "EXECUTION_ERROR"
170
+
171
+ # _RESULT_CACHE_LOCK = threading.Lock()
172
+ # _RESULT_CACHE: "Dict[str, Union[List[Tuple], str]]" = {}
173
+ # _RESULT_CACHE_MAX = 100_000
174
+
175
+ # def clear_result_cache() -> None:
176
+ # """Clear both DB query cache and reward cache."""
177
+ # with _RESULT_CACHE_LOCK:
178
+ # _RESULT_CACHE.clear()
179
+ # _REWARD_CACHE.clear()
180
+
181
+ # def _db_state_fingerprint(db_path: str) -> str:
182
+ # try:
183
+ # st = os.stat(db_path)
184
+ # return f"{st.st_mtime_ns}:{st.st_size}"
185
+ # except OSError:
186
+ # return "missing"
187
+
188
+ # def _result_cache_key(db_path: str, sql: str) -> str:
189
+ # fp = _db_state_fingerprint(db_path)
190
+ # payload = f"{fp}\0{sql}".encode("utf-8", errors="ignore")
191
+ # return hashlib.sha256(payload).hexdigest()
192
+
193
+ # class _ConnectionPool:
194
+ # def __init__(self, db_path: str, maxsize: int = 1) -> None:
195
+ # self.db_path = db_path
196
+ # self.pool = queue.LifoQueue(maxsize=maxsize)
197
+ # self.lock = threading.Lock()
198
+
199
+ # def acquire(self) -> sqlite3.Connection:
200
+ # try:
201
+ # return self.pool.get_nowait()
202
+ # except queue.Empty:
203
+ # with self.lock:
204
+ # try:
205
+ # return self.pool.get_nowait()
206
+ # except queue.Empty:
207
+ # return _connect_readonly(self.db_path)
208
+
209
+ # def release(self, conn: sqlite3.Connection) -> None:
210
+ # try:
211
+ # self.pool.put_nowait(conn)
212
+ # except queue.Full:
213
+ # try:
214
+ # conn.close()
215
+ # except Exception:
216
+ # pass
217
+
218
+ # _POOL_LOCK = threading.Lock()
219
+ # _POOLS: Dict[str, _ConnectionPool] = {}
220
+
221
+ # def _get_pool(db_path: str) -> _ConnectionPool:
222
+ # with _POOL_LOCK:
223
+ # pool = _POOLS.get(db_path)
224
+ # if pool is None:
225
+ # pool = _ConnectionPool(db_path=db_path, maxsize=1)
226
+ # _POOLS[db_path] = pool
227
+ # return pool
228
+
229
+ # class _PooledConnection:
230
+ # def __init__(self, db_path: str) -> None:
231
+ # self.db_path = db_path
232
+ # self.pool = _get_pool(db_path)
233
+ # self.conn: Optional[sqlite3.Connection] = None
234
+
235
+ # def __enter__(self) -> sqlite3.Connection:
236
+ # self.conn = self.pool.acquire()
237
+ # return self.conn
238
+
239
+ # def __exit__(self, exc_type, exc, tb) -> None:
240
+ # if self.conn is not None:
241
+ # self.pool.release(self.conn)
242
+ # self.conn = None
243
+
244
+ # def _cache_get(key: str) -> Optional[Union[List[Tuple], str]]:
245
+ # with _RESULT_CACHE_LOCK:
246
+ # return _RESULT_CACHE.get(key)
247
+
248
+ # def _cache_put(key: str, value: Union[List[Tuple], str]) -> None:
249
+ # with _RESULT_CACHE_LOCK:
250
+ # if len(_RESULT_CACHE) >= _RESULT_CACHE_MAX:
251
+ # _RESULT_CACHE.clear()
252
+ # _RESULT_CACHE[key] = value
253
+
254
+ # def execute_sql(conn: sqlite3.Connection, sql: str, *, max_rows: int = 1000) -> Union[List[Tuple], str]:
255
+ # try:
256
+ # _with_timeout(conn, timeout_s=DEFAULT_QUERY_TIMEOUT_S)
257
+ # cur = conn.execute(sql)
258
+ # rows = cur.fetchmany(max_rows)
259
+ # return [tuple(r) for r in rows]
260
+ # except Exception:
261
+ # return EXECUTION_ERROR
262
+
263
+ # def execute_sql_cached(db_path: str, sql: str, *, max_rows: int = 1000) -> Union[List[Tuple], str]:
264
+ # if not USE_CACHE:
265
+ # with _PooledConnection(db_path) as conn:
266
+ # return execute_sql(conn, sql, max_rows=max_rows)
267
+
268
+ # key = _result_cache_key(db_path, sql)
269
+ # cached = _cache_get(key)
270
+ # if cached is not None:
271
+ # return cached
272
+ # with _PooledConnection(db_path) as conn:
273
+ # res = execute_sql(conn, sql, max_rows=max_rows)
274
+ # _cache_put(key, res)
275
+ # return res
276
+
277
+ # def execution_reward_timed(
278
+ # pred_sql: str, db_path: str, gold_sql: str, *, measure_plan: bool = False,
279
+ # ) -> Tuple[float, Dict[str, float]]:
280
+ # timings = {"parse_s": 0.0, "plan_s": 0.0, "exec_s": 0.0}
281
+ # t0 = time.perf_counter()
282
+ # sql = _normalize_sql(pred_sql)
283
+ # gold = _normalize_sql(gold_sql)
284
+
285
+ # if not is_valid_select(sql):
286
+ # timings["parse_s"] = time.perf_counter() - t0
287
+ # return 0.0, timings
288
+
289
+ # t1 = time.perf_counter()
290
+ # timings["parse_s"] = t1 - t0
291
+
292
+ # if measure_plan:
293
+ # with _PooledConnection(db_path) as conn:
294
+ # p0 = time.perf_counter()
295
+ # _explain_query_plan(conn, sql)
296
+ # _explain_query_plan(conn, gold)
297
+ # timings["plan_s"] = time.perf_counter() - p0
298
+
299
+ # e0 = time.perf_counter()
300
+ # pred_res = execute_sql_cached(db_path, sql)
301
+ # if pred_res == EXECUTION_ERROR:
302
+ # timings["exec_s"] = time.perf_counter() - e0
303
+ # return 0.0, timings
304
+ # gold_res = execute_sql_cached(db_path, gold)
305
+ # timings["exec_s"] = time.perf_counter() - e0
306
+ # if gold_res == EXECUTION_ERROR:
307
+ # return 0.0, timings
308
+
309
+ # reward = -0.2
310
+ # reward += 0.2
311
+ # if _safe_results_equal(pred_res, gold_res):
312
+ # return 1.0, timings
313
+ # return max(-1.0, min(1.0, reward)), timings
314
+
315
+ # def execution_reward(pred_sql: str, db_path: str, gold_sql: str) -> float:
316
+ # try:
317
+ # sql = _normalize_sql(pred_sql)
318
+ # gold = _normalize_sql(gold_sql)
319
+
320
+ # if not is_valid_select(sql):
321
+ # return -1.0
322
+
323
+ # reward = -0.2
324
+
325
+ # pred_tables = set(extract_tables(sql))
326
+ # gold_tables = set(extract_tables(gold))
327
+
328
+ # if pred_tables == gold_tables and len(gold_tables) > 0:
329
+ # reward += 0.3
330
+
331
+ # pred_cols = set(extract_columns(sql))
332
+ # gold_cols = set(extract_columns(gold))
333
+
334
+ # if gold_cols:
335
+ # overlap = len(pred_cols & gold_cols) / len(gold_cols)
336
+ # reward += 0.3 * overlap
337
+
338
+ # pred_res = execute_sql_cached(db_path, sql)
339
+ # if pred_res == EXECUTION_ERROR:
340
+ # return 0.0
341
+ # reward += 0.2
342
+
343
+ # gold_res = execute_sql_cached(db_path, gold)
344
+ # if gold_res == EXECUTION_ERROR:
345
+ # return 0.0
346
+ # if _safe_results_equal(pred_res, gold_res):
347
+ # return 1.0
348
+
349
+ # return max(-1.0, min(1.0, reward))
350
+
351
+ # except Exception:
352
+ # return 0.0
353
+
354
+ # def cached_execution_reward(pred_sql: str, db_path: str, gold_sql: str) -> float:
355
+ # if not USE_CACHE:
356
+ # return execution_reward(pred_sql, db_path, gold_sql)
357
+
358
+ # key = f"{db_path}|{pred_sql}|{gold_sql}"
359
+ # if key not in _REWARD_CACHE:
360
+ # _REWARD_CACHE[key] = execution_reward(pred_sql, db_path, gold_sql)
361
+ # return _REWARD_CACHE[key]
362
+
363
+ # def execution_reward_batch_sequential(rollouts: Sequence[Tuple[str, str, str]]) -> List[float]:
364
+ # return [cached_execution_reward(pred_sql, db_path, gold_sql) for pred_sql, db_path, gold_sql in rollouts]
365
+
366
+ # def execution_reward_batch_parallel(rollouts: Sequence[Tuple[str, str, str]], *, max_workers: int = 20) -> List[float]:
367
+ # if not rollouts:
368
+ # return []
369
+
370
+ # unique_dbs = {db_path for _, db_path, _ in rollouts}
371
+ # worker_count = max(1, min(max_workers, len(unique_dbs)))
372
+ # results: List[Optional[float]] = [None] * len(rollouts)
373
+
374
+ # with ThreadPoolExecutor(max_workers=worker_count) as executor:
375
+ # futures = {
376
+ # executor.submit(cached_execution_reward, pred_sql, db_path, gold_sql): i
377
+ # for i, (pred_sql, db_path, gold_sql) in enumerate(rollouts)
378
+ # }
379
+ # for fut in as_completed(futures):
380
+ # idx = futures[fut]
381
+ # try:
382
+ # results[idx] = float(fut.result())
383
+ # except Exception:
384
+ # results[idx] = 0.0
385
+
386
+ # return [r if r is not None else 0.0 for r in results]
387
+
388
+ from __future__ import annotations
389
+
390
+ import os
391
+ import re
392
+ import sqlite3
393
+ import threading
394
+ import time
395
+ import json
396
+ from concurrent.futures import ThreadPoolExecutor, as_completed
397
+ from dataclasses import dataclass
398
+ from typing import Dict, List
399
+
400
+ from src.sql_validator import validate_sql_schema
401
+
402
+ # =========================================================
403
+ # 🔥 CONFIG FLAGS
404
+ # =========================================================
405
+ USE_SCHEMA_VALIDATION = True
406
+ USE_CACHE = True
407
+ DEFAULT_QUERY_TIMEOUT_S = 2.0
408
+
409
+ EXECUTION_ERROR = "EXECUTION_ERROR"
410
+
411
+ _REWARD_CACHE: Dict[str, float] = {}
412
+
413
+ # =========================================================
414
+ # 🔥 TASK 2: ERROR ANALYSIS + LOGGING
415
+ # =========================================================
416
+ ERROR_LOG_FILE = "results/error_logs.json"
417
+
418
+
419
+ def classify_error(sql: str) -> str:
420
+ sql = sql.lower()
421
+
422
+ if "join" in sql and " on " not in sql:
423
+ return "missing_join"
424
+
425
+ if "where" in sql and "=" not in sql and ">" not in sql and "<" not in sql:
426
+ return "wrong_where"
427
+
428
+ if "null" in sql:
429
+ return "null_handling"
430
+
431
+ if "group by" in sql and "count" not in sql:
432
+ return "wrong_groupby"
433
+
434
+ return "other"
435
+
436
+
437
+ def get_hint(error_type: str) -> str:
438
+ hints = {
439
+ "missing_join": "Add proper JOIN condition using ON.",
440
+ "wrong_where": "Check WHERE clause conditions.",
441
+ "null_handling": "Handle NULL values using IS NULL.",
442
+ "wrong_groupby": "Use aggregation functions with GROUP BY.",
443
+ "other": "Check SQL syntax and logic."
444
+ }
445
+ return hints.get(error_type, "Check query.")
446
+
447
+
448
+ def log_error(question: str, sql: str, error: str, error_type: str):
449
+ os.makedirs("results", exist_ok=True)
450
+
451
+ entry = {
452
+ "question": question,
453
+ "sql": sql,
454
+ "error": error,
455
+ "error_type": error_type,
456
+ "timestamp": time.time()
457
+ }
458
+
459
+ if os.path.exists(ERROR_LOG_FILE):
460
+ with open(ERROR_LOG_FILE, "r") as f:
461
+ logs = json.load(f)
462
+ else:
463
+ logs = []
464
+
465
+ logs.append(entry)
466
+
467
+ with open(ERROR_LOG_FILE, "w") as f:
468
+ json.dump(logs, f, indent=2)
469
+
470
+ # =========================================================
471
+ # CACHE/VALIDATION TOGGLES (Task 1)
472
+ # =========================================================
473
+ def set_use_cache(enabled: bool) -> None:
474
+ global USE_CACHE
475
+ USE_CACHE = bool(enabled)
476
+
477
+
478
+ def set_use_schema_validation(enabled: bool) -> None:
479
+ global USE_SCHEMA_VALIDATION
480
+ USE_SCHEMA_VALIDATION = bool(enabled)
481
+
482
+
483
+ # =========================================================
484
+ # SQL CLEANING
485
+ # =========================================================
486
+ def _normalize_sql(sql: str) -> str:
487
+ if not isinstance(sql, str):
488
+ return ""
489
+ s = sql.strip()
490
+
491
+ if s.startswith("```"):
492
+ s = re.sub(r"^```[a-zA-Z0-9_+-]*\n?", "", s).strip()
493
+ s = re.sub(r"\n?```$", "", s).strip()
494
+
495
+ if s.lower().startswith("sql:"):
496
+ s = s[4:].strip()
497
+
498
+ if ";" in s:
499
+ s = s.split(";", 1)[0].strip()
500
+
501
+ return s
502
+
503
+
504
+ # =========================================================
505
+ # DB EXECUTION
506
+ # =========================================================
507
+ def _connect_readonly(db_path: str):
508
+ uri = f"file:{os.path.abspath(db_path)}?mode=ro"
509
+ conn = sqlite3.connect(uri, uri=True, check_same_thread=False)
510
+ conn.execute("PRAGMA query_only = ON;")
511
+ conn.execute("PRAGMA foreign_keys = ON;")
512
+ return conn
513
+
514
+
515
+ def _with_timeout(conn: sqlite3.Connection, timeout_s: float = DEFAULT_QUERY_TIMEOUT_S):
516
+ start = time.monotonic()
517
+
518
+ def handler():
519
+ return 1 if (time.monotonic() - start) > timeout_s else 0
520
+
521
+ conn.set_progress_handler(handler, 10_000)
522
+
523
+
524
+ def execute_sql(conn, sql):
525
+ try:
526
+ _with_timeout(conn, timeout_s=DEFAULT_QUERY_TIMEOUT_S)
527
+ cur = conn.execute(sql)
528
+ return cur.fetchall()
529
+ except Exception:
530
+ return EXECUTION_ERROR
531
+
532
+
533
+ _RESULT_CACHE = {}
534
+ _RESULT_LOCK = threading.Lock()
535
+
536
+
537
+ def execute_sql_cached(db_path, sql):
538
+ key = f"{db_path}|{sql}"
539
+
540
+ if USE_CACHE:
541
+ with _RESULT_LOCK:
542
+ if key in _RESULT_CACHE:
543
+ return _RESULT_CACHE[key]
544
+
545
+ conn = _connect_readonly(db_path)
546
+ result = execute_sql(conn, sql)
547
+ conn.close()
548
+
549
+ if USE_CACHE:
550
+ with _RESULT_LOCK:
551
+ _RESULT_CACHE[key] = result
552
+
553
+ return result
554
+
555
+
556
+ def execute_sql_cached_conn(conn: sqlite3.Connection, db_path: str, sql: str):
557
+ """
558
+ Like execute_sql_cached(), but reuses an existing connection.
559
+ Intended for 1-thread-per-DB workloads (Task 1).
560
+ """
561
+ key = f"{db_path}|{sql}"
562
+ if USE_CACHE:
563
+ with _RESULT_LOCK:
564
+ if key in _RESULT_CACHE:
565
+ return _RESULT_CACHE[key]
566
+
567
+ result = execute_sql(conn, sql)
568
+
569
+ if USE_CACHE:
570
+ with _RESULT_LOCK:
571
+ _RESULT_CACHE[key] = result
572
+
573
+ return result
574
+
575
+
576
+ def clear_result_cache() -> None:
577
+ global _RESULT_CACHE, _REWARD_CACHE
578
+ with _RESULT_LOCK:
579
+ _RESULT_CACHE.clear()
580
+ _REWARD_CACHE.clear()
581
+
582
+
583
+ # =========================================================
584
+ # SQL PARSING
585
+ # =========================================================
586
+ def is_valid_select(sql):
587
+ return sql.lower().startswith("select") or sql.lower().startswith("with")
588
+
589
+
590
+ def extract_tables(sql):
591
+ return re.findall(r'from\s+(\w+)', sql.lower())
592
+
593
+
594
+ def extract_columns(sql):
595
+ match = re.search(r'select\s+(.*?)\s+from', sql.lower())
596
+ if not match:
597
+ return []
598
+ cols = match.group(1)
599
+ return ["*"] if cols.strip() == "*" else [c.strip() for c in cols.split(",")]
600
+
601
+
602
+ def get_sql_operations(sql: str):
603
+ sql = sql.lower()
604
+ ops = []
605
+
606
+ if "select" in sql: ops.append("SELECT")
607
+ if "where" in sql: ops.append("WHERE")
608
+ if "join" in sql: ops.append("JOIN")
609
+ if "group by" in sql: ops.append("GROUP_BY")
610
+ if "order by" in sql: ops.append("ORDER_BY")
611
+
612
+ return ops
613
+
614
+
615
+ def _explain_query_plan(conn: sqlite3.Connection, sql: str) -> bool:
616
+ try:
617
+ _with_timeout(conn, timeout_s=DEFAULT_QUERY_TIMEOUT_S)
618
+ conn.execute(f"EXPLAIN QUERY PLAN {sql}")
619
+ return True
620
+ except Exception:
621
+ return False
622
+
623
+
624
+ def execution_reward_timed(pred_sql: str, db_path: str, gold_sql: str, measure_plan: bool = False):
625
+ """
626
+ Returns (reward, timings) where timings keys: parse_s, plan_s, exec_s.
627
+ Used by Task-1 benchmark to profile bottlenecks.
628
+ """
629
+ timings = {"parse_s": 0.0, "plan_s": 0.0, "exec_s": 0.0}
630
+ t0 = time.perf_counter()
631
+
632
+ sql = _normalize_sql(pred_sql)
633
+ gold = _normalize_sql(gold_sql)
634
+
635
+ if not is_valid_select(sql):
636
+ timings["parse_s"] = time.perf_counter() - t0
637
+ return 0.0, timings
638
+
639
+ t1 = time.perf_counter()
640
+ timings["parse_s"] = t1 - t0
641
+
642
+ conn = _connect_readonly(db_path)
643
+ try:
644
+ if measure_plan:
645
+ p0 = time.perf_counter()
646
+ _explain_query_plan(conn, sql)
647
+ _explain_query_plan(conn, gold)
648
+ timings["plan_s"] = time.perf_counter() - p0
649
+
650
+ e0 = time.perf_counter()
651
+ pred_res = execute_sql_cached_conn(conn, db_path, sql)
652
+ if pred_res == EXECUTION_ERROR:
653
+ timings["exec_s"] = time.perf_counter() - e0
654
+ return 0.0, timings
655
+ gold_res = execute_sql_cached_conn(conn, db_path, gold)
656
+ timings["exec_s"] = time.perf_counter() - e0
657
+ if gold_res == EXECUTION_ERROR:
658
+ return 0.0, timings
659
+
660
+ reward = -0.2 + 0.2
661
+ if pred_res == gold_res:
662
+ return 1.0, timings
663
+ return max(-1.0, min(1.0, reward)), timings
664
+ finally:
665
+ try:
666
+ conn.close()
667
+ except Exception:
668
+ pass
669
+
670
+
671
+ # =========================================================
672
+ # 🔥 FINAL REWARD FUNCTION (TASK 2 INTEGRATED)
673
+ # =========================================================
674
+ def execution_reward(pred_sql: str, db_path: str, gold_sql: str) -> float:
675
+ try:
676
+ sql = _normalize_sql(pred_sql)
677
+ gold = _normalize_sql(gold_sql)
678
+
679
+ if not is_valid_select(sql):
680
+ return -1.0
681
+
682
+ reward = -0.2
683
+
684
+ # =========================
685
+ # SCHEMA VALIDATION (Task 3)
686
+ # =========================
687
+ if USE_SCHEMA_VALIDATION:
688
+ valid, _ = validate_sql_schema(sql, db_path)
689
+ if not valid:
690
+ error_type = classify_error(sql)
691
+ log_error("UNKNOWN", sql, "schema_invalid", error_type)
692
+ return 0.1
693
+
694
+ # =========================
695
+ # EXECUTION
696
+ # =========================
697
+ pred_res = execute_sql_cached(db_path, sql)
698
+
699
+ if pred_res == "EXECUTION_ERROR":
700
+ error_type = classify_error(sql)
701
+
702
+ log_error(
703
+ question="UNKNOWN",
704
+ sql=sql,
705
+ error="execution_error",
706
+ error_type=error_type
707
+ )
708
+
709
+ print(f"[ERROR] {error_type}")
710
+ print(f"[HINT] {get_hint(error_type)}")
711
+
712
+ return 0.1
713
+
714
+ reward += 0.2
715
+
716
+ gold_res = execute_sql_cached(db_path, gold)
717
+
718
+ if gold_res == "EXECUTION_ERROR":
719
+ return 0.1
720
+
721
+ if pred_res == gold_res:
722
+ return 1.0
723
+
724
+ return max(-1.0, min(1.0, reward))
725
+
726
+ except Exception as e:
727
+ log_error("UNKNOWN", pred_sql, str(e), "runtime_error")
728
+ return 0.0
729
+
730
+
731
+ # =========================================================
732
+ # BATCH EXECUTION (Task 1)
733
+ # =========================================================
734
+ def cached_execution_reward(pred_sql: str, db_path: str, gold_sql: str) -> float:
735
+ if not USE_CACHE:
736
+ return float(execution_reward(pred_sql, db_path, gold_sql))
737
+ key = f"{db_path}|{pred_sql}|{gold_sql}"
738
+ if key in _REWARD_CACHE:
739
+ return float(_REWARD_CACHE[key])
740
+ r = float(execution_reward(pred_sql, db_path, gold_sql))
741
+ _REWARD_CACHE[key] = r
742
+ return r
743
+
744
+
745
+ def execution_reward_batch_sequential(rollouts):
746
+ return [cached_execution_reward(p, d, g) for (p, d, g) in rollouts]
747
+
748
+
749
+ def execution_reward_batch_parallel(rollouts, max_workers=10):
750
+ results = [0.0] * len(rollouts)
751
+
752
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
753
+ futures = {
754
+ executor.submit(cached_execution_reward, p, d, g): i
755
+ for i, (p, d, g) in enumerate(rollouts)
756
+ }
757
+
758
+ for fut in as_completed(futures):
759
+ idx = futures[fut]
760
+ try:
761
+ results[idx] = fut.result()
762
+ except Exception:
763
+ results[idx] = 0.0
764
+
765
+ return results
766
+
767
+
768
+ def execution_reward_batch_parallel_by_db(rollouts, max_workers: int = 20):
769
+ """
770
+ 1 thread per DB path. Reuses a single readonly connection per DB worker.
771
+ Preserves input order.
772
+ """
773
+ if not rollouts:
774
+ return []
775
+
776
+ by_db = {}
777
+ for idx, (pred_sql, db_path, gold_sql) in enumerate(rollouts):
778
+ by_db.setdefault(db_path, []).append((idx, pred_sql, gold_sql))
779
+
780
+ results = [0.0 for _ in range(len(rollouts))]
781
+
782
+ def _reward_with_conn(conn: sqlite3.Connection, pred_sql: str, db_path: str, gold_sql: str) -> float:
783
+ try:
784
+ sql = _normalize_sql(pred_sql)
785
+ gold = _normalize_sql(gold_sql)
786
+
787
+ if not is_valid_select(sql):
788
+ return -1.0
789
+
790
+ reward = -0.2
791
+
792
+ if USE_SCHEMA_VALIDATION:
793
+ valid, _ = validate_sql_schema(sql, db_path)
794
+ if not valid:
795
+ error_type = classify_error(sql)
796
+ log_error("UNKNOWN", sql, "schema_invalid", error_type)
797
+ return 0.1
798
+
799
+ pred_res = execute_sql_cached_conn(conn, db_path, sql)
800
+ if pred_res == EXECUTION_ERROR:
801
+ error_type = classify_error(sql)
802
+ log_error("UNKNOWN", sql, "execution_error", error_type)
803
+ return 0.1
804
+
805
+ reward += 0.2
806
+ gold_res = execute_sql_cached_conn(conn, db_path, gold)
807
+ if gold_res == EXECUTION_ERROR:
808
+ return 0.1
809
+ if pred_res == gold_res:
810
+ return 1.0
811
+ return max(-1.0, min(1.0, reward))
812
+ except Exception:
813
+ return 0.0
814
+
815
+ def _worker(db_path: str, items):
816
+ conn = _connect_readonly(db_path)
817
+ try:
818
+ for idx, pred, gold in items:
819
+ results[idx] = _reward_with_conn(conn, pred, db_path, gold)
820
+ finally:
821
+ try:
822
+ conn.close()
823
+ except Exception:
824
+ pass
825
+
826
+ with ThreadPoolExecutor(max_workers=int(max_workers)) as ex:
827
+ futures = [ex.submit(_worker, db_path, items) for db_path, items in by_db.items()]
828
+ for fut in as_completed(futures):
829
+ fut.result()
830
+
831
+ return results