Spaces:
Sleeping
Sleeping
| # import sys | |
| # import os | |
| # sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| # import json | |
| # import subprocess | |
| # import argparse | |
| # from pathlib import Path | |
| # import torch | |
| # from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
| # from peft import PeftModel | |
| # # IMPORTANT: must match training prompt format | |
| # from prompting import build_prompt | |
| # from schema_utils import get_schema as get_db_schema | |
| # def _parse_exec_accuracy(stdout: str): | |
| # for line in stdout.splitlines(): | |
| # if line.strip().startswith("execution"): | |
| # parts = line.split() | |
| # try: | |
| # return float(parts[-1]) | |
| # except Exception: | |
| # return None | |
| # return None | |
| # def main(): | |
| # parser = argparse.ArgumentParser() | |
| # parser.add_argument("--adapter", type=str, default="checkpoints/best_rlhf_model") | |
| # parser.add_argument("--num_samples", type=int, default=200) | |
| # args = parser.parse_args() | |
| # project_root = Path(__file__).resolve().parents[1] | |
| # adapter_dir = project_root / args.adapter | |
| # if not adapter_dir.exists(): | |
| # raise FileNotFoundError(f"Adapter not found: {adapter_dir}") | |
| # db_root = project_root / "data" / "database" | |
| # table_json = project_root / "data" / "tables.json" | |
| # dev_json = project_root / "data" / "dev.json" | |
| # gold_sql = project_root / "data" / "dev_gold.sql" | |
| # pred_path = project_root / "predictions_rl.txt" | |
| # device = "mps" if torch.backends.mps.is_available() else "cpu" | |
| # # ---- LOAD MODEL (CodeT5 + LoRA) ---- | |
| # base_model = "Salesforce/codet5-base" | |
| # tokenizer = AutoTokenizer.from_pretrained(str(adapter_dir)) | |
| # base = AutoModelForSeq2SeqLM.from_pretrained(base_model).to(device) | |
| # model = PeftModel.from_pretrained(base, str(adapter_dir)).to(device) | |
| # # merge LoRA for faster inference | |
| # model = model.merge_and_unload() | |
| # model.eval() | |
| # model.config.use_cache = True | |
| # if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None: | |
| # tokenizer.pad_token = tokenizer.eos_token | |
| # # ---- LOAD DATA ---- | |
| # with dev_json.open() as f: | |
| # dev = json.load(f) | |
| # dev = dev[: args.num_samples] | |
| # gen_kwargs = dict( | |
| # max_new_tokens=120, | |
| # do_sample=False, | |
| # num_beams=1, | |
| # pad_token_id=tokenizer.pad_token_id, | |
| # eos_token_id=tokenizer.eos_token_id, | |
| # ) | |
| # print(f"Generating {len(dev)} predictions...") | |
| # with pred_path.open("w") as out_f, torch.no_grad(): | |
| # for i, ex in enumerate(dev, start=1): | |
| # db_id = ex["db_id"] | |
| # question = ex["question"] | |
| # db_path = db_root / db_id / f"{db_id}.sqlite" | |
| # schema = get_db_schema(str(db_path)) | |
| # prompt = build_prompt(question, schema, use_schema=True) | |
| # inputs = tokenizer( | |
| # prompt, | |
| # return_tensors="pt", | |
| # truncation=True, | |
| # max_length=512 | |
| # ).to(device) | |
| # out = model.generate(**inputs, **gen_kwargs) | |
| # pred_sql = tokenizer.decode(out[0], skip_special_tokens=True).strip() | |
| # out_f.write(f"{pred_sql}\t{db_id}\n") | |
| # if i % 20 == 0 or i == len(dev): | |
| # print(f"{i}/{len(dev)} done") | |
| # # ---- SPIDER OFFICIAL EVAL ---- | |
| # eval_script = project_root / "spider_eval" / "evaluation.py" | |
| # cmd = [ | |
| # sys.executable, | |
| # str(eval_script), | |
| # "--gold", | |
| # str(gold_sql), | |
| # "--pred", | |
| # str(pred_path), | |
| # "--etype", | |
| # "exec", | |
| # "--db", | |
| # str(db_root), | |
| # "--table", | |
| # str(table_json), | |
| # ] | |
| # print("\nRunning Spider execution evaluation...\n") | |
| # proc = subprocess.run(cmd, capture_output=True, text=True) | |
| # if proc.returncode != 0: | |
| # print(proc.stdout) | |
| # print(proc.stderr) | |
| # sys.exit(proc.returncode) | |
| # print(proc.stdout) | |
| # acc = _parse_exec_accuracy(proc.stdout) | |
| # if acc is not None: | |
| # print(f"\nFINAL EXECUTION ACCURACY: {acc*100:.2f}%") | |
| # else: | |
| # print("Could not parse execution accuracy") | |
| # if __name__ == "__main__": | |
| # main() | |
| import json | |
| import sqlite3 | |
| import argparse | |
| import time | |
| from pathlib import Path | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
| from peft import PeftModel | |
| # ---------------- PROMPT (FIXED TO PERFECTLY MATCH RLHF TRAINING) ---------------- | |
| def build_prompt(question, schema): | |
| return f"translate English to SQL:\n\nSchema:\n{schema}\n\nQuestion:\n{question}\n\nSQL:" | |
| # ---------------- LOAD SCHEMA (FIXED TO MATCH TRAINING FORMAT) ---------------- | |
| def load_schema(db_path): | |
| conn = sqlite3.connect(db_path) | |
| cursor = conn.cursor() | |
| tables = cursor.execute( | |
| "SELECT name FROM sqlite_master WHERE type='table';" | |
| ).fetchall() | |
| schema = "" | |
| for (table,) in tables: | |
| cols = cursor.execute(f"PRAGMA table_info({table});").fetchall() | |
| col_names = [c[1] for c in cols] | |
| # Space-separated, not newline-separated, just like the RLHF script | |
| schema += f"{table}({', '.join(col_names)}) " | |
| conn.close() | |
| return schema.strip() | |
| # ---------------- EXECUTION CHECK WITH TIMEOUT ---------------- | |
| def execution_match(pred_sql, gold_sql, db_path): | |
| try: | |
| conn = sqlite3.connect(db_path) | |
| # --- 5-SECOND TIMEOUT SO THE SCRIPT DOESN'T HANG --- | |
| start_time = time.monotonic() | |
| def timeout_handler(): | |
| return 1 if (time.monotonic() - start_time) > 5.0 else 0 | |
| conn.set_progress_handler(timeout_handler, 10000) | |
| cur = conn.cursor() | |
| cur.execute(pred_sql) | |
| pred = cur.fetchall() | |
| cur.execute(gold_sql) | |
| gold = cur.fetchall() | |
| conn.close() | |
| return pred == gold | |
| except Exception: | |
| return False | |
| # ---------------- MAIN ---------------- | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| # 🎯 Set the default directly to your best RLHF model! | |
| parser.add_argument("--adapter", type=str, default="checkpoints/rlhf_t5_best") | |
| parser.add_argument("--num_samples", type=int, default=1000) | |
| args = parser.parse_args() | |
| project_root = Path(__file__).resolve().parents[1] | |
| # Resolve adapter path safely | |
| adapter_path = project_root / args.adapter | |
| dev_json = project_root / "data" / "dev.json" | |
| db_root = project_root / "data" / "database" | |
| # 🎯 Added CUDA support | |
| device = "mps" if torch.backends.mps.is_available() else ("cuda" if torch.cuda.is_available() else "cpu") | |
| # load model | |
| base_model = "t5-small" | |
| print(f"Loading Base: {base_model}") | |
| print(f"Loading Adapter: {adapter_path}") | |
| tokenizer = AutoTokenizer.from_pretrained(str(adapter_path)) | |
| base = AutoModelForSeq2SeqLM.from_pretrained(base_model).to(device) | |
| model = PeftModel.from_pretrained(base, str(adapter_path)).to(device) | |
| model = model.merge_and_unload() | |
| with open(dev_json) as f: | |
| dev = json.load(f)[: args.num_samples] | |
| correct = 0 | |
| print(f"Evaluating {len(dev)} examples...\n") | |
| for i, ex in enumerate(dev, 1): | |
| question = ex["question"] | |
| db_id = ex["db_id"] | |
| gold_sql = ex["query"] | |
| db_path = db_root / db_id / f"{db_id}.sqlite" | |
| schema = load_schema(db_path) | |
| prompt = build_prompt(question, schema) | |
| inputs = tokenizer(prompt, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=80, | |
| do_sample=False, | |
| num_beams=4, | |
| ) | |
| pred_sql = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| if "SQL:" in pred_sql: | |
| pred_sql = pred_sql.split("SQL:")[-1].strip() | |
| match = execution_match(pred_sql, gold_sql, db_path) | |
| if match: | |
| correct += 1 | |
| if i % 10 == 0: | |
| print(f"{i}/{len(dev)} | Acc: {correct/i:.3f}") | |
| print("\n=============================") | |
| print(f"FINAL EXECUTION ACCURACY: {correct/len(dev)*100:.2f}%") | |
| print("=============================") | |
| if __name__ == "__main__": | |
| main() |