sql-error-classifier / src /hf_predict_codebert.py
nishu08's picture
Deploy CodeBERT inference Space
7aae828 verified
"""Inference for CodeBERT SQL error cross-encoder."""
from __future__ import annotations
import argparse
import json
from pathlib import Path
from typing import List, Optional, Union
import numpy as np
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from src.codebert_formatting import format_cross_encoder_input, sql_queries_equivalent
from src.device_utils import get_device
from src.codebert_labels import load_codebert_labels, multihot_to_label_names
from src.hf_metrics import sigmoid
PROJECT_ROOT = Path(__file__).resolve().parent.parent
DEFAULT_MODEL_DIR = PROJECT_ROOT / "models" / "codebert-cross-encoder"
def _is_hub_id(model_dir: Union[str, Path]) -> bool:
text = str(model_dir)
local = Path(text)
return "/" in text and not local.exists()
class CodeBERTSQLErrorClassifier:
"""CodeBERT cross-encoder inference wrapper."""
def __init__(
self,
model_dir: Union[str, Path] = DEFAULT_MODEL_DIR,
threshold: float = 0.5,
device: Optional[str] = None,
):
self.hub_id = str(model_dir) if _is_hub_id(model_dir) else None
self.model_dir = Path(model_dir) if not self.hub_id else None
self.threshold = threshold
# MPS inference can be flaky for some ops; CPU is reliable on Mac
self.device = device or (
"cuda" if torch.cuda.is_available() else "cpu"
)
model_ref = self.hub_id or str(self.model_dir)
if self.hub_id:
self.label_list = load_codebert_labels()
self.max_length = 512
else:
config_path = self.model_dir / "label_config.json"
if config_path.exists():
with open(config_path) as f:
cfg = json.load(f)
self.label_list = cfg.get("labels", load_codebert_labels())
self.threshold = cfg.get("threshold", threshold)
self.max_length = cfg.get("max_length", 512)
else:
self.label_list = load_codebert_labels()
self.max_length = 512
self.tokenizer = AutoTokenizer.from_pretrained(model_ref)
self.model = AutoModelForSequenceClassification.from_pretrained(
model_ref
).to(self.device)
self.model.eval()
def predict(
self,
question: str,
schema: str,
student_sql: str,
correct_sql: str,
threshold: Optional[float] = None,
top_k: int = 5,
) -> dict:
thr = threshold if threshold is not None else self.threshold
if sql_queries_equivalent(student_sql, correct_sql):
return {
"error_labels": [],
"probabilities": {name: 0.0 for name in self.label_list},
"top_k": [
{"label": name, "probability": 0.0}
for name in self.label_list[:5]
],
"primary_label": "NO_ERROR",
"primary_confidence": 1.0,
"match_detected": True,
}
text = format_cross_encoder_input(
question=question,
schema=schema,
student_sql=student_sql,
correct_sql=correct_sql,
)
encoded = self.tokenizer(
text,
truncation=True,
max_length=self.max_length,
padding=True,
return_tensors="pt",
).to(self.device)
with torch.no_grad():
logits = self.model(**encoded).logits.cpu().numpy()[0]
probs = sigmoid(logits)
predicted = multihot_to_label_names(probs, self.label_list, threshold=thr)
ranked = sorted(
zip(self.label_list, probs.tolist()),
key=lambda x: x[1],
reverse=True,
)[:top_k]
top_label, top_prob = ranked[0]
if top_prob >= thr:
primary_label = top_label
primary_confidence = float(top_prob)
else:
primary_label = "NO_ERROR"
primary_confidence = 1.0 - float(top_prob)
return {
"error_labels": predicted,
"probabilities": {name: float(p) for name, p in ranked},
"top_k": [
{"label": name, "probability": float(p)} for name, p in ranked
],
"primary_label": primary_label,
"primary_confidence": primary_confidence,
"match_detected": False,
}
def predict_batch(
self,
examples: List[dict],
batch_size: int = 16,
) -> List[dict]:
results = []
for i in range(0, len(examples), batch_size):
chunk = examples[i : i + batch_size]
texts = [
format_cross_encoder_input(
question=x["question"],
schema=x["schema"],
student_sql=x["student_sql"],
correct_sql=x["correct_sql"],
)
for x in chunk
]
encoded = self.tokenizer(
texts,
truncation=True,
max_length=self.max_length,
padding=True,
return_tensors="pt",
).to(self.device)
with torch.no_grad():
logits = self.model(**encoded).logits.cpu().numpy()
for j, row in enumerate(logits):
probs = sigmoid(row)
results.append(
{
"error_labels": multihot_to_label_names(
probs, self.label_list, self.threshold
),
"primary_label": self.label_list[int(np.argmax(probs))],
"primary_confidence": float(np.max(probs)),
}
)
return results
def main() -> None:
parser = argparse.ArgumentParser(description="CodeBERT SQL error inference")
parser.add_argument("--model-dir", type=Path, default=DEFAULT_MODEL_DIR)
parser.add_argument("--question", type=str, required=True)
parser.add_argument("--schema", type=str, required=True)
parser.add_argument("--student-sql", type=str, required=True)
parser.add_argument("--correct-sql", type=str, required=True)
parser.add_argument("--threshold", type=float, default=0.5)
args = parser.parse_args()
clf = CodeBERTSQLErrorClassifier(args.model_dir, threshold=args.threshold)
result = clf.predict(
question=args.question,
schema=args.schema,
student_sql=args.student_sql,
correct_sql=args.correct_sql,
)
print(json.dumps(result, indent=2))
if __name__ == "__main__":
main()