nishu08's picture
Upload README.md with huggingface_hub
7e815bc verified
metadata
language: en
license: mit
tags:
  - codebert
  - sql
  - education
  - text-classification
  - cross-encoder
base_model: microsoft/codebert-base
pipeline_tag: text-classification

SQL CodeBERT Cross-Encoder

Multi-label SQL error classifier using microsoft/codebert-base as a cross-encoder.

Input Format

All fields are concatenated into one sequence:

QUESTION:
{question}

SCHEMA:
{schema}

STUDENT_SQL:
{student_sql}

CORRECT_SQL:
{correct_sql}

Labels

JOIN_ERROR, AGGREGATION_ERROR, FILTER_ERROR, WINDOW_FUNCTION_ERROR, SUBQUERY_ERROR, NULL_HANDLING_ERROR, PERFORMANCE_ERROR, LOGICAL_ERROR, SYNTAX_ERROR

Training

python -m src.hf_train_codebert \
  --data data/sql_errors_1m.parquet \
  --output-dir models/codebert-cross-encoder \
  --epochs 3 \
  --push-to-hub \
  --hub-model-id YOUR_USERNAME/sql-codebert-cross-encoder

Inference

from src.hf_predict_codebert import CodeBERTSQLErrorClassifier

clf = CodeBERTSQLErrorClassifier("YOUR_USERNAME/sql-codebert-cross-encoder")
result = clf.predict(
    question="What is the average score per department?",
    schema="students(id, score, department_id)",
    student_sql="SELECT department_id, SUM(score) FROM students GROUP BY department_id",
    correct_sql="SELECT department_id, AVG(score) FROM students GROUP BY department_id",
)
print(result["error_labels"])