nishu08's picture
Deploy CodeBERT training Space
9b2cded verified
raw
history blame contribute delete
3.63 kB
"""Evaluate trained model with confusion matrix and per-class metrics."""
from __future__ import annotations
import argparse
import json
from pathlib import Path
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn.metrics import ConfusionMatrixDisplay, classification_report
from src.categories import id_to_name, load_categories
from src.model import DEFAULT_MODEL_PATH, combine_features, load_model
from src.cross_encoder_model import (
CrossEncoderClassifier,
FineTunedCrossEncoderClassifier,
)
from src.multi_tower_model import MultiTowerClassifier, contexts_from_dataframe
CONTEXT_MODELS = (
CrossEncoderClassifier,
FineTunedCrossEncoderClassifier,
MultiTowerClassifier,
)
PROJECT_ROOT = Path(__file__).resolve().parent.parent
DEFAULT_DATA = PROJECT_ROOT / "data" / "sql_errors_1m.parquet"
DEFAULT_OUTPUT = PROJECT_ROOT / "models" / "evaluation"
def evaluate(
data_path: Path = DEFAULT_DATA,
model_path: Path = DEFAULT_MODEL_PATH,
output_dir: Path = DEFAULT_OUTPUT,
sample_size: int = 100_000,
use_error_message: bool = True,
seed: int = 42,
) -> dict:
output_dir.mkdir(parents=True, exist_ok=True)
df = pd.read_parquet(data_path)
if len(df) > sample_size:
df = df.sample(n=sample_size, random_state=seed)
labels = df["label_id"].values
model = load_model(model_path)
if isinstance(model, CONTEXT_MODELS):
if not use_error_message and "error_message" in df.columns:
df = df.drop(columns=["error_message"])
preds = model.predict(contexts_from_dataframe(df))
else:
texts = combine_features(
queries=df["query"].tolist(),
error_messages=df["error_message"].tolist() if use_error_message else None,
schemas=df["schema"].tolist() if "schema" in df.columns else None,
questions=df["question"].tolist() if "question" in df.columns else None,
)
preds = model.predict(texts)
categories = load_categories()
target_names = [c.name for c in categories]
report = classification_report(
labels, preds, target_names=target_names, output_dict=True, zero_division=0
)
with open(output_dir / "classification_report.json", "w") as f:
json.dump(report, f, indent=2)
cm = ConfusionMatrixDisplay.from_predictions(
labels,
preds,
display_labels=target_names,
xticks_rotation=90,
cmap="Blues",
colorbar=False,
)
fig = cm.figure_
fig.set_size_inches(14, 12)
fig.tight_layout()
fig.savefig(output_dir / "confusion_matrix.png", dpi=150)
plt.close(fig)
print(f"Accuracy: {report['accuracy']:.4f}")
print(f"Reports saved to {output_dir}")
return report
def main() -> None:
parser = argparse.ArgumentParser(description="Evaluate SQL error classifier")
parser.add_argument("--data", type=Path, default=DEFAULT_DATA)
parser.add_argument("--model", type=Path, default=DEFAULT_MODEL_PATH)
parser.add_argument("--output", type=Path, default=DEFAULT_OUTPUT)
parser.add_argument("--sample-size", type=int, default=100_000)
parser.add_argument("--no-error-message", action="store_true")
parser.add_argument("--seed", type=int, default=42)
args = parser.parse_args()
evaluate(
data_path=args.data,
model_path=args.model,
output_dir=args.output,
sample_size=args.sample_size,
use_error_message=not args.no_error_message,
seed=args.seed,
)
if __name__ == "__main__":
main()