sql-error-classifier / src /hf_eval_codebert.py
nishu08's picture
Deploy CodeBERT inference Space
8a3099e verified
"""Evaluate a trained CodeBERT cross-encoder."""
from __future__ import annotations
import argparse
import json
from pathlib import Path
import pandas as pd
from transformers import AutoModelForSequenceClassification, AutoTokenizer, Trainer
from src.codebert_dataset import SQLCodeBERTDataCollator, SQLCodeBERTDataset, normalize_dataframe
from src.hf_metrics import build_compute_metrics, compute_multilabel_metrics
PROJECT_ROOT = Path(__file__).resolve().parent.parent
DEFAULT_DATA = PROJECT_ROOT / "data" / "sql_errors_dev.parquet"
DEFAULT_MODEL = PROJECT_ROOT / "models" / "codebert-cross-encoder"
def evaluate(
model_dir: Path = DEFAULT_MODEL,
data_path: Path = DEFAULT_DATA,
sample_size: int = 10_000,
threshold: float = 0.5,
seed: int = 42,
) -> dict:
df = normalize_dataframe(pd.read_parquet(data_path))
if len(df) > sample_size:
df = df.sample(n=sample_size, random_state=seed)
tokenizer = AutoTokenizer.from_pretrained(model_dir)
model = AutoModelForSequenceClassification.from_pretrained(model_dir)
dataset = SQLCodeBERTDataset(df, tokenizer)
trainer_kwargs = dict(
model=model,
data_collator=SQLCodeBERTDataCollator(tokenizer),
compute_metrics=build_compute_metrics(threshold=threshold),
)
try:
trainer = Trainer(processing_class=tokenizer, **trainer_kwargs)
except TypeError:
trainer = Trainer(tokenizer=tokenizer, **trainer_kwargs)
output = trainer.predict(dataset)
metrics = compute_multilabel_metrics(
output.predictions, output.label_ids, threshold=threshold
)
print(json.dumps(metrics, indent=2))
return metrics
def main() -> None:
parser = argparse.ArgumentParser(description="Evaluate CodeBERT SQL classifier")
parser.add_argument("--model-dir", type=Path, default=DEFAULT_MODEL)
parser.add_argument("--data", type=Path, default=DEFAULT_DATA)
parser.add_argument("--sample-size", type=int, default=10_000)
parser.add_argument("--threshold", type=float, default=0.5)
args = parser.parse_args()
evaluate(
model_dir=args.model_dir,
data_path=args.data,
sample_size=args.sample_size,
threshold=args.threshold,
)
if __name__ == "__main__":
main()