"""Evaluate a trained CodeBERT cross-encoder.""" from __future__ import annotations import argparse import json from pathlib import Path import pandas as pd from transformers import AutoModelForSequenceClassification, AutoTokenizer, Trainer from src.codebert_dataset import SQLCodeBERTDataCollator, SQLCodeBERTDataset, normalize_dataframe from src.hf_metrics import build_compute_metrics, compute_multilabel_metrics PROJECT_ROOT = Path(__file__).resolve().parent.parent DEFAULT_DATA = PROJECT_ROOT / "data" / "sql_errors_dev.parquet" DEFAULT_MODEL = PROJECT_ROOT / "models" / "codebert-cross-encoder" def evaluate( model_dir: Path = DEFAULT_MODEL, data_path: Path = DEFAULT_DATA, sample_size: int = 10_000, threshold: float = 0.5, seed: int = 42, ) -> dict: df = normalize_dataframe(pd.read_parquet(data_path)) if len(df) > sample_size: df = df.sample(n=sample_size, random_state=seed) tokenizer = AutoTokenizer.from_pretrained(model_dir) model = AutoModelForSequenceClassification.from_pretrained(model_dir) dataset = SQLCodeBERTDataset(df, tokenizer) trainer_kwargs = dict( model=model, data_collator=SQLCodeBERTDataCollator(tokenizer), compute_metrics=build_compute_metrics(threshold=threshold), ) try: trainer = Trainer(processing_class=tokenizer, **trainer_kwargs) except TypeError: trainer = Trainer(tokenizer=tokenizer, **trainer_kwargs) output = trainer.predict(dataset) metrics = compute_multilabel_metrics( output.predictions, output.label_ids, threshold=threshold ) print(json.dumps(metrics, indent=2)) return metrics def main() -> None: parser = argparse.ArgumentParser(description="Evaluate CodeBERT SQL classifier") parser.add_argument("--model-dir", type=Path, default=DEFAULT_MODEL) parser.add_argument("--data", type=Path, default=DEFAULT_DATA) parser.add_argument("--sample-size", type=int, default=10_000) parser.add_argument("--threshold", type=float, default=0.5) args = parser.parse_args() evaluate( model_dir=args.model_dir, data_path=args.data, sample_size=args.sample_size, threshold=args.threshold, ) if __name__ == "__main__": main()