Spaces:
Sleeping
Sleeping
File size: 1,124 Bytes
8a3099e 7aae828 8a3099e 7aae828 8a3099e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 | """Cross-encoder input formatting for CodeBERT."""
from __future__ import annotations
import re
QUESTION_TAG = "QUESTION:"
SCHEMA_TAG = "SCHEMA:"
STUDENT_TAG = "STUDENT_SQL:"
CORRECT_TAG = "CORRECT_SQL:"
def normalize_sql(sql: str) -> str:
"""Normalize SQL for equality checks (whitespace, case, trailing semicolon)."""
text = sql.strip().rstrip(";")
return re.sub(r"\s+", " ", text).lower()
def sql_queries_equivalent(student_sql: str, correct_sql: str) -> bool:
return normalize_sql(student_sql) == normalize_sql(correct_sql)
def format_cross_encoder_input(
question: str,
schema: str,
student_sql: str,
correct_sql: str,
) -> str:
"""
Concatenate all fields into a single CodeBERT input sequence.
The model attends jointly across question intent, schema, student SQL,
and the reference solution — cross-encoder style in one forward pass.
"""
return (
f"{QUESTION_TAG}\n{question.strip()}\n\n"
f"{SCHEMA_TAG}\n{schema.strip()}\n\n"
f"{STUDENT_TAG}\n{student_sql.strip()}\n\n"
f"{CORRECT_TAG}\n{correct_sql.strip()}"
)
|