"""Inference for CodeBERT SQL error cross-encoder.""" from __future__ import annotations import argparse import json from pathlib import Path from typing import List, Optional, Union import numpy as np import torch from transformers import AutoModelForSequenceClassification, AutoTokenizer from src.codebert_formatting import format_cross_encoder_input from src.codebert_labels import load_codebert_labels, multihot_to_label_names from src.hf_metrics import sigmoid PROJECT_ROOT = Path(__file__).resolve().parent.parent DEFAULT_MODEL_DIR = PROJECT_ROOT / "models" / "codebert-cross-encoder" class CodeBERTSQLErrorClassifier: """CodeBERT cross-encoder inference wrapper.""" def __init__( self, model_dir: Union[str, Path] = DEFAULT_MODEL_DIR, threshold: float = 0.5, device: Optional[str] = None, ): self.model_dir = Path(model_dir) self.threshold = threshold self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") config_path = self.model_dir / "label_config.json" if config_path.exists(): with open(config_path) as f: cfg = json.load(f) self.label_list = cfg.get("labels", load_codebert_labels()) self.threshold = cfg.get("threshold", threshold) self.max_length = cfg.get("max_length", 512) else: self.label_list = load_codebert_labels() self.max_length = 512 self.tokenizer = AutoTokenizer.from_pretrained(self.model_dir) self.model = AutoModelForSequenceClassification.from_pretrained( self.model_dir ).to(self.device) self.model.eval() def predict( self, question: str, schema: str, student_sql: str, correct_sql: str, threshold: Optional[float] = None, top_k: int = 5, ) -> dict: text = format_cross_encoder_input( question=question, schema=schema, student_sql=student_sql, correct_sql=correct_sql, ) encoded = self.tokenizer( text, truncation=True, max_length=self.max_length, padding=True, return_tensors="pt", ).to(self.device) with torch.no_grad(): logits = self.model(**encoded).logits.cpu().numpy()[0] probs = sigmoid(logits) thr = threshold if threshold is not None else self.threshold predicted = multihot_to_label_names(probs, self.label_list, threshold=thr) ranked = sorted( zip(self.label_list, probs.tolist()), key=lambda x: x[1], reverse=True, )[:top_k] return { "error_labels": predicted, "probabilities": {name: float(p) for name, p in ranked}, "top_k": [ {"label": name, "probability": float(p)} for name, p in ranked ], "primary_label": ranked[0][0], "primary_confidence": float(ranked[0][1]), } def predict_batch( self, examples: List[dict], batch_size: int = 16, ) -> List[dict]: results = [] for i in range(0, len(examples), batch_size): chunk = examples[i : i + batch_size] texts = [ format_cross_encoder_input( question=x["question"], schema=x["schema"], student_sql=x["student_sql"], correct_sql=x["correct_sql"], ) for x in chunk ] encoded = self.tokenizer( texts, truncation=True, max_length=self.max_length, padding=True, return_tensors="pt", ).to(self.device) with torch.no_grad(): logits = self.model(**encoded).logits.cpu().numpy() for j, row in enumerate(logits): probs = sigmoid(row) results.append( { "error_labels": multihot_to_label_names( probs, self.label_list, self.threshold ), "primary_label": self.label_list[int(np.argmax(probs))], "primary_confidence": float(np.max(probs)), } ) return results def main() -> None: parser = argparse.ArgumentParser(description="CodeBERT SQL error inference") parser.add_argument("--model-dir", type=Path, default=DEFAULT_MODEL_DIR) parser.add_argument("--question", type=str, required=True) parser.add_argument("--schema", type=str, required=True) parser.add_argument("--student-sql", type=str, required=True) parser.add_argument("--correct-sql", type=str, required=True) parser.add_argument("--threshold", type=float, default=0.5) args = parser.parse_args() clf = CodeBERTSQLErrorClassifier(args.model_dir, threshold=args.threshold) result = clf.predict( question=args.question, schema=args.schema, student_sql=args.student_sql, correct_sql=args.correct_sql, ) print(json.dumps(result, indent=2)) if __name__ == "__main__": main()