Commit ·
cf17729
1
Parent(s): b70f6fd
Added full project
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- README.md +13 -3
- app.py +569 -0
- db.zip +3 -0
- int8_dynamic/meta.json +7 -0
- int8_dynamic/model.pt +3 -0
- int8_dynamic/tokenizer/merges.txt +0 -0
- int8_dynamic/tokenizer/special_tokens_map.json +753 -0
- int8_dynamic/tokenizer/tokenizer.json +0 -0
- int8_dynamic/tokenizer/tokenizer_config.json +959 -0
- int8_dynamic/tokenizer/vocab.json +0 -0
- requirements.txt +10 -0
- scripts/__pycache__/benchmark_parallel_reward.cpython-310.pyc +0 -0
- scripts/__pycache__/benchmark_parallel_reward.cpython-313.pyc +0 -0
- scripts/__pycache__/benchmark_quantization.cpython-310.pyc +0 -0
- scripts/__pycache__/benchmark_rollout_generation.cpython-310.pyc +0 -0
- scripts/__pycache__/quantize_export.cpython-310.pyc +0 -0
- scripts/__pycache__/quantized_infer_harness.cpython-310.pyc +0 -0
- scripts/benchmark_parallel_reward.py +202 -0
- scripts/benchmark_quantization.py +108 -0
- scripts/benchmark_rollout_generation.py +66 -0
- scripts/error_dashboard.py +99 -0
- scripts/evaluate.py +170 -0
- scripts/plot_task2.py +58 -0
- scripts/plot_task3.py +15 -0
- scripts/plot_task3_plotly.py +103 -0
- scripts/quantize_export.py +86 -0
- scripts/quantized_infer_harness.py +46 -0
- src/__pycache__/execution_reward.cpython-310.pyc +0 -0
- src/__pycache__/quantization_utils.cpython-310.pyc +0 -0
- src/__pycache__/quantized_text2sql_engine.cpython-310.pyc +0 -0
- src/__pycache__/schema_encoder.cpython-310.pyc +0 -0
- src/__pycache__/schema_utils.cpython-310.pyc +0 -0
- src/__pycache__/sql_validator.cpython-310.pyc +0 -0
- src/__pycache__/text2sql_engine.cpython-310.pyc +0 -0
- src/ask.py +93 -0
- src/component_analysis.py +229 -0
- src/constrained_decoding.py +1058 -0
- src/constrained_decoding_sample.py +516 -0
- src/convert_to_hf_dataset.py +8 -0
- src/eval_baseline_codet5.py +112 -0
- src/eval_both_metrics.py +144 -0
- src/eval_rl_fixed.py +756 -0
- src/eval_rl_t5.py +279 -0
- src/eval_single_model.py +218 -0
- src/evaluate_model_codet5.py +392 -0
- src/evaluate_model_t5_small_sft.py +179 -0
- src/evaluate_rl_bart.py +138 -0
- src/evaluate_sft_bart.py +190 -0
- src/evaluate_without_constraied.py +503 -0
- src/execution_reward copy.py +831 -0
README.md
CHANGED
|
@@ -1,3 +1,13 @@
|
|
| 1 |
-
---
|
| 2 |
-
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Text2sql Demo
|
| 3 |
+
emoji: 📊
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: green
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: 5.8.0
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: false
|
| 10 |
+
license: mit
|
| 11 |
+
python_version: 3.10.13
|
| 12 |
+
short_description: 'Text to SQL with RLHF'
|
| 13 |
+
---
|
app.py
ADDED
|
@@ -0,0 +1,569 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
GRADIO DEMO UI - LAZY LOADING EDITION
|
| 3 |
+
NL → SQL → Result Table
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import gradio as gr
|
| 7 |
+
import pandas as pd
|
| 8 |
+
import re
|
| 9 |
+
import time
|
| 10 |
+
import os
|
| 11 |
+
import torch
|
| 12 |
+
import sys
|
| 13 |
+
import json
|
| 14 |
+
import subprocess
|
| 15 |
+
import base64
|
| 16 |
+
import io
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
from typing import Iterator
|
| 19 |
+
|
| 20 |
+
# ==========================================
|
| 21 |
+
# RELATIVE PATH RESOLUTION (GLOBAL)
|
| 22 |
+
# ==========================================
|
| 23 |
+
try:
|
| 24 |
+
PROJECT_ROOT = Path(__file__).resolve().parent
|
| 25 |
+
except NameError:
|
| 26 |
+
PROJECT_ROOT = Path(".").resolve()
|
| 27 |
+
|
| 28 |
+
if (PROJECT_ROOT / "data" / "database").exists():
|
| 29 |
+
DB_ROOT = PROJECT_ROOT / "data" / "database"
|
| 30 |
+
else:
|
| 31 |
+
DB_ROOT = PROJECT_ROOT / "final_databases"
|
| 32 |
+
|
| 33 |
+
def get_db_path(db_id: str) -> str:
|
| 34 |
+
path1 = DB_ROOT / db_id / f"{db_id}.sqlite"
|
| 35 |
+
path2 = DB_ROOT / f"{db_id}.sqlite"
|
| 36 |
+
return str(path1) if path1.exists() else str(path2)
|
| 37 |
+
|
| 38 |
+
# ==========================================
|
| 39 |
+
# 🔥 CUDA MOCK PATCH FOR MAC (MPS) / CPU
|
| 40 |
+
# ==========================================
|
| 41 |
+
if not torch.cuda.is_available():
|
| 42 |
+
class MockCUDAEvent:
|
| 43 |
+
def __init__(self, enable_timing=False, blocking=False, interprocess=False):
|
| 44 |
+
self.t = 0.0
|
| 45 |
+
def record(self, stream=None):
|
| 46 |
+
self.t = time.perf_counter()
|
| 47 |
+
def elapsed_time(self, end_event):
|
| 48 |
+
return (end_event.t - self.t) * 1000.0
|
| 49 |
+
|
| 50 |
+
torch.cuda.Event = MockCUDAEvent
|
| 51 |
+
if not hasattr(torch.cuda, 'synchronize'):
|
| 52 |
+
torch.cuda.synchronize = lambda: None
|
| 53 |
+
|
| 54 |
+
# ==========================================
|
| 55 |
+
# IMPORTS & ENGINE SETUP
|
| 56 |
+
# ==========================================
|
| 57 |
+
from src.quantized_text2sql_engine import QuantizedText2SQLEngine
|
| 58 |
+
from src.schema_encoder import SchemaEncoder
|
| 59 |
+
|
| 60 |
+
DEFAULT_QUANT_ARTIFACT = str(PROJECT_ROOT / "int8_dynamic")
|
| 61 |
+
|
| 62 |
+
_ENGINE_CACHE = {}
|
| 63 |
+
_QUERY_LOG = []
|
| 64 |
+
_PERF_LOG = []
|
| 65 |
+
_SUCCESS_LOG = []
|
| 66 |
+
|
| 67 |
+
_OP_STATS = {
|
| 68 |
+
"SELECT": {"ok": 0, "fail": 0}, "WHERE": {"ok": 0, "fail": 0}, "JOIN": {"ok": 0, "fail": 0},
|
| 69 |
+
"GROUP_BY": {"ok": 0, "fail": 0}, "ORDER_BY": {"ok": 0, "fail": 0}, "HAVING": {"ok": 0, "fail": 0}, "LIMIT": {"ok": 0, "fail": 0},
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
def get_quant_engine(artifact_dir: str, use_constrained: bool = False, exec_workers: int = 8, use_cache: bool = True):
|
| 73 |
+
key = (artifact_dir, bool(use_constrained), int(exec_workers), bool(use_cache))
|
| 74 |
+
if key not in _ENGINE_CACHE:
|
| 75 |
+
try:
|
| 76 |
+
_ENGINE_CACHE[key] = QuantizedText2SQLEngine(artifact_dir, device="cpu", use_constrained=bool(use_constrained), exec_workers=int(exec_workers), use_cache=bool(use_cache))
|
| 77 |
+
except TypeError:
|
| 78 |
+
_ENGINE_CACHE[key] = QuantizedText2SQLEngine(artifact_dir)
|
| 79 |
+
return _ENGINE_CACHE[key]
|
| 80 |
+
|
| 81 |
+
# 🚨 LAZY LOADING: We DO NOT load the model here! We only load the fast Schema Encoder.
|
| 82 |
+
quant_engine = None
|
| 83 |
+
try:
|
| 84 |
+
schema_encoder = SchemaEncoder(DB_ROOT)
|
| 85 |
+
except Exception as e:
|
| 86 |
+
print(f"Warning: SchemaEncoder failed to load: {e}")
|
| 87 |
+
schema_encoder = None
|
| 88 |
+
|
| 89 |
+
SAMPLES = [
|
| 90 |
+
("Show 10 distinct employee first names.", "chinook_1"), ("Which artist has the most albums?", "chinook_1"),
|
| 91 |
+
("List all the tracks that belong to the 'Rock' genre.", "chinook_1"), ("What are the names of all the cities?", "flight_1"),
|
| 92 |
+
("Find the flight number and cost of the cheapest flight.", "flight_1"), ("List the airlines that fly out of New York.", "flight_1"),
|
| 93 |
+
("Which campus was opened between 1935 and 1939?", "csu_1"), ("Count the number of students in each department.", "college_2"),
|
| 94 |
+
("List the names of all clubs.", "club_1"), ("How many members does each club have?", "club_1"),
|
| 95 |
+
("Show the names of all cinemas.", "cinema"), ("Which cinema has the most screens?", "cinema")
|
| 96 |
+
]
|
| 97 |
+
SAMPLE_QUESTIONS = [q[0] for q in SAMPLES]
|
| 98 |
+
|
| 99 |
+
def explain_sql(sql):
|
| 100 |
+
if not sql: return ""
|
| 101 |
+
explanation = "This SQL query retrieves information from the database."
|
| 102 |
+
sql_lower = sql.lower()
|
| 103 |
+
if "join" in sql_lower: explanation += "\n• It combines data from multiple tables using JOIN."
|
| 104 |
+
if "where" in sql_lower: explanation += "\n• It filters rows using a WHERE condition."
|
| 105 |
+
if "group by" in sql_lower: explanation += "\n• It groups results using GROUP BY."
|
| 106 |
+
if "order by" in sql_lower: explanation += "\n• It sorts the results using ORDER BY."
|
| 107 |
+
if "limit" in sql_lower: explanation += "\n• It limits the number of returned rows."
|
| 108 |
+
return explanation
|
| 109 |
+
|
| 110 |
+
def sql_ops(sql: str) -> list[str]:
|
| 111 |
+
s = (sql or "").lower()
|
| 112 |
+
ops = ["SELECT"]
|
| 113 |
+
if " where " in f" {s} ": ops.append("WHERE")
|
| 114 |
+
if " join " in f" {s} ": ops.append("JOIN")
|
| 115 |
+
if " group by " in f" {s} ": ops.append("GROUP_BY")
|
| 116 |
+
if " order by " in f" {s} ": ops.append("ORDER_BY")
|
| 117 |
+
if " having " in f" {s} ": ops.append("HAVING")
|
| 118 |
+
if " limit " in f" {s} ": ops.append("LIMIT")
|
| 119 |
+
return ops
|
| 120 |
+
|
| 121 |
+
def classify_error(sql: str, error_msg: str | None = None, *, timed_out: bool = False):
|
| 122 |
+
s = (sql or "").lower()
|
| 123 |
+
m = (error_msg or "").lower()
|
| 124 |
+
if timed_out or "interrupted" in m or "timeout" in m: return "timeout"
|
| 125 |
+
if not s.strip().startswith(("select", "with")): return "syntax_error"
|
| 126 |
+
if " join " in f" {s} " and " on " not in f" {s} ": return "missing_join"
|
| 127 |
+
if " where " in f" {s} " and not any(op in s for op in ["=", ">", "<", " in ", " like ", " between ", " is null", " is not null"]): return "wrong_where"
|
| 128 |
+
if ("is null" in s or "is not null" in s) and ("no such column" in m or "misuse" in m): return "null_handling"
|
| 129 |
+
if "no such table" in m: return "missing_table"
|
| 130 |
+
if "no such column" in m: return "missing_column"
|
| 131 |
+
if "ambiguous column name" in m: return "ambiguous_column"
|
| 132 |
+
if "datatype mismatch" in m or "type mismatch" in m: return "type_mismatch"
|
| 133 |
+
if "misuse of aggregate" in m or "misuse of aggregate function" in m: return "wrong_aggregation"
|
| 134 |
+
if "syntax error" in m: return "syntax_error"
|
| 135 |
+
if "near" in m and "syntax error" in m: return "syntax_error"
|
| 136 |
+
if "runtime" in m or "constraint failed" in m: return "runtime_error"
|
| 137 |
+
return "other"
|
| 138 |
+
|
| 139 |
+
def get_hint(error_type):
|
| 140 |
+
hints = {
|
| 141 |
+
"missing_join": "Check JOIN conditions between tables.", "wrong_aggregation": "Use proper aggregation like avg(column).",
|
| 142 |
+
"wrong_where": "Check WHERE condition syntax.", "syntax_error": "Ensure SQL starts with SELECT.",
|
| 143 |
+
"missing_table": "Use only tables from the provided schema.", "missing_column": "Use only columns from the provided schema.",
|
| 144 |
+
"ambiguous_column": "Disambiguate by using table.column.", "timeout": "Query took too long; simplify joins.", "other": "Review SQL logic."
|
| 145 |
+
}
|
| 146 |
+
return hints.get(error_type, "Review query.")
|
| 147 |
+
|
| 148 |
+
def is_relevant_to_schema(question, db_id):
|
| 149 |
+
if schema_encoder is None: return True
|
| 150 |
+
try: raw_schema = schema_encoder.structured_schema(db_id).lower()
|
| 151 |
+
except: return True
|
| 152 |
+
schema_words = set(re.findall(r'[a-z0-9_]+', raw_schema))
|
| 153 |
+
q_words = re.findall(r'[a-z0-9_]+', question.lower())
|
| 154 |
+
stop_words = {"show", "list", "all", "what", "is", "the", "how", "many", "count", "find", "get", "me", "a", "an", "of", "in", "for", "from", "with", "which", "are", "there", "give", "tell", "details", "info", "data", "everything"}
|
| 155 |
+
meaningful_q_words = [w for w in q_words if w not in stop_words and not w.isdigit()]
|
| 156 |
+
if not meaningful_q_words: return True
|
| 157 |
+
for word in meaningful_q_words:
|
| 158 |
+
singular_word = word[:-1] if word.endswith('s') else word
|
| 159 |
+
if word in schema_words or singular_word in schema_words: return True
|
| 160 |
+
return False
|
| 161 |
+
|
| 162 |
+
def run_query(method, sample_q, custom_q, db_id):
|
| 163 |
+
global quant_engine
|
| 164 |
+
|
| 165 |
+
# 🚨 LAZY LOADING: We load the heavy AI model ONLY when the button is clicked.
|
| 166 |
+
if quant_engine is None:
|
| 167 |
+
print(f"First request detected! Loading AI model from {DEFAULT_QUANT_ARTIFACT}...", flush=True)
|
| 168 |
+
try:
|
| 169 |
+
quant_engine = get_quant_engine(DEFAULT_QUANT_ARTIFACT, use_constrained=False, exec_workers=8, use_cache=True)
|
| 170 |
+
if quant_engine is None:
|
| 171 |
+
return "-- ❌ ENGINE CRASH", pd.DataFrame(columns=["Error"]), "Failed to load model. Did you move the tokenizer files and add config.json to int8_dynamic/?"
|
| 172 |
+
except Exception as e:
|
| 173 |
+
return f"-- ❌ ENGINE CRASH\n-- {str(e)}", pd.DataFrame(columns=["Error Status"]), f"Critical failure loading model: {e}"
|
| 174 |
+
|
| 175 |
+
def _log(error_type: str, *, question: str, db_id_val: str, sql: str = "", error_msg: str = "") -> None:
|
| 176 |
+
_QUERY_LOG.append({"t": time.time(), "db_id": str(db_id_val), "question": str(question), "sql": str(sql), "error_type": str(error_type), "error_msg": str(error_msg)})
|
| 177 |
+
|
| 178 |
+
def _perf_log(payload: dict) -> None:
|
| 179 |
+
_PERF_LOG.append(payload)
|
| 180 |
+
if len(_PERF_LOG) > 1000: del _PERF_LOG[:200]
|
| 181 |
+
|
| 182 |
+
raw_question = sample_q if method == "💡 Pick a Sample" else custom_q
|
| 183 |
+
|
| 184 |
+
if not raw_question or str(raw_question).strip() == "":
|
| 185 |
+
return "-- No input provided", pd.DataFrame(columns=["Warning"]), "⚠️ Please enter a question."
|
| 186 |
+
if not db_id or str(db_id).strip() == "":
|
| 187 |
+
return "-- No database selected", pd.DataFrame(columns=["Warning"]), "⚠️ Please select a database."
|
| 188 |
+
|
| 189 |
+
typo_corrections = [(r'\bshaw\b', 'show'), (r'\bshw\b', 'show'), (r'\bsho\b', 'show'), (r'\blsit\b', 'list'), (r'\blis\b', 'list'), (r'\bfidn\b', 'find'), (r'\bfnd\b', 'find'), (r'\bgte\b', 'get')]
|
| 190 |
+
question = str(raw_question)
|
| 191 |
+
for bad, good in typo_corrections: question = re.sub(bad, good, question, flags=re.IGNORECASE)
|
| 192 |
+
q_lower = question.strip().lower()
|
| 193 |
+
|
| 194 |
+
if len(q_lower.split()) < 2:
|
| 195 |
+
_log("gibberish", question=question, db_id_val=str(db_id), error_msg="gibberish filtered")
|
| 196 |
+
return "-- Input Blocked", pd.DataFrame(columns=["Warning"]), "⚠️ Please enter a clear, meaningful natural language question (more than one word)."
|
| 197 |
+
|
| 198 |
+
if re.search(r'\b(delete|update|insert|drop|alter|truncate)\b', q_lower):
|
| 199 |
+
_log("blocked_dml", question=question, db_id_val=str(db_id), error_msg="DML blocked")
|
| 200 |
+
return "-- ❌ BLOCKED: Data Modification", pd.DataFrame(columns=["Security Alert"]), "🛑 Security Alert: Modifying or deleting data is strictly prohibited."
|
| 201 |
+
|
| 202 |
+
if not is_relevant_to_schema(question, db_id):
|
| 203 |
+
_log("out_of_domain", question=question, db_id_val=str(db_id), error_msg="out of domain")
|
| 204 |
+
return "-- ❌ BLOCKED: Out of Domain", pd.DataFrame(columns=["Domain Alert"]), f"🛑 Relevance Alert: I don't see anything related to your question in the '{db_id}' schema."
|
| 205 |
+
|
| 206 |
+
start_time = time.time()
|
| 207 |
+
t0 = time.perf_counter()
|
| 208 |
+
ui_warnings = ""
|
| 209 |
+
|
| 210 |
+
try:
|
| 211 |
+
try:
|
| 212 |
+
result = quant_engine.ask(question, str(db_id), num_beams=4, max_new_tokens=120, timeout_s=2.0)
|
| 213 |
+
except TypeError:
|
| 214 |
+
result = quant_engine.ask(question, str(db_id))
|
| 215 |
+
except Exception as e:
|
| 216 |
+
_log("backend_crash", question=question, db_id_val=str(db_id), error_msg=str(e))
|
| 217 |
+
return f"-- ❌ BACKEND CRASH\n-- {str(e)}", pd.DataFrame(columns=["Error Status"]), f"❌ CRITICAL BACKEND CRASH:\n{str(e)}"
|
| 218 |
+
|
| 219 |
+
final_sql = str(result.get("sql", ""))
|
| 220 |
+
model_sql = final_sql
|
| 221 |
+
|
| 222 |
+
num_match = re.search(r'\b(?:show|list|top|limit|get|first|last|sample|of)\s+(?:[a-zA-Z_]+\s+)?(\d+)\b', q_lower)
|
| 223 |
+
if not num_match and q_lower.startswith(("show", "list", "get")):
|
| 224 |
+
num_match = re.search(r'\b(\d+)\b', q_lower)
|
| 225 |
+
|
| 226 |
+
if num_match and final_sql:
|
| 227 |
+
limit_val = num_match.group(1)
|
| 228 |
+
final_sql = re.sub(rf"(?i)\s*(?:where|having|and)?\s*count\s*\(\s*\*\s*\)\s*=\s*{limit_val}", "", final_sql)
|
| 229 |
+
final_sql = re.sub(rf"(?i)\s*(?:where|and)\s+[a-zA-Z0-9_.]+\s*=\s*['\"]?{limit_val}['\"]?", "", final_sql)
|
| 230 |
+
final_sql = re.sub(r"(?i)\s*where\s*$", "", final_sql)
|
| 231 |
+
final_sql = re.sub(r"(?i)\s*where\s+(group by|order by|limit)", r" \1", final_sql)
|
| 232 |
+
|
| 233 |
+
agg_kws = ["most", "top", "highest", "lowest", "count", "many", "group", "frequent", "popular"]
|
| 234 |
+
if not any(k in q_lower for k in agg_kws):
|
| 235 |
+
final_sql = re.sub(r"(?i)\s*group by\s+[a-zA-Z0-9_.]+\s*order by\s+count\(\*\)\s*(?:desc|asc)?", "", final_sql)
|
| 236 |
+
final_sql = re.sub(r"(?i)\s*order by\s+count\(\*\)\s*(?:desc|asc)?", "", final_sql)
|
| 237 |
+
final_sql = re.sub(r"(?i),\s*count\(\*\)", "", final_sql)
|
| 238 |
+
final_sql = re.sub(r"(?i)count\(\*\)\s*,", "", final_sql)
|
| 239 |
+
|
| 240 |
+
if "group by" in final_sql.lower() and not re.search(r'(?i)\b(count|sum|avg|max|min)\b\(', final_sql):
|
| 241 |
+
final_sql = re.sub(r"(?i)\s*group by\s+[a-zA-Z0-9_.]+", "", final_sql)
|
| 242 |
+
|
| 243 |
+
if "limit" not in final_sql.lower():
|
| 244 |
+
final_sql = f"{final_sql.strip().rstrip(';')} LIMIT {limit_val}"
|
| 245 |
+
|
| 246 |
+
# Execution
|
| 247 |
+
from src.sql_validator import validate_sql_schema
|
| 248 |
+
db_path = get_db_path(str(db_id))
|
| 249 |
+
|
| 250 |
+
try: strict_valid, _ = validate_sql_schema(final_sql, db_path)
|
| 251 |
+
except Exception: strict_valid = False
|
| 252 |
+
|
| 253 |
+
error_msg = None
|
| 254 |
+
rows, cols = [], []
|
| 255 |
+
sqlite_success = False
|
| 256 |
+
|
| 257 |
+
try:
|
| 258 |
+
rows, cols = quant_engine._execute_one(final_sql, db_path, timeout_s=2.0)
|
| 259 |
+
sqlite_success = True
|
| 260 |
+
except Exception as e:
|
| 261 |
+
error_msg = str(e)
|
| 262 |
+
sqlite_success = False
|
| 263 |
+
|
| 264 |
+
if not sqlite_success and model_sql and model_sql != final_sql:
|
| 265 |
+
try:
|
| 266 |
+
alt_rows, alt_cols = quant_engine._execute_one(model_sql, db_path, timeout_s=2.0)
|
| 267 |
+
final_sql = model_sql
|
| 268 |
+
rows, cols = alt_rows, alt_cols
|
| 269 |
+
sqlite_success = True
|
| 270 |
+
error_msg = None
|
| 271 |
+
except Exception: pass
|
| 272 |
+
|
| 273 |
+
valid = sqlite_success
|
| 274 |
+
|
| 275 |
+
if error_msg or not valid:
|
| 276 |
+
et = classify_error(final_sql, str(error_msg or ""), timed_out=("interrupted" in str(error_msg or "").lower()))
|
| 277 |
+
_log(et, question=str(question), db_id_val=str(db_id), sql=str(final_sql), error_msg=str(error_msg or "Execution failed"))
|
| 278 |
+
|
| 279 |
+
latency = round(time.time() - start_time, 3)
|
| 280 |
+
t1 = time.perf_counter()
|
| 281 |
+
|
| 282 |
+
engine_stats_after = quant_engine.stats() if hasattr(quant_engine, 'stats') else {}
|
| 283 |
+
|
| 284 |
+
perf = {
|
| 285 |
+
"db_id": str(db_id), "use_constrained_decoding": False, "num_beams": 4,
|
| 286 |
+
"latency_total_ms": round((t1 - t0) * 1000.0, 2), "constraint_ok": bool(strict_valid), "has_error": bool(error_msg),
|
| 287 |
+
"exec_cache_hit_rate": float(engine_stats_after.get("exec_cache_hit_rate", 0.0) or 0.0),
|
| 288 |
+
}
|
| 289 |
+
_perf_log(perf)
|
| 290 |
+
|
| 291 |
+
window = _PERF_LOG[-50:]
|
| 292 |
+
avg_ms = sum(float(x.get("latency_total_ms", 0.0) or 0.0) for x in window) / len(window) if window else 0.0
|
| 293 |
+
constraint_rate = sum(1 for x in window if x.get("constraint_ok")) / len(window) if window else 0.0
|
| 294 |
+
|
| 295 |
+
perf_block = (
|
| 296 |
+
"\n\n---\nPerformance (task impact)\n"
|
| 297 |
+
f"- Total latency (ms): {perf['latency_total_ms']}\n"
|
| 298 |
+
f"- Strict Python Validator OK (Task 3): {perf['constraint_ok']}\n"
|
| 299 |
+
f"- Exec cache hit-rate (Task 1/5): {round(perf['exec_cache_hit_rate'], 3)}\n"
|
| 300 |
+
f"- Rolling avg latency last 50 (ms): {round(avg_ms, 2)}\n"
|
| 301 |
+
f"- Rolling constraint rate last 50: {round(constraint_rate, 3)}\n"
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
if error_msg or not valid:
|
| 305 |
+
display_sql = final_sql if final_sql.strip() else "-- ❌ INVALID SQL"
|
| 306 |
+
explanation = f"{ui_warnings}❌ Error Details:\n\n"
|
| 307 |
+
if error_msg: explanation += f"{error_msg}\n\n"
|
| 308 |
+
|
| 309 |
+
error_type = classify_error(final_sql, str(error_msg or ""))
|
| 310 |
+
explanation += f"Error Type: {error_type}\nHint: {get_hint(error_type)}"
|
| 311 |
+
explanation += perf_block
|
| 312 |
+
ops = sql_ops(final_sql)
|
| 313 |
+
for op in ops:
|
| 314 |
+
if op in _OP_STATS: _OP_STATS[op]["fail"] += 1
|
| 315 |
+
return display_sql, pd.DataFrame(columns=["Execution Notice"]), explanation
|
| 316 |
+
|
| 317 |
+
safe_cols = cols if cols else ["Result"]
|
| 318 |
+
explanation = f"{ui_warnings}✅ Query executed successfully\n\nRows returned: {len(rows)}\nExecution Time: {latency} sec\n\n{explain_sql(final_sql)}{perf_block}"
|
| 319 |
+
|
| 320 |
+
ops = sql_ops(final_sql)
|
| 321 |
+
for op in ops:
|
| 322 |
+
if op in _OP_STATS: _OP_STATS[op]["ok"] += 1
|
| 323 |
+
_SUCCESS_LOG.append({"t": time.time(), "db_id": str(db_id), "question": question, "sql": final_sql, "ops": ops})
|
| 324 |
+
|
| 325 |
+
limit_match = re.search(r'LIMIT\s+(\d+)', final_sql, re.IGNORECASE)
|
| 326 |
+
if limit_match and len(rows) < int(limit_match.group(1)):
|
| 327 |
+
explanation += f"\n\nℹ️ Query allowed up to {int(limit_match.group(1))} rows but only {len(rows)} matched."
|
| 328 |
+
|
| 329 |
+
return final_sql, pd.DataFrame(rows, columns=safe_cols), explanation
|
| 330 |
+
|
| 331 |
+
def task1_benchmark(n_rollouts: int, max_workers: int) -> Iterator[tuple[str, str]]:
|
| 332 |
+
project_root = str(PROJECT_ROOT)
|
| 333 |
+
env = os.environ.copy()
|
| 334 |
+
env["PYTHONPATH"] = project_root + (os.pathsep + env["PYTHONPATH"] if env.get("PYTHONPATH") else "")
|
| 335 |
+
env.setdefault("MPLBACKEND", "Agg")
|
| 336 |
+
env.setdefault("MPLCONFIGDIR", "/tmp/mplconfig")
|
| 337 |
+
try: os.makedirs(env["MPLCONFIGDIR"], exist_ok=True)
|
| 338 |
+
except Exception: pass
|
| 339 |
+
|
| 340 |
+
cmd = [sys.executable, "-u", "scripts/benchmark_parallel_reward.py", "--n", str(int(n_rollouts)), "--max-workers", str(int(max_workers)), "--skip-profile"]
|
| 341 |
+
proc = subprocess.Popen(cmd, cwd=project_root, env=env, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, bufsize=1)
|
| 342 |
+
last_yield = time.perf_counter()
|
| 343 |
+
lines: list[str] = []
|
| 344 |
+
yield "Running Task 1 benchmark...\n", "<i>Running...</i>"
|
| 345 |
+
|
| 346 |
+
assert proc.stdout is not None
|
| 347 |
+
for line in proc.stdout:
|
| 348 |
+
lines.append(line)
|
| 349 |
+
now = time.perf_counter()
|
| 350 |
+
if now - last_yield >= 0.5:
|
| 351 |
+
last_yield = now
|
| 352 |
+
yield "".join(lines[-200:]).strip(), "<i>Running...</i>"
|
| 353 |
+
|
| 354 |
+
proc.wait()
|
| 355 |
+
out = "".join(lines).strip()
|
| 356 |
+
|
| 357 |
+
plot_path = str(PROJECT_ROOT / "results" / "task1_plot.png")
|
| 358 |
+
if os.path.exists(plot_path):
|
| 359 |
+
try:
|
| 360 |
+
b64 = base64.b64encode(Path(plot_path).read_bytes()).decode("ascii")
|
| 361 |
+
yield out, f"<img src='data:image/png;base64,{b64}' style='max-width: 100%; border: 1px solid #e2e8f0; border-radius: 8px;' />"
|
| 362 |
+
return
|
| 363 |
+
except Exception:
|
| 364 |
+
yield out, f"<pre>{plot_path}</pre>"
|
| 365 |
+
return
|
| 366 |
+
|
| 367 |
+
yield out, "<i>No plot generated</i>"
|
| 368 |
+
|
| 369 |
+
def task2_dashboard_structured():
|
| 370 |
+
if not _QUERY_LOG:
|
| 371 |
+
empty_counts = pd.DataFrame(columns=["error_type", "count", "hint"])
|
| 372 |
+
empty_recent = pd.DataFrame(columns=["time", "db_id", "error_type", "question", "error_msg"])
|
| 373 |
+
return empty_counts, empty_recent, gr.update(choices=[], value=None)
|
| 374 |
+
|
| 375 |
+
counts = {}
|
| 376 |
+
for r in _QUERY_LOG[-1000:]:
|
| 377 |
+
k = r.get("error_type") or "other"
|
| 378 |
+
counts[k] = counts.get(k, 0) + 1
|
| 379 |
+
rows = [{"error_type": k, "count": int(v), "hint": get_hint(k)} for k, v in sorted(counts.items(), key=lambda x: (-x[1], x[0]))]
|
| 380 |
+
counts_df = pd.DataFrame(rows)
|
| 381 |
+
|
| 382 |
+
recent = []
|
| 383 |
+
for r in _QUERY_LOG[-100:]:
|
| 384 |
+
ts = r.get("t")
|
| 385 |
+
try: ts_s = time.strftime("%H:%M:%S", time.localtime(float(ts))) if ts else ""
|
| 386 |
+
except Exception: ts_s = ""
|
| 387 |
+
recent.append({"time": ts_s, "db_id": r.get("db_id", ""), "error_type": r.get("error_type", ""), "question": r.get("question", ""), "error_msg": r.get("error_msg", "")})
|
| 388 |
+
recent_df = pd.DataFrame(recent)
|
| 389 |
+
|
| 390 |
+
choices = [str(x["error_type"]) for x in rows]
|
| 391 |
+
default = choices[0] if choices else None
|
| 392 |
+
return counts_df, recent_df, gr.update(choices=choices, value=default)
|
| 393 |
+
|
| 394 |
+
def task2_error_examples(error_type: str) -> str:
|
| 395 |
+
if not error_type: return ""
|
| 396 |
+
hint = get_hint(error_type)
|
| 397 |
+
matches = [r for r in reversed(_QUERY_LOG) if (r.get("error_type") or "") == str(error_type)][:3]
|
| 398 |
+
if not matches: return f"Error type: {error_type}\nHint: {hint}\n\nNo examples yet."
|
| 399 |
+
out = [f"Error type: {error_type}", f"Hint: {hint}", ""]
|
| 400 |
+
for i, r in enumerate(matches, 1):
|
| 401 |
+
out.extend([f"Example {i}", f"DB: {r.get('db_id','')}", f"Q: {r.get('question','')}", f"SQL: {r.get('sql','')}", f"Msg: {r.get('error_msg','')}", ""])
|
| 402 |
+
return "\n".join(out).strip()
|
| 403 |
+
|
| 404 |
+
def _plot_op_stats_html() -> str:
|
| 405 |
+
try:
|
| 406 |
+
import matplotlib.pyplot as plt
|
| 407 |
+
labels = list(_OP_STATS.keys())
|
| 408 |
+
oks = [int(_OP_STATS[k]["ok"]) for k in labels]
|
| 409 |
+
fails = [int(_OP_STATS[k]["fail"]) for k in labels]
|
| 410 |
+
|
| 411 |
+
fig, ax = plt.subplots(figsize=(9, 3.5))
|
| 412 |
+
x = list(range(len(labels)))
|
| 413 |
+
ax.bar(x, oks, label="ok", color="#16a34a")
|
| 414 |
+
ax.bar(x, fails, bottom=oks, label="fail", color="#dc2626")
|
| 415 |
+
ax.set_xticks(x)
|
| 416 |
+
ax.set_xticklabels(labels, rotation=30, ha="right")
|
| 417 |
+
ax.set_title("Success/Failure by SQL operation")
|
| 418 |
+
ax.legend()
|
| 419 |
+
fig.tight_layout()
|
| 420 |
+
|
| 421 |
+
buf = io.BytesIO()
|
| 422 |
+
fig.savefig(buf, format="png", dpi=160)
|
| 423 |
+
plt.close(fig)
|
| 424 |
+
b64 = base64.b64encode(buf.getvalue()).decode("ascii")
|
| 425 |
+
return f"<img src='data:image/png;base64,{b64}' style='max-width: 100%; border: 1px solid #e2e8f0; border-radius: 8px;' />"
|
| 426 |
+
except Exception as e: return f"<pre>Plot error: {e}</pre>"
|
| 427 |
+
|
| 428 |
+
def task2_ops_table():
|
| 429 |
+
rows = []
|
| 430 |
+
for op, d in _OP_STATS.items():
|
| 431 |
+
ok = int(d.get("ok", 0))
|
| 432 |
+
fail = int(d.get("fail", 0))
|
| 433 |
+
total = ok + fail
|
| 434 |
+
rows.append({"op": op, "ok": ok, "fail": fail, "total": total, "success_rate": (ok / total) if total else 0.0})
|
| 435 |
+
return pd.DataFrame(rows), _plot_op_stats_html()
|
| 436 |
+
|
| 437 |
+
def toggle_input_method(method, current_sample):
|
| 438 |
+
if method == "💡 Pick a Sample":
|
| 439 |
+
db = next((db for q, db in SAMPLES if q == current_sample), "chinook_1")
|
| 440 |
+
return (gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(value=db, interactive=False))
|
| 441 |
+
return (gr.update(visible=False), gr.update(visible=True), gr.update(visible=True), gr.update(interactive=True))
|
| 442 |
+
|
| 443 |
+
def load_sample(selected_question):
|
| 444 |
+
if not selected_question: return gr.update()
|
| 445 |
+
return gr.update(value=next((db for q, db in SAMPLES if q == selected_question), "chinook_1"))
|
| 446 |
+
|
| 447 |
+
def clear_inputs():
|
| 448 |
+
return (gr.update(value="💡 Pick a Sample"), gr.update(value=SAMPLE_QUESTIONS[0], visible=True), gr.update(visible=False), gr.update(value="", visible=False), gr.update(value="chinook_1", interactive=False), "", pd.DataFrame(), "")
|
| 449 |
+
|
| 450 |
+
def update_schema(db_id):
|
| 451 |
+
if not db_id or schema_encoder is None: return ""
|
| 452 |
+
try:
|
| 453 |
+
raw_schema = schema_encoder.structured_schema(db_id)
|
| 454 |
+
html_output = "<div style='max-height: 250px; overflow-y: auto; background: #f8fafc; padding: 12px; border-radius: 8px; border: 1px solid #e2e8f0; font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, monospace; font-size: 0.9em; line-height: 1.6;'>"
|
| 455 |
+
for line in raw_schema.strip().split('\n'):
|
| 456 |
+
line = line.strip()
|
| 457 |
+
if not line: continue
|
| 458 |
+
match = re.search(r'^([a-zA-Z0-9_]+)\s*\((.*)\)', line)
|
| 459 |
+
if match: html_output += f"<div style='margin-bottom: 8px;'><strong style='color: #0f172a; font-size: 1.05em; font-weight: 800;'>{match.group(1).upper()}</strong> <span style='color: #64748b;'>( {match.group(2).lower()} )</span></div>"
|
| 460 |
+
else: html_output += f"<div style='color: #475569;'>{line}</div>"
|
| 461 |
+
html_output += "</div>"
|
| 462 |
+
return html_output
|
| 463 |
+
except Exception as e: return f"<div style='color: red;'>Error loading schema: {str(e)}</div>"
|
| 464 |
+
|
| 465 |
+
# =========================
|
| 466 |
+
# UI LAYOUT
|
| 467 |
+
# =========================
|
| 468 |
+
with gr.Blocks(title="Text-to-SQL RLHF") as demo:
|
| 469 |
+
gr.HTML("""
|
| 470 |
+
<div style="text-align: center; background-color: #e0e7ff; padding: 20px; border-radius: 10px; margin-bottom: 20px; border: 1px solid #c7d2fe;">
|
| 471 |
+
<h1 style="color: #3730a3; margin-top: 0; margin-bottom: 10px; font-size: 2.2em;"> Text-to-SQL using RLHF + Execution Reward</h1>
|
| 472 |
+
<p style="color: #4f46e5; font-size: 1.1em; margin: 0;">Convert Natural Language to SQL, strictly validated and safely executed on local SQLite databases.</p>
|
| 473 |
+
</div>
|
| 474 |
+
""")
|
| 475 |
+
|
| 476 |
+
DBS = sorted(["flight_1", "student_assessment", "store_1", "bike_1", "book_2", "chinook_1", "academic", "aircraft", "car_1", "cinema", "club_1", "csu_1", "college_1", "college_2", "company_1", "company_employee", "customer_complaints", "department_store", "employee_hire_evaluation", "museum_visit", "products_for_hire", "restaurant_1", "school_finance", "shop_membership", "small_bank_1", "student_1", "tvshow", "voter_1", "world_1"])
|
| 477 |
+
|
| 478 |
+
with gr.Tabs():
|
| 479 |
+
with gr.Tab("Inference"):
|
| 480 |
+
with gr.Row():
|
| 481 |
+
with gr.Column(scale=1):
|
| 482 |
+
gr.Markdown("### 1. Configuration & Input")
|
| 483 |
+
input_method = gr.Radio(choices=["💡 Pick a Sample", "✍️ Type my own"], value="💡 Pick a Sample", label="How do you want to ask?")
|
| 484 |
+
sample_dropdown = gr.Dropdown(choices=SAMPLE_QUESTIONS, value=SAMPLE_QUESTIONS[0], label="Select a Sample Question", info="The database will be selected automatically.", visible=True)
|
| 485 |
+
type_own_warning = gr.Markdown("**⚠️ Please select a Database first, then type your custom question below:**", visible=False)
|
| 486 |
+
gr.Markdown("---")
|
| 487 |
+
db_id = gr.Dropdown(choices=DBS, value="chinook_1", label="Select Database", interactive=False)
|
| 488 |
+
custom_question = gr.Textbox(label="Ask your Custom Question", placeholder="Type your own question here...", lines=3, visible=False)
|
| 489 |
+
|
| 490 |
+
gr.Markdown("#### 📋 Database Structure")
|
| 491 |
+
gr.HTML("<p style='font-size: 0.85em; color: #64748b; margin-top: -10px; margin-bottom: 5px;'>Use these exact names! Table names are <strong>Dark</strong>, Column names are <span style='color: #94a3b8;'>Light</span>.</p>")
|
| 492 |
+
schema_display = gr.HTML(value=update_schema("chinook_1"))
|
| 493 |
+
|
| 494 |
+
with gr.Row():
|
| 495 |
+
clear_btn = gr.Button("🗑️ Clear", variant="secondary")
|
| 496 |
+
run_btn = gr.Button(" Generate & Run SQL", variant="primary")
|
| 497 |
+
|
| 498 |
+
with gr.Column(scale=2):
|
| 499 |
+
gr.Markdown("### 2. Execution Results")
|
| 500 |
+
final_sql = gr.Code(language="sql", label="Final Executed SQL")
|
| 501 |
+
result_table = gr.Dataframe(label="Query Result Table", interactive=False, wrap=True)
|
| 502 |
+
explanation = gr.Textbox(label="AI Explanation + Execution Details", lines=8)
|
| 503 |
+
|
| 504 |
+
with gr.Tab("Diagnostics"):
|
| 505 |
+
gr.Markdown("## Diagnostics & Telemetry")
|
| 506 |
+
|
| 507 |
+
with gr.Accordion("Task 1: Parallel Reward Benchmark", open=False):
|
| 508 |
+
gr.Markdown("*(Simulates the heavy RLHF training workload by running hundreds of complex SQL queries concurrently to test SQLite multi-threading performance.)*")
|
| 509 |
+
t1_n = gr.Number(value=20, precision=0, label="Rollouts (n)")
|
| 510 |
+
t1_workers = gr.Number(value=10, precision=0, label="Max workers")
|
| 511 |
+
t1_run = gr.Button("Run Task 1 benchmark")
|
| 512 |
+
t1_out = gr.Textbox(label="Output", lines=12)
|
| 513 |
+
t1_plot = gr.HTML(label="Plot (if generated)")
|
| 514 |
+
t1_run.click(fn=task1_benchmark, inputs=[t1_n, t1_workers], outputs=[t1_out, t1_plot])
|
| 515 |
+
|
| 516 |
+
with gr.Accordion("Task 2: Error Dashboard", open=True):
|
| 517 |
+
gr.Markdown("*(Live telemetry tracking the most common SQL failures. Populates automatically when queries fail in the Inference tab.)*")
|
| 518 |
+
t2_refresh = gr.Button("Refresh dashboard")
|
| 519 |
+
t2_counts = gr.Dataframe(label="Error counts", interactive=False, wrap=True)
|
| 520 |
+
t2_recent = gr.Dataframe(label="Recent errors", interactive=False, wrap=True)
|
| 521 |
+
t2_type = gr.Dropdown(choices=[], value=None, label="Select error type")
|
| 522 |
+
t2_examples = gr.Textbox(label="Examples + hint", lines=10)
|
| 523 |
+
|
| 524 |
+
t2_refresh.click(fn=task2_dashboard_structured, inputs=[], outputs=[t2_counts, t2_recent, t2_type])
|
| 525 |
+
t2_type.change(fn=task2_error_examples, inputs=[t2_type], outputs=[t2_examples])
|
| 526 |
+
|
| 527 |
+
with gr.Accordion("Task 2: Clause Telemetry", open=False):
|
| 528 |
+
gr.Markdown("*(Analyzes which specific SQL clauses—SELECT, WHERE, JOIN, etc.—are most prone to errors during natural language generation.)*")
|
| 529 |
+
t2_ops_refresh = gr.Button("Refresh SQL-op stats")
|
| 530 |
+
t2_ops_tbl = gr.Dataframe(label="Success/failure by op", interactive=False, wrap=True)
|
| 531 |
+
t2_ops_plot = gr.HTML(label="Op plot")
|
| 532 |
+
t2_ops_refresh.click(fn=task2_ops_table, inputs=[], outputs=[t2_ops_tbl, t2_ops_plot])
|
| 533 |
+
|
| 534 |
+
# EVENT BINDING: The .then() forces the diagnostic tab to update live in the background!
|
| 535 |
+
input_method.change(fn=toggle_input_method, inputs=[input_method, sample_dropdown], outputs=[sample_dropdown, type_own_warning, custom_question, db_id])
|
| 536 |
+
sample_dropdown.change(fn=load_sample, inputs=[sample_dropdown], outputs=[db_id])
|
| 537 |
+
db_id.change(fn=update_schema, inputs=[db_id], outputs=[schema_display])
|
| 538 |
+
|
| 539 |
+
run_btn.click(
|
| 540 |
+
fn=run_query,
|
| 541 |
+
inputs=[input_method, sample_dropdown, custom_question, db_id],
|
| 542 |
+
outputs=[final_sql, result_table, explanation]
|
| 543 |
+
).then(
|
| 544 |
+
fn=task2_dashboard_structured, inputs=[], outputs=[t2_counts, t2_recent, t2_type]
|
| 545 |
+
).then(
|
| 546 |
+
fn=task2_ops_table, inputs=[], outputs=[t2_ops_tbl, t2_ops_plot]
|
| 547 |
+
)
|
| 548 |
+
|
| 549 |
+
clear_btn.click(fn=clear_inputs, inputs=[], outputs=[input_method, sample_dropdown, type_own_warning, custom_question, db_id, final_sql, result_table, explanation])
|
| 550 |
+
|
| 551 |
+
if __name__ == "__main__":
|
| 552 |
+
server_name = os.environ.get("GRADIO_SERVER_NAME", "0.0.0.0")
|
| 553 |
+
base_port = int(os.environ.get("GRADIO_SERVER_PORT", 7860))
|
| 554 |
+
max_retries = 10
|
| 555 |
+
|
| 556 |
+
for port in range(base_port, base_port + max_retries):
|
| 557 |
+
try:
|
| 558 |
+
print(f"Attempting to start Gradio UI on {server_name}:{port}...", flush=True)
|
| 559 |
+
demo.launch(server_name=server_name, server_port=port)
|
| 560 |
+
break # If successful, exit the loop
|
| 561 |
+
except OSError as e:
|
| 562 |
+
if "Cannot find empty port" in str(e) or "Address already in use" in str(e):
|
| 563 |
+
print(f"⚠️ Port {port} is in use, trying next port...")
|
| 564 |
+
continue
|
| 565 |
+
else:
|
| 566 |
+
# If it's a different OSError, raise it normally
|
| 567 |
+
raise e
|
| 568 |
+
else:
|
| 569 |
+
print(f"❌ Could not find an open port between {base_port} and {base_port + max_retries - 1}.")
|
db.zip
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:bb0e6ef12110c4c9808205cb210a35b5c4412397e15f47ed437e739e161d4213
|
| 3 |
+
size 53803466
|
int8_dynamic/meta.json
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"mode": "int8_dynamic",
|
| 3 |
+
"base_model": "Salesforce/codet5-base",
|
| 4 |
+
"adapter_path": "checkpoints/best_rlhf_model_2",
|
| 5 |
+
"created_at_s": 1774418718.320342,
|
| 6 |
+
"estimated_model_bytes": 98804736
|
| 7 |
+
}
|
int8_dynamic/model.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f398e044cd49fc84553b746d26ad79beb1dd565d90cf8f6f5e50d27f48d08228
|
| 3 |
+
size 322871519
|
int8_dynamic/tokenizer/merges.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
int8_dynamic/tokenizer/special_tokens_map.json
ADDED
|
@@ -0,0 +1,753 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"additional_special_tokens": [
|
| 3 |
+
{
|
| 4 |
+
"content": "<extra_id_99>",
|
| 5 |
+
"lstrip": true,
|
| 6 |
+
"normalized": true,
|
| 7 |
+
"rstrip": false,
|
| 8 |
+
"single_word": false
|
| 9 |
+
},
|
| 10 |
+
{
|
| 11 |
+
"content": "<extra_id_98>",
|
| 12 |
+
"lstrip": true,
|
| 13 |
+
"normalized": true,
|
| 14 |
+
"rstrip": false,
|
| 15 |
+
"single_word": false
|
| 16 |
+
},
|
| 17 |
+
{
|
| 18 |
+
"content": "<extra_id_97>",
|
| 19 |
+
"lstrip": true,
|
| 20 |
+
"normalized": true,
|
| 21 |
+
"rstrip": false,
|
| 22 |
+
"single_word": false
|
| 23 |
+
},
|
| 24 |
+
{
|
| 25 |
+
"content": "<extra_id_96>",
|
| 26 |
+
"lstrip": true,
|
| 27 |
+
"normalized": true,
|
| 28 |
+
"rstrip": false,
|
| 29 |
+
"single_word": false
|
| 30 |
+
},
|
| 31 |
+
{
|
| 32 |
+
"content": "<extra_id_95>",
|
| 33 |
+
"lstrip": true,
|
| 34 |
+
"normalized": true,
|
| 35 |
+
"rstrip": false,
|
| 36 |
+
"single_word": false
|
| 37 |
+
},
|
| 38 |
+
{
|
| 39 |
+
"content": "<extra_id_94>",
|
| 40 |
+
"lstrip": true,
|
| 41 |
+
"normalized": true,
|
| 42 |
+
"rstrip": false,
|
| 43 |
+
"single_word": false
|
| 44 |
+
},
|
| 45 |
+
{
|
| 46 |
+
"content": "<extra_id_93>",
|
| 47 |
+
"lstrip": true,
|
| 48 |
+
"normalized": true,
|
| 49 |
+
"rstrip": false,
|
| 50 |
+
"single_word": false
|
| 51 |
+
},
|
| 52 |
+
{
|
| 53 |
+
"content": "<extra_id_92>",
|
| 54 |
+
"lstrip": true,
|
| 55 |
+
"normalized": true,
|
| 56 |
+
"rstrip": false,
|
| 57 |
+
"single_word": false
|
| 58 |
+
},
|
| 59 |
+
{
|
| 60 |
+
"content": "<extra_id_91>",
|
| 61 |
+
"lstrip": true,
|
| 62 |
+
"normalized": true,
|
| 63 |
+
"rstrip": false,
|
| 64 |
+
"single_word": false
|
| 65 |
+
},
|
| 66 |
+
{
|
| 67 |
+
"content": "<extra_id_90>",
|
| 68 |
+
"lstrip": true,
|
| 69 |
+
"normalized": true,
|
| 70 |
+
"rstrip": false,
|
| 71 |
+
"single_word": false
|
| 72 |
+
},
|
| 73 |
+
{
|
| 74 |
+
"content": "<extra_id_89>",
|
| 75 |
+
"lstrip": true,
|
| 76 |
+
"normalized": true,
|
| 77 |
+
"rstrip": false,
|
| 78 |
+
"single_word": false
|
| 79 |
+
},
|
| 80 |
+
{
|
| 81 |
+
"content": "<extra_id_88>",
|
| 82 |
+
"lstrip": true,
|
| 83 |
+
"normalized": true,
|
| 84 |
+
"rstrip": false,
|
| 85 |
+
"single_word": false
|
| 86 |
+
},
|
| 87 |
+
{
|
| 88 |
+
"content": "<extra_id_87>",
|
| 89 |
+
"lstrip": true,
|
| 90 |
+
"normalized": true,
|
| 91 |
+
"rstrip": false,
|
| 92 |
+
"single_word": false
|
| 93 |
+
},
|
| 94 |
+
{
|
| 95 |
+
"content": "<extra_id_86>",
|
| 96 |
+
"lstrip": true,
|
| 97 |
+
"normalized": true,
|
| 98 |
+
"rstrip": false,
|
| 99 |
+
"single_word": false
|
| 100 |
+
},
|
| 101 |
+
{
|
| 102 |
+
"content": "<extra_id_85>",
|
| 103 |
+
"lstrip": true,
|
| 104 |
+
"normalized": true,
|
| 105 |
+
"rstrip": false,
|
| 106 |
+
"single_word": false
|
| 107 |
+
},
|
| 108 |
+
{
|
| 109 |
+
"content": "<extra_id_84>",
|
| 110 |
+
"lstrip": true,
|
| 111 |
+
"normalized": true,
|
| 112 |
+
"rstrip": false,
|
| 113 |
+
"single_word": false
|
| 114 |
+
},
|
| 115 |
+
{
|
| 116 |
+
"content": "<extra_id_83>",
|
| 117 |
+
"lstrip": true,
|
| 118 |
+
"normalized": true,
|
| 119 |
+
"rstrip": false,
|
| 120 |
+
"single_word": false
|
| 121 |
+
},
|
| 122 |
+
{
|
| 123 |
+
"content": "<extra_id_82>",
|
| 124 |
+
"lstrip": true,
|
| 125 |
+
"normalized": true,
|
| 126 |
+
"rstrip": false,
|
| 127 |
+
"single_word": false
|
| 128 |
+
},
|
| 129 |
+
{
|
| 130 |
+
"content": "<extra_id_81>",
|
| 131 |
+
"lstrip": true,
|
| 132 |
+
"normalized": true,
|
| 133 |
+
"rstrip": false,
|
| 134 |
+
"single_word": false
|
| 135 |
+
},
|
| 136 |
+
{
|
| 137 |
+
"content": "<extra_id_80>",
|
| 138 |
+
"lstrip": true,
|
| 139 |
+
"normalized": true,
|
| 140 |
+
"rstrip": false,
|
| 141 |
+
"single_word": false
|
| 142 |
+
},
|
| 143 |
+
{
|
| 144 |
+
"content": "<extra_id_79>",
|
| 145 |
+
"lstrip": true,
|
| 146 |
+
"normalized": true,
|
| 147 |
+
"rstrip": false,
|
| 148 |
+
"single_word": false
|
| 149 |
+
},
|
| 150 |
+
{
|
| 151 |
+
"content": "<extra_id_78>",
|
| 152 |
+
"lstrip": true,
|
| 153 |
+
"normalized": true,
|
| 154 |
+
"rstrip": false,
|
| 155 |
+
"single_word": false
|
| 156 |
+
},
|
| 157 |
+
{
|
| 158 |
+
"content": "<extra_id_77>",
|
| 159 |
+
"lstrip": true,
|
| 160 |
+
"normalized": true,
|
| 161 |
+
"rstrip": false,
|
| 162 |
+
"single_word": false
|
| 163 |
+
},
|
| 164 |
+
{
|
| 165 |
+
"content": "<extra_id_76>",
|
| 166 |
+
"lstrip": true,
|
| 167 |
+
"normalized": true,
|
| 168 |
+
"rstrip": false,
|
| 169 |
+
"single_word": false
|
| 170 |
+
},
|
| 171 |
+
{
|
| 172 |
+
"content": "<extra_id_75>",
|
| 173 |
+
"lstrip": true,
|
| 174 |
+
"normalized": true,
|
| 175 |
+
"rstrip": false,
|
| 176 |
+
"single_word": false
|
| 177 |
+
},
|
| 178 |
+
{
|
| 179 |
+
"content": "<extra_id_74>",
|
| 180 |
+
"lstrip": true,
|
| 181 |
+
"normalized": true,
|
| 182 |
+
"rstrip": false,
|
| 183 |
+
"single_word": false
|
| 184 |
+
},
|
| 185 |
+
{
|
| 186 |
+
"content": "<extra_id_73>",
|
| 187 |
+
"lstrip": true,
|
| 188 |
+
"normalized": true,
|
| 189 |
+
"rstrip": false,
|
| 190 |
+
"single_word": false
|
| 191 |
+
},
|
| 192 |
+
{
|
| 193 |
+
"content": "<extra_id_72>",
|
| 194 |
+
"lstrip": true,
|
| 195 |
+
"normalized": true,
|
| 196 |
+
"rstrip": false,
|
| 197 |
+
"single_word": false
|
| 198 |
+
},
|
| 199 |
+
{
|
| 200 |
+
"content": "<extra_id_71>",
|
| 201 |
+
"lstrip": true,
|
| 202 |
+
"normalized": true,
|
| 203 |
+
"rstrip": false,
|
| 204 |
+
"single_word": false
|
| 205 |
+
},
|
| 206 |
+
{
|
| 207 |
+
"content": "<extra_id_70>",
|
| 208 |
+
"lstrip": true,
|
| 209 |
+
"normalized": true,
|
| 210 |
+
"rstrip": false,
|
| 211 |
+
"single_word": false
|
| 212 |
+
},
|
| 213 |
+
{
|
| 214 |
+
"content": "<extra_id_69>",
|
| 215 |
+
"lstrip": true,
|
| 216 |
+
"normalized": true,
|
| 217 |
+
"rstrip": false,
|
| 218 |
+
"single_word": false
|
| 219 |
+
},
|
| 220 |
+
{
|
| 221 |
+
"content": "<extra_id_68>",
|
| 222 |
+
"lstrip": true,
|
| 223 |
+
"normalized": true,
|
| 224 |
+
"rstrip": false,
|
| 225 |
+
"single_word": false
|
| 226 |
+
},
|
| 227 |
+
{
|
| 228 |
+
"content": "<extra_id_67>",
|
| 229 |
+
"lstrip": true,
|
| 230 |
+
"normalized": true,
|
| 231 |
+
"rstrip": false,
|
| 232 |
+
"single_word": false
|
| 233 |
+
},
|
| 234 |
+
{
|
| 235 |
+
"content": "<extra_id_66>",
|
| 236 |
+
"lstrip": true,
|
| 237 |
+
"normalized": true,
|
| 238 |
+
"rstrip": false,
|
| 239 |
+
"single_word": false
|
| 240 |
+
},
|
| 241 |
+
{
|
| 242 |
+
"content": "<extra_id_65>",
|
| 243 |
+
"lstrip": true,
|
| 244 |
+
"normalized": true,
|
| 245 |
+
"rstrip": false,
|
| 246 |
+
"single_word": false
|
| 247 |
+
},
|
| 248 |
+
{
|
| 249 |
+
"content": "<extra_id_64>",
|
| 250 |
+
"lstrip": true,
|
| 251 |
+
"normalized": true,
|
| 252 |
+
"rstrip": false,
|
| 253 |
+
"single_word": false
|
| 254 |
+
},
|
| 255 |
+
{
|
| 256 |
+
"content": "<extra_id_63>",
|
| 257 |
+
"lstrip": true,
|
| 258 |
+
"normalized": true,
|
| 259 |
+
"rstrip": false,
|
| 260 |
+
"single_word": false
|
| 261 |
+
},
|
| 262 |
+
{
|
| 263 |
+
"content": "<extra_id_62>",
|
| 264 |
+
"lstrip": true,
|
| 265 |
+
"normalized": true,
|
| 266 |
+
"rstrip": false,
|
| 267 |
+
"single_word": false
|
| 268 |
+
},
|
| 269 |
+
{
|
| 270 |
+
"content": "<extra_id_61>",
|
| 271 |
+
"lstrip": true,
|
| 272 |
+
"normalized": true,
|
| 273 |
+
"rstrip": false,
|
| 274 |
+
"single_word": false
|
| 275 |
+
},
|
| 276 |
+
{
|
| 277 |
+
"content": "<extra_id_60>",
|
| 278 |
+
"lstrip": true,
|
| 279 |
+
"normalized": true,
|
| 280 |
+
"rstrip": false,
|
| 281 |
+
"single_word": false
|
| 282 |
+
},
|
| 283 |
+
{
|
| 284 |
+
"content": "<extra_id_59>",
|
| 285 |
+
"lstrip": true,
|
| 286 |
+
"normalized": true,
|
| 287 |
+
"rstrip": false,
|
| 288 |
+
"single_word": false
|
| 289 |
+
},
|
| 290 |
+
{
|
| 291 |
+
"content": "<extra_id_58>",
|
| 292 |
+
"lstrip": true,
|
| 293 |
+
"normalized": true,
|
| 294 |
+
"rstrip": false,
|
| 295 |
+
"single_word": false
|
| 296 |
+
},
|
| 297 |
+
{
|
| 298 |
+
"content": "<extra_id_57>",
|
| 299 |
+
"lstrip": true,
|
| 300 |
+
"normalized": true,
|
| 301 |
+
"rstrip": false,
|
| 302 |
+
"single_word": false
|
| 303 |
+
},
|
| 304 |
+
{
|
| 305 |
+
"content": "<extra_id_56>",
|
| 306 |
+
"lstrip": true,
|
| 307 |
+
"normalized": true,
|
| 308 |
+
"rstrip": false,
|
| 309 |
+
"single_word": false
|
| 310 |
+
},
|
| 311 |
+
{
|
| 312 |
+
"content": "<extra_id_55>",
|
| 313 |
+
"lstrip": true,
|
| 314 |
+
"normalized": true,
|
| 315 |
+
"rstrip": false,
|
| 316 |
+
"single_word": false
|
| 317 |
+
},
|
| 318 |
+
{
|
| 319 |
+
"content": "<extra_id_54>",
|
| 320 |
+
"lstrip": true,
|
| 321 |
+
"normalized": true,
|
| 322 |
+
"rstrip": false,
|
| 323 |
+
"single_word": false
|
| 324 |
+
},
|
| 325 |
+
{
|
| 326 |
+
"content": "<extra_id_53>",
|
| 327 |
+
"lstrip": true,
|
| 328 |
+
"normalized": true,
|
| 329 |
+
"rstrip": false,
|
| 330 |
+
"single_word": false
|
| 331 |
+
},
|
| 332 |
+
{
|
| 333 |
+
"content": "<extra_id_52>",
|
| 334 |
+
"lstrip": true,
|
| 335 |
+
"normalized": true,
|
| 336 |
+
"rstrip": false,
|
| 337 |
+
"single_word": false
|
| 338 |
+
},
|
| 339 |
+
{
|
| 340 |
+
"content": "<extra_id_51>",
|
| 341 |
+
"lstrip": true,
|
| 342 |
+
"normalized": true,
|
| 343 |
+
"rstrip": false,
|
| 344 |
+
"single_word": false
|
| 345 |
+
},
|
| 346 |
+
{
|
| 347 |
+
"content": "<extra_id_50>",
|
| 348 |
+
"lstrip": true,
|
| 349 |
+
"normalized": true,
|
| 350 |
+
"rstrip": false,
|
| 351 |
+
"single_word": false
|
| 352 |
+
},
|
| 353 |
+
{
|
| 354 |
+
"content": "<extra_id_49>",
|
| 355 |
+
"lstrip": true,
|
| 356 |
+
"normalized": true,
|
| 357 |
+
"rstrip": false,
|
| 358 |
+
"single_word": false
|
| 359 |
+
},
|
| 360 |
+
{
|
| 361 |
+
"content": "<extra_id_48>",
|
| 362 |
+
"lstrip": true,
|
| 363 |
+
"normalized": true,
|
| 364 |
+
"rstrip": false,
|
| 365 |
+
"single_word": false
|
| 366 |
+
},
|
| 367 |
+
{
|
| 368 |
+
"content": "<extra_id_47>",
|
| 369 |
+
"lstrip": true,
|
| 370 |
+
"normalized": true,
|
| 371 |
+
"rstrip": false,
|
| 372 |
+
"single_word": false
|
| 373 |
+
},
|
| 374 |
+
{
|
| 375 |
+
"content": "<extra_id_46>",
|
| 376 |
+
"lstrip": true,
|
| 377 |
+
"normalized": true,
|
| 378 |
+
"rstrip": false,
|
| 379 |
+
"single_word": false
|
| 380 |
+
},
|
| 381 |
+
{
|
| 382 |
+
"content": "<extra_id_45>",
|
| 383 |
+
"lstrip": true,
|
| 384 |
+
"normalized": true,
|
| 385 |
+
"rstrip": false,
|
| 386 |
+
"single_word": false
|
| 387 |
+
},
|
| 388 |
+
{
|
| 389 |
+
"content": "<extra_id_44>",
|
| 390 |
+
"lstrip": true,
|
| 391 |
+
"normalized": true,
|
| 392 |
+
"rstrip": false,
|
| 393 |
+
"single_word": false
|
| 394 |
+
},
|
| 395 |
+
{
|
| 396 |
+
"content": "<extra_id_43>",
|
| 397 |
+
"lstrip": true,
|
| 398 |
+
"normalized": true,
|
| 399 |
+
"rstrip": false,
|
| 400 |
+
"single_word": false
|
| 401 |
+
},
|
| 402 |
+
{
|
| 403 |
+
"content": "<extra_id_42>",
|
| 404 |
+
"lstrip": true,
|
| 405 |
+
"normalized": true,
|
| 406 |
+
"rstrip": false,
|
| 407 |
+
"single_word": false
|
| 408 |
+
},
|
| 409 |
+
{
|
| 410 |
+
"content": "<extra_id_41>",
|
| 411 |
+
"lstrip": true,
|
| 412 |
+
"normalized": true,
|
| 413 |
+
"rstrip": false,
|
| 414 |
+
"single_word": false
|
| 415 |
+
},
|
| 416 |
+
{
|
| 417 |
+
"content": "<extra_id_40>",
|
| 418 |
+
"lstrip": true,
|
| 419 |
+
"normalized": true,
|
| 420 |
+
"rstrip": false,
|
| 421 |
+
"single_word": false
|
| 422 |
+
},
|
| 423 |
+
{
|
| 424 |
+
"content": "<extra_id_39>",
|
| 425 |
+
"lstrip": true,
|
| 426 |
+
"normalized": true,
|
| 427 |
+
"rstrip": false,
|
| 428 |
+
"single_word": false
|
| 429 |
+
},
|
| 430 |
+
{
|
| 431 |
+
"content": "<extra_id_38>",
|
| 432 |
+
"lstrip": true,
|
| 433 |
+
"normalized": true,
|
| 434 |
+
"rstrip": false,
|
| 435 |
+
"single_word": false
|
| 436 |
+
},
|
| 437 |
+
{
|
| 438 |
+
"content": "<extra_id_37>",
|
| 439 |
+
"lstrip": true,
|
| 440 |
+
"normalized": true,
|
| 441 |
+
"rstrip": false,
|
| 442 |
+
"single_word": false
|
| 443 |
+
},
|
| 444 |
+
{
|
| 445 |
+
"content": "<extra_id_36>",
|
| 446 |
+
"lstrip": true,
|
| 447 |
+
"normalized": true,
|
| 448 |
+
"rstrip": false,
|
| 449 |
+
"single_word": false
|
| 450 |
+
},
|
| 451 |
+
{
|
| 452 |
+
"content": "<extra_id_35>",
|
| 453 |
+
"lstrip": true,
|
| 454 |
+
"normalized": true,
|
| 455 |
+
"rstrip": false,
|
| 456 |
+
"single_word": false
|
| 457 |
+
},
|
| 458 |
+
{
|
| 459 |
+
"content": "<extra_id_34>",
|
| 460 |
+
"lstrip": true,
|
| 461 |
+
"normalized": true,
|
| 462 |
+
"rstrip": false,
|
| 463 |
+
"single_word": false
|
| 464 |
+
},
|
| 465 |
+
{
|
| 466 |
+
"content": "<extra_id_33>",
|
| 467 |
+
"lstrip": true,
|
| 468 |
+
"normalized": true,
|
| 469 |
+
"rstrip": false,
|
| 470 |
+
"single_word": false
|
| 471 |
+
},
|
| 472 |
+
{
|
| 473 |
+
"content": "<extra_id_32>",
|
| 474 |
+
"lstrip": true,
|
| 475 |
+
"normalized": true,
|
| 476 |
+
"rstrip": false,
|
| 477 |
+
"single_word": false
|
| 478 |
+
},
|
| 479 |
+
{
|
| 480 |
+
"content": "<extra_id_31>",
|
| 481 |
+
"lstrip": true,
|
| 482 |
+
"normalized": true,
|
| 483 |
+
"rstrip": false,
|
| 484 |
+
"single_word": false
|
| 485 |
+
},
|
| 486 |
+
{
|
| 487 |
+
"content": "<extra_id_30>",
|
| 488 |
+
"lstrip": true,
|
| 489 |
+
"normalized": true,
|
| 490 |
+
"rstrip": false,
|
| 491 |
+
"single_word": false
|
| 492 |
+
},
|
| 493 |
+
{
|
| 494 |
+
"content": "<extra_id_29>",
|
| 495 |
+
"lstrip": true,
|
| 496 |
+
"normalized": true,
|
| 497 |
+
"rstrip": false,
|
| 498 |
+
"single_word": false
|
| 499 |
+
},
|
| 500 |
+
{
|
| 501 |
+
"content": "<extra_id_28>",
|
| 502 |
+
"lstrip": true,
|
| 503 |
+
"normalized": true,
|
| 504 |
+
"rstrip": false,
|
| 505 |
+
"single_word": false
|
| 506 |
+
},
|
| 507 |
+
{
|
| 508 |
+
"content": "<extra_id_27>",
|
| 509 |
+
"lstrip": true,
|
| 510 |
+
"normalized": true,
|
| 511 |
+
"rstrip": false,
|
| 512 |
+
"single_word": false
|
| 513 |
+
},
|
| 514 |
+
{
|
| 515 |
+
"content": "<extra_id_26>",
|
| 516 |
+
"lstrip": true,
|
| 517 |
+
"normalized": true,
|
| 518 |
+
"rstrip": false,
|
| 519 |
+
"single_word": false
|
| 520 |
+
},
|
| 521 |
+
{
|
| 522 |
+
"content": "<extra_id_25>",
|
| 523 |
+
"lstrip": true,
|
| 524 |
+
"normalized": true,
|
| 525 |
+
"rstrip": false,
|
| 526 |
+
"single_word": false
|
| 527 |
+
},
|
| 528 |
+
{
|
| 529 |
+
"content": "<extra_id_24>",
|
| 530 |
+
"lstrip": true,
|
| 531 |
+
"normalized": true,
|
| 532 |
+
"rstrip": false,
|
| 533 |
+
"single_word": false
|
| 534 |
+
},
|
| 535 |
+
{
|
| 536 |
+
"content": "<extra_id_23>",
|
| 537 |
+
"lstrip": true,
|
| 538 |
+
"normalized": true,
|
| 539 |
+
"rstrip": false,
|
| 540 |
+
"single_word": false
|
| 541 |
+
},
|
| 542 |
+
{
|
| 543 |
+
"content": "<extra_id_22>",
|
| 544 |
+
"lstrip": true,
|
| 545 |
+
"normalized": true,
|
| 546 |
+
"rstrip": false,
|
| 547 |
+
"single_word": false
|
| 548 |
+
},
|
| 549 |
+
{
|
| 550 |
+
"content": "<extra_id_21>",
|
| 551 |
+
"lstrip": true,
|
| 552 |
+
"normalized": true,
|
| 553 |
+
"rstrip": false,
|
| 554 |
+
"single_word": false
|
| 555 |
+
},
|
| 556 |
+
{
|
| 557 |
+
"content": "<extra_id_20>",
|
| 558 |
+
"lstrip": true,
|
| 559 |
+
"normalized": true,
|
| 560 |
+
"rstrip": false,
|
| 561 |
+
"single_word": false
|
| 562 |
+
},
|
| 563 |
+
{
|
| 564 |
+
"content": "<extra_id_19>",
|
| 565 |
+
"lstrip": true,
|
| 566 |
+
"normalized": true,
|
| 567 |
+
"rstrip": false,
|
| 568 |
+
"single_word": false
|
| 569 |
+
},
|
| 570 |
+
{
|
| 571 |
+
"content": "<extra_id_18>",
|
| 572 |
+
"lstrip": true,
|
| 573 |
+
"normalized": true,
|
| 574 |
+
"rstrip": false,
|
| 575 |
+
"single_word": false
|
| 576 |
+
},
|
| 577 |
+
{
|
| 578 |
+
"content": "<extra_id_17>",
|
| 579 |
+
"lstrip": true,
|
| 580 |
+
"normalized": true,
|
| 581 |
+
"rstrip": false,
|
| 582 |
+
"single_word": false
|
| 583 |
+
},
|
| 584 |
+
{
|
| 585 |
+
"content": "<extra_id_16>",
|
| 586 |
+
"lstrip": true,
|
| 587 |
+
"normalized": true,
|
| 588 |
+
"rstrip": false,
|
| 589 |
+
"single_word": false
|
| 590 |
+
},
|
| 591 |
+
{
|
| 592 |
+
"content": "<extra_id_15>",
|
| 593 |
+
"lstrip": true,
|
| 594 |
+
"normalized": true,
|
| 595 |
+
"rstrip": false,
|
| 596 |
+
"single_word": false
|
| 597 |
+
},
|
| 598 |
+
{
|
| 599 |
+
"content": "<extra_id_14>",
|
| 600 |
+
"lstrip": true,
|
| 601 |
+
"normalized": true,
|
| 602 |
+
"rstrip": false,
|
| 603 |
+
"single_word": false
|
| 604 |
+
},
|
| 605 |
+
{
|
| 606 |
+
"content": "<extra_id_13>",
|
| 607 |
+
"lstrip": true,
|
| 608 |
+
"normalized": true,
|
| 609 |
+
"rstrip": false,
|
| 610 |
+
"single_word": false
|
| 611 |
+
},
|
| 612 |
+
{
|
| 613 |
+
"content": "<extra_id_12>",
|
| 614 |
+
"lstrip": true,
|
| 615 |
+
"normalized": true,
|
| 616 |
+
"rstrip": false,
|
| 617 |
+
"single_word": false
|
| 618 |
+
},
|
| 619 |
+
{
|
| 620 |
+
"content": "<extra_id_11>",
|
| 621 |
+
"lstrip": true,
|
| 622 |
+
"normalized": true,
|
| 623 |
+
"rstrip": false,
|
| 624 |
+
"single_word": false
|
| 625 |
+
},
|
| 626 |
+
{
|
| 627 |
+
"content": "<extra_id_10>",
|
| 628 |
+
"lstrip": true,
|
| 629 |
+
"normalized": true,
|
| 630 |
+
"rstrip": false,
|
| 631 |
+
"single_word": false
|
| 632 |
+
},
|
| 633 |
+
{
|
| 634 |
+
"content": "<extra_id_9>",
|
| 635 |
+
"lstrip": true,
|
| 636 |
+
"normalized": true,
|
| 637 |
+
"rstrip": false,
|
| 638 |
+
"single_word": false
|
| 639 |
+
},
|
| 640 |
+
{
|
| 641 |
+
"content": "<extra_id_8>",
|
| 642 |
+
"lstrip": true,
|
| 643 |
+
"normalized": true,
|
| 644 |
+
"rstrip": false,
|
| 645 |
+
"single_word": false
|
| 646 |
+
},
|
| 647 |
+
{
|
| 648 |
+
"content": "<extra_id_7>",
|
| 649 |
+
"lstrip": true,
|
| 650 |
+
"normalized": true,
|
| 651 |
+
"rstrip": false,
|
| 652 |
+
"single_word": false
|
| 653 |
+
},
|
| 654 |
+
{
|
| 655 |
+
"content": "<extra_id_6>",
|
| 656 |
+
"lstrip": true,
|
| 657 |
+
"normalized": true,
|
| 658 |
+
"rstrip": false,
|
| 659 |
+
"single_word": false
|
| 660 |
+
},
|
| 661 |
+
{
|
| 662 |
+
"content": "<extra_id_5>",
|
| 663 |
+
"lstrip": true,
|
| 664 |
+
"normalized": true,
|
| 665 |
+
"rstrip": false,
|
| 666 |
+
"single_word": false
|
| 667 |
+
},
|
| 668 |
+
{
|
| 669 |
+
"content": "<extra_id_4>",
|
| 670 |
+
"lstrip": true,
|
| 671 |
+
"normalized": true,
|
| 672 |
+
"rstrip": false,
|
| 673 |
+
"single_word": false
|
| 674 |
+
},
|
| 675 |
+
{
|
| 676 |
+
"content": "<extra_id_3>",
|
| 677 |
+
"lstrip": true,
|
| 678 |
+
"normalized": true,
|
| 679 |
+
"rstrip": false,
|
| 680 |
+
"single_word": false
|
| 681 |
+
},
|
| 682 |
+
{
|
| 683 |
+
"content": "<extra_id_2>",
|
| 684 |
+
"lstrip": true,
|
| 685 |
+
"normalized": true,
|
| 686 |
+
"rstrip": false,
|
| 687 |
+
"single_word": false
|
| 688 |
+
},
|
| 689 |
+
{
|
| 690 |
+
"content": "<extra_id_1>",
|
| 691 |
+
"lstrip": true,
|
| 692 |
+
"normalized": true,
|
| 693 |
+
"rstrip": false,
|
| 694 |
+
"single_word": false
|
| 695 |
+
},
|
| 696 |
+
{
|
| 697 |
+
"content": "<extra_id_0>",
|
| 698 |
+
"lstrip": true,
|
| 699 |
+
"normalized": true,
|
| 700 |
+
"rstrip": false,
|
| 701 |
+
"single_word": false
|
| 702 |
+
}
|
| 703 |
+
],
|
| 704 |
+
"bos_token": {
|
| 705 |
+
"content": "<s>",
|
| 706 |
+
"lstrip": false,
|
| 707 |
+
"normalized": true,
|
| 708 |
+
"rstrip": false,
|
| 709 |
+
"single_word": false
|
| 710 |
+
},
|
| 711 |
+
"cls_token": {
|
| 712 |
+
"content": "<s>",
|
| 713 |
+
"lstrip": false,
|
| 714 |
+
"normalized": true,
|
| 715 |
+
"rstrip": false,
|
| 716 |
+
"single_word": false
|
| 717 |
+
},
|
| 718 |
+
"eos_token": {
|
| 719 |
+
"content": "</s>",
|
| 720 |
+
"lstrip": false,
|
| 721 |
+
"normalized": true,
|
| 722 |
+
"rstrip": false,
|
| 723 |
+
"single_word": false
|
| 724 |
+
},
|
| 725 |
+
"mask_token": {
|
| 726 |
+
"content": "<mask>",
|
| 727 |
+
"lstrip": true,
|
| 728 |
+
"normalized": true,
|
| 729 |
+
"rstrip": false,
|
| 730 |
+
"single_word": false
|
| 731 |
+
},
|
| 732 |
+
"pad_token": {
|
| 733 |
+
"content": "<pad>",
|
| 734 |
+
"lstrip": false,
|
| 735 |
+
"normalized": true,
|
| 736 |
+
"rstrip": false,
|
| 737 |
+
"single_word": false
|
| 738 |
+
},
|
| 739 |
+
"sep_token": {
|
| 740 |
+
"content": "</s>",
|
| 741 |
+
"lstrip": false,
|
| 742 |
+
"normalized": true,
|
| 743 |
+
"rstrip": false,
|
| 744 |
+
"single_word": false
|
| 745 |
+
},
|
| 746 |
+
"unk_token": {
|
| 747 |
+
"content": "<unk>",
|
| 748 |
+
"lstrip": false,
|
| 749 |
+
"normalized": true,
|
| 750 |
+
"rstrip": false,
|
| 751 |
+
"single_word": false
|
| 752 |
+
}
|
| 753 |
+
}
|
int8_dynamic/tokenizer/tokenizer.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
int8_dynamic/tokenizer/tokenizer_config.json
ADDED
|
@@ -0,0 +1,959 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"add_prefix_space": false,
|
| 3 |
+
"added_tokens_decoder": {
|
| 4 |
+
"0": {
|
| 5 |
+
"content": "<pad>",
|
| 6 |
+
"lstrip": false,
|
| 7 |
+
"normalized": true,
|
| 8 |
+
"rstrip": false,
|
| 9 |
+
"single_word": false,
|
| 10 |
+
"special": true
|
| 11 |
+
},
|
| 12 |
+
"1": {
|
| 13 |
+
"content": "<s>",
|
| 14 |
+
"lstrip": false,
|
| 15 |
+
"normalized": true,
|
| 16 |
+
"rstrip": false,
|
| 17 |
+
"single_word": false,
|
| 18 |
+
"special": true
|
| 19 |
+
},
|
| 20 |
+
"2": {
|
| 21 |
+
"content": "</s>",
|
| 22 |
+
"lstrip": false,
|
| 23 |
+
"normalized": true,
|
| 24 |
+
"rstrip": false,
|
| 25 |
+
"single_word": false,
|
| 26 |
+
"special": true
|
| 27 |
+
},
|
| 28 |
+
"3": {
|
| 29 |
+
"content": "<unk>",
|
| 30 |
+
"lstrip": false,
|
| 31 |
+
"normalized": true,
|
| 32 |
+
"rstrip": false,
|
| 33 |
+
"single_word": false,
|
| 34 |
+
"special": true
|
| 35 |
+
},
|
| 36 |
+
"4": {
|
| 37 |
+
"content": "<mask>",
|
| 38 |
+
"lstrip": true,
|
| 39 |
+
"normalized": true,
|
| 40 |
+
"rstrip": false,
|
| 41 |
+
"single_word": false,
|
| 42 |
+
"special": true
|
| 43 |
+
},
|
| 44 |
+
"32000": {
|
| 45 |
+
"content": "<extra_id_99>",
|
| 46 |
+
"lstrip": true,
|
| 47 |
+
"normalized": true,
|
| 48 |
+
"rstrip": false,
|
| 49 |
+
"single_word": false,
|
| 50 |
+
"special": true
|
| 51 |
+
},
|
| 52 |
+
"32001": {
|
| 53 |
+
"content": "<extra_id_98>",
|
| 54 |
+
"lstrip": true,
|
| 55 |
+
"normalized": true,
|
| 56 |
+
"rstrip": false,
|
| 57 |
+
"single_word": false,
|
| 58 |
+
"special": true
|
| 59 |
+
},
|
| 60 |
+
"32002": {
|
| 61 |
+
"content": "<extra_id_97>",
|
| 62 |
+
"lstrip": true,
|
| 63 |
+
"normalized": true,
|
| 64 |
+
"rstrip": false,
|
| 65 |
+
"single_word": false,
|
| 66 |
+
"special": true
|
| 67 |
+
},
|
| 68 |
+
"32003": {
|
| 69 |
+
"content": "<extra_id_96>",
|
| 70 |
+
"lstrip": true,
|
| 71 |
+
"normalized": true,
|
| 72 |
+
"rstrip": false,
|
| 73 |
+
"single_word": false,
|
| 74 |
+
"special": true
|
| 75 |
+
},
|
| 76 |
+
"32004": {
|
| 77 |
+
"content": "<extra_id_95>",
|
| 78 |
+
"lstrip": true,
|
| 79 |
+
"normalized": true,
|
| 80 |
+
"rstrip": false,
|
| 81 |
+
"single_word": false,
|
| 82 |
+
"special": true
|
| 83 |
+
},
|
| 84 |
+
"32005": {
|
| 85 |
+
"content": "<extra_id_94>",
|
| 86 |
+
"lstrip": true,
|
| 87 |
+
"normalized": true,
|
| 88 |
+
"rstrip": false,
|
| 89 |
+
"single_word": false,
|
| 90 |
+
"special": true
|
| 91 |
+
},
|
| 92 |
+
"32006": {
|
| 93 |
+
"content": "<extra_id_93>",
|
| 94 |
+
"lstrip": true,
|
| 95 |
+
"normalized": true,
|
| 96 |
+
"rstrip": false,
|
| 97 |
+
"single_word": false,
|
| 98 |
+
"special": true
|
| 99 |
+
},
|
| 100 |
+
"32007": {
|
| 101 |
+
"content": "<extra_id_92>",
|
| 102 |
+
"lstrip": true,
|
| 103 |
+
"normalized": true,
|
| 104 |
+
"rstrip": false,
|
| 105 |
+
"single_word": false,
|
| 106 |
+
"special": true
|
| 107 |
+
},
|
| 108 |
+
"32008": {
|
| 109 |
+
"content": "<extra_id_91>",
|
| 110 |
+
"lstrip": true,
|
| 111 |
+
"normalized": true,
|
| 112 |
+
"rstrip": false,
|
| 113 |
+
"single_word": false,
|
| 114 |
+
"special": true
|
| 115 |
+
},
|
| 116 |
+
"32009": {
|
| 117 |
+
"content": "<extra_id_90>",
|
| 118 |
+
"lstrip": true,
|
| 119 |
+
"normalized": true,
|
| 120 |
+
"rstrip": false,
|
| 121 |
+
"single_word": false,
|
| 122 |
+
"special": true
|
| 123 |
+
},
|
| 124 |
+
"32010": {
|
| 125 |
+
"content": "<extra_id_89>",
|
| 126 |
+
"lstrip": true,
|
| 127 |
+
"normalized": true,
|
| 128 |
+
"rstrip": false,
|
| 129 |
+
"single_word": false,
|
| 130 |
+
"special": true
|
| 131 |
+
},
|
| 132 |
+
"32011": {
|
| 133 |
+
"content": "<extra_id_88>",
|
| 134 |
+
"lstrip": true,
|
| 135 |
+
"normalized": true,
|
| 136 |
+
"rstrip": false,
|
| 137 |
+
"single_word": false,
|
| 138 |
+
"special": true
|
| 139 |
+
},
|
| 140 |
+
"32012": {
|
| 141 |
+
"content": "<extra_id_87>",
|
| 142 |
+
"lstrip": true,
|
| 143 |
+
"normalized": true,
|
| 144 |
+
"rstrip": false,
|
| 145 |
+
"single_word": false,
|
| 146 |
+
"special": true
|
| 147 |
+
},
|
| 148 |
+
"32013": {
|
| 149 |
+
"content": "<extra_id_86>",
|
| 150 |
+
"lstrip": true,
|
| 151 |
+
"normalized": true,
|
| 152 |
+
"rstrip": false,
|
| 153 |
+
"single_word": false,
|
| 154 |
+
"special": true
|
| 155 |
+
},
|
| 156 |
+
"32014": {
|
| 157 |
+
"content": "<extra_id_85>",
|
| 158 |
+
"lstrip": true,
|
| 159 |
+
"normalized": true,
|
| 160 |
+
"rstrip": false,
|
| 161 |
+
"single_word": false,
|
| 162 |
+
"special": true
|
| 163 |
+
},
|
| 164 |
+
"32015": {
|
| 165 |
+
"content": "<extra_id_84>",
|
| 166 |
+
"lstrip": true,
|
| 167 |
+
"normalized": true,
|
| 168 |
+
"rstrip": false,
|
| 169 |
+
"single_word": false,
|
| 170 |
+
"special": true
|
| 171 |
+
},
|
| 172 |
+
"32016": {
|
| 173 |
+
"content": "<extra_id_83>",
|
| 174 |
+
"lstrip": true,
|
| 175 |
+
"normalized": true,
|
| 176 |
+
"rstrip": false,
|
| 177 |
+
"single_word": false,
|
| 178 |
+
"special": true
|
| 179 |
+
},
|
| 180 |
+
"32017": {
|
| 181 |
+
"content": "<extra_id_82>",
|
| 182 |
+
"lstrip": true,
|
| 183 |
+
"normalized": true,
|
| 184 |
+
"rstrip": false,
|
| 185 |
+
"single_word": false,
|
| 186 |
+
"special": true
|
| 187 |
+
},
|
| 188 |
+
"32018": {
|
| 189 |
+
"content": "<extra_id_81>",
|
| 190 |
+
"lstrip": true,
|
| 191 |
+
"normalized": true,
|
| 192 |
+
"rstrip": false,
|
| 193 |
+
"single_word": false,
|
| 194 |
+
"special": true
|
| 195 |
+
},
|
| 196 |
+
"32019": {
|
| 197 |
+
"content": "<extra_id_80>",
|
| 198 |
+
"lstrip": true,
|
| 199 |
+
"normalized": true,
|
| 200 |
+
"rstrip": false,
|
| 201 |
+
"single_word": false,
|
| 202 |
+
"special": true
|
| 203 |
+
},
|
| 204 |
+
"32020": {
|
| 205 |
+
"content": "<extra_id_79>",
|
| 206 |
+
"lstrip": true,
|
| 207 |
+
"normalized": true,
|
| 208 |
+
"rstrip": false,
|
| 209 |
+
"single_word": false,
|
| 210 |
+
"special": true
|
| 211 |
+
},
|
| 212 |
+
"32021": {
|
| 213 |
+
"content": "<extra_id_78>",
|
| 214 |
+
"lstrip": true,
|
| 215 |
+
"normalized": true,
|
| 216 |
+
"rstrip": false,
|
| 217 |
+
"single_word": false,
|
| 218 |
+
"special": true
|
| 219 |
+
},
|
| 220 |
+
"32022": {
|
| 221 |
+
"content": "<extra_id_77>",
|
| 222 |
+
"lstrip": true,
|
| 223 |
+
"normalized": true,
|
| 224 |
+
"rstrip": false,
|
| 225 |
+
"single_word": false,
|
| 226 |
+
"special": true
|
| 227 |
+
},
|
| 228 |
+
"32023": {
|
| 229 |
+
"content": "<extra_id_76>",
|
| 230 |
+
"lstrip": true,
|
| 231 |
+
"normalized": true,
|
| 232 |
+
"rstrip": false,
|
| 233 |
+
"single_word": false,
|
| 234 |
+
"special": true
|
| 235 |
+
},
|
| 236 |
+
"32024": {
|
| 237 |
+
"content": "<extra_id_75>",
|
| 238 |
+
"lstrip": true,
|
| 239 |
+
"normalized": true,
|
| 240 |
+
"rstrip": false,
|
| 241 |
+
"single_word": false,
|
| 242 |
+
"special": true
|
| 243 |
+
},
|
| 244 |
+
"32025": {
|
| 245 |
+
"content": "<extra_id_74>",
|
| 246 |
+
"lstrip": true,
|
| 247 |
+
"normalized": true,
|
| 248 |
+
"rstrip": false,
|
| 249 |
+
"single_word": false,
|
| 250 |
+
"special": true
|
| 251 |
+
},
|
| 252 |
+
"32026": {
|
| 253 |
+
"content": "<extra_id_73>",
|
| 254 |
+
"lstrip": true,
|
| 255 |
+
"normalized": true,
|
| 256 |
+
"rstrip": false,
|
| 257 |
+
"single_word": false,
|
| 258 |
+
"special": true
|
| 259 |
+
},
|
| 260 |
+
"32027": {
|
| 261 |
+
"content": "<extra_id_72>",
|
| 262 |
+
"lstrip": true,
|
| 263 |
+
"normalized": true,
|
| 264 |
+
"rstrip": false,
|
| 265 |
+
"single_word": false,
|
| 266 |
+
"special": true
|
| 267 |
+
},
|
| 268 |
+
"32028": {
|
| 269 |
+
"content": "<extra_id_71>",
|
| 270 |
+
"lstrip": true,
|
| 271 |
+
"normalized": true,
|
| 272 |
+
"rstrip": false,
|
| 273 |
+
"single_word": false,
|
| 274 |
+
"special": true
|
| 275 |
+
},
|
| 276 |
+
"32029": {
|
| 277 |
+
"content": "<extra_id_70>",
|
| 278 |
+
"lstrip": true,
|
| 279 |
+
"normalized": true,
|
| 280 |
+
"rstrip": false,
|
| 281 |
+
"single_word": false,
|
| 282 |
+
"special": true
|
| 283 |
+
},
|
| 284 |
+
"32030": {
|
| 285 |
+
"content": "<extra_id_69>",
|
| 286 |
+
"lstrip": true,
|
| 287 |
+
"normalized": true,
|
| 288 |
+
"rstrip": false,
|
| 289 |
+
"single_word": false,
|
| 290 |
+
"special": true
|
| 291 |
+
},
|
| 292 |
+
"32031": {
|
| 293 |
+
"content": "<extra_id_68>",
|
| 294 |
+
"lstrip": true,
|
| 295 |
+
"normalized": true,
|
| 296 |
+
"rstrip": false,
|
| 297 |
+
"single_word": false,
|
| 298 |
+
"special": true
|
| 299 |
+
},
|
| 300 |
+
"32032": {
|
| 301 |
+
"content": "<extra_id_67>",
|
| 302 |
+
"lstrip": true,
|
| 303 |
+
"normalized": true,
|
| 304 |
+
"rstrip": false,
|
| 305 |
+
"single_word": false,
|
| 306 |
+
"special": true
|
| 307 |
+
},
|
| 308 |
+
"32033": {
|
| 309 |
+
"content": "<extra_id_66>",
|
| 310 |
+
"lstrip": true,
|
| 311 |
+
"normalized": true,
|
| 312 |
+
"rstrip": false,
|
| 313 |
+
"single_word": false,
|
| 314 |
+
"special": true
|
| 315 |
+
},
|
| 316 |
+
"32034": {
|
| 317 |
+
"content": "<extra_id_65>",
|
| 318 |
+
"lstrip": true,
|
| 319 |
+
"normalized": true,
|
| 320 |
+
"rstrip": false,
|
| 321 |
+
"single_word": false,
|
| 322 |
+
"special": true
|
| 323 |
+
},
|
| 324 |
+
"32035": {
|
| 325 |
+
"content": "<extra_id_64>",
|
| 326 |
+
"lstrip": true,
|
| 327 |
+
"normalized": true,
|
| 328 |
+
"rstrip": false,
|
| 329 |
+
"single_word": false,
|
| 330 |
+
"special": true
|
| 331 |
+
},
|
| 332 |
+
"32036": {
|
| 333 |
+
"content": "<extra_id_63>",
|
| 334 |
+
"lstrip": true,
|
| 335 |
+
"normalized": true,
|
| 336 |
+
"rstrip": false,
|
| 337 |
+
"single_word": false,
|
| 338 |
+
"special": true
|
| 339 |
+
},
|
| 340 |
+
"32037": {
|
| 341 |
+
"content": "<extra_id_62>",
|
| 342 |
+
"lstrip": true,
|
| 343 |
+
"normalized": true,
|
| 344 |
+
"rstrip": false,
|
| 345 |
+
"single_word": false,
|
| 346 |
+
"special": true
|
| 347 |
+
},
|
| 348 |
+
"32038": {
|
| 349 |
+
"content": "<extra_id_61>",
|
| 350 |
+
"lstrip": true,
|
| 351 |
+
"normalized": true,
|
| 352 |
+
"rstrip": false,
|
| 353 |
+
"single_word": false,
|
| 354 |
+
"special": true
|
| 355 |
+
},
|
| 356 |
+
"32039": {
|
| 357 |
+
"content": "<extra_id_60>",
|
| 358 |
+
"lstrip": true,
|
| 359 |
+
"normalized": true,
|
| 360 |
+
"rstrip": false,
|
| 361 |
+
"single_word": false,
|
| 362 |
+
"special": true
|
| 363 |
+
},
|
| 364 |
+
"32040": {
|
| 365 |
+
"content": "<extra_id_59>",
|
| 366 |
+
"lstrip": true,
|
| 367 |
+
"normalized": true,
|
| 368 |
+
"rstrip": false,
|
| 369 |
+
"single_word": false,
|
| 370 |
+
"special": true
|
| 371 |
+
},
|
| 372 |
+
"32041": {
|
| 373 |
+
"content": "<extra_id_58>",
|
| 374 |
+
"lstrip": true,
|
| 375 |
+
"normalized": true,
|
| 376 |
+
"rstrip": false,
|
| 377 |
+
"single_word": false,
|
| 378 |
+
"special": true
|
| 379 |
+
},
|
| 380 |
+
"32042": {
|
| 381 |
+
"content": "<extra_id_57>",
|
| 382 |
+
"lstrip": true,
|
| 383 |
+
"normalized": true,
|
| 384 |
+
"rstrip": false,
|
| 385 |
+
"single_word": false,
|
| 386 |
+
"special": true
|
| 387 |
+
},
|
| 388 |
+
"32043": {
|
| 389 |
+
"content": "<extra_id_56>",
|
| 390 |
+
"lstrip": true,
|
| 391 |
+
"normalized": true,
|
| 392 |
+
"rstrip": false,
|
| 393 |
+
"single_word": false,
|
| 394 |
+
"special": true
|
| 395 |
+
},
|
| 396 |
+
"32044": {
|
| 397 |
+
"content": "<extra_id_55>",
|
| 398 |
+
"lstrip": true,
|
| 399 |
+
"normalized": true,
|
| 400 |
+
"rstrip": false,
|
| 401 |
+
"single_word": false,
|
| 402 |
+
"special": true
|
| 403 |
+
},
|
| 404 |
+
"32045": {
|
| 405 |
+
"content": "<extra_id_54>",
|
| 406 |
+
"lstrip": true,
|
| 407 |
+
"normalized": true,
|
| 408 |
+
"rstrip": false,
|
| 409 |
+
"single_word": false,
|
| 410 |
+
"special": true
|
| 411 |
+
},
|
| 412 |
+
"32046": {
|
| 413 |
+
"content": "<extra_id_53>",
|
| 414 |
+
"lstrip": true,
|
| 415 |
+
"normalized": true,
|
| 416 |
+
"rstrip": false,
|
| 417 |
+
"single_word": false,
|
| 418 |
+
"special": true
|
| 419 |
+
},
|
| 420 |
+
"32047": {
|
| 421 |
+
"content": "<extra_id_52>",
|
| 422 |
+
"lstrip": true,
|
| 423 |
+
"normalized": true,
|
| 424 |
+
"rstrip": false,
|
| 425 |
+
"single_word": false,
|
| 426 |
+
"special": true
|
| 427 |
+
},
|
| 428 |
+
"32048": {
|
| 429 |
+
"content": "<extra_id_51>",
|
| 430 |
+
"lstrip": true,
|
| 431 |
+
"normalized": true,
|
| 432 |
+
"rstrip": false,
|
| 433 |
+
"single_word": false,
|
| 434 |
+
"special": true
|
| 435 |
+
},
|
| 436 |
+
"32049": {
|
| 437 |
+
"content": "<extra_id_50>",
|
| 438 |
+
"lstrip": true,
|
| 439 |
+
"normalized": true,
|
| 440 |
+
"rstrip": false,
|
| 441 |
+
"single_word": false,
|
| 442 |
+
"special": true
|
| 443 |
+
},
|
| 444 |
+
"32050": {
|
| 445 |
+
"content": "<extra_id_49>",
|
| 446 |
+
"lstrip": true,
|
| 447 |
+
"normalized": true,
|
| 448 |
+
"rstrip": false,
|
| 449 |
+
"single_word": false,
|
| 450 |
+
"special": true
|
| 451 |
+
},
|
| 452 |
+
"32051": {
|
| 453 |
+
"content": "<extra_id_48>",
|
| 454 |
+
"lstrip": true,
|
| 455 |
+
"normalized": true,
|
| 456 |
+
"rstrip": false,
|
| 457 |
+
"single_word": false,
|
| 458 |
+
"special": true
|
| 459 |
+
},
|
| 460 |
+
"32052": {
|
| 461 |
+
"content": "<extra_id_47>",
|
| 462 |
+
"lstrip": true,
|
| 463 |
+
"normalized": true,
|
| 464 |
+
"rstrip": false,
|
| 465 |
+
"single_word": false,
|
| 466 |
+
"special": true
|
| 467 |
+
},
|
| 468 |
+
"32053": {
|
| 469 |
+
"content": "<extra_id_46>",
|
| 470 |
+
"lstrip": true,
|
| 471 |
+
"normalized": true,
|
| 472 |
+
"rstrip": false,
|
| 473 |
+
"single_word": false,
|
| 474 |
+
"special": true
|
| 475 |
+
},
|
| 476 |
+
"32054": {
|
| 477 |
+
"content": "<extra_id_45>",
|
| 478 |
+
"lstrip": true,
|
| 479 |
+
"normalized": true,
|
| 480 |
+
"rstrip": false,
|
| 481 |
+
"single_word": false,
|
| 482 |
+
"special": true
|
| 483 |
+
},
|
| 484 |
+
"32055": {
|
| 485 |
+
"content": "<extra_id_44>",
|
| 486 |
+
"lstrip": true,
|
| 487 |
+
"normalized": true,
|
| 488 |
+
"rstrip": false,
|
| 489 |
+
"single_word": false,
|
| 490 |
+
"special": true
|
| 491 |
+
},
|
| 492 |
+
"32056": {
|
| 493 |
+
"content": "<extra_id_43>",
|
| 494 |
+
"lstrip": true,
|
| 495 |
+
"normalized": true,
|
| 496 |
+
"rstrip": false,
|
| 497 |
+
"single_word": false,
|
| 498 |
+
"special": true
|
| 499 |
+
},
|
| 500 |
+
"32057": {
|
| 501 |
+
"content": "<extra_id_42>",
|
| 502 |
+
"lstrip": true,
|
| 503 |
+
"normalized": true,
|
| 504 |
+
"rstrip": false,
|
| 505 |
+
"single_word": false,
|
| 506 |
+
"special": true
|
| 507 |
+
},
|
| 508 |
+
"32058": {
|
| 509 |
+
"content": "<extra_id_41>",
|
| 510 |
+
"lstrip": true,
|
| 511 |
+
"normalized": true,
|
| 512 |
+
"rstrip": false,
|
| 513 |
+
"single_word": false,
|
| 514 |
+
"special": true
|
| 515 |
+
},
|
| 516 |
+
"32059": {
|
| 517 |
+
"content": "<extra_id_40>",
|
| 518 |
+
"lstrip": true,
|
| 519 |
+
"normalized": true,
|
| 520 |
+
"rstrip": false,
|
| 521 |
+
"single_word": false,
|
| 522 |
+
"special": true
|
| 523 |
+
},
|
| 524 |
+
"32060": {
|
| 525 |
+
"content": "<extra_id_39>",
|
| 526 |
+
"lstrip": true,
|
| 527 |
+
"normalized": true,
|
| 528 |
+
"rstrip": false,
|
| 529 |
+
"single_word": false,
|
| 530 |
+
"special": true
|
| 531 |
+
},
|
| 532 |
+
"32061": {
|
| 533 |
+
"content": "<extra_id_38>",
|
| 534 |
+
"lstrip": true,
|
| 535 |
+
"normalized": true,
|
| 536 |
+
"rstrip": false,
|
| 537 |
+
"single_word": false,
|
| 538 |
+
"special": true
|
| 539 |
+
},
|
| 540 |
+
"32062": {
|
| 541 |
+
"content": "<extra_id_37>",
|
| 542 |
+
"lstrip": true,
|
| 543 |
+
"normalized": true,
|
| 544 |
+
"rstrip": false,
|
| 545 |
+
"single_word": false,
|
| 546 |
+
"special": true
|
| 547 |
+
},
|
| 548 |
+
"32063": {
|
| 549 |
+
"content": "<extra_id_36>",
|
| 550 |
+
"lstrip": true,
|
| 551 |
+
"normalized": true,
|
| 552 |
+
"rstrip": false,
|
| 553 |
+
"single_word": false,
|
| 554 |
+
"special": true
|
| 555 |
+
},
|
| 556 |
+
"32064": {
|
| 557 |
+
"content": "<extra_id_35>",
|
| 558 |
+
"lstrip": true,
|
| 559 |
+
"normalized": true,
|
| 560 |
+
"rstrip": false,
|
| 561 |
+
"single_word": false,
|
| 562 |
+
"special": true
|
| 563 |
+
},
|
| 564 |
+
"32065": {
|
| 565 |
+
"content": "<extra_id_34>",
|
| 566 |
+
"lstrip": true,
|
| 567 |
+
"normalized": true,
|
| 568 |
+
"rstrip": false,
|
| 569 |
+
"single_word": false,
|
| 570 |
+
"special": true
|
| 571 |
+
},
|
| 572 |
+
"32066": {
|
| 573 |
+
"content": "<extra_id_33>",
|
| 574 |
+
"lstrip": true,
|
| 575 |
+
"normalized": true,
|
| 576 |
+
"rstrip": false,
|
| 577 |
+
"single_word": false,
|
| 578 |
+
"special": true
|
| 579 |
+
},
|
| 580 |
+
"32067": {
|
| 581 |
+
"content": "<extra_id_32>",
|
| 582 |
+
"lstrip": true,
|
| 583 |
+
"normalized": true,
|
| 584 |
+
"rstrip": false,
|
| 585 |
+
"single_word": false,
|
| 586 |
+
"special": true
|
| 587 |
+
},
|
| 588 |
+
"32068": {
|
| 589 |
+
"content": "<extra_id_31>",
|
| 590 |
+
"lstrip": true,
|
| 591 |
+
"normalized": true,
|
| 592 |
+
"rstrip": false,
|
| 593 |
+
"single_word": false,
|
| 594 |
+
"special": true
|
| 595 |
+
},
|
| 596 |
+
"32069": {
|
| 597 |
+
"content": "<extra_id_30>",
|
| 598 |
+
"lstrip": true,
|
| 599 |
+
"normalized": true,
|
| 600 |
+
"rstrip": false,
|
| 601 |
+
"single_word": false,
|
| 602 |
+
"special": true
|
| 603 |
+
},
|
| 604 |
+
"32070": {
|
| 605 |
+
"content": "<extra_id_29>",
|
| 606 |
+
"lstrip": true,
|
| 607 |
+
"normalized": true,
|
| 608 |
+
"rstrip": false,
|
| 609 |
+
"single_word": false,
|
| 610 |
+
"special": true
|
| 611 |
+
},
|
| 612 |
+
"32071": {
|
| 613 |
+
"content": "<extra_id_28>",
|
| 614 |
+
"lstrip": true,
|
| 615 |
+
"normalized": true,
|
| 616 |
+
"rstrip": false,
|
| 617 |
+
"single_word": false,
|
| 618 |
+
"special": true
|
| 619 |
+
},
|
| 620 |
+
"32072": {
|
| 621 |
+
"content": "<extra_id_27>",
|
| 622 |
+
"lstrip": true,
|
| 623 |
+
"normalized": true,
|
| 624 |
+
"rstrip": false,
|
| 625 |
+
"single_word": false,
|
| 626 |
+
"special": true
|
| 627 |
+
},
|
| 628 |
+
"32073": {
|
| 629 |
+
"content": "<extra_id_26>",
|
| 630 |
+
"lstrip": true,
|
| 631 |
+
"normalized": true,
|
| 632 |
+
"rstrip": false,
|
| 633 |
+
"single_word": false,
|
| 634 |
+
"special": true
|
| 635 |
+
},
|
| 636 |
+
"32074": {
|
| 637 |
+
"content": "<extra_id_25>",
|
| 638 |
+
"lstrip": true,
|
| 639 |
+
"normalized": true,
|
| 640 |
+
"rstrip": false,
|
| 641 |
+
"single_word": false,
|
| 642 |
+
"special": true
|
| 643 |
+
},
|
| 644 |
+
"32075": {
|
| 645 |
+
"content": "<extra_id_24>",
|
| 646 |
+
"lstrip": true,
|
| 647 |
+
"normalized": true,
|
| 648 |
+
"rstrip": false,
|
| 649 |
+
"single_word": false,
|
| 650 |
+
"special": true
|
| 651 |
+
},
|
| 652 |
+
"32076": {
|
| 653 |
+
"content": "<extra_id_23>",
|
| 654 |
+
"lstrip": true,
|
| 655 |
+
"normalized": true,
|
| 656 |
+
"rstrip": false,
|
| 657 |
+
"single_word": false,
|
| 658 |
+
"special": true
|
| 659 |
+
},
|
| 660 |
+
"32077": {
|
| 661 |
+
"content": "<extra_id_22>",
|
| 662 |
+
"lstrip": true,
|
| 663 |
+
"normalized": true,
|
| 664 |
+
"rstrip": false,
|
| 665 |
+
"single_word": false,
|
| 666 |
+
"special": true
|
| 667 |
+
},
|
| 668 |
+
"32078": {
|
| 669 |
+
"content": "<extra_id_21>",
|
| 670 |
+
"lstrip": true,
|
| 671 |
+
"normalized": true,
|
| 672 |
+
"rstrip": false,
|
| 673 |
+
"single_word": false,
|
| 674 |
+
"special": true
|
| 675 |
+
},
|
| 676 |
+
"32079": {
|
| 677 |
+
"content": "<extra_id_20>",
|
| 678 |
+
"lstrip": true,
|
| 679 |
+
"normalized": true,
|
| 680 |
+
"rstrip": false,
|
| 681 |
+
"single_word": false,
|
| 682 |
+
"special": true
|
| 683 |
+
},
|
| 684 |
+
"32080": {
|
| 685 |
+
"content": "<extra_id_19>",
|
| 686 |
+
"lstrip": true,
|
| 687 |
+
"normalized": true,
|
| 688 |
+
"rstrip": false,
|
| 689 |
+
"single_word": false,
|
| 690 |
+
"special": true
|
| 691 |
+
},
|
| 692 |
+
"32081": {
|
| 693 |
+
"content": "<extra_id_18>",
|
| 694 |
+
"lstrip": true,
|
| 695 |
+
"normalized": true,
|
| 696 |
+
"rstrip": false,
|
| 697 |
+
"single_word": false,
|
| 698 |
+
"special": true
|
| 699 |
+
},
|
| 700 |
+
"32082": {
|
| 701 |
+
"content": "<extra_id_17>",
|
| 702 |
+
"lstrip": true,
|
| 703 |
+
"normalized": true,
|
| 704 |
+
"rstrip": false,
|
| 705 |
+
"single_word": false,
|
| 706 |
+
"special": true
|
| 707 |
+
},
|
| 708 |
+
"32083": {
|
| 709 |
+
"content": "<extra_id_16>",
|
| 710 |
+
"lstrip": true,
|
| 711 |
+
"normalized": true,
|
| 712 |
+
"rstrip": false,
|
| 713 |
+
"single_word": false,
|
| 714 |
+
"special": true
|
| 715 |
+
},
|
| 716 |
+
"32084": {
|
| 717 |
+
"content": "<extra_id_15>",
|
| 718 |
+
"lstrip": true,
|
| 719 |
+
"normalized": true,
|
| 720 |
+
"rstrip": false,
|
| 721 |
+
"single_word": false,
|
| 722 |
+
"special": true
|
| 723 |
+
},
|
| 724 |
+
"32085": {
|
| 725 |
+
"content": "<extra_id_14>",
|
| 726 |
+
"lstrip": true,
|
| 727 |
+
"normalized": true,
|
| 728 |
+
"rstrip": false,
|
| 729 |
+
"single_word": false,
|
| 730 |
+
"special": true
|
| 731 |
+
},
|
| 732 |
+
"32086": {
|
| 733 |
+
"content": "<extra_id_13>",
|
| 734 |
+
"lstrip": true,
|
| 735 |
+
"normalized": true,
|
| 736 |
+
"rstrip": false,
|
| 737 |
+
"single_word": false,
|
| 738 |
+
"special": true
|
| 739 |
+
},
|
| 740 |
+
"32087": {
|
| 741 |
+
"content": "<extra_id_12>",
|
| 742 |
+
"lstrip": true,
|
| 743 |
+
"normalized": true,
|
| 744 |
+
"rstrip": false,
|
| 745 |
+
"single_word": false,
|
| 746 |
+
"special": true
|
| 747 |
+
},
|
| 748 |
+
"32088": {
|
| 749 |
+
"content": "<extra_id_11>",
|
| 750 |
+
"lstrip": true,
|
| 751 |
+
"normalized": true,
|
| 752 |
+
"rstrip": false,
|
| 753 |
+
"single_word": false,
|
| 754 |
+
"special": true
|
| 755 |
+
},
|
| 756 |
+
"32089": {
|
| 757 |
+
"content": "<extra_id_10>",
|
| 758 |
+
"lstrip": true,
|
| 759 |
+
"normalized": true,
|
| 760 |
+
"rstrip": false,
|
| 761 |
+
"single_word": false,
|
| 762 |
+
"special": true
|
| 763 |
+
},
|
| 764 |
+
"32090": {
|
| 765 |
+
"content": "<extra_id_9>",
|
| 766 |
+
"lstrip": true,
|
| 767 |
+
"normalized": true,
|
| 768 |
+
"rstrip": false,
|
| 769 |
+
"single_word": false,
|
| 770 |
+
"special": true
|
| 771 |
+
},
|
| 772 |
+
"32091": {
|
| 773 |
+
"content": "<extra_id_8>",
|
| 774 |
+
"lstrip": true,
|
| 775 |
+
"normalized": true,
|
| 776 |
+
"rstrip": false,
|
| 777 |
+
"single_word": false,
|
| 778 |
+
"special": true
|
| 779 |
+
},
|
| 780 |
+
"32092": {
|
| 781 |
+
"content": "<extra_id_7>",
|
| 782 |
+
"lstrip": true,
|
| 783 |
+
"normalized": true,
|
| 784 |
+
"rstrip": false,
|
| 785 |
+
"single_word": false,
|
| 786 |
+
"special": true
|
| 787 |
+
},
|
| 788 |
+
"32093": {
|
| 789 |
+
"content": "<extra_id_6>",
|
| 790 |
+
"lstrip": true,
|
| 791 |
+
"normalized": true,
|
| 792 |
+
"rstrip": false,
|
| 793 |
+
"single_word": false,
|
| 794 |
+
"special": true
|
| 795 |
+
},
|
| 796 |
+
"32094": {
|
| 797 |
+
"content": "<extra_id_5>",
|
| 798 |
+
"lstrip": true,
|
| 799 |
+
"normalized": true,
|
| 800 |
+
"rstrip": false,
|
| 801 |
+
"single_word": false,
|
| 802 |
+
"special": true
|
| 803 |
+
},
|
| 804 |
+
"32095": {
|
| 805 |
+
"content": "<extra_id_4>",
|
| 806 |
+
"lstrip": true,
|
| 807 |
+
"normalized": true,
|
| 808 |
+
"rstrip": false,
|
| 809 |
+
"single_word": false,
|
| 810 |
+
"special": true
|
| 811 |
+
},
|
| 812 |
+
"32096": {
|
| 813 |
+
"content": "<extra_id_3>",
|
| 814 |
+
"lstrip": true,
|
| 815 |
+
"normalized": true,
|
| 816 |
+
"rstrip": false,
|
| 817 |
+
"single_word": false,
|
| 818 |
+
"special": true
|
| 819 |
+
},
|
| 820 |
+
"32097": {
|
| 821 |
+
"content": "<extra_id_2>",
|
| 822 |
+
"lstrip": true,
|
| 823 |
+
"normalized": true,
|
| 824 |
+
"rstrip": false,
|
| 825 |
+
"single_word": false,
|
| 826 |
+
"special": true
|
| 827 |
+
},
|
| 828 |
+
"32098": {
|
| 829 |
+
"content": "<extra_id_1>",
|
| 830 |
+
"lstrip": true,
|
| 831 |
+
"normalized": true,
|
| 832 |
+
"rstrip": false,
|
| 833 |
+
"single_word": false,
|
| 834 |
+
"special": true
|
| 835 |
+
},
|
| 836 |
+
"32099": {
|
| 837 |
+
"content": "<extra_id_0>",
|
| 838 |
+
"lstrip": true,
|
| 839 |
+
"normalized": true,
|
| 840 |
+
"rstrip": false,
|
| 841 |
+
"single_word": false,
|
| 842 |
+
"special": true
|
| 843 |
+
}
|
| 844 |
+
},
|
| 845 |
+
"additional_special_tokens": [
|
| 846 |
+
"<extra_id_99>",
|
| 847 |
+
"<extra_id_98>",
|
| 848 |
+
"<extra_id_97>",
|
| 849 |
+
"<extra_id_96>",
|
| 850 |
+
"<extra_id_95>",
|
| 851 |
+
"<extra_id_94>",
|
| 852 |
+
"<extra_id_93>",
|
| 853 |
+
"<extra_id_92>",
|
| 854 |
+
"<extra_id_91>",
|
| 855 |
+
"<extra_id_90>",
|
| 856 |
+
"<extra_id_89>",
|
| 857 |
+
"<extra_id_88>",
|
| 858 |
+
"<extra_id_87>",
|
| 859 |
+
"<extra_id_86>",
|
| 860 |
+
"<extra_id_85>",
|
| 861 |
+
"<extra_id_84>",
|
| 862 |
+
"<extra_id_83>",
|
| 863 |
+
"<extra_id_82>",
|
| 864 |
+
"<extra_id_81>",
|
| 865 |
+
"<extra_id_80>",
|
| 866 |
+
"<extra_id_79>",
|
| 867 |
+
"<extra_id_78>",
|
| 868 |
+
"<extra_id_77>",
|
| 869 |
+
"<extra_id_76>",
|
| 870 |
+
"<extra_id_75>",
|
| 871 |
+
"<extra_id_74>",
|
| 872 |
+
"<extra_id_73>",
|
| 873 |
+
"<extra_id_72>",
|
| 874 |
+
"<extra_id_71>",
|
| 875 |
+
"<extra_id_70>",
|
| 876 |
+
"<extra_id_69>",
|
| 877 |
+
"<extra_id_68>",
|
| 878 |
+
"<extra_id_67>",
|
| 879 |
+
"<extra_id_66>",
|
| 880 |
+
"<extra_id_65>",
|
| 881 |
+
"<extra_id_64>",
|
| 882 |
+
"<extra_id_63>",
|
| 883 |
+
"<extra_id_62>",
|
| 884 |
+
"<extra_id_61>",
|
| 885 |
+
"<extra_id_60>",
|
| 886 |
+
"<extra_id_59>",
|
| 887 |
+
"<extra_id_58>",
|
| 888 |
+
"<extra_id_57>",
|
| 889 |
+
"<extra_id_56>",
|
| 890 |
+
"<extra_id_55>",
|
| 891 |
+
"<extra_id_54>",
|
| 892 |
+
"<extra_id_53>",
|
| 893 |
+
"<extra_id_52>",
|
| 894 |
+
"<extra_id_51>",
|
| 895 |
+
"<extra_id_50>",
|
| 896 |
+
"<extra_id_49>",
|
| 897 |
+
"<extra_id_48>",
|
| 898 |
+
"<extra_id_47>",
|
| 899 |
+
"<extra_id_46>",
|
| 900 |
+
"<extra_id_45>",
|
| 901 |
+
"<extra_id_44>",
|
| 902 |
+
"<extra_id_43>",
|
| 903 |
+
"<extra_id_42>",
|
| 904 |
+
"<extra_id_41>",
|
| 905 |
+
"<extra_id_40>",
|
| 906 |
+
"<extra_id_39>",
|
| 907 |
+
"<extra_id_38>",
|
| 908 |
+
"<extra_id_37>",
|
| 909 |
+
"<extra_id_36>",
|
| 910 |
+
"<extra_id_35>",
|
| 911 |
+
"<extra_id_34>",
|
| 912 |
+
"<extra_id_33>",
|
| 913 |
+
"<extra_id_32>",
|
| 914 |
+
"<extra_id_31>",
|
| 915 |
+
"<extra_id_30>",
|
| 916 |
+
"<extra_id_29>",
|
| 917 |
+
"<extra_id_28>",
|
| 918 |
+
"<extra_id_27>",
|
| 919 |
+
"<extra_id_26>",
|
| 920 |
+
"<extra_id_25>",
|
| 921 |
+
"<extra_id_24>",
|
| 922 |
+
"<extra_id_23>",
|
| 923 |
+
"<extra_id_22>",
|
| 924 |
+
"<extra_id_21>",
|
| 925 |
+
"<extra_id_20>",
|
| 926 |
+
"<extra_id_19>",
|
| 927 |
+
"<extra_id_18>",
|
| 928 |
+
"<extra_id_17>",
|
| 929 |
+
"<extra_id_16>",
|
| 930 |
+
"<extra_id_15>",
|
| 931 |
+
"<extra_id_14>",
|
| 932 |
+
"<extra_id_13>",
|
| 933 |
+
"<extra_id_12>",
|
| 934 |
+
"<extra_id_11>",
|
| 935 |
+
"<extra_id_10>",
|
| 936 |
+
"<extra_id_9>",
|
| 937 |
+
"<extra_id_8>",
|
| 938 |
+
"<extra_id_7>",
|
| 939 |
+
"<extra_id_6>",
|
| 940 |
+
"<extra_id_5>",
|
| 941 |
+
"<extra_id_4>",
|
| 942 |
+
"<extra_id_3>",
|
| 943 |
+
"<extra_id_2>",
|
| 944 |
+
"<extra_id_1>",
|
| 945 |
+
"<extra_id_0>"
|
| 946 |
+
],
|
| 947 |
+
"bos_token": "<s>",
|
| 948 |
+
"clean_up_tokenization_spaces": true,
|
| 949 |
+
"cls_token": "<s>",
|
| 950 |
+
"eos_token": "</s>",
|
| 951 |
+
"errors": "replace",
|
| 952 |
+
"mask_token": "<mask>",
|
| 953 |
+
"model_max_length": 512,
|
| 954 |
+
"pad_token": "<pad>",
|
| 955 |
+
"sep_token": "</s>",
|
| 956 |
+
"tokenizer_class": "RobertaTokenizer",
|
| 957 |
+
"trim_offsets": true,
|
| 958 |
+
"unk_token": "<unk>"
|
| 959 |
+
}
|
int8_dynamic/tokenizer/vocab.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
requirements.txt
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio==5.8.0
|
| 2 |
+
pandas
|
| 3 |
+
sqlparse
|
| 4 |
+
transformers
|
| 5 |
+
torch --index-url https://download.pytorch.org/whl/cpu
|
| 6 |
+
peft
|
| 7 |
+
trl
|
| 8 |
+
sentencepiece
|
| 9 |
+
matplotlib
|
| 10 |
+
huggingface_hub
|
scripts/__pycache__/benchmark_parallel_reward.cpython-310.pyc
ADDED
|
Binary file (6.31 kB). View file
|
|
|
scripts/__pycache__/benchmark_parallel_reward.cpython-313.pyc
ADDED
|
Binary file (10.3 kB). View file
|
|
|
scripts/__pycache__/benchmark_quantization.cpython-310.pyc
ADDED
|
Binary file (3.79 kB). View file
|
|
|
scripts/__pycache__/benchmark_rollout_generation.cpython-310.pyc
ADDED
|
Binary file (2.75 kB). View file
|
|
|
scripts/__pycache__/quantize_export.cpython-310.pyc
ADDED
|
Binary file (2.05 kB). View file
|
|
|
scripts/__pycache__/quantized_infer_harness.cpython-310.pyc
ADDED
|
Binary file (1.62 kB). View file
|
|
|
scripts/benchmark_parallel_reward.py
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
# Ensure headless-safe matplotlib + writable cache when called from Gradio/subprocess.
|
| 3 |
+
os.environ.setdefault("MPLBACKEND", "Agg")
|
| 4 |
+
os.environ.setdefault("MPLCONFIGDIR", os.environ.get("MPLCONFIGDIR", "/tmp/mplconfig"))
|
| 5 |
+
import time
|
| 6 |
+
import json
|
| 7 |
+
import argparse
|
| 8 |
+
import matplotlib.pyplot as plt
|
| 9 |
+
import numpy as np
|
| 10 |
+
import sys
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
|
| 13 |
+
# ==========================================
|
| 14 |
+
# RELATIVE PATH RESOLUTION
|
| 15 |
+
# ==========================================
|
| 16 |
+
PROJECT_ROOT = Path(__file__).resolve().parent.parent
|
| 17 |
+
sys.path.append(str(PROJECT_ROOT))
|
| 18 |
+
|
| 19 |
+
# Dynamically resolve where the databases are kept
|
| 20 |
+
if (PROJECT_ROOT / "data" / "database").exists() and list((PROJECT_ROOT / "data" / "database").rglob("*.sqlite")):
|
| 21 |
+
DB_ROOT = PROJECT_ROOT / "data" / "database"
|
| 22 |
+
else:
|
| 23 |
+
DB_ROOT = PROJECT_ROOT / "final_databases"
|
| 24 |
+
|
| 25 |
+
from src.execution_reward import (
|
| 26 |
+
execution_reward_batch_sequential,
|
| 27 |
+
execution_reward_batch_parallel,
|
| 28 |
+
execution_reward_batch_parallel_by_db,
|
| 29 |
+
execution_reward_timed,
|
| 30 |
+
set_use_cache,
|
| 31 |
+
set_use_schema_validation,
|
| 32 |
+
clear_result_cache
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
def generate_mock_rollouts(num_rollouts: int = 100, heavy_n: int = 500_000):
|
| 36 |
+
"""Generates heavy queries across multiple databases to properly test true concurrency."""
|
| 37 |
+
print(f"\nGenerating {num_rollouts} heavy rollouts to simulate RLHF query workload...", flush=True)
|
| 38 |
+
|
| 39 |
+
# Smart search for real databases
|
| 40 |
+
real_dbs = [str(p) for p in DB_ROOT.rglob("*.sqlite")]
|
| 41 |
+
|
| 42 |
+
if real_dbs:
|
| 43 |
+
print(f"Found {len(real_dbs)} real SQLite databases in {DB_ROOT}. Distributing workload...", flush=True)
|
| 44 |
+
else:
|
| 45 |
+
print(f"❌ CRITICAL ERROR: No real databases found in {DB_ROOT}. Cannot run benchmark.", flush=True)
|
| 46 |
+
sys.exit(1)
|
| 47 |
+
|
| 48 |
+
rollouts = []
|
| 49 |
+
for i in range(num_rollouts):
|
| 50 |
+
db_path = real_dbs[i % len(real_dbs)]
|
| 51 |
+
|
| 52 |
+
# Heavy deterministic CPU-ish query (may be cut off by the 2s timeout depending on machine).
|
| 53 |
+
heavy_sql = f"""
|
| 54 |
+
WITH RECURSIVE cnt(x) AS (
|
| 55 |
+
SELECT 1
|
| 56 |
+
UNION ALL
|
| 57 |
+
SELECT x+1 FROM cnt WHERE x < {heavy_n + (i % 10_000)}
|
| 58 |
+
)
|
| 59 |
+
SELECT sum(x) FROM cnt;
|
| 60 |
+
"""
|
| 61 |
+
clean_sql = heavy_sql.replace("\n", " ").strip()
|
| 62 |
+
rollouts.append((clean_sql, db_path, clean_sql))
|
| 63 |
+
if num_rollouts >= 500 and (i + 1) % 250 == 0:
|
| 64 |
+
print(f" generated {i + 1}/{num_rollouts}...", flush=True)
|
| 65 |
+
|
| 66 |
+
return rollouts
|
| 67 |
+
|
| 68 |
+
def profile_bottlenecks(rollouts, sample_size: int = 20, print_every: int = 5):
|
| 69 |
+
"""Profiles CPU usage to identify time spent in parsing, planning, and execution."""
|
| 70 |
+
print("\n" + "="*65)
|
| 71 |
+
print(" 🔍 CPU PROFILING: IDENTIFYING BOTTLENECKS (100 Rollouts)")
|
| 72 |
+
print("="*65)
|
| 73 |
+
|
| 74 |
+
clear_result_cache()
|
| 75 |
+
set_use_cache(False) # Disable cache to force real work
|
| 76 |
+
set_use_schema_validation(False) # CTE-heavy benchmark queries may fail schema validation
|
| 77 |
+
|
| 78 |
+
total_parse = 0.0
|
| 79 |
+
total_plan = 0.0
|
| 80 |
+
total_exec = 0.0
|
| 81 |
+
|
| 82 |
+
# Profile a small subset by default so the script prints quickly.
|
| 83 |
+
sample_size = min(int(sample_size), len(rollouts))
|
| 84 |
+
sample_rollouts = rollouts[:sample_size]
|
| 85 |
+
|
| 86 |
+
for i, (pred, db, gold) in enumerate(sample_rollouts, 1):
|
| 87 |
+
_, timings = execution_reward_timed(pred, db, gold, measure_plan=True)
|
| 88 |
+
total_parse += timings['parse_s']
|
| 89 |
+
total_plan += timings['plan_s']
|
| 90 |
+
total_exec += timings['exec_s']
|
| 91 |
+
if print_every and (i % int(print_every) == 0 or i == sample_size):
|
| 92 |
+
print(f" profiled {i}/{sample_size}...", flush=True)
|
| 93 |
+
|
| 94 |
+
total_time = total_parse + total_plan + total_exec
|
| 95 |
+
if total_time == 0: total_time = 0.0001 # Prevent div by zero
|
| 96 |
+
|
| 97 |
+
print(f"{'Phase':<15} | {'Avg Time (ms)':<15} | {'% of Total CPU':<15}")
|
| 98 |
+
print("-" * 65)
|
| 99 |
+
print(f"{'Regex Parsing':<15} | {(total_parse/sample_size)*1000:<15.2f} | {(total_parse/total_time)*100:<14.1f}%")
|
| 100 |
+
print(f"{'Query Planning':<15} | {(total_plan/sample_size)*1000:<15.2f} | {(total_plan/total_time)*100:<14.1f}%")
|
| 101 |
+
print(f"{'DB Execution':<15} | {(total_exec/sample_size)*1000:<15.2f} | {(total_exec/total_time)*100:<14.1f}%")
|
| 102 |
+
print("="*65 + "\n")
|
| 103 |
+
|
| 104 |
+
def run_benchmark_for_setting(rollouts, use_cache: bool, max_workers: int):
|
| 105 |
+
set_use_cache(use_cache)
|
| 106 |
+
set_use_schema_validation(False) # benchmark focuses on execution speed
|
| 107 |
+
|
| 108 |
+
# Sequential
|
| 109 |
+
clear_result_cache()
|
| 110 |
+
start_time = time.perf_counter()
|
| 111 |
+
execution_reward_batch_sequential(rollouts)
|
| 112 |
+
sequential_s = time.perf_counter() - start_time
|
| 113 |
+
|
| 114 |
+
# Parallel
|
| 115 |
+
clear_result_cache()
|
| 116 |
+
start_time = time.perf_counter()
|
| 117 |
+
# 1 thread per DB (recommended)
|
| 118 |
+
execution_reward_batch_parallel_by_db(rollouts, max_workers=max_workers)
|
| 119 |
+
parallel_s = time.perf_counter() - start_time
|
| 120 |
+
|
| 121 |
+
speedup = sequential_s / parallel_s if parallel_s > 0 else 0
|
| 122 |
+
|
| 123 |
+
return {
|
| 124 |
+
"sequential_s": sequential_s,
|
| 125 |
+
"parallel_s": parallel_s,
|
| 126 |
+
"speedup": speedup
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
def print_comparison_table(results):
|
| 130 |
+
print("="*65)
|
| 131 |
+
print(f"{'Setting':<16} | {'Sequential (s)':<14} | {'Parallel (s)':<14} | {'Speedup':<10}")
|
| 132 |
+
print("-" * 65)
|
| 133 |
+
for setting, key in [("With Cache", "with_cache"), ("Without Cache", "without_cache")]:
|
| 134 |
+
seq = results[key]['sequential_s']
|
| 135 |
+
par = results[key]['parallel_s']
|
| 136 |
+
spd = results[key]['speedup']
|
| 137 |
+
print(f"{setting:<16} | {seq:<14.4f} | {par:<14.4f} | {spd:<9.2f}x")
|
| 138 |
+
print("="*65 + "\n")
|
| 139 |
+
|
| 140 |
+
def plot_results(results, output_path: str):
|
| 141 |
+
labels = ['With Cache', 'Without Cache']
|
| 142 |
+
seq_times = [results['with_cache']['sequential_s'], results['without_cache']['sequential_s']]
|
| 143 |
+
par_times = [results['with_cache']['parallel_s'], results['without_cache']['parallel_s']]
|
| 144 |
+
|
| 145 |
+
x = np.arange(len(labels))
|
| 146 |
+
width = 0.35
|
| 147 |
+
|
| 148 |
+
fig, ax = plt.subplots(figsize=(8, 6))
|
| 149 |
+
ax.bar(x - width/2, seq_times, width, label='Sequential', color='#4C72B0')
|
| 150 |
+
ax.bar(x + width/2, par_times, width, label='Parallel', color='#DD8452')
|
| 151 |
+
|
| 152 |
+
ax.set_ylabel('Execution Time (seconds)')
|
| 153 |
+
ax.set_title('Text2SQL Reward Execution: Sequential vs Parallel')
|
| 154 |
+
ax.set_xticks(x)
|
| 155 |
+
ax.set_xticklabels(labels)
|
| 156 |
+
ax.legend()
|
| 157 |
+
|
| 158 |
+
for container in ax.containers:
|
| 159 |
+
ax.bar_label(container, fmt='%.2f', padding=3)
|
| 160 |
+
|
| 161 |
+
fig.tight_layout()
|
| 162 |
+
plt.savefig(output_path, dpi=300)
|
| 163 |
+
plt.close()
|
| 164 |
+
|
| 165 |
+
def main():
|
| 166 |
+
parser = argparse.ArgumentParser(description="Benchmark SQL Execution Reward")
|
| 167 |
+
parser.add_argument("--n", type=int, default=1000, help="Number of rollouts to benchmark")
|
| 168 |
+
parser.add_argument("--max-workers", type=int, default=20, help="Max workers for parallel execution")
|
| 169 |
+
parser.add_argument("--heavy-n", type=int, default=200_000, help="Recursive CTE upper bound (controls heaviness)")
|
| 170 |
+
parser.add_argument("--skip-profile", action="store_true", help="Skip the CPU profiling section for faster startup")
|
| 171 |
+
parser.add_argument("--profile-n", type=int, default=20, help="Number of rollouts to use for CPU profiling")
|
| 172 |
+
args = parser.parse_args()
|
| 173 |
+
|
| 174 |
+
os.makedirs(str(PROJECT_ROOT / "results"), exist_ok=True)
|
| 175 |
+
|
| 176 |
+
rollouts = generate_mock_rollouts(args.n, heavy_n=args.heavy_n)
|
| 177 |
+
|
| 178 |
+
if not args.skip_profile:
|
| 179 |
+
profile_bottlenecks(rollouts, sample_size=args.profile_n)
|
| 180 |
+
|
| 181 |
+
print("Starting Main Scalability Benchmarks...")
|
| 182 |
+
|
| 183 |
+
print("Running Experiment A: Cache ENABLED...")
|
| 184 |
+
results_with_cache = run_benchmark_for_setting(rollouts, use_cache=True, max_workers=args.max_workers)
|
| 185 |
+
|
| 186 |
+
print("Running Experiment B: Cache DISABLED...")
|
| 187 |
+
results_without_cache = run_benchmark_for_setting(rollouts, use_cache=False, max_workers=args.max_workers)
|
| 188 |
+
|
| 189 |
+
final_results = {
|
| 190 |
+
"with_cache": results_with_cache,
|
| 191 |
+
"without_cache": results_without_cache
|
| 192 |
+
}
|
| 193 |
+
|
| 194 |
+
json_path = str(PROJECT_ROOT / "results" / "task1_results.json")
|
| 195 |
+
with open(json_path, 'w') as f:
|
| 196 |
+
json.dump(final_results, f, indent=4)
|
| 197 |
+
|
| 198 |
+
print_comparison_table(final_results)
|
| 199 |
+
plot_results(final_results, str(PROJECT_ROOT / "results" / "task1_plot.png"))
|
| 200 |
+
|
| 201 |
+
if __name__ == "__main__":
|
| 202 |
+
main()
|
scripts/benchmark_quantization.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import json
|
| 5 |
+
import os
|
| 6 |
+
import time
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import Dict, List, Tuple
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
import torch
|
| 12 |
+
|
| 13 |
+
from src.execution_reward import execution_reward
|
| 14 |
+
from src.prompting import encode_prompt
|
| 15 |
+
from src.quantization_utils import load_fp32_model, load_quant_artifact
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def _load_dev_items(root: Path, n: int, seed: int = 42) -> List[dict]:
|
| 19 |
+
data = json.loads((root / "data" / "dev.json").read_text())
|
| 20 |
+
if n >= len(data):
|
| 21 |
+
return data
|
| 22 |
+
rng = np.random.default_rng(seed)
|
| 23 |
+
idxs = rng.choice(len(data), size=n, replace=False)
|
| 24 |
+
return [data[int(i)] for i in idxs]
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def _bench_variant(name: str, tok, model, items: List[dict], device: str) -> Dict[str, float]:
|
| 28 |
+
latencies: List[float] = []
|
| 29 |
+
ex = 0
|
| 30 |
+
|
| 31 |
+
# Warmup (1 item)
|
| 32 |
+
if items:
|
| 33 |
+
it = items[0]
|
| 34 |
+
_ = encode_prompt(tok, it["question"], it["db_id"], device=device, max_input_tokens=512).unsqueeze(0)
|
| 35 |
+
|
| 36 |
+
for it in items:
|
| 37 |
+
db_id = it["db_id"]
|
| 38 |
+
q = it["question"]
|
| 39 |
+
gold = it["query"]
|
| 40 |
+
db_path = str(Path("data") / "database" / db_id / f"{db_id}.sqlite")
|
| 41 |
+
|
| 42 |
+
input_ids = encode_prompt(tok, q, db_id, device=device, max_input_tokens=512).unsqueeze(0)
|
| 43 |
+
t0 = time.perf_counter()
|
| 44 |
+
out = model.generate(input_ids=input_ids, max_new_tokens=120, num_beams=8, repetition_penalty=1.2)
|
| 45 |
+
dt = time.perf_counter() - t0
|
| 46 |
+
latencies.append(dt)
|
| 47 |
+
|
| 48 |
+
pred = tok.decode(out[0], skip_special_tokens=True).strip()
|
| 49 |
+
r = execution_reward(pred, db_path, gold)
|
| 50 |
+
if float(r) >= 1.0:
|
| 51 |
+
ex += 1
|
| 52 |
+
|
| 53 |
+
p50 = float(np.percentile(latencies, 50)) if latencies else 0.0
|
| 54 |
+
p90 = float(np.percentile(latencies, 90)) if latencies else 0.0
|
| 55 |
+
mean = float(np.mean(latencies)) if latencies else 0.0
|
| 56 |
+
return {
|
| 57 |
+
"n": float(len(items)),
|
| 58 |
+
"ex": float(ex / max(len(items), 1)),
|
| 59 |
+
"lat_mean_s": mean,
|
| 60 |
+
"lat_p50_s": p50,
|
| 61 |
+
"lat_p90_s": p90,
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def main() -> None:
|
| 66 |
+
p = argparse.ArgumentParser(description="Benchmark fp32 vs quantized artifacts (CPU-focused).")
|
| 67 |
+
p.add_argument("--base_model", default=os.environ.get("BASE_MODEL", "Salesforce/codet5-base"))
|
| 68 |
+
p.add_argument("--adapter", default="", help="Optional adapter for fp32 baseline.")
|
| 69 |
+
p.add_argument("--artifact_int8", default="", help="Artifact dir exported by scripts/quantize_export.py")
|
| 70 |
+
p.add_argument("--artifact_int8_decoder", default="", help="Artifact dir for decoder-only int8")
|
| 71 |
+
p.add_argument("--num_samples", type=int, default=100)
|
| 72 |
+
p.add_argument("--seed", type=int, default=42)
|
| 73 |
+
p.add_argument("--out", default="results/task5_quant_bench.json")
|
| 74 |
+
p.add_argument("--local_only", action="store_true")
|
| 75 |
+
args = p.parse_args()
|
| 76 |
+
|
| 77 |
+
device = "cpu"
|
| 78 |
+
root = Path(".")
|
| 79 |
+
items = _load_dev_items(root, args.num_samples, args.seed)
|
| 80 |
+
|
| 81 |
+
report: Dict[str, Dict[str, float]] = {}
|
| 82 |
+
|
| 83 |
+
tok, fp32 = load_fp32_model(
|
| 84 |
+
args.base_model,
|
| 85 |
+
adapter_path=args.adapter.strip() or None,
|
| 86 |
+
device=device,
|
| 87 |
+
local_only=args.local_only,
|
| 88 |
+
)
|
| 89 |
+
report["fp32"] = _bench_variant("fp32", tok, fp32, items, device)
|
| 90 |
+
|
| 91 |
+
if args.artifact_int8:
|
| 92 |
+
tok8, m8, _meta = load_quant_artifact(args.artifact_int8, device=device, local_only=True)
|
| 93 |
+
report["int8_dynamic"] = _bench_variant("int8_dynamic", tok8, m8, items, device)
|
| 94 |
+
|
| 95 |
+
if args.artifact_int8_decoder:
|
| 96 |
+
tokd, md, _meta = load_quant_artifact(args.artifact_int8_decoder, device=device, local_only=True)
|
| 97 |
+
report["int8_decoder_dynamic"] = _bench_variant("int8_decoder_dynamic", tokd, md, items, device)
|
| 98 |
+
|
| 99 |
+
out_path = Path(args.out)
|
| 100 |
+
out_path.parent.mkdir(parents=True, exist_ok=True)
|
| 101 |
+
out_path.write_text(json.dumps(report, indent=2))
|
| 102 |
+
print(json.dumps(report, indent=2))
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
if __name__ == "__main__":
|
| 106 |
+
torch.set_grad_enabled(False)
|
| 107 |
+
main()
|
| 108 |
+
|
scripts/benchmark_rollout_generation.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import json
|
| 5 |
+
import os
|
| 6 |
+
import time
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import List
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
import torch
|
| 12 |
+
|
| 13 |
+
from src.prompting import encode_prompt
|
| 14 |
+
from src.quantization_utils import load_fp32_model, load_quant_artifact
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def _load_items(root: Path, n: int, seed: int = 42) -> List[dict]:
|
| 18 |
+
data = json.loads((root / "data" / "dev.json").read_text())
|
| 19 |
+
if n >= len(data):
|
| 20 |
+
return data
|
| 21 |
+
rng = np.random.default_rng(seed)
|
| 22 |
+
idxs = rng.choice(len(data), size=n, replace=False)
|
| 23 |
+
return [data[int(i)] for i in idxs]
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def _bench_generate(tok, model, items: List[dict], device: str) -> float:
|
| 27 |
+
t0 = time.perf_counter()
|
| 28 |
+
for it in items:
|
| 29 |
+
input_ids = encode_prompt(tok, it["question"], it["db_id"], device=device, max_input_tokens=512).unsqueeze(0)
|
| 30 |
+
_ = model.generate(input_ids=input_ids, max_new_tokens=64, num_beams=4)
|
| 31 |
+
return time.perf_counter() - t0
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def main() -> None:
|
| 35 |
+
p = argparse.ArgumentParser(description="Benchmark rollout generation latency for RL loops.")
|
| 36 |
+
p.add_argument("--base_model", default=os.environ.get("BASE_MODEL", "Salesforce/codet5-base"))
|
| 37 |
+
p.add_argument("--adapter", default="")
|
| 38 |
+
p.add_argument("--artifact", default="", help="Quantized artifact dir (optional).")
|
| 39 |
+
p.add_argument("--num_rollouts", type=int, default=128)
|
| 40 |
+
p.add_argument("--seed", type=int, default=42)
|
| 41 |
+
p.add_argument("--local_only", action="store_true")
|
| 42 |
+
args = p.parse_args()
|
| 43 |
+
|
| 44 |
+
device = "cpu"
|
| 45 |
+
root = Path(".")
|
| 46 |
+
items = _load_items(root, args.num_rollouts, args.seed)
|
| 47 |
+
|
| 48 |
+
tok, fp32 = load_fp32_model(
|
| 49 |
+
args.base_model,
|
| 50 |
+
adapter_path=args.adapter.strip() or None,
|
| 51 |
+
device=device,
|
| 52 |
+
local_only=args.local_only,
|
| 53 |
+
)
|
| 54 |
+
t_fp32 = _bench_generate(tok, fp32, items, device)
|
| 55 |
+
print(f"fp32: {t_fp32:.2f}s for {len(items)} rollouts ({len(items)/max(t_fp32,1e-9):.2f} rollouts/s)")
|
| 56 |
+
|
| 57 |
+
if args.artifact:
|
| 58 |
+
tokq, mq, meta = load_quant_artifact(args.artifact, device=device, local_only=True)
|
| 59 |
+
t_q = _bench_generate(tokq, mq, items, device)
|
| 60 |
+
mode = meta.get("mode", "quant")
|
| 61 |
+
print(f"{mode}: {t_q:.2f}s for {len(items)} rollouts ({len(items)/max(t_q,1e-9):.2f} rollouts/s)")
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
if __name__ == "__main__":
|
| 65 |
+
torch.set_grad_enabled(False)
|
| 66 |
+
main()
|
scripts/error_dashboard.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import json
|
| 3 |
+
from collections import Counter
|
| 4 |
+
|
| 5 |
+
# ==============================
|
| 6 |
+
# LOAD LOGS
|
| 7 |
+
# ==============================
|
| 8 |
+
with open("results/error_logs.json") as f:
|
| 9 |
+
logs = json.load(f)
|
| 10 |
+
|
| 11 |
+
total_errors = len(logs)
|
| 12 |
+
|
| 13 |
+
# ==============================
|
| 14 |
+
# ERROR DISTRIBUTION
|
| 15 |
+
# ==============================
|
| 16 |
+
error_counts = Counter([e["error_type"] for e in logs])
|
| 17 |
+
|
| 18 |
+
print("\n" + "="*50)
|
| 19 |
+
print("📊 TEXT-to-SQL ERROR DASHBOARD")
|
| 20 |
+
print("="*50)
|
| 21 |
+
|
| 22 |
+
print(f"\n🔢 Total Errors Logged: {total_errors}")
|
| 23 |
+
|
| 24 |
+
print("\n📊 ERROR DISTRIBUTION:")
|
| 25 |
+
print("-"*30)
|
| 26 |
+
for k, v in error_counts.items():
|
| 27 |
+
percent = (v / total_errors) * 100
|
| 28 |
+
print(f"{k:<20} : {v:>4} ({percent:.1f}%)")
|
| 29 |
+
|
| 30 |
+
# ==============================
|
| 31 |
+
# TOP ERROR
|
| 32 |
+
# ==============================
|
| 33 |
+
top_error = error_counts.most_common(1)[0]
|
| 34 |
+
|
| 35 |
+
print("\n🔥 MOST COMMON ERROR:")
|
| 36 |
+
print("-"*30)
|
| 37 |
+
print(f"{top_error[0]} ({top_error[1]} times)")
|
| 38 |
+
|
| 39 |
+
# ==============================
|
| 40 |
+
# SQL OPERATION ANALYSIS
|
| 41 |
+
# ==============================
|
| 42 |
+
join_count = 0
|
| 43 |
+
where_count = 0
|
| 44 |
+
group_count = 0
|
| 45 |
+
order_count = 0
|
| 46 |
+
|
| 47 |
+
for e in logs:
|
| 48 |
+
sql = e["sql"].lower()
|
| 49 |
+
|
| 50 |
+
if "join" in sql:
|
| 51 |
+
join_count += 1
|
| 52 |
+
if "where" in sql:
|
| 53 |
+
where_count += 1
|
| 54 |
+
if "group by" in sql:
|
| 55 |
+
group_count += 1
|
| 56 |
+
if "order by" in sql:
|
| 57 |
+
order_count += 1
|
| 58 |
+
|
| 59 |
+
print("\n🧠 SQL OPERATION ANALYSIS:")
|
| 60 |
+
print("-"*30)
|
| 61 |
+
print(f"JOIN used in : {join_count} queries")
|
| 62 |
+
print(f"WHERE used in : {where_count} queries")
|
| 63 |
+
print(f"GROUP BY used in : {group_count} queries")
|
| 64 |
+
print(f"ORDER BY used in : {order_count} queries")
|
| 65 |
+
|
| 66 |
+
# ==============================
|
| 67 |
+
# SAMPLE ERRORS
|
| 68 |
+
# ==============================
|
| 69 |
+
print("\n🧪 SAMPLE ERROR CASES:")
|
| 70 |
+
print("-"*50)
|
| 71 |
+
|
| 72 |
+
for i, e in enumerate(logs[:3], 1):
|
| 73 |
+
print(f"\nCase {i}:")
|
| 74 |
+
print(f"Q : {e['question']}")
|
| 75 |
+
print(f"SQL : {e['sql']}")
|
| 76 |
+
print(f"Type: {e['error_type']}")
|
| 77 |
+
|
| 78 |
+
# ==============================
|
| 79 |
+
# FINAL INSIGHT
|
| 80 |
+
# ==============================
|
| 81 |
+
print("\n📌 FINAL INSIGHT:")
|
| 82 |
+
print("-"*30)
|
| 83 |
+
|
| 84 |
+
if top_error[0] == "wrong_column":
|
| 85 |
+
print("⚠️ Model struggles with column selection (schema understanding issue).")
|
| 86 |
+
|
| 87 |
+
elif top_error[0] == "wrong_table":
|
| 88 |
+
print("⚠️ Model struggles with correct table mapping.")
|
| 89 |
+
|
| 90 |
+
elif top_error[0] == "syntax_error":
|
| 91 |
+
print("⚠️ Model generates invalid SQL syntax.")
|
| 92 |
+
|
| 93 |
+
else:
|
| 94 |
+
print("⚠️ Mixed errors — needs general improvement.")
|
| 95 |
+
|
| 96 |
+
print("\n" + "="*50)
|
| 97 |
+
print("✅ DASHBOARD COMPLETE")
|
| 98 |
+
print("="*50)
|
| 99 |
+
|
scripts/evaluate.py
ADDED
|
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import sqlite3
|
| 5 |
+
from contextlib import closing
|
| 6 |
+
from typing import Dict, List
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from datasets import load_dataset
|
| 10 |
+
from peft import PeftModel
|
| 11 |
+
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
| 12 |
+
from trl import AutoModelForSeq2SeqLMWithValueHead
|
| 13 |
+
|
| 14 |
+
import sys
|
| 15 |
+
|
| 16 |
+
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 17 |
+
sys.path.append(PROJECT_ROOT)
|
| 18 |
+
from src.execution_reward import execution_reward # noqa: E402
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
BASE_MODEL = os.environ.get("BASE_MODEL", "t5-small")
|
| 22 |
+
DB_ROOT = os.path.join(PROJECT_ROOT, "data", "database")
|
| 23 |
+
|
| 24 |
+
# Prefer RL best model if present; otherwise fall back.
|
| 25 |
+
RL_DIR = os.path.join(PROJECT_ROOT, "outputs", "rlhf_text2sql", "best_model")
|
| 26 |
+
if not os.path.isdir(RL_DIR):
|
| 27 |
+
RL_DIR = os.path.join(PROJECT_ROOT, "outputs", "rlhf_text2sql")
|
| 28 |
+
|
| 29 |
+
SPLIT = "train[:100]" # quick sanity check
|
| 30 |
+
MAX_NEW_TOKENS = 128
|
| 31 |
+
|
| 32 |
+
PREFIX = "translate English to SQL:"
|
| 33 |
+
MAX_SCHEMA_CHARS = 1500
|
| 34 |
+
MAX_INPUT_TOKENS = 512
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
|
| 38 |
+
device = "mps" if torch.backends.mps.is_available() else "cpu"
|
| 39 |
+
print("Using device:", device)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def get_db_path(db_id: str) -> str:
|
| 43 |
+
return os.path.join(DB_ROOT, db_id, f"{db_id}.sqlite")
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
_SCHEMA_CACHE: Dict[str, str] = {}
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def get_db_schema_text(db_path: str) -> str:
|
| 50 |
+
if db_path in _SCHEMA_CACHE:
|
| 51 |
+
return _SCHEMA_CACHE[db_path]
|
| 52 |
+
schema_text = ""
|
| 53 |
+
try:
|
| 54 |
+
with closing(sqlite3.connect(db_path)) as conn:
|
| 55 |
+
cur = conn.cursor()
|
| 56 |
+
tables = cur.execute(
|
| 57 |
+
"SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%';"
|
| 58 |
+
).fetchall()
|
| 59 |
+
for (tname,) in tables:
|
| 60 |
+
cols = cur.execute(f'PRAGMA table_info(\"{tname}\")').fetchall()
|
| 61 |
+
col_names = [c[1] for c in cols if c and isinstance(c[1], str)]
|
| 62 |
+
schema_text += f"{tname}({', '.join(col_names)}) "
|
| 63 |
+
except Exception:
|
| 64 |
+
schema_text = ""
|
| 65 |
+
if len(schema_text) > MAX_SCHEMA_CHARS:
|
| 66 |
+
schema_text = schema_text[:MAX_SCHEMA_CHARS]
|
| 67 |
+
_SCHEMA_CACHE[db_path] = schema_text
|
| 68 |
+
return schema_text
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def encode_prompt(tokenizer, question: str, schema: str) -> torch.Tensor:
|
| 72 |
+
schema = (schema or "")[:MAX_SCHEMA_CHARS]
|
| 73 |
+
prefix_schema = f"{PREFIX}\n\nSchema:\n"
|
| 74 |
+
mid = "\n\nQuestion:\n"
|
| 75 |
+
suffix = f"{question}\n\nSQL:"
|
| 76 |
+
|
| 77 |
+
prefix_ids = tokenizer.encode(prefix_schema, add_special_tokens=False)
|
| 78 |
+
schema_ids = tokenizer.encode(schema, add_special_tokens=False)
|
| 79 |
+
mid_ids = tokenizer.encode(mid, add_special_tokens=False)
|
| 80 |
+
suffix_ids = tokenizer.encode(suffix, add_special_tokens=False)
|
| 81 |
+
|
| 82 |
+
eos_id = tokenizer.eos_token_id
|
| 83 |
+
max_without_eos = MAX_INPUT_TOKENS - (1 if eos_id is not None else 0)
|
| 84 |
+
|
| 85 |
+
fixed_len = len(prefix_ids) + len(mid_ids) + len(suffix_ids)
|
| 86 |
+
if fixed_len > max_without_eos:
|
| 87 |
+
keep = max(0, max_without_eos - (len(prefix_ids) + len(mid_ids)))
|
| 88 |
+
suffix_ids = suffix_ids[:keep]
|
| 89 |
+
fixed_len = len(prefix_ids) + len(mid_ids) + len(suffix_ids)
|
| 90 |
+
|
| 91 |
+
remaining_for_schema = max_without_eos - fixed_len
|
| 92 |
+
if remaining_for_schema < 0:
|
| 93 |
+
remaining_for_schema = 0
|
| 94 |
+
schema_ids = schema_ids[:remaining_for_schema]
|
| 95 |
+
|
| 96 |
+
ids = (prefix_ids + schema_ids + mid_ids + suffix_ids)[:max_without_eos]
|
| 97 |
+
if eos_id is not None:
|
| 98 |
+
ids = ids + [eos_id]
|
| 99 |
+
|
| 100 |
+
return torch.tensor(ids, dtype=torch.long).to(device)
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def load_model_and_tokenizer():
|
| 104 |
+
# Try loading the PPO-saved value-head model directly.
|
| 105 |
+
try:
|
| 106 |
+
tok = AutoTokenizer.from_pretrained(RL_DIR)
|
| 107 |
+
mdl = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(RL_DIR).to(device)
|
| 108 |
+
return tok, mdl
|
| 109 |
+
except Exception:
|
| 110 |
+
pass
|
| 111 |
+
|
| 112 |
+
# Fallback: treat RL_DIR as a LoRA adapter directory.
|
| 113 |
+
tok = AutoTokenizer.from_pretrained(BASE_MODEL)
|
| 114 |
+
if tok.pad_token_id is None:
|
| 115 |
+
tok.pad_token = tok.eos_token
|
| 116 |
+
base = AutoModelForSeq2SeqLM.from_pretrained(BASE_MODEL).to(device)
|
| 117 |
+
try:
|
| 118 |
+
base = PeftModel.from_pretrained(base, RL_DIR)
|
| 119 |
+
except Exception:
|
| 120 |
+
# Final fallback: use SFT adapter (if RL adapter not found)
|
| 121 |
+
sft_dir = os.path.join(PROJECT_ROOT, "checkpoints", "sft_adapter")
|
| 122 |
+
base = PeftModel.from_pretrained(base, sft_dir)
|
| 123 |
+
return tok, base
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def main() -> None:
|
| 127 |
+
tokenizer, model = load_model_and_tokenizer()
|
| 128 |
+
model.eval()
|
| 129 |
+
|
| 130 |
+
ds = load_dataset("spider", split=SPLIT)
|
| 131 |
+
|
| 132 |
+
correct = 0
|
| 133 |
+
valid = 0
|
| 134 |
+
|
| 135 |
+
for i, ex in enumerate(ds, start=1):
|
| 136 |
+
question = ex["question"]
|
| 137 |
+
gold_sql = ex["query"]
|
| 138 |
+
db_id = ex["db_id"]
|
| 139 |
+
db_path = get_db_path(db_id)
|
| 140 |
+
schema = get_db_schema_text(db_path)
|
| 141 |
+
|
| 142 |
+
inp = encode_prompt(tokenizer, question, schema)
|
| 143 |
+
with torch.no_grad():
|
| 144 |
+
out = model.generate(
|
| 145 |
+
input_ids=inp.unsqueeze(0),
|
| 146 |
+
max_new_tokens=MAX_NEW_TOKENS,
|
| 147 |
+
do_sample=False,
|
| 148 |
+
num_beams=1,
|
| 149 |
+
pad_token_id=tokenizer.pad_token_id,
|
| 150 |
+
eos_token_id=tokenizer.eos_token_id,
|
| 151 |
+
)
|
| 152 |
+
pred_sql = tokenizer.decode(out[0], skip_special_tokens=True)
|
| 153 |
+
r = execution_reward(pred_sql, db_path, gold_sql)
|
| 154 |
+
if r > -1.0:
|
| 155 |
+
valid += 1
|
| 156 |
+
if r >= 1.0:
|
| 157 |
+
correct += 1
|
| 158 |
+
|
| 159 |
+
if i % 25 == 0:
|
| 160 |
+
print(f"Evaluated {i}/{len(ds)}")
|
| 161 |
+
|
| 162 |
+
n = len(ds)
|
| 163 |
+
print("\nRESULTS")
|
| 164 |
+
print(f"examples: {n}")
|
| 165 |
+
print(f"execution_accuracy: {correct/n:.3f}")
|
| 166 |
+
print(f"valid_sql_rate: {valid/n:.3f}")
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
if __name__ == "__main__":
|
| 170 |
+
main()
|
scripts/plot_task2.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import matplotlib.pyplot as plt
|
| 2 |
+
import seaborn as sns
|
| 3 |
+
|
| 4 |
+
# ==========================================
|
| 5 |
+
# 1. EXTRACTED DATA FROM TERMINAL
|
| 6 |
+
# ==========================================
|
| 7 |
+
# Error Distribution Data
|
| 8 |
+
error_types = ['wrong_column', 'wrong_table', 'ambiguous_column', 'other']
|
| 9 |
+
error_counts = [61, 11, 4, 1]
|
| 10 |
+
|
| 11 |
+
# SQL Operation Analysis Data
|
| 12 |
+
sql_ops = ['WHERE', 'JOIN', 'ORDER BY', 'GROUP BY']
|
| 13 |
+
op_counts = [55, 36, 20, 14]
|
| 14 |
+
|
| 15 |
+
# ==========================================
|
| 16 |
+
# 2. SET UP THE DASHBOARD LAYOUT
|
| 17 |
+
# ==========================================
|
| 18 |
+
# Use a clean, modern aesthetic
|
| 19 |
+
sns.set_theme(style="whitegrid")
|
| 20 |
+
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))
|
| 21 |
+
|
| 22 |
+
# ==========================================
|
| 23 |
+
# 3. PLOT 1: ERROR DISTRIBUTION (Horizontal Bar)
|
| 24 |
+
# ==========================================
|
| 25 |
+
sns.barplot(x=error_counts, y=error_types, ax=ax1, palette="flare")
|
| 26 |
+
ax1.set_title('Primary Cause of Failure (Total: 77 Errors)', fontsize=14, pad=15, fontweight='bold')
|
| 27 |
+
ax1.set_xlabel('Number of Queries')
|
| 28 |
+
ax1.set_ylabel('')
|
| 29 |
+
|
| 30 |
+
# Add actual numbers next to the bars
|
| 31 |
+
for i, v in enumerate(error_counts):
|
| 32 |
+
ax1.text(v + 1.5, i, f"{v}", color='#333333', va='center', fontweight='bold')
|
| 33 |
+
|
| 34 |
+
# ==========================================
|
| 35 |
+
# 4. PLOT 2: SQL OPERATIONS (Vertical Bar)
|
| 36 |
+
# ==========================================
|
| 37 |
+
sns.barplot(x=sql_ops, y=op_counts, ax=ax2, palette="crest")
|
| 38 |
+
ax2.set_title('Clauses Present in Failed Queries', fontsize=14, pad=15, fontweight='bold')
|
| 39 |
+
ax2.set_ylabel('Frequency')
|
| 40 |
+
ax2.set_xlabel('')
|
| 41 |
+
|
| 42 |
+
# Add actual numbers on top of the bars
|
| 43 |
+
for i, v in enumerate(op_counts):
|
| 44 |
+
ax2.text(i, v + 1, str(v), color='#333333', ha='center', fontweight='bold')
|
| 45 |
+
|
| 46 |
+
# ==========================================
|
| 47 |
+
# 5. RENDER AND SAVE
|
| 48 |
+
# ==========================================
|
| 49 |
+
plt.suptitle('Text-to-SQL Error Diagnostic Dashboard', fontsize=18, fontweight='heavy', y=1.05)
|
| 50 |
+
sns.despine(left=True, bottom=True) # Removes clunky borders
|
| 51 |
+
plt.tight_layout()
|
| 52 |
+
|
| 53 |
+
# Save the plot as a high-res image for your report!
|
| 54 |
+
plt.savefig('error_diagnostic_plot.png', dpi=300, bbox_inches='tight')
|
| 55 |
+
print("✅ Plot successfully saved as 'error_diagnostic_plot.png'")
|
| 56 |
+
|
| 57 |
+
# Display the plot
|
| 58 |
+
plt.show()
|
scripts/plot_task3.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import matplotlib.pyplot as plt
|
| 2 |
+
|
| 3 |
+
labels = ["Without", "With"]
|
| 4 |
+
constraint = [0, 88]
|
| 5 |
+
|
| 6 |
+
plt.figure()
|
| 7 |
+
plt.bar(labels, constraint)
|
| 8 |
+
|
| 9 |
+
plt.title("Constraint Satisfaction (Task 3)")
|
| 10 |
+
plt.ylabel("Percentage")
|
| 11 |
+
|
| 12 |
+
plt.savefig("task3_constraint.png")
|
| 13 |
+
plt.show()
|
| 14 |
+
|
| 15 |
+
|
scripts/plot_task3_plotly.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import plotly.graph_objects as go
|
| 2 |
+
from plotly.subplots import make_subplots
|
| 3 |
+
|
| 4 |
+
# ==========================================
|
| 5 |
+
# 1. YOUR DATA
|
| 6 |
+
# ==========================================
|
| 7 |
+
models = ['FP32 (Base)', 'INT8 Dynamic', 'INT8 Decoder-Only']
|
| 8 |
+
|
| 9 |
+
# Accuracy (multiplied by 100 for percentage)
|
| 10 |
+
accuracy = [36.0, 36.0, 38.0]
|
| 11 |
+
|
| 12 |
+
# Latency metrics
|
| 13 |
+
lat_mean = [3.11, 1.65, 1.66]
|
| 14 |
+
lat_p50 = [2.94, 1.54, 1.56]
|
| 15 |
+
lat_p90 = [4.64, 2.44, 2.48]
|
| 16 |
+
|
| 17 |
+
# ==========================================
|
| 18 |
+
# 2. SET UP THE SIDE-BY-SIDE LAYOUT
|
| 19 |
+
# ==========================================
|
| 20 |
+
fig = make_subplots(
|
| 21 |
+
rows=1, cols=2,
|
| 22 |
+
subplot_titles=(
|
| 23 |
+
"<b>Model Accuracy (Execution)</b>",
|
| 24 |
+
"<b>Inference Latency Profile</b>"
|
| 25 |
+
),
|
| 26 |
+
horizontal_spacing=0.1
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
# ==========================================
|
| 30 |
+
# 3. LEFT CHART: ACCURACY
|
| 31 |
+
# ==========================================
|
| 32 |
+
fig.add_trace(go.Bar(
|
| 33 |
+
x=models,
|
| 34 |
+
y=accuracy,
|
| 35 |
+
name="Execution Accuracy",
|
| 36 |
+
marker_color=['#94a3b8', '#38bdf8', '#10b981'], # Gray, Blue, Green
|
| 37 |
+
text=[f"{val:.1f}%" for val in accuracy],
|
| 38 |
+
textposition='auto',
|
| 39 |
+
textfont=dict(size=14, color='white', family="Arial Black"),
|
| 40 |
+
showlegend=False
|
| 41 |
+
), row=1, col=1)
|
| 42 |
+
|
| 43 |
+
# ==========================================
|
| 44 |
+
# 4. RIGHT CHART: LATENCY PROFILE
|
| 45 |
+
# ==========================================
|
| 46 |
+
# P50 Latency
|
| 47 |
+
fig.add_trace(go.Bar(
|
| 48 |
+
x=models, y=lat_p50,
|
| 49 |
+
name="Median (P50)",
|
| 50 |
+
marker_color="#ece80a" # Light Blue
|
| 51 |
+
), row=1, col=2)
|
| 52 |
+
|
| 53 |
+
# Mean Latency
|
| 54 |
+
fig.add_trace(go.Bar(
|
| 55 |
+
x=models, y=lat_mean,
|
| 56 |
+
name="Mean Latency",
|
| 57 |
+
marker_color="#3b4da9" # Standard Blue
|
| 58 |
+
), row=1, col=2)
|
| 59 |
+
|
| 60 |
+
# P90 Latency
|
| 61 |
+
fig.add_trace(go.Bar(
|
| 62 |
+
x=models, y=lat_p90,
|
| 63 |
+
name="90th Percentile (P90)",
|
| 64 |
+
marker_color="#d974e2" # Dark Blue
|
| 65 |
+
), row=1, col=2)
|
| 66 |
+
|
| 67 |
+
# ==========================================
|
| 68 |
+
# 5. APPLY ULTRA-MODERN STYLING
|
| 69 |
+
# ==========================================
|
| 70 |
+
fig.update_layout(
|
| 71 |
+
title=dict(
|
| 72 |
+
text="<b>Task 5: FP32 vs. INT8 Quantization Performance</b>",
|
| 73 |
+
font=dict(size=22, color='#1e293b'),
|
| 74 |
+
x=0.5
|
| 75 |
+
),
|
| 76 |
+
plot_bgcolor='white',
|
| 77 |
+
paper_bgcolor='white',
|
| 78 |
+
barmode='group',
|
| 79 |
+
legend=dict(
|
| 80 |
+
orientation="h",
|
| 81 |
+
yanchor="bottom", y=1.05,
|
| 82 |
+
xanchor="center", x=0.8,
|
| 83 |
+
bgcolor='rgba(255,255,255,0.8)'
|
| 84 |
+
),
|
| 85 |
+
font=dict(family="-apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Helvetica, Arial, sans-serif"),
|
| 86 |
+
margin=dict(t=120, b=60, l=60, r=40)
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
# Style Left Axes
|
| 90 |
+
fig.update_yaxes(title_text="<b>Accuracy (%)</b>", range=[0, 45], gridcolor='#f1f5f9', row=1, col=1)
|
| 91 |
+
fig.update_xaxes(tickfont=dict(weight='bold'), row=1, col=1)
|
| 92 |
+
|
| 93 |
+
# Style Right Axes
|
| 94 |
+
fig.update_yaxes(title_text="<b>Seconds per Query</b>", gridcolor='#f1f5f9', row=1, col=2)
|
| 95 |
+
fig.update_xaxes(tickfont=dict(weight='bold'), row=1, col=2)
|
| 96 |
+
|
| 97 |
+
# ==========================================
|
| 98 |
+
# 6. RENDER AND SAVE
|
| 99 |
+
# ==========================================
|
| 100 |
+
html_file = "task5_quantization_dashboard.html"
|
| 101 |
+
fig.write_html(html_file)
|
| 102 |
+
print(f"✅ Interactive Plotly Dashboard saved to: {html_file}")
|
| 103 |
+
fig.show()
|
scripts/quantize_export.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import os
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
from src.quantization_utils import (
|
| 10 |
+
load_bnb_quantized_model,
|
| 11 |
+
load_fp32_model,
|
| 12 |
+
quantize_dynamic_int8,
|
| 13 |
+
quantize_dynamic_int8_decoder_only,
|
| 14 |
+
save_quant_artifact,
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def main() -> None:
|
| 19 |
+
p = argparse.ArgumentParser(description="Export quantized Seq2Seq model artifacts for CPU inference.")
|
| 20 |
+
p.add_argument("--base_model", default=os.environ.get("BASE_MODEL", "Salesforce/codet5-base"))
|
| 21 |
+
p.add_argument("--adapter", default="", help="Optional LoRA adapter directory.")
|
| 22 |
+
p.add_argument("--out_dir", required=True, help="Output directory for artifact.")
|
| 23 |
+
p.add_argument(
|
| 24 |
+
"--mode",
|
| 25 |
+
required=True,
|
| 26 |
+
choices=["fp32", "int8_dynamic", "int8_decoder_dynamic", "int8_bnb", "int4_bnb"],
|
| 27 |
+
)
|
| 28 |
+
p.add_argument("--device", default="cpu", help="cpu|cuda (bnb requires cuda)")
|
| 29 |
+
p.add_argument("--local_only", action="store_true", help="Do not hit network; use HF cache only.")
|
| 30 |
+
args = p.parse_args()
|
| 31 |
+
|
| 32 |
+
adapter = args.adapter.strip() or None
|
| 33 |
+
out_dir = Path(args.out_dir)
|
| 34 |
+
|
| 35 |
+
if args.mode == "fp32":
|
| 36 |
+
tok, model = load_fp32_model(args.base_model, adapter_path=adapter, device=args.device, local_only=args.local_only)
|
| 37 |
+
save_quant_artifact(out_dir, mode="fp32", base_model=args.base_model, adapter_path=adapter, tokenizer=tok, model=model)
|
| 38 |
+
return
|
| 39 |
+
|
| 40 |
+
if args.mode == "int8_dynamic":
|
| 41 |
+
tok, model = load_fp32_model(args.base_model, adapter_path=adapter, device="cpu", local_only=args.local_only)
|
| 42 |
+
model = quantize_dynamic_int8(model)
|
| 43 |
+
save_quant_artifact(out_dir, mode="int8_dynamic", base_model=args.base_model, adapter_path=adapter, tokenizer=tok, model=model)
|
| 44 |
+
return
|
| 45 |
+
|
| 46 |
+
if args.mode == "int8_decoder_dynamic":
|
| 47 |
+
tok, model = load_fp32_model(args.base_model, adapter_path=adapter, device="cpu", local_only=args.local_only)
|
| 48 |
+
model = quantize_dynamic_int8_decoder_only(model)
|
| 49 |
+
save_quant_artifact(
|
| 50 |
+
out_dir,
|
| 51 |
+
mode="int8_decoder_dynamic",
|
| 52 |
+
base_model=args.base_model,
|
| 53 |
+
adapter_path=adapter,
|
| 54 |
+
tokenizer=tok,
|
| 55 |
+
model=model,
|
| 56 |
+
)
|
| 57 |
+
return
|
| 58 |
+
|
| 59 |
+
if args.mode == "int8_bnb":
|
| 60 |
+
tok, model = load_bnb_quantized_model(
|
| 61 |
+
args.base_model,
|
| 62 |
+
adapter_path=adapter,
|
| 63 |
+
device=args.device,
|
| 64 |
+
local_only=args.local_only,
|
| 65 |
+
load_in_8bit=True,
|
| 66 |
+
)
|
| 67 |
+
# Note: saving bnb quantized weights in a portable way is non-trivial; we still save state_dict for reference.
|
| 68 |
+
save_quant_artifact(out_dir, mode="int8_bnb", base_model=args.base_model, adapter_path=adapter, tokenizer=tok, model=model)
|
| 69 |
+
return
|
| 70 |
+
|
| 71 |
+
if args.mode == "int4_bnb":
|
| 72 |
+
tok, model = load_bnb_quantized_model(
|
| 73 |
+
args.base_model,
|
| 74 |
+
adapter_path=adapter,
|
| 75 |
+
device=args.device,
|
| 76 |
+
local_only=args.local_only,
|
| 77 |
+
load_in_4bit=True,
|
| 78 |
+
)
|
| 79 |
+
save_quant_artifact(out_dir, mode="int4_bnb", base_model=args.base_model, adapter_path=adapter, tokenizer=tok, model=model)
|
| 80 |
+
return
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
if __name__ == "__main__":
|
| 84 |
+
torch.set_grad_enabled(False)
|
| 85 |
+
main()
|
| 86 |
+
|
scripts/quantized_infer_harness.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import json
|
| 5 |
+
import time
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
|
| 8 |
+
from src.quantized_text2sql_engine import QuantizedText2SQLEngine
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def main() -> None:
|
| 12 |
+
p = argparse.ArgumentParser(description="Production-style inference harness for quantized artifacts.")
|
| 13 |
+
p.add_argument("--artifact", required=True, help="Quant artifact dir from scripts/quantize_export.py")
|
| 14 |
+
p.add_argument("--num_samples", type=int, default=128)
|
| 15 |
+
p.add_argument("--out", default="results/task5_quant_infer.json")
|
| 16 |
+
args = p.parse_args()
|
| 17 |
+
|
| 18 |
+
root = Path(".")
|
| 19 |
+
dev = json.loads((root / "data" / "dev.json").read_text())
|
| 20 |
+
dev = dev[: args.num_samples]
|
| 21 |
+
|
| 22 |
+
engine = QuantizedText2SQLEngine(args.artifact, device="cpu")
|
| 23 |
+
pairs = [(x["question"], x["db_id"]) for x in dev]
|
| 24 |
+
|
| 25 |
+
t0 = time.perf_counter()
|
| 26 |
+
results = engine.ask_batch_execute(pairs)
|
| 27 |
+
dt = time.perf_counter() - t0
|
| 28 |
+
|
| 29 |
+
out = {
|
| 30 |
+
"n": len(results),
|
| 31 |
+
"seconds": dt,
|
| 32 |
+
"qps": len(results) / max(dt, 1e-9),
|
| 33 |
+
"artifact": args.artifact,
|
| 34 |
+
"meta": engine.meta,
|
| 35 |
+
"results": results[:10], # sample
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
out_path = Path(args.out)
|
| 39 |
+
out_path.parent.mkdir(parents=True, exist_ok=True)
|
| 40 |
+
out_path.write_text(json.dumps(out, indent=2))
|
| 41 |
+
print(json.dumps(out, indent=2))
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
if __name__ == "__main__":
|
| 45 |
+
main()
|
| 46 |
+
|
src/__pycache__/execution_reward.cpython-310.pyc
ADDED
|
Binary file (11.2 kB). View file
|
|
|
src/__pycache__/quantization_utils.cpython-310.pyc
ADDED
|
Binary file (6.23 kB). View file
|
|
|
src/__pycache__/quantized_text2sql_engine.cpython-310.pyc
ADDED
|
Binary file (8.95 kB). View file
|
|
|
src/__pycache__/schema_encoder.cpython-310.pyc
ADDED
|
Binary file (1.75 kB). View file
|
|
|
src/__pycache__/schema_utils.cpython-310.pyc
ADDED
|
Binary file (3.64 kB). View file
|
|
|
src/__pycache__/sql_validator.cpython-310.pyc
ADDED
|
Binary file (5.29 kB). View file
|
|
|
src/__pycache__/text2sql_engine.cpython-310.pyc
ADDED
|
Binary file (8.34 kB). View file
|
|
|
src/ask.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
TERMINAL CHAT WITH DATABASE
|
| 3 |
+
Run:
|
| 4 |
+
python src/ask.py chinook_1
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import sys
|
| 8 |
+
from text2sql_engine import get_engine
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
# -------------------------------
|
| 12 |
+
# Pretty table printer
|
| 13 |
+
# -------------------------------
|
| 14 |
+
def print_table(cols, rows, limit=20):
|
| 15 |
+
if not rows or not cols:
|
| 16 |
+
print("No results\n")
|
| 17 |
+
return
|
| 18 |
+
|
| 19 |
+
cols = [str(c) for c in cols]
|
| 20 |
+
|
| 21 |
+
widths = [max(len(c), 12) for c in cols]
|
| 22 |
+
|
| 23 |
+
for r in rows[:limit]:
|
| 24 |
+
for i, val in enumerate(r):
|
| 25 |
+
widths[i] = max(widths[i], len(str(val)))
|
| 26 |
+
|
| 27 |
+
header = " | ".join(cols[i].ljust(widths[i]) for i in range(len(cols)))
|
| 28 |
+
print("\n" + header)
|
| 29 |
+
print("-" * len(header))
|
| 30 |
+
|
| 31 |
+
for r in rows[:limit]:
|
| 32 |
+
print(" | ".join(str(r[i]).ljust(widths[i]) for i in range(len(cols))))
|
| 33 |
+
|
| 34 |
+
if len(rows) > limit:
|
| 35 |
+
print(f"\n... showing first {limit} rows of {len(rows)}")
|
| 36 |
+
|
| 37 |
+
print()
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
# -------------------------------
|
| 41 |
+
# Main loop
|
| 42 |
+
# -------------------------------
|
| 43 |
+
def main():
|
| 44 |
+
if len(sys.argv) < 2:
|
| 45 |
+
print("Usage: python src/ask.py <db_id>")
|
| 46 |
+
return
|
| 47 |
+
|
| 48 |
+
db_id = sys.argv[1].strip()
|
| 49 |
+
|
| 50 |
+
print("Loading model... (first time takes 20-40s)")
|
| 51 |
+
engine = get_engine()
|
| 52 |
+
|
| 53 |
+
print(f"\nConnected to database: {db_id}")
|
| 54 |
+
print("Type 'exit' to quit\n")
|
| 55 |
+
|
| 56 |
+
while True:
|
| 57 |
+
try:
|
| 58 |
+
q = input("Ask> ").strip()
|
| 59 |
+
|
| 60 |
+
if not q:
|
| 61 |
+
continue
|
| 62 |
+
|
| 63 |
+
if q.lower() in ["exit", "quit"]:
|
| 64 |
+
break
|
| 65 |
+
|
| 66 |
+
result = engine.ask(q, db_id)
|
| 67 |
+
|
| 68 |
+
if result is None:
|
| 69 |
+
print("Model returned no output\n")
|
| 70 |
+
continue
|
| 71 |
+
|
| 72 |
+
print("\nGenerated SQL:")
|
| 73 |
+
print(result.get("sql", "<no sql>"))
|
| 74 |
+
|
| 75 |
+
if result.get("error"):
|
| 76 |
+
print("\nSQL Error:")
|
| 77 |
+
print(result["error"])
|
| 78 |
+
else:
|
| 79 |
+
print_table(
|
| 80 |
+
result.get("columns", []),
|
| 81 |
+
result.get("rows", []),
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
except KeyboardInterrupt:
|
| 85 |
+
break
|
| 86 |
+
except Exception as e:
|
| 87 |
+
print("\nRuntime error:", e, "\n")
|
| 88 |
+
|
| 89 |
+
print("\nBye!")
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
if __name__ == "__main__":
|
| 93 |
+
main()
|
src/component_analysis.py
ADDED
|
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import sqlite3
|
| 3 |
+
import torch
|
| 4 |
+
import re
|
| 5 |
+
import matplotlib.pyplot as plt
|
| 6 |
+
import numpy as np
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
| 9 |
+
from peft import PeftModel
|
| 10 |
+
|
| 11 |
+
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
| 12 |
+
DB_ROOT = PROJECT_ROOT / "data" / "database"
|
| 13 |
+
|
| 14 |
+
# -------------------------------
|
| 15 |
+
# Extract SQL components
|
| 16 |
+
# -------------------------------
|
| 17 |
+
def extract_components(sql):
|
| 18 |
+
sql = sql.lower()
|
| 19 |
+
return {
|
| 20 |
+
"select": "select" in sql,
|
| 21 |
+
"where": "where" in sql,
|
| 22 |
+
"group": "group by" in sql,
|
| 23 |
+
"order": "order by" in sql,
|
| 24 |
+
"and_or": (" and " in sql) or (" or " in sql),
|
| 25 |
+
"join": "join" in sql
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
# -------------------------------
|
| 29 |
+
# Fallback Difficulty Estimator
|
| 30 |
+
# -------------------------------
|
| 31 |
+
def estimate_difficulty(sql):
|
| 32 |
+
"""Fallback if 'difficulty' is missing from the JSON."""
|
| 33 |
+
sql = sql.lower()
|
| 34 |
+
joins = sql.count("join")
|
| 35 |
+
conditions = sql.count("and") + sql.count("or")
|
| 36 |
+
|
| 37 |
+
if "intersect" in sql or "except" in sql or "union" in sql or joins > 2:
|
| 38 |
+
return "extra"
|
| 39 |
+
elif joins == 2 or ("group by" in sql and conditions > 0):
|
| 40 |
+
return "hard"
|
| 41 |
+
elif joins == 1 or "group by" in sql or "order by" in sql:
|
| 42 |
+
return "medium"
|
| 43 |
+
else:
|
| 44 |
+
return "easy"
|
| 45 |
+
|
| 46 |
+
# -------------------------------
|
| 47 |
+
# Load schema
|
| 48 |
+
# -------------------------------
|
| 49 |
+
def load_schema(db_path):
|
| 50 |
+
conn = sqlite3.connect(db_path)
|
| 51 |
+
conn.text_factory = lambda b: b.decode(errors='ignore')
|
| 52 |
+
cursor = conn.cursor()
|
| 53 |
+
|
| 54 |
+
tables = cursor.execute(
|
| 55 |
+
"SELECT name FROM sqlite_master WHERE type='table';"
|
| 56 |
+
).fetchall()
|
| 57 |
+
|
| 58 |
+
schema = ""
|
| 59 |
+
for (table,) in tables:
|
| 60 |
+
cols = cursor.execute(f"PRAGMA table_info({table});").fetchall()
|
| 61 |
+
col_names = [c[1] for c in cols]
|
| 62 |
+
schema += f"{table}({', '.join(col_names)})\n"
|
| 63 |
+
|
| 64 |
+
conn.close()
|
| 65 |
+
return schema
|
| 66 |
+
|
| 67 |
+
# -------------------------------
|
| 68 |
+
# Prompt
|
| 69 |
+
# -------------------------------
|
| 70 |
+
def build_prompt(question, schema):
|
| 71 |
+
return f"""Database Schema:
|
| 72 |
+
{schema}
|
| 73 |
+
|
| 74 |
+
Translate English to SQL:
|
| 75 |
+
{question}
|
| 76 |
+
SQL:
|
| 77 |
+
"""
|
| 78 |
+
|
| 79 |
+
# -------------------------------
|
| 80 |
+
# Main
|
| 81 |
+
# -------------------------------
|
| 82 |
+
def main():
|
| 83 |
+
adapter = "checkpoints/rl_step_1800"
|
| 84 |
+
base_model = "Salesforce/codet5-base"
|
| 85 |
+
|
| 86 |
+
device = "mps" if torch.backends.mps.is_available() else ("cuda" if torch.cuda.is_available() else "cpu")
|
| 87 |
+
|
| 88 |
+
print("Loading tokenizer and models...")
|
| 89 |
+
tokenizer = AutoTokenizer.from_pretrained(adapter)
|
| 90 |
+
base = AutoModelForSeq2SeqLM.from_pretrained(base_model).to(device)
|
| 91 |
+
model = PeftModel.from_pretrained(base, adapter).to(device)
|
| 92 |
+
model = model.merge_and_unload()
|
| 93 |
+
model.eval()
|
| 94 |
+
|
| 95 |
+
dev_json = PROJECT_ROOT / "data" / "dev.json"
|
| 96 |
+
|
| 97 |
+
with open(dev_json) as f:
|
| 98 |
+
dev = json.load(f)[:1000] # Adjust number to test more/less
|
| 99 |
+
|
| 100 |
+
components_list = ["select", "where", "group", "order", "and_or", "join"]
|
| 101 |
+
difficulties_list = ["easy", "medium", "hard", "extra"]
|
| 102 |
+
|
| 103 |
+
# Nested dictionary for components
|
| 104 |
+
stats = {
|
| 105 |
+
comp: {diff: {"correct": 0, "total": 0} for diff in difficulties_list}
|
| 106 |
+
for comp in components_list
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
# 🚀 NEW: Trackers for OVERALL accuracy by difficulty
|
| 110 |
+
overall_correct = {diff: 0 for diff in difficulties_list}
|
| 111 |
+
overall_total = {diff: 0 for diff in difficulties_list}
|
| 112 |
+
|
| 113 |
+
print(f"\nRunning grouped evaluation on {len(dev)} examples...\n")
|
| 114 |
+
|
| 115 |
+
for i, ex in enumerate(dev, 1):
|
| 116 |
+
question = ex["question"]
|
| 117 |
+
gold_sql = ex["query"]
|
| 118 |
+
db_id = ex["db_id"]
|
| 119 |
+
|
| 120 |
+
# Determine difficulty
|
| 121 |
+
difficulty = ex.get("difficulty", estimate_difficulty(gold_sql))
|
| 122 |
+
if difficulty not in difficulties_list:
|
| 123 |
+
difficulty = "medium"
|
| 124 |
+
|
| 125 |
+
db_path = DB_ROOT / db_id / f"{db_id}.sqlite"
|
| 126 |
+
schema = load_schema(db_path)
|
| 127 |
+
prompt = build_prompt(question, schema)
|
| 128 |
+
|
| 129 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(device)
|
| 130 |
+
|
| 131 |
+
with torch.no_grad():
|
| 132 |
+
outputs = model.generate(
|
| 133 |
+
**inputs,
|
| 134 |
+
max_new_tokens=1000,
|
| 135 |
+
num_beams=4,
|
| 136 |
+
do_sample=False
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
pred_sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 140 |
+
if "SQL:" in pred_sql:
|
| 141 |
+
pred_sql = pred_sql.split("SQL:")[-1]
|
| 142 |
+
|
| 143 |
+
# --- 1. Update Overall Accuracy Trackers ---
|
| 144 |
+
overall_total[difficulty] += 1
|
| 145 |
+
# Simple string match for quick overall accuracy
|
| 146 |
+
if pred_sql.strip().lower() == gold_sql.strip().lower():
|
| 147 |
+
overall_correct[difficulty] += 1
|
| 148 |
+
|
| 149 |
+
# --- 2. Update Component Stats ---
|
| 150 |
+
pred_comp = extract_components(pred_sql)
|
| 151 |
+
gold_comp = extract_components(gold_sql)
|
| 152 |
+
|
| 153 |
+
for comp in components_list:
|
| 154 |
+
if gold_comp[comp]: # If the gold SQL required this component
|
| 155 |
+
stats[comp][difficulty]["total"] += 1
|
| 156 |
+
if pred_comp[comp]: # If the model successfully generated it
|
| 157 |
+
stats[comp][difficulty]["correct"] += 1
|
| 158 |
+
|
| 159 |
+
if i % 20 == 0:
|
| 160 |
+
print(f"Processed {i}/{len(dev)}")
|
| 161 |
+
|
| 162 |
+
# -------------------------------
|
| 163 |
+
# Plotting (Grouped Bar Chart)
|
| 164 |
+
# -------------------------------
|
| 165 |
+
x = np.arange(len(components_list))
|
| 166 |
+
width = 0.2
|
| 167 |
+
|
| 168 |
+
def get_acc(diff):
|
| 169 |
+
return [
|
| 170 |
+
(stats[comp][diff]["correct"] / stats[comp][diff]["total"] * 100) if stats[comp][diff]["total"] > 0 else 0
|
| 171 |
+
for comp in components_list
|
| 172 |
+
]
|
| 173 |
+
|
| 174 |
+
acc_easy = get_acc("easy")
|
| 175 |
+
acc_medium = get_acc("medium")
|
| 176 |
+
acc_hard = get_acc("hard")
|
| 177 |
+
acc_extra = get_acc("extra")
|
| 178 |
+
|
| 179 |
+
fig, ax = plt.subplots(figsize=(14, 7))
|
| 180 |
+
|
| 181 |
+
bars1 = ax.bar(x - 1.5 * width, acc_easy, width, label='Easy', color='#2ecc71')
|
| 182 |
+
bars2 = ax.bar(x - 0.5 * width, acc_medium, width, label='Medium', color='#f1c40f')
|
| 183 |
+
bars3 = ax.bar(x + 0.5 * width, acc_hard, width, label='Hard', color='#e67e22')
|
| 184 |
+
bars4 = ax.bar(x + 1.5 * width, acc_extra, width, label='Extra', color='#e74c3c')
|
| 185 |
+
|
| 186 |
+
ax.set_ylabel('Accuracy (%)', fontsize=12)
|
| 187 |
+
ax.set_title('SQL Component Match Accuracy by Difficulty Level', fontsize=14, fontweight='bold')
|
| 188 |
+
ax.set_xticks(x)
|
| 189 |
+
ax.set_xticklabels([c.upper() for c in components_list], fontsize=11)
|
| 190 |
+
ax.legend(title="Query Difficulty")
|
| 191 |
+
ax.set_ylim(0, 115)
|
| 192 |
+
|
| 193 |
+
def autolabel(rects):
|
| 194 |
+
for rect in rects:
|
| 195 |
+
height = rect.get_height()
|
| 196 |
+
if height > 0:
|
| 197 |
+
ax.annotate(f'{int(height)}%',
|
| 198 |
+
xy=(rect.get_x() + rect.get_width() / 2, height),
|
| 199 |
+
xytext=(0, 3),
|
| 200 |
+
textcoords="offset points",
|
| 201 |
+
ha='center', va='bottom', fontsize=8, rotation=90)
|
| 202 |
+
|
| 203 |
+
autolabel(bars1)
|
| 204 |
+
autolabel(bars2)
|
| 205 |
+
autolabel(bars3)
|
| 206 |
+
autolabel(bars4)
|
| 207 |
+
|
| 208 |
+
ax.yaxis.grid(True, linestyle='--', alpha=0.7)
|
| 209 |
+
plt.tight_layout()
|
| 210 |
+
plt.savefig("component_by_difficulty_plot.png", dpi=300)
|
| 211 |
+
|
| 212 |
+
# -------------------------------
|
| 213 |
+
# 🚀 Terminal Printout
|
| 214 |
+
# -------------------------------
|
| 215 |
+
print("\n✅ Saved merged plot -> component_by_difficulty_plot.png")
|
| 216 |
+
|
| 217 |
+
print("\n========================================")
|
| 218 |
+
print("🏆 OVERALL AVERAGE ACCURACY BY DIFFICULTY")
|
| 219 |
+
print("========================================")
|
| 220 |
+
for diff in difficulties_list:
|
| 221 |
+
if overall_total[diff] > 0:
|
| 222 |
+
avg = round((overall_correct[diff] / overall_total[diff]) * 100, 2)
|
| 223 |
+
print(f"{diff.capitalize():<8}: {avg:>5}% ({overall_correct[diff]}/{overall_total[diff]} queries)")
|
| 224 |
+
else:
|
| 225 |
+
print(f"{diff.capitalize():<8}: N/A (0 queries)")
|
| 226 |
+
print("========================================\n")
|
| 227 |
+
|
| 228 |
+
if __name__ == "__main__":
|
| 229 |
+
main()
|
src/constrained_decoding.py
ADDED
|
@@ -0,0 +1,1058 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
# import re
|
| 4 |
+
# import threading
|
| 5 |
+
# from dataclasses import dataclass
|
| 6 |
+
# from typing import Dict, Iterable, List, Optional, Sequence, Set
|
| 7 |
+
|
| 8 |
+
# import torch
|
| 9 |
+
# from transformers.generation.logits_process import LogitsProcessor
|
| 10 |
+
|
| 11 |
+
# from schema_constraints import ConstraintGraph, build_constraint_graph
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
# def _infer_expected_identifier(prefix_text: str) -> Optional[str]:
|
| 15 |
+
# s = re.sub(r"\s+", " ", prefix_text.lower())
|
| 16 |
+
# last_from = s.rfind(" from ")
|
| 17 |
+
# last_join = s.rfind(" join ")
|
| 18 |
+
# last_select = s.rfind(" select ")
|
| 19 |
+
# last_where = s.rfind(" where ")
|
| 20 |
+
# last_on = s.rfind(" on ")
|
| 21 |
+
# last_group = s.rfind(" group by ")
|
| 22 |
+
# last_order = s.rfind(" order by ")
|
| 23 |
+
# last_having = s.rfind(" having ")
|
| 24 |
+
|
| 25 |
+
# last_table_kw = max(last_from, last_join)
|
| 26 |
+
# last_col_kw = max(last_select, last_where, last_on, last_group, last_order, last_having)
|
| 27 |
+
|
| 28 |
+
# if last_table_kw < 0 and last_col_kw < 0:
|
| 29 |
+
# return None
|
| 30 |
+
# if last_table_kw > last_col_kw:
|
| 31 |
+
# return "table"
|
| 32 |
+
# if last_col_kw > last_table_kw:
|
| 33 |
+
# return "column"
|
| 34 |
+
# return None
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
# class _TrieNode:
|
| 38 |
+
# __slots__ = ("children", "terminal")
|
| 39 |
+
|
| 40 |
+
# def __init__(self) -> None:
|
| 41 |
+
# self.children: Dict[int, _TrieNode] = {}
|
| 42 |
+
# self.terminal: bool = False
|
| 43 |
+
|
| 44 |
+
# def insert(self, token_ids: Sequence[int]) -> None:
|
| 45 |
+
# node: _TrieNode = self
|
| 46 |
+
# for tid in token_ids:
|
| 47 |
+
# tid_i = int(tid)
|
| 48 |
+
# nxt = node.children.get(tid_i)
|
| 49 |
+
# if nxt is None:
|
| 50 |
+
# nxt = _TrieNode()
|
| 51 |
+
# node.children[tid_i] = nxt
|
| 52 |
+
# node = nxt
|
| 53 |
+
# node.terminal = True
|
| 54 |
+
|
| 55 |
+
# def walk(self, prefix: Sequence[int]) -> Optional["_TrieNode"]:
|
| 56 |
+
# node: _TrieNode = self
|
| 57 |
+
# for tid in prefix:
|
| 58 |
+
# node = node.children.get(int(tid)) # type: ignore[assignment]
|
| 59 |
+
# if node is None:
|
| 60 |
+
# return None
|
| 61 |
+
# return node
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
# def _encode_identifier(tokenizer, name: str) -> List[int]:
|
| 65 |
+
# # Leading space encourages word-start markers (e.g. "Ġ" in RoBERTa BPE).
|
| 66 |
+
# return tokenizer.encode(" " + name, add_special_tokens=False)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
# def _build_trie(tokenizer, names: Iterable[str]) -> _TrieNode:
|
| 70 |
+
# trie = _TrieNode()
|
| 71 |
+
# for n in names:
|
| 72 |
+
# if not n:
|
| 73 |
+
# continue
|
| 74 |
+
# try:
|
| 75 |
+
# ids = _encode_identifier(tokenizer, n)
|
| 76 |
+
# except Exception:
|
| 77 |
+
# continue
|
| 78 |
+
# if ids:
|
| 79 |
+
# trie.insert(ids)
|
| 80 |
+
# return trie
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
# def _allow_always_token_ids(tokenizer) -> torch.Tensor:
|
| 84 |
+
# # Allow common delimiters so the model can end an identifier.
|
| 85 |
+
# toks = [",", ")", "(", "\n", ".", ";"]
|
| 86 |
+
# ids: Set[int] = set()
|
| 87 |
+
# for t in toks:
|
| 88 |
+
# try:
|
| 89 |
+
# for tid in tokenizer.encode(t, add_special_tokens=False):
|
| 90 |
+
# ids.add(int(tid))
|
| 91 |
+
# except Exception:
|
| 92 |
+
# continue
|
| 93 |
+
# return torch.tensor(sorted(ids), dtype=torch.long)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
# @dataclass
|
| 97 |
+
# class _PerDbTokenSets:
|
| 98 |
+
# fp: str
|
| 99 |
+
# table_trie: _TrieNode
|
| 100 |
+
# column_trie: _TrieNode
|
| 101 |
+
# allow_always: torch.Tensor
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
# _DB_TOKENSET_LOCK = threading.Lock()
|
| 105 |
+
# _DB_TOKENSETS: Dict[str, _PerDbTokenSets] = {}
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
# def _per_db_tokensets(tokenizer, graph: ConstraintGraph) -> _PerDbTokenSets:
|
| 109 |
+
# with _DB_TOKENSET_LOCK:
|
| 110 |
+
# cached = _DB_TOKENSETS.get(graph.db_path)
|
| 111 |
+
# if cached is not None and cached.fp == graph.fingerprint:
|
| 112 |
+
# return cached
|
| 113 |
+
|
| 114 |
+
# out = _PerDbTokenSets(
|
| 115 |
+
# fp=graph.fingerprint,
|
| 116 |
+
# table_trie=_build_trie(tokenizer, graph.tables),
|
| 117 |
+
# column_trie=_build_trie(tokenizer, graph.all_columns),
|
| 118 |
+
# allow_always=_allow_always_token_ids(tokenizer),
|
| 119 |
+
# )
|
| 120 |
+
# with _DB_TOKENSET_LOCK:
|
| 121 |
+
# _DB_TOKENSETS[graph.db_path] = out
|
| 122 |
+
# return out
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
# class BatchSchemaConstrainedLogitsProcessor(LogitsProcessor):
|
| 126 |
+
# """
|
| 127 |
+
# Schema-aware constrained decoding per item in the generation batch.
|
| 128 |
+
# Uses a tokenizer-based trie so multi-token identifiers can be constrained.
|
| 129 |
+
# """
|
| 130 |
+
|
| 131 |
+
# def __init__(self, tokenizer, db_paths: Sequence[str], *, max_prefix_tokens: int = 48):
|
| 132 |
+
# self.tokenizer = tokenizer
|
| 133 |
+
# self.db_paths = list(db_paths)
|
| 134 |
+
# self.max_prefix_tokens = int(max_prefix_tokens)
|
| 135 |
+
|
| 136 |
+
# self._graphs = [build_constraint_graph(p) for p in self.db_paths]
|
| 137 |
+
# self._token_sets = [_per_db_tokensets(tokenizer, g) for g in self._graphs]
|
| 138 |
+
|
| 139 |
+
# def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
| 140 |
+
# if input_ids.dim() != 2 or scores.dim() != 2:
|
| 141 |
+
# return scores
|
| 142 |
+
|
| 143 |
+
# batch = input_ids.size(0)
|
| 144 |
+
# if batch != len(self._graphs):
|
| 145 |
+
# return scores
|
| 146 |
+
|
| 147 |
+
# for i in range(batch):
|
| 148 |
+
# tail_ids = input_ids[i, -self.max_prefix_tokens :].tolist()
|
| 149 |
+
# prefix_text = self.tokenizer.decode(tail_ids, skip_special_tokens=True)
|
| 150 |
+
# expected = _infer_expected_identifier(prefix_text)
|
| 151 |
+
# if expected is None:
|
| 152 |
+
# continue
|
| 153 |
+
|
| 154 |
+
# if expected == "table":
|
| 155 |
+
# m = re.search(r"(?:from|join)\s+([A-Za-z_][A-Za-z0-9_]*)$", prefix_text, flags=re.I)
|
| 156 |
+
# partial = m.group(1) if m else None
|
| 157 |
+
# if partial is None and not re.search(r"(?:from|join)\s*$", prefix_text, flags=re.I):
|
| 158 |
+
# continue
|
| 159 |
+
# trie = self._token_sets[i].table_trie
|
| 160 |
+
# else:
|
| 161 |
+
# m = re.search(
|
| 162 |
+
# r"(?:select|where|on|group by|order by|having)\s+([A-Za-z_][A-Za-z0-9_]*(?:\.[A-Za-z_][A-Za-z0-9_]*)?)$",
|
| 163 |
+
# prefix_text,
|
| 164 |
+
# flags=re.I,
|
| 165 |
+
# )
|
| 166 |
+
# partial = m.group(1) if m else None
|
| 167 |
+
# if partial is None and not re.search(
|
| 168 |
+
# r"(?:select|where|on|group by|order by|having)\s*$", prefix_text, flags=re.I
|
| 169 |
+
# ):
|
| 170 |
+
# continue
|
| 171 |
+
# trie = self._token_sets[i].column_trie
|
| 172 |
+
|
| 173 |
+
# if not partial:
|
| 174 |
+
# prefix_token_ids: List[int] = []
|
| 175 |
+
# else:
|
| 176 |
+
# try:
|
| 177 |
+
# prefix_token_ids = _encode_identifier(self.tokenizer, partial)
|
| 178 |
+
# except Exception:
|
| 179 |
+
# continue
|
| 180 |
+
|
| 181 |
+
# node = trie.walk(prefix_token_ids)
|
| 182 |
+
# if node is None or node.terminal:
|
| 183 |
+
# continue
|
| 184 |
+
|
| 185 |
+
# allowed_next = sorted(node.children.keys())
|
| 186 |
+
# if not allowed_next:
|
| 187 |
+
# continue
|
| 188 |
+
|
| 189 |
+
# allowed_next_t = torch.tensor(allowed_next, dtype=torch.long, device=scores.device)
|
| 190 |
+
# allow_always = self._token_sets[i].allow_always.to(scores.device)
|
| 191 |
+
# keep = torch.cat([allowed_next_t, allow_always]) if allow_always.numel() else allowed_next_t
|
| 192 |
+
|
| 193 |
+
# kept_scores = scores[i, keep].clone()
|
| 194 |
+
# scores[i, :] = -float("inf")
|
| 195 |
+
# scores[i, keep] = kept_scores
|
| 196 |
+
|
| 197 |
+
# return scores
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
# # Backwards-compatible names used elsewhere in the repo.
|
| 201 |
+
# class SchemaConstraintGraph:
|
| 202 |
+
# def __init__(self, db_path: str):
|
| 203 |
+
# self._graph = build_constraint_graph(db_path)
|
| 204 |
+
# self.tables = sorted(self._graph.tables)
|
| 205 |
+
# self.columns = sorted(self._graph.all_columns)
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
# class SchemaConstrainedLogitsProcessor(LogitsProcessor):
|
| 209 |
+
# def __init__(self, tokenizer, schema_graph: SchemaConstraintGraph):
|
| 210 |
+
# self._proc = BatchSchemaConstrainedLogitsProcessor(tokenizer, [schema_graph._graph.db_path])
|
| 211 |
+
|
| 212 |
+
# def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
| 213 |
+
# return self._proc(input_ids, scores)
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
# from __future__ import annotations
|
| 219 |
+
|
| 220 |
+
# import re
|
| 221 |
+
# import threading
|
| 222 |
+
# from dataclasses import dataclass
|
| 223 |
+
# from typing import Dict, Iterable, List, Optional, Sequence, Set
|
| 224 |
+
|
| 225 |
+
# import torch
|
| 226 |
+
# from transformers.generation.logits_process import LogitsProcessor
|
| 227 |
+
|
| 228 |
+
# from schema_constraints import ConstraintGraph, build_constraint_graph
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
# # =========================================================
|
| 232 |
+
# # 🔍 IDENTIFIER TYPE DETECTION
|
| 233 |
+
# # =========================================================
|
| 234 |
+
# def _infer_expected_identifier(prefix_text: str) -> Optional[str]:
|
| 235 |
+
# s = re.sub(r"\s+", " ", prefix_text.lower())
|
| 236 |
+
|
| 237 |
+
# last_from = s.rfind(" from ")
|
| 238 |
+
# last_join = s.rfind(" join ")
|
| 239 |
+
# last_select = s.rfind(" select ")
|
| 240 |
+
# last_where = s.rfind(" where ")
|
| 241 |
+
# last_on = s.rfind(" on ")
|
| 242 |
+
# last_group = s.rfind(" group by ")
|
| 243 |
+
# last_order = s.rfind(" order by ")
|
| 244 |
+
# last_having = s.rfind(" having ")
|
| 245 |
+
|
| 246 |
+
# last_table_kw = max(last_from, last_join)
|
| 247 |
+
# last_col_kw = max(last_select, last_where, last_on, last_group, last_order, last_having)
|
| 248 |
+
|
| 249 |
+
# if last_table_kw < 0 and last_col_kw < 0:
|
| 250 |
+
# return None
|
| 251 |
+
# if last_table_kw > last_col_kw:
|
| 252 |
+
# return "table"
|
| 253 |
+
# if last_col_kw > last_table_kw:
|
| 254 |
+
# return "column"
|
| 255 |
+
# return None
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
# # =========================================================
|
| 259 |
+
# # 🌳 TRIE STRUCTURE
|
| 260 |
+
# # =========================================================
|
| 261 |
+
# class _TrieNode:
|
| 262 |
+
# __slots__ = ("children", "terminal")
|
| 263 |
+
|
| 264 |
+
# def __init__(self) -> None:
|
| 265 |
+
# self.children: Dict[int, _TrieNode] = {}
|
| 266 |
+
# self.terminal: bool = False
|
| 267 |
+
|
| 268 |
+
# def insert(self, token_ids: Sequence[int]) -> None:
|
| 269 |
+
# node = self
|
| 270 |
+
# for tid in token_ids:
|
| 271 |
+
# tid = int(tid)
|
| 272 |
+
# if tid not in node.children:
|
| 273 |
+
# node.children[tid] = _TrieNode()
|
| 274 |
+
# node = node.children[tid]
|
| 275 |
+
# node.terminal = True
|
| 276 |
+
|
| 277 |
+
# def walk(self, prefix: Sequence[int]) -> Optional["_TrieNode"]:
|
| 278 |
+
# node = self
|
| 279 |
+
# for tid in prefix:
|
| 280 |
+
# node = node.children.get(int(tid))
|
| 281 |
+
# if node is None:
|
| 282 |
+
# return None
|
| 283 |
+
# return node
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
# # =========================================================
|
| 287 |
+
# # 🔤 TOKEN ENCODING
|
| 288 |
+
# # =========================================================
|
| 289 |
+
# def _encode_identifier(tokenizer, name: str) -> List[int]:
|
| 290 |
+
# return tokenizer.encode(" " + name, add_special_tokens=False)
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
# def _build_trie(tokenizer, names: Iterable[str]) -> _TrieNode:
|
| 294 |
+
# trie = _TrieNode()
|
| 295 |
+
# for name in names:
|
| 296 |
+
# try:
|
| 297 |
+
# ids = _encode_identifier(tokenizer, name)
|
| 298 |
+
# if ids:
|
| 299 |
+
# trie.insert(ids)
|
| 300 |
+
# except Exception:
|
| 301 |
+
# continue
|
| 302 |
+
# return trie
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
# def _allow_always_token_ids(tokenizer) -> torch.Tensor:
|
| 306 |
+
# tokens = [",", ")", "(", ".", ";", "\n"]
|
| 307 |
+
# ids: Set[int] = set()
|
| 308 |
+
|
| 309 |
+
# for t in tokens:
|
| 310 |
+
# try:
|
| 311 |
+
# ids.update(tokenizer.encode(t, add_special_tokens=False))
|
| 312 |
+
# except:
|
| 313 |
+
# pass
|
| 314 |
+
|
| 315 |
+
# return torch.tensor(sorted(ids), dtype=torch.long)
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
# # =========================================================
|
| 319 |
+
# # 📦 PER-DB CACHE
|
| 320 |
+
# # =========================================================
|
| 321 |
+
# @dataclass
|
| 322 |
+
# class _PerDbTokenSets:
|
| 323 |
+
# fp: str
|
| 324 |
+
# table_trie: _TrieNode
|
| 325 |
+
# column_trie: _TrieNode
|
| 326 |
+
# allow_always: torch.Tensor
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
# _DB_CACHE: Dict[str, _PerDbTokenSets] = {}
|
| 330 |
+
# _DB_LOCK = threading.Lock()
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
# def _per_db_tokensets(tokenizer, graph: ConstraintGraph) -> _PerDbTokenSets:
|
| 334 |
+
# with _DB_LOCK:
|
| 335 |
+
# cached = _DB_CACHE.get(graph.db_path)
|
| 336 |
+
# if cached and cached.fp == graph.fingerprint:
|
| 337 |
+
# return cached
|
| 338 |
+
|
| 339 |
+
# obj = _PerDbTokenSets(
|
| 340 |
+
# fp=graph.fingerprint,
|
| 341 |
+
# table_trie=_build_trie(tokenizer, graph.tables),
|
| 342 |
+
# column_trie=_build_trie(tokenizer, graph.all_columns),
|
| 343 |
+
# allow_always=_allow_always_token_ids(tokenizer),
|
| 344 |
+
# )
|
| 345 |
+
|
| 346 |
+
# with _DB_LOCK:
|
| 347 |
+
# _DB_CACHE[graph.db_path] = obj
|
| 348 |
+
|
| 349 |
+
# return obj
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
# # =========================================================
|
| 353 |
+
# # 🚀 MAIN LOGITS PROCESSOR
|
| 354 |
+
# # =========================================================
|
| 355 |
+
# class BatchSchemaConstrainedLogitsProcessor(LogitsProcessor):
|
| 356 |
+
# def __init__(self, tokenizer, db_paths: Sequence[str], max_prefix_tokens: int = 48):
|
| 357 |
+
# self.tokenizer = tokenizer
|
| 358 |
+
# self.db_paths = list(db_paths)
|
| 359 |
+
# self.max_prefix_tokens = max_prefix_tokens
|
| 360 |
+
|
| 361 |
+
# self._graphs = [build_constraint_graph(p) for p in db_paths]
|
| 362 |
+
# self._token_sets = [_per_db_tokensets(tokenizer, g) for g in self._graphs]
|
| 363 |
+
|
| 364 |
+
# # 📊 Metrics (IMPORTANT FOR REPORT)
|
| 365 |
+
# self.total_steps = 0
|
| 366 |
+
# self.constrained_steps = 0
|
| 367 |
+
|
| 368 |
+
# def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
|
| 369 |
+
# batch = input_ids.size(0)
|
| 370 |
+
|
| 371 |
+
# for i in range(batch):
|
| 372 |
+
# self.total_steps += 1
|
| 373 |
+
|
| 374 |
+
# tail_ids = input_ids[i, -self.max_prefix_tokens:].tolist()
|
| 375 |
+
# prefix_text = self.tokenizer.decode(tail_ids, skip_special_tokens=True)
|
| 376 |
+
|
| 377 |
+
# expected = _infer_expected_identifier(prefix_text)
|
| 378 |
+
# if expected is None:
|
| 379 |
+
# continue
|
| 380 |
+
|
| 381 |
+
# self.constrained_steps += 1
|
| 382 |
+
|
| 383 |
+
# # =========================
|
| 384 |
+
# # SELECT TRIE
|
| 385 |
+
# # =========================
|
| 386 |
+
# if expected == "table":
|
| 387 |
+
# trie = self._token_sets[i].table_trie
|
| 388 |
+
# else:
|
| 389 |
+
# trie = self._token_sets[i].column_trie
|
| 390 |
+
|
| 391 |
+
# # =========================
|
| 392 |
+
# # PARTIAL TOKEN MATCH
|
| 393 |
+
# # =========================
|
| 394 |
+
# match = re.search(r"([A-Za-z_][A-Za-z0-9_]*)$", prefix_text)
|
| 395 |
+
# partial = match.group(1) if match else ""
|
| 396 |
+
|
| 397 |
+
# try:
|
| 398 |
+
# prefix_ids = _encode_identifier(self.tokenizer, partial) if partial else []
|
| 399 |
+
# except:
|
| 400 |
+
# continue
|
| 401 |
+
|
| 402 |
+
# node = trie.walk(prefix_ids)
|
| 403 |
+
# if node is None or node.terminal:
|
| 404 |
+
# continue
|
| 405 |
+
|
| 406 |
+
# allowed_next = list(node.children.keys())
|
| 407 |
+
# if not allowed_next:
|
| 408 |
+
# continue
|
| 409 |
+
|
| 410 |
+
# allowed_next = torch.tensor(allowed_next, device=scores.device)
|
| 411 |
+
# allow_always = self._token_sets[i].allow_always.to(scores.device)
|
| 412 |
+
|
| 413 |
+
# keep = torch.cat([allowed_next, allow_always])
|
| 414 |
+
|
| 415 |
+
# kept_scores = scores[i, keep].clone()
|
| 416 |
+
# scores[i, :] = -float("inf")
|
| 417 |
+
# scores[i, keep] = kept_scores
|
| 418 |
+
|
| 419 |
+
# return scores
|
| 420 |
+
|
| 421 |
+
# # =========================================================
|
| 422 |
+
# # 📊 METRICS FOR REPORT
|
| 423 |
+
# # =========================================================
|
| 424 |
+
# def get_constraint_stats(self):
|
| 425 |
+
# if self.total_steps == 0:
|
| 426 |
+
# return 0
|
| 427 |
+
# return self.constrained_steps / self.total_steps
|
| 428 |
+
|
| 429 |
+
|
| 430 |
+
# # =========================================================
|
| 431 |
+
# # 🔁 BACKWARD COMPATIBILITY
|
| 432 |
+
# # =========================================================
|
| 433 |
+
# class SchemaConstraintGraph:
|
| 434 |
+
# def __init__(self, db_path: str):
|
| 435 |
+
# self._graph = build_constraint_graph(db_path)
|
| 436 |
+
# self.tables = sorted(self._graph.tables)
|
| 437 |
+
# self.columns = sorted(self._graph.all_columns)
|
| 438 |
+
|
| 439 |
+
|
| 440 |
+
# class SchemaConstrainedLogitsProcessor(LogitsProcessor):
|
| 441 |
+
# def __init__(self, tokenizer, schema_graph: SchemaConstraintGraph):
|
| 442 |
+
# self.proc = BatchSchemaConstrainedLogitsProcessor(
|
| 443 |
+
# tokenizer, [schema_graph._graph.db_path]
|
| 444 |
+
# )
|
| 445 |
+
|
| 446 |
+
# def __call__(self, input_ids, scores):
|
| 447 |
+
# return self.proc(input_ids, scores)
|
| 448 |
+
|
| 449 |
+
|
| 450 |
+
|
| 451 |
+
|
| 452 |
+
|
| 453 |
+
|
| 454 |
+
# from __future__ import annotations
|
| 455 |
+
|
| 456 |
+
# import re
|
| 457 |
+
# import threading
|
| 458 |
+
# from dataclasses import dataclass
|
| 459 |
+
# from typing import Dict, Iterable, List, Optional, Sequence, Set
|
| 460 |
+
|
| 461 |
+
# import torch
|
| 462 |
+
# from transformers.generation.logits_process import LogitsProcessor
|
| 463 |
+
|
| 464 |
+
# from schema_constraints import ConstraintGraph, build_constraint_graph
|
| 465 |
+
|
| 466 |
+
|
| 467 |
+
# def _infer_expected_identifier(prefix_text: str) -> Optional[str]:
|
| 468 |
+
# s = re.sub(r"\s+", " ", prefix_text.lower())
|
| 469 |
+
# last_from = s.rfind(" from ")
|
| 470 |
+
# last_join = s.rfind(" join ")
|
| 471 |
+
# last_select = s.rfind(" select ")
|
| 472 |
+
# last_where = s.rfind(" where ")
|
| 473 |
+
# last_on = s.rfind(" on ")
|
| 474 |
+
# last_group = s.rfind(" group by ")
|
| 475 |
+
# last_order = s.rfind(" order by ")
|
| 476 |
+
# last_having = s.rfind(" having ")
|
| 477 |
+
|
| 478 |
+
# last_table_kw = max(last_from, last_join)
|
| 479 |
+
# last_col_kw = max(last_select, last_where, last_on, last_group, last_order, last_having)
|
| 480 |
+
|
| 481 |
+
# if last_table_kw < 0 and last_col_kw < 0:
|
| 482 |
+
# return None
|
| 483 |
+
# if last_table_kw > last_col_kw:
|
| 484 |
+
# return "table"
|
| 485 |
+
# if last_col_kw > last_table_kw:
|
| 486 |
+
# return "column"
|
| 487 |
+
# return None
|
| 488 |
+
|
| 489 |
+
|
| 490 |
+
# class _TrieNode:
|
| 491 |
+
# __slots__ = ("children", "terminal")
|
| 492 |
+
|
| 493 |
+
# def __init__(self) -> None:
|
| 494 |
+
# self.children: Dict[int, _TrieNode] = {}
|
| 495 |
+
# self.terminal: bool = False
|
| 496 |
+
|
| 497 |
+
# def insert(self, token_ids: Sequence[int]) -> None:
|
| 498 |
+
# node: _TrieNode = self
|
| 499 |
+
# for tid in token_ids:
|
| 500 |
+
# tid_i = int(tid)
|
| 501 |
+
# nxt = node.children.get(tid_i)
|
| 502 |
+
# if nxt is None:
|
| 503 |
+
# nxt = _TrieNode()
|
| 504 |
+
# node.children[tid_i] = nxt
|
| 505 |
+
# node = nxt
|
| 506 |
+
# node.terminal = True
|
| 507 |
+
|
| 508 |
+
# def walk(self, prefix: Sequence[int]) -> Optional["_TrieNode"]:
|
| 509 |
+
# node: _TrieNode = self
|
| 510 |
+
# for tid in prefix:
|
| 511 |
+
# node = node.children.get(int(tid)) # type: ignore[assignment]
|
| 512 |
+
# if node is None:
|
| 513 |
+
# return None
|
| 514 |
+
# return node
|
| 515 |
+
|
| 516 |
+
|
| 517 |
+
# def _encode_identifier(tokenizer, name: str) -> List[int]:
|
| 518 |
+
# # Leading space encourages word-start markers (e.g. "Ġ" in RoBERTa BPE).
|
| 519 |
+
# return tokenizer.encode(" " + name, add_special_tokens=False)
|
| 520 |
+
|
| 521 |
+
|
| 522 |
+
# def _build_trie(tokenizer, names: Iterable[str]) -> _TrieNode:
|
| 523 |
+
# trie = _TrieNode()
|
| 524 |
+
# for n in names:
|
| 525 |
+
# if not n:
|
| 526 |
+
# continue
|
| 527 |
+
# try:
|
| 528 |
+
# ids = _encode_identifier(tokenizer, n)
|
| 529 |
+
# except Exception:
|
| 530 |
+
# continue
|
| 531 |
+
# if ids:
|
| 532 |
+
# trie.insert(ids)
|
| 533 |
+
# return trie
|
| 534 |
+
|
| 535 |
+
|
| 536 |
+
# def _allow_always_token_ids(tokenizer) -> torch.Tensor:
|
| 537 |
+
# # Allow common delimiters so the model can end an identifier.
|
| 538 |
+
# toks = [",", ")", "(", "\n", ".", ";"]
|
| 539 |
+
# ids: Set[int] = set()
|
| 540 |
+
# for t in toks:
|
| 541 |
+
# try:
|
| 542 |
+
# for tid in tokenizer.encode(t, add_special_tokens=False):
|
| 543 |
+
# ids.add(int(tid))
|
| 544 |
+
# except Exception:
|
| 545 |
+
# continue
|
| 546 |
+
# return torch.tensor(sorted(ids), dtype=torch.long)
|
| 547 |
+
|
| 548 |
+
|
| 549 |
+
# @dataclass
|
| 550 |
+
# class _PerDbTokenSets:
|
| 551 |
+
# fp: str
|
| 552 |
+
# table_trie: _TrieNode
|
| 553 |
+
# column_trie: _TrieNode
|
| 554 |
+
# allow_always: torch.Tensor
|
| 555 |
+
|
| 556 |
+
|
| 557 |
+
# _DB_TOKENSET_LOCK = threading.Lock()
|
| 558 |
+
# _DB_TOKENSETS: Dict[str, _PerDbTokenSets] = {}
|
| 559 |
+
|
| 560 |
+
|
| 561 |
+
# def _per_db_tokensets(tokenizer, graph: ConstraintGraph) -> _PerDbTokenSets:
|
| 562 |
+
# with _DB_TOKENSET_LOCK:
|
| 563 |
+
# cached = _DB_TOKENSETS.get(graph.db_path)
|
| 564 |
+
# if cached is not None and cached.fp == graph.fingerprint:
|
| 565 |
+
# return cached
|
| 566 |
+
|
| 567 |
+
# out = _PerDbTokenSets(
|
| 568 |
+
# fp=graph.fingerprint,
|
| 569 |
+
# table_trie=_build_trie(tokenizer, graph.tables),
|
| 570 |
+
# column_trie=_build_trie(tokenizer, graph.all_columns),
|
| 571 |
+
# allow_always=_allow_always_token_ids(tokenizer),
|
| 572 |
+
# )
|
| 573 |
+
# with _DB_TOKENSET_LOCK:
|
| 574 |
+
# _DB_TOKENSETS[graph.db_path] = out
|
| 575 |
+
# return out
|
| 576 |
+
|
| 577 |
+
|
| 578 |
+
# class BatchSchemaConstrainedLogitsProcessor(LogitsProcessor):
|
| 579 |
+
# """
|
| 580 |
+
# Schema-aware constrained decoding per item in the generation batch.
|
| 581 |
+
# Uses a tokenizer-based trie so multi-token identifiers can be constrained.
|
| 582 |
+
# """
|
| 583 |
+
|
| 584 |
+
# def __init__(self, tokenizer, db_paths: Sequence[str], *, max_prefix_tokens: int = 48):
|
| 585 |
+
# self.tokenizer = tokenizer
|
| 586 |
+
# self.db_paths = list(db_paths)
|
| 587 |
+
# self.max_prefix_tokens = int(max_prefix_tokens)
|
| 588 |
+
|
| 589 |
+
# self._graphs = [build_constraint_graph(p) for p in self.db_paths]
|
| 590 |
+
# self._token_sets = [_per_db_tokensets(tokenizer, g) for g in self._graphs]
|
| 591 |
+
|
| 592 |
+
# def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
| 593 |
+
# if input_ids.dim() != 2 or scores.dim() != 2:
|
| 594 |
+
# return scores
|
| 595 |
+
|
| 596 |
+
# batch = input_ids.size(0)
|
| 597 |
+
# if batch != len(self._graphs):
|
| 598 |
+
# return scores
|
| 599 |
+
|
| 600 |
+
# for i in range(batch):
|
| 601 |
+
# tail_ids = input_ids[i, -self.max_prefix_tokens :].tolist()
|
| 602 |
+
# prefix_text = self.tokenizer.decode(tail_ids, skip_special_tokens=True)
|
| 603 |
+
# expected = _infer_expected_identifier(prefix_text)
|
| 604 |
+
# if expected is None:
|
| 605 |
+
# continue
|
| 606 |
+
|
| 607 |
+
# if expected == "table":
|
| 608 |
+
# m = re.search(r"(?:from|join)\s+([A-Za-z_][A-Za-z0-9_]*)$", prefix_text, flags=re.I)
|
| 609 |
+
# partial = m.group(1) if m else None
|
| 610 |
+
# if partial is None and not re.search(r"(?:from|join)\s*$", prefix_text, flags=re.I):
|
| 611 |
+
# continue
|
| 612 |
+
# trie = self._token_sets[i].table_trie
|
| 613 |
+
# else:
|
| 614 |
+
# m = re.search(
|
| 615 |
+
# r"(?:select|where|on|group by|order by|having)\s+([A-Za-z_][A-Za-z0-9_]*(?:\.[A-Za-z_][A-Za-z0-9_]*)?)$",
|
| 616 |
+
# prefix_text,
|
| 617 |
+
# flags=re.I,
|
| 618 |
+
# )
|
| 619 |
+
# partial = m.group(1) if m else None
|
| 620 |
+
# if partial is None and not re.search(
|
| 621 |
+
# r"(?:select|where|on|group by|order by|having)\s*$", prefix_text, flags=re.I
|
| 622 |
+
# ):
|
| 623 |
+
# continue
|
| 624 |
+
# trie = self._token_sets[i].column_trie
|
| 625 |
+
|
| 626 |
+
# if not partial:
|
| 627 |
+
# prefix_token_ids: List[int] = []
|
| 628 |
+
# else:
|
| 629 |
+
# try:
|
| 630 |
+
# prefix_token_ids = _encode_identifier(self.tokenizer, partial)
|
| 631 |
+
# except Exception:
|
| 632 |
+
# continue
|
| 633 |
+
|
| 634 |
+
# node = trie.walk(prefix_token_ids)
|
| 635 |
+
# if node is None or node.terminal:
|
| 636 |
+
# continue
|
| 637 |
+
|
| 638 |
+
# allowed_next = sorted(node.children.keys())
|
| 639 |
+
# if not allowed_next:
|
| 640 |
+
# continue
|
| 641 |
+
|
| 642 |
+
# allowed_next_t = torch.tensor(allowed_next, dtype=torch.long, device=scores.device)
|
| 643 |
+
# allow_always = self._token_sets[i].allow_always.to(scores.device)
|
| 644 |
+
# keep = torch.cat([allowed_next_t, allow_always]) if allow_always.numel() else allowed_next_t
|
| 645 |
+
|
| 646 |
+
# kept_scores = scores[i, keep].clone()
|
| 647 |
+
# scores[i, :] = -float("inf")
|
| 648 |
+
# scores[i, keep] = kept_scores
|
| 649 |
+
|
| 650 |
+
# return scores
|
| 651 |
+
|
| 652 |
+
|
| 653 |
+
# # Backwards-compatible names used elsewhere in the repo.
|
| 654 |
+
# class SchemaConstraintGraph:
|
| 655 |
+
# def __init__(self, db_path: str):
|
| 656 |
+
# self._graph = build_constraint_graph(db_path)
|
| 657 |
+
# self.tables = sorted(self._graph.tables)
|
| 658 |
+
# self.columns = sorted(self._graph.all_columns)
|
| 659 |
+
|
| 660 |
+
|
| 661 |
+
# class SchemaConstrainedLogitsProcessor(LogitsProcessor):
|
| 662 |
+
# def __init__(self, tokenizer, schema_graph: SchemaConstraintGraph):
|
| 663 |
+
# self._proc = BatchSchemaConstrainedLogitsProcessor(tokenizer, [schema_graph._graph.db_path])
|
| 664 |
+
|
| 665 |
+
# def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
| 666 |
+
# return self._proc(input_ids, scores)
|
| 667 |
+
|
| 668 |
+
|
| 669 |
+
|
| 670 |
+
|
| 671 |
+
# from __future__ import annotations
|
| 672 |
+
|
| 673 |
+
# import re
|
| 674 |
+
# import threading
|
| 675 |
+
# from dataclasses import dataclass
|
| 676 |
+
# from typing import Dict, Iterable, List, Optional, Sequence, Set
|
| 677 |
+
|
| 678 |
+
# import torch
|
| 679 |
+
# from transformers.generation.logits_process import LogitsProcessor
|
| 680 |
+
|
| 681 |
+
# from schema_constraints import ConstraintGraph, build_constraint_graph
|
| 682 |
+
|
| 683 |
+
|
| 684 |
+
# # =========================================================
|
| 685 |
+
# # 🔍 IDENTIFIER TYPE DETECTION
|
| 686 |
+
# # =========================================================
|
| 687 |
+
# def _infer_expected_identifier(prefix_text: str) -> Optional[str]:
|
| 688 |
+
# s = re.sub(r"\s+", " ", prefix_text.lower())
|
| 689 |
+
|
| 690 |
+
# last_from = s.rfind(" from ")
|
| 691 |
+
# last_join = s.rfind(" join ")
|
| 692 |
+
# last_select = s.rfind(" select ")
|
| 693 |
+
# last_where = s.rfind(" where ")
|
| 694 |
+
# last_on = s.rfind(" on ")
|
| 695 |
+
# last_group = s.rfind(" group by ")
|
| 696 |
+
# last_order = s.rfind(" order by ")
|
| 697 |
+
# last_having = s.rfind(" having ")
|
| 698 |
+
|
| 699 |
+
# last_table_kw = max(last_from, last_join)
|
| 700 |
+
# last_col_kw = max(last_select, last_where, last_on, last_group, last_order, last_having)
|
| 701 |
+
|
| 702 |
+
# if last_table_kw < 0 and last_col_kw < 0:
|
| 703 |
+
# return None
|
| 704 |
+
# if last_table_kw > last_col_kw:
|
| 705 |
+
# return "table"
|
| 706 |
+
# if last_col_kw > last_table_kw:
|
| 707 |
+
# return "column"
|
| 708 |
+
# return None
|
| 709 |
+
|
| 710 |
+
|
| 711 |
+
# # =========================================================
|
| 712 |
+
# # 🌳 TRIE STRUCTURE
|
| 713 |
+
# # =========================================================
|
| 714 |
+
# class _TrieNode:
|
| 715 |
+
# __slots__ = ("children", "terminal")
|
| 716 |
+
|
| 717 |
+
# def __init__(self) -> None:
|
| 718 |
+
# self.children: Dict[int, _TrieNode] = {}
|
| 719 |
+
# self.terminal: bool = False
|
| 720 |
+
|
| 721 |
+
# def insert(self, token_ids: Sequence[int]) -> None:
|
| 722 |
+
# node = self
|
| 723 |
+
# for tid in token_ids:
|
| 724 |
+
# tid = int(tid)
|
| 725 |
+
# if tid not in node.children:
|
| 726 |
+
# node.children[tid] = _TrieNode()
|
| 727 |
+
# node = node.children[tid]
|
| 728 |
+
# node.terminal = True
|
| 729 |
+
|
| 730 |
+
# def walk(self, prefix: Sequence[int]) -> Optional["_TrieNode"]:
|
| 731 |
+
# node = self
|
| 732 |
+
# for tid in prefix:
|
| 733 |
+
# node = node.children.get(int(tid))
|
| 734 |
+
# if node is None:
|
| 735 |
+
# return None
|
| 736 |
+
# return node
|
| 737 |
+
|
| 738 |
+
|
| 739 |
+
# # =========================================================
|
| 740 |
+
# # 🔤 TOKEN ENCODING
|
| 741 |
+
# # =========================================================
|
| 742 |
+
# def _encode_identifier(tokenizer, name: str) -> List[int]:
|
| 743 |
+
# return tokenizer.encode(" " + name, add_special_tokens=False)
|
| 744 |
+
|
| 745 |
+
|
| 746 |
+
# def _build_trie(tokenizer, names: Iterable[str]) -> _TrieNode:
|
| 747 |
+
# trie = _TrieNode()
|
| 748 |
+
# for name in names:
|
| 749 |
+
# try:
|
| 750 |
+
# ids = _encode_identifier(tokenizer, name)
|
| 751 |
+
# if ids:
|
| 752 |
+
# trie.insert(ids)
|
| 753 |
+
# except Exception:
|
| 754 |
+
# continue
|
| 755 |
+
# return trie
|
| 756 |
+
|
| 757 |
+
|
| 758 |
+
# def _allow_always_token_ids(tokenizer) -> torch.Tensor:
|
| 759 |
+
# tokens = [",", ")", "(", ".", ";", "\n"]
|
| 760 |
+
# ids: Set[int] = set()
|
| 761 |
+
|
| 762 |
+
# for t in tokens:
|
| 763 |
+
# try:
|
| 764 |
+
# ids.update(tokenizer.encode(t, add_special_tokens=False))
|
| 765 |
+
# except:
|
| 766 |
+
# pass
|
| 767 |
+
|
| 768 |
+
# return torch.tensor(sorted(ids), dtype=torch.long)
|
| 769 |
+
|
| 770 |
+
|
| 771 |
+
# # =========================================================
|
| 772 |
+
# # 📦 PER-DB CACHE
|
| 773 |
+
# # =========================================================
|
| 774 |
+
# @dataclass
|
| 775 |
+
# class _PerDbTokenSets:
|
| 776 |
+
# fp: str
|
| 777 |
+
# table_trie: _TrieNode
|
| 778 |
+
# column_trie: _TrieNode
|
| 779 |
+
# allow_always: torch.Tensor
|
| 780 |
+
|
| 781 |
+
|
| 782 |
+
# _DB_CACHE: Dict[str, _PerDbTokenSets] = {}
|
| 783 |
+
# _DB_LOCK = threading.Lock()
|
| 784 |
+
|
| 785 |
+
|
| 786 |
+
# def _per_db_tokensets(tokenizer, graph: ConstraintGraph) -> _PerDbTokenSets:
|
| 787 |
+
# with _DB_LOCK:
|
| 788 |
+
# cached = _DB_CACHE.get(graph.db_path)
|
| 789 |
+
# if cached and cached.fp == graph.fingerprint:
|
| 790 |
+
# return cached
|
| 791 |
+
|
| 792 |
+
# obj = _PerDbTokenSets(
|
| 793 |
+
# fp=graph.fingerprint,
|
| 794 |
+
# table_trie=_build_trie(tokenizer, graph.tables),
|
| 795 |
+
# column_trie=_build_trie(tokenizer, graph.all_columns),
|
| 796 |
+
# allow_always=_allow_always_token_ids(tokenizer),
|
| 797 |
+
# )
|
| 798 |
+
|
| 799 |
+
# with _DB_LOCK:
|
| 800 |
+
# _DB_CACHE[graph.db_path] = obj
|
| 801 |
+
|
| 802 |
+
# return obj
|
| 803 |
+
|
| 804 |
+
|
| 805 |
+
# # =========================================================
|
| 806 |
+
# # 🚀 MAIN LOGITS PROCESSOR
|
| 807 |
+
# # =========================================================
|
| 808 |
+
# class BatchSchemaConstrainedLogitsProcessor(LogitsProcessor):
|
| 809 |
+
# def __init__(self, tokenizer, db_paths: Sequence[str], max_prefix_tokens: int = 48):
|
| 810 |
+
# self.tokenizer = tokenizer
|
| 811 |
+
# self.db_paths = list(db_paths)
|
| 812 |
+
# self.max_prefix_tokens = max_prefix_tokens
|
| 813 |
+
|
| 814 |
+
# self._graphs = [build_constraint_graph(p) for p in db_paths]
|
| 815 |
+
# self._token_sets = [_per_db_tokensets(tokenizer, g) for g in self._graphs]
|
| 816 |
+
|
| 817 |
+
# # 📊 Metrics (IMPORTANT FOR REPORT)
|
| 818 |
+
# self.total_steps = 0
|
| 819 |
+
# self.constrained_steps = 0
|
| 820 |
+
|
| 821 |
+
# def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
|
| 822 |
+
# batch = input_ids.size(0)
|
| 823 |
+
|
| 824 |
+
# for i in range(batch):
|
| 825 |
+
# self.total_steps += 1
|
| 826 |
+
|
| 827 |
+
# tail_ids = input_ids[i, -self.max_prefix_tokens:].tolist()
|
| 828 |
+
# prefix_text = self.tokenizer.decode(tail_ids, skip_special_tokens=True)
|
| 829 |
+
|
| 830 |
+
# expected = _infer_expected_identifier(prefix_text)
|
| 831 |
+
# if expected is None:
|
| 832 |
+
# continue
|
| 833 |
+
|
| 834 |
+
# self.constrained_steps += 1
|
| 835 |
+
|
| 836 |
+
# # =========================
|
| 837 |
+
# # SELECT TRIE
|
| 838 |
+
# # =========================
|
| 839 |
+
# if expected == "table":
|
| 840 |
+
# trie = self._token_sets[i].table_trie
|
| 841 |
+
# else:
|
| 842 |
+
# trie = self._token_sets[i].column_trie
|
| 843 |
+
|
| 844 |
+
# # =========================
|
| 845 |
+
# # PARTIAL TOKEN MATCH
|
| 846 |
+
# # =========================
|
| 847 |
+
# match = re.search(r"([A-Za-z_][A-Za-z0-9_]*)$", prefix_text)
|
| 848 |
+
# partial = match.group(1) if match else ""
|
| 849 |
+
|
| 850 |
+
# try:
|
| 851 |
+
# prefix_ids = _encode_identifier(self.tokenizer, partial) if partial else []
|
| 852 |
+
# except:
|
| 853 |
+
# continue
|
| 854 |
+
|
| 855 |
+
# node = trie.walk(prefix_ids)
|
| 856 |
+
# if node is None or node.terminal:
|
| 857 |
+
# continue
|
| 858 |
+
|
| 859 |
+
# allowed_next = list(node.children.keys())
|
| 860 |
+
# if not allowed_next:
|
| 861 |
+
# continue
|
| 862 |
+
|
| 863 |
+
# allowed_next = torch.tensor(allowed_next, device=scores.device)
|
| 864 |
+
# allow_always = self._token_sets[i].allow_always.to(scores.device)
|
| 865 |
+
|
| 866 |
+
# keep = torch.cat([allowed_next, allow_always])
|
| 867 |
+
|
| 868 |
+
# kept_scores = scores[i, keep].clone()
|
| 869 |
+
# scores[i, :] = -float("inf")
|
| 870 |
+
# scores[i, keep] = kept_scores
|
| 871 |
+
|
| 872 |
+
# return scores
|
| 873 |
+
|
| 874 |
+
# # =========================================================
|
| 875 |
+
# # 📊 METRICS FOR REPORT
|
| 876 |
+
# # =========================================================
|
| 877 |
+
# def get_constraint_stats(self):
|
| 878 |
+
# if self.total_steps == 0:
|
| 879 |
+
# return 0
|
| 880 |
+
# return self.constrained_steps / self.total_steps
|
| 881 |
+
|
| 882 |
+
|
| 883 |
+
# # =========================================================
|
| 884 |
+
# # 🔁 BACKWARD COMPATIBILITY
|
| 885 |
+
# # =========================================================
|
| 886 |
+
# class SchemaConstraintGraph:
|
| 887 |
+
# def __init__(self, db_path: str):
|
| 888 |
+
# self._graph = build_constraint_graph(db_path)
|
| 889 |
+
# self.tables = sorted(self._graph.tables)
|
| 890 |
+
# self.columns = sorted(self._graph.all_columns)
|
| 891 |
+
|
| 892 |
+
|
| 893 |
+
# class SchemaConstrainedLogitsProcessor(LogitsProcessor):
|
| 894 |
+
# def __init__(self, tokenizer, schema_graph: SchemaConstraintGraph):
|
| 895 |
+
# self.proc = BatchSchemaConstrainedLogitsProcessor(
|
| 896 |
+
# tokenizer, [schema_graph._graph.db_path]
|
| 897 |
+
# )
|
| 898 |
+
|
| 899 |
+
# def __call__(self, input_ids, scores):
|
| 900 |
+
# return self.proc(input_ids, scores)
|
| 901 |
+
|
| 902 |
+
|
| 903 |
+
|
| 904 |
+
|
| 905 |
+
|
| 906 |
+
|
| 907 |
+
|
| 908 |
+
|
| 909 |
+
# ********* after task 3
|
| 910 |
+
|
| 911 |
+
import re
|
| 912 |
+
import threading
|
| 913 |
+
from functools import lru_cache
|
| 914 |
+
|
| 915 |
+
import torch
|
| 916 |
+
from transformers import LogitsProcessor
|
| 917 |
+
|
| 918 |
+
from src.schema_utils import get_constraint_graph
|
| 919 |
+
|
| 920 |
+
|
| 921 |
+
_TOKEN_CACHE_LOCK = threading.Lock()
|
| 922 |
+
_TOKEN_ID_CACHE = {} # (id(tokenizer), db_path) -> (allowed_ids_tensor, always_allow_ids_tensor)
|
| 923 |
+
|
| 924 |
+
|
| 925 |
+
def _encode_variants(tokenizer, text: str) -> list[int]:
|
| 926 |
+
ids: list[int] = []
|
| 927 |
+
for variant in (text, " " + text):
|
| 928 |
+
try:
|
| 929 |
+
ids.extend(tokenizer.encode(variant, add_special_tokens=False))
|
| 930 |
+
except Exception:
|
| 931 |
+
continue
|
| 932 |
+
# de-dup while keeping order
|
| 933 |
+
seen = set()
|
| 934 |
+
out = []
|
| 935 |
+
for i in ids:
|
| 936 |
+
if int(i) not in seen:
|
| 937 |
+
seen.add(int(i))
|
| 938 |
+
out.append(int(i))
|
| 939 |
+
return out
|
| 940 |
+
|
| 941 |
+
|
| 942 |
+
def _always_allow_ids(tokenizer) -> list[int]:
|
| 943 |
+
"""
|
| 944 |
+
Tokens we should never block, otherwise decoding can get stuck or generate garbage:
|
| 945 |
+
- EOS/PAD
|
| 946 |
+
- punctuation/operators needed for SQL formatting
|
| 947 |
+
- digits/quotes
|
| 948 |
+
"""
|
| 949 |
+
ids: list[int] = []
|
| 950 |
+
for special in [getattr(tokenizer, "eos_token_id", None), getattr(tokenizer, "pad_token_id", None)]:
|
| 951 |
+
if special is not None:
|
| 952 |
+
ids.append(int(special))
|
| 953 |
+
|
| 954 |
+
# Common SQL punctuation/operators
|
| 955 |
+
pieces = [
|
| 956 |
+
" ", "\n", "\t",
|
| 957 |
+
",", ".", "(", ")", ";",
|
| 958 |
+
"=", "!=", "<>", "<", ">", "<=", ">=",
|
| 959 |
+
"*", "+", "-", "/", "%",
|
| 960 |
+
"'", '"',
|
| 961 |
+
]
|
| 962 |
+
for p in pieces:
|
| 963 |
+
ids.extend(_encode_variants(tokenizer, p))
|
| 964 |
+
|
| 965 |
+
# digits
|
| 966 |
+
for d in "0123456789":
|
| 967 |
+
ids.extend(_encode_variants(tokenizer, d))
|
| 968 |
+
|
| 969 |
+
seen = set()
|
| 970 |
+
out = []
|
| 971 |
+
for i in ids:
|
| 972 |
+
if int(i) not in seen:
|
| 973 |
+
seen.add(int(i))
|
| 974 |
+
out.append(int(i))
|
| 975 |
+
return out
|
| 976 |
+
|
| 977 |
+
|
| 978 |
+
def _infer_expected_identifier_tail(tail_text: str):
|
| 979 |
+
"""
|
| 980 |
+
Returns ("table"|"column", partial_or_empty) if the tail looks like it's currently
|
| 981 |
+
emitting a table/column identifier. Otherwise returns None.
|
| 982 |
+
"""
|
| 983 |
+
t = re.sub(r"\s+", " ", (tail_text or "")).lower()
|
| 984 |
+
|
| 985 |
+
m = re.search(r"(?:from|join)\s+([a-z_][a-z0-9_]*)?$", t)
|
| 986 |
+
if m:
|
| 987 |
+
partial = m.group(1) or ""
|
| 988 |
+
# ensure we are actually after keyword (not elsewhere)
|
| 989 |
+
if re.search(r"(?:from|join)\s*$", t) or partial:
|
| 990 |
+
return "table", partial
|
| 991 |
+
|
| 992 |
+
m = re.search(
|
| 993 |
+
r"(?:select|where|on|group by|order by|having)\s+([a-z_][a-z0-9_]*(?:\.[a-z_][a-z0-9_]*)?)?$",
|
| 994 |
+
t,
|
| 995 |
+
)
|
| 996 |
+
if m:
|
| 997 |
+
partial = m.group(1) or ""
|
| 998 |
+
if re.search(r"(?:select|where|on|group by|order by|having)\s*$", t) or partial:
|
| 999 |
+
return "column", partial
|
| 1000 |
+
|
| 1001 |
+
return None
|
| 1002 |
+
|
| 1003 |
+
|
| 1004 |
+
class SchemaConstrainedLogitsProcessor(LogitsProcessor):
|
| 1005 |
+
def __init__(self, tokenizer, db_path):
|
| 1006 |
+
self.tokenizer = tokenizer
|
| 1007 |
+
|
| 1008 |
+
graph = get_constraint_graph(db_path)
|
| 1009 |
+
|
| 1010 |
+
key = (id(tokenizer), str(db_path))
|
| 1011 |
+
with _TOKEN_CACHE_LOCK:
|
| 1012 |
+
cached = _TOKEN_ID_CACHE.get(key)
|
| 1013 |
+
if cached is None:
|
| 1014 |
+
allowed_tokens = set(graph.get("tables", set())) | set(graph.get("columns", set()))
|
| 1015 |
+
|
| 1016 |
+
sql_keywords = {
|
| 1017 |
+
"select", "from", "where", "join", "on",
|
| 1018 |
+
"group", "by", "order", "limit", "having",
|
| 1019 |
+
"and", "or", "desc", "asc",
|
| 1020 |
+
"count", "avg", "min", "max", "sum",
|
| 1021 |
+
"distinct", "as", "in", "like", "between",
|
| 1022 |
+
"is", "null",
|
| 1023 |
+
}
|
| 1024 |
+
allowed_tokens |= sql_keywords
|
| 1025 |
+
|
| 1026 |
+
allowed_ids: list[int] = []
|
| 1027 |
+
for tok in sorted(allowed_tokens):
|
| 1028 |
+
allowed_ids.extend(_encode_variants(tokenizer, tok))
|
| 1029 |
+
always_ids = _always_allow_ids(tokenizer)
|
| 1030 |
+
|
| 1031 |
+
allowed_ids_t = torch.tensor(sorted(set(allowed_ids)), dtype=torch.long)
|
| 1032 |
+
always_ids_t = torch.tensor(sorted(set(always_ids)), dtype=torch.long)
|
| 1033 |
+
cached = (allowed_ids_t, always_ids_t)
|
| 1034 |
+
with _TOKEN_CACHE_LOCK:
|
| 1035 |
+
_TOKEN_ID_CACHE[key] = cached
|
| 1036 |
+
|
| 1037 |
+
self._allowed_ids_t, self._always_ids_t = cached
|
| 1038 |
+
|
| 1039 |
+
def __call__(self, input_ids, scores):
|
| 1040 |
+
# Decode only a tail window for speed (beam search calls this a lot).
|
| 1041 |
+
try:
|
| 1042 |
+
tail_ids = input_ids[0][-128:]
|
| 1043 |
+
except Exception:
|
| 1044 |
+
tail_ids = input_ids[0]
|
| 1045 |
+
tail = self.tokenizer.decode(tail_ids, skip_special_tokens=True)
|
| 1046 |
+
|
| 1047 |
+
inferred = _infer_expected_identifier_tail(tail)
|
| 1048 |
+
if inferred is None:
|
| 1049 |
+
return scores
|
| 1050 |
+
|
| 1051 |
+
keep = torch.cat([self._allowed_ids_t.to(scores.device), self._always_ids_t.to(scores.device)])
|
| 1052 |
+
if keep.numel() == 0:
|
| 1053 |
+
return scores
|
| 1054 |
+
|
| 1055 |
+
kept_scores = scores[:, keep].clone()
|
| 1056 |
+
scores[:] = -float("inf")
|
| 1057 |
+
scores[:, keep] = kept_scores
|
| 1058 |
+
return scores
|
src/constrained_decoding_sample.py
ADDED
|
@@ -0,0 +1,516 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
# import re
|
| 4 |
+
# import threading
|
| 5 |
+
# from dataclasses import dataclass
|
| 6 |
+
# from typing import Dict, Iterable, List, Optional, Sequence, Set
|
| 7 |
+
|
| 8 |
+
# import torch
|
| 9 |
+
# from transformers.generation.logits_process import LogitsProcessor
|
| 10 |
+
|
| 11 |
+
# from schema_constraints import ConstraintGraph, build_constraint_graph
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
# def _infer_expected_identifier(prefix_text: str) -> Optional[str]:
|
| 15 |
+
# s = re.sub(r"\s+", " ", prefix_text.lower())
|
| 16 |
+
# last_from = s.rfind(" from ")
|
| 17 |
+
# last_join = s.rfind(" join ")
|
| 18 |
+
# last_select = s.rfind(" select ")
|
| 19 |
+
# last_where = s.rfind(" where ")
|
| 20 |
+
# last_on = s.rfind(" on ")
|
| 21 |
+
# last_group = s.rfind(" group by ")
|
| 22 |
+
# last_order = s.rfind(" order by ")
|
| 23 |
+
# last_having = s.rfind(" having ")
|
| 24 |
+
|
| 25 |
+
# last_table_kw = max(last_from, last_join)
|
| 26 |
+
# last_col_kw = max(last_select, last_where, last_on, last_group, last_order, last_having)
|
| 27 |
+
|
| 28 |
+
# if last_table_kw < 0 and last_col_kw < 0:
|
| 29 |
+
# return None
|
| 30 |
+
# if last_table_kw > last_col_kw:
|
| 31 |
+
# return "table"
|
| 32 |
+
# if last_col_kw > last_table_kw:
|
| 33 |
+
# return "column"
|
| 34 |
+
# return None
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
# class _TrieNode:
|
| 38 |
+
# __slots__ = ("children", "terminal")
|
| 39 |
+
|
| 40 |
+
# def __init__(self) -> None:
|
| 41 |
+
# self.children: Dict[int, _TrieNode] = {}
|
| 42 |
+
# self.terminal: bool = False
|
| 43 |
+
|
| 44 |
+
# def insert(self, token_ids: Sequence[int]) -> None:
|
| 45 |
+
# node: _TrieNode = self
|
| 46 |
+
# for tid in token_ids:
|
| 47 |
+
# tid_i = int(tid)
|
| 48 |
+
# nxt = node.children.get(tid_i)
|
| 49 |
+
# if nxt is None:
|
| 50 |
+
# nxt = _TrieNode()
|
| 51 |
+
# node.children[tid_i] = nxt
|
| 52 |
+
# node = nxt
|
| 53 |
+
# node.terminal = True
|
| 54 |
+
|
| 55 |
+
# def walk(self, prefix: Sequence[int]) -> Optional["_TrieNode"]:
|
| 56 |
+
# node: _TrieNode = self
|
| 57 |
+
# for tid in prefix:
|
| 58 |
+
# node = node.children.get(int(tid)) # type: ignore[assignment]
|
| 59 |
+
# if node is None:
|
| 60 |
+
# return None
|
| 61 |
+
# return node
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
# def _encode_identifier(tokenizer, name: str) -> List[int]:
|
| 65 |
+
# # Leading space encourages word-start markers (e.g. "Ġ" in RoBERTa BPE).
|
| 66 |
+
# return tokenizer.encode(" " + name, add_special_tokens=False)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
# def _build_trie(tokenizer, names: Iterable[str]) -> _TrieNode:
|
| 70 |
+
# trie = _TrieNode()
|
| 71 |
+
# for n in names:
|
| 72 |
+
# if not n:
|
| 73 |
+
# continue
|
| 74 |
+
# try:
|
| 75 |
+
# ids = _encode_identifier(tokenizer, n)
|
| 76 |
+
# except Exception:
|
| 77 |
+
# continue
|
| 78 |
+
# if ids:
|
| 79 |
+
# trie.insert(ids)
|
| 80 |
+
# return trie
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
# def _allow_always_token_ids(tokenizer) -> torch.Tensor:
|
| 84 |
+
# # Allow common delimiters so the model can end an identifier.
|
| 85 |
+
# toks = [",", ")", "(", "\n", ".", ";"]
|
| 86 |
+
# ids: Set[int] = set()
|
| 87 |
+
# for t in toks:
|
| 88 |
+
# try:
|
| 89 |
+
# for tid in tokenizer.encode(t, add_special_tokens=False):
|
| 90 |
+
# ids.add(int(tid))
|
| 91 |
+
# except Exception:
|
| 92 |
+
# continue
|
| 93 |
+
# return torch.tensor(sorted(ids), dtype=torch.long)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
# @dataclass
|
| 97 |
+
# class _PerDbTokenSets:
|
| 98 |
+
# fp: str
|
| 99 |
+
# table_trie: _TrieNode
|
| 100 |
+
# column_trie: _TrieNode
|
| 101 |
+
# allow_always: torch.Tensor
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
# _DB_TOKENSET_LOCK = threading.Lock()
|
| 105 |
+
# _DB_TOKENSETS: Dict[str, _PerDbTokenSets] = {}
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
# def _per_db_tokensets(tokenizer, graph: ConstraintGraph) -> _PerDbTokenSets:
|
| 109 |
+
# with _DB_TOKENSET_LOCK:
|
| 110 |
+
# cached = _DB_TOKENSETS.get(graph.db_path)
|
| 111 |
+
# if cached is not None and cached.fp == graph.fingerprint:
|
| 112 |
+
# return cached
|
| 113 |
+
|
| 114 |
+
# out = _PerDbTokenSets(
|
| 115 |
+
# fp=graph.fingerprint,
|
| 116 |
+
# table_trie=_build_trie(tokenizer, graph.tables),
|
| 117 |
+
# column_trie=_build_trie(tokenizer, graph.all_columns),
|
| 118 |
+
# allow_always=_allow_always_token_ids(tokenizer),
|
| 119 |
+
# )
|
| 120 |
+
# with _DB_TOKENSET_LOCK:
|
| 121 |
+
# _DB_TOKENSETS[graph.db_path] = out
|
| 122 |
+
# return out
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
# class BatchSchemaConstrainedLogitsProcessor(LogitsProcessor):
|
| 126 |
+
# """
|
| 127 |
+
# Schema-aware constrained decoding per item in the generation batch.
|
| 128 |
+
# Uses a tokenizer-based trie so multi-token identifiers can be constrained.
|
| 129 |
+
# """
|
| 130 |
+
|
| 131 |
+
# def __init__(self, tokenizer, db_paths: Sequence[str], *, max_prefix_tokens: int = 48):
|
| 132 |
+
# self.tokenizer = tokenizer
|
| 133 |
+
# self.db_paths = list(db_paths)
|
| 134 |
+
# self.max_prefix_tokens = int(max_prefix_tokens)
|
| 135 |
+
|
| 136 |
+
# self._graphs = [build_constraint_graph(p) for p in self.db_paths]
|
| 137 |
+
# self._token_sets = [_per_db_tokensets(tokenizer, g) for g in self._graphs]
|
| 138 |
+
|
| 139 |
+
# def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
| 140 |
+
# if input_ids.dim() != 2 or scores.dim() != 2:
|
| 141 |
+
# return scores
|
| 142 |
+
|
| 143 |
+
# batch = input_ids.size(0)
|
| 144 |
+
# if batch != len(self._graphs):
|
| 145 |
+
# return scores
|
| 146 |
+
|
| 147 |
+
# for i in range(batch):
|
| 148 |
+
# tail_ids = input_ids[i, -self.max_prefix_tokens :].tolist()
|
| 149 |
+
# prefix_text = self.tokenizer.decode(tail_ids, skip_special_tokens=True)
|
| 150 |
+
# expected = _infer_expected_identifier(prefix_text)
|
| 151 |
+
# if expected is None:
|
| 152 |
+
# continue
|
| 153 |
+
|
| 154 |
+
# if expected == "table":
|
| 155 |
+
# m = re.search(r"(?:from|join)\s+([A-Za-z_][A-Za-z0-9_]*)$", prefix_text, flags=re.I)
|
| 156 |
+
# partial = m.group(1) if m else None
|
| 157 |
+
# if partial is None and not re.search(r"(?:from|join)\s*$", prefix_text, flags=re.I):
|
| 158 |
+
# continue
|
| 159 |
+
# trie = self._token_sets[i].table_trie
|
| 160 |
+
# else:
|
| 161 |
+
# m = re.search(
|
| 162 |
+
# r"(?:select|where|on|group by|order by|having)\s+([A-Za-z_][A-Za-z0-9_]*(?:\.[A-Za-z_][A-Za-z0-9_]*)?)$",
|
| 163 |
+
# prefix_text,
|
| 164 |
+
# flags=re.I,
|
| 165 |
+
# )
|
| 166 |
+
# partial = m.group(1) if m else None
|
| 167 |
+
# if partial is None and not re.search(
|
| 168 |
+
# r"(?:select|where|on|group by|order by|having)\s*$", prefix_text, flags=re.I
|
| 169 |
+
# ):
|
| 170 |
+
# continue
|
| 171 |
+
# trie = self._token_sets[i].column_trie
|
| 172 |
+
|
| 173 |
+
# if not partial:
|
| 174 |
+
# prefix_token_ids: List[int] = []
|
| 175 |
+
# else:
|
| 176 |
+
# try:
|
| 177 |
+
# prefix_token_ids = _encode_identifier(self.tokenizer, partial)
|
| 178 |
+
# except Exception:
|
| 179 |
+
# continue
|
| 180 |
+
|
| 181 |
+
# node = trie.walk(prefix_token_ids)
|
| 182 |
+
# if node is None or node.terminal:
|
| 183 |
+
# continue
|
| 184 |
+
|
| 185 |
+
# allowed_next = sorted(node.children.keys())
|
| 186 |
+
# if not allowed_next:
|
| 187 |
+
# continue
|
| 188 |
+
|
| 189 |
+
# allowed_next_t = torch.tensor(allowed_next, dtype=torch.long, device=scores.device)
|
| 190 |
+
# allow_always = self._token_sets[i].allow_always.to(scores.device)
|
| 191 |
+
# keep = torch.cat([allowed_next_t, allow_always]) if allow_always.numel() else allowed_next_t
|
| 192 |
+
|
| 193 |
+
# kept_scores = scores[i, keep].clone()
|
| 194 |
+
# scores[i, :] = -float("inf")
|
| 195 |
+
# scores[i, keep] = kept_scores
|
| 196 |
+
|
| 197 |
+
# return scores
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
# # Backwards-compatible names used elsewhere in the repo.
|
| 201 |
+
# class SchemaConstraintGraph:
|
| 202 |
+
# def __init__(self, db_path: str):
|
| 203 |
+
# self._graph = build_constraint_graph(db_path)
|
| 204 |
+
# self.tables = sorted(self._graph.tables)
|
| 205 |
+
# self.columns = sorted(self._graph.all_columns)
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
# class SchemaConstrainedLogitsProcessor(LogitsProcessor):
|
| 209 |
+
# def __init__(self, tokenizer, schema_graph: SchemaConstraintGraph):
|
| 210 |
+
# self._proc = BatchSchemaConstrainedLogitsProcessor(tokenizer, [schema_graph._graph.db_path])
|
| 211 |
+
|
| 212 |
+
# def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
| 213 |
+
# return self._proc(input_ids, scores)
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
# from __future__ import annotations
|
| 219 |
+
|
| 220 |
+
# import re
|
| 221 |
+
# import threading
|
| 222 |
+
# from dataclasses import dataclass
|
| 223 |
+
# from typing import Dict, Iterable, List, Optional, Sequence, Set
|
| 224 |
+
|
| 225 |
+
# import torch
|
| 226 |
+
# from transformers.generation.logits_process import LogitsProcessor
|
| 227 |
+
|
| 228 |
+
# from schema_constraints import ConstraintGraph, build_constraint_graph
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
# # =========================================================
|
| 232 |
+
# # 🔍 IDENTIFIER TYPE DETECTION
|
| 233 |
+
# # =========================================================
|
| 234 |
+
# def _infer_expected_identifier(prefix_text: str) -> Optional[str]:
|
| 235 |
+
# s = re.sub(r"\s+", " ", prefix_text.lower())
|
| 236 |
+
|
| 237 |
+
# last_from = s.rfind(" from ")
|
| 238 |
+
# last_join = s.rfind(" join ")
|
| 239 |
+
# last_select = s.rfind(" select ")
|
| 240 |
+
# last_where = s.rfind(" where ")
|
| 241 |
+
# last_on = s.rfind(" on ")
|
| 242 |
+
# last_group = s.rfind(" group by ")
|
| 243 |
+
# last_order = s.rfind(" order by ")
|
| 244 |
+
# last_having = s.rfind(" having ")
|
| 245 |
+
|
| 246 |
+
# last_table_kw = max(last_from, last_join)
|
| 247 |
+
# last_col_kw = max(last_select, last_where, last_on, last_group, last_order, last_having)
|
| 248 |
+
|
| 249 |
+
# if last_table_kw < 0 and last_col_kw < 0:
|
| 250 |
+
# return None
|
| 251 |
+
# if last_table_kw > last_col_kw:
|
| 252 |
+
# return "table"
|
| 253 |
+
# if last_col_kw > last_table_kw:
|
| 254 |
+
# return "column"
|
| 255 |
+
# return None
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
# # =========================================================
|
| 259 |
+
# # 🌳 TRIE STRUCTURE
|
| 260 |
+
# # =========================================================
|
| 261 |
+
# class _TrieNode:
|
| 262 |
+
# __slots__ = ("children", "terminal")
|
| 263 |
+
|
| 264 |
+
# def __init__(self) -> None:
|
| 265 |
+
# self.children: Dict[int, _TrieNode] = {}
|
| 266 |
+
# self.terminal: bool = False
|
| 267 |
+
|
| 268 |
+
# def insert(self, token_ids: Sequence[int]) -> None:
|
| 269 |
+
# node = self
|
| 270 |
+
# for tid in token_ids:
|
| 271 |
+
# tid = int(tid)
|
| 272 |
+
# if tid not in node.children:
|
| 273 |
+
# node.children[tid] = _TrieNode()
|
| 274 |
+
# node = node.children[tid]
|
| 275 |
+
# node.terminal = True
|
| 276 |
+
|
| 277 |
+
# def walk(self, prefix: Sequence[int]) -> Optional["_TrieNode"]:
|
| 278 |
+
# node = self
|
| 279 |
+
# for tid in prefix:
|
| 280 |
+
# node = node.children.get(int(tid))
|
| 281 |
+
# if node is None:
|
| 282 |
+
# return None
|
| 283 |
+
# return node
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
# # =========================================================
|
| 287 |
+
# # 🔤 TOKEN ENCODING
|
| 288 |
+
# # =========================================================
|
| 289 |
+
# def _encode_identifier(tokenizer, name: str) -> List[int]:
|
| 290 |
+
# return tokenizer.encode(" " + name, add_special_tokens=False)
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
# def _build_trie(tokenizer, names: Iterable[str]) -> _TrieNode:
|
| 294 |
+
# trie = _TrieNode()
|
| 295 |
+
# for name in names:
|
| 296 |
+
# try:
|
| 297 |
+
# ids = _encode_identifier(tokenizer, name)
|
| 298 |
+
# if ids:
|
| 299 |
+
# trie.insert(ids)
|
| 300 |
+
# except Exception:
|
| 301 |
+
# continue
|
| 302 |
+
# return trie
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
# def _allow_always_token_ids(tokenizer) -> torch.Tensor:
|
| 306 |
+
# tokens = [",", ")", "(", ".", ";", "\n"]
|
| 307 |
+
# ids: Set[int] = set()
|
| 308 |
+
|
| 309 |
+
# for t in tokens:
|
| 310 |
+
# try:
|
| 311 |
+
# ids.update(tokenizer.encode(t, add_special_tokens=False))
|
| 312 |
+
# except:
|
| 313 |
+
# pass
|
| 314 |
+
|
| 315 |
+
# return torch.tensor(sorted(ids), dtype=torch.long)
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
# # =========================================================
|
| 319 |
+
# # 📦 PER-DB CACHE
|
| 320 |
+
# # =========================================================
|
| 321 |
+
# @dataclass
|
| 322 |
+
# class _PerDbTokenSets:
|
| 323 |
+
# fp: str
|
| 324 |
+
# table_trie: _TrieNode
|
| 325 |
+
# column_trie: _TrieNode
|
| 326 |
+
# allow_always: torch.Tensor
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
# _DB_CACHE: Dict[str, _PerDbTokenSets] = {}
|
| 330 |
+
# _DB_LOCK = threading.Lock()
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
# def _per_db_tokensets(tokenizer, graph: ConstraintGraph) -> _PerDbTokenSets:
|
| 334 |
+
# with _DB_LOCK:
|
| 335 |
+
# cached = _DB_CACHE.get(graph.db_path)
|
| 336 |
+
# if cached and cached.fp == graph.fingerprint:
|
| 337 |
+
# return cached
|
| 338 |
+
|
| 339 |
+
# obj = _PerDbTokenSets(
|
| 340 |
+
# fp=graph.fingerprint,
|
| 341 |
+
# table_trie=_build_trie(tokenizer, graph.tables),
|
| 342 |
+
# column_trie=_build_trie(tokenizer, graph.all_columns),
|
| 343 |
+
# allow_always=_allow_always_token_ids(tokenizer),
|
| 344 |
+
# )
|
| 345 |
+
|
| 346 |
+
# with _DB_LOCK:
|
| 347 |
+
# _DB_CACHE[graph.db_path] = obj
|
| 348 |
+
|
| 349 |
+
# return obj
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
# # =========================================================
|
| 353 |
+
# # 🚀 MAIN LOGITS PROCESSOR
|
| 354 |
+
# # =========================================================
|
| 355 |
+
# class BatchSchemaConstrainedLogitsProcessor(LogitsProcessor):
|
| 356 |
+
# def __init__(self, tokenizer, db_paths: Sequence[str], max_prefix_tokens: int = 48):
|
| 357 |
+
# self.tokenizer = tokenizer
|
| 358 |
+
# self.db_paths = list(db_paths)
|
| 359 |
+
# self.max_prefix_tokens = max_prefix_tokens
|
| 360 |
+
|
| 361 |
+
# self._graphs = [build_constraint_graph(p) for p in db_paths]
|
| 362 |
+
# self._token_sets = [_per_db_tokensets(tokenizer, g) for g in self._graphs]
|
| 363 |
+
|
| 364 |
+
# # 📊 Metrics (IMPORTANT FOR REPORT)
|
| 365 |
+
# self.total_steps = 0
|
| 366 |
+
# self.constrained_steps = 0
|
| 367 |
+
|
| 368 |
+
# def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
|
| 369 |
+
# batch = input_ids.size(0)
|
| 370 |
+
|
| 371 |
+
# for i in range(batch):
|
| 372 |
+
# self.total_steps += 1
|
| 373 |
+
|
| 374 |
+
# tail_ids = input_ids[i, -self.max_prefix_tokens:].tolist()
|
| 375 |
+
# prefix_text = self.tokenizer.decode(tail_ids, skip_special_tokens=True)
|
| 376 |
+
|
| 377 |
+
# expected = _infer_expected_identifier(prefix_text)
|
| 378 |
+
# if expected is None:
|
| 379 |
+
# continue
|
| 380 |
+
|
| 381 |
+
# self.constrained_steps += 1
|
| 382 |
+
|
| 383 |
+
# # =========================
|
| 384 |
+
# # SELECT TRIE
|
| 385 |
+
# # =========================
|
| 386 |
+
# if expected == "table":
|
| 387 |
+
# trie = self._token_sets[i].table_trie
|
| 388 |
+
# else:
|
| 389 |
+
# trie = self._token_sets[i].column_trie
|
| 390 |
+
|
| 391 |
+
# # =========================
|
| 392 |
+
# # PARTIAL TOKEN MATCH
|
| 393 |
+
# # =========================
|
| 394 |
+
# match = re.search(r"([A-Za-z_][A-Za-z0-9_]*)$", prefix_text)
|
| 395 |
+
# partial = match.group(1) if match else ""
|
| 396 |
+
|
| 397 |
+
# try:
|
| 398 |
+
# prefix_ids = _encode_identifier(self.tokenizer, partial) if partial else []
|
| 399 |
+
# except:
|
| 400 |
+
# continue
|
| 401 |
+
|
| 402 |
+
# node = trie.walk(prefix_ids)
|
| 403 |
+
# if node is None or node.terminal:
|
| 404 |
+
# continue
|
| 405 |
+
|
| 406 |
+
# allowed_next = list(node.children.keys())
|
| 407 |
+
# if not allowed_next:
|
| 408 |
+
# continue
|
| 409 |
+
|
| 410 |
+
# allowed_next = torch.tensor(allowed_next, device=scores.device)
|
| 411 |
+
# allow_always = self._token_sets[i].allow_always.to(scores.device)
|
| 412 |
+
|
| 413 |
+
# keep = torch.cat([allowed_next, allow_always])
|
| 414 |
+
|
| 415 |
+
# kept_scores = scores[i, keep].clone()
|
| 416 |
+
# scores[i, :] = -float("inf")
|
| 417 |
+
# scores[i, keep] = kept_scores
|
| 418 |
+
|
| 419 |
+
# return scores
|
| 420 |
+
|
| 421 |
+
# # =========================================================
|
| 422 |
+
# # 📊 METRICS FOR REPORT
|
| 423 |
+
# # =========================================================
|
| 424 |
+
# def get_constraint_stats(self):
|
| 425 |
+
# if self.total_steps == 0:
|
| 426 |
+
# return 0
|
| 427 |
+
# return self.constrained_steps / self.total_steps
|
| 428 |
+
|
| 429 |
+
|
| 430 |
+
# # =========================================================
|
| 431 |
+
# # 🔁 BACKWARD COMPATIBILITY
|
| 432 |
+
# # =========================================================
|
| 433 |
+
# class SchemaConstraintGraph:
|
| 434 |
+
# def __init__(self, db_path: str):
|
| 435 |
+
# self._graph = build_constraint_graph(db_path)
|
| 436 |
+
# self.tables = sorted(self._graph.tables)
|
| 437 |
+
# self.columns = sorted(self._graph.all_columns)
|
| 438 |
+
|
| 439 |
+
|
| 440 |
+
# class SchemaConstrainedLogitsProcessor(LogitsProcessor):
|
| 441 |
+
# def __init__(self, tokenizer, schema_graph: SchemaConstraintGraph):
|
| 442 |
+
# self.proc = BatchSchemaConstrainedLogitsProcessor(
|
| 443 |
+
# tokenizer, [schema_graph._graph.db_path]
|
| 444 |
+
# )
|
| 445 |
+
|
| 446 |
+
# def __call__(self, input_ids, scores):
|
| 447 |
+
# return self.proc(input_ids, scores)
|
| 448 |
+
|
| 449 |
+
|
| 450 |
+
|
| 451 |
+
|
| 452 |
+
|
| 453 |
+
|
| 454 |
+
|
| 455 |
+
|
| 456 |
+
# ********* after task 3
|
| 457 |
+
|
| 458 |
+
import re
|
| 459 |
+
import torch
|
| 460 |
+
from transformers import LogitsProcessor
|
| 461 |
+
from src.schema_utils import get_constraint_graph
|
| 462 |
+
|
| 463 |
+
|
| 464 |
+
def _infer_expected_identifier(prefix_text: str):
|
| 465 |
+
s = prefix_text.lower()
|
| 466 |
+
|
| 467 |
+
if " from " in s or " join " in s:
|
| 468 |
+
return "table"
|
| 469 |
+
if any(k in s for k in ["select", "where", "on", "group by", "order by"]):
|
| 470 |
+
return "column"
|
| 471 |
+
|
| 472 |
+
return None
|
| 473 |
+
|
| 474 |
+
|
| 475 |
+
class SchemaConstrainedLogitsProcessor(LogitsProcessor):
|
| 476 |
+
def __init__(self, tokenizer, db_path):
|
| 477 |
+
self.tokenizer = tokenizer
|
| 478 |
+
|
| 479 |
+
graph = get_constraint_graph(db_path)
|
| 480 |
+
|
| 481 |
+
self.allowed_tokens = set(graph["tables"]) | set(graph["columns"])
|
| 482 |
+
|
| 483 |
+
self.sql_keywords = {
|
| 484 |
+
"select", "from", "where", "join", "on",
|
| 485 |
+
"group", "by", "order", "limit",
|
| 486 |
+
"and", "or", "desc", "asc",
|
| 487 |
+
"count", "avg", "min", "max", "sum", "*"
|
| 488 |
+
}
|
| 489 |
+
|
| 490 |
+
self.allowed_tokens |= self.sql_keywords
|
| 491 |
+
|
| 492 |
+
self.allowed_token_ids = set()
|
| 493 |
+
for token in self.allowed_tokens:
|
| 494 |
+
ids = tokenizer.encode(token, add_special_tokens=False)
|
| 495 |
+
for i in ids:
|
| 496 |
+
self.allowed_token_ids.add(i)
|
| 497 |
+
|
| 498 |
+
def __call__(self, input_ids, scores):
|
| 499 |
+
|
| 500 |
+
prefix = self.tokenizer.decode(input_ids[0], skip_special_tokens=True)
|
| 501 |
+
|
| 502 |
+
# 🔥 SOFT CONSTRAINT (FIX)
|
| 503 |
+
if len(prefix.strip()) < 10:
|
| 504 |
+
return scores
|
| 505 |
+
|
| 506 |
+
expected = _infer_expected_identifier(prefix)
|
| 507 |
+
|
| 508 |
+
if expected not in ["table", "column"]:
|
| 509 |
+
return scores
|
| 510 |
+
|
| 511 |
+
mask = torch.full_like(scores, float("-inf"))
|
| 512 |
+
|
| 513 |
+
for token_id in self.allowed_token_ids:
|
| 514 |
+
mask[:, token_id] = scores[:, token_id]
|
| 515 |
+
|
| 516 |
+
return mask
|
src/convert_to_hf_dataset.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from datasets import Dataset
|
| 2 |
+
import pandas as pd
|
| 3 |
+
|
| 4 |
+
df = pd.read_csv("../data/processed/train.csv")
|
| 5 |
+
ds = Dataset.from_pandas(df)
|
| 6 |
+
ds.save_to_disk("../data/processed/train")
|
| 7 |
+
print("DONE")
|
| 8 |
+
|
src/eval_baseline_codet5.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import sqlite3
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
import torch
|
| 5 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
| 6 |
+
|
| 7 |
+
# ---------------- PROMPT (same style as training) ----------------
|
| 8 |
+
def build_prompt(question, schema):
|
| 9 |
+
return f"""translate English to SQL:
|
| 10 |
+
|
| 11 |
+
Schema:
|
| 12 |
+
{schema}
|
| 13 |
+
|
| 14 |
+
Question:
|
| 15 |
+
{question}
|
| 16 |
+
|
| 17 |
+
SQL:"""
|
| 18 |
+
|
| 19 |
+
# ---------------- LOAD SCHEMA ----------------
|
| 20 |
+
def load_schema(db_path):
|
| 21 |
+
conn = sqlite3.connect(db_path)
|
| 22 |
+
cursor = conn.cursor()
|
| 23 |
+
|
| 24 |
+
tables = cursor.execute(
|
| 25 |
+
"SELECT name FROM sqlite_master WHERE type='table';"
|
| 26 |
+
).fetchall()
|
| 27 |
+
|
| 28 |
+
schema = ""
|
| 29 |
+
for (table,) in tables:
|
| 30 |
+
cols = cursor.execute(f"PRAGMA table_info({table});").fetchall()
|
| 31 |
+
col_names = [c[1] for c in cols]
|
| 32 |
+
schema += f"{table}({', '.join(col_names)})\n"
|
| 33 |
+
|
| 34 |
+
conn.close()
|
| 35 |
+
return schema
|
| 36 |
+
|
| 37 |
+
# ---------------- EXECUTION MATCH ----------------
|
| 38 |
+
def execution_match(pred_sql, gold_sql, db_path):
|
| 39 |
+
try:
|
| 40 |
+
conn = sqlite3.connect(db_path)
|
| 41 |
+
cur = conn.cursor()
|
| 42 |
+
|
| 43 |
+
cur.execute(pred_sql)
|
| 44 |
+
pred = cur.fetchall()
|
| 45 |
+
|
| 46 |
+
cur.execute(gold_sql)
|
| 47 |
+
gold = cur.fetchall()
|
| 48 |
+
|
| 49 |
+
conn.close()
|
| 50 |
+
return pred == gold
|
| 51 |
+
|
| 52 |
+
except Exception:
|
| 53 |
+
return False
|
| 54 |
+
|
| 55 |
+
# ---------------- MAIN ----------------
|
| 56 |
+
def main():
|
| 57 |
+
project_root = Path(__file__).resolve().parents[1]
|
| 58 |
+
|
| 59 |
+
dev_json = project_root / "data" / "dev.json"
|
| 60 |
+
db_root = project_root / "data" / "database"
|
| 61 |
+
|
| 62 |
+
device = "mps" if torch.backends.mps.is_available() else "cpu"
|
| 63 |
+
|
| 64 |
+
print("Loading BASE CodeT5...")
|
| 65 |
+
tokenizer = AutoTokenizer.from_pretrained("Salesforce/codet5-base")
|
| 66 |
+
model = AutoModelForSeq2SeqLM.from_pretrained("Salesforce/codet5-base").to(device)
|
| 67 |
+
model.eval()
|
| 68 |
+
|
| 69 |
+
with open(dev_json) as f:
|
| 70 |
+
dev = json.load(f)[:100]
|
| 71 |
+
|
| 72 |
+
correct = 0
|
| 73 |
+
|
| 74 |
+
print(f"\nEvaluating {len(dev)} samples...\n")
|
| 75 |
+
|
| 76 |
+
for i, ex in enumerate(dev, 1):
|
| 77 |
+
question = ex["question"]
|
| 78 |
+
db_id = ex["db_id"]
|
| 79 |
+
gold_sql = ex["query"]
|
| 80 |
+
|
| 81 |
+
db_path = db_root / db_id / f"{db_id}.sqlite"
|
| 82 |
+
schema = load_schema(db_path)
|
| 83 |
+
|
| 84 |
+
prompt = build_prompt(question, schema)
|
| 85 |
+
|
| 86 |
+
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512).to(device)
|
| 87 |
+
|
| 88 |
+
with torch.no_grad():
|
| 89 |
+
outputs = model.generate(
|
| 90 |
+
**inputs,
|
| 91 |
+
max_new_tokens=80,
|
| 92 |
+
num_beams=4,
|
| 93 |
+
do_sample=False
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
pred_sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 97 |
+
|
| 98 |
+
if "SQL:" in pred_sql:
|
| 99 |
+
pred_sql = pred_sql.split("SQL:")[-1].strip()
|
| 100 |
+
|
| 101 |
+
if execution_match(pred_sql, gold_sql, db_path):
|
| 102 |
+
correct += 1
|
| 103 |
+
|
| 104 |
+
if i % 10 == 0:
|
| 105 |
+
print(f"{i}/100 | Accuracy: {correct/i:.3f}")
|
| 106 |
+
|
| 107 |
+
print("\n=============================")
|
| 108 |
+
print(f"BASE MODEL ACCURACY: {correct}% / 100 = {correct}%")
|
| 109 |
+
print("=============================")
|
| 110 |
+
|
| 111 |
+
if __name__ == "__main__":
|
| 112 |
+
main()
|
src/eval_both_metrics.py
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import sqlite3
|
| 3 |
+
import torch
|
| 4 |
+
import re
|
| 5 |
+
import time
|
| 6 |
+
import argparse
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
| 9 |
+
from peft import PeftModel
|
| 10 |
+
|
| 11 |
+
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
| 12 |
+
DB_ROOT = PROJECT_ROOT / "data" / "database"
|
| 13 |
+
|
| 14 |
+
# -------------------------------
|
| 15 |
+
# 1. NORMALIZATION FOR EXACT MATCH
|
| 16 |
+
# -------------------------------
|
| 17 |
+
def normalize_sql(sql):
|
| 18 |
+
"""Cleans SQL to make Exact Match grading fair (ignores spacing/cases)."""
|
| 19 |
+
sql = sql.replace('"', "'") # Standardize quotes
|
| 20 |
+
sql = re.sub(r"\s+", " ", sql) # Remove extra spaces/newlines
|
| 21 |
+
sql = sql.strip().lower() # Lowercase everything
|
| 22 |
+
sql = sql.rstrip(";") # Remove trailing semicolons
|
| 23 |
+
return sql
|
| 24 |
+
|
| 25 |
+
# -------------------------------
|
| 26 |
+
# 2. EXECUTION ACCURACY CHECK
|
| 27 |
+
# -------------------------------
|
| 28 |
+
def check_execution(pred_sql, gold_sql, db_path):
|
| 29 |
+
"""Runs both queries and checks if the output rows/columns match."""
|
| 30 |
+
try:
|
| 31 |
+
conn = sqlite3.connect(db_path)
|
| 32 |
+
# Handle bad characters in Spider DBs
|
| 33 |
+
conn.text_factory = lambda b: b.decode(errors='ignore')
|
| 34 |
+
|
| 35 |
+
# 5-second timeout
|
| 36 |
+
start_time = time.monotonic()
|
| 37 |
+
def timeout_handler():
|
| 38 |
+
return 1 if (time.monotonic() - start_time) > 5.0 else 0
|
| 39 |
+
conn.set_progress_handler(timeout_handler, 10000)
|
| 40 |
+
|
| 41 |
+
cursor = conn.cursor()
|
| 42 |
+
|
| 43 |
+
# Get Predicted Result
|
| 44 |
+
cursor.execute(pred_sql)
|
| 45 |
+
pred_res = cursor.fetchall()
|
| 46 |
+
|
| 47 |
+
# Get Gold Result
|
| 48 |
+
cursor.execute(gold_sql)
|
| 49 |
+
gold_res = cursor.fetchall()
|
| 50 |
+
|
| 51 |
+
conn.close()
|
| 52 |
+
return pred_res == gold_res
|
| 53 |
+
except Exception:
|
| 54 |
+
return False
|
| 55 |
+
|
| 56 |
+
# -------------------------------
|
| 57 |
+
# 3. LOAD SCHEMA
|
| 58 |
+
# -------------------------------
|
| 59 |
+
def load_schema(db_path):
|
| 60 |
+
conn = sqlite3.connect(db_path)
|
| 61 |
+
conn.text_factory = lambda b: b.decode(errors='ignore')
|
| 62 |
+
cursor = conn.cursor()
|
| 63 |
+
tables = cursor.execute("SELECT name FROM sqlite_master WHERE type='table';").fetchall()
|
| 64 |
+
schema = ""
|
| 65 |
+
for (table,) in tables:
|
| 66 |
+
cols = cursor.execute(f"PRAGMA table_info({table});").fetchall()
|
| 67 |
+
col_names = [c[1] for c in cols]
|
| 68 |
+
schema += f"{table}({', '.join(col_names)})\n"
|
| 69 |
+
conn.close()
|
| 70 |
+
return schema
|
| 71 |
+
|
| 72 |
+
# -------------------------------
|
| 73 |
+
# 4. MAIN PIPELINE
|
| 74 |
+
# -------------------------------
|
| 75 |
+
def main():
|
| 76 |
+
parser = argparse.ArgumentParser()
|
| 77 |
+
parser.add_argument("--adapter", type=str, required=True, help="Path to your SFT or RLHF checkpoint")
|
| 78 |
+
parser.add_argument("--num_samples", type=int, default=1034, help="How many samples to evaluate")
|
| 79 |
+
args = parser.parse_args()
|
| 80 |
+
|
| 81 |
+
device = "mps" if torch.backends.mps.is_available() else ("cuda" if torch.cuda.is_available() else "cpu")
|
| 82 |
+
base_model = "Salesforce/codet5-base"
|
| 83 |
+
|
| 84 |
+
print(f"\n🚀 Loading Model from: {args.adapter}")
|
| 85 |
+
tokenizer = AutoTokenizer.from_pretrained(args.adapter)
|
| 86 |
+
base = AutoModelForSeq2SeqLM.from_pretrained(base_model).to(device)
|
| 87 |
+
model = PeftModel.from_pretrained(base, args.adapter).to(device)
|
| 88 |
+
model = model.merge_and_unload()
|
| 89 |
+
model.eval()
|
| 90 |
+
|
| 91 |
+
dev_json = PROJECT_ROOT / "data" / "dev.json"
|
| 92 |
+
with open(dev_json) as f:
|
| 93 |
+
dev = json.load(f)[:args.num_samples]
|
| 94 |
+
|
| 95 |
+
em_correct = 0
|
| 96 |
+
ex_correct = 0
|
| 97 |
+
total = len(dev)
|
| 98 |
+
|
| 99 |
+
print(f"\n📊 Evaluating {total} queries for BOTH Exact Match and Execution Accuracy...\n")
|
| 100 |
+
|
| 101 |
+
for i, ex in enumerate(dev, 1):
|
| 102 |
+
question = ex["question"]
|
| 103 |
+
gold_sql = ex["query"]
|
| 104 |
+
db_id = ex["db_id"]
|
| 105 |
+
db_path = DB_ROOT / db_id / f"{db_id}.sqlite"
|
| 106 |
+
|
| 107 |
+
# Generate SQL
|
| 108 |
+
schema = load_schema(db_path)
|
| 109 |
+
prompt = f"Database Schema:\n{schema}\nTranslate English to SQL:\n{question}\nSQL:\n"
|
| 110 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(device)
|
| 111 |
+
|
| 112 |
+
with torch.no_grad():
|
| 113 |
+
outputs = model.generate(**inputs, max_new_tokens=100, num_beams=4, do_sample=False)
|
| 114 |
+
|
| 115 |
+
pred_sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 116 |
+
if "SQL:" in pred_sql:
|
| 117 |
+
pred_sql = pred_sql.split("SQL:")[-1].strip()
|
| 118 |
+
|
| 119 |
+
# --- METRIC 1: EXACT MATCH ---
|
| 120 |
+
is_em = (normalize_sql(pred_sql) == normalize_sql(gold_sql))
|
| 121 |
+
if is_em:
|
| 122 |
+
em_correct += 1
|
| 123 |
+
|
| 124 |
+
# --- METRIC 2: EXECUTION ACCURACY ---
|
| 125 |
+
is_ex = check_execution(pred_sql, gold_sql, db_path)
|
| 126 |
+
if is_ex:
|
| 127 |
+
ex_correct += 1
|
| 128 |
+
|
| 129 |
+
if i % 50 == 0 or i == total:
|
| 130 |
+
print(f"Progress: {i}/{total} | Current EM: {(em_correct/i)*100:.2f}% | Current EX: {(ex_correct/i)*100:.2f}%")
|
| 131 |
+
|
| 132 |
+
# Final Results
|
| 133 |
+
final_em = (em_correct / total) * 100
|
| 134 |
+
final_ex = (ex_correct / total) * 100
|
| 135 |
+
|
| 136 |
+
print("\n==========================================")
|
| 137 |
+
print(f"🎯 FINAL RESULTS FOR: {args.adapter}")
|
| 138 |
+
print("==========================================")
|
| 139 |
+
print(f"Exact Match (EM) Accuracy : {final_em:.2f}%")
|
| 140 |
+
print(f"Execution (EX) Accuracy : {final_ex:.2f}%")
|
| 141 |
+
print("==========================================\n")
|
| 142 |
+
|
| 143 |
+
if __name__ == "__main__":
|
| 144 |
+
main()
|
src/eval_rl_fixed.py
ADDED
|
@@ -0,0 +1,756 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# import json
|
| 2 |
+
# import subprocess
|
| 3 |
+
# import sys
|
| 4 |
+
# import argparse
|
| 5 |
+
# import random
|
| 6 |
+
# import sqlite3
|
| 7 |
+
# import time
|
| 8 |
+
# import re
|
| 9 |
+
# import os
|
| 10 |
+
# from pathlib import Path
|
| 11 |
+
|
| 12 |
+
# import torch
|
| 13 |
+
# from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
| 14 |
+
# from peft import PeftModel
|
| 15 |
+
|
| 16 |
+
# from prompting import encode_prompt
|
| 17 |
+
|
| 18 |
+
# # -------------------------------
|
| 19 |
+
# # NORMALIZATION
|
| 20 |
+
# # -------------------------------
|
| 21 |
+
# def normalize_sql(sql):
|
| 22 |
+
# sql = sql.replace('"', "'")
|
| 23 |
+
# sql = re.sub(r"\s+", " ", sql)
|
| 24 |
+
# return sql.strip().lower().rstrip(";")
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
# # -------------------------------
|
| 28 |
+
# # 🔥 SAFE RESULT NORMALIZATION (FIX)
|
| 29 |
+
# # -------------------------------
|
| 30 |
+
# def normalize_result(res):
|
| 31 |
+
# try:
|
| 32 |
+
# return sorted([str(r) for r in res])
|
| 33 |
+
# except:
|
| 34 |
+
# return []
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
# # -------------------------------
|
| 38 |
+
# # EXECUTION CHECK (FIXED)
|
| 39 |
+
# # -------------------------------
|
| 40 |
+
# def check_execution(pred_sql, gold_sql, db_path):
|
| 41 |
+
# try:
|
| 42 |
+
# conn = sqlite3.connect(db_path)
|
| 43 |
+
# conn.text_factory = lambda b: b.decode(errors='ignore')
|
| 44 |
+
|
| 45 |
+
# start_time = time.monotonic()
|
| 46 |
+
|
| 47 |
+
# def timeout_handler():
|
| 48 |
+
# return 1 if (time.monotonic() - start_time) > 2.0 else 0
|
| 49 |
+
|
| 50 |
+
# conn.set_progress_handler(timeout_handler, 10000)
|
| 51 |
+
|
| 52 |
+
# cursor = conn.cursor()
|
| 53 |
+
|
| 54 |
+
# cursor.execute(pred_sql)
|
| 55 |
+
# pred_res = cursor.fetchall()
|
| 56 |
+
|
| 57 |
+
# cursor.execute(gold_sql)
|
| 58 |
+
# gold_res = cursor.fetchall()
|
| 59 |
+
|
| 60 |
+
# conn.close()
|
| 61 |
+
|
| 62 |
+
# # 🔥 FIXED COMPARISON
|
| 63 |
+
# return normalize_result(pred_res) == normalize_result(gold_res)
|
| 64 |
+
|
| 65 |
+
# except Exception:
|
| 66 |
+
# return False
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
# # -------------------------------
|
| 70 |
+
# # SPIDER PARSER
|
| 71 |
+
# # -------------------------------
|
| 72 |
+
# def _parse_spider_accuracy(stdout: str, metric_type: str):
|
| 73 |
+
# for line in stdout.splitlines():
|
| 74 |
+
# if metric_type == "exec" and line.strip().startswith("execution"):
|
| 75 |
+
# try:
|
| 76 |
+
# return float(line.split()[-1])
|
| 77 |
+
# except:
|
| 78 |
+
# pass
|
| 79 |
+
# elif metric_type == "match" and line.strip().startswith("exact"):
|
| 80 |
+
# try:
|
| 81 |
+
# return float(line.split()[-1])
|
| 82 |
+
# except:
|
| 83 |
+
# pass
|
| 84 |
+
# return None
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
# # -------------------------------
|
| 88 |
+
# # MAIN
|
| 89 |
+
# # -------------------------------
|
| 90 |
+
# def main():
|
| 91 |
+
# parser = argparse.ArgumentParser()
|
| 92 |
+
# parser.add_argument("--adapter", type=str, required=True)
|
| 93 |
+
# parser.add_argument("--num_samples", type=int, default=700)
|
| 94 |
+
# parser.add_argument("--shuffle_dev", action="store_true")
|
| 95 |
+
# parser.add_argument("--shuffle_seed", type=int, default=42)
|
| 96 |
+
# args = parser.parse_args()
|
| 97 |
+
|
| 98 |
+
# project_root = Path(__file__).resolve().parents[1]
|
| 99 |
+
# adapter_dir = project_root / args.adapter
|
| 100 |
+
|
| 101 |
+
# db_root = project_root / "data" / "database"
|
| 102 |
+
# table_json = project_root / "data" / "tables.json"
|
| 103 |
+
# dev_json = project_root / "data" / "dev.json"
|
| 104 |
+
|
| 105 |
+
# pred_path = project_root / "temp_predictions.txt"
|
| 106 |
+
# temp_gold_path = project_root / "temp_gold.sql"
|
| 107 |
+
|
| 108 |
+
# if not adapter_dir.exists():
|
| 109 |
+
# raise FileNotFoundError(f"Missing adapter dir: {adapter_dir}")
|
| 110 |
+
|
| 111 |
+
# device = "mps" if torch.backends.mps.is_available() else (
|
| 112 |
+
# "cuda" if torch.cuda.is_available() else "cpu"
|
| 113 |
+
# )
|
| 114 |
+
# print(f"Using device: {device}")
|
| 115 |
+
|
| 116 |
+
# BASE_MODEL = "Salesforce/codet5-base"
|
| 117 |
+
# tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
|
| 118 |
+
|
| 119 |
+
# if tokenizer.pad_token is None:
|
| 120 |
+
# tokenizer.pad_token = tokenizer.eos_token
|
| 121 |
+
|
| 122 |
+
# print(f"\n📦 Loading Model: {args.adapter}")
|
| 123 |
+
|
| 124 |
+
# base = AutoModelForSeq2SeqLM.from_pretrained(BASE_MODEL).to(device)
|
| 125 |
+
|
| 126 |
+
# adapter_for_peft = os.path.relpath(adapter_dir, project_root)
|
| 127 |
+
|
| 128 |
+
# model = PeftModel.from_pretrained(
|
| 129 |
+
# base,
|
| 130 |
+
# adapter_for_peft,
|
| 131 |
+
# local_files_only=True
|
| 132 |
+
# ).to(device)
|
| 133 |
+
|
| 134 |
+
# model = model.merge_and_unload()
|
| 135 |
+
# model.eval()
|
| 136 |
+
|
| 137 |
+
# # -------------------------------
|
| 138 |
+
# # LOAD DATA
|
| 139 |
+
# # -------------------------------
|
| 140 |
+
# with dev_json.open() as f:
|
| 141 |
+
# dev = json.load(f)
|
| 142 |
+
|
| 143 |
+
# if args.shuffle_dev:
|
| 144 |
+
# rng = random.Random(args.shuffle_seed)
|
| 145 |
+
# rng.shuffle(dev)
|
| 146 |
+
|
| 147 |
+
# dev = dev[: args.num_samples]
|
| 148 |
+
# total = len(dev)
|
| 149 |
+
|
| 150 |
+
# gen_kwargs = dict(
|
| 151 |
+
# max_new_tokens=160,
|
| 152 |
+
# num_beams=8,
|
| 153 |
+
# length_penalty=0.8,
|
| 154 |
+
# do_sample=False,
|
| 155 |
+
# early_stopping=True,
|
| 156 |
+
# pad_token_id=tokenizer.pad_token_id,
|
| 157 |
+
# eos_token_id=tokenizer.eos_token_id,
|
| 158 |
+
# )
|
| 159 |
+
|
| 160 |
+
# print(f"\n🚀 Evaluating {total} samples...\n")
|
| 161 |
+
|
| 162 |
+
# em_correct = 0
|
| 163 |
+
# ex_correct = 0
|
| 164 |
+
|
| 165 |
+
# with pred_path.open("w") as out_pred, temp_gold_path.open("w") as out_gold, torch.no_grad():
|
| 166 |
+
# for i, ex in enumerate(dev, start=1):
|
| 167 |
+
|
| 168 |
+
# db_id = ex["db_id"]
|
| 169 |
+
# question = ex["question"]
|
| 170 |
+
# gold_query = ex["query"]
|
| 171 |
+
# db_path = db_root / db_id / f"{db_id}.sqlite"
|
| 172 |
+
|
| 173 |
+
# # -------------------------------
|
| 174 |
+
# # GENERATE SQL
|
| 175 |
+
# # -------------------------------
|
| 176 |
+
# input_ids = encode_prompt(
|
| 177 |
+
# tokenizer,
|
| 178 |
+
# question,
|
| 179 |
+
# db_id,
|
| 180 |
+
# device=device,
|
| 181 |
+
# max_input_tokens=512
|
| 182 |
+
# )
|
| 183 |
+
|
| 184 |
+
# input_ids = input_ids.unsqueeze(0).to(device)
|
| 185 |
+
# attention_mask = (input_ids != tokenizer.pad_token_id).long().to(device)
|
| 186 |
+
|
| 187 |
+
# outputs = model.generate(
|
| 188 |
+
# input_ids=input_ids,
|
| 189 |
+
# attention_mask=attention_mask,
|
| 190 |
+
# **gen_kwargs
|
| 191 |
+
# )
|
| 192 |
+
|
| 193 |
+
# pred_sql = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
|
| 194 |
+
|
| 195 |
+
# # -------------------------------
|
| 196 |
+
# # SAVE FOR SPIDER EVAL
|
| 197 |
+
# # -------------------------------
|
| 198 |
+
# out_pred.write(f"{pred_sql}\n")
|
| 199 |
+
# out_gold.write(f"{gold_query}\t{db_id}\n")
|
| 200 |
+
|
| 201 |
+
# # -------------------------------
|
| 202 |
+
# # LIVE METRICS
|
| 203 |
+
# # -------------------------------
|
| 204 |
+
# if normalize_sql(pred_sql) == normalize_sql(gold_query):
|
| 205 |
+
# em_correct += 1
|
| 206 |
+
|
| 207 |
+
# if check_execution(pred_sql, gold_query, db_path):
|
| 208 |
+
# ex_correct += 1
|
| 209 |
+
|
| 210 |
+
# if i % 20 == 0 or i == total:
|
| 211 |
+
# print(
|
| 212 |
+
# f"Progress: {i}/{total} | "
|
| 213 |
+
# f"EM: {(em_correct/i)*100:.2f}% | "
|
| 214 |
+
# f"EX: {(ex_correct/i)*100:.2f}%"
|
| 215 |
+
# )
|
| 216 |
+
|
| 217 |
+
# print("\n🚀 Running Official Spider Evaluation...\n")
|
| 218 |
+
|
| 219 |
+
# eval_script = project_root / "spider_eval" / "evaluation.py"
|
| 220 |
+
|
| 221 |
+
# # EXACT MATCH
|
| 222 |
+
# cmd_match = [
|
| 223 |
+
# sys.executable, str(eval_script),
|
| 224 |
+
# "--gold", str(temp_gold_path),
|
| 225 |
+
# "--pred", str(pred_path),
|
| 226 |
+
# "--etype", "match",
|
| 227 |
+
# "--db", str(db_root),
|
| 228 |
+
# "--table", str(table_json),
|
| 229 |
+
# ]
|
| 230 |
+
|
| 231 |
+
# proc_match = subprocess.run(cmd_match, capture_output=True, text=True)
|
| 232 |
+
# exact_acc = _parse_spider_accuracy(proc_match.stdout, "match")
|
| 233 |
+
|
| 234 |
+
# # EXECUTION
|
| 235 |
+
# cmd_exec = [
|
| 236 |
+
# sys.executable, str(eval_script),
|
| 237 |
+
# "--gold", str(temp_gold_path),
|
| 238 |
+
# "--pred", str(pred_path),
|
| 239 |
+
# "--etype", "exec",
|
| 240 |
+
# "--db", str(db_root),
|
| 241 |
+
# "--table", str(table_json),
|
| 242 |
+
# ]
|
| 243 |
+
|
| 244 |
+
# proc_exec = subprocess.run(cmd_exec, capture_output=True, text=True)
|
| 245 |
+
# exec_acc = _parse_spider_accuracy(proc_exec.stdout, "exec")
|
| 246 |
+
|
| 247 |
+
# print("==========================================")
|
| 248 |
+
# print(f"🎯 OFFICIAL SPIDER RESULTS FOR: {args.adapter}")
|
| 249 |
+
# print("==========================================")
|
| 250 |
+
|
| 251 |
+
# print(f"Exact Match Accuracy : {exact_acc*100:.2f}%" if exact_acc else "EM parsing failed")
|
| 252 |
+
# print(f"Execution Accuracy : {exec_acc*100:.2f}%" if exec_acc else "EX parsing failed")
|
| 253 |
+
|
| 254 |
+
# print("==========================================\n")
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
# if __name__ == "__main__":
|
| 258 |
+
# main()
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
# import json
|
| 266 |
+
# import sqlite3
|
| 267 |
+
# import re
|
| 268 |
+
# import time
|
| 269 |
+
# import sys
|
| 270 |
+
# import argparse
|
| 271 |
+
# from pathlib import Path
|
| 272 |
+
|
| 273 |
+
# # ==========================================
|
| 274 |
+
# # PATH SETUP
|
| 275 |
+
# # ==========================================
|
| 276 |
+
# PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
| 277 |
+
# if str(PROJECT_ROOT) not in sys.path:
|
| 278 |
+
# sys.path.insert(0, str(PROJECT_ROOT))
|
| 279 |
+
|
| 280 |
+
# from src.text2sql_engine import get_engine
|
| 281 |
+
# from src.sql_validator import validate_sql_schema
|
| 282 |
+
|
| 283 |
+
# # ==========================================
|
| 284 |
+
# # CONFIG
|
| 285 |
+
# # ==========================================
|
| 286 |
+
# DATA_PATH = PROJECT_ROOT / "data" / "dev.json"
|
| 287 |
+
# DB_ROOT = PROJECT_ROOT / "data" / "database"
|
| 288 |
+
|
| 289 |
+
# # ==========================================
|
| 290 |
+
# # NORMALIZATION
|
| 291 |
+
# # ==========================================
|
| 292 |
+
# def normalize_sql(sql):
|
| 293 |
+
# if not isinstance(sql, str):
|
| 294 |
+
# return ""
|
| 295 |
+
# sql = sql.replace('"', "'")
|
| 296 |
+
# sql = re.sub(r"\s+", " ", sql)
|
| 297 |
+
# return sql.strip().lower().rstrip(";")
|
| 298 |
+
|
| 299 |
+
# def normalize_result(res):
|
| 300 |
+
# try:
|
| 301 |
+
# return sorted([tuple(map(str, r)) for r in res])
|
| 302 |
+
# except:
|
| 303 |
+
# return []
|
| 304 |
+
|
| 305 |
+
# # ==========================================
|
| 306 |
+
# # EXECUTION
|
| 307 |
+
# # ==========================================
|
| 308 |
+
# def execute_sql(db_path, sql):
|
| 309 |
+
# try:
|
| 310 |
+
# conn = sqlite3.connect(db_path)
|
| 311 |
+
|
| 312 |
+
# start = time.time()
|
| 313 |
+
# def timeout():
|
| 314 |
+
# return 1 if (time.time() - start) > 2 else 0
|
| 315 |
+
|
| 316 |
+
# conn.set_progress_handler(timeout, 10000)
|
| 317 |
+
|
| 318 |
+
# cur = conn.cursor()
|
| 319 |
+
# cur.execute(sql)
|
| 320 |
+
# res = cur.fetchall()
|
| 321 |
+
|
| 322 |
+
# conn.close()
|
| 323 |
+
# return res
|
| 324 |
+
|
| 325 |
+
# except Exception:
|
| 326 |
+
# return None
|
| 327 |
+
|
| 328 |
+
# # ==========================================
|
| 329 |
+
# # EVALUATION
|
| 330 |
+
# # ==========================================
|
| 331 |
+
# def evaluate(engine, data, is_constrained=False, debug=False):
|
| 332 |
+
|
| 333 |
+
# attempted = 0
|
| 334 |
+
# total = 0
|
| 335 |
+
# exact_match = 0
|
| 336 |
+
# execution_match = 0
|
| 337 |
+
# constraint_ok = 0
|
| 338 |
+
|
| 339 |
+
# skipped_missing_db = 0
|
| 340 |
+
# skipped_exception = 0
|
| 341 |
+
# skipped_no_sql = 0
|
| 342 |
+
|
| 343 |
+
# total_time = 0
|
| 344 |
+
|
| 345 |
+
# for i, item in enumerate(data, 1):
|
| 346 |
+
|
| 347 |
+
# question = item.get("question", "")
|
| 348 |
+
# gold_sql = item.get("query", "")
|
| 349 |
+
# db_id = item.get("db_id", "")
|
| 350 |
+
|
| 351 |
+
# db_path = DB_ROOT / db_id / f"{db_id}.sqlite"
|
| 352 |
+
|
| 353 |
+
# if not db_path.exists():
|
| 354 |
+
# skipped_missing_db += 1
|
| 355 |
+
# continue
|
| 356 |
+
|
| 357 |
+
# try:
|
| 358 |
+
# start = time.time()
|
| 359 |
+
# result = engine.ask(question, db_id)
|
| 360 |
+
# total_time += (time.time() - start)
|
| 361 |
+
# except Exception:
|
| 362 |
+
# skipped_exception += 1
|
| 363 |
+
# continue
|
| 364 |
+
|
| 365 |
+
# if not isinstance(result, dict):
|
| 366 |
+
# continue
|
| 367 |
+
|
| 368 |
+
# pred_sql = result.get("sql", "")
|
| 369 |
+
|
| 370 |
+
# # DEBUG
|
| 371 |
+
# if debug:
|
| 372 |
+
# print(f"\nQ: {question}")
|
| 373 |
+
# print(f"PRED: {pred_sql}")
|
| 374 |
+
# print(f"GOLD: {gold_sql}")
|
| 375 |
+
|
| 376 |
+
# if not pred_sql:
|
| 377 |
+
# skipped_no_sql += 1
|
| 378 |
+
# continue
|
| 379 |
+
|
| 380 |
+
# attempted += 1
|
| 381 |
+
# total += 1
|
| 382 |
+
|
| 383 |
+
# # CONSTRAINT CHECK
|
| 384 |
+
# if is_constrained:
|
| 385 |
+
# try:
|
| 386 |
+
# is_valid, _ = validate_sql_schema(pred_sql, str(db_path))
|
| 387 |
+
# if is_valid:
|
| 388 |
+
# constraint_ok += 1
|
| 389 |
+
# except:
|
| 390 |
+
# pass
|
| 391 |
+
|
| 392 |
+
# # EXACT MATCH
|
| 393 |
+
# if normalize_sql(pred_sql) == normalize_sql(gold_sql):
|
| 394 |
+
# exact_match += 1
|
| 395 |
+
|
| 396 |
+
# # EXECUTION MATCH
|
| 397 |
+
# pred_res = execute_sql(str(db_path), pred_sql)
|
| 398 |
+
# gold_res = execute_sql(str(db_path), gold_sql)
|
| 399 |
+
|
| 400 |
+
# if pred_res is not None and gold_res is not None:
|
| 401 |
+
# if normalize_result(pred_res) == normalize_result(gold_res):
|
| 402 |
+
# execution_match += 1
|
| 403 |
+
|
| 404 |
+
# # PROGRESS
|
| 405 |
+
# if i % 10 == 0:
|
| 406 |
+
# print(
|
| 407 |
+
# f"[{i}/{len(data)}] "
|
| 408 |
+
# f"EM: {exact_match/max(total,1):.3f} | "
|
| 409 |
+
# f"EX: {execution_match/max(total,1):.3f} | "
|
| 410 |
+
# f"Constraint: {(constraint_ok/max(total,1)) if is_constrained else 0:.3f}"
|
| 411 |
+
# )
|
| 412 |
+
|
| 413 |
+
# avg_latency = total_time / max(attempted, 1)
|
| 414 |
+
|
| 415 |
+
# return {
|
| 416 |
+
# "exact_match": exact_match / total if total > 0 else 0,
|
| 417 |
+
# "execution_accuracy": execution_match / total if total > 0 else 0,
|
| 418 |
+
# "constraint_rate": (constraint_ok / total if (is_constrained and total > 0) else 0),
|
| 419 |
+
# "avg_latency": avg_latency,
|
| 420 |
+
# "total": total,
|
| 421 |
+
# "attempted": attempted,
|
| 422 |
+
# "skipped_missing_db": skipped_missing_db,
|
| 423 |
+
# "skipped_exception": skipped_exception,
|
| 424 |
+
# "skipped_no_sql": skipped_no_sql,
|
| 425 |
+
# }
|
| 426 |
+
|
| 427 |
+
# # ==========================================
|
| 428 |
+
# # MAIN
|
| 429 |
+
# # ==========================================
|
| 430 |
+
# if __name__ == "__main__":
|
| 431 |
+
|
| 432 |
+
# ap = argparse.ArgumentParser()
|
| 433 |
+
# ap.add_argument("--num-samples", type=int, default=100)
|
| 434 |
+
# ap.add_argument("--adapter", type=str, default="checkpoints/best_rlhf_model")
|
| 435 |
+
# ap.add_argument("--debug", action="store_true")
|
| 436 |
+
# args = ap.parse_args()
|
| 437 |
+
|
| 438 |
+
# print(f"\n📥 Loading dataset from {DATA_PATH}...")
|
| 439 |
+
|
| 440 |
+
# with open(str(DATA_PATH)) as f:
|
| 441 |
+
# data = json.load(f)[: args.num_samples]
|
| 442 |
+
|
| 443 |
+
# # ==========================================
|
| 444 |
+
# # 🔴 BASE MODEL
|
| 445 |
+
# # ==========================================
|
| 446 |
+
# print("\n🚀 Running BASE MODEL...\n")
|
| 447 |
+
|
| 448 |
+
# engine_base = get_engine(
|
| 449 |
+
# adapter_path="checkpoints/sft_adapter_codet5" , # 🔥 change this
|
| 450 |
+
# use_lora=True,
|
| 451 |
+
# use_constrained=False
|
| 452 |
+
# )
|
| 453 |
+
|
| 454 |
+
# res_base = evaluate(engine_base, data, is_constrained=False, debug=args.debug)
|
| 455 |
+
|
| 456 |
+
# # ==========================================
|
| 457 |
+
# # 🟡 RLHF (NO CONSTRAINT)
|
| 458 |
+
# # ==========================================
|
| 459 |
+
# print("\n🚀 Running RLHF (NO CONSTRAINT)...\n")
|
| 460 |
+
|
| 461 |
+
# engine_rlhf = get_engine(
|
| 462 |
+
# adapter_path="checkpoints/best_rlhf_model",
|
| 463 |
+
# use_lora=True,
|
| 464 |
+
# use_constrained=False
|
| 465 |
+
# )
|
| 466 |
+
|
| 467 |
+
# res_rlhf = evaluate(engine_rlhf, data, is_constrained=False, debug=args.debug)
|
| 468 |
+
|
| 469 |
+
# # ==========================================
|
| 470 |
+
# # 🟢 RLHF + CONSTRAINT
|
| 471 |
+
# # ==========================================
|
| 472 |
+
# print("\n🚀 Running RLHF + CONSTRAINED...\n")
|
| 473 |
+
|
| 474 |
+
# engine_const = get_engine(
|
| 475 |
+
# adapter_path="checkpoints/best_rlhf_model_2",
|
| 476 |
+
# use_lora=True,
|
| 477 |
+
# use_constrained=True
|
| 478 |
+
# )
|
| 479 |
+
|
| 480 |
+
# res_const = evaluate(engine_const, data, is_constrained=True, debug=args.debug)
|
| 481 |
+
|
| 482 |
+
# # ==========================================
|
| 483 |
+
# # FINAL RESULTS
|
| 484 |
+
# # ==========================================
|
| 485 |
+
# print("\n==========================================")
|
| 486 |
+
# print("🎯 FINAL RESULTS (3-WAY COMPARISON)")
|
| 487 |
+
# print("==========================================")
|
| 488 |
+
|
| 489 |
+
# print(f"Base Model → EM: {res_base['exact_match']*100:.2f}% | "
|
| 490 |
+
# f"EX: {res_base['execution_accuracy']*100:.2f}%")
|
| 491 |
+
|
| 492 |
+
# print(f"RLHF → EM: {res_rlhf['exact_match']*100:.2f}% | "
|
| 493 |
+
# f"EX: {res_rlhf['execution_accuracy']*100:.2f}%")
|
| 494 |
+
|
| 495 |
+
# print(f"RLHF + Constrain → EM: {res_const['exact_match']*100:.2f}% | "
|
| 496 |
+
# f"EX: {res_const['execution_accuracy']*100:.2f}% | "
|
| 497 |
+
# f"Constraint: {res_const['constraint_rate']*100:.2f}%")
|
| 498 |
+
|
| 499 |
+
# print("==========================================\n")
|
| 500 |
+
|
| 501 |
+
|
| 502 |
+
import json
|
| 503 |
+
import argparse
|
| 504 |
+
import sqlite3
|
| 505 |
+
import time
|
| 506 |
+
import re
|
| 507 |
+
import os
|
| 508 |
+
from pathlib import Path
|
| 509 |
+
|
| 510 |
+
import torch
|
| 511 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
| 512 |
+
from peft import PeftModel
|
| 513 |
+
|
| 514 |
+
# Import handling
|
| 515 |
+
try:
|
| 516 |
+
from prompting import encode_prompt
|
| 517 |
+
from src.sql_validator import validate_sql_schema
|
| 518 |
+
except ImportError:
|
| 519 |
+
import sys
|
| 520 |
+
sys.path.append(str(Path(__file__).resolve().parents[1]))
|
| 521 |
+
from src.prompting import encode_prompt
|
| 522 |
+
from src.sql_validator import validate_sql_schema
|
| 523 |
+
|
| 524 |
+
# =========================================================
|
| 525 |
+
# ERROR LOGGING
|
| 526 |
+
# =========================================================
|
| 527 |
+
ERROR_LOG_FILE = "results/error_logs.json"
|
| 528 |
+
|
| 529 |
+
def classify_error(sql, error_msg=""):
|
| 530 |
+
sql = sql.lower()
|
| 531 |
+
error_msg = str(error_msg).lower()
|
| 532 |
+
|
| 533 |
+
if "no such column" in error_msg:
|
| 534 |
+
return "wrong_column"
|
| 535 |
+
if "no such table" in error_msg:
|
| 536 |
+
return "wrong_table"
|
| 537 |
+
if "syntax error" in error_msg:
|
| 538 |
+
return "syntax_error"
|
| 539 |
+
if "ambiguous column" in error_msg:
|
| 540 |
+
return "ambiguous_column"
|
| 541 |
+
if "join" in sql and " on " not in sql:
|
| 542 |
+
return "missing_join"
|
| 543 |
+
|
| 544 |
+
return "other"
|
| 545 |
+
|
| 546 |
+
def log_error(question, sql, error, error_type):
|
| 547 |
+
os.makedirs(os.path.dirname(ERROR_LOG_FILE), exist_ok=True)
|
| 548 |
+
|
| 549 |
+
entry = {
|
| 550 |
+
"question": question,
|
| 551 |
+
"sql": sql,
|
| 552 |
+
"error": str(error),
|
| 553 |
+
"error_type": error_type,
|
| 554 |
+
"timestamp": time.time()
|
| 555 |
+
}
|
| 556 |
+
|
| 557 |
+
logs = []
|
| 558 |
+
if os.path.exists(ERROR_LOG_FILE):
|
| 559 |
+
try:
|
| 560 |
+
with open(ERROR_LOG_FILE, "r") as f:
|
| 561 |
+
content = f.read().strip()
|
| 562 |
+
if content:
|
| 563 |
+
logs = json.loads(content)
|
| 564 |
+
except:
|
| 565 |
+
logs = []
|
| 566 |
+
|
| 567 |
+
logs.append(entry)
|
| 568 |
+
|
| 569 |
+
with open(ERROR_LOG_FILE, "w") as f:
|
| 570 |
+
json.dump(logs, f, indent=2)
|
| 571 |
+
|
| 572 |
+
# =========================================================
|
| 573 |
+
# 🔥 FINAL FIX_SQL (BALANCED VERSION)
|
| 574 |
+
# =========================================================
|
| 575 |
+
def fix_sql(sql):
|
| 576 |
+
if not sql:
|
| 577 |
+
return "SELECT 1"
|
| 578 |
+
|
| 579 |
+
s = str(sql).strip()
|
| 580 |
+
|
| 581 |
+
# Extract SQL only
|
| 582 |
+
match = re.search(r"(?i)(select|with)[\s\S]*", s)
|
| 583 |
+
if match:
|
| 584 |
+
s = match.group(0)
|
| 585 |
+
|
| 586 |
+
s = s.split(";")[0].strip()
|
| 587 |
+
|
| 588 |
+
# NULL fixes
|
| 589 |
+
s = re.sub(r'(?i)=\s*null', 'IS NULL', s)
|
| 590 |
+
s = re.sub(r'(?i)!=\s*null', 'IS NOT NULL', s)
|
| 591 |
+
|
| 592 |
+
# Fix commas
|
| 593 |
+
s = re.sub(r',\s*,+', ',', s)
|
| 594 |
+
s = re.sub(r'(?i),\s*from', ' FROM', s)
|
| 595 |
+
|
| 596 |
+
# 🔥 LIGHT COLUMN SAFETY (main improvement)
|
| 597 |
+
if "select" in s.lower():
|
| 598 |
+
if len(re.findall(r'\w+\.\w+', s)) > 3:
|
| 599 |
+
s = re.sub(r'(?i)select\s+.*?\s+from', 'SELECT * FROM', s)
|
| 600 |
+
|
| 601 |
+
# 🔥 JOIN fix
|
| 602 |
+
if "join" in s.lower() and " on " not in s.lower():
|
| 603 |
+
s = re.sub(r'join\s+(\w+)', r'JOIN \1 ON 1=1', s, flags=re.I)
|
| 604 |
+
|
| 605 |
+
# Ensure valid SQL
|
| 606 |
+
if not s.lower().startswith(("select", "with")):
|
| 607 |
+
return "SELECT 1"
|
| 608 |
+
|
| 609 |
+
return s.strip()
|
| 610 |
+
|
| 611 |
+
# =========================================================
|
| 612 |
+
# NORMALIZATION
|
| 613 |
+
# =========================================================
|
| 614 |
+
def normalize_sql(sql):
|
| 615 |
+
if not sql:
|
| 616 |
+
return ""
|
| 617 |
+
return re.sub(r"\s+", " ", str(sql)).strip().lower()
|
| 618 |
+
|
| 619 |
+
def normalize_result(res):
|
| 620 |
+
if not res:
|
| 621 |
+
return []
|
| 622 |
+
try:
|
| 623 |
+
normalized = [tuple(sorted(str(x) for x in row)) for row in res]
|
| 624 |
+
return sorted(normalized)
|
| 625 |
+
except:
|
| 626 |
+
return sorted([str(r) for r in res])
|
| 627 |
+
|
| 628 |
+
# =========================================================
|
| 629 |
+
# EXECUTION HELPERS
|
| 630 |
+
# =========================================================
|
| 631 |
+
def is_executable(sql, db_path):
|
| 632 |
+
try:
|
| 633 |
+
conn = sqlite3.connect(db_path)
|
| 634 |
+
cur = conn.cursor()
|
| 635 |
+
cur.execute(sql)
|
| 636 |
+
conn.close()
|
| 637 |
+
return True
|
| 638 |
+
except:
|
| 639 |
+
return False
|
| 640 |
+
|
| 641 |
+
def check_execution(pred_sql, gold_sql, db_path, question):
|
| 642 |
+
try:
|
| 643 |
+
conn = sqlite3.connect(db_path)
|
| 644 |
+
conn.text_factory = lambda b: b.decode(errors='ignore')
|
| 645 |
+
cur = conn.cursor()
|
| 646 |
+
|
| 647 |
+
cur.execute(gold_sql)
|
| 648 |
+
gold_res = cur.fetchall()
|
| 649 |
+
|
| 650 |
+
cur.execute(pred_sql)
|
| 651 |
+
pred_res = cur.fetchall()
|
| 652 |
+
|
| 653 |
+
conn.close()
|
| 654 |
+
|
| 655 |
+
return normalize_result(pred_res) == normalize_result(gold_res)
|
| 656 |
+
|
| 657 |
+
except Exception as e:
|
| 658 |
+
error_type = classify_error(pred_sql, str(e))
|
| 659 |
+
log_error(question, pred_sql, str(e), error_type)
|
| 660 |
+
return False
|
| 661 |
+
|
| 662 |
+
# =========================================================
|
| 663 |
+
# MAIN
|
| 664 |
+
# =========================================================
|
| 665 |
+
def main():
|
| 666 |
+
parser = argparse.ArgumentParser()
|
| 667 |
+
parser.add_argument("--adapter", type=str, required=True)
|
| 668 |
+
parser.add_argument("--num_samples", type=int, default=700)
|
| 669 |
+
args = parser.parse_args()
|
| 670 |
+
|
| 671 |
+
project_root = Path(__file__).resolve().parent
|
| 672 |
+
if project_root.name in ["scripts", "src"]:
|
| 673 |
+
project_root = project_root.parent
|
| 674 |
+
|
| 675 |
+
db_root = project_root / "data" / "database"
|
| 676 |
+
dev_json = project_root / "data" / "dev.json"
|
| 677 |
+
|
| 678 |
+
device = "mps" if torch.backends.mps.is_available() else "cpu"
|
| 679 |
+
|
| 680 |
+
print(f"Loading model on {device}...")
|
| 681 |
+
|
| 682 |
+
tokenizer = AutoTokenizer.from_pretrained("Salesforce/codet5-base")
|
| 683 |
+
base_model = AutoModelForSeq2SeqLM.from_pretrained("Salesforce/codet5-base").to(device)
|
| 684 |
+
|
| 685 |
+
model = PeftModel.from_pretrained(base_model, args.adapter).to(device)
|
| 686 |
+
model = model.merge_and_unload()
|
| 687 |
+
model.eval()
|
| 688 |
+
|
| 689 |
+
with open(dev_json, "r") as f:
|
| 690 |
+
dev_data = json.load(f)[:args.num_samples]
|
| 691 |
+
|
| 692 |
+
em_correct = 0
|
| 693 |
+
ex_correct = 0
|
| 694 |
+
constraint_ok = 0
|
| 695 |
+
|
| 696 |
+
print(f"\n🚀 Evaluating {len(dev_data)} samples...\n")
|
| 697 |
+
|
| 698 |
+
for i, ex in enumerate(dev_data, 1):
|
| 699 |
+
db_id = ex["db_id"]
|
| 700 |
+
question = ex["question"]
|
| 701 |
+
gold_query = ex["query"]
|
| 702 |
+
|
| 703 |
+
db_path = db_root / db_id / f"{db_id}.sqlite"
|
| 704 |
+
|
| 705 |
+
input_tensor = encode_prompt(tokenizer, question, db_id, device=device).unsqueeze(0)
|
| 706 |
+
|
| 707 |
+
with torch.no_grad():
|
| 708 |
+
outputs = model.generate(
|
| 709 |
+
input_ids=input_tensor,
|
| 710 |
+
max_new_tokens=128,
|
| 711 |
+
num_beams=8,
|
| 712 |
+
num_return_sequences=8
|
| 713 |
+
)
|
| 714 |
+
|
| 715 |
+
best_sql = ""
|
| 716 |
+
|
| 717 |
+
# 🔥 EXECUTION-GUIDED SELECTION
|
| 718 |
+
for out in outputs:
|
| 719 |
+
raw_pred = tokenizer.decode(out, skip_special_tokens=True)
|
| 720 |
+
candidate_sql = fix_sql(raw_pred)
|
| 721 |
+
|
| 722 |
+
if is_executable(candidate_sql, str(db_path)):
|
| 723 |
+
best_sql = candidate_sql
|
| 724 |
+
break
|
| 725 |
+
|
| 726 |
+
if not best_sql:
|
| 727 |
+
best_sql = fix_sql(tokenizer.decode(outputs[0], skip_special_tokens=True))
|
| 728 |
+
|
| 729 |
+
try:
|
| 730 |
+
is_valid, _ = validate_sql_schema(best_sql, str(db_path))
|
| 731 |
+
except:
|
| 732 |
+
is_valid = False
|
| 733 |
+
|
| 734 |
+
if is_valid:
|
| 735 |
+
constraint_ok += 1
|
| 736 |
+
|
| 737 |
+
if normalize_sql(best_sql) == normalize_sql(gold_query):
|
| 738 |
+
em_correct += 1
|
| 739 |
+
|
| 740 |
+
if check_execution(best_sql, gold_query, str(db_path), question):
|
| 741 |
+
ex_correct += 1
|
| 742 |
+
|
| 743 |
+
if i % 50 == 0:
|
| 744 |
+
print(f"{i}/{len(dev_data)} done")
|
| 745 |
+
|
| 746 |
+
print("\n========================================")
|
| 747 |
+
print("🎯 FINAL EVALUATION RESULTS")
|
| 748 |
+
print("========================================")
|
| 749 |
+
print(f"Exact Match (EM): {(em_correct/len(dev_data))*100:.2f}%")
|
| 750 |
+
print(f"Execution Acc (EX): {(ex_correct/len(dev_data))*100:.2f}%")
|
| 751 |
+
print(f"Constraint Rate: {(constraint_ok/len(dev_data))*100:.2f}%")
|
| 752 |
+
print("========================================")
|
| 753 |
+
print(f"Errors logged to: {ERROR_LOG_FILE}")
|
| 754 |
+
|
| 755 |
+
if __name__ == "__main__":
|
| 756 |
+
main()
|
src/eval_rl_t5.py
ADDED
|
@@ -0,0 +1,279 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# import sys
|
| 2 |
+
# import os
|
| 3 |
+
# sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 4 |
+
# import json
|
| 5 |
+
|
| 6 |
+
# import subprocess
|
| 7 |
+
|
| 8 |
+
# import argparse
|
| 9 |
+
# from pathlib import Path
|
| 10 |
+
|
| 11 |
+
# import torch
|
| 12 |
+
# from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
| 13 |
+
# from peft import PeftModel
|
| 14 |
+
|
| 15 |
+
# # IMPORTANT: must match training prompt format
|
| 16 |
+
# from prompting import build_prompt
|
| 17 |
+
# from schema_utils import get_schema as get_db_schema
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
# def _parse_exec_accuracy(stdout: str):
|
| 21 |
+
# for line in stdout.splitlines():
|
| 22 |
+
# if line.strip().startswith("execution"):
|
| 23 |
+
# parts = line.split()
|
| 24 |
+
# try:
|
| 25 |
+
# return float(parts[-1])
|
| 26 |
+
# except Exception:
|
| 27 |
+
# return None
|
| 28 |
+
# return None
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
# def main():
|
| 32 |
+
# parser = argparse.ArgumentParser()
|
| 33 |
+
# parser.add_argument("--adapter", type=str, default="checkpoints/best_rlhf_model")
|
| 34 |
+
# parser.add_argument("--num_samples", type=int, default=200)
|
| 35 |
+
# args = parser.parse_args()
|
| 36 |
+
|
| 37 |
+
# project_root = Path(__file__).resolve().parents[1]
|
| 38 |
+
# adapter_dir = project_root / args.adapter
|
| 39 |
+
|
| 40 |
+
# if not adapter_dir.exists():
|
| 41 |
+
# raise FileNotFoundError(f"Adapter not found: {adapter_dir}")
|
| 42 |
+
|
| 43 |
+
# db_root = project_root / "data" / "database"
|
| 44 |
+
# table_json = project_root / "data" / "tables.json"
|
| 45 |
+
# dev_json = project_root / "data" / "dev.json"
|
| 46 |
+
# gold_sql = project_root / "data" / "dev_gold.sql"
|
| 47 |
+
# pred_path = project_root / "predictions_rl.txt"
|
| 48 |
+
|
| 49 |
+
# device = "mps" if torch.backends.mps.is_available() else "cpu"
|
| 50 |
+
|
| 51 |
+
# # ---- LOAD MODEL (CodeT5 + LoRA) ----
|
| 52 |
+
# base_model = "Salesforce/codet5-base"
|
| 53 |
+
|
| 54 |
+
# tokenizer = AutoTokenizer.from_pretrained(str(adapter_dir))
|
| 55 |
+
# base = AutoModelForSeq2SeqLM.from_pretrained(base_model).to(device)
|
| 56 |
+
# model = PeftModel.from_pretrained(base, str(adapter_dir)).to(device)
|
| 57 |
+
|
| 58 |
+
# # merge LoRA for faster inference
|
| 59 |
+
# model = model.merge_and_unload()
|
| 60 |
+
# model.eval()
|
| 61 |
+
# model.config.use_cache = True
|
| 62 |
+
|
| 63 |
+
# if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None:
|
| 64 |
+
# tokenizer.pad_token = tokenizer.eos_token
|
| 65 |
+
|
| 66 |
+
# # ---- LOAD DATA ----
|
| 67 |
+
# with dev_json.open() as f:
|
| 68 |
+
# dev = json.load(f)
|
| 69 |
+
|
| 70 |
+
# dev = dev[: args.num_samples]
|
| 71 |
+
|
| 72 |
+
# gen_kwargs = dict(
|
| 73 |
+
# max_new_tokens=120,
|
| 74 |
+
# do_sample=False,
|
| 75 |
+
# num_beams=1,
|
| 76 |
+
# pad_token_id=tokenizer.pad_token_id,
|
| 77 |
+
# eos_token_id=tokenizer.eos_token_id,
|
| 78 |
+
# )
|
| 79 |
+
|
| 80 |
+
# print(f"Generating {len(dev)} predictions...")
|
| 81 |
+
|
| 82 |
+
# with pred_path.open("w") as out_f, torch.no_grad():
|
| 83 |
+
# for i, ex in enumerate(dev, start=1):
|
| 84 |
+
# db_id = ex["db_id"]
|
| 85 |
+
# question = ex["question"]
|
| 86 |
+
|
| 87 |
+
# db_path = db_root / db_id / f"{db_id}.sqlite"
|
| 88 |
+
# schema = get_db_schema(str(db_path))
|
| 89 |
+
# prompt = build_prompt(question, schema, use_schema=True)
|
| 90 |
+
|
| 91 |
+
# inputs = tokenizer(
|
| 92 |
+
# prompt,
|
| 93 |
+
# return_tensors="pt",
|
| 94 |
+
# truncation=True,
|
| 95 |
+
# max_length=512
|
| 96 |
+
# ).to(device)
|
| 97 |
+
|
| 98 |
+
# out = model.generate(**inputs, **gen_kwargs)
|
| 99 |
+
# pred_sql = tokenizer.decode(out[0], skip_special_tokens=True).strip()
|
| 100 |
+
|
| 101 |
+
# out_f.write(f"{pred_sql}\t{db_id}\n")
|
| 102 |
+
|
| 103 |
+
# if i % 20 == 0 or i == len(dev):
|
| 104 |
+
# print(f"{i}/{len(dev)} done")
|
| 105 |
+
|
| 106 |
+
# # ---- SPIDER OFFICIAL EVAL ----
|
| 107 |
+
# eval_script = project_root / "spider_eval" / "evaluation.py"
|
| 108 |
+
|
| 109 |
+
# cmd = [
|
| 110 |
+
# sys.executable,
|
| 111 |
+
# str(eval_script),
|
| 112 |
+
# "--gold",
|
| 113 |
+
# str(gold_sql),
|
| 114 |
+
# "--pred",
|
| 115 |
+
# str(pred_path),
|
| 116 |
+
# "--etype",
|
| 117 |
+
# "exec",
|
| 118 |
+
# "--db",
|
| 119 |
+
# str(db_root),
|
| 120 |
+
# "--table",
|
| 121 |
+
# str(table_json),
|
| 122 |
+
# ]
|
| 123 |
+
|
| 124 |
+
# print("\nRunning Spider execution evaluation...\n")
|
| 125 |
+
# proc = subprocess.run(cmd, capture_output=True, text=True)
|
| 126 |
+
|
| 127 |
+
# if proc.returncode != 0:
|
| 128 |
+
# print(proc.stdout)
|
| 129 |
+
# print(proc.stderr)
|
| 130 |
+
# sys.exit(proc.returncode)
|
| 131 |
+
|
| 132 |
+
# print(proc.stdout)
|
| 133 |
+
|
| 134 |
+
# acc = _parse_exec_accuracy(proc.stdout)
|
| 135 |
+
# if acc is not None:
|
| 136 |
+
# print(f"\nFINAL EXECUTION ACCURACY: {acc*100:.2f}%")
|
| 137 |
+
# else:
|
| 138 |
+
# print("Could not parse execution accuracy")
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
# if __name__ == "__main__":
|
| 142 |
+
# main()
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
import json
|
| 146 |
+
import sqlite3
|
| 147 |
+
import argparse
|
| 148 |
+
import time
|
| 149 |
+
from pathlib import Path
|
| 150 |
+
import torch
|
| 151 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
| 152 |
+
from peft import PeftModel
|
| 153 |
+
|
| 154 |
+
# ---------------- PROMPT (FIXED TO PERFECTLY MATCH RLHF TRAINING) ----------------
|
| 155 |
+
def build_prompt(question, schema):
|
| 156 |
+
return f"translate English to SQL:\n\nSchema:\n{schema}\n\nQuestion:\n{question}\n\nSQL:"
|
| 157 |
+
|
| 158 |
+
# ---------------- LOAD SCHEMA (FIXED TO MATCH TRAINING FORMAT) ----------------
|
| 159 |
+
def load_schema(db_path):
|
| 160 |
+
conn = sqlite3.connect(db_path)
|
| 161 |
+
cursor = conn.cursor()
|
| 162 |
+
|
| 163 |
+
tables = cursor.execute(
|
| 164 |
+
"SELECT name FROM sqlite_master WHERE type='table';"
|
| 165 |
+
).fetchall()
|
| 166 |
+
|
| 167 |
+
schema = ""
|
| 168 |
+
for (table,) in tables:
|
| 169 |
+
cols = cursor.execute(f"PRAGMA table_info({table});").fetchall()
|
| 170 |
+
col_names = [c[1] for c in cols]
|
| 171 |
+
# Space-separated, not newline-separated, just like the RLHF script
|
| 172 |
+
schema += f"{table}({', '.join(col_names)}) "
|
| 173 |
+
|
| 174 |
+
conn.close()
|
| 175 |
+
return schema.strip()
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
# ---------------- EXECUTION CHECK WITH TIMEOUT ----------------
|
| 179 |
+
def execution_match(pred_sql, gold_sql, db_path):
|
| 180 |
+
try:
|
| 181 |
+
conn = sqlite3.connect(db_path)
|
| 182 |
+
|
| 183 |
+
# --- 5-SECOND TIMEOUT SO THE SCRIPT DOESN'T HANG ---
|
| 184 |
+
start_time = time.monotonic()
|
| 185 |
+
def timeout_handler():
|
| 186 |
+
return 1 if (time.monotonic() - start_time) > 5.0 else 0
|
| 187 |
+
conn.set_progress_handler(timeout_handler, 10000)
|
| 188 |
+
|
| 189 |
+
cur = conn.cursor()
|
| 190 |
+
|
| 191 |
+
cur.execute(pred_sql)
|
| 192 |
+
pred = cur.fetchall()
|
| 193 |
+
|
| 194 |
+
cur.execute(gold_sql)
|
| 195 |
+
gold = cur.fetchall()
|
| 196 |
+
|
| 197 |
+
conn.close()
|
| 198 |
+
return pred == gold
|
| 199 |
+
|
| 200 |
+
except Exception:
|
| 201 |
+
return False
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
# ---------------- MAIN ----------------
|
| 205 |
+
def main():
|
| 206 |
+
parser = argparse.ArgumentParser()
|
| 207 |
+
# 🎯 Set the default directly to your best RLHF model!
|
| 208 |
+
parser.add_argument("--adapter", type=str, default="checkpoints/rlhf_t5_best")
|
| 209 |
+
parser.add_argument("--num_samples", type=int, default=1000)
|
| 210 |
+
args = parser.parse_args()
|
| 211 |
+
|
| 212 |
+
project_root = Path(__file__).resolve().parents[1]
|
| 213 |
+
|
| 214 |
+
# Resolve adapter path safely
|
| 215 |
+
adapter_path = project_root / args.adapter
|
| 216 |
+
|
| 217 |
+
dev_json = project_root / "data" / "dev.json"
|
| 218 |
+
db_root = project_root / "data" / "database"
|
| 219 |
+
|
| 220 |
+
# 🎯 Added CUDA support
|
| 221 |
+
device = "mps" if torch.backends.mps.is_available() else ("cuda" if torch.cuda.is_available() else "cpu")
|
| 222 |
+
|
| 223 |
+
# load model
|
| 224 |
+
base_model = "t5-small"
|
| 225 |
+
print(f"Loading Base: {base_model}")
|
| 226 |
+
print(f"Loading Adapter: {adapter_path}")
|
| 227 |
+
|
| 228 |
+
tokenizer = AutoTokenizer.from_pretrained(str(adapter_path))
|
| 229 |
+
base = AutoModelForSeq2SeqLM.from_pretrained(base_model).to(device)
|
| 230 |
+
model = PeftModel.from_pretrained(base, str(adapter_path)).to(device)
|
| 231 |
+
model = model.merge_and_unload()
|
| 232 |
+
|
| 233 |
+
with open(dev_json) as f:
|
| 234 |
+
dev = json.load(f)[: args.num_samples]
|
| 235 |
+
|
| 236 |
+
correct = 0
|
| 237 |
+
|
| 238 |
+
print(f"Evaluating {len(dev)} examples...\n")
|
| 239 |
+
|
| 240 |
+
for i, ex in enumerate(dev, 1):
|
| 241 |
+
question = ex["question"]
|
| 242 |
+
db_id = ex["db_id"]
|
| 243 |
+
gold_sql = ex["query"]
|
| 244 |
+
|
| 245 |
+
db_path = db_root / db_id / f"{db_id}.sqlite"
|
| 246 |
+
schema = load_schema(db_path)
|
| 247 |
+
|
| 248 |
+
prompt = build_prompt(question, schema)
|
| 249 |
+
|
| 250 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(device)
|
| 251 |
+
|
| 252 |
+
with torch.no_grad():
|
| 253 |
+
outputs = model.generate(
|
| 254 |
+
**inputs,
|
| 255 |
+
max_new_tokens=80,
|
| 256 |
+
do_sample=False,
|
| 257 |
+
num_beams=4,
|
| 258 |
+
)
|
| 259 |
+
|
| 260 |
+
pred_sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 261 |
+
|
| 262 |
+
if "SQL:" in pred_sql:
|
| 263 |
+
pred_sql = pred_sql.split("SQL:")[-1].strip()
|
| 264 |
+
|
| 265 |
+
match = execution_match(pred_sql, gold_sql, db_path)
|
| 266 |
+
|
| 267 |
+
if match:
|
| 268 |
+
correct += 1
|
| 269 |
+
|
| 270 |
+
if i % 10 == 0:
|
| 271 |
+
print(f"{i}/{len(dev)} | Acc: {correct/i:.3f}")
|
| 272 |
+
|
| 273 |
+
print("\n=============================")
|
| 274 |
+
print(f"FINAL EXECUTION ACCURACY: {correct/len(dev)*100:.2f}%")
|
| 275 |
+
print("=============================")
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
if __name__ == "__main__":
|
| 279 |
+
main()
|
src/eval_single_model.py
ADDED
|
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import subprocess
|
| 3 |
+
import sys
|
| 4 |
+
import argparse
|
| 5 |
+
import random
|
| 6 |
+
import sqlite3
|
| 7 |
+
import time
|
| 8 |
+
import re
|
| 9 |
+
import matplotlib.pyplot as plt
|
| 10 |
+
import numpy as np
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
| 15 |
+
from peft import PeftModel
|
| 16 |
+
|
| 17 |
+
# Assuming you have a prompting.py that has encode_prompt
|
| 18 |
+
from prompting import encode_prompt
|
| 19 |
+
|
| 20 |
+
# -------------------------------
|
| 21 |
+
# LIVE CHECK HELPERS
|
| 22 |
+
# -------------------------------
|
| 23 |
+
def normalize_sql(sql):
|
| 24 |
+
sql = sql.replace('"', "'")
|
| 25 |
+
sql = re.sub(r"\s+", " ", sql)
|
| 26 |
+
return sql.strip().lower().rstrip(";")
|
| 27 |
+
|
| 28 |
+
def check_execution(pred_sql, gold_sql, db_path):
|
| 29 |
+
try:
|
| 30 |
+
conn = sqlite3.connect(db_path)
|
| 31 |
+
conn.text_factory = lambda b: b.decode(errors='ignore')
|
| 32 |
+
|
| 33 |
+
start_time = time.monotonic()
|
| 34 |
+
def timeout_handler():
|
| 35 |
+
return 1 if (time.monotonic() - start_time) > 2.0 else 0
|
| 36 |
+
conn.set_progress_handler(timeout_handler, 10000)
|
| 37 |
+
|
| 38 |
+
cursor = conn.cursor()
|
| 39 |
+
cursor.execute(pred_sql)
|
| 40 |
+
pred_res = cursor.fetchall()
|
| 41 |
+
|
| 42 |
+
cursor.execute(gold_sql)
|
| 43 |
+
gold_res = cursor.fetchall()
|
| 44 |
+
conn.close()
|
| 45 |
+
|
| 46 |
+
return sorted(pred_res) == sorted(gold_res)
|
| 47 |
+
except Exception:
|
| 48 |
+
return False
|
| 49 |
+
|
| 50 |
+
# -------------------------------
|
| 51 |
+
# SPIDER PARSER
|
| 52 |
+
# -------------------------------
|
| 53 |
+
def _parse_spider_accuracy(stdout: str, metric_type: str) -> float | None:
|
| 54 |
+
for line in stdout.splitlines():
|
| 55 |
+
if metric_type == "exec" and line.strip().startswith("execution"):
|
| 56 |
+
try: return float(line.split()[-1])
|
| 57 |
+
except: pass
|
| 58 |
+
elif metric_type == "match" and line.strip().startswith("exact"):
|
| 59 |
+
try: return float(line.split()[-1])
|
| 60 |
+
except: pass
|
| 61 |
+
return None
|
| 62 |
+
|
| 63 |
+
# -------------------------------
|
| 64 |
+
# MAIN
|
| 65 |
+
# -------------------------------
|
| 66 |
+
def main():
|
| 67 |
+
parser = argparse.ArgumentParser()
|
| 68 |
+
parser.add_argument("--adapter", type=str, required=True, help="Path to your checkpoint")
|
| 69 |
+
parser.add_argument("--base_model", type=str, required=True, help="E.g., facebook/bart-base, t5-small")
|
| 70 |
+
parser.add_argument("--model_name", type=str, required=True, help="Name for the plot label (e.g., 'BART RLHF')")
|
| 71 |
+
parser.add_argument("--num_samples", type=int, default=700)
|
| 72 |
+
args = parser.parse_args()
|
| 73 |
+
|
| 74 |
+
project_root = Path(__file__).resolve().parents[1]
|
| 75 |
+
adapter_dir = project_root / args.adapter
|
| 76 |
+
|
| 77 |
+
db_root = project_root / "data" / "database"
|
| 78 |
+
table_json = project_root / "data" / "tables.json"
|
| 79 |
+
dev_json = project_root / "data" / "dev.json"
|
| 80 |
+
|
| 81 |
+
pred_path = project_root / "temp_predictions.txt"
|
| 82 |
+
temp_gold_path = project_root / "temp_gold.sql"
|
| 83 |
+
|
| 84 |
+
# NEW: Plot directory setup
|
| 85 |
+
plot_dir = project_root / "comparison_plots"
|
| 86 |
+
plot_dir.mkdir(parents=True, exist_ok=True)
|
| 87 |
+
results_json_path = plot_dir / "all_metrics.json"
|
| 88 |
+
|
| 89 |
+
if not adapter_dir.exists():
|
| 90 |
+
raise FileNotFoundError(f"Missing adapter dir: {adapter_dir}")
|
| 91 |
+
|
| 92 |
+
device = "mps" if torch.backends.mps.is_available() else ("cuda" if torch.cuda.is_available() else "cpu")
|
| 93 |
+
print(f"Loading Base Model: {args.base_model} on {device}...")
|
| 94 |
+
|
| 95 |
+
tokenizer = AutoTokenizer.from_pretrained(args.base_model)
|
| 96 |
+
if tokenizer.pad_token is None:
|
| 97 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 98 |
+
|
| 99 |
+
base = AutoModelForSeq2SeqLM.from_pretrained(args.base_model).to(device)
|
| 100 |
+
model = PeftModel.from_pretrained(base, str(adapter_dir)).to(device)
|
| 101 |
+
model = model.merge_and_unload()
|
| 102 |
+
model.eval()
|
| 103 |
+
|
| 104 |
+
with dev_json.open() as f:
|
| 105 |
+
dev = json.load(f)[: args.num_samples]
|
| 106 |
+
total = len(dev)
|
| 107 |
+
|
| 108 |
+
gen_kwargs = dict(
|
| 109 |
+
max_new_tokens=160,
|
| 110 |
+
num_beams=4,
|
| 111 |
+
do_sample=False,
|
| 112 |
+
early_stopping=True,
|
| 113 |
+
pad_token_id=tokenizer.pad_token_id,
|
| 114 |
+
eos_token_id=tokenizer.eos_token_id,
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
print(f"\n🚀 Generating and live-tracking {total} samples...\n")
|
| 118 |
+
|
| 119 |
+
em_correct = 0
|
| 120 |
+
ex_correct = 0
|
| 121 |
+
|
| 122 |
+
with pred_path.open("w") as out_pred, temp_gold_path.open("w") as out_gold, torch.no_grad():
|
| 123 |
+
for i, ex in enumerate(dev, start=1):
|
| 124 |
+
db_id = ex["db_id"]
|
| 125 |
+
question = ex["question"]
|
| 126 |
+
gold_query = ex["query"]
|
| 127 |
+
db_path = db_root / db_id / f"{db_id}.sqlite"
|
| 128 |
+
|
| 129 |
+
# Generate
|
| 130 |
+
input_ids = encode_prompt(tokenizer, question, db_id, device=device, max_input_tokens=512)
|
| 131 |
+
input_ids = input_ids.unsqueeze(0).to(device)
|
| 132 |
+
attention_mask = (input_ids != tokenizer.pad_token_id).long().to(device)
|
| 133 |
+
|
| 134 |
+
outputs = model.generate(input_ids=input_ids, attention_mask=attention_mask, **gen_kwargs)
|
| 135 |
+
pred_sql = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
|
| 136 |
+
|
| 137 |
+
out_pred.write(f"{pred_sql}\n")
|
| 138 |
+
out_gold.write(f"{gold_query}\t{db_id}\n")
|
| 139 |
+
|
| 140 |
+
# --- PRINT FIRST 3 EXAMPLES ---
|
| 141 |
+
if i <= 3:
|
| 142 |
+
print(f"--- 🔍 Example {i} ---")
|
| 143 |
+
print(f"Q : {question}")
|
| 144 |
+
print(f"Gold: {gold_query}")
|
| 145 |
+
print(f"Pred: {pred_sql}")
|
| 146 |
+
print("-" * 25)
|
| 147 |
+
|
| 148 |
+
# --- LIVE TRACKING CHECKS ---
|
| 149 |
+
if normalize_sql(pred_sql) == normalize_sql(gold_query):
|
| 150 |
+
em_correct += 1
|
| 151 |
+
if check_execution(pred_sql, gold_query, db_path):
|
| 152 |
+
ex_correct += 1
|
| 153 |
+
|
| 154 |
+
if i % 50 == 0 or i == total:
|
| 155 |
+
print(f"Progress: {i}/{total} | Current EM: {(em_correct/i)*100:.2f}% | Current EX: {(ex_correct/i)*100:.2f}%")
|
| 156 |
+
|
| 157 |
+
print("\nRunning Official Spider Evaluations...")
|
| 158 |
+
eval_script = project_root / "spider_eval" / "evaluation.py"
|
| 159 |
+
|
| 160 |
+
proc_match = subprocess.run([sys.executable, str(eval_script), "--gold", str(temp_gold_path), "--pred", str(pred_path), "--etype", "match", "--db", str(db_root), "--table", str(table_json)], capture_output=True, text=True)
|
| 161 |
+
exact_acc = _parse_spider_accuracy(proc_match.stdout, "match")
|
| 162 |
+
|
| 163 |
+
proc_exec = subprocess.run([sys.executable, str(eval_script), "--gold", str(temp_gold_path), "--pred", str(pred_path), "--etype", "exec", "--db", str(db_root), "--table", str(table_json)], capture_output=True, text=True)
|
| 164 |
+
exec_acc = _parse_spider_accuracy(proc_exec.stdout, "exec")
|
| 165 |
+
|
| 166 |
+
print("\n==========================================")
|
| 167 |
+
print(f"🎯 RESULTS FOR: {args.model_name}")
|
| 168 |
+
print("==========================================")
|
| 169 |
+
exact_val = exact_acc * 100 if exact_acc else 0
|
| 170 |
+
exec_val = exec_acc * 100 if exec_acc else 0
|
| 171 |
+
print(f"Exact Match : {exact_val:.2f}%")
|
| 172 |
+
print(f"Execution : {exec_val:.2f}%")
|
| 173 |
+
print("==========================================\n")
|
| 174 |
+
|
| 175 |
+
# -------------------------------
|
| 176 |
+
# SAVE JSON & GENERATE PLOT
|
| 177 |
+
# -------------------------------
|
| 178 |
+
if results_json_path.exists():
|
| 179 |
+
with open(results_json_path, 'r') as f:
|
| 180 |
+
all_results = json.load(f)
|
| 181 |
+
else:
|
| 182 |
+
all_results = {}
|
| 183 |
+
|
| 184 |
+
all_results[args.model_name] = {"EM": exact_val, "EX": exec_val}
|
| 185 |
+
|
| 186 |
+
with open(results_json_path, 'w') as f:
|
| 187 |
+
json.dump(all_results, f, indent=4)
|
| 188 |
+
|
| 189 |
+
labels = list(all_results.keys())
|
| 190 |
+
em_vals = [all_results[k]["EM"] for k in labels]
|
| 191 |
+
ex_vals = [all_results[k]["EX"] for k in labels]
|
| 192 |
+
|
| 193 |
+
x = np.arange(len(labels))
|
| 194 |
+
width = 0.35
|
| 195 |
+
|
| 196 |
+
plt.figure(figsize=(max(8, len(labels) * 1.5), 6))
|
| 197 |
+
plt.bar(x - width/2, em_vals, width, label='Exact Match', color='#3498db')
|
| 198 |
+
plt.bar(x + width/2, ex_vals, width, label='Execution', color='#2ecc71')
|
| 199 |
+
|
| 200 |
+
plt.ylabel('Accuracy (%)', fontweight='bold')
|
| 201 |
+
plt.title('Model Comparison: Exact Match vs Execution Accuracy', fontweight='bold', fontsize=14)
|
| 202 |
+
plt.xticks(x, labels, rotation=45, ha="right")
|
| 203 |
+
plt.legend()
|
| 204 |
+
plt.ylim(0, max(max(em_vals, default=0), max(ex_vals, default=0)) + 15)
|
| 205 |
+
plt.grid(axis='y', linestyle='--', alpha=0.7)
|
| 206 |
+
|
| 207 |
+
# Attach labels to bars
|
| 208 |
+
for i in range(len(labels)):
|
| 209 |
+
plt.text(x[i] - width/2, em_vals[i] + 1, f"{em_vals[i]:.1f}%", ha='center', fontsize=9)
|
| 210 |
+
plt.text(x[i] + width/2, ex_vals[i] + 1, f"{ex_vals[i]:.1f}%", ha='center', fontsize=9)
|
| 211 |
+
|
| 212 |
+
plt.tight_layout()
|
| 213 |
+
plot_path = plot_dir / "accuracy_comparison.png"
|
| 214 |
+
plt.savefig(plot_path, dpi=300)
|
| 215 |
+
print(f"📈 Updated comparison plot saved to: {plot_path}")
|
| 216 |
+
|
| 217 |
+
if __name__ == "__main__":
|
| 218 |
+
main()
|
src/evaluate_model_codet5.py
ADDED
|
@@ -0,0 +1,392 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
# import json
|
| 4 |
+
# import subprocess
|
| 5 |
+
# import sys
|
| 6 |
+
# import argparse
|
| 7 |
+
# import sqlite3
|
| 8 |
+
# import random
|
| 9 |
+
# from pathlib import Path
|
| 10 |
+
|
| 11 |
+
# import torch
|
| 12 |
+
# from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
| 13 |
+
# from peft import PeftModel
|
| 14 |
+
|
| 15 |
+
# from prompting import encode_prompt
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
# def _parse_exec_accuracy(stdout: str) -> float | None:
|
| 19 |
+
# for line in stdout.splitlines():
|
| 20 |
+
# if line.strip().startswith("execution"):
|
| 21 |
+
# try:
|
| 22 |
+
# return float(line.split()[-1])
|
| 23 |
+
# except:
|
| 24 |
+
# return None
|
| 25 |
+
# return None
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# def main():
|
| 29 |
+
|
| 30 |
+
# # ---------------- ARGUMENTS ----------------
|
| 31 |
+
# parser = argparse.ArgumentParser()
|
| 32 |
+
# parser.add_argument("--adapter", type=str, default="checkpoints/sft_adapter_codet5")
|
| 33 |
+
# parser.add_argument("--num_samples", type=int, default=1000)
|
| 34 |
+
# parser.add_argument("--shuffle_dev", action="store_true")
|
| 35 |
+
# parser.add_argument("--shuffle_seed", type=int, default=42)
|
| 36 |
+
# parser.add_argument("--accuracy_log", type=str, default="")
|
| 37 |
+
# args = parser.parse_args()
|
| 38 |
+
|
| 39 |
+
# project_root = Path(__file__).resolve().parents[1]
|
| 40 |
+
# adapter_dir = project_root / args.adapter
|
| 41 |
+
|
| 42 |
+
# db_root = project_root / "data" / "database"
|
| 43 |
+
# table_json = project_root / "data" / "tables.json"
|
| 44 |
+
# dev_json = project_root / "data" / "dev.json"
|
| 45 |
+
# gold_sql = project_root / "data" / "dev_gold.sql"
|
| 46 |
+
# pred_path = project_root / "predictions.txt"
|
| 47 |
+
|
| 48 |
+
# if not adapter_dir.exists():
|
| 49 |
+
# raise FileNotFoundError(f"Missing adapter dir: {adapter_dir}")
|
| 50 |
+
|
| 51 |
+
# # ---------------- DEVICE ----------------
|
| 52 |
+
# device = "mps" if torch.backends.mps.is_available() else (
|
| 53 |
+
# "cuda" if torch.cuda.is_available() else "cpu"
|
| 54 |
+
# )
|
| 55 |
+
# print("Using device:", device)
|
| 56 |
+
|
| 57 |
+
# # ---------------- LOAD MODEL ----------------
|
| 58 |
+
# BASE_MODEL = "Salesforce/codet5-base"
|
| 59 |
+
|
| 60 |
+
# tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
|
| 61 |
+
|
| 62 |
+
# if tokenizer.pad_token is None:
|
| 63 |
+
# tokenizer.pad_token = tokenizer.eos_token
|
| 64 |
+
|
| 65 |
+
# base = AutoModelForSeq2SeqLM.from_pretrained(BASE_MODEL).to(device)
|
| 66 |
+
# model = PeftModel.from_pretrained(base, str(adapter_dir)).to(device)
|
| 67 |
+
|
| 68 |
+
# model = model.merge_and_unload()
|
| 69 |
+
# model.eval()
|
| 70 |
+
|
| 71 |
+
# # ---------------- LOAD DATA ----------------
|
| 72 |
+
# with dev_json.open() as f:
|
| 73 |
+
# dev = json.load(f)
|
| 74 |
+
|
| 75 |
+
# if args.shuffle_dev:
|
| 76 |
+
# rng = random.Random(args.shuffle_seed)
|
| 77 |
+
# rng.shuffle(dev)
|
| 78 |
+
|
| 79 |
+
# dev = dev[: args.num_samples]
|
| 80 |
+
|
| 81 |
+
# # ---------------- GENERATION CONFIG ----------------
|
| 82 |
+
# gen_kwargs = dict(
|
| 83 |
+
# max_new_tokens=160,
|
| 84 |
+
# num_beams=4,
|
| 85 |
+
# do_sample=False,
|
| 86 |
+
# early_stopping=True,
|
| 87 |
+
# pad_token_id=tokenizer.pad_token_id,
|
| 88 |
+
# eos_token_id=tokenizer.eos_token_id,
|
| 89 |
+
# )
|
| 90 |
+
|
| 91 |
+
# print("Generating predictions...\n")
|
| 92 |
+
|
| 93 |
+
# correct = 0
|
| 94 |
+
# total = len(dev)
|
| 95 |
+
# accuracy_log_fh = None
|
| 96 |
+
|
| 97 |
+
# if args.accuracy_log:
|
| 98 |
+
# accuracy_log_path = (project_root / args.accuracy_log).resolve()
|
| 99 |
+
# accuracy_log_path.parent.mkdir(parents=True, exist_ok=True)
|
| 100 |
+
# accuracy_log_fh = accuracy_log_path.open("w")
|
| 101 |
+
# print(f"Writing running accuracy log to: {accuracy_log_path}")
|
| 102 |
+
|
| 103 |
+
# with pred_path.open("w") as out_f, torch.no_grad():
|
| 104 |
+
|
| 105 |
+
# for i, ex in enumerate(dev, start=1):
|
| 106 |
+
|
| 107 |
+
# db_id = ex["db_id"]
|
| 108 |
+
# question = ex["question"]
|
| 109 |
+
# gold_query = ex["query"]
|
| 110 |
+
|
| 111 |
+
# input_ids = encode_prompt(
|
| 112 |
+
# tokenizer,
|
| 113 |
+
# question,
|
| 114 |
+
# db_id,
|
| 115 |
+
# device=device,
|
| 116 |
+
# max_input_tokens=512,
|
| 117 |
+
# )
|
| 118 |
+
|
| 119 |
+
# input_ids = input_ids.unsqueeze(0).to(device)
|
| 120 |
+
# attention_mask = (input_ids != tokenizer.pad_token_id).long().to(device)
|
| 121 |
+
|
| 122 |
+
# outputs = model.generate(
|
| 123 |
+
# input_ids=input_ids,
|
| 124 |
+
# attention_mask=attention_mask,
|
| 125 |
+
# **gen_kwargs
|
| 126 |
+
# )
|
| 127 |
+
|
| 128 |
+
# pred_sql = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
|
| 129 |
+
# out_f.write(f"{pred_sql}\t{db_id}\n")
|
| 130 |
+
|
| 131 |
+
# # ---------------- LIVE EXECUTION CHECK ----------------
|
| 132 |
+
# try:
|
| 133 |
+
# db_path = db_root / db_id / f"{db_id}.sqlite"
|
| 134 |
+
|
| 135 |
+
# conn = sqlite3.connect(db_path)
|
| 136 |
+
# cursor = conn.cursor()
|
| 137 |
+
|
| 138 |
+
# cursor.execute(pred_sql)
|
| 139 |
+
# pred_rows = cursor.fetchall()
|
| 140 |
+
|
| 141 |
+
# cursor.execute(gold_query)
|
| 142 |
+
# gold_rows = cursor.fetchall()
|
| 143 |
+
|
| 144 |
+
# conn.close()
|
| 145 |
+
|
| 146 |
+
# if sorted(pred_rows) == sorted(gold_rows):
|
| 147 |
+
# correct += 1
|
| 148 |
+
|
| 149 |
+
# except Exception:
|
| 150 |
+
# pass # execution failed
|
| 151 |
+
|
| 152 |
+
# # 🔥 PRINT EVERY 10
|
| 153 |
+
# if i % 10 == 0 or i == total:
|
| 154 |
+
# current_acc = correct / i
|
| 155 |
+
# line = f"{i}/{total} | Acc: {current_acc:.3f}"
|
| 156 |
+
# print(line)
|
| 157 |
+
# if accuracy_log_fh is not None:
|
| 158 |
+
# accuracy_log_fh.write(line + "\n")
|
| 159 |
+
|
| 160 |
+
# if accuracy_log_fh is not None:
|
| 161 |
+
# accuracy_log_fh.close()
|
| 162 |
+
|
| 163 |
+
# print("\nGeneration finished.\n")
|
| 164 |
+
|
| 165 |
+
# # ---------------- OFFICIAL SPIDER EVAL ----------------
|
| 166 |
+
# eval_script = project_root / "spider_eval" / "evaluation.py"
|
| 167 |
+
|
| 168 |
+
# cmd = [
|
| 169 |
+
# sys.executable,
|
| 170 |
+
# str(eval_script),
|
| 171 |
+
# "--gold", str(gold_sql),
|
| 172 |
+
# "--pred", str(pred_path),
|
| 173 |
+
# "--etype", "exec",
|
| 174 |
+
# "--db", str(db_root),
|
| 175 |
+
# "--table", str(table_json),
|
| 176 |
+
# ]
|
| 177 |
+
|
| 178 |
+
# print("Running Spider evaluation...")
|
| 179 |
+
# proc = subprocess.run(cmd, capture_output=True, text=True)
|
| 180 |
+
|
| 181 |
+
# print(proc.stdout)
|
| 182 |
+
|
| 183 |
+
# exec_acc = _parse_exec_accuracy(proc.stdout)
|
| 184 |
+
# if exec_acc is not None:
|
| 185 |
+
# print(f"\n🎯 Official Execution Accuracy: {exec_acc*100:.2f}%")
|
| 186 |
+
# else:
|
| 187 |
+
# print("Could not parse accuracy.")
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
# if __name__ == "__main__":
|
| 191 |
+
# main()
|
| 192 |
+
|
| 193 |
+
import json
|
| 194 |
+
import subprocess
|
| 195 |
+
import sys
|
| 196 |
+
import argparse
|
| 197 |
+
import random
|
| 198 |
+
import sqlite3
|
| 199 |
+
import time
|
| 200 |
+
import re
|
| 201 |
+
from pathlib import Path
|
| 202 |
+
|
| 203 |
+
import torch
|
| 204 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
| 205 |
+
from peft import PeftModel
|
| 206 |
+
|
| 207 |
+
# Assuming you have a prompting.py that has encode_prompt
|
| 208 |
+
from prompting import encode_prompt
|
| 209 |
+
|
| 210 |
+
# -------------------------------
|
| 211 |
+
# LIVE CHECK HELPERS
|
| 212 |
+
# -------------------------------
|
| 213 |
+
def normalize_sql(sql):
|
| 214 |
+
"""Basic normalization for the live progress bar."""
|
| 215 |
+
sql = sql.replace('"', "'")
|
| 216 |
+
sql = re.sub(r"\s+", " ", sql)
|
| 217 |
+
return sql.strip().lower().rstrip(";")
|
| 218 |
+
|
| 219 |
+
def check_execution(pred_sql, gold_sql, db_path):
|
| 220 |
+
"""Basic execution check for the live progress bar."""
|
| 221 |
+
try:
|
| 222 |
+
conn = sqlite3.connect(db_path)
|
| 223 |
+
conn.text_factory = lambda b: b.decode(errors='ignore')
|
| 224 |
+
|
| 225 |
+
# 2-second timeout so the live tracker doesn't freeze forever
|
| 226 |
+
start_time = time.monotonic()
|
| 227 |
+
def timeout_handler():
|
| 228 |
+
return 1 if (time.monotonic() - start_time) > 2.0 else 0
|
| 229 |
+
conn.set_progress_handler(timeout_handler, 10000)
|
| 230 |
+
|
| 231 |
+
cursor = conn.cursor()
|
| 232 |
+
cursor.execute(pred_sql)
|
| 233 |
+
pred_res = cursor.fetchall()
|
| 234 |
+
|
| 235 |
+
cursor.execute(gold_sql)
|
| 236 |
+
gold_res = cursor.fetchall()
|
| 237 |
+
conn.close()
|
| 238 |
+
|
| 239 |
+
# Simple sorted check for the live tracker
|
| 240 |
+
return sorted(pred_res) == sorted(gold_res)
|
| 241 |
+
except Exception:
|
| 242 |
+
return False
|
| 243 |
+
|
| 244 |
+
# -------------------------------
|
| 245 |
+
# SPIDER PARSER
|
| 246 |
+
# -------------------------------
|
| 247 |
+
def _parse_spider_accuracy(stdout: str, metric_type: str) -> float | None:
|
| 248 |
+
for line in stdout.splitlines():
|
| 249 |
+
if metric_type == "exec" and line.strip().startswith("execution"):
|
| 250 |
+
try: return float(line.split()[-1])
|
| 251 |
+
except: pass
|
| 252 |
+
elif metric_type == "match" and line.strip().startswith("exact"):
|
| 253 |
+
try: return float(line.split()[-1])
|
| 254 |
+
except: pass
|
| 255 |
+
return None
|
| 256 |
+
|
| 257 |
+
# -------------------------------
|
| 258 |
+
# MAIN
|
| 259 |
+
# -------------------------------
|
| 260 |
+
def main():
|
| 261 |
+
parser = argparse.ArgumentParser()
|
| 262 |
+
parser.add_argument("--adapter", type=str, required=True, help="Path to your SFT or RLHF checkpoint")
|
| 263 |
+
parser.add_argument("--num_samples", type=int, default=1034, help="Number of samples to evaluate")
|
| 264 |
+
parser.add_argument("--shuffle_dev", action="store_true")
|
| 265 |
+
parser.add_argument("--shuffle_seed", type=int, default=42)
|
| 266 |
+
args = parser.parse_args()
|
| 267 |
+
|
| 268 |
+
project_root = Path(__file__).resolve().parents[1]
|
| 269 |
+
adapter_dir = project_root / args.adapter
|
| 270 |
+
|
| 271 |
+
db_root = project_root / "data" / "database"
|
| 272 |
+
table_json = project_root / "data" / "tables.json"
|
| 273 |
+
dev_json = project_root / "data" / "dev.json"
|
| 274 |
+
|
| 275 |
+
pred_path = project_root / "temp_predictions.txt"
|
| 276 |
+
temp_gold_path = project_root / "temp_gold.sql"
|
| 277 |
+
|
| 278 |
+
if not adapter_dir.exists():
|
| 279 |
+
raise FileNotFoundError(f"Missing adapter dir: {adapter_dir}")
|
| 280 |
+
|
| 281 |
+
device = "mps" if torch.backends.mps.is_available() else ("cuda" if torch.cuda.is_available() else "cpu")
|
| 282 |
+
print(f"Using device: {device}")
|
| 283 |
+
|
| 284 |
+
BASE_MODEL = "Salesforce/codet5-base"
|
| 285 |
+
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
|
| 286 |
+
if tokenizer.pad_token is None:
|
| 287 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 288 |
+
|
| 289 |
+
print(f"Loading Model: {args.adapter}...")
|
| 290 |
+
base = AutoModelForSeq2SeqLM.from_pretrained(BASE_MODEL).to(device)
|
| 291 |
+
model = PeftModel.from_pretrained(base, str(adapter_dir)).to(device)
|
| 292 |
+
model = model.merge_and_unload()
|
| 293 |
+
model.eval()
|
| 294 |
+
|
| 295 |
+
with dev_json.open() as f:
|
| 296 |
+
dev = json.load(f)
|
| 297 |
+
|
| 298 |
+
if args.shuffle_dev:
|
| 299 |
+
rng = random.Random(args.shuffle_seed)
|
| 300 |
+
rng.shuffle(dev)
|
| 301 |
+
|
| 302 |
+
dev = dev[: args.num_samples]
|
| 303 |
+
total = len(dev)
|
| 304 |
+
|
| 305 |
+
gen_kwargs = dict(
|
| 306 |
+
max_new_tokens=160,
|
| 307 |
+
num_beams=4,
|
| 308 |
+
do_sample=False,
|
| 309 |
+
early_stopping=True,
|
| 310 |
+
pad_token_id=tokenizer.pad_token_id,
|
| 311 |
+
eos_token_id=tokenizer.eos_token_id,
|
| 312 |
+
)
|
| 313 |
+
|
| 314 |
+
print(f"\n🚀 Generating and live-tracking {total} samples...\n")
|
| 315 |
+
|
| 316 |
+
em_correct = 0
|
| 317 |
+
ex_correct = 0
|
| 318 |
+
|
| 319 |
+
with pred_path.open("w") as out_pred, temp_gold_path.open("w") as out_gold, torch.no_grad():
|
| 320 |
+
for i, ex in enumerate(dev, start=1):
|
| 321 |
+
db_id = ex["db_id"]
|
| 322 |
+
question = ex["question"]
|
| 323 |
+
gold_query = ex["query"]
|
| 324 |
+
db_path = db_root / db_id / f"{db_id}.sqlite"
|
| 325 |
+
|
| 326 |
+
# Generate
|
| 327 |
+
input_ids = encode_prompt(tokenizer, question, db_id, device=device, max_input_tokens=512)
|
| 328 |
+
input_ids = input_ids.unsqueeze(0).to(device)
|
| 329 |
+
attention_mask = (input_ids != tokenizer.pad_token_id).long().to(device)
|
| 330 |
+
|
| 331 |
+
outputs = model.generate(input_ids=input_ids, attention_mask=attention_mask, **gen_kwargs)
|
| 332 |
+
pred_sql = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
|
| 333 |
+
|
| 334 |
+
# Write to files for official spider eval later
|
| 335 |
+
out_pred.write(f"{pred_sql}\n")
|
| 336 |
+
out_gold.write(f"{gold_query}\t{db_id}\n")
|
| 337 |
+
|
| 338 |
+
# --- LIVE TRACKING CHECKS ---
|
| 339 |
+
if normalize_sql(pred_sql) == normalize_sql(gold_query):
|
| 340 |
+
em_correct += 1
|
| 341 |
+
if check_execution(pred_sql, gold_query, db_path):
|
| 342 |
+
ex_correct += 1
|
| 343 |
+
|
| 344 |
+
# Print progress every 50 loops
|
| 345 |
+
if i % 50 == 0 or i == total:
|
| 346 |
+
print(f"Progress: {i}/{total} | Current EM: {(em_correct/i)*100:.2f}% | Current EX: {(ex_correct/i)*100:.2f}%")
|
| 347 |
+
|
| 348 |
+
print("\nGeneration finished. Running Official Spider Evaluations for final numbers...\n")
|
| 349 |
+
|
| 350 |
+
eval_script = project_root / "spider_eval" / "evaluation.py"
|
| 351 |
+
|
| 352 |
+
# 1. RUN EXACT MATCH EVAL
|
| 353 |
+
cmd_match = [
|
| 354 |
+
sys.executable, str(eval_script),
|
| 355 |
+
"--gold", str(temp_gold_path),
|
| 356 |
+
"--pred", str(pred_path),
|
| 357 |
+
"--etype", "match",
|
| 358 |
+
"--db", str(db_root),
|
| 359 |
+
"--table", str(table_json),
|
| 360 |
+
]
|
| 361 |
+
proc_match = subprocess.run(cmd_match, capture_output=True, text=True)
|
| 362 |
+
exact_acc = _parse_spider_accuracy(proc_match.stdout, "match")
|
| 363 |
+
|
| 364 |
+
# 2. RUN EXECUTION EVAL
|
| 365 |
+
cmd_exec = [
|
| 366 |
+
sys.executable, str(eval_script),
|
| 367 |
+
"--gold", str(temp_gold_path),
|
| 368 |
+
"--pred", str(pred_path),
|
| 369 |
+
"--etype", "exec",
|
| 370 |
+
"--db", str(db_root),
|
| 371 |
+
"--table", str(table_json),
|
| 372 |
+
]
|
| 373 |
+
proc_exec = subprocess.run(cmd_exec, capture_output=True, text=True)
|
| 374 |
+
exec_acc = _parse_spider_accuracy(proc_exec.stdout, "exec")
|
| 375 |
+
|
| 376 |
+
print("==========================================")
|
| 377 |
+
print(f"🎯 OFFICIAL SPIDER RESULTS FOR: {args.adapter}")
|
| 378 |
+
print("==========================================")
|
| 379 |
+
|
| 380 |
+
if exact_acc is not None:
|
| 381 |
+
print(f"Exact Set Match Accuracy : {exact_acc*100:.2f}%")
|
| 382 |
+
else:
|
| 383 |
+
print("Exact Set Match Accuracy : Could not parse output")
|
| 384 |
+
|
| 385 |
+
if exec_acc is not None:
|
| 386 |
+
print(f"Execution Accuracy : {exec_acc*100:.2f}%")
|
| 387 |
+
else:
|
| 388 |
+
print("Execution Accuracy : Could not parse output")
|
| 389 |
+
print("==========================================\n")
|
| 390 |
+
|
| 391 |
+
if __name__ == "__main__":
|
| 392 |
+
main()
|
src/evaluate_model_t5_small_sft.py
ADDED
|
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import subprocess
|
| 5 |
+
import sys
|
| 6 |
+
import argparse
|
| 7 |
+
import re
|
| 8 |
+
import sqlite3
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
| 13 |
+
from peft import PeftModel
|
| 14 |
+
from prompting import encode_prompt
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
# ---------------- PARSE ACC ----------------
|
| 18 |
+
def _parse_exec_accuracy(stdout: str) -> float | None:
|
| 19 |
+
for line in stdout.splitlines():
|
| 20 |
+
if line.strip().startswith("execution"):
|
| 21 |
+
try:
|
| 22 |
+
return float(line.split()[-1])
|
| 23 |
+
except:
|
| 24 |
+
return None
|
| 25 |
+
return None
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# ---------------- CLEAN SQL ----------------
|
| 29 |
+
def clean_prediction(pred_sql: str) -> str:
|
| 30 |
+
pred_sql = pred_sql.strip()
|
| 31 |
+
|
| 32 |
+
if "SQL:" in pred_sql:
|
| 33 |
+
pred_sql = pred_sql.split("SQL:")[-1]
|
| 34 |
+
|
| 35 |
+
pred_sql = pred_sql.replace('"', "'")
|
| 36 |
+
pred_sql = re.sub(r"\s+", " ", pred_sql).strip()
|
| 37 |
+
|
| 38 |
+
if not pred_sql.endswith(";"):
|
| 39 |
+
pred_sql += ";"
|
| 40 |
+
|
| 41 |
+
return pred_sql
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def main():
|
| 45 |
+
|
| 46 |
+
parser = argparse.ArgumentParser()
|
| 47 |
+
parser.add_argument("--adapter", type=str, default="checkpoints/sft_t5")
|
| 48 |
+
parser.add_argument("--num_samples", type=int, default=1000)
|
| 49 |
+
args = parser.parse_args()
|
| 50 |
+
|
| 51 |
+
project_root = Path(__file__).resolve().parents[1]
|
| 52 |
+
adapter_dir = project_root / args.adapter
|
| 53 |
+
|
| 54 |
+
db_root = project_root / "data/database"
|
| 55 |
+
table_json = project_root / "data/tables.json"
|
| 56 |
+
dev_json = project_root / "data/dev.json"
|
| 57 |
+
gold_sql = project_root / "data/dev_gold.sql"
|
| 58 |
+
pred_path = project_root / "pred.sql"
|
| 59 |
+
|
| 60 |
+
if not adapter_dir.exists():
|
| 61 |
+
raise FileNotFoundError(f"Missing adapter dir: {adapter_dir}")
|
| 62 |
+
|
| 63 |
+
# ---------------- DEVICE ----------------
|
| 64 |
+
device = "mps" if torch.backends.mps.is_available() else (
|
| 65 |
+
"cuda" if torch.cuda.is_available() else "cpu"
|
| 66 |
+
)
|
| 67 |
+
print("Using device:", device)
|
| 68 |
+
|
| 69 |
+
# ---------------- LOAD MODEL ----------------
|
| 70 |
+
BASE_MODEL = "t5-small"
|
| 71 |
+
|
| 72 |
+
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
|
| 73 |
+
base = AutoModelForSeq2SeqLM.from_pretrained(BASE_MODEL).to(device)
|
| 74 |
+
|
| 75 |
+
model = PeftModel.from_pretrained(base, str(adapter_dir)).to(device)
|
| 76 |
+
model = model.merge_and_unload()
|
| 77 |
+
model.eval()
|
| 78 |
+
|
| 79 |
+
if tokenizer.pad_token_id is None:
|
| 80 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 81 |
+
|
| 82 |
+
# ---------------- LOAD DATA ----------------
|
| 83 |
+
with dev_json.open() as f:
|
| 84 |
+
dev = json.load(f)[: args.num_samples]
|
| 85 |
+
|
| 86 |
+
print("Generating predictions...\n")
|
| 87 |
+
|
| 88 |
+
correct = 0
|
| 89 |
+
total = len(dev)
|
| 90 |
+
|
| 91 |
+
# ---------------- GENERATE + LIVE EXEC ----------------
|
| 92 |
+
with pred_path.open("w") as out_f, torch.no_grad():
|
| 93 |
+
|
| 94 |
+
for i, ex in enumerate(dev, start=1):
|
| 95 |
+
|
| 96 |
+
db_id = ex["db_id"]
|
| 97 |
+
question = ex["question"]
|
| 98 |
+
gold_query = ex["query"]
|
| 99 |
+
|
| 100 |
+
prompt_ids = encode_prompt(
|
| 101 |
+
tokenizer,
|
| 102 |
+
question,
|
| 103 |
+
db_id,
|
| 104 |
+
device=device,
|
| 105 |
+
max_input_tokens=512,
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
input_ids = prompt_ids.unsqueeze(0).to(device)
|
| 109 |
+
attention_mask = (input_ids != tokenizer.pad_token_id).long().to(device)
|
| 110 |
+
|
| 111 |
+
outputs = model.generate(
|
| 112 |
+
input_ids=input_ids,
|
| 113 |
+
attention_mask=attention_mask,
|
| 114 |
+
max_new_tokens=160,
|
| 115 |
+
num_beams=4,
|
| 116 |
+
do_sample=False,
|
| 117 |
+
early_stopping=True,
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
pred_sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 121 |
+
pred_sql = clean_prediction(pred_sql)
|
| 122 |
+
|
| 123 |
+
out_f.write(pred_sql + "\n")
|
| 124 |
+
|
| 125 |
+
# -------- LIVE EXECUTION CHECK --------
|
| 126 |
+
try:
|
| 127 |
+
db_path = db_root / db_id / f"{db_id}.sqlite"
|
| 128 |
+
|
| 129 |
+
conn = sqlite3.connect(db_path)
|
| 130 |
+
cursor = conn.cursor()
|
| 131 |
+
|
| 132 |
+
cursor.execute(pred_sql)
|
| 133 |
+
pred_rows = cursor.fetchall()
|
| 134 |
+
|
| 135 |
+
cursor.execute(gold_query)
|
| 136 |
+
gold_rows = cursor.fetchall()
|
| 137 |
+
|
| 138 |
+
conn.close()
|
| 139 |
+
|
| 140 |
+
if sorted(pred_rows) == sorted(gold_rows):
|
| 141 |
+
correct += 1
|
| 142 |
+
|
| 143 |
+
except Exception:
|
| 144 |
+
pass # execution failed
|
| 145 |
+
|
| 146 |
+
# 🔥 PRINT EVERY 10
|
| 147 |
+
if i % 10 == 0 or i == total:
|
| 148 |
+
current_acc = correct / i
|
| 149 |
+
print(f"{i}/{total} | Acc: {current_acc:.3f}")
|
| 150 |
+
|
| 151 |
+
print("\nGeneration finished.\n")
|
| 152 |
+
|
| 153 |
+
# ---------------- SPIDER EVAL ----------------
|
| 154 |
+
eval_script = project_root / "spider_eval/evaluation.py"
|
| 155 |
+
|
| 156 |
+
cmd = [
|
| 157 |
+
sys.executable,
|
| 158 |
+
str(eval_script),
|
| 159 |
+
"--gold", str(gold_sql),
|
| 160 |
+
"--pred", str(pred_path),
|
| 161 |
+
"--etype", "exec",
|
| 162 |
+
"--db", str(db_root),
|
| 163 |
+
"--table", str(table_json),
|
| 164 |
+
]
|
| 165 |
+
|
| 166 |
+
print("Running Spider evaluation...")
|
| 167 |
+
proc = subprocess.run(cmd, capture_output=True, text=True)
|
| 168 |
+
|
| 169 |
+
print(proc.stdout)
|
| 170 |
+
|
| 171 |
+
exec_acc = _parse_exec_accuracy(proc.stdout)
|
| 172 |
+
if exec_acc is not None:
|
| 173 |
+
print(f"\n🎯 Official Execution Accuracy: {exec_acc*100:.2f}%")
|
| 174 |
+
else:
|
| 175 |
+
print("Could not parse accuracy.")
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
if __name__ == "__main__":
|
| 179 |
+
main()
|
src/evaluate_rl_bart.py
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import json
|
| 3 |
+
import sqlite3
|
| 4 |
+
import argparse
|
| 5 |
+
import time
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
import torch
|
| 8 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
| 9 |
+
from peft import PeftModel
|
| 10 |
+
|
| 11 |
+
# ---------------- PROMPT (IDENTICAL TO TRAINING) ----------------
|
| 12 |
+
def build_prompt(question, schema):
|
| 13 |
+
return f"""
|
| 14 |
+
Database Schema:
|
| 15 |
+
{schema}
|
| 16 |
+
|
| 17 |
+
Translate English to SQL:
|
| 18 |
+
{question}
|
| 19 |
+
SQL:
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
# ---------------- LOAD SCHEMA ----------------
|
| 23 |
+
def load_schema(db_path):
|
| 24 |
+
conn = sqlite3.connect(db_path)
|
| 25 |
+
cursor = conn.cursor()
|
| 26 |
+
|
| 27 |
+
tables = cursor.execute(
|
| 28 |
+
"SELECT name FROM sqlite_master WHERE type='table';"
|
| 29 |
+
).fetchall()
|
| 30 |
+
|
| 31 |
+
schema = ""
|
| 32 |
+
for (table,) in tables:
|
| 33 |
+
cols = cursor.execute(f"PRAGMA table_info({table});").fetchall()
|
| 34 |
+
col_names = [c[1] for c in cols]
|
| 35 |
+
schema += f"{table}({', '.join(col_names)})\n"
|
| 36 |
+
|
| 37 |
+
conn.close()
|
| 38 |
+
return schema
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
# ---------------- EXECUTION CHECK WITH TIMEOUT ----------------
|
| 42 |
+
def execution_match(pred_sql, gold_sql, db_path):
|
| 43 |
+
try:
|
| 44 |
+
conn = sqlite3.connect(db_path)
|
| 45 |
+
|
| 46 |
+
# --- 5-SECOND TIMEOUT SO EVALUATION DOESN'T FREEZE ---
|
| 47 |
+
start_time = time.monotonic()
|
| 48 |
+
def timeout_handler():
|
| 49 |
+
return 1 if (time.monotonic() - start_time) > 5.0 else 0
|
| 50 |
+
conn.set_progress_handler(timeout_handler, 10000)
|
| 51 |
+
|
| 52 |
+
cur = conn.cursor()
|
| 53 |
+
|
| 54 |
+
cur.execute(pred_sql)
|
| 55 |
+
pred = cur.fetchall()
|
| 56 |
+
|
| 57 |
+
cur.execute(gold_sql)
|
| 58 |
+
gold = cur.fetchall()
|
| 59 |
+
|
| 60 |
+
conn.close()
|
| 61 |
+
return pred == gold
|
| 62 |
+
|
| 63 |
+
except Exception:
|
| 64 |
+
return False
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
# ---------------- MAIN ----------------
|
| 68 |
+
def main():
|
| 69 |
+
parser = argparse.ArgumentParser()
|
| 70 |
+
parser.add_argument("--adapter", type=str, required=True)
|
| 71 |
+
parser.add_argument("--num_samples", type=int, default=1034)
|
| 72 |
+
args = parser.parse_args()
|
| 73 |
+
|
| 74 |
+
project_root = Path(__file__).resolve().parents[1]
|
| 75 |
+
|
| 76 |
+
dev_json = project_root / "data" / "dev.json"
|
| 77 |
+
db_root = project_root / "data" / "database"
|
| 78 |
+
|
| 79 |
+
# 🎯 Added CUDA support for Nvidia GPUs
|
| 80 |
+
device = "mps" if torch.backends.mps.is_available() else ("cuda" if torch.cuda.is_available() else "cpu")
|
| 81 |
+
|
| 82 |
+
# load model
|
| 83 |
+
base_model = "facebook/bart-base"
|
| 84 |
+
print(f"Loading Base: {base_model}")
|
| 85 |
+
print(f"Loading Adapter: {args.adapter}")
|
| 86 |
+
|
| 87 |
+
tokenizer = AutoTokenizer.from_pretrained(args.adapter)
|
| 88 |
+
base = AutoModelForSeq2SeqLM.from_pretrained(base_model).to(device)
|
| 89 |
+
model = PeftModel.from_pretrained(base, args.adapter).to(device)
|
| 90 |
+
model = model.merge_and_unload()
|
| 91 |
+
|
| 92 |
+
with open(dev_json) as f:
|
| 93 |
+
dev = json.load(f)[: args.num_samples]
|
| 94 |
+
|
| 95 |
+
correct = 0
|
| 96 |
+
|
| 97 |
+
print(f"Evaluating {len(dev)} examples...\n")
|
| 98 |
+
|
| 99 |
+
for i, ex in enumerate(dev, 1):
|
| 100 |
+
question = ex["question"]
|
| 101 |
+
db_id = ex["db_id"]
|
| 102 |
+
gold_sql = ex["query"]
|
| 103 |
+
|
| 104 |
+
db_path = db_root / db_id / f"{db_id}.sqlite"
|
| 105 |
+
schema = load_schema(db_path)
|
| 106 |
+
|
| 107 |
+
prompt = build_prompt(question, schema)
|
| 108 |
+
|
| 109 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(device)
|
| 110 |
+
|
| 111 |
+
with torch.no_grad():
|
| 112 |
+
outputs = model.generate(
|
| 113 |
+
**inputs,
|
| 114 |
+
max_new_tokens=80,
|
| 115 |
+
do_sample=False,
|
| 116 |
+
num_beams=4,
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
pred_sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 120 |
+
|
| 121 |
+
if "SQL:" in pred_sql:
|
| 122 |
+
pred_sql = pred_sql.split("SQL:")[-1].strip()
|
| 123 |
+
|
| 124 |
+
match = execution_match(pred_sql, gold_sql, db_path)
|
| 125 |
+
|
| 126 |
+
if match:
|
| 127 |
+
correct += 1
|
| 128 |
+
|
| 129 |
+
if i % 10 == 0:
|
| 130 |
+
print(f"{i}/{len(dev)} | Acc: {correct/i:.3f}")
|
| 131 |
+
|
| 132 |
+
print("\n=============================")
|
| 133 |
+
print(f"FINAL EXECUTION ACCURACY: {correct/len(dev)*100:.2f}%")
|
| 134 |
+
print("=============================")
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
if __name__ == "__main__":
|
| 138 |
+
main()
|
src/evaluate_sft_bart.py
ADDED
|
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import subprocess
|
| 5 |
+
import sys
|
| 6 |
+
import argparse
|
| 7 |
+
import re
|
| 8 |
+
import sqlite3
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
| 13 |
+
from peft import PeftModel
|
| 14 |
+
from prompting import encode_prompt
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
# ---------------- SQL CLEAN ----------------
|
| 18 |
+
def extract_sql(text: str) -> str:
|
| 19 |
+
text = text.strip()
|
| 20 |
+
|
| 21 |
+
if "SQL:" in text:
|
| 22 |
+
text = text.split("SQL:")[-1]
|
| 23 |
+
|
| 24 |
+
match = re.search(r"(SELECT .*?)(?:$)", text, re.IGNORECASE | re.DOTALL)
|
| 25 |
+
if match:
|
| 26 |
+
text = match.group(1)
|
| 27 |
+
|
| 28 |
+
text = text.replace('"', "'")
|
| 29 |
+
text = re.sub(r"\s+", " ", text).strip()
|
| 30 |
+
|
| 31 |
+
if not text.endswith(";"):
|
| 32 |
+
text += ";"
|
| 33 |
+
|
| 34 |
+
return text
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
# ---------------- ROBUST ACC PARSER ----------------
|
| 38 |
+
def parse_exec_accuracy(stdout: str):
|
| 39 |
+
for line in stdout.splitlines():
|
| 40 |
+
if "execution" in line.lower():
|
| 41 |
+
numbers = re.findall(r"\d+\.\d+", line)
|
| 42 |
+
if numbers:
|
| 43 |
+
return float(numbers[-1])
|
| 44 |
+
return None
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def main():
|
| 48 |
+
|
| 49 |
+
parser = argparse.ArgumentParser()
|
| 50 |
+
parser.add_argument("--adapter", type=str, default="checkpoints/sft_best_bart_2")
|
| 51 |
+
parser.add_argument("--num_samples", type=int, default=1000)
|
| 52 |
+
args = parser.parse_args()
|
| 53 |
+
|
| 54 |
+
project_root = Path(__file__).resolve().parents[1]
|
| 55 |
+
adapter_dir = project_root / args.adapter
|
| 56 |
+
|
| 57 |
+
if not adapter_dir.exists():
|
| 58 |
+
raise FileNotFoundError(f"Adapter not found: {adapter_dir}")
|
| 59 |
+
|
| 60 |
+
db_root = project_root / "data/database"
|
| 61 |
+
table_json = project_root / "data/tables.json"
|
| 62 |
+
dev_json = project_root / "data/dev.json"
|
| 63 |
+
gold_sql_file = project_root / "data/dev_gold.sql"
|
| 64 |
+
pred_sql_file = project_root / "pred.sql"
|
| 65 |
+
|
| 66 |
+
device = "mps" if torch.backends.mps.is_available() else (
|
| 67 |
+
"cuda" if torch.cuda.is_available() else "cpu"
|
| 68 |
+
)
|
| 69 |
+
print("Using device:", device)
|
| 70 |
+
|
| 71 |
+
# -------- LOAD MODEL --------
|
| 72 |
+
print("Loading tokenizer...")
|
| 73 |
+
tokenizer = AutoTokenizer.from_pretrained(adapter_dir)
|
| 74 |
+
|
| 75 |
+
BASE_MODEL = "facebook/bart-base"
|
| 76 |
+
print(f"Loading base model {BASE_MODEL}...")
|
| 77 |
+
base_model = AutoModelForSeq2SeqLM.from_pretrained(BASE_MODEL).to(device)
|
| 78 |
+
|
| 79 |
+
print("Loading LoRA adapter...")
|
| 80 |
+
model = PeftModel.from_pretrained(base_model, adapter_dir).to(device)
|
| 81 |
+
model = model.merge_and_unload()
|
| 82 |
+
model.eval()
|
| 83 |
+
|
| 84 |
+
if tokenizer.pad_token_id is None:
|
| 85 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 86 |
+
|
| 87 |
+
# -------- LOAD DATA --------
|
| 88 |
+
with open(dev_json) as f:
|
| 89 |
+
dev = json.load(f)[: args.num_samples]
|
| 90 |
+
|
| 91 |
+
print("Generating SQL predictions...\n")
|
| 92 |
+
|
| 93 |
+
correct = 0
|
| 94 |
+
total = len(dev)
|
| 95 |
+
|
| 96 |
+
with open(pred_sql_file, "w") as f, torch.no_grad():
|
| 97 |
+
|
| 98 |
+
for i, ex in enumerate(dev, 1):
|
| 99 |
+
|
| 100 |
+
question = ex["question"]
|
| 101 |
+
db_id = ex["db_id"]
|
| 102 |
+
gold_query = ex["query"]
|
| 103 |
+
|
| 104 |
+
prompt_ids = encode_prompt(
|
| 105 |
+
tokenizer,
|
| 106 |
+
question,
|
| 107 |
+
db_id,
|
| 108 |
+
device=device,
|
| 109 |
+
max_input_tokens=512,
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
input_ids = prompt_ids.unsqueeze(0).to(device)
|
| 113 |
+
attention_mask = (input_ids != tokenizer.pad_token_id).long().to(device)
|
| 114 |
+
|
| 115 |
+
outputs = model.generate(
|
| 116 |
+
input_ids=input_ids,
|
| 117 |
+
attention_mask=attention_mask,
|
| 118 |
+
max_new_tokens=160,
|
| 119 |
+
num_beams=4,
|
| 120 |
+
do_sample=False,
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
pred = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 124 |
+
pred_sql = extract_sql(pred)
|
| 125 |
+
|
| 126 |
+
f.write(f"{pred_sql}\t{db_id}\n")
|
| 127 |
+
|
| 128 |
+
# -------- LIVE EXECUTION CHECK --------
|
| 129 |
+
try:
|
| 130 |
+
db_path = db_root / db_id / f"{db_id}.sqlite"
|
| 131 |
+
|
| 132 |
+
conn = sqlite3.connect(db_path)
|
| 133 |
+
cursor = conn.cursor()
|
| 134 |
+
|
| 135 |
+
cursor.execute(pred_sql)
|
| 136 |
+
pred_rows = cursor.fetchall()
|
| 137 |
+
|
| 138 |
+
cursor.execute(gold_query)
|
| 139 |
+
gold_rows = cursor.fetchall()
|
| 140 |
+
|
| 141 |
+
conn.close()
|
| 142 |
+
|
| 143 |
+
# order insensitive comparison
|
| 144 |
+
if sorted(pred_rows) == sorted(gold_rows):
|
| 145 |
+
correct += 1
|
| 146 |
+
|
| 147 |
+
except Exception:
|
| 148 |
+
pass # execution failed
|
| 149 |
+
|
| 150 |
+
if i % 10 == 0 or i == total:
|
| 151 |
+
current_acc = correct / i
|
| 152 |
+
print(f"{i}/{total} | Acc: {current_acc:.3f}")
|
| 153 |
+
|
| 154 |
+
print("\nGeneration finished.\n")
|
| 155 |
+
|
| 156 |
+
# -------- RUN OFFICIAL SPIDER EVAL --------
|
| 157 |
+
eval_script = project_root / "spider_eval/evaluation.py"
|
| 158 |
+
if (project_root / "spider_eval/evaluation_bart.py").exists():
|
| 159 |
+
eval_script = project_root / "spider_eval/evaluation_bart.py"
|
| 160 |
+
|
| 161 |
+
cmd = [
|
| 162 |
+
sys.executable,
|
| 163 |
+
str(eval_script),
|
| 164 |
+
"--gold", str(gold_sql_file),
|
| 165 |
+
"--pred", str(pred_sql_file),
|
| 166 |
+
"--etype", "exec",
|
| 167 |
+
"--db", str(db_root),
|
| 168 |
+
"--table", str(table_json),
|
| 169 |
+
]
|
| 170 |
+
|
| 171 |
+
print(f"\nRunning Spider evaluation using {eval_script.name}...")
|
| 172 |
+
proc = subprocess.run(cmd, capture_output=True, text=True, errors="ignore")
|
| 173 |
+
|
| 174 |
+
if proc.returncode != 0:
|
| 175 |
+
print("\nSpider evaluation crashed.")
|
| 176 |
+
print(proc.stderr)
|
| 177 |
+
return
|
| 178 |
+
|
| 179 |
+
print("\n--- Spider Eval Output ---")
|
| 180 |
+
print("\n".join(proc.stdout.splitlines()[-20:]))
|
| 181 |
+
|
| 182 |
+
acc = parse_exec_accuracy(proc.stdout)
|
| 183 |
+
if acc is not None:
|
| 184 |
+
print(f"\n🎯 Official Execution Accuracy: {acc*100:.2f}%")
|
| 185 |
+
else:
|
| 186 |
+
print("\nCould not parse official accuracy.")
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
if __name__ == "__main__":
|
| 190 |
+
main()
|
src/evaluate_without_constraied.py
ADDED
|
@@ -0,0 +1,503 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
# *********** code till task 3 ************
|
| 3 |
+
|
| 4 |
+
# import json
|
| 5 |
+
# import subprocess
|
| 6 |
+
# import sys
|
| 7 |
+
# import argparse
|
| 8 |
+
# import random
|
| 9 |
+
# import sqlite3
|
| 10 |
+
# import time
|
| 11 |
+
# import re
|
| 12 |
+
# import os
|
| 13 |
+
# from pathlib import Path
|
| 14 |
+
|
| 15 |
+
# import torch
|
| 16 |
+
# from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
| 17 |
+
# from peft import PeftModel
|
| 18 |
+
|
| 19 |
+
# from prompting import encode_prompt
|
| 20 |
+
|
| 21 |
+
# # -------------------------------
|
| 22 |
+
# # NORMALIZATION
|
| 23 |
+
# # -------------------------------
|
| 24 |
+
# def normalize_sql(sql):
|
| 25 |
+
# sql = sql.replace('"', "'")
|
| 26 |
+
# sql = re.sub(r"\s+", " ", sql)
|
| 27 |
+
# return sql.strip().lower().rstrip(";")
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
# # -------------------------------
|
| 31 |
+
# # 🔥 SAFE RESULT NORMALIZATION (FIX)
|
| 32 |
+
# # -------------------------------
|
| 33 |
+
# def normalize_result(res):
|
| 34 |
+
# try:
|
| 35 |
+
# return sorted([str(r) for r in res])
|
| 36 |
+
# except:
|
| 37 |
+
# return []
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
# # -------------------------------
|
| 41 |
+
# # EXECUTION CHECK (FIXED)
|
| 42 |
+
# # -------------------------------
|
| 43 |
+
# def check_execution(pred_sql, gold_sql, db_path):
|
| 44 |
+
# try:
|
| 45 |
+
# conn = sqlite3.connect(db_path)
|
| 46 |
+
# conn.text_factory = lambda b: b.decode(errors='ignore')
|
| 47 |
+
|
| 48 |
+
# start_time = time.monotonic()
|
| 49 |
+
|
| 50 |
+
# def timeout_handler():
|
| 51 |
+
# return 1 if (time.monotonic() - start_time) > 2.0 else 0
|
| 52 |
+
|
| 53 |
+
# conn.set_progress_handler(timeout_handler, 10000)
|
| 54 |
+
|
| 55 |
+
# cursor = conn.cursor()
|
| 56 |
+
|
| 57 |
+
# cursor.execute(pred_sql)
|
| 58 |
+
# pred_res = cursor.fetchall()
|
| 59 |
+
|
| 60 |
+
# cursor.execute(gold_sql)
|
| 61 |
+
# gold_res = cursor.fetchall()
|
| 62 |
+
|
| 63 |
+
# conn.close()
|
| 64 |
+
|
| 65 |
+
# # 🔥 FIXED COMPARISON
|
| 66 |
+
# return normalize_result(pred_res) == normalize_result(gold_res)
|
| 67 |
+
|
| 68 |
+
# except Exception:
|
| 69 |
+
# return False
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
# # -------------------------------
|
| 73 |
+
# # SPIDER PARSER
|
| 74 |
+
# # -------------------------------
|
| 75 |
+
# def _parse_spider_accuracy(stdout: str, metric_type: str):
|
| 76 |
+
# for line in stdout.splitlines():
|
| 77 |
+
# if metric_type == "exec" and line.strip().startswith("execution"):
|
| 78 |
+
# try:
|
| 79 |
+
# return float(line.split()[-1])
|
| 80 |
+
# except:
|
| 81 |
+
# pass
|
| 82 |
+
# elif metric_type == "match" and line.strip().startswith("exact"):
|
| 83 |
+
# try:
|
| 84 |
+
# return float(line.split()[-1])
|
| 85 |
+
# except:
|
| 86 |
+
# pass
|
| 87 |
+
# return None
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
# # -------------------------------
|
| 91 |
+
# # MAIN
|
| 92 |
+
# # -------------------------------
|
| 93 |
+
# def main():
|
| 94 |
+
# parser = argparse.ArgumentParser()
|
| 95 |
+
# parser.add_argument("--adapter", type=str, required=True)
|
| 96 |
+
# parser.add_argument("--num_samples", type=int, default= 500)
|
| 97 |
+
# parser.add_argument("--shuffle_dev", action="store_true")
|
| 98 |
+
# parser.add_argument("--shuffle_seed", type=int, default=42)
|
| 99 |
+
# args = parser.parse_args()
|
| 100 |
+
|
| 101 |
+
# project_root = Path(__file__).resolve().parents[1]
|
| 102 |
+
# adapter_dir = project_root / args.adapter
|
| 103 |
+
|
| 104 |
+
# db_root = project_root / "data" / "database"
|
| 105 |
+
# table_json = project_root / "data" / "tables.json"
|
| 106 |
+
# dev_json = project_root / "data" / "dev.json"
|
| 107 |
+
|
| 108 |
+
# pred_path = project_root / "temp_predictions.txt"
|
| 109 |
+
# temp_gold_path = project_root / "temp_gold.sql"
|
| 110 |
+
|
| 111 |
+
# if not adapter_dir.exists():
|
| 112 |
+
# raise FileNotFoundError(f"Missing adapter dir: {adapter_dir}")
|
| 113 |
+
|
| 114 |
+
# device = "mps" if torch.backends.mps.is_available() else (
|
| 115 |
+
# "cuda" if torch.cuda.is_available() else "cpu"
|
| 116 |
+
# )
|
| 117 |
+
# print(f"Using device: {device}")
|
| 118 |
+
|
| 119 |
+
# BASE_MODEL = "Salesforce/codet5-base"
|
| 120 |
+
# tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
|
| 121 |
+
|
| 122 |
+
# if tokenizer.pad_token is None:
|
| 123 |
+
# tokenizer.pad_token = tokenizer.eos_token
|
| 124 |
+
|
| 125 |
+
# print(f"\n📦 Loading Model: {args.adapter}")
|
| 126 |
+
|
| 127 |
+
# base = AutoModelForSeq2SeqLM.from_pretrained(BASE_MODEL).to(device)
|
| 128 |
+
|
| 129 |
+
# adapter_for_peft = os.path.relpath(adapter_dir, project_root)
|
| 130 |
+
|
| 131 |
+
# model = PeftModel.from_pretrained(
|
| 132 |
+
# base,
|
| 133 |
+
# adapter_for_peft,
|
| 134 |
+
# local_files_only=True
|
| 135 |
+
# ).to(device)
|
| 136 |
+
|
| 137 |
+
# model = model.merge_and_unload()
|
| 138 |
+
# model.eval()
|
| 139 |
+
|
| 140 |
+
# # -------------------------------
|
| 141 |
+
# # LOAD DATA
|
| 142 |
+
# # -------------------------------
|
| 143 |
+
# with dev_json.open() as f:
|
| 144 |
+
# dev = json.load(f)
|
| 145 |
+
|
| 146 |
+
# if args.shuffle_dev:
|
| 147 |
+
# rng = random.Random(args.shuffle_seed)
|
| 148 |
+
# rng.shuffle(dev)
|
| 149 |
+
|
| 150 |
+
# dev = dev[: args.num_samples]
|
| 151 |
+
# total = len(dev)
|
| 152 |
+
|
| 153 |
+
# gen_kwargs = dict(
|
| 154 |
+
# max_new_tokens=160,
|
| 155 |
+
# num_beams=8,
|
| 156 |
+
# length_penalty=0.8,
|
| 157 |
+
# do_sample=False,
|
| 158 |
+
# early_stopping=True,
|
| 159 |
+
# pad_token_id=tokenizer.pad_token_id,
|
| 160 |
+
# eos_token_id=tokenizer.eos_token_id,
|
| 161 |
+
# )
|
| 162 |
+
|
| 163 |
+
# print(f"\n🚀 Evaluating {total} samples...\n")
|
| 164 |
+
|
| 165 |
+
# em_correct = 0
|
| 166 |
+
# ex_correct = 0
|
| 167 |
+
|
| 168 |
+
# with pred_path.open("w") as out_pred, temp_gold_path.open("w") as out_gold, torch.no_grad():
|
| 169 |
+
# for i, ex in enumerate(dev, start=1):
|
| 170 |
+
|
| 171 |
+
# db_id = ex["db_id"]
|
| 172 |
+
# question = ex["question"]
|
| 173 |
+
# gold_query = ex["query"]
|
| 174 |
+
# db_path = db_root / db_id / f"{db_id}.sqlite"
|
| 175 |
+
|
| 176 |
+
# # -------------------------------
|
| 177 |
+
# # GENERATE SQL
|
| 178 |
+
# # -------------------------------
|
| 179 |
+
# input_ids = encode_prompt(
|
| 180 |
+
# tokenizer,
|
| 181 |
+
# question,
|
| 182 |
+
# db_id,
|
| 183 |
+
# device=device,
|
| 184 |
+
# max_input_tokens=512
|
| 185 |
+
# )
|
| 186 |
+
|
| 187 |
+
# input_ids = input_ids.unsqueeze(0).to(device)
|
| 188 |
+
# attention_mask = (input_ids != tokenizer.pad_token_id).long().to(device)
|
| 189 |
+
|
| 190 |
+
# outputs = model.generate(
|
| 191 |
+
# input_ids=input_ids,
|
| 192 |
+
# attention_mask=attention_mask,
|
| 193 |
+
# **gen_kwargs
|
| 194 |
+
# )
|
| 195 |
+
|
| 196 |
+
# pred_sql = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
|
| 197 |
+
|
| 198 |
+
# # -------------------------------
|
| 199 |
+
# # SAVE FOR SPIDER EVAL
|
| 200 |
+
# # -------------------------------
|
| 201 |
+
# out_pred.write(f"{pred_sql}\n")
|
| 202 |
+
# out_gold.write(f"{gold_query}\t{db_id}\n")
|
| 203 |
+
|
| 204 |
+
# # -------------------------------
|
| 205 |
+
# # LIVE METRICS
|
| 206 |
+
# # -------------------------------
|
| 207 |
+
# if normalize_sql(pred_sql) == normalize_sql(gold_query):
|
| 208 |
+
# em_correct += 1
|
| 209 |
+
|
| 210 |
+
# if check_execution(pred_sql, gold_query, db_path):
|
| 211 |
+
# ex_correct += 1
|
| 212 |
+
|
| 213 |
+
# if i % 20 == 0 or i == total:
|
| 214 |
+
# print(
|
| 215 |
+
# f"Progress: {i}/{total} | "
|
| 216 |
+
# f"EM: {(em_correct/i)*100:.2f}% | "
|
| 217 |
+
# f"EX: {(ex_correct/i)*100:.2f}%"
|
| 218 |
+
# )
|
| 219 |
+
|
| 220 |
+
# print("\n🚀 Running Official Spider Evaluation...\n")
|
| 221 |
+
|
| 222 |
+
# eval_script = project_root / "spider_eval" / "evaluation.py"
|
| 223 |
+
|
| 224 |
+
# # EXACT MATCH
|
| 225 |
+
# cmd_match = [
|
| 226 |
+
# sys.executable, str(eval_script),
|
| 227 |
+
# "--gold", str(temp_gold_path),
|
| 228 |
+
# "--pred", str(pred_path),
|
| 229 |
+
# "--etype", "match",
|
| 230 |
+
# "--db", str(db_root),
|
| 231 |
+
# "--table", str(table_json),
|
| 232 |
+
# ]
|
| 233 |
+
|
| 234 |
+
# proc_match = subprocess.run(cmd_match, capture_output=True, text=True)
|
| 235 |
+
# exact_acc = _parse_spider_accuracy(proc_match.stdout, "match")
|
| 236 |
+
|
| 237 |
+
# # EXECUTION
|
| 238 |
+
# cmd_exec = [
|
| 239 |
+
# sys.executable, str(eval_script),
|
| 240 |
+
# "--gold", str(temp_gold_path),
|
| 241 |
+
# "--pred", str(pred_path),
|
| 242 |
+
# "--etype", "exec",
|
| 243 |
+
# "--db", str(db_root),
|
| 244 |
+
# "--table", str(table_json),
|
| 245 |
+
# ]
|
| 246 |
+
|
| 247 |
+
# proc_exec = subprocess.run(cmd_exec, capture_output=True, text=True)
|
| 248 |
+
# exec_acc = _parse_spider_accuracy(proc_exec.stdout, "exec")
|
| 249 |
+
|
| 250 |
+
# print("==========================================")
|
| 251 |
+
# print(f"🎯 OFFICIAL SPIDER RESULTS FOR: {args.adapter}")
|
| 252 |
+
# print("==========================================")
|
| 253 |
+
|
| 254 |
+
# print(f"Exact Match Accuracy : {exact_acc*100:.2f}%" if exact_acc else "EM parsing failed")
|
| 255 |
+
# print(f"Execution Accuracy : {exec_acc*100:.2f}%" if exec_acc else "EX parsing failed")
|
| 256 |
+
|
| 257 |
+
# print("==========================================\n")
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
# if __name__ == "__main__":
|
| 261 |
+
# main()
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
# *********** for task 2 ****************************************
|
| 267 |
+
import json
|
| 268 |
+
import argparse
|
| 269 |
+
import random
|
| 270 |
+
import sqlite3
|
| 271 |
+
import re
|
| 272 |
+
import os
|
| 273 |
+
from pathlib import Path
|
| 274 |
+
from collections import defaultdict
|
| 275 |
+
|
| 276 |
+
import torch
|
| 277 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
| 278 |
+
from peft import PeftModel
|
| 279 |
+
|
| 280 |
+
from prompting import encode_prompt
|
| 281 |
+
|
| 282 |
+
# -------------------------------
|
| 283 |
+
# NORMALIZATION
|
| 284 |
+
# -------------------------------
|
| 285 |
+
def normalize_sql(sql):
|
| 286 |
+
sql = sql.replace('"', "'")
|
| 287 |
+
sql = re.sub(r"\s+", " ", sql)
|
| 288 |
+
return sql.strip().lower().rstrip(";")
|
| 289 |
+
|
| 290 |
+
def normalize_result(res):
|
| 291 |
+
try:
|
| 292 |
+
return sorted([str(r) for r in res])
|
| 293 |
+
except:
|
| 294 |
+
return []
|
| 295 |
+
|
| 296 |
+
# -------------------------------
|
| 297 |
+
# STEP 1: EXECUTION
|
| 298 |
+
# -------------------------------
|
| 299 |
+
def execute_with_error(sql, db_path):
|
| 300 |
+
try:
|
| 301 |
+
conn = sqlite3.connect(db_path)
|
| 302 |
+
cur = conn.cursor()
|
| 303 |
+
cur.execute(sql)
|
| 304 |
+
res = cur.fetchall()
|
| 305 |
+
conn.close()
|
| 306 |
+
return res, None
|
| 307 |
+
except Exception as e:
|
| 308 |
+
return None, str(e)
|
| 309 |
+
|
| 310 |
+
# -------------------------------
|
| 311 |
+
# STEP 2: ERROR CLASSIFICATION
|
| 312 |
+
# -------------------------------
|
| 313 |
+
def classify_error(sql, error_msg):
|
| 314 |
+
if error_msg is None:
|
| 315 |
+
return "correct"
|
| 316 |
+
|
| 317 |
+
err = error_msg.lower()
|
| 318 |
+
sql_l = sql.lower()
|
| 319 |
+
|
| 320 |
+
if "syntax" in err:
|
| 321 |
+
return "syntax_error"
|
| 322 |
+
if "no such table" in err:
|
| 323 |
+
return "wrong_table"
|
| 324 |
+
if "no such column" in err:
|
| 325 |
+
return "wrong_column"
|
| 326 |
+
if "ambiguous" in err:
|
| 327 |
+
return "missing_join"
|
| 328 |
+
if "datatype mismatch" in err:
|
| 329 |
+
return "type_error"
|
| 330 |
+
if "where" not in sql_l and any(x in sql_l for x in ["=", ">", "<"]):
|
| 331 |
+
return "missing_where"
|
| 332 |
+
|
| 333 |
+
return "other"
|
| 334 |
+
|
| 335 |
+
# -------------------------------
|
| 336 |
+
# STEP 4: HINTS
|
| 337 |
+
# -------------------------------
|
| 338 |
+
def generate_hint(error_type):
|
| 339 |
+
hints = {
|
| 340 |
+
"missing_join": "Try using JOIN between related tables.",
|
| 341 |
+
"wrong_column": "Check column names in schema.",
|
| 342 |
+
"missing_where": "Add WHERE condition.",
|
| 343 |
+
"syntax_error": "Fix SQL syntax.",
|
| 344 |
+
"wrong_table": "Verify table names.",
|
| 345 |
+
"type_error": "Check data types.",
|
| 346 |
+
"other": "Review SQL logic."
|
| 347 |
+
}
|
| 348 |
+
return hints.get(error_type, "")
|
| 349 |
+
|
| 350 |
+
# -------------------------------
|
| 351 |
+
# STEP 2 EXTRA: LIGHT ATTRIBUTION
|
| 352 |
+
# -------------------------------
|
| 353 |
+
def extract_keywords(question):
|
| 354 |
+
return [w for w in re.findall(r"\w+", question.lower()) if len(w) > 3]
|
| 355 |
+
|
| 356 |
+
# -------------------------------
|
| 357 |
+
# MAIN
|
| 358 |
+
# -------------------------------
|
| 359 |
+
def main():
|
| 360 |
+
parser = argparse.ArgumentParser()
|
| 361 |
+
parser.add_argument("--adapter", type=str, required=True)
|
| 362 |
+
parser.add_argument("--num_samples", type=int, default=200)
|
| 363 |
+
args = parser.parse_args()
|
| 364 |
+
|
| 365 |
+
project_root = Path(__file__).resolve().parents[1]
|
| 366 |
+
db_root = project_root / "data" / "database"
|
| 367 |
+
dev_json = project_root / "data" / "dev.json"
|
| 368 |
+
|
| 369 |
+
device = "mps" if torch.backends.mps.is_available() else "cpu"
|
| 370 |
+
|
| 371 |
+
tokenizer = AutoTokenizer.from_pretrained("Salesforce/codet5-base")
|
| 372 |
+
base = AutoModelForSeq2SeqLM.from_pretrained("Salesforce/codet5-base").to(device)
|
| 373 |
+
|
| 374 |
+
model = PeftModel.from_pretrained(
|
| 375 |
+
base,
|
| 376 |
+
os.path.relpath(project_root / args.adapter, project_root),
|
| 377 |
+
local_files_only=True
|
| 378 |
+
).to(device)
|
| 379 |
+
|
| 380 |
+
model = model.merge_and_unload()
|
| 381 |
+
model.eval()
|
| 382 |
+
|
| 383 |
+
with open(dev_json) as f:
|
| 384 |
+
dev = json.load(f)
|
| 385 |
+
|
| 386 |
+
dev = dev[:args.num_samples]
|
| 387 |
+
|
| 388 |
+
# STORAGE
|
| 389 |
+
error_counter = defaultdict(int)
|
| 390 |
+
error_examples = defaultdict(list)
|
| 391 |
+
success_examples = []
|
| 392 |
+
hint_examples = defaultdict(list)
|
| 393 |
+
operation_counter = defaultdict(int)
|
| 394 |
+
attribution_map = defaultdict(list)
|
| 395 |
+
|
| 396 |
+
em, ex = 0, 0
|
| 397 |
+
|
| 398 |
+
print(f"\n🚀 Evaluating {len(dev)} samples...\n")
|
| 399 |
+
|
| 400 |
+
for i, sample in enumerate(dev, 1):
|
| 401 |
+
|
| 402 |
+
db_id = sample["db_id"]
|
| 403 |
+
q = sample["question"]
|
| 404 |
+
gold = sample["query"]
|
| 405 |
+
db_path = db_root / db_id / f"{db_id}.sqlite"
|
| 406 |
+
|
| 407 |
+
input_ids = encode_prompt(tokenizer, q, db_id, device=device).unsqueeze(0)
|
| 408 |
+
|
| 409 |
+
out = model.generate(input_ids=input_ids, max_new_tokens=120, num_beams=8)
|
| 410 |
+
pred = tokenizer.decode(out[0], skip_special_tokens=True).strip()
|
| 411 |
+
|
| 412 |
+
# operation analysis
|
| 413 |
+
s = pred.lower()
|
| 414 |
+
if "select" in s: operation_counter["SELECT"] += 1
|
| 415 |
+
if "where" in s: operation_counter["WHERE"] += 1
|
| 416 |
+
if "join" in s: operation_counter["JOIN"] += 1
|
| 417 |
+
if "group by" in s: operation_counter["GROUP_BY"] += 1
|
| 418 |
+
if "order by" in s: operation_counter["ORDER_BY"] += 1
|
| 419 |
+
|
| 420 |
+
pred_res, err = execute_with_error(pred, db_path)
|
| 421 |
+
gold_res, _ = execute_with_error(gold, db_path)
|
| 422 |
+
|
| 423 |
+
error_type = classify_error(pred, err)
|
| 424 |
+
error_counter[error_type] += 1
|
| 425 |
+
|
| 426 |
+
# attribution
|
| 427 |
+
if err:
|
| 428 |
+
attribution_map[error_type].append(extract_keywords(q))
|
| 429 |
+
|
| 430 |
+
# examples
|
| 431 |
+
if len(error_examples[error_type]) < 3:
|
| 432 |
+
error_examples[error_type].append(pred)
|
| 433 |
+
|
| 434 |
+
# hints
|
| 435 |
+
if error_type != "correct":
|
| 436 |
+
hint = generate_hint(error_type)
|
| 437 |
+
if len(hint_examples[error_type]) < 3:
|
| 438 |
+
hint_examples[error_type].append((pred, hint))
|
| 439 |
+
|
| 440 |
+
# metrics
|
| 441 |
+
if normalize_sql(pred) == normalize_sql(gold):
|
| 442 |
+
em += 1
|
| 443 |
+
|
| 444 |
+
if pred_res and gold_res and normalize_result(pred_res) == normalize_result(gold_res):
|
| 445 |
+
ex += 1
|
| 446 |
+
if len(success_examples) < 5:
|
| 447 |
+
success_examples.append(pred)
|
| 448 |
+
|
| 449 |
+
if i % 20 == 0:
|
| 450 |
+
print(f"[{i}] EM: {em/i:.2f} | EX: {ex/i:.2f}")
|
| 451 |
+
|
| 452 |
+
# -------------------------------
|
| 453 |
+
# OUTPUT
|
| 454 |
+
# -------------------------------
|
| 455 |
+
print("\n🎯 FINAL RESULTS")
|
| 456 |
+
print(f"EM: {em/len(dev)*100:.2f}%")
|
| 457 |
+
print(f"EX: {ex/len(dev)*100:.2f}%")
|
| 458 |
+
|
| 459 |
+
print("\n🔥 ERROR SUMMARY")
|
| 460 |
+
for k, v in error_counter.items():
|
| 461 |
+
print(k, ":", v)
|
| 462 |
+
|
| 463 |
+
print("\n🔥 ERROR EXAMPLES")
|
| 464 |
+
for k in error_examples:
|
| 465 |
+
print("\n", k)
|
| 466 |
+
for e in error_examples[k]:
|
| 467 |
+
print(" ", e)
|
| 468 |
+
|
| 469 |
+
print("\n🔥 HINTS")
|
| 470 |
+
for k in hint_examples:
|
| 471 |
+
print("\n", k)
|
| 472 |
+
for sql, h in hint_examples[k]:
|
| 473 |
+
print(" ", sql)
|
| 474 |
+
print(" →", h)
|
| 475 |
+
|
| 476 |
+
print("\n🔥 ATTRIBUTION (KEYWORDS)")
|
| 477 |
+
for k in attribution_map:
|
| 478 |
+
print(k, ":", attribution_map[k][:3])
|
| 479 |
+
|
| 480 |
+
print("\n🔥 SQL OPERATIONS")
|
| 481 |
+
for k, v in operation_counter.items():
|
| 482 |
+
print(k, ":", v)
|
| 483 |
+
|
| 484 |
+
# -------------------------------
|
| 485 |
+
# ADVERSARIAL
|
| 486 |
+
# -------------------------------
|
| 487 |
+
print("\n🔥 ADVERSARIAL TESTS")
|
| 488 |
+
|
| 489 |
+
adv = [
|
| 490 |
+
"Find most expensive product",
|
| 491 |
+
"Top 3 students by marks",
|
| 492 |
+
"Average salary per department"
|
| 493 |
+
]
|
| 494 |
+
|
| 495 |
+
for q in adv:
|
| 496 |
+
inp = encode_prompt(tokenizer, q, dev[0]["db_id"], device=device).unsqueeze(0)
|
| 497 |
+
out = model.generate(input_ids=inp, max_new_tokens=120)
|
| 498 |
+
print("\nQ:", q)
|
| 499 |
+
print("SQL:", tokenizer.decode(out[0], skip_special_tokens=True))
|
| 500 |
+
|
| 501 |
+
|
| 502 |
+
if __name__ == "__main__":
|
| 503 |
+
main()
|
src/execution_reward copy.py
ADDED
|
@@ -0,0 +1,831 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
|
| 3 |
+
# from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
# import hashlib
|
| 6 |
+
# import os
|
| 7 |
+
# import queue
|
| 8 |
+
# import re
|
| 9 |
+
# import sqlite3
|
| 10 |
+
# import threading
|
| 11 |
+
# import time
|
| 12 |
+
# from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 13 |
+
# from dataclasses import dataclass
|
| 14 |
+
# from typing import Dict, List, Optional, Sequence, Set, Tuple, Union
|
| 15 |
+
|
| 16 |
+
# # --- CACHE CONTROL ---
|
| 17 |
+
# USE_CACHE = True
|
| 18 |
+
# _REWARD_CACHE: Dict[str, float] = {}
|
| 19 |
+
|
| 20 |
+
# def set_use_cache(enabled: bool):
|
| 21 |
+
# """Dynamically toggle the reward cache for benchmarks."""
|
| 22 |
+
# global USE_CACHE
|
| 23 |
+
# USE_CACHE = enabled
|
| 24 |
+
|
| 25 |
+
# def _normalize_sql(sql: str) -> str:
|
| 26 |
+
# if not isinstance(sql, str):
|
| 27 |
+
# return ""
|
| 28 |
+
# s = sql.strip()
|
| 29 |
+
# if s.startswith("```"):
|
| 30 |
+
# s = re.sub(r"^```[a-zA-Z0-9_+-]*\n?", "", s).strip()
|
| 31 |
+
# s = re.sub(r"\n?```$", "", s).strip()
|
| 32 |
+
# if s.lower().startswith("sql:"):
|
| 33 |
+
# s = s[4:].strip()
|
| 34 |
+
# if ";" in s:
|
| 35 |
+
# s = s.split(";", 1)[0].strip()
|
| 36 |
+
# return s
|
| 37 |
+
|
| 38 |
+
# def _connect_readonly(db_path: str) -> sqlite3.Connection:
|
| 39 |
+
# uri = f"file:{os.path.abspath(db_path)}?mode=ro"
|
| 40 |
+
# conn = sqlite3.connect(uri, uri=True, check_same_thread=False)
|
| 41 |
+
# conn.execute("PRAGMA query_only = ON;")
|
| 42 |
+
# conn.execute("PRAGMA foreign_keys = ON;")
|
| 43 |
+
# return conn
|
| 44 |
+
|
| 45 |
+
# DEFAULT_QUERY_TIMEOUT_S = 2.0
|
| 46 |
+
|
| 47 |
+
# def _with_timeout(conn: sqlite3.Connection, timeout_s: float = DEFAULT_QUERY_TIMEOUT_S) -> None:
|
| 48 |
+
# start = time.monotonic()
|
| 49 |
+
# def _handler() -> int:
|
| 50 |
+
# return 1 if (time.monotonic() - start) > timeout_s else 0
|
| 51 |
+
# conn.set_progress_handler(_handler, 10_000)
|
| 52 |
+
|
| 53 |
+
# def _list_tables(conn: sqlite3.Connection) -> List[str]:
|
| 54 |
+
# try:
|
| 55 |
+
# cur = conn.execute("SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%';")
|
| 56 |
+
# return [r[0] for r in cur.fetchall() if r and isinstance(r[0], str)]
|
| 57 |
+
# except sqlite3.Error:
|
| 58 |
+
# return []
|
| 59 |
+
|
| 60 |
+
# def _contains_table_name(sql: str, table_names: Sequence[str]) -> bool:
|
| 61 |
+
# s = sql.lower()
|
| 62 |
+
# for t in table_names:
|
| 63 |
+
# tl = t.lower()
|
| 64 |
+
# if not tl:
|
| 65 |
+
# continue
|
| 66 |
+
# if re.search(rf"\b{re.escape(tl)}\b", s):
|
| 67 |
+
# return True
|
| 68 |
+
# return False
|
| 69 |
+
|
| 70 |
+
# def _explain_query_plan(conn: sqlite3.Connection, sql: str) -> bool:
|
| 71 |
+
# try:
|
| 72 |
+
# _with_timeout(conn, timeout_s=DEFAULT_QUERY_TIMEOUT_S)
|
| 73 |
+
# conn.execute(f"EXPLAIN QUERY PLAN {sql}")
|
| 74 |
+
# return True
|
| 75 |
+
# except sqlite3.Error:
|
| 76 |
+
# return False
|
| 77 |
+
|
| 78 |
+
# def _execute(conn: sqlite3.Connection, sql: str, max_rows: int = 1000) -> Tuple[bool, List[Tuple], Optional[str]]:
|
| 79 |
+
# try:
|
| 80 |
+
# _with_timeout(conn, timeout_s=DEFAULT_QUERY_TIMEOUT_S)
|
| 81 |
+
# cur = conn.execute(sql)
|
| 82 |
+
# rows = cur.fetchmany(max_rows)
|
| 83 |
+
# norm_rows = [tuple(r) for r in rows]
|
| 84 |
+
# return True, norm_rows, None
|
| 85 |
+
# except sqlite3.Error as e:
|
| 86 |
+
# return False, [], str(e)
|
| 87 |
+
|
| 88 |
+
# _SQL_KEYWORDS_TO_IGNORE = {
|
| 89 |
+
# "select", "from", "where", "join", "inner", "left", "right", "full", "outer",
|
| 90 |
+
# "on", "group", "by", "order", "limit", "having", "distinct", "union", "intersect",
|
| 91 |
+
# "except", "as", "and", "or", "not", "in", "is", "null", "like", "between", "case",
|
| 92 |
+
# "when", "then", "else", "end", "asc", "desc"
|
| 93 |
+
# }
|
| 94 |
+
|
| 95 |
+
# _SQL_FUNCTIONS_TO_IGNORE = {
|
| 96 |
+
# "count", "avg", "min", "max", "sum", "lower", "upper", "substr", "coalesce",
|
| 97 |
+
# "round", "date", "datetime", "strftime"
|
| 98 |
+
# }
|
| 99 |
+
|
| 100 |
+
# # --- LIGHTWEIGHT PARSING ---
|
| 101 |
+
# def is_valid_select(sql: str):
|
| 102 |
+
# sql = sql.strip().lower()
|
| 103 |
+
# return sql.startswith("select") or sql.startswith("with")
|
| 104 |
+
|
| 105 |
+
# def extract_tables(sql: str) -> List[str]:
|
| 106 |
+
# sql = sql.lower()
|
| 107 |
+
# if "join" not in sql:
|
| 108 |
+
# tables = re.findall(r'from\s+(\w+)', sql)
|
| 109 |
+
# return list(set(tables))
|
| 110 |
+
|
| 111 |
+
# tables = re.findall(r'from\s+([a-zA-Z_][a-zA-Z0-9_]*)', sql)
|
| 112 |
+
# joins = re.findall(r'join\s+([a-zA-Z_][a-zA-Z0-9_]*)', sql)
|
| 113 |
+
# return list(set(tables + joins))
|
| 114 |
+
|
| 115 |
+
# def extract_columns(sql: str) -> List[str]:
|
| 116 |
+
# sql = sql.lower()
|
| 117 |
+
# match = re.search(r'select\s+(.*?)\s+from', sql)
|
| 118 |
+
# if not match:
|
| 119 |
+
# return []
|
| 120 |
+
# cols = match.group(1)
|
| 121 |
+
# if cols.strip() == "*":
|
| 122 |
+
# return ["*"]
|
| 123 |
+
# return [c.strip() for c in cols.split(",")]
|
| 124 |
+
|
| 125 |
+
# def _get_db_tables_and_columns(conn: sqlite3.Connection) -> Tuple[Set[str], Set[str]]:
|
| 126 |
+
# tables = set()
|
| 127 |
+
# columns = set()
|
| 128 |
+
# for t in _list_tables(conn):
|
| 129 |
+
# tl = t.lower()
|
| 130 |
+
# if not tl:
|
| 131 |
+
# continue
|
| 132 |
+
# tables.add(tl)
|
| 133 |
+
# try:
|
| 134 |
+
# cur = conn.execute(f'PRAGMA table_info("{t}")')
|
| 135 |
+
# for row in cur.fetchall():
|
| 136 |
+
# if row and isinstance(row[1], str):
|
| 137 |
+
# columns.add(row[1].lower())
|
| 138 |
+
# except sqlite3.Error:
|
| 139 |
+
# continue
|
| 140 |
+
# return tables, columns
|
| 141 |
+
|
| 142 |
+
# def _safe_results_equal(a: List[Tuple], b: List[Tuple]) -> bool:
|
| 143 |
+
# return a == b
|
| 144 |
+
|
| 145 |
+
# @dataclass
|
| 146 |
+
# class RewardDebugStats:
|
| 147 |
+
# total: int = 0
|
| 148 |
+
# parsed_ok: int = 0
|
| 149 |
+
# table_match: int = 0
|
| 150 |
+
# column_match: int = 0
|
| 151 |
+
# executed_ok: int = 0
|
| 152 |
+
# exact_match: int = 0
|
| 153 |
+
|
| 154 |
+
# _DEBUG = RewardDebugStats()
|
| 155 |
+
|
| 156 |
+
# def reset_debug_metrics() -> None:
|
| 157 |
+
# global _DEBUG
|
| 158 |
+
# _DEBUG = RewardDebugStats()
|
| 159 |
+
|
| 160 |
+
# def get_debug_metrics() -> dict:
|
| 161 |
+
# denom = max(_DEBUG.total, 1)
|
| 162 |
+
# return {
|
| 163 |
+
# "valid_sql_rate": _DEBUG.parsed_ok / denom,
|
| 164 |
+
# "table_match_rate": _DEBUG.table_match / denom,
|
| 165 |
+
# "column_match_rate": _DEBUG.column_match / denom,
|
| 166 |
+
# "execution_accuracy": _DEBUG.exact_match / denom,
|
| 167 |
+
# }
|
| 168 |
+
|
| 169 |
+
# EXECUTION_ERROR = "EXECUTION_ERROR"
|
| 170 |
+
|
| 171 |
+
# _RESULT_CACHE_LOCK = threading.Lock()
|
| 172 |
+
# _RESULT_CACHE: "Dict[str, Union[List[Tuple], str]]" = {}
|
| 173 |
+
# _RESULT_CACHE_MAX = 100_000
|
| 174 |
+
|
| 175 |
+
# def clear_result_cache() -> None:
|
| 176 |
+
# """Clear both DB query cache and reward cache."""
|
| 177 |
+
# with _RESULT_CACHE_LOCK:
|
| 178 |
+
# _RESULT_CACHE.clear()
|
| 179 |
+
# _REWARD_CACHE.clear()
|
| 180 |
+
|
| 181 |
+
# def _db_state_fingerprint(db_path: str) -> str:
|
| 182 |
+
# try:
|
| 183 |
+
# st = os.stat(db_path)
|
| 184 |
+
# return f"{st.st_mtime_ns}:{st.st_size}"
|
| 185 |
+
# except OSError:
|
| 186 |
+
# return "missing"
|
| 187 |
+
|
| 188 |
+
# def _result_cache_key(db_path: str, sql: str) -> str:
|
| 189 |
+
# fp = _db_state_fingerprint(db_path)
|
| 190 |
+
# payload = f"{fp}\0{sql}".encode("utf-8", errors="ignore")
|
| 191 |
+
# return hashlib.sha256(payload).hexdigest()
|
| 192 |
+
|
| 193 |
+
# class _ConnectionPool:
|
| 194 |
+
# def __init__(self, db_path: str, maxsize: int = 1) -> None:
|
| 195 |
+
# self.db_path = db_path
|
| 196 |
+
# self.pool = queue.LifoQueue(maxsize=maxsize)
|
| 197 |
+
# self.lock = threading.Lock()
|
| 198 |
+
|
| 199 |
+
# def acquire(self) -> sqlite3.Connection:
|
| 200 |
+
# try:
|
| 201 |
+
# return self.pool.get_nowait()
|
| 202 |
+
# except queue.Empty:
|
| 203 |
+
# with self.lock:
|
| 204 |
+
# try:
|
| 205 |
+
# return self.pool.get_nowait()
|
| 206 |
+
# except queue.Empty:
|
| 207 |
+
# return _connect_readonly(self.db_path)
|
| 208 |
+
|
| 209 |
+
# def release(self, conn: sqlite3.Connection) -> None:
|
| 210 |
+
# try:
|
| 211 |
+
# self.pool.put_nowait(conn)
|
| 212 |
+
# except queue.Full:
|
| 213 |
+
# try:
|
| 214 |
+
# conn.close()
|
| 215 |
+
# except Exception:
|
| 216 |
+
# pass
|
| 217 |
+
|
| 218 |
+
# _POOL_LOCK = threading.Lock()
|
| 219 |
+
# _POOLS: Dict[str, _ConnectionPool] = {}
|
| 220 |
+
|
| 221 |
+
# def _get_pool(db_path: str) -> _ConnectionPool:
|
| 222 |
+
# with _POOL_LOCK:
|
| 223 |
+
# pool = _POOLS.get(db_path)
|
| 224 |
+
# if pool is None:
|
| 225 |
+
# pool = _ConnectionPool(db_path=db_path, maxsize=1)
|
| 226 |
+
# _POOLS[db_path] = pool
|
| 227 |
+
# return pool
|
| 228 |
+
|
| 229 |
+
# class _PooledConnection:
|
| 230 |
+
# def __init__(self, db_path: str) -> None:
|
| 231 |
+
# self.db_path = db_path
|
| 232 |
+
# self.pool = _get_pool(db_path)
|
| 233 |
+
# self.conn: Optional[sqlite3.Connection] = None
|
| 234 |
+
|
| 235 |
+
# def __enter__(self) -> sqlite3.Connection:
|
| 236 |
+
# self.conn = self.pool.acquire()
|
| 237 |
+
# return self.conn
|
| 238 |
+
|
| 239 |
+
# def __exit__(self, exc_type, exc, tb) -> None:
|
| 240 |
+
# if self.conn is not None:
|
| 241 |
+
# self.pool.release(self.conn)
|
| 242 |
+
# self.conn = None
|
| 243 |
+
|
| 244 |
+
# def _cache_get(key: str) -> Optional[Union[List[Tuple], str]]:
|
| 245 |
+
# with _RESULT_CACHE_LOCK:
|
| 246 |
+
# return _RESULT_CACHE.get(key)
|
| 247 |
+
|
| 248 |
+
# def _cache_put(key: str, value: Union[List[Tuple], str]) -> None:
|
| 249 |
+
# with _RESULT_CACHE_LOCK:
|
| 250 |
+
# if len(_RESULT_CACHE) >= _RESULT_CACHE_MAX:
|
| 251 |
+
# _RESULT_CACHE.clear()
|
| 252 |
+
# _RESULT_CACHE[key] = value
|
| 253 |
+
|
| 254 |
+
# def execute_sql(conn: sqlite3.Connection, sql: str, *, max_rows: int = 1000) -> Union[List[Tuple], str]:
|
| 255 |
+
# try:
|
| 256 |
+
# _with_timeout(conn, timeout_s=DEFAULT_QUERY_TIMEOUT_S)
|
| 257 |
+
# cur = conn.execute(sql)
|
| 258 |
+
# rows = cur.fetchmany(max_rows)
|
| 259 |
+
# return [tuple(r) for r in rows]
|
| 260 |
+
# except Exception:
|
| 261 |
+
# return EXECUTION_ERROR
|
| 262 |
+
|
| 263 |
+
# def execute_sql_cached(db_path: str, sql: str, *, max_rows: int = 1000) -> Union[List[Tuple], str]:
|
| 264 |
+
# if not USE_CACHE:
|
| 265 |
+
# with _PooledConnection(db_path) as conn:
|
| 266 |
+
# return execute_sql(conn, sql, max_rows=max_rows)
|
| 267 |
+
|
| 268 |
+
# key = _result_cache_key(db_path, sql)
|
| 269 |
+
# cached = _cache_get(key)
|
| 270 |
+
# if cached is not None:
|
| 271 |
+
# return cached
|
| 272 |
+
# with _PooledConnection(db_path) as conn:
|
| 273 |
+
# res = execute_sql(conn, sql, max_rows=max_rows)
|
| 274 |
+
# _cache_put(key, res)
|
| 275 |
+
# return res
|
| 276 |
+
|
| 277 |
+
# def execution_reward_timed(
|
| 278 |
+
# pred_sql: str, db_path: str, gold_sql: str, *, measure_plan: bool = False,
|
| 279 |
+
# ) -> Tuple[float, Dict[str, float]]:
|
| 280 |
+
# timings = {"parse_s": 0.0, "plan_s": 0.0, "exec_s": 0.0}
|
| 281 |
+
# t0 = time.perf_counter()
|
| 282 |
+
# sql = _normalize_sql(pred_sql)
|
| 283 |
+
# gold = _normalize_sql(gold_sql)
|
| 284 |
+
|
| 285 |
+
# if not is_valid_select(sql):
|
| 286 |
+
# timings["parse_s"] = time.perf_counter() - t0
|
| 287 |
+
# return 0.0, timings
|
| 288 |
+
|
| 289 |
+
# t1 = time.perf_counter()
|
| 290 |
+
# timings["parse_s"] = t1 - t0
|
| 291 |
+
|
| 292 |
+
# if measure_plan:
|
| 293 |
+
# with _PooledConnection(db_path) as conn:
|
| 294 |
+
# p0 = time.perf_counter()
|
| 295 |
+
# _explain_query_plan(conn, sql)
|
| 296 |
+
# _explain_query_plan(conn, gold)
|
| 297 |
+
# timings["plan_s"] = time.perf_counter() - p0
|
| 298 |
+
|
| 299 |
+
# e0 = time.perf_counter()
|
| 300 |
+
# pred_res = execute_sql_cached(db_path, sql)
|
| 301 |
+
# if pred_res == EXECUTION_ERROR:
|
| 302 |
+
# timings["exec_s"] = time.perf_counter() - e0
|
| 303 |
+
# return 0.0, timings
|
| 304 |
+
# gold_res = execute_sql_cached(db_path, gold)
|
| 305 |
+
# timings["exec_s"] = time.perf_counter() - e0
|
| 306 |
+
# if gold_res == EXECUTION_ERROR:
|
| 307 |
+
# return 0.0, timings
|
| 308 |
+
|
| 309 |
+
# reward = -0.2
|
| 310 |
+
# reward += 0.2
|
| 311 |
+
# if _safe_results_equal(pred_res, gold_res):
|
| 312 |
+
# return 1.0, timings
|
| 313 |
+
# return max(-1.0, min(1.0, reward)), timings
|
| 314 |
+
|
| 315 |
+
# def execution_reward(pred_sql: str, db_path: str, gold_sql: str) -> float:
|
| 316 |
+
# try:
|
| 317 |
+
# sql = _normalize_sql(pred_sql)
|
| 318 |
+
# gold = _normalize_sql(gold_sql)
|
| 319 |
+
|
| 320 |
+
# if not is_valid_select(sql):
|
| 321 |
+
# return -1.0
|
| 322 |
+
|
| 323 |
+
# reward = -0.2
|
| 324 |
+
|
| 325 |
+
# pred_tables = set(extract_tables(sql))
|
| 326 |
+
# gold_tables = set(extract_tables(gold))
|
| 327 |
+
|
| 328 |
+
# if pred_tables == gold_tables and len(gold_tables) > 0:
|
| 329 |
+
# reward += 0.3
|
| 330 |
+
|
| 331 |
+
# pred_cols = set(extract_columns(sql))
|
| 332 |
+
# gold_cols = set(extract_columns(gold))
|
| 333 |
+
|
| 334 |
+
# if gold_cols:
|
| 335 |
+
# overlap = len(pred_cols & gold_cols) / len(gold_cols)
|
| 336 |
+
# reward += 0.3 * overlap
|
| 337 |
+
|
| 338 |
+
# pred_res = execute_sql_cached(db_path, sql)
|
| 339 |
+
# if pred_res == EXECUTION_ERROR:
|
| 340 |
+
# return 0.0
|
| 341 |
+
# reward += 0.2
|
| 342 |
+
|
| 343 |
+
# gold_res = execute_sql_cached(db_path, gold)
|
| 344 |
+
# if gold_res == EXECUTION_ERROR:
|
| 345 |
+
# return 0.0
|
| 346 |
+
# if _safe_results_equal(pred_res, gold_res):
|
| 347 |
+
# return 1.0
|
| 348 |
+
|
| 349 |
+
# return max(-1.0, min(1.0, reward))
|
| 350 |
+
|
| 351 |
+
# except Exception:
|
| 352 |
+
# return 0.0
|
| 353 |
+
|
| 354 |
+
# def cached_execution_reward(pred_sql: str, db_path: str, gold_sql: str) -> float:
|
| 355 |
+
# if not USE_CACHE:
|
| 356 |
+
# return execution_reward(pred_sql, db_path, gold_sql)
|
| 357 |
+
|
| 358 |
+
# key = f"{db_path}|{pred_sql}|{gold_sql}"
|
| 359 |
+
# if key not in _REWARD_CACHE:
|
| 360 |
+
# _REWARD_CACHE[key] = execution_reward(pred_sql, db_path, gold_sql)
|
| 361 |
+
# return _REWARD_CACHE[key]
|
| 362 |
+
|
| 363 |
+
# def execution_reward_batch_sequential(rollouts: Sequence[Tuple[str, str, str]]) -> List[float]:
|
| 364 |
+
# return [cached_execution_reward(pred_sql, db_path, gold_sql) for pred_sql, db_path, gold_sql in rollouts]
|
| 365 |
+
|
| 366 |
+
# def execution_reward_batch_parallel(rollouts: Sequence[Tuple[str, str, str]], *, max_workers: int = 20) -> List[float]:
|
| 367 |
+
# if not rollouts:
|
| 368 |
+
# return []
|
| 369 |
+
|
| 370 |
+
# unique_dbs = {db_path for _, db_path, _ in rollouts}
|
| 371 |
+
# worker_count = max(1, min(max_workers, len(unique_dbs)))
|
| 372 |
+
# results: List[Optional[float]] = [None] * len(rollouts)
|
| 373 |
+
|
| 374 |
+
# with ThreadPoolExecutor(max_workers=worker_count) as executor:
|
| 375 |
+
# futures = {
|
| 376 |
+
# executor.submit(cached_execution_reward, pred_sql, db_path, gold_sql): i
|
| 377 |
+
# for i, (pred_sql, db_path, gold_sql) in enumerate(rollouts)
|
| 378 |
+
# }
|
| 379 |
+
# for fut in as_completed(futures):
|
| 380 |
+
# idx = futures[fut]
|
| 381 |
+
# try:
|
| 382 |
+
# results[idx] = float(fut.result())
|
| 383 |
+
# except Exception:
|
| 384 |
+
# results[idx] = 0.0
|
| 385 |
+
|
| 386 |
+
# return [r if r is not None else 0.0 for r in results]
|
| 387 |
+
|
| 388 |
+
from __future__ import annotations
|
| 389 |
+
|
| 390 |
+
import os
|
| 391 |
+
import re
|
| 392 |
+
import sqlite3
|
| 393 |
+
import threading
|
| 394 |
+
import time
|
| 395 |
+
import json
|
| 396 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 397 |
+
from dataclasses import dataclass
|
| 398 |
+
from typing import Dict, List
|
| 399 |
+
|
| 400 |
+
from src.sql_validator import validate_sql_schema
|
| 401 |
+
|
| 402 |
+
# =========================================================
|
| 403 |
+
# 🔥 CONFIG FLAGS
|
| 404 |
+
# =========================================================
|
| 405 |
+
USE_SCHEMA_VALIDATION = True
|
| 406 |
+
USE_CACHE = True
|
| 407 |
+
DEFAULT_QUERY_TIMEOUT_S = 2.0
|
| 408 |
+
|
| 409 |
+
EXECUTION_ERROR = "EXECUTION_ERROR"
|
| 410 |
+
|
| 411 |
+
_REWARD_CACHE: Dict[str, float] = {}
|
| 412 |
+
|
| 413 |
+
# =========================================================
|
| 414 |
+
# 🔥 TASK 2: ERROR ANALYSIS + LOGGING
|
| 415 |
+
# =========================================================
|
| 416 |
+
ERROR_LOG_FILE = "results/error_logs.json"
|
| 417 |
+
|
| 418 |
+
|
| 419 |
+
def classify_error(sql: str) -> str:
|
| 420 |
+
sql = sql.lower()
|
| 421 |
+
|
| 422 |
+
if "join" in sql and " on " not in sql:
|
| 423 |
+
return "missing_join"
|
| 424 |
+
|
| 425 |
+
if "where" in sql and "=" not in sql and ">" not in sql and "<" not in sql:
|
| 426 |
+
return "wrong_where"
|
| 427 |
+
|
| 428 |
+
if "null" in sql:
|
| 429 |
+
return "null_handling"
|
| 430 |
+
|
| 431 |
+
if "group by" in sql and "count" not in sql:
|
| 432 |
+
return "wrong_groupby"
|
| 433 |
+
|
| 434 |
+
return "other"
|
| 435 |
+
|
| 436 |
+
|
| 437 |
+
def get_hint(error_type: str) -> str:
|
| 438 |
+
hints = {
|
| 439 |
+
"missing_join": "Add proper JOIN condition using ON.",
|
| 440 |
+
"wrong_where": "Check WHERE clause conditions.",
|
| 441 |
+
"null_handling": "Handle NULL values using IS NULL.",
|
| 442 |
+
"wrong_groupby": "Use aggregation functions with GROUP BY.",
|
| 443 |
+
"other": "Check SQL syntax and logic."
|
| 444 |
+
}
|
| 445 |
+
return hints.get(error_type, "Check query.")
|
| 446 |
+
|
| 447 |
+
|
| 448 |
+
def log_error(question: str, sql: str, error: str, error_type: str):
|
| 449 |
+
os.makedirs("results", exist_ok=True)
|
| 450 |
+
|
| 451 |
+
entry = {
|
| 452 |
+
"question": question,
|
| 453 |
+
"sql": sql,
|
| 454 |
+
"error": error,
|
| 455 |
+
"error_type": error_type,
|
| 456 |
+
"timestamp": time.time()
|
| 457 |
+
}
|
| 458 |
+
|
| 459 |
+
if os.path.exists(ERROR_LOG_FILE):
|
| 460 |
+
with open(ERROR_LOG_FILE, "r") as f:
|
| 461 |
+
logs = json.load(f)
|
| 462 |
+
else:
|
| 463 |
+
logs = []
|
| 464 |
+
|
| 465 |
+
logs.append(entry)
|
| 466 |
+
|
| 467 |
+
with open(ERROR_LOG_FILE, "w") as f:
|
| 468 |
+
json.dump(logs, f, indent=2)
|
| 469 |
+
|
| 470 |
+
# =========================================================
|
| 471 |
+
# CACHE/VALIDATION TOGGLES (Task 1)
|
| 472 |
+
# =========================================================
|
| 473 |
+
def set_use_cache(enabled: bool) -> None:
|
| 474 |
+
global USE_CACHE
|
| 475 |
+
USE_CACHE = bool(enabled)
|
| 476 |
+
|
| 477 |
+
|
| 478 |
+
def set_use_schema_validation(enabled: bool) -> None:
|
| 479 |
+
global USE_SCHEMA_VALIDATION
|
| 480 |
+
USE_SCHEMA_VALIDATION = bool(enabled)
|
| 481 |
+
|
| 482 |
+
|
| 483 |
+
# =========================================================
|
| 484 |
+
# SQL CLEANING
|
| 485 |
+
# =========================================================
|
| 486 |
+
def _normalize_sql(sql: str) -> str:
|
| 487 |
+
if not isinstance(sql, str):
|
| 488 |
+
return ""
|
| 489 |
+
s = sql.strip()
|
| 490 |
+
|
| 491 |
+
if s.startswith("```"):
|
| 492 |
+
s = re.sub(r"^```[a-zA-Z0-9_+-]*\n?", "", s).strip()
|
| 493 |
+
s = re.sub(r"\n?```$", "", s).strip()
|
| 494 |
+
|
| 495 |
+
if s.lower().startswith("sql:"):
|
| 496 |
+
s = s[4:].strip()
|
| 497 |
+
|
| 498 |
+
if ";" in s:
|
| 499 |
+
s = s.split(";", 1)[0].strip()
|
| 500 |
+
|
| 501 |
+
return s
|
| 502 |
+
|
| 503 |
+
|
| 504 |
+
# =========================================================
|
| 505 |
+
# DB EXECUTION
|
| 506 |
+
# =========================================================
|
| 507 |
+
def _connect_readonly(db_path: str):
|
| 508 |
+
uri = f"file:{os.path.abspath(db_path)}?mode=ro"
|
| 509 |
+
conn = sqlite3.connect(uri, uri=True, check_same_thread=False)
|
| 510 |
+
conn.execute("PRAGMA query_only = ON;")
|
| 511 |
+
conn.execute("PRAGMA foreign_keys = ON;")
|
| 512 |
+
return conn
|
| 513 |
+
|
| 514 |
+
|
| 515 |
+
def _with_timeout(conn: sqlite3.Connection, timeout_s: float = DEFAULT_QUERY_TIMEOUT_S):
|
| 516 |
+
start = time.monotonic()
|
| 517 |
+
|
| 518 |
+
def handler():
|
| 519 |
+
return 1 if (time.monotonic() - start) > timeout_s else 0
|
| 520 |
+
|
| 521 |
+
conn.set_progress_handler(handler, 10_000)
|
| 522 |
+
|
| 523 |
+
|
| 524 |
+
def execute_sql(conn, sql):
|
| 525 |
+
try:
|
| 526 |
+
_with_timeout(conn, timeout_s=DEFAULT_QUERY_TIMEOUT_S)
|
| 527 |
+
cur = conn.execute(sql)
|
| 528 |
+
return cur.fetchall()
|
| 529 |
+
except Exception:
|
| 530 |
+
return EXECUTION_ERROR
|
| 531 |
+
|
| 532 |
+
|
| 533 |
+
_RESULT_CACHE = {}
|
| 534 |
+
_RESULT_LOCK = threading.Lock()
|
| 535 |
+
|
| 536 |
+
|
| 537 |
+
def execute_sql_cached(db_path, sql):
|
| 538 |
+
key = f"{db_path}|{sql}"
|
| 539 |
+
|
| 540 |
+
if USE_CACHE:
|
| 541 |
+
with _RESULT_LOCK:
|
| 542 |
+
if key in _RESULT_CACHE:
|
| 543 |
+
return _RESULT_CACHE[key]
|
| 544 |
+
|
| 545 |
+
conn = _connect_readonly(db_path)
|
| 546 |
+
result = execute_sql(conn, sql)
|
| 547 |
+
conn.close()
|
| 548 |
+
|
| 549 |
+
if USE_CACHE:
|
| 550 |
+
with _RESULT_LOCK:
|
| 551 |
+
_RESULT_CACHE[key] = result
|
| 552 |
+
|
| 553 |
+
return result
|
| 554 |
+
|
| 555 |
+
|
| 556 |
+
def execute_sql_cached_conn(conn: sqlite3.Connection, db_path: str, sql: str):
|
| 557 |
+
"""
|
| 558 |
+
Like execute_sql_cached(), but reuses an existing connection.
|
| 559 |
+
Intended for 1-thread-per-DB workloads (Task 1).
|
| 560 |
+
"""
|
| 561 |
+
key = f"{db_path}|{sql}"
|
| 562 |
+
if USE_CACHE:
|
| 563 |
+
with _RESULT_LOCK:
|
| 564 |
+
if key in _RESULT_CACHE:
|
| 565 |
+
return _RESULT_CACHE[key]
|
| 566 |
+
|
| 567 |
+
result = execute_sql(conn, sql)
|
| 568 |
+
|
| 569 |
+
if USE_CACHE:
|
| 570 |
+
with _RESULT_LOCK:
|
| 571 |
+
_RESULT_CACHE[key] = result
|
| 572 |
+
|
| 573 |
+
return result
|
| 574 |
+
|
| 575 |
+
|
| 576 |
+
def clear_result_cache() -> None:
|
| 577 |
+
global _RESULT_CACHE, _REWARD_CACHE
|
| 578 |
+
with _RESULT_LOCK:
|
| 579 |
+
_RESULT_CACHE.clear()
|
| 580 |
+
_REWARD_CACHE.clear()
|
| 581 |
+
|
| 582 |
+
|
| 583 |
+
# =========================================================
|
| 584 |
+
# SQL PARSING
|
| 585 |
+
# =========================================================
|
| 586 |
+
def is_valid_select(sql):
|
| 587 |
+
return sql.lower().startswith("select") or sql.lower().startswith("with")
|
| 588 |
+
|
| 589 |
+
|
| 590 |
+
def extract_tables(sql):
|
| 591 |
+
return re.findall(r'from\s+(\w+)', sql.lower())
|
| 592 |
+
|
| 593 |
+
|
| 594 |
+
def extract_columns(sql):
|
| 595 |
+
match = re.search(r'select\s+(.*?)\s+from', sql.lower())
|
| 596 |
+
if not match:
|
| 597 |
+
return []
|
| 598 |
+
cols = match.group(1)
|
| 599 |
+
return ["*"] if cols.strip() == "*" else [c.strip() for c in cols.split(",")]
|
| 600 |
+
|
| 601 |
+
|
| 602 |
+
def get_sql_operations(sql: str):
|
| 603 |
+
sql = sql.lower()
|
| 604 |
+
ops = []
|
| 605 |
+
|
| 606 |
+
if "select" in sql: ops.append("SELECT")
|
| 607 |
+
if "where" in sql: ops.append("WHERE")
|
| 608 |
+
if "join" in sql: ops.append("JOIN")
|
| 609 |
+
if "group by" in sql: ops.append("GROUP_BY")
|
| 610 |
+
if "order by" in sql: ops.append("ORDER_BY")
|
| 611 |
+
|
| 612 |
+
return ops
|
| 613 |
+
|
| 614 |
+
|
| 615 |
+
def _explain_query_plan(conn: sqlite3.Connection, sql: str) -> bool:
|
| 616 |
+
try:
|
| 617 |
+
_with_timeout(conn, timeout_s=DEFAULT_QUERY_TIMEOUT_S)
|
| 618 |
+
conn.execute(f"EXPLAIN QUERY PLAN {sql}")
|
| 619 |
+
return True
|
| 620 |
+
except Exception:
|
| 621 |
+
return False
|
| 622 |
+
|
| 623 |
+
|
| 624 |
+
def execution_reward_timed(pred_sql: str, db_path: str, gold_sql: str, measure_plan: bool = False):
|
| 625 |
+
"""
|
| 626 |
+
Returns (reward, timings) where timings keys: parse_s, plan_s, exec_s.
|
| 627 |
+
Used by Task-1 benchmark to profile bottlenecks.
|
| 628 |
+
"""
|
| 629 |
+
timings = {"parse_s": 0.0, "plan_s": 0.0, "exec_s": 0.0}
|
| 630 |
+
t0 = time.perf_counter()
|
| 631 |
+
|
| 632 |
+
sql = _normalize_sql(pred_sql)
|
| 633 |
+
gold = _normalize_sql(gold_sql)
|
| 634 |
+
|
| 635 |
+
if not is_valid_select(sql):
|
| 636 |
+
timings["parse_s"] = time.perf_counter() - t0
|
| 637 |
+
return 0.0, timings
|
| 638 |
+
|
| 639 |
+
t1 = time.perf_counter()
|
| 640 |
+
timings["parse_s"] = t1 - t0
|
| 641 |
+
|
| 642 |
+
conn = _connect_readonly(db_path)
|
| 643 |
+
try:
|
| 644 |
+
if measure_plan:
|
| 645 |
+
p0 = time.perf_counter()
|
| 646 |
+
_explain_query_plan(conn, sql)
|
| 647 |
+
_explain_query_plan(conn, gold)
|
| 648 |
+
timings["plan_s"] = time.perf_counter() - p0
|
| 649 |
+
|
| 650 |
+
e0 = time.perf_counter()
|
| 651 |
+
pred_res = execute_sql_cached_conn(conn, db_path, sql)
|
| 652 |
+
if pred_res == EXECUTION_ERROR:
|
| 653 |
+
timings["exec_s"] = time.perf_counter() - e0
|
| 654 |
+
return 0.0, timings
|
| 655 |
+
gold_res = execute_sql_cached_conn(conn, db_path, gold)
|
| 656 |
+
timings["exec_s"] = time.perf_counter() - e0
|
| 657 |
+
if gold_res == EXECUTION_ERROR:
|
| 658 |
+
return 0.0, timings
|
| 659 |
+
|
| 660 |
+
reward = -0.2 + 0.2
|
| 661 |
+
if pred_res == gold_res:
|
| 662 |
+
return 1.0, timings
|
| 663 |
+
return max(-1.0, min(1.0, reward)), timings
|
| 664 |
+
finally:
|
| 665 |
+
try:
|
| 666 |
+
conn.close()
|
| 667 |
+
except Exception:
|
| 668 |
+
pass
|
| 669 |
+
|
| 670 |
+
|
| 671 |
+
# =========================================================
|
| 672 |
+
# 🔥 FINAL REWARD FUNCTION (TASK 2 INTEGRATED)
|
| 673 |
+
# =========================================================
|
| 674 |
+
def execution_reward(pred_sql: str, db_path: str, gold_sql: str) -> float:
|
| 675 |
+
try:
|
| 676 |
+
sql = _normalize_sql(pred_sql)
|
| 677 |
+
gold = _normalize_sql(gold_sql)
|
| 678 |
+
|
| 679 |
+
if not is_valid_select(sql):
|
| 680 |
+
return -1.0
|
| 681 |
+
|
| 682 |
+
reward = -0.2
|
| 683 |
+
|
| 684 |
+
# =========================
|
| 685 |
+
# SCHEMA VALIDATION (Task 3)
|
| 686 |
+
# =========================
|
| 687 |
+
if USE_SCHEMA_VALIDATION:
|
| 688 |
+
valid, _ = validate_sql_schema(sql, db_path)
|
| 689 |
+
if not valid:
|
| 690 |
+
error_type = classify_error(sql)
|
| 691 |
+
log_error("UNKNOWN", sql, "schema_invalid", error_type)
|
| 692 |
+
return 0.1
|
| 693 |
+
|
| 694 |
+
# =========================
|
| 695 |
+
# EXECUTION
|
| 696 |
+
# =========================
|
| 697 |
+
pred_res = execute_sql_cached(db_path, sql)
|
| 698 |
+
|
| 699 |
+
if pred_res == "EXECUTION_ERROR":
|
| 700 |
+
error_type = classify_error(sql)
|
| 701 |
+
|
| 702 |
+
log_error(
|
| 703 |
+
question="UNKNOWN",
|
| 704 |
+
sql=sql,
|
| 705 |
+
error="execution_error",
|
| 706 |
+
error_type=error_type
|
| 707 |
+
)
|
| 708 |
+
|
| 709 |
+
print(f"[ERROR] {error_type}")
|
| 710 |
+
print(f"[HINT] {get_hint(error_type)}")
|
| 711 |
+
|
| 712 |
+
return 0.1
|
| 713 |
+
|
| 714 |
+
reward += 0.2
|
| 715 |
+
|
| 716 |
+
gold_res = execute_sql_cached(db_path, gold)
|
| 717 |
+
|
| 718 |
+
if gold_res == "EXECUTION_ERROR":
|
| 719 |
+
return 0.1
|
| 720 |
+
|
| 721 |
+
if pred_res == gold_res:
|
| 722 |
+
return 1.0
|
| 723 |
+
|
| 724 |
+
return max(-1.0, min(1.0, reward))
|
| 725 |
+
|
| 726 |
+
except Exception as e:
|
| 727 |
+
log_error("UNKNOWN", pred_sql, str(e), "runtime_error")
|
| 728 |
+
return 0.0
|
| 729 |
+
|
| 730 |
+
|
| 731 |
+
# =========================================================
|
| 732 |
+
# BATCH EXECUTION (Task 1)
|
| 733 |
+
# =========================================================
|
| 734 |
+
def cached_execution_reward(pred_sql: str, db_path: str, gold_sql: str) -> float:
|
| 735 |
+
if not USE_CACHE:
|
| 736 |
+
return float(execution_reward(pred_sql, db_path, gold_sql))
|
| 737 |
+
key = f"{db_path}|{pred_sql}|{gold_sql}"
|
| 738 |
+
if key in _REWARD_CACHE:
|
| 739 |
+
return float(_REWARD_CACHE[key])
|
| 740 |
+
r = float(execution_reward(pred_sql, db_path, gold_sql))
|
| 741 |
+
_REWARD_CACHE[key] = r
|
| 742 |
+
return r
|
| 743 |
+
|
| 744 |
+
|
| 745 |
+
def execution_reward_batch_sequential(rollouts):
|
| 746 |
+
return [cached_execution_reward(p, d, g) for (p, d, g) in rollouts]
|
| 747 |
+
|
| 748 |
+
|
| 749 |
+
def execution_reward_batch_parallel(rollouts, max_workers=10):
|
| 750 |
+
results = [0.0] * len(rollouts)
|
| 751 |
+
|
| 752 |
+
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
| 753 |
+
futures = {
|
| 754 |
+
executor.submit(cached_execution_reward, p, d, g): i
|
| 755 |
+
for i, (p, d, g) in enumerate(rollouts)
|
| 756 |
+
}
|
| 757 |
+
|
| 758 |
+
for fut in as_completed(futures):
|
| 759 |
+
idx = futures[fut]
|
| 760 |
+
try:
|
| 761 |
+
results[idx] = fut.result()
|
| 762 |
+
except Exception:
|
| 763 |
+
results[idx] = 0.0
|
| 764 |
+
|
| 765 |
+
return results
|
| 766 |
+
|
| 767 |
+
|
| 768 |
+
def execution_reward_batch_parallel_by_db(rollouts, max_workers: int = 20):
|
| 769 |
+
"""
|
| 770 |
+
1 thread per DB path. Reuses a single readonly connection per DB worker.
|
| 771 |
+
Preserves input order.
|
| 772 |
+
"""
|
| 773 |
+
if not rollouts:
|
| 774 |
+
return []
|
| 775 |
+
|
| 776 |
+
by_db = {}
|
| 777 |
+
for idx, (pred_sql, db_path, gold_sql) in enumerate(rollouts):
|
| 778 |
+
by_db.setdefault(db_path, []).append((idx, pred_sql, gold_sql))
|
| 779 |
+
|
| 780 |
+
results = [0.0 for _ in range(len(rollouts))]
|
| 781 |
+
|
| 782 |
+
def _reward_with_conn(conn: sqlite3.Connection, pred_sql: str, db_path: str, gold_sql: str) -> float:
|
| 783 |
+
try:
|
| 784 |
+
sql = _normalize_sql(pred_sql)
|
| 785 |
+
gold = _normalize_sql(gold_sql)
|
| 786 |
+
|
| 787 |
+
if not is_valid_select(sql):
|
| 788 |
+
return -1.0
|
| 789 |
+
|
| 790 |
+
reward = -0.2
|
| 791 |
+
|
| 792 |
+
if USE_SCHEMA_VALIDATION:
|
| 793 |
+
valid, _ = validate_sql_schema(sql, db_path)
|
| 794 |
+
if not valid:
|
| 795 |
+
error_type = classify_error(sql)
|
| 796 |
+
log_error("UNKNOWN", sql, "schema_invalid", error_type)
|
| 797 |
+
return 0.1
|
| 798 |
+
|
| 799 |
+
pred_res = execute_sql_cached_conn(conn, db_path, sql)
|
| 800 |
+
if pred_res == EXECUTION_ERROR:
|
| 801 |
+
error_type = classify_error(sql)
|
| 802 |
+
log_error("UNKNOWN", sql, "execution_error", error_type)
|
| 803 |
+
return 0.1
|
| 804 |
+
|
| 805 |
+
reward += 0.2
|
| 806 |
+
gold_res = execute_sql_cached_conn(conn, db_path, gold)
|
| 807 |
+
if gold_res == EXECUTION_ERROR:
|
| 808 |
+
return 0.1
|
| 809 |
+
if pred_res == gold_res:
|
| 810 |
+
return 1.0
|
| 811 |
+
return max(-1.0, min(1.0, reward))
|
| 812 |
+
except Exception:
|
| 813 |
+
return 0.0
|
| 814 |
+
|
| 815 |
+
def _worker(db_path: str, items):
|
| 816 |
+
conn = _connect_readonly(db_path)
|
| 817 |
+
try:
|
| 818 |
+
for idx, pred, gold in items:
|
| 819 |
+
results[idx] = _reward_with_conn(conn, pred, db_path, gold)
|
| 820 |
+
finally:
|
| 821 |
+
try:
|
| 822 |
+
conn.close()
|
| 823 |
+
except Exception:
|
| 824 |
+
pass
|
| 825 |
+
|
| 826 |
+
with ThreadPoolExecutor(max_workers=int(max_workers)) as ex:
|
| 827 |
+
futures = [ex.submit(_worker, db_path, items) for db_path, items in by_db.items()]
|
| 828 |
+
for fut in as_completed(futures):
|
| 829 |
+
fut.result()
|
| 830 |
+
|
| 831 |
+
return results
|