| """ |
| Dataset classes for LLM training. |
| |
| TextDataset: Sliding window (stride 1) over a memory-mapped uint16 binary file. |
| PackedDataset: Non-overlapping windows (stride = seq_len) over the same file format. |
| """ |
|
|
| from __future__ import annotations |
|
|
| from pathlib import Path |
| from typing import Tuple, Union |
|
|
| import numpy as np |
| import torch |
| from torch.utils.data import Dataset |
|
|
|
|
| class TextDataset(Dataset): |
| """ |
| Sliding-window dataset over a memory-mapped numpy uint16 binary token file. |
| |
| Each sample is a (input_ids, targets) pair of length seq_len, where |
| targets is input_ids shifted by one position. Windows overlap by |
| (seq_len - 1) tokens, i.e. stride = 1. |
| |
| Args: |
| data_path: Path to the .bin file produced by data/prepare.py. |
| seq_len: Number of tokens per sample (context length). |
| """ |
|
|
| def __init__(self, data_path: Union[str, Path], seq_len: int) -> None: |
| super().__init__() |
| self.seq_len = seq_len |
| path = Path(data_path) |
| if not path.exists(): |
| raise FileNotFoundError(f"Data file not found: {path}") |
| |
| self.data: np.ndarray = np.memmap(path, dtype="uint16", mode="r") |
| |
| import mmap as _mmap |
| try: |
| self.data._mmap.madvise(_mmap.MADV_SEQUENTIAL) |
| except (AttributeError, OSError): |
| pass |
| if len(self.data) < seq_len + 1: |
| raise ValueError( |
| f"Data file has only {len(self.data)} tokens, " |
| f"need at least {seq_len + 1}." |
| ) |
|
|
| def __len__(self) -> int: |
| |
| return len(self.data) - self.seq_len |
|
|
| def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]: |
| |
| chunk = self.data[idx : idx + self.seq_len + 1] |
| |
| |
| |
| |
| chunk = torch.from_numpy(chunk.astype(np.int32)) |
| input_ids = chunk[:-1] |
| targets = chunk[1:] |
| return input_ids, targets |
|
|
|
|
| class PackedDataset(Dataset): |
| """ |
| Non-overlapping packed dataset over a memory-mapped uint16 binary token file. |
| |
| Intended for data that has already been packed (documents concatenated with |
| EOS tokens). Windows do not overlap; stride = seq_len. |
| |
| The target sequence is shifted by one token relative to input_ids. Because |
| the last token of a window shares its target with the *first* token of the |
| next window, the final target position is filled with -1 (the standard |
| ``ignore_index`` for ``nn.CrossEntropyLoss``). |
| |
| Args: |
| data_path: Path to the .bin file produced by data/prepare.py. |
| seq_len: Number of tokens per sample (context length). |
| """ |
|
|
| def __init__(self, data_path: Union[str, Path], seq_len: int) -> None: |
| super().__init__() |
| self.seq_len = seq_len |
| path = Path(data_path) |
| if not path.exists(): |
| raise FileNotFoundError(f"Data file not found: {path}") |
| self.data: np.ndarray = np.memmap(path, dtype="uint16", mode="r") |
| |
| import mmap as _mmap |
| try: |
| self.data._mmap.madvise(_mmap.MADV_RANDOM) |
| self.data._mmap.madvise(_mmap.MADV_WILLNEED) |
| except (AttributeError, OSError): |
| pass |
| if len(self.data) < seq_len: |
| raise ValueError( |
| f"Data file has only {len(self.data)} tokens, " |
| f"need at least {seq_len}." |
| ) |
|
|
| def __len__(self) -> int: |
| return len(self.data) // self.seq_len |
|
|
| def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]: |
| start = idx * self.seq_len |
| end = start + self.seq_len |
|
|
| |
| |
| input_ids = torch.from_numpy( |
| self.data[start:end].astype(np.int32) |
| ) |
|
|
| |
| |
| if end < len(self.data): |
| targets = torch.from_numpy( |
| self.data[start + 1 : end + 1].astype(np.int32) |
| ) |
| else: |
| |
| |
| targets = torch.full((self.seq_len,), fill_value=-1, dtype=torch.int32) |
| if end - start - 1 > 0: |
| targets[: self.seq_len - 1] = torch.from_numpy( |
| self.data[start + 1 : end].astype(np.int32) |
| ) |
|
|
| return input_ids, targets |
|
|