Spaces:
Running
Running
| """PyTorch Dataset and preprocessing for CodeBERT Trainer.""" | |
| from __future__ import annotations | |
| from typing import Any, Dict, List, Optional | |
| import numpy as np | |
| import pandas as pd | |
| import torch | |
| from torch.utils.data import Dataset | |
| from transformers import PreTrainedTokenizerBase | |
| from src.codebert_formatting import format_cross_encoder_input | |
| from src.codebert_labels import label_to_multihot, load_codebert_labels | |
| def normalize_dataframe(df: pd.DataFrame) -> pd.DataFrame: | |
| """Map project column names to the canonical training schema.""" | |
| col_map = { | |
| "query": "student_sql", | |
| "correct_query": "correct_sql", | |
| "label_name": "error_labels", | |
| } | |
| out = df.rename(columns={k: v for k, v in col_map.items() if k in df.columns}).copy() | |
| required = ["question", "schema", "student_sql", "correct_sql", "error_labels"] | |
| missing = [c for c in required if c not in out.columns] | |
| if missing: | |
| raise ValueError( | |
| f"Dataset missing required columns: {missing}. " | |
| f"Expected {required} (or aliases query/correct_query/label_name)." | |
| ) | |
| return out | |
| class SQLCodeBERTDataset(Dataset): | |
| """Tokenized SQL error dataset for Hugging Face Trainer.""" | |
| def __init__( | |
| self, | |
| df: pd.DataFrame, | |
| tokenizer: PreTrainedTokenizerBase, | |
| label_list: Optional[List[str]] = None, | |
| max_length: int = 512, | |
| ): | |
| self.df = normalize_dataframe(df).reset_index(drop=True) | |
| self.tokenizer = tokenizer | |
| self.label_list = label_list or load_codebert_labels() | |
| self.max_length = max_length | |
| self.num_labels = len(self.label_list) | |
| def __len__(self) -> int: | |
| return len(self.df) | |
| def __getitem__(self, idx: int) -> Dict[str, Any]: | |
| row = self.df.iloc[idx] | |
| text = format_cross_encoder_input( | |
| question=str(row["question"]), | |
| schema=str(row["schema"]), | |
| student_sql=str(row["student_sql"]), | |
| correct_sql=str(row["correct_sql"]), | |
| ) | |
| encoded = self.tokenizer( | |
| text, | |
| truncation=True, | |
| max_length=self.max_length, | |
| padding=False, | |
| return_tensors=None, | |
| ) | |
| labels = label_to_multihot(str(row["error_labels"]), self.label_list) | |
| encoded["labels"] = labels.tolist() | |
| return encoded | |
| class SQLCodeBERTDataCollator: | |
| """Pad batches dynamically for Trainer.""" | |
| def __init__(self, tokenizer: PreTrainedTokenizerBase): | |
| self.tokenizer = tokenizer | |
| def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]: | |
| labels = [f.pop("labels") for f in features] | |
| batch = self.tokenizer.pad(features, padding=True, return_tensors="pt") | |
| batch["labels"] = torch.tensor(labels, dtype=torch.float) | |
| return batch | |
| def prepare_datasets( | |
| df: pd.DataFrame, | |
| tokenizer: PreTrainedTokenizerBase, | |
| test_size: float = 0.1, | |
| val_size: float = 0.1, | |
| max_length: int = 512, | |
| seed: int = 42, | |
| ) -> tuple[SQLCodeBERTDataset, SQLCodeBERTDataset, SQLCodeBERTDataset]: | |
| from sklearn.model_selection import train_test_split | |
| df = normalize_dataframe(df) | |
| trainval, test_df = train_test_split( | |
| df, | |
| test_size=test_size, | |
| random_state=seed, | |
| stratify=df["error_labels"], | |
| ) | |
| relative_val = val_size / (1 - test_size) | |
| train_df, val_df = train_test_split( | |
| trainval, | |
| test_size=relative_val, | |
| random_state=seed, | |
| stratify=trainval["error_labels"], | |
| ) | |
| return ( | |
| SQLCodeBERTDataset(train_df, tokenizer, max_length=max_length), | |
| SQLCodeBERTDataset(val_df, tokenizer, max_length=max_length), | |
| SQLCodeBERTDataset(test_df, tokenizer, max_length=max_length), | |
| ) | |