Spaces:
Sleeping
Sleeping
File size: 2,275 Bytes
8a3099e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 | """Evaluate a trained CodeBERT cross-encoder."""
from __future__ import annotations
import argparse
import json
from pathlib import Path
import pandas as pd
from transformers import AutoModelForSequenceClassification, AutoTokenizer, Trainer
from src.codebert_dataset import SQLCodeBERTDataCollator, SQLCodeBERTDataset, normalize_dataframe
from src.hf_metrics import build_compute_metrics, compute_multilabel_metrics
PROJECT_ROOT = Path(__file__).resolve().parent.parent
DEFAULT_DATA = PROJECT_ROOT / "data" / "sql_errors_dev.parquet"
DEFAULT_MODEL = PROJECT_ROOT / "models" / "codebert-cross-encoder"
def evaluate(
model_dir: Path = DEFAULT_MODEL,
data_path: Path = DEFAULT_DATA,
sample_size: int = 10_000,
threshold: float = 0.5,
seed: int = 42,
) -> dict:
df = normalize_dataframe(pd.read_parquet(data_path))
if len(df) > sample_size:
df = df.sample(n=sample_size, random_state=seed)
tokenizer = AutoTokenizer.from_pretrained(model_dir)
model = AutoModelForSequenceClassification.from_pretrained(model_dir)
dataset = SQLCodeBERTDataset(df, tokenizer)
trainer_kwargs = dict(
model=model,
data_collator=SQLCodeBERTDataCollator(tokenizer),
compute_metrics=build_compute_metrics(threshold=threshold),
)
try:
trainer = Trainer(processing_class=tokenizer, **trainer_kwargs)
except TypeError:
trainer = Trainer(tokenizer=tokenizer, **trainer_kwargs)
output = trainer.predict(dataset)
metrics = compute_multilabel_metrics(
output.predictions, output.label_ids, threshold=threshold
)
print(json.dumps(metrics, indent=2))
return metrics
def main() -> None:
parser = argparse.ArgumentParser(description="Evaluate CodeBERT SQL classifier")
parser.add_argument("--model-dir", type=Path, default=DEFAULT_MODEL)
parser.add_argument("--data", type=Path, default=DEFAULT_DATA)
parser.add_argument("--sample-size", type=int, default=10_000)
parser.add_argument("--threshold", type=float, default=0.5)
args = parser.parse_args()
evaluate(
model_dir=args.model_dir,
data_path=args.data,
sample_size=args.sample_size,
threshold=args.threshold,
)
if __name__ == "__main__":
main()
|