File size: 2,275 Bytes
8a3099e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
"""Evaluate a trained CodeBERT cross-encoder."""

from __future__ import annotations

import argparse
import json
from pathlib import Path

import pandas as pd
from transformers import AutoModelForSequenceClassification, AutoTokenizer, Trainer

from src.codebert_dataset import SQLCodeBERTDataCollator, SQLCodeBERTDataset, normalize_dataframe
from src.hf_metrics import build_compute_metrics, compute_multilabel_metrics

PROJECT_ROOT = Path(__file__).resolve().parent.parent
DEFAULT_DATA = PROJECT_ROOT / "data" / "sql_errors_dev.parquet"
DEFAULT_MODEL = PROJECT_ROOT / "models" / "codebert-cross-encoder"


def evaluate(
    model_dir: Path = DEFAULT_MODEL,
    data_path: Path = DEFAULT_DATA,
    sample_size: int = 10_000,
    threshold: float = 0.5,
    seed: int = 42,
) -> dict:
    df = normalize_dataframe(pd.read_parquet(data_path))
    if len(df) > sample_size:
        df = df.sample(n=sample_size, random_state=seed)

    tokenizer = AutoTokenizer.from_pretrained(model_dir)
    model = AutoModelForSequenceClassification.from_pretrained(model_dir)
    dataset = SQLCodeBERTDataset(df, tokenizer)

    trainer_kwargs = dict(
        model=model,
        data_collator=SQLCodeBERTDataCollator(tokenizer),
        compute_metrics=build_compute_metrics(threshold=threshold),
    )
    try:
        trainer = Trainer(processing_class=tokenizer, **trainer_kwargs)
    except TypeError:
        trainer = Trainer(tokenizer=tokenizer, **trainer_kwargs)

    output = trainer.predict(dataset)
    metrics = compute_multilabel_metrics(
        output.predictions, output.label_ids, threshold=threshold
    )
    print(json.dumps(metrics, indent=2))
    return metrics


def main() -> None:
    parser = argparse.ArgumentParser(description="Evaluate CodeBERT SQL classifier")
    parser.add_argument("--model-dir", type=Path, default=DEFAULT_MODEL)
    parser.add_argument("--data", type=Path, default=DEFAULT_DATA)
    parser.add_argument("--sample-size", type=int, default=10_000)
    parser.add_argument("--threshold", type=float, default=0.5)
    args = parser.parse_args()
    evaluate(
        model_dir=args.model_dir,
        data_path=args.data,
        sample_size=args.sample_size,
        threshold=args.threshold,
    )


if __name__ == "__main__":
    main()