SQL CodeBERT Cross-Encoder
Multi-label SQL error classifier using microsoft/codebert-base as a cross-encoder.
Input Format
All fields are concatenated into one sequence:
QUESTION:
{question}
SCHEMA:
{schema}
STUDENT_SQL:
{student_sql}
CORRECT_SQL:
{correct_sql}
Labels
JOIN_ERROR, AGGREGATION_ERROR, FILTER_ERROR, WINDOW_FUNCTION_ERROR,
SUBQUERY_ERROR, NULL_HANDLING_ERROR, PERFORMANCE_ERROR, LOGICAL_ERROR, SYNTAX_ERROR
Training
python -m src.hf_train_codebert \
--data data/sql_errors_1m.parquet \
--output-dir models/codebert-cross-encoder \
--epochs 3 \
--push-to-hub \
--hub-model-id YOUR_USERNAME/sql-codebert-cross-encoder
Inference
from src.hf_predict_codebert import CodeBERTSQLErrorClassifier
clf = CodeBERTSQLErrorClassifier("YOUR_USERNAME/sql-codebert-cross-encoder")
result = clf.predict(
question="What is the average score per department?",
schema="students(id, score, department_id)",
student_sql="SELECT department_id, SUM(score) FROM students GROUP BY department_id",
correct_sql="SELECT department_id, AVG(score) FROM students GROUP BY department_id",
)
print(result["error_labels"])
- Downloads last month
- 42
Model tree for nishu08/sql-codebert-classifier
Base model
microsoft/codebert-base