DuckDB-SQL-Eval / evaluation_logic.py
tdoehmen's picture
no subprocess
b9dc6d6
raw
history blame
7.44 kB
import os
import sys
from pathlib import Path
from datetime import datetime
import json
import traceback
# Add the necessary directories to the Python path
current_dir = Path(__file__).resolve().parent
duckdb_nsql_dir = current_dir / 'duckdb-nsql'
eval_dir = duckdb_nsql_dir / 'eval'
sys.path.extend([str(current_dir), str(duckdb_nsql_dir), str(eval_dir)])
# Import necessary functions and classes
from eval.predict import get_manifest, DefaultLoader, PROMPT_FORMATTERS, generate_sql
from eval.evaluate import evaluate, compute_metrics, get_to_print
from eval.evaluate import test_suite_evaluation, read_tables_json
from eval.schema import TextToSQLParams, Table
AVAILABLE_PROMPT_FORMATS = list(PROMPT_FORMATTERS.keys())
def run_prediction(model_name, prompt_format, output_file):
dataset_path = str(eval_dir / "data/dev.json")
table_meta_path = str(eval_dir / "data/tables.json")
stop_tokens = [';']
max_tokens = 30000
temperature = 0.1
num_beams = -1
manifest_client = "openrouter"
manifest_engine = model_name
manifest_connection = "http://localhost:5000"
overwrite_manifest = True
parallel = False
yield "Starting prediction..."
try:
# Initialize necessary components
data_formatter = DefaultLoader()
prompt_formatter = PROMPT_FORMATTERS[prompt_format]()
# Load manifest
manifest = get_manifest(
manifest_client=manifest_client,
manifest_connection=manifest_connection,
manifest_engine=manifest_engine,
)
# Load data
data = data_formatter.load_data(dataset_path)
db_to_tables = data_formatter.load_table_metadata(table_meta_path)
# Prepare input for generate_sql
text_to_sql_inputs = []
for input_question in data:
question = input_question["question"]
db_id = input_question.get("db_id", "none")
if db_id != "none":
table_params = list(db_to_tables.get(db_id, {}).values())
else:
table_params = []
if len(table_params) == 0:
yield f"[red] WARNING: No tables found for {db_id} [/red]"
text_to_sql_inputs.append(TextToSQLParams(
instruction=question,
database=db_id,
tables=table_params,
))
# Generate SQL
generated_sqls = generate_sql(
manifest=manifest,
text_to_sql_in=text_to_sql_inputs,
retrieved_docs=[[] for _ in text_to_sql_inputs], # Assuming no retrieved docs
prompt_formatter=prompt_formatter,
stop_tokens=stop_tokens,
overwrite_manifest=overwrite_manifest,
max_tokens=max_tokens,
temperature=temperature,
num_beams=num_beams,
parallel=parallel
)
# Save results
with output_file.open('w') as f:
for original_data, (sql, _) in zip(data, generated_sqls):
output = {**original_data, "pred": sql}
json.dump(output, f)
f.write('\n')
yield f"Prediction completed. Results saved to {output_file}"
except Exception as e:
yield f"Prediction failed with error: {str(e)}"
yield f"Error traceback: {traceback.format_exc()}"
def run_evaluation(model_name, prompt_format="duckdbinstgraniteshort"):
if "OPENROUTER_API_KEY" not in os.environ:
yield "Error: OPENROUTER_API_KEY not found in environment variables."
return
try:
# Set up the arguments
dataset_path = str(eval_dir / "data/dev.json")
table_meta_path = str(eval_dir / "data/tables.json")
output_dir = eval_dir / "output"
yield f"Using model: {model_name}"
yield f"Using prompt format: {prompt_format}"
output_file = output_dir / f"{prompt_format}_0docs_{model_name.trim().replace('/', '_')}_dev_{datetime.now().strftime('%y-%m-%d')}.json"
# Ensure the output directory exists
output_dir.mkdir(parents=True, exist_ok=True)
if output_file.exists():
yield f"Prediction file already exists: {output_file}"
yield "Skipping prediction step and proceeding to evaluation."
else:
# Run prediction
for output in run_prediction(model_name, prompt_format, output_file):
yield output
# Run evaluation
yield "Starting evaluation..."
# Set up evaluation arguments
gold_path = Path(dataset_path)
db_dir = str(eval_dir / "data/databases/")
tables_path = Path(table_meta_path)
kmaps = test_suite_evaluation.build_foreign_key_map_from_json(str(tables_path))
db_schemas = read_tables_json(str(tables_path))
gold_sqls_dict = json.load(gold_path.open("r", encoding="utf-8"))
pred_sqls_dict = [json.loads(l) for l in output_file.open("r").readlines()]
gold_sqls = [p.get("query", p.get("sql", "")) for p in gold_sqls_dict]
setup_sqls = [p["setup_sql"] for p in gold_sqls_dict]
validate_sqls = [p["validation_sql"] for p in gold_sqls_dict]
gold_dbs = [p.get("db_id", p.get("db", "")) for p in gold_sqls_dict]
pred_sqls = [p["pred"] for p in pred_sqls_dict]
categories = [p.get("category", "") for p in gold_sqls_dict]
yield "Computing metrics..."
metrics = compute_metrics(
gold_sqls=gold_sqls,
pred_sqls=pred_sqls,
gold_dbs=gold_dbs,
setup_sqls=setup_sqls,
validate_sqls=validate_sqls,
kmaps=kmaps,
db_schemas=db_schemas,
database_dir=db_dir,
lowercase_schema_match=False,
model_name=model_name,
categories=categories,
)
yield "Evaluation completed."
if metrics:
yield "Overall Results:"
overall_metrics = metrics['exec']['all']
yield f"Count: {overall_metrics['count']}"
yield f"Execution Accuracy: {overall_metrics['exec']:.3f}"
yield f"Exact Match Accuracy: {overall_metrics['exact']:.3f}"
yield f"Equality: {metrics['equality']['equality']:.3f}"
yield f"Edit Distance: {metrics['edit_distance']['edit_distance']:.3f}"
yield "\nResults by Category:"
categories = ['easy', 'medium', 'hard', 'duckdb', 'ddl', 'all']
for category in categories:
if category in metrics['exec']:
yield f"\n{category}:"
category_metrics = metrics['exec'][category]
yield f"Count: {category_metrics['count']}"
yield f"Execution Accuracy: {category_metrics['exec']:.3f}"
else:
yield f"\n{category}: No data available"
else:
yield "No evaluation metrics returned."
except Exception as e:
yield f"An unexpected error occurred: {str(e)}"
yield f"Error traceback: {traceback.format_exc()}"
if __name__ == "__main__":
model_name = input("Enter the model name: ")
prompt_format = input("Enter the prompt format (default is duckdbinstgraniteshort): ") or "duckdbinstgraniteshort"
for result in run_evaluation(model_name, prompt_format):
print(result, flush=True)