sql-error-classifier / src /codebert_formatting.py
nishu08's picture
Deploy CodeBERT inference Space
7aae828 verified
Raw
History Blame Contribute Delete
1.12 kB
"""Cross-encoder input formatting for CodeBERT."""
from __future__ import annotations
import re
QUESTION_TAG = "QUESTION:"
SCHEMA_TAG = "SCHEMA:"
STUDENT_TAG = "STUDENT_SQL:"
CORRECT_TAG = "CORRECT_SQL:"
def normalize_sql(sql: str) -> str:
"""Normalize SQL for equality checks (whitespace, case, trailing semicolon)."""
text = sql.strip().rstrip(";")
return re.sub(r"\s+", " ", text).lower()
def sql_queries_equivalent(student_sql: str, correct_sql: str) -> bool:
return normalize_sql(student_sql) == normalize_sql(correct_sql)
def format_cross_encoder_input(
question: str,
schema: str,
student_sql: str,
correct_sql: str,
) -> str:
"""
Concatenate all fields into a single CodeBERT input sequence.
The model attends jointly across question intent, schema, student SQL,
and the reference solution — cross-encoder style in one forward pass.
"""
return (
f"{QUESTION_TAG}\n{question.strip()}\n\n"
f"{SCHEMA_TAG}\n{schema.strip()}\n\n"
f"{STUDENT_TAG}\n{student_sql.strip()}\n\n"
f"{CORRECT_TAG}\n{correct_sql.strip()}"
)