| """ |
| DataLoader utilities for SLM training. |
| |
| Provides efficient batching and data loading for training. |
| """ |
|
|
| import os |
| from typing import Dict, Optional, List |
|
|
| import torch |
| from torch.utils.data import DataLoader, Dataset, DistributedSampler |
|
|
| from .dataset import ConversationalDataset, StreamingTextDataset, PackedDataset |
| from .tokenizer import SLMTokenizer |
|
|
|
|
| def create_dataloader( |
| dataset: Dataset, |
| batch_size: int, |
| shuffle: bool = True, |
| num_workers: int = 4, |
| pin_memory: bool = None, |
| drop_last: bool = True, |
| distributed: bool = False, |
| world_size: int = 1, |
| rank: int = 0, |
| ) -> DataLoader: |
| """Create a DataLoader with optimal settings. |
| |
| Args: |
| dataset: The dataset to load from |
| batch_size: Batch size per device |
| shuffle: Whether to shuffle data |
| num_workers: Number of data loading workers |
| pin_memory: Pin memory for faster GPU transfer |
| drop_last: Drop last incomplete batch |
| distributed: Whether using distributed training |
| world_size: Number of distributed processes |
| rank: Current process rank |
| |
| Returns: |
| Configured DataLoader |
| """ |
| sampler = None |
| if distributed: |
| sampler = DistributedSampler( |
| dataset, |
| num_replicas=world_size, |
| rank=rank, |
| shuffle=shuffle, |
| ) |
| shuffle = False |
|
|
| |
| if pin_memory is None: |
| import torch |
| pin_memory = torch.cuda.is_available() |
|
|
| return DataLoader( |
| dataset, |
| batch_size=batch_size, |
| shuffle=shuffle if sampler is None else False, |
| sampler=sampler, |
| num_workers=num_workers, |
| pin_memory=pin_memory, |
| drop_last=drop_last, |
| collate_fn=default_collate_fn, |
| ) |
|
|
|
|
| def default_collate_fn(batch: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]: |
| """Collate function for batching samples. |
| |
| Args: |
| batch: List of sample dictionaries |
| |
| Returns: |
| Batched dictionary with stacked tensors |
| """ |
| return { |
| "input_ids": torch.stack([s["input_ids"] for s in batch]), |
| "attention_mask": torch.stack([s["attention_mask"] for s in batch]), |
| "labels": torch.stack([s["labels"] for s in batch]), |
| } |
|
|
|
|
| class DataModule: |
| """Data module for managing train/val dataloaders. |
| |
| Provides a unified interface for data loading during training. |
| """ |
|
|
| def __init__( |
| self, |
| data_dir: str, |
| tokenizer_path: str, |
| max_length: int = 1024, |
| batch_size: int = 32, |
| num_workers: int = 4, |
| val_batch_size: Optional[int] = None, |
| ): |
| """Initialize data module. |
| |
| Args: |
| data_dir: Directory containing processed data |
| tokenizer_path: Path to tokenizer.json |
| max_length: Maximum sequence length |
| batch_size: Training batch size |
| num_workers: Number of data loading workers |
| val_batch_size: Validation batch size (defaults to batch_size) |
| """ |
| self.data_dir = data_dir |
| self.max_length = max_length |
| self.batch_size = batch_size |
| self.val_batch_size = val_batch_size or batch_size |
| self.num_workers = num_workers |
|
|
| |
| self.tokenizer = SLMTokenizer.from_file(tokenizer_path) |
|
|
| |
| self._train_dataset = None |
| self._val_dataset = None |
|
|
| @property |
| def train_dataset(self) -> Dataset: |
| """Get or create training dataset.""" |
| if self._train_dataset is None: |
| self._train_dataset = ConversationalDataset( |
| data_path=self.data_dir, |
| tokenizer=self.tokenizer, |
| max_length=self.max_length, |
| split="train", |
| ) |
| return self._train_dataset |
|
|
| @property |
| def val_dataset(self) -> Dataset: |
| """Get or create validation dataset.""" |
| if self._val_dataset is None: |
| self._val_dataset = ConversationalDataset( |
| data_path=self.data_dir, |
| tokenizer=self.tokenizer, |
| max_length=self.max_length, |
| split="val", |
| ) |
| return self._val_dataset |
|
|
| def train_dataloader( |
| self, |
| distributed: bool = False, |
| world_size: int = 1, |
| rank: int = 0, |
| ) -> DataLoader: |
| """Get training dataloader.""" |
| return create_dataloader( |
| self.train_dataset, |
| batch_size=self.batch_size, |
| shuffle=True, |
| num_workers=self.num_workers, |
| drop_last=True, |
| distributed=distributed, |
| world_size=world_size, |
| rank=rank, |
| ) |
|
|
| def val_dataloader(self) -> DataLoader: |
| """Get validation dataloader.""" |
| return create_dataloader( |
| self.val_dataset, |
| batch_size=self.val_batch_size, |
| shuffle=False, |
| num_workers=self.num_workers, |
| drop_last=False, |
| ) |
|
|
|
|
| class StreamingDataModule: |
| """Data module for streaming large datasets. |
| |
| Memory-efficient loading for large text corpora. |
| """ |
|
|
| def __init__( |
| self, |
| data_files: List[str], |
| tokenizer_path: str, |
| max_length: int = 1024, |
| batch_size: int = 32, |
| num_workers: int = 4, |
| ): |
| """Initialize streaming data module. |
| |
| Args: |
| data_files: List of text file paths |
| tokenizer_path: Path to tokenizer.json |
| max_length: Maximum sequence length |
| batch_size: Batch size |
| num_workers: Number of data loading workers |
| """ |
| self.data_files = data_files |
| self.max_length = max_length |
| self.batch_size = batch_size |
| self.num_workers = num_workers |
|
|
| |
| self.tokenizer = SLMTokenizer.from_file(tokenizer_path) |
|
|
| def train_dataloader(self) -> DataLoader: |
| """Get training dataloader for streaming data.""" |
| dataset = StreamingTextDataset( |
| data_files=self.data_files, |
| tokenizer=self.tokenizer, |
| max_length=self.max_length, |
| shuffle=True, |
| ) |
|
|
| return DataLoader( |
| dataset, |
| batch_size=self.batch_size, |
| num_workers=self.num_workers, |
| pin_memory=True, |
| collate_fn=default_collate_fn, |
| ) |
|
|
|
|
| def estimate_dataset_tokens(data_dir: str, tokenizer_path: str) -> Dict[str, int]: |
| """Estimate total tokens in a dataset. |
| |
| Args: |
| data_dir: Directory containing data files |
| tokenizer_path: Path to tokenizer |
| |
| Returns: |
| Dictionary with token counts |
| """ |
| import json |
| from pathlib import Path |
|
|
| tokenizer = SLMTokenizer.from_file(tokenizer_path) |
|
|
| total_tokens = 0 |
| total_samples = 0 |
|
|
| for file_path in Path(data_dir).glob("*.json*"): |
| with open(file_path, "r") as f: |
| if file_path.suffix == ".jsonl": |
| samples = [json.loads(line) for line in f if line.strip()] |
| else: |
| samples = json.load(f) |
| if not isinstance(samples, list): |
| samples = [samples] |
|
|
| for sample in samples: |
| if "user" in sample and "assistant" in sample: |
| tokens = tokenizer.encode_conversation( |
| sample["user"], sample["assistant"] |
| ) |
| elif "text" in sample: |
| tokens = tokenizer.encode(sample["text"]) |
| else: |
| continue |
|
|
| total_tokens += len(tokens) |
| total_samples += 1 |
|
|
| return { |
| "total_tokens": total_tokens, |
| "total_samples": total_samples, |
| "avg_tokens_per_sample": total_tokens / max(total_samples, 1), |
| } |
|
|
|
|
| def get_dataloader_stats(dataloader: DataLoader) -> Dict[str, float]: |
| """Get statistics from a dataloader. |
| |
| Args: |
| dataloader: The dataloader to analyze |
| |
| Returns: |
| Dictionary with statistics |
| """ |
| total_batches = 0 |
| total_tokens = 0 |
| total_non_pad_tokens = 0 |
|
|
| for batch in dataloader: |
| total_batches += 1 |
| total_tokens += batch["input_ids"].numel() |
| total_non_pad_tokens += batch["attention_mask"].sum().item() |
|
|
| |
| if total_batches >= 100: |
| break |
|
|
| return { |
| "batches_sampled": total_batches, |
| "tokens_per_batch": total_tokens / max(total_batches, 1), |
| "non_pad_ratio": total_non_pad_tokens / max(total_tokens, 1), |
| } |
|
|