"""Cross-encoder input formatting for CodeBERT.""" from __future__ import annotations import re QUESTION_TAG = "QUESTION:" SCHEMA_TAG = "SCHEMA:" STUDENT_TAG = "STUDENT_SQL:" CORRECT_TAG = "CORRECT_SQL:" def normalize_sql(sql: str) -> str: """Normalize SQL for equality checks (whitespace, case, trailing semicolon).""" text = sql.strip().rstrip(";") return re.sub(r"\s+", " ", text).lower() def sql_queries_equivalent(student_sql: str, correct_sql: str) -> bool: return normalize_sql(student_sql) == normalize_sql(correct_sql) def format_cross_encoder_input( question: str, schema: str, student_sql: str, correct_sql: str, ) -> str: """ Concatenate all fields into a single CodeBERT input sequence. The model attends jointly across question intent, schema, student SQL, and the reference solution — cross-encoder style in one forward pass. """ return ( f"{QUESTION_TAG}\n{question.strip()}\n\n" f"{SCHEMA_TAG}\n{schema.strip()}\n\n" f"{STUDENT_TAG}\n{student_sql.strip()}\n\n" f"{CORRECT_TAG}\n{correct_sql.strip()}" )