File size: 1,401 Bytes
7e815bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
---
language: en
license: mit
tags:
  - codebert
  - sql
  - education
  - text-classification
  - cross-encoder
base_model: microsoft/codebert-base
pipeline_tag: text-classification
---

# SQL CodeBERT Cross-Encoder

Multi-label SQL error classifier using **microsoft/codebert-base** as a cross-encoder.

## Input Format

All fields are concatenated into one sequence:

```
QUESTION:
{question}

SCHEMA:
{schema}

STUDENT_SQL:
{student_sql}

CORRECT_SQL:
{correct_sql}
```

## Labels

`JOIN_ERROR`, `AGGREGATION_ERROR`, `FILTER_ERROR`, `WINDOW_FUNCTION_ERROR`,
`SUBQUERY_ERROR`, `NULL_HANDLING_ERROR`, `PERFORMANCE_ERROR`, `LOGICAL_ERROR`, `SYNTAX_ERROR`

## Training

```bash
python -m src.hf_train_codebert \
  --data data/sql_errors_1m.parquet \
  --output-dir models/codebert-cross-encoder \
  --epochs 3 \
  --push-to-hub \
  --hub-model-id YOUR_USERNAME/sql-codebert-cross-encoder
```

## Inference

```python
from src.hf_predict_codebert import CodeBERTSQLErrorClassifier

clf = CodeBERTSQLErrorClassifier("YOUR_USERNAME/sql-codebert-cross-encoder")
result = clf.predict(
    question="What is the average score per department?",
    schema="students(id, score, department_id)",
    student_sql="SELECT department_id, SUM(score) FROM students GROUP BY department_id",
    correct_sql="SELECT department_id, AVG(score) FROM students GROUP BY department_id",
)
print(result["error_labels"])
```