File size: 3,626 Bytes
9b2cded
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
"""Evaluate trained model with confusion matrix and per-class metrics."""

from __future__ import annotations

import argparse
import json
from pathlib import Path

import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn.metrics import ConfusionMatrixDisplay, classification_report

from src.categories import id_to_name, load_categories
from src.model import DEFAULT_MODEL_PATH, combine_features, load_model
from src.cross_encoder_model import (
    CrossEncoderClassifier,
    FineTunedCrossEncoderClassifier,
)
from src.multi_tower_model import MultiTowerClassifier, contexts_from_dataframe

CONTEXT_MODELS = (
    CrossEncoderClassifier,
    FineTunedCrossEncoderClassifier,
    MultiTowerClassifier,
)

PROJECT_ROOT = Path(__file__).resolve().parent.parent
DEFAULT_DATA = PROJECT_ROOT / "data" / "sql_errors_1m.parquet"
DEFAULT_OUTPUT = PROJECT_ROOT / "models" / "evaluation"


def evaluate(
    data_path: Path = DEFAULT_DATA,
    model_path: Path = DEFAULT_MODEL_PATH,
    output_dir: Path = DEFAULT_OUTPUT,
    sample_size: int = 100_000,
    use_error_message: bool = True,
    seed: int = 42,
) -> dict:
    output_dir.mkdir(parents=True, exist_ok=True)

    df = pd.read_parquet(data_path)
    if len(df) > sample_size:
        df = df.sample(n=sample_size, random_state=seed)

    labels = df["label_id"].values
    model = load_model(model_path)

    if isinstance(model, CONTEXT_MODELS):
        if not use_error_message and "error_message" in df.columns:
            df = df.drop(columns=["error_message"])
        preds = model.predict(contexts_from_dataframe(df))
    else:
        texts = combine_features(
            queries=df["query"].tolist(),
            error_messages=df["error_message"].tolist() if use_error_message else None,
            schemas=df["schema"].tolist() if "schema" in df.columns else None,
            questions=df["question"].tolist() if "question" in df.columns else None,
        )
        preds = model.predict(texts)

    categories = load_categories()
    target_names = [c.name for c in categories]
    report = classification_report(
        labels, preds, target_names=target_names, output_dict=True, zero_division=0
    )

    with open(output_dir / "classification_report.json", "w") as f:
        json.dump(report, f, indent=2)

    cm = ConfusionMatrixDisplay.from_predictions(
        labels,
        preds,
        display_labels=target_names,
        xticks_rotation=90,
        cmap="Blues",
        colorbar=False,
    )
    fig = cm.figure_
    fig.set_size_inches(14, 12)
    fig.tight_layout()
    fig.savefig(output_dir / "confusion_matrix.png", dpi=150)
    plt.close(fig)

    print(f"Accuracy: {report['accuracy']:.4f}")
    print(f"Reports saved to {output_dir}")
    return report


def main() -> None:
    parser = argparse.ArgumentParser(description="Evaluate SQL error classifier")
    parser.add_argument("--data", type=Path, default=DEFAULT_DATA)
    parser.add_argument("--model", type=Path, default=DEFAULT_MODEL_PATH)
    parser.add_argument("--output", type=Path, default=DEFAULT_OUTPUT)
    parser.add_argument("--sample-size", type=int, default=100_000)
    parser.add_argument("--no-error-message", action="store_true")
    parser.add_argument("--seed", type=int, default=42)
    args = parser.parse_args()

    evaluate(
        data_path=args.data,
        model_path=args.model,
        output_dir=args.output,
        sample_size=args.sample_size,
        use_error_message=not args.no_error_message,
        seed=args.seed,
    )


if __name__ == "__main__":
    main()