Spaces:
Running
Running
| # import json | |
| # import sqlite3 | |
| # import argparse | |
| # from pathlib import Path | |
| # import torch | |
| # from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
| # from peft import PeftModel | |
| # # ---------------- PROMPT (IDENTICAL TO TRAINING) ---------------- | |
| # def build_prompt(question, schema): | |
| # return f""" | |
| # Database Schema: | |
| # {schema} | |
| # Translate English to SQL: | |
| # {question} | |
| # SQL: | |
| # """ | |
| # # ---------------- LOAD SCHEMA ---------------- | |
| # def load_schema(db_path): | |
| # conn = sqlite3.connect(db_path) | |
| # cursor = conn.cursor() | |
| # tables = cursor.execute( | |
| # "SELECT name FROM sqlite_master WHERE type='table';" | |
| # ).fetchall() | |
| # schema = "" | |
| # for (table,) in tables: | |
| # cols = cursor.execute(f"PRAGMA table_info({table});").fetchall() | |
| # col_names = [c[1] for c in cols] | |
| # schema += f"{table}({', '.join(col_names)})\n" | |
| # conn.close() | |
| # return schema | |
| # # ---------------- EXECUTION CHECK ---------------- | |
| # def execution_match(pred_sql, gold_sql, db_path): | |
| # try: | |
| # conn = sqlite3.connect(db_path) | |
| # cur = conn.cursor() | |
| # cur.execute(pred_sql) | |
| # pred = cur.fetchall() | |
| # cur.execute(gold_sql) | |
| # gold = cur.fetchall() | |
| # conn.close() | |
| # return pred == gold | |
| # except Exception: | |
| # return False | |
| # # ---------------- MAIN ---------------- | |
| # def main(): | |
| # parser = argparse.ArgumentParser() | |
| # parser.add_argument("--adapter", type=str, required=True) | |
| # parser.add_argument("--num_samples", type=int, default=1034) | |
| # args = parser.parse_args() | |
| # project_root = Path(__file__).resolve().parents[1] | |
| # dev_json = project_root / "data" / "dev.json" | |
| # db_root = project_root / "data" / "database" | |
| # device = "mps" if torch.backends.mps.is_available() else "cpu" | |
| # # load model | |
| # base_model = "Salesforce/codet5-base" | |
| # tokenizer = AutoTokenizer.from_pretrained(args.adapter) | |
| # base = AutoModelForSeq2SeqLM.from_pretrained(base_model).to(device) | |
| # model = PeftModel.from_pretrained(base, args.adapter).to(device) | |
| # model = model.merge_and_unload() | |
| # with open(dev_json) as f: | |
| # dev = json.load(f)[: args.num_samples] | |
| # correct = 0 | |
| # print(f"Evaluating {len(dev)} examples...\n") | |
| # for i, ex in enumerate(dev, 1): | |
| # question = ex["question"] | |
| # db_id = ex["db_id"] | |
| # gold_sql = ex["query"] | |
| # db_path = db_root / db_id / f"{db_id}.sqlite" | |
| # schema = load_schema(db_path) | |
| # prompt = build_prompt(question, schema) | |
| # inputs = tokenizer(prompt, return_tensors="pt").to(device) | |
| # with torch.no_grad(): | |
| # outputs = model.generate( | |
| # **inputs, | |
| # max_new_tokens=80, | |
| # do_sample=False, | |
| # num_beams=4, | |
| # ) | |
| # pred_sql = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # if "SQL:" in pred_sql: | |
| # pred_sql = pred_sql.split("SQL:")[-1].strip() | |
| # match = execution_match(pred_sql, gold_sql, db_path) | |
| # if match: | |
| # correct += 1 | |
| # if i % 10 == 0: | |
| # print(f"{i}/{len(dev)} | Acc: {correct/i:.3f}") | |
| # print("\n=============================") | |
| # print(f"FINAL EXECUTION ACCURACY: {correct/len(dev)*100:.2f}%") | |
| # print("=============================") | |
| # if __name__ == "__main__": | |
| # main() | |
| # import json | |
| # import sqlite3 | |
| # import argparse | |
| # import time | |
| # from pathlib import Path | |
| # import torch | |
| # from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
| # from peft import PeftModel | |
| # # ---------------- PROMPT (IDENTICAL TO TRAINING) ---------------- | |
| # def build_prompt(question, schema): | |
| # return f""" | |
| # Database Schema: | |
| # {schema} | |
| # Translate English to SQL: | |
| # {question} | |
| # SQL: | |
| # """ | |
| # # ---------------- LOAD SCHEMA ---------------- | |
| # def load_schema(db_path): | |
| # conn = sqlite3.connect(db_path) | |
| # cursor = conn.cursor() | |
| # tables = cursor.execute( | |
| # "SELECT name FROM sqlite_master WHERE type='table';" | |
| # ).fetchall() | |
| # schema = "" | |
| # for (table,) in tables: | |
| # cols = cursor.execute(f"PRAGMA table_info({table});").fetchall() | |
| # col_names = [c[1] for c in cols] | |
| # schema += f"{table}({', '.join(col_names)})\n" | |
| # conn.close() | |
| # return schema | |
| # # ---------------- EXECUTION CHECK WITH TIMEOUT ---------------- | |
| # def execution_match(pred_sql, gold_sql, db_path): | |
| # try: | |
| # conn = sqlite3.connect(db_path) | |
| # # --- 5-SECOND TIMEOUT SO EVALUATION DOESN'T FREEZE --- | |
| # start_time = time.monotonic() | |
| # def timeout_handler(): | |
| # return 1 if (time.monotonic() - start_time) > 5.0 else 0 | |
| # conn.set_progress_handler(timeout_handler, 10000) | |
| # cur = conn.cursor() | |
| # cur.execute(pred_sql) | |
| # pred = cur.fetchall() | |
| # cur.execute(gold_sql) | |
| # gold = cur.fetchall() | |
| # conn.close() | |
| # return pred == gold | |
| # except Exception: | |
| # return False | |
| # # ---------------- MAIN ---------------- | |
| # def main(): | |
| # parser = argparse.ArgumentParser() | |
| # parser.add_argument("--adapter", type=str, required=True) | |
| # parser.add_argument("--num_samples", type=int, default=1034) | |
| # args = parser.parse_args() | |
| # project_root = Path(__file__).resolve().parents[1] | |
| # dev_json = project_root / "data" / "dev.json" | |
| # db_root = project_root / "data" / "database" | |
| # # 🎯 Added CUDA support for Nvidia GPUs | |
| # device = "mps" if torch.backends.mps.is_available() else ("cuda" if torch.cuda.is_available() else "cpu") | |
| # # load model | |
| # base_model = "Salesforce/codet5-base" | |
| # print(f"Loading Base: {base_model}") | |
| # print(f"Loading Adapter: {args.adapter}") | |
| # tokenizer = AutoTokenizer.from_pretrained(args.adapter) | |
| # base = AutoModelForSeq2SeqLM.from_pretrained(base_model).to(device) | |
| # model = PeftModel.from_pretrained(base, args.adapter).to(device) | |
| # model = model.merge_and_unload() | |
| # with open(dev_json) as f: | |
| # dev = json.load(f)[: args.num_samples] | |
| # correct = 0 | |
| # print(f"Evaluating {len(dev)} examples...\n") | |
| # for i, ex in enumerate(dev, 1): | |
| # question = ex["question"] | |
| # db_id = ex["db_id"] | |
| # gold_sql = ex["query"] | |
| # db_path = db_root / db_id / f"{db_id}.sqlite" | |
| # schema = load_schema(db_path) | |
| # prompt = build_prompt(question, schema) | |
| # inputs = tokenizer(prompt, return_tensors="pt").to(device) | |
| # with torch.no_grad(): | |
| # outputs = model.generate( | |
| # **inputs, | |
| # max_new_tokens=80, | |
| # do_sample=False, | |
| # num_beams=4, | |
| # ) | |
| # pred_sql = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # if "SQL:" in pred_sql: | |
| # pred_sql = pred_sql.split("SQL:")[-1].strip() | |
| # match = execution_match(pred_sql, gold_sql, db_path) | |
| # if match: | |
| # correct += 1 | |
| # if i % 10 == 0: | |
| # print(f"{i}/{len(dev)} | Acc: {correct/i:.3f}") | |
| # print("\n=============================") | |
| # print(f"FINAL EXECUTION ACCURACY: {correct/len(dev)*100:.2f}%") | |
| # print("=============================") | |
| # if __name__ == "__main__": | |
| # main() | |
| import json | |
| import subprocess | |
| import sys | |
| import argparse | |
| import random | |
| import sqlite3 | |
| import time | |
| import re | |
| from pathlib import Path | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
| from peft import PeftModel | |
| # Assuming you have a prompting.py that has encode_prompt | |
| from prompting import encode_prompt | |
| # ------------------------------- | |
| # LIVE CHECK HELPERS | |
| # ------------------------------- | |
| def normalize_sql(sql): | |
| """Basic normalization for the live progress bar.""" | |
| sql = sql.replace('"', "'") | |
| sql = re.sub(r"\s+", " ", sql) | |
| return sql.strip().lower().rstrip(";") | |
| def check_execution(pred_sql, gold_sql, db_path): | |
| """Basic execution check for the live progress bar.""" | |
| try: | |
| conn = sqlite3.connect(db_path) | |
| conn.text_factory = lambda b: b.decode(errors='ignore') | |
| # 2-second timeout so the live tracker doesn't freeze forever | |
| start_time = time.monotonic() | |
| def timeout_handler(): | |
| return 1 if (time.monotonic() - start_time) > 2.0 else 0 | |
| conn.set_progress_handler(timeout_handler, 10000) | |
| cursor = conn.cursor() | |
| cursor.execute(pred_sql) | |
| pred_res = cursor.fetchall() | |
| cursor.execute(gold_sql) | |
| gold_res = cursor.fetchall() | |
| conn.close() | |
| # Simple sorted check for the live tracker | |
| return sorted(pred_res) == sorted(gold_res) | |
| except Exception: | |
| return False | |
| # ------------------------------- | |
| # SPIDER PARSER | |
| # ------------------------------- | |
| def _parse_spider_accuracy(stdout: str, metric_type: str) -> float | None: | |
| for line in stdout.splitlines(): | |
| if metric_type == "exec" and line.strip().startswith("execution"): | |
| try: return float(line.split()[-1]) | |
| except: pass | |
| elif metric_type == "match" and line.strip().startswith("exact"): | |
| try: return float(line.split()[-1]) | |
| except: pass | |
| return None | |
| # ------------------------------- | |
| # MAIN | |
| # ------------------------------- | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--adapter", type=str, required=True, help="Path to your SFT or RLHF checkpoint") | |
| parser.add_argument("--num_samples", type=int, default=700, help="Number of samples to evaluate") | |
| parser.add_argument("--shuffle_dev", action="store_true") | |
| parser.add_argument("--shuffle_seed", type=int, default=42) | |
| args = parser.parse_args() | |
| project_root = Path(__file__).resolve().parents[1] | |
| adapter_dir = project_root / args.adapter | |
| db_root = project_root / "data" / "database" | |
| table_json = project_root / "data" / "tables.json" | |
| dev_json = project_root / "data" / "dev.json" | |
| pred_path = project_root / "temp_predictions.txt" | |
| temp_gold_path = project_root / "temp_gold.sql" | |
| if not adapter_dir.exists(): | |
| raise FileNotFoundError(f"Missing adapter dir: {adapter_dir}") | |
| device = "mps" if torch.backends.mps.is_available() else ("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"Using device: {device}") | |
| BASE_MODEL = "Salesforce/codet5-base" | |
| tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| print(f"Loading Model: {args.adapter}...") | |
| base = AutoModelForSeq2SeqLM.from_pretrained(BASE_MODEL).to(device) | |
| model = PeftModel.from_pretrained(base, str(adapter_dir)).to(device) | |
| model = model.merge_and_unload() | |
| model.eval() | |
| with dev_json.open() as f: | |
| dev = json.load(f) | |
| if args.shuffle_dev: | |
| rng = random.Random(args.shuffle_seed) | |
| rng.shuffle(dev) | |
| dev = dev[: args.num_samples] | |
| total = len(dev) | |
| gen_kwargs = dict( | |
| max_new_tokens=160, | |
| num_beams=4, | |
| do_sample=False, | |
| early_stopping=True, | |
| pad_token_id=tokenizer.pad_token_id, | |
| eos_token_id=tokenizer.eos_token_id, | |
| ) | |
| print(f"\n🚀 Generating and live-tracking {total} samples...\n") | |
| em_correct = 0 | |
| ex_correct = 0 | |
| with pred_path.open("w") as out_pred, temp_gold_path.open("w") as out_gold, torch.no_grad(): | |
| for i, ex in enumerate(dev, start=1): | |
| db_id = ex["db_id"] | |
| question = ex["question"] | |
| gold_query = ex["query"] | |
| db_path = db_root / db_id / f"{db_id}.sqlite" | |
| # Generate | |
| input_ids = encode_prompt(tokenizer, question, db_id, device=device, max_input_tokens=512) | |
| input_ids = input_ids.unsqueeze(0).to(device) | |
| attention_mask = (input_ids != tokenizer.pad_token_id).long().to(device) | |
| outputs = model.generate(input_ids=input_ids, attention_mask=attention_mask, **gen_kwargs) | |
| pred_sql = tokenizer.decode(outputs[0], skip_special_tokens=True).strip() | |
| # Write to files for official spider eval later | |
| out_pred.write(f"{pred_sql}\n") | |
| out_gold.write(f"{gold_query}\t{db_id}\n") | |
| # --- LIVE TRACKING CHECKS --- | |
| if normalize_sql(pred_sql) == normalize_sql(gold_query): | |
| em_correct += 1 | |
| if check_execution(pred_sql, gold_query, db_path): | |
| ex_correct += 1 | |
| # Print progress every 50 loops | |
| if i % 10 == 0 or i == total: | |
| print(f"Progress: {i}/{total} | Current EM: {(em_correct/i)*100:.2f}% | Current EX: {(ex_correct/i)*100:.2f}%") | |
| print("\nGeneration finished. Running Official Spider Evaluations for final numbers...\n") | |
| eval_script = project_root / "spider_eval" / "evaluation.py" | |
| # 1. RUN EXACT MATCH EVAL | |
| cmd_match = [ | |
| sys.executable, str(eval_script), | |
| "--gold", str(temp_gold_path), | |
| "--pred", str(pred_path), | |
| "--etype", "match", | |
| "--db", str(db_root), | |
| "--table", str(table_json), | |
| ] | |
| proc_match = subprocess.run(cmd_match, capture_output=True, text=True) | |
| exact_acc = _parse_spider_accuracy(proc_match.stdout, "match") | |
| # 2. RUN EXECUTION EVAL | |
| cmd_exec = [ | |
| sys.executable, str(eval_script), | |
| "--gold", str(temp_gold_path), | |
| "--pred", str(pred_path), | |
| "--etype", "exec", | |
| "--db", str(db_root), | |
| "--table", str(table_json), | |
| ] | |
| proc_exec = subprocess.run(cmd_exec, capture_output=True, text=True) | |
| exec_acc = _parse_spider_accuracy(proc_exec.stdout, "exec") | |
| print("==========================================") | |
| print(f"🎯 OFFICIAL SPIDER RESULTS FOR: {args.adapter}") | |
| print("==========================================") | |
| if exact_acc is not None: | |
| print(f"Exact Set Match Accuracy : {exact_acc*100:.2f}%") | |
| else: | |
| print("Exact Set Match Accuracy : Could not parse output") | |
| if exec_acc is not None: | |
| print(f"Execution Accuracy : {exec_acc*100:.2f}%") | |
| else: | |
| print("Execution Accuracy : Could not parse output") | |
| print("==========================================\n") | |
| if __name__ == "__main__": | |
| main() |