File size: 3,807 Bytes
9b2cded
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
"""PyTorch Dataset and preprocessing for CodeBERT Trainer."""

from __future__ import annotations

from typing import Any, Dict, List, Optional

import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset
from transformers import PreTrainedTokenizerBase

from src.codebert_formatting import format_cross_encoder_input
from src.codebert_labels import label_to_multihot, load_codebert_labels


def normalize_dataframe(df: pd.DataFrame) -> pd.DataFrame:
    """Map project column names to the canonical training schema."""
    col_map = {
        "query": "student_sql",
        "correct_query": "correct_sql",
        "label_name": "error_labels",
    }
    out = df.rename(columns={k: v for k, v in col_map.items() if k in df.columns}).copy()

    required = ["question", "schema", "student_sql", "correct_sql", "error_labels"]
    missing = [c for c in required if c not in out.columns]
    if missing:
        raise ValueError(
            f"Dataset missing required columns: {missing}. "
            f"Expected {required} (or aliases query/correct_query/label_name)."
        )
    return out


class SQLCodeBERTDataset(Dataset):
    """Tokenized SQL error dataset for Hugging Face Trainer."""

    def __init__(
        self,
        df: pd.DataFrame,
        tokenizer: PreTrainedTokenizerBase,
        label_list: Optional[List[str]] = None,
        max_length: int = 512,
    ):
        self.df = normalize_dataframe(df).reset_index(drop=True)
        self.tokenizer = tokenizer
        self.label_list = label_list or load_codebert_labels()
        self.max_length = max_length
        self.num_labels = len(self.label_list)

    def __len__(self) -> int:
        return len(self.df)

    def __getitem__(self, idx: int) -> Dict[str, Any]:
        row = self.df.iloc[idx]
        text = format_cross_encoder_input(
            question=str(row["question"]),
            schema=str(row["schema"]),
            student_sql=str(row["student_sql"]),
            correct_sql=str(row["correct_sql"]),
        )
        encoded = self.tokenizer(
            text,
            truncation=True,
            max_length=self.max_length,
            padding=False,
            return_tensors=None,
        )
        labels = label_to_multihot(str(row["error_labels"]), self.label_list)
        encoded["labels"] = labels.tolist()
        return encoded


class SQLCodeBERTDataCollator:
    """Pad batches dynamically for Trainer."""

    def __init__(self, tokenizer: PreTrainedTokenizerBase):
        self.tokenizer = tokenizer

    def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
        labels = [f.pop("labels") for f in features]
        batch = self.tokenizer.pad(features, padding=True, return_tensors="pt")
        batch["labels"] = torch.tensor(labels, dtype=torch.float)
        return batch


def prepare_datasets(
    df: pd.DataFrame,
    tokenizer: PreTrainedTokenizerBase,
    test_size: float = 0.1,
    val_size: float = 0.1,
    max_length: int = 512,
    seed: int = 42,
) -> tuple[SQLCodeBERTDataset, SQLCodeBERTDataset, SQLCodeBERTDataset]:
    from sklearn.model_selection import train_test_split

    df = normalize_dataframe(df)
    trainval, test_df = train_test_split(
        df,
        test_size=test_size,
        random_state=seed,
        stratify=df["error_labels"],
    )
    relative_val = val_size / (1 - test_size)
    train_df, val_df = train_test_split(
        trainval,
        test_size=relative_val,
        random_state=seed,
        stratify=trainval["error_labels"],
    )

    return (
        SQLCodeBERTDataset(train_df, tokenizer, max_length=max_length),
        SQLCodeBERTDataset(val_df, tokenizer, max_length=max_length),
        SQLCodeBERTDataset(test_df, tokenizer, max_length=max_length),
    )