sql-error-classifier-train / src /hf_predict_codebert.py
nishu08's picture
Deploy CodeBERT training Space
9b2cded verified
"""Inference for CodeBERT SQL error cross-encoder."""
from __future__ import annotations
import argparse
import json
from pathlib import Path
from typing import List, Optional, Union
import numpy as np
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from src.codebert_formatting import format_cross_encoder_input
from src.codebert_labels import load_codebert_labels, multihot_to_label_names
from src.hf_metrics import sigmoid
PROJECT_ROOT = Path(__file__).resolve().parent.parent
DEFAULT_MODEL_DIR = PROJECT_ROOT / "models" / "codebert-cross-encoder"
class CodeBERTSQLErrorClassifier:
"""CodeBERT cross-encoder inference wrapper."""
def __init__(
self,
model_dir: Union[str, Path] = DEFAULT_MODEL_DIR,
threshold: float = 0.5,
device: Optional[str] = None,
):
self.model_dir = Path(model_dir)
self.threshold = threshold
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
config_path = self.model_dir / "label_config.json"
if config_path.exists():
with open(config_path) as f:
cfg = json.load(f)
self.label_list = cfg.get("labels", load_codebert_labels())
self.threshold = cfg.get("threshold", threshold)
self.max_length = cfg.get("max_length", 512)
else:
self.label_list = load_codebert_labels()
self.max_length = 512
self.tokenizer = AutoTokenizer.from_pretrained(self.model_dir)
self.model = AutoModelForSequenceClassification.from_pretrained(
self.model_dir
).to(self.device)
self.model.eval()
def predict(
self,
question: str,
schema: str,
student_sql: str,
correct_sql: str,
threshold: Optional[float] = None,
top_k: int = 5,
) -> dict:
text = format_cross_encoder_input(
question=question,
schema=schema,
student_sql=student_sql,
correct_sql=correct_sql,
)
encoded = self.tokenizer(
text,
truncation=True,
max_length=self.max_length,
padding=True,
return_tensors="pt",
).to(self.device)
with torch.no_grad():
logits = self.model(**encoded).logits.cpu().numpy()[0]
probs = sigmoid(logits)
thr = threshold if threshold is not None else self.threshold
predicted = multihot_to_label_names(probs, self.label_list, threshold=thr)
ranked = sorted(
zip(self.label_list, probs.tolist()),
key=lambda x: x[1],
reverse=True,
)[:top_k]
return {
"error_labels": predicted,
"probabilities": {name: float(p) for name, p in ranked},
"top_k": [
{"label": name, "probability": float(p)} for name, p in ranked
],
"primary_label": ranked[0][0],
"primary_confidence": float(ranked[0][1]),
}
def predict_batch(
self,
examples: List[dict],
batch_size: int = 16,
) -> List[dict]:
results = []
for i in range(0, len(examples), batch_size):
chunk = examples[i : i + batch_size]
texts = [
format_cross_encoder_input(
question=x["question"],
schema=x["schema"],
student_sql=x["student_sql"],
correct_sql=x["correct_sql"],
)
for x in chunk
]
encoded = self.tokenizer(
texts,
truncation=True,
max_length=self.max_length,
padding=True,
return_tensors="pt",
).to(self.device)
with torch.no_grad():
logits = self.model(**encoded).logits.cpu().numpy()
for j, row in enumerate(logits):
probs = sigmoid(row)
results.append(
{
"error_labels": multihot_to_label_names(
probs, self.label_list, self.threshold
),
"primary_label": self.label_list[int(np.argmax(probs))],
"primary_confidence": float(np.max(probs)),
}
)
return results
def main() -> None:
parser = argparse.ArgumentParser(description="CodeBERT SQL error inference")
parser.add_argument("--model-dir", type=Path, default=DEFAULT_MODEL_DIR)
parser.add_argument("--question", type=str, required=True)
parser.add_argument("--schema", type=str, required=True)
parser.add_argument("--student-sql", type=str, required=True)
parser.add_argument("--correct-sql", type=str, required=True)
parser.add_argument("--threshold", type=float, default=0.5)
args = parser.parse_args()
clf = CodeBERTSQLErrorClassifier(args.model_dir, threshold=args.threshold)
result = clf.predict(
question=args.question,
schema=args.schema,
student_sql=args.student_sql,
correct_sql=args.correct_sql,
)
print(json.dumps(result, indent=2))
if __name__ == "__main__":
main()