Spaces:
Sleeping
Sleeping
| """Evaluate trained model with confusion matrix and per-class metrics.""" | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| from pathlib import Path | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import pandas as pd | |
| from sklearn.metrics import ConfusionMatrixDisplay, classification_report | |
| from src.categories import id_to_name, load_categories | |
| from src.model import DEFAULT_MODEL_PATH, combine_features, load_model | |
| from src.cross_encoder_model import ( | |
| CrossEncoderClassifier, | |
| FineTunedCrossEncoderClassifier, | |
| ) | |
| from src.multi_tower_model import MultiTowerClassifier, contexts_from_dataframe | |
| CONTEXT_MODELS = ( | |
| CrossEncoderClassifier, | |
| FineTunedCrossEncoderClassifier, | |
| MultiTowerClassifier, | |
| ) | |
| PROJECT_ROOT = Path(__file__).resolve().parent.parent | |
| DEFAULT_DATA = PROJECT_ROOT / "data" / "sql_errors_1m.parquet" | |
| DEFAULT_OUTPUT = PROJECT_ROOT / "models" / "evaluation" | |
| def evaluate( | |
| data_path: Path = DEFAULT_DATA, | |
| model_path: Path = DEFAULT_MODEL_PATH, | |
| output_dir: Path = DEFAULT_OUTPUT, | |
| sample_size: int = 100_000, | |
| use_error_message: bool = True, | |
| seed: int = 42, | |
| ) -> dict: | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| df = pd.read_parquet(data_path) | |
| if len(df) > sample_size: | |
| df = df.sample(n=sample_size, random_state=seed) | |
| labels = df["label_id"].values | |
| model = load_model(model_path) | |
| if isinstance(model, CONTEXT_MODELS): | |
| if not use_error_message and "error_message" in df.columns: | |
| df = df.drop(columns=["error_message"]) | |
| preds = model.predict(contexts_from_dataframe(df)) | |
| else: | |
| texts = combine_features( | |
| queries=df["query"].tolist(), | |
| error_messages=df["error_message"].tolist() if use_error_message else None, | |
| schemas=df["schema"].tolist() if "schema" in df.columns else None, | |
| questions=df["question"].tolist() if "question" in df.columns else None, | |
| ) | |
| preds = model.predict(texts) | |
| categories = load_categories() | |
| target_names = [c.name for c in categories] | |
| report = classification_report( | |
| labels, preds, target_names=target_names, output_dict=True, zero_division=0 | |
| ) | |
| with open(output_dir / "classification_report.json", "w") as f: | |
| json.dump(report, f, indent=2) | |
| cm = ConfusionMatrixDisplay.from_predictions( | |
| labels, | |
| preds, | |
| display_labels=target_names, | |
| xticks_rotation=90, | |
| cmap="Blues", | |
| colorbar=False, | |
| ) | |
| fig = cm.figure_ | |
| fig.set_size_inches(14, 12) | |
| fig.tight_layout() | |
| fig.savefig(output_dir / "confusion_matrix.png", dpi=150) | |
| plt.close(fig) | |
| print(f"Accuracy: {report['accuracy']:.4f}") | |
| print(f"Reports saved to {output_dir}") | |
| return report | |
| def main() -> None: | |
| parser = argparse.ArgumentParser(description="Evaluate SQL error classifier") | |
| parser.add_argument("--data", type=Path, default=DEFAULT_DATA) | |
| parser.add_argument("--model", type=Path, default=DEFAULT_MODEL_PATH) | |
| parser.add_argument("--output", type=Path, default=DEFAULT_OUTPUT) | |
| parser.add_argument("--sample-size", type=int, default=100_000) | |
| parser.add_argument("--no-error-message", action="store_true") | |
| parser.add_argument("--seed", type=int, default=42) | |
| args = parser.parse_args() | |
| evaluate( | |
| data_path=args.data, | |
| model_path=args.model, | |
| output_dir=args.output, | |
| sample_size=args.sample_size, | |
| use_error_message=not args.no_error_message, | |
| seed=args.seed, | |
| ) | |
| if __name__ == "__main__": | |
| main() | |