text2sql-demo / src /eval_rl_t5.py
tjhalanigrid's picture
Add src folder
dc59b01
# import sys
# import os
# sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
# import json
# import subprocess
# import argparse
# from pathlib import Path
# import torch
# from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
# from peft import PeftModel
# # IMPORTANT: must match training prompt format
# from prompting import build_prompt
# from schema_utils import get_schema as get_db_schema
# def _parse_exec_accuracy(stdout: str):
# for line in stdout.splitlines():
# if line.strip().startswith("execution"):
# parts = line.split()
# try:
# return float(parts[-1])
# except Exception:
# return None
# return None
# def main():
# parser = argparse.ArgumentParser()
# parser.add_argument("--adapter", type=str, default="checkpoints/best_rlhf_model")
# parser.add_argument("--num_samples", type=int, default=200)
# args = parser.parse_args()
# project_root = Path(__file__).resolve().parents[1]
# adapter_dir = project_root / args.adapter
# if not adapter_dir.exists():
# raise FileNotFoundError(f"Adapter not found: {adapter_dir}")
# db_root = project_root / "data" / "database"
# table_json = project_root / "data" / "tables.json"
# dev_json = project_root / "data" / "dev.json"
# gold_sql = project_root / "data" / "dev_gold.sql"
# pred_path = project_root / "predictions_rl.txt"
# device = "mps" if torch.backends.mps.is_available() else "cpu"
# # ---- LOAD MODEL (CodeT5 + LoRA) ----
# base_model = "Salesforce/codet5-base"
# tokenizer = AutoTokenizer.from_pretrained(str(adapter_dir))
# base = AutoModelForSeq2SeqLM.from_pretrained(base_model).to(device)
# model = PeftModel.from_pretrained(base, str(adapter_dir)).to(device)
# # merge LoRA for faster inference
# model = model.merge_and_unload()
# model.eval()
# model.config.use_cache = True
# if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None:
# tokenizer.pad_token = tokenizer.eos_token
# # ---- LOAD DATA ----
# with dev_json.open() as f:
# dev = json.load(f)
# dev = dev[: args.num_samples]
# gen_kwargs = dict(
# max_new_tokens=120,
# do_sample=False,
# num_beams=1,
# pad_token_id=tokenizer.pad_token_id,
# eos_token_id=tokenizer.eos_token_id,
# )
# print(f"Generating {len(dev)} predictions...")
# with pred_path.open("w") as out_f, torch.no_grad():
# for i, ex in enumerate(dev, start=1):
# db_id = ex["db_id"]
# question = ex["question"]
# db_path = db_root / db_id / f"{db_id}.sqlite"
# schema = get_db_schema(str(db_path))
# prompt = build_prompt(question, schema, use_schema=True)
# inputs = tokenizer(
# prompt,
# return_tensors="pt",
# truncation=True,
# max_length=512
# ).to(device)
# out = model.generate(**inputs, **gen_kwargs)
# pred_sql = tokenizer.decode(out[0], skip_special_tokens=True).strip()
# out_f.write(f"{pred_sql}\t{db_id}\n")
# if i % 20 == 0 or i == len(dev):
# print(f"{i}/{len(dev)} done")
# # ---- SPIDER OFFICIAL EVAL ----
# eval_script = project_root / "spider_eval" / "evaluation.py"
# cmd = [
# sys.executable,
# str(eval_script),
# "--gold",
# str(gold_sql),
# "--pred",
# str(pred_path),
# "--etype",
# "exec",
# "--db",
# str(db_root),
# "--table",
# str(table_json),
# ]
# print("\nRunning Spider execution evaluation...\n")
# proc = subprocess.run(cmd, capture_output=True, text=True)
# if proc.returncode != 0:
# print(proc.stdout)
# print(proc.stderr)
# sys.exit(proc.returncode)
# print(proc.stdout)
# acc = _parse_exec_accuracy(proc.stdout)
# if acc is not None:
# print(f"\nFINAL EXECUTION ACCURACY: {acc*100:.2f}%")
# else:
# print("Could not parse execution accuracy")
# if __name__ == "__main__":
# main()
import json
import sqlite3
import argparse
import time
from pathlib import Path
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from peft import PeftModel
# ---------------- PROMPT (FIXED TO PERFECTLY MATCH RLHF TRAINING) ----------------
def build_prompt(question, schema):
return f"translate English to SQL:\n\nSchema:\n{schema}\n\nQuestion:\n{question}\n\nSQL:"
# ---------------- LOAD SCHEMA (FIXED TO MATCH TRAINING FORMAT) ----------------
def load_schema(db_path):
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
tables = cursor.execute(
"SELECT name FROM sqlite_master WHERE type='table';"
).fetchall()
schema = ""
for (table,) in tables:
cols = cursor.execute(f"PRAGMA table_info({table});").fetchall()
col_names = [c[1] for c in cols]
# Space-separated, not newline-separated, just like the RLHF script
schema += f"{table}({', '.join(col_names)}) "
conn.close()
return schema.strip()
# ---------------- EXECUTION CHECK WITH TIMEOUT ----------------
def execution_match(pred_sql, gold_sql, db_path):
try:
conn = sqlite3.connect(db_path)
# --- 5-SECOND TIMEOUT SO THE SCRIPT DOESN'T HANG ---
start_time = time.monotonic()
def timeout_handler():
return 1 if (time.monotonic() - start_time) > 5.0 else 0
conn.set_progress_handler(timeout_handler, 10000)
cur = conn.cursor()
cur.execute(pred_sql)
pred = cur.fetchall()
cur.execute(gold_sql)
gold = cur.fetchall()
conn.close()
return pred == gold
except Exception:
return False
# ---------------- MAIN ----------------
def main():
parser = argparse.ArgumentParser()
# 🎯 Set the default directly to your best RLHF model!
parser.add_argument("--adapter", type=str, default="checkpoints/rlhf_t5_best")
parser.add_argument("--num_samples", type=int, default=1000)
args = parser.parse_args()
project_root = Path(__file__).resolve().parents[1]
# Resolve adapter path safely
adapter_path = project_root / args.adapter
dev_json = project_root / "data" / "dev.json"
db_root = project_root / "data" / "database"
# 🎯 Added CUDA support
device = "mps" if torch.backends.mps.is_available() else ("cuda" if torch.cuda.is_available() else "cpu")
# load model
base_model = "t5-small"
print(f"Loading Base: {base_model}")
print(f"Loading Adapter: {adapter_path}")
tokenizer = AutoTokenizer.from_pretrained(str(adapter_path))
base = AutoModelForSeq2SeqLM.from_pretrained(base_model).to(device)
model = PeftModel.from_pretrained(base, str(adapter_path)).to(device)
model = model.merge_and_unload()
with open(dev_json) as f:
dev = json.load(f)[: args.num_samples]
correct = 0
print(f"Evaluating {len(dev)} examples...\n")
for i, ex in enumerate(dev, 1):
question = ex["question"]
db_id = ex["db_id"]
gold_sql = ex["query"]
db_path = db_root / db_id / f"{db_id}.sqlite"
schema = load_schema(db_path)
prompt = build_prompt(question, schema)
inputs = tokenizer(prompt, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=80,
do_sample=False,
num_beams=4,
)
pred_sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
if "SQL:" in pred_sql:
pred_sql = pred_sql.split("SQL:")[-1].strip()
match = execution_match(pred_sql, gold_sql, db_path)
if match:
correct += 1
if i % 10 == 0:
print(f"{i}/{len(dev)} | Acc: {correct/i:.3f}")
print("\n=============================")
print(f"FINAL EXECUTION ACCURACY: {correct/len(dev)*100:.2f}%")
print("=============================")
if __name__ == "__main__":
main()