Spaces:
Running
Running
| """Inference for CodeBERT SQL error cross-encoder.""" | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| from pathlib import Path | |
| from typing import List, Optional, Union | |
| import numpy as np | |
| import torch | |
| from transformers import AutoModelForSequenceClassification, AutoTokenizer | |
| from src.codebert_formatting import format_cross_encoder_input, sql_queries_equivalent | |
| from src.device_utils import get_device | |
| from src.codebert_labels import load_codebert_labels, multihot_to_label_names | |
| from src.hf_metrics import sigmoid | |
| PROJECT_ROOT = Path(__file__).resolve().parent.parent | |
| DEFAULT_MODEL_DIR = PROJECT_ROOT / "models" / "codebert-cross-encoder" | |
| def _is_hub_id(model_dir: Union[str, Path]) -> bool: | |
| text = str(model_dir) | |
| local = Path(text) | |
| return "/" in text and not local.exists() | |
| class CodeBERTSQLErrorClassifier: | |
| """CodeBERT cross-encoder inference wrapper.""" | |
| def __init__( | |
| self, | |
| model_dir: Union[str, Path] = DEFAULT_MODEL_DIR, | |
| threshold: float = 0.5, | |
| device: Optional[str] = None, | |
| ): | |
| self.hub_id = str(model_dir) if _is_hub_id(model_dir) else None | |
| self.model_dir = Path(model_dir) if not self.hub_id else None | |
| self.threshold = threshold | |
| # MPS inference can be flaky for some ops; CPU is reliable on Mac | |
| self.device = device or ( | |
| "cuda" if torch.cuda.is_available() else "cpu" | |
| ) | |
| model_ref = self.hub_id or str(self.model_dir) | |
| if self.hub_id: | |
| self.label_list = load_codebert_labels() | |
| self.max_length = 512 | |
| else: | |
| config_path = self.model_dir / "label_config.json" | |
| if config_path.exists(): | |
| with open(config_path) as f: | |
| cfg = json.load(f) | |
| self.label_list = cfg.get("labels", load_codebert_labels()) | |
| self.threshold = cfg.get("threshold", threshold) | |
| self.max_length = cfg.get("max_length", 512) | |
| else: | |
| self.label_list = load_codebert_labels() | |
| self.max_length = 512 | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_ref) | |
| self.model = AutoModelForSequenceClassification.from_pretrained( | |
| model_ref | |
| ).to(self.device) | |
| self.model.eval() | |
| def predict( | |
| self, | |
| question: str, | |
| schema: str, | |
| student_sql: str, | |
| correct_sql: str, | |
| threshold: Optional[float] = None, | |
| top_k: int = 5, | |
| ) -> dict: | |
| thr = threshold if threshold is not None else self.threshold | |
| if sql_queries_equivalent(student_sql, correct_sql): | |
| return { | |
| "error_labels": [], | |
| "probabilities": {name: 0.0 for name in self.label_list}, | |
| "top_k": [ | |
| {"label": name, "probability": 0.0} | |
| for name in self.label_list[:5] | |
| ], | |
| "primary_label": "NO_ERROR", | |
| "primary_confidence": 1.0, | |
| "match_detected": True, | |
| } | |
| text = format_cross_encoder_input( | |
| question=question, | |
| schema=schema, | |
| student_sql=student_sql, | |
| correct_sql=correct_sql, | |
| ) | |
| encoded = self.tokenizer( | |
| text, | |
| truncation=True, | |
| max_length=self.max_length, | |
| padding=True, | |
| return_tensors="pt", | |
| ).to(self.device) | |
| with torch.no_grad(): | |
| logits = self.model(**encoded).logits.cpu().numpy()[0] | |
| probs = sigmoid(logits) | |
| predicted = multihot_to_label_names(probs, self.label_list, threshold=thr) | |
| ranked = sorted( | |
| zip(self.label_list, probs.tolist()), | |
| key=lambda x: x[1], | |
| reverse=True, | |
| )[:top_k] | |
| top_label, top_prob = ranked[0] | |
| if top_prob >= thr: | |
| primary_label = top_label | |
| primary_confidence = float(top_prob) | |
| else: | |
| primary_label = "NO_ERROR" | |
| primary_confidence = 1.0 - float(top_prob) | |
| return { | |
| "error_labels": predicted, | |
| "probabilities": {name: float(p) for name, p in ranked}, | |
| "top_k": [ | |
| {"label": name, "probability": float(p)} for name, p in ranked | |
| ], | |
| "primary_label": primary_label, | |
| "primary_confidence": primary_confidence, | |
| "match_detected": False, | |
| } | |
| def predict_batch( | |
| self, | |
| examples: List[dict], | |
| batch_size: int = 16, | |
| ) -> List[dict]: | |
| results = [] | |
| for i in range(0, len(examples), batch_size): | |
| chunk = examples[i : i + batch_size] | |
| texts = [ | |
| format_cross_encoder_input( | |
| question=x["question"], | |
| schema=x["schema"], | |
| student_sql=x["student_sql"], | |
| correct_sql=x["correct_sql"], | |
| ) | |
| for x in chunk | |
| ] | |
| encoded = self.tokenizer( | |
| texts, | |
| truncation=True, | |
| max_length=self.max_length, | |
| padding=True, | |
| return_tensors="pt", | |
| ).to(self.device) | |
| with torch.no_grad(): | |
| logits = self.model(**encoded).logits.cpu().numpy() | |
| for j, row in enumerate(logits): | |
| probs = sigmoid(row) | |
| results.append( | |
| { | |
| "error_labels": multihot_to_label_names( | |
| probs, self.label_list, self.threshold | |
| ), | |
| "primary_label": self.label_list[int(np.argmax(probs))], | |
| "primary_confidence": float(np.max(probs)), | |
| } | |
| ) | |
| return results | |
| def main() -> None: | |
| parser = argparse.ArgumentParser(description="CodeBERT SQL error inference") | |
| parser.add_argument("--model-dir", type=Path, default=DEFAULT_MODEL_DIR) | |
| parser.add_argument("--question", type=str, required=True) | |
| parser.add_argument("--schema", type=str, required=True) | |
| parser.add_argument("--student-sql", type=str, required=True) | |
| parser.add_argument("--correct-sql", type=str, required=True) | |
| parser.add_argument("--threshold", type=float, default=0.5) | |
| args = parser.parse_args() | |
| clf = CodeBERTSQLErrorClassifier(args.model_dir, threshold=args.threshold) | |
| result = clf.predict( | |
| question=args.question, | |
| schema=args.schema, | |
| student_sql=args.student_sql, | |
| correct_sql=args.correct_sql, | |
| ) | |
| print(json.dumps(result, indent=2)) | |
| if __name__ == "__main__": | |
| main() | |