| --- |
| language: en |
| license: mit |
| tags: |
| - codebert |
| - sql |
| - education |
| - text-classification |
| - cross-encoder |
| base_model: microsoft/codebert-base |
| pipeline_tag: text-classification |
| --- |
| |
| # SQL CodeBERT Cross-Encoder |
|
|
| Multi-label SQL error classifier using **microsoft/codebert-base** as a cross-encoder. |
|
|
| ## Input Format |
|
|
| All fields are concatenated into one sequence: |
|
|
| ``` |
| QUESTION: |
| {question} |
| |
| SCHEMA: |
| {schema} |
| |
| STUDENT_SQL: |
| {student_sql} |
| |
| CORRECT_SQL: |
| {correct_sql} |
| ``` |
|
|
| ## Labels |
|
|
| `JOIN_ERROR`, `AGGREGATION_ERROR`, `FILTER_ERROR`, `WINDOW_FUNCTION_ERROR`, |
| `SUBQUERY_ERROR`, `NULL_HANDLING_ERROR`, `PERFORMANCE_ERROR`, `LOGICAL_ERROR`, `SYNTAX_ERROR` |
|
|
| ## Training |
|
|
| ```bash |
| python -m src.hf_train_codebert \ |
| --data data/sql_errors_1m.parquet \ |
| --output-dir models/codebert-cross-encoder \ |
| --epochs 3 \ |
| --push-to-hub \ |
| --hub-model-id YOUR_USERNAME/sql-codebert-cross-encoder |
| ``` |
|
|
| ## Inference |
|
|
| ```python |
| from src.hf_predict_codebert import CodeBERTSQLErrorClassifier |
| |
| clf = CodeBERTSQLErrorClassifier("YOUR_USERNAME/sql-codebert-cross-encoder") |
| result = clf.predict( |
| question="What is the average score per department?", |
| schema="students(id, score, department_id)", |
| student_sql="SELECT department_id, SUM(score) FROM students GROUP BY department_id", |
| correct_sql="SELECT department_id, AVG(score) FROM students GROUP BY department_id", |
| ) |
| print(result["error_labels"]) |
| ``` |
|
|