Spaces:
Running
Running
| """Evaluate a trained CodeBERT cross-encoder.""" | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| from pathlib import Path | |
| import pandas as pd | |
| from transformers import AutoModelForSequenceClassification, AutoTokenizer, Trainer | |
| from src.codebert_dataset import SQLCodeBERTDataCollator, SQLCodeBERTDataset, normalize_dataframe | |
| from src.hf_metrics import build_compute_metrics, compute_multilabel_metrics | |
| PROJECT_ROOT = Path(__file__).resolve().parent.parent | |
| DEFAULT_DATA = PROJECT_ROOT / "data" / "sql_errors_dev.parquet" | |
| DEFAULT_MODEL = PROJECT_ROOT / "models" / "codebert-cross-encoder" | |
| def evaluate( | |
| model_dir: Path = DEFAULT_MODEL, | |
| data_path: Path = DEFAULT_DATA, | |
| sample_size: int = 10_000, | |
| threshold: float = 0.5, | |
| seed: int = 42, | |
| ) -> dict: | |
| df = normalize_dataframe(pd.read_parquet(data_path)) | |
| if len(df) > sample_size: | |
| df = df.sample(n=sample_size, random_state=seed) | |
| tokenizer = AutoTokenizer.from_pretrained(model_dir) | |
| model = AutoModelForSequenceClassification.from_pretrained(model_dir) | |
| dataset = SQLCodeBERTDataset(df, tokenizer) | |
| trainer_kwargs = dict( | |
| model=model, | |
| data_collator=SQLCodeBERTDataCollator(tokenizer), | |
| compute_metrics=build_compute_metrics(threshold=threshold), | |
| ) | |
| try: | |
| trainer = Trainer(processing_class=tokenizer, **trainer_kwargs) | |
| except TypeError: | |
| trainer = Trainer(tokenizer=tokenizer, **trainer_kwargs) | |
| output = trainer.predict(dataset) | |
| metrics = compute_multilabel_metrics( | |
| output.predictions, output.label_ids, threshold=threshold | |
| ) | |
| print(json.dumps(metrics, indent=2)) | |
| return metrics | |
| def main() -> None: | |
| parser = argparse.ArgumentParser(description="Evaluate CodeBERT SQL classifier") | |
| parser.add_argument("--model-dir", type=Path, default=DEFAULT_MODEL) | |
| parser.add_argument("--data", type=Path, default=DEFAULT_DATA) | |
| parser.add_argument("--sample-size", type=int, default=10_000) | |
| parser.add_argument("--threshold", type=float, default=0.5) | |
| args = parser.parse_args() | |
| evaluate( | |
| model_dir=args.model_dir, | |
| data_path=args.data, | |
| sample_size=args.sample_size, | |
| threshold=args.threshold, | |
| ) | |
| if __name__ == "__main__": | |
| main() | |