DuckDB-SQL-Eval / evaluation_logic.py
tdoehmen's picture
fixed extraction and cleaned output
1da0a65
raw
history blame
10.3 kB
import os
import sys
from pathlib import Path
from datetime import datetime
import json
import traceback
import uuid
from huggingface_hub import CommitScheduler
current_dir = Path(__file__).resolve().parent
duckdb_nsql_dir = current_dir / 'duckdb-nsql'
eval_dir = duckdb_nsql_dir / 'eval'
sys.path.extend([str(current_dir), str(duckdb_nsql_dir), str(eval_dir)])
from eval.predict import get_manifest, DefaultLoader, PROMPT_FORMATTERS, generate_sql
from eval.evaluate import evaluate, compute_metrics, get_to_print
from eval.evaluate import test_suite_evaluation, read_tables_json
from eval.schema import TextToSQLParams, Table
AVAILABLE_PROMPT_FORMATS = list(PROMPT_FORMATTERS.keys())
prediction_folder = Path("prediction_results/")
evaluation_folder = Path("evaluation_results/")
file_uuid = uuid.uuid4()
prediction_scheduler = CommitScheduler(
repo_id="sql-console/duckdb-nsql-predictions",
repo_type="dataset",
folder_path=prediction_folder,
path_in_repo="data",
every=10,
)
evaluation_scheduler = CommitScheduler(
repo_id="sql-console/duckdb-nsql-scores",
repo_type="dataset",
folder_path=evaluation_folder,
path_in_repo="data",
every=10,
)
def save_prediction(inference_api, model_name, prompt_format, question, generated_sql):
prediction_file = prediction_folder / f"prediction_{file_uuid}.json"
prediction_folder.mkdir(parents=True, exist_ok=True)
with prediction_scheduler.lock:
with prediction_file.open("a") as f:
json.dump({
"inference_api": inference_api,
"model_name": model_name,
"prompt_format": prompt_format,
"question": question,
"generated_sql": generated_sql,
"timestamp": datetime.now().isoformat()
}, f)
def save_evaluation(inference_api, model_name, prompt_format, custom_prompt, metrics):
evaluation_file = evaluation_folder / f"evaluation_{file_uuid}.json"
evaluation_folder.mkdir(parents=True, exist_ok=True)
# Extract and flatten the category-specific execution metrics
categories = ['easy', 'medium', 'hard', 'duckdb', 'ddl', 'all']
flattened_metrics = {
"inference_api": inference_api,
"model_name": model_name,
"prompt_format": prompt_format,
"custom_prompt": str(custom_prompt) if prompt_format.startswith("custom") else "",
"timestamp": datetime.now().isoformat()
}
# Flatten each category's metrics into separate columns
for category in categories:
if category in metrics['exec']:
category_metrics = metrics['exec'][category]
flattened_metrics[f"{category}_count"] = category_metrics['count']
flattened_metrics[f"{category}_execution_accuracy"] = category_metrics['exec']
else:
flattened_metrics[f"{category}_count"] = 0
flattened_metrics[f"{category}_execution_accuracy"] = 0.0
with evaluation_scheduler.lock:
with evaluation_file.open("a") as f:
json.dump(flattened_metrics, f)
f.write('\n')
def run_prediction(inference_api, model_name, prompt_format, custom_prompt, output_file):
dataset_path = str(eval_dir / "data/dev.json")
table_meta_path = str(eval_dir / "data/tables.json")
stop_tokens = ['`<|dummy|>`']
max_tokens = 1000
temperature = 0
num_beams = -1
manifest_client = inference_api
manifest_engine = model_name
manifest_connection = "http://localhost:5000"
overwrite_manifest = True
parallel = False
yield "Starting prediction..."
try:
# Initialize necessary components
data_formatter = DefaultLoader()
if prompt_format.startswith("custom"):
prompt_formatter_cls = PROMPT_FORMATTERS["custom"]
prompt_formatter_cls.PROMPT_TEMPLATE = custom_prompt
prompt_formatter = prompt_formatter_cls()
else:
prompt_formatter = PROMPT_FORMATTERS[prompt_format]()
# Load manifest
manifest = get_manifest(
manifest_client=manifest_client,
manifest_connection=manifest_connection,
manifest_engine=manifest_engine,
)
# Load data
data = data_formatter.load_data(dataset_path)
db_to_tables = data_formatter.load_table_metadata(table_meta_path)
# Prepare input for generate_sql
text_to_sql_inputs = []
for input_question in data:
question = input_question["question"]
db_id = input_question.get("db_id", "none")
if db_id != "none":
table_params = list(db_to_tables.get(db_id, {}).values())
else:
table_params = []
text_to_sql_inputs.append(TextToSQLParams(
instruction=question,
database=db_id,
tables=table_params,
))
# Generate SQL
generated_sqls = generate_sql(
manifest=manifest,
text_to_sql_in=text_to_sql_inputs,
retrieved_docs=[[] for _ in text_to_sql_inputs],
prompt_formatter=prompt_formatter,
stop_tokens=stop_tokens,
overwrite_manifest=overwrite_manifest,
max_tokens=max_tokens,
temperature=temperature,
num_beams=num_beams,
parallel=parallel
)
# Save results
output_file.parent.mkdir(parents=True, exist_ok=True)
with output_file.open('w') as f:
for original_data, (sql, _) in zip(data, generated_sqls):
output = {**original_data, "pred": sql}
json.dump(output, f)
f.write('\n')
# Save prediction to dataset
save_prediction(inference_api, model_name, prompt_format, original_data["question"], sql)
yield f"Prediction completed. Results saved to {output_file}"
except Exception as e:
yield f"Prediction failed with error: {str(e)}"
yield f"Error traceback: {traceback.format_exc()}"
def run_evaluation(inference_api, model_name, prompt_format="duckdbinstgraniteshort", custom_prompt=None):
if "OPENROUTER_API_KEY" not in os.environ:
yield "Error: OPENROUTER_API_KEY not found in environment variables."
return
if "HF_TOKEN" not in os.environ:
yield "Error: HF_TOKEN not found in environment variables."
return
try:
# Set up the arguments
dataset_path = str(eval_dir / "data/dev.json")
table_meta_path = str(eval_dir / "data/tables.json")
output_dir = eval_dir / "output"
yield f"Using model: {model_name}"
yield f"Using prompt format: {prompt_format}"
if prompt_format == "custom":
prompt_format = prompt_format+"_"+str(abs(hash(custom_prompt)) % (10 ** 8))
output_file = output_dir / f"{prompt_format}_0docs_{model_name.replace('/', '_')}_dev_{datetime.now().strftime('%y-%m-%d')}.json"
# Ensure the output directory exists
output_dir.mkdir(parents=True, exist_ok=True)
if output_file.exists():
yield f"Prediction file already exists: {output_file}"
yield "Skipping prediction step and proceeding to evaluation."
else:
# Run prediction
for output in run_prediction(inference_api, model_name, prompt_format, custom_prompt, output_file):
yield output
# Run evaluation
yield "Starting evaluation..."
# Set up evaluation arguments
gold_path = Path(dataset_path)
db_dir = str(eval_dir / "data/databases/")
tables_path = Path(table_meta_path)
kmaps = test_suite_evaluation.build_foreign_key_map_from_json(str(tables_path))
db_schemas = read_tables_json(str(tables_path))
gold_sqls_dict = json.load(gold_path.open("r", encoding="utf-8"))
pred_sqls_dict = [json.loads(l) for l in output_file.open("r").readlines()]
gold_sqls = [p.get("query", p.get("sql", "")) for p in gold_sqls_dict]
setup_sqls = [p["setup_sql"] for p in gold_sqls_dict]
validate_sqls = [p["validation_sql"] for p in gold_sqls_dict]
gold_dbs = [p.get("db_id", p.get("db", "")) for p in gold_sqls_dict]
pred_sqls = [p["pred"] for p in pred_sqls_dict]
categories = [p.get("category", "") for p in gold_sqls_dict]
yield "Computing metrics..."
metrics = compute_metrics(
gold_sqls=gold_sqls,
pred_sqls=pred_sqls,
gold_dbs=gold_dbs,
setup_sqls=setup_sqls,
validate_sqls=validate_sqls,
kmaps=kmaps,
db_schemas=db_schemas,
database_dir=db_dir,
lowercase_schema_match=False,
model_name=model_name,
categories=categories,
)
# Save evaluation results to dataset
save_evaluation(inference_api, model_name, prompt_format, custom_prompt, metrics)
yield "Evaluation completed."
if metrics:
yield "Overall Results:"
overall_metrics = metrics['exec']['all']
yield f"All (n={overall_metrics['count']}) - Execution Accuracy: {overall_metrics['exec']:.3f}"
yield f"All (n={overall_metrics['count']}) - Edit Distance: {metrics['edit_distance']['edit_distance']:.3f}"
categories = ['easy', 'medium', 'hard', 'duckdb', 'ddl', 'all']
for category in categories:
if category in metrics['exec']:
category_metrics = metrics['exec'][category]
yield f"{category} (n={category_metrics['count']}) - Execution Accuracy: {category_metrics['exec']:.3f}"
else:
yield f"{category}: No data available"
else:
yield "No evaluation metrics returned."
except Exception as e:
yield f"An unexpected error occurred: {str(e)}"
yield f"Error traceback: {traceback.format_exc()}"
if __name__ == "__main__":
model_name = input("Enter the model name: ")
prompt_format = input("Enter the prompt format (default is duckdbinstgraniteshort): ") or "duckdbinstgraniteshort"
for result in run_evaluation(model_name, prompt_format):
print(result, flush=True)