"""Inference for CodeBERT SQL error cross-encoder.""" from __future__ import annotations import argparse import json from pathlib import Path from typing import List, Optional, Union import numpy as np import torch from transformers import AutoModelForSequenceClassification, AutoTokenizer from src.codebert_formatting import format_cross_encoder_input, sql_queries_equivalent from src.device_utils import get_device from src.codebert_labels import load_codebert_labels, multihot_to_label_names from src.hf_metrics import sigmoid PROJECT_ROOT = Path(__file__).resolve().parent.parent DEFAULT_MODEL_DIR = PROJECT_ROOT / "models" / "codebert-cross-encoder" def _is_hub_id(model_dir: Union[str, Path]) -> bool: text = str(model_dir) local = Path(text) return "/" in text and not local.exists() class CodeBERTSQLErrorClassifier: """CodeBERT cross-encoder inference wrapper.""" def __init__( self, model_dir: Union[str, Path] = DEFAULT_MODEL_DIR, threshold: float = 0.5, device: Optional[str] = None, ): self.hub_id = str(model_dir) if _is_hub_id(model_dir) else None self.model_dir = Path(model_dir) if not self.hub_id else None self.threshold = threshold # MPS inference can be flaky for some ops; CPU is reliable on Mac self.device = device or ( "cuda" if torch.cuda.is_available() else "cpu" ) model_ref = self.hub_id or str(self.model_dir) if self.hub_id: self.label_list = load_codebert_labels() self.max_length = 512 else: config_path = self.model_dir / "label_config.json" if config_path.exists(): with open(config_path) as f: cfg = json.load(f) self.label_list = cfg.get("labels", load_codebert_labels()) self.threshold = cfg.get("threshold", threshold) self.max_length = cfg.get("max_length", 512) else: self.label_list = load_codebert_labels() self.max_length = 512 self.tokenizer = AutoTokenizer.from_pretrained(model_ref) self.model = AutoModelForSequenceClassification.from_pretrained( model_ref ).to(self.device) self.model.eval() def predict( self, question: str, schema: str, student_sql: str, correct_sql: str, threshold: Optional[float] = None, top_k: int = 5, ) -> dict: thr = threshold if threshold is not None else self.threshold if sql_queries_equivalent(student_sql, correct_sql): return { "error_labels": [], "probabilities": {name: 0.0 for name in self.label_list}, "top_k": [ {"label": name, "probability": 0.0} for name in self.label_list[:5] ], "primary_label": "NO_ERROR", "primary_confidence": 1.0, "match_detected": True, } text = format_cross_encoder_input( question=question, schema=schema, student_sql=student_sql, correct_sql=correct_sql, ) encoded = self.tokenizer( text, truncation=True, max_length=self.max_length, padding=True, return_tensors="pt", ).to(self.device) with torch.no_grad(): logits = self.model(**encoded).logits.cpu().numpy()[0] probs = sigmoid(logits) predicted = multihot_to_label_names(probs, self.label_list, threshold=thr) ranked = sorted( zip(self.label_list, probs.tolist()), key=lambda x: x[1], reverse=True, )[:top_k] top_label, top_prob = ranked[0] if top_prob >= thr: primary_label = top_label primary_confidence = float(top_prob) else: primary_label = "NO_ERROR" primary_confidence = 1.0 - float(top_prob) return { "error_labels": predicted, "probabilities": {name: float(p) for name, p in ranked}, "top_k": [ {"label": name, "probability": float(p)} for name, p in ranked ], "primary_label": primary_label, "primary_confidence": primary_confidence, "match_detected": False, } def predict_batch( self, examples: List[dict], batch_size: int = 16, ) -> List[dict]: results = [] for i in range(0, len(examples), batch_size): chunk = examples[i : i + batch_size] texts = [ format_cross_encoder_input( question=x["question"], schema=x["schema"], student_sql=x["student_sql"], correct_sql=x["correct_sql"], ) for x in chunk ] encoded = self.tokenizer( texts, truncation=True, max_length=self.max_length, padding=True, return_tensors="pt", ).to(self.device) with torch.no_grad(): logits = self.model(**encoded).logits.cpu().numpy() for j, row in enumerate(logits): probs = sigmoid(row) results.append( { "error_labels": multihot_to_label_names( probs, self.label_list, self.threshold ), "primary_label": self.label_list[int(np.argmax(probs))], "primary_confidence": float(np.max(probs)), } ) return results def main() -> None: parser = argparse.ArgumentParser(description="CodeBERT SQL error inference") parser.add_argument("--model-dir", type=Path, default=DEFAULT_MODEL_DIR) parser.add_argument("--question", type=str, required=True) parser.add_argument("--schema", type=str, required=True) parser.add_argument("--student-sql", type=str, required=True) parser.add_argument("--correct-sql", type=str, required=True) parser.add_argument("--threshold", type=float, default=0.5) args = parser.parse_args() clf = CodeBERTSQLErrorClassifier(args.model_dir, threshold=args.threshold) result = clf.predict( question=args.question, schema=args.schema, student_sql=args.student_sql, correct_sql=args.correct_sql, ) print(json.dumps(result, indent=2)) if __name__ == "__main__": main()