"""Evaluate trained model with confusion matrix and per-class metrics.""" from __future__ import annotations import argparse import json from pathlib import Path import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt import numpy as np import pandas as pd from sklearn.metrics import ConfusionMatrixDisplay, classification_report from src.categories import id_to_name, load_categories from src.model import DEFAULT_MODEL_PATH, combine_features, load_model from src.cross_encoder_model import ( CrossEncoderClassifier, FineTunedCrossEncoderClassifier, ) from src.multi_tower_model import MultiTowerClassifier, contexts_from_dataframe CONTEXT_MODELS = ( CrossEncoderClassifier, FineTunedCrossEncoderClassifier, MultiTowerClassifier, ) PROJECT_ROOT = Path(__file__).resolve().parent.parent DEFAULT_DATA = PROJECT_ROOT / "data" / "sql_errors_1m.parquet" DEFAULT_OUTPUT = PROJECT_ROOT / "models" / "evaluation" def evaluate( data_path: Path = DEFAULT_DATA, model_path: Path = DEFAULT_MODEL_PATH, output_dir: Path = DEFAULT_OUTPUT, sample_size: int = 100_000, use_error_message: bool = True, seed: int = 42, ) -> dict: output_dir.mkdir(parents=True, exist_ok=True) df = pd.read_parquet(data_path) if len(df) > sample_size: df = df.sample(n=sample_size, random_state=seed) labels = df["label_id"].values model = load_model(model_path) if isinstance(model, CONTEXT_MODELS): if not use_error_message and "error_message" in df.columns: df = df.drop(columns=["error_message"]) preds = model.predict(contexts_from_dataframe(df)) else: texts = combine_features( queries=df["query"].tolist(), error_messages=df["error_message"].tolist() if use_error_message else None, schemas=df["schema"].tolist() if "schema" in df.columns else None, questions=df["question"].tolist() if "question" in df.columns else None, ) preds = model.predict(texts) categories = load_categories() target_names = [c.name for c in categories] report = classification_report( labels, preds, target_names=target_names, output_dict=True, zero_division=0 ) with open(output_dir / "classification_report.json", "w") as f: json.dump(report, f, indent=2) cm = ConfusionMatrixDisplay.from_predictions( labels, preds, display_labels=target_names, xticks_rotation=90, cmap="Blues", colorbar=False, ) fig = cm.figure_ fig.set_size_inches(14, 12) fig.tight_layout() fig.savefig(output_dir / "confusion_matrix.png", dpi=150) plt.close(fig) print(f"Accuracy: {report['accuracy']:.4f}") print(f"Reports saved to {output_dir}") return report def main() -> None: parser = argparse.ArgumentParser(description="Evaluate SQL error classifier") parser.add_argument("--data", type=Path, default=DEFAULT_DATA) parser.add_argument("--model", type=Path, default=DEFAULT_MODEL_PATH) parser.add_argument("--output", type=Path, default=DEFAULT_OUTPUT) parser.add_argument("--sample-size", type=int, default=100_000) parser.add_argument("--no-error-message", action="store_true") parser.add_argument("--seed", type=int, default=42) args = parser.parse_args() evaluate( data_path=args.data, model_path=args.model, output_dir=args.output, sample_size=args.sample_size, use_error_message=not args.no_error_message, seed=args.seed, ) if __name__ == "__main__": main()