--- language: en license: mit tags: - codebert - sql - education - text-classification - cross-encoder base_model: microsoft/codebert-base pipeline_tag: text-classification --- # SQL CodeBERT Cross-Encoder Multi-label SQL error classifier using **microsoft/codebert-base** as a cross-encoder. ## Input Format All fields are concatenated into one sequence: ``` QUESTION: {question} SCHEMA: {schema} STUDENT_SQL: {student_sql} CORRECT_SQL: {correct_sql} ``` ## Labels `JOIN_ERROR`, `AGGREGATION_ERROR`, `FILTER_ERROR`, `WINDOW_FUNCTION_ERROR`, `SUBQUERY_ERROR`, `NULL_HANDLING_ERROR`, `PERFORMANCE_ERROR`, `LOGICAL_ERROR`, `SYNTAX_ERROR` ## Training ```bash python -m src.hf_train_codebert \ --data data/sql_errors_1m.parquet \ --output-dir models/codebert-cross-encoder \ --epochs 3 \ --push-to-hub \ --hub-model-id YOUR_USERNAME/sql-codebert-cross-encoder ``` ## Inference ```python from src.hf_predict_codebert import CodeBERTSQLErrorClassifier clf = CodeBERTSQLErrorClassifier("YOUR_USERNAME/sql-codebert-cross-encoder") result = clf.predict( question="What is the average score per department?", schema="students(id, score, department_id)", student_sql="SELECT department_id, SUM(score) FROM students GROUP BY department_id", correct_sql="SELECT department_id, AVG(score) FROM students GROUP BY department_id", ) print(result["error_labels"]) ```