sql-error-classifier / src /codebert_dataset.py
nishu08's picture
Deploy CodeBERT inference Space
8a3099e verified
"""PyTorch Dataset and preprocessing for CodeBERT Trainer."""
from __future__ import annotations
from typing import Any, Dict, List, Optional
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset
from transformers import PreTrainedTokenizerBase
from src.codebert_formatting import format_cross_encoder_input
from src.codebert_labels import label_to_multihot, load_codebert_labels
def normalize_dataframe(df: pd.DataFrame) -> pd.DataFrame:
"""Map project column names to the canonical training schema."""
col_map = {
"query": "student_sql",
"correct_query": "correct_sql",
"label_name": "error_labels",
}
out = df.rename(columns={k: v for k, v in col_map.items() if k in df.columns}).copy()
required = ["question", "schema", "student_sql", "correct_sql", "error_labels"]
missing = [c for c in required if c not in out.columns]
if missing:
raise ValueError(
f"Dataset missing required columns: {missing}. "
f"Expected {required} (or aliases query/correct_query/label_name)."
)
return out
class SQLCodeBERTDataset(Dataset):
"""Tokenized SQL error dataset for Hugging Face Trainer."""
def __init__(
self,
df: pd.DataFrame,
tokenizer: PreTrainedTokenizerBase,
label_list: Optional[List[str]] = None,
max_length: int = 512,
):
self.df = normalize_dataframe(df).reset_index(drop=True)
self.tokenizer = tokenizer
self.label_list = label_list or load_codebert_labels()
self.max_length = max_length
self.num_labels = len(self.label_list)
def __len__(self) -> int:
return len(self.df)
def __getitem__(self, idx: int) -> Dict[str, Any]:
row = self.df.iloc[idx]
text = format_cross_encoder_input(
question=str(row["question"]),
schema=str(row["schema"]),
student_sql=str(row["student_sql"]),
correct_sql=str(row["correct_sql"]),
)
encoded = self.tokenizer(
text,
truncation=True,
max_length=self.max_length,
padding=False,
return_tensors=None,
)
labels = label_to_multihot(str(row["error_labels"]), self.label_list)
encoded["labels"] = labels.tolist()
return encoded
class SQLCodeBERTDataCollator:
"""Pad batches dynamically for Trainer."""
def __init__(self, tokenizer: PreTrainedTokenizerBase):
self.tokenizer = tokenizer
def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
labels = [f.pop("labels") for f in features]
batch = self.tokenizer.pad(features, padding=True, return_tensors="pt")
batch["labels"] = torch.tensor(labels, dtype=torch.float)
return batch
def prepare_datasets(
df: pd.DataFrame,
tokenizer: PreTrainedTokenizerBase,
test_size: float = 0.1,
val_size: float = 0.1,
max_length: int = 512,
seed: int = 42,
) -> tuple[SQLCodeBERTDataset, SQLCodeBERTDataset, SQLCodeBERTDataset]:
from sklearn.model_selection import train_test_split
df = normalize_dataframe(df)
trainval, test_df = train_test_split(
df,
test_size=test_size,
random_state=seed,
stratify=df["error_labels"],
)
relative_val = val_size / (1 - test_size)
train_df, val_df = train_test_split(
trainval,
test_size=relative_val,
random_state=seed,
stratify=trainval["error_labels"],
)
return (
SQLCodeBERTDataset(train_df, tokenizer, max_length=max_length),
SQLCodeBERTDataset(val_df, tokenizer, max_length=max_length),
SQLCodeBERTDataset(test_df, tokenizer, max_length=max_length),
)