| from __future__ import annotations |
|
|
| import bisect |
| import functools |
| import importlib.util |
| import json |
| from dataclasses import dataclass |
| from pathlib import Path |
| from typing import Dict, Iterable, List, Optional, Tuple |
|
|
| import torch |
| from torch.utils.data import DataLoader, Dataset |
| from xqs_stack import choose_optimizer_backend |
|
|
| SHARD_INDEX_FILENAME = "shard_index.json" |
| SHARD_INDEX_PROGRESS_EVERY = 256 |
|
|
|
|
| @dataclass |
| class TrainStackConfig: |
| optimizer_name: str = "adafactor" |
| learning_rate: float = 3e-4 |
| weight_decay: float = 0.01 |
| batch_size: int = 4 |
| grad_accum_steps: int = 1 |
| num_workers: int = 2 |
| pin_memory: bool = True |
| prefetch_factor: int = 4 |
| persistent_workers: bool = True |
| max_seq_len: int = 2048 |
| dataset_dir: str = "" |
| use_bf16: bool = True |
|
|
|
|
| class PretokenizedShardDataset(Dataset): |
| def __init__(self, dataset_dir: str, max_seq_len: int): |
| self.root = Path(dataset_dir) |
| if not self.root.exists(): |
| raise FileNotFoundError(f"Dataset directory not found: {dataset_dir}") |
| self.max_seq_len = max_seq_len |
| self.shard_paths = sorted(self.root.glob("*.pt")) |
| if not self.shard_paths: |
| raise FileNotFoundError(f"No .pt shards found in {dataset_dir}") |
| self.shard_sizes: List[int] = [] |
| self.cumulative_sizes: List[int] = [] |
| total = 0 |
| self._cached_shard_path: Optional[Path] = None |
| self._cached_shard_tensor: Optional[torch.Tensor] = None |
| for shard_path, shard_len in self._load_or_build_shard_index(): |
| total += shard_len |
| self.shard_sizes.append(shard_len) |
| self.cumulative_sizes.append(total) |
|
|
| def _shard_index_path(self) -> Path: |
| return self.root / SHARD_INDEX_FILENAME |
|
|
| def _read_json_file(self, path: Path) -> Dict[str, object]: |
| try: |
| return json.loads(path.read_text(encoding="utf-8")) |
| except (OSError, json.JSONDecodeError): |
| return {} |
|
|
| def _extract_index_entries(self, payload: Dict[str, object]) -> Optional[List[Tuple[Path, int]]]: |
| shard_entries = payload.get("shards") |
| if not isinstance(shard_entries, list): |
| return None |
| lengths_by_name: Dict[str, int] = {} |
| for entry in shard_entries: |
| if not isinstance(entry, dict): |
| return None |
| file_name = entry.get("file") |
| length = entry.get("length") |
| if not isinstance(file_name, str) or not isinstance(length, int): |
| return None |
| lengths_by_name[file_name] = length |
| resolved: List[Tuple[Path, int]] = [] |
| for shard_path in self.shard_paths: |
| length = lengths_by_name.get(shard_path.name) |
| if length is None: |
| return None |
| resolved.append((shard_path, length)) |
| return resolved |
|
|
| def _load_cached_index(self) -> Optional[List[Tuple[Path, int]]]: |
| for candidate in [self._shard_index_path(), self.root / "metadata.json"]: |
| if not candidate.exists(): |
| continue |
| resolved = self._extract_index_entries(self._read_json_file(candidate)) |
| if resolved is not None: |
| print( |
| json.dumps( |
| { |
| "event": "dataset_index_loaded", |
| "dataset_dir": str(self.root), |
| "source": candidate.name, |
| "shards": len(resolved), |
| "samples": sum(length for _, length in resolved), |
| } |
| ), |
| flush=True, |
| ) |
| return resolved |
| return None |
|
|
| def _infer_shard_len(self, shard_path: Path) -> int: |
| shard = torch.load(shard_path, map_location="cpu") |
| if isinstance(shard, torch.Tensor): |
| if shard.ndim == 2: |
| return int(shard.size(0)) |
| return 1 |
| if isinstance(shard, list): |
| return len(shard) |
| raise TypeError(f"Unsupported shard format in {shard_path}") |
|
|
| def _write_cached_index(self, entries: List[Tuple[Path, int]]) -> None: |
| payload = { |
| "shards": [{"file": path.name, "length": length} for path, length in entries], |
| "total_samples": sum(length for _, length in entries), |
| } |
| self._shard_index_path().write_text(json.dumps(payload, indent=2), encoding="utf-8") |
|
|
| def _load_or_build_shard_index(self) -> List[Tuple[Path, int]]: |
| cached = self._load_cached_index() |
| if cached is not None: |
| return cached |
| print( |
| json.dumps( |
| { |
| "event": "dataset_index_build_start", |
| "dataset_dir": str(self.root), |
| "shards": len(self.shard_paths), |
| } |
| ), |
| flush=True, |
| ) |
| entries: List[Tuple[Path, int]] = [] |
| running_total = 0 |
| for shard_idx, shard_path in enumerate(self.shard_paths, start=1): |
| shard_len = self._infer_shard_len(shard_path) |
| entries.append((shard_path, shard_len)) |
| running_total += shard_len |
| if shard_idx % SHARD_INDEX_PROGRESS_EVERY == 0 or shard_idx == len(self.shard_paths): |
| print( |
| json.dumps( |
| { |
| "event": "dataset_index_build_progress", |
| "dataset_dir": str(self.root), |
| "indexed_shards": shard_idx, |
| "total_shards": len(self.shard_paths), |
| "samples": running_total, |
| } |
| ), |
| flush=True, |
| ) |
| self._write_cached_index(entries) |
| print( |
| json.dumps( |
| { |
| "event": "dataset_index_build_done", |
| "dataset_dir": str(self.root), |
| "shards": len(entries), |
| "samples": running_total, |
| } |
| ), |
| flush=True, |
| ) |
| return entries |
|
|
| def __len__(self) -> int: |
| return self.cumulative_sizes[-1] |
|
|
| def _load_shard(self, shard_idx: int) -> torch.Tensor: |
| shard_path = self.shard_paths[shard_idx] |
| if self._cached_shard_path == shard_path and self._cached_shard_tensor is not None: |
| return self._cached_shard_tensor |
| shard = torch.load(shard_path, map_location="cpu") |
| if isinstance(shard, list): |
| shard = torch.stack([torch.as_tensor(item, dtype=torch.long) for item in shard], dim=0) |
| elif isinstance(shard, torch.Tensor): |
| if shard.ndim == 1: |
| shard = shard.unsqueeze(0) |
| else: |
| raise TypeError(f"Unsupported shard format in {shard_path}") |
| self._cached_shard_path = shard_path |
| self._cached_shard_tensor = shard |
| return shard |
|
|
| def __getitem__(self, idx: int) -> torch.Tensor: |
| if idx < 0: |
| idx += len(self) |
| shard_idx = bisect.bisect_right(self.cumulative_sizes, idx) |
| shard_start = 0 if shard_idx == 0 else self.cumulative_sizes[shard_idx - 1] |
| item_idx = idx - shard_start |
| tokens = self._load_shard(shard_idx)[item_idx].to(dtype=torch.long) |
| if tokens.numel() < 2: |
| padded = torch.zeros(2, dtype=torch.long) |
| padded[: tokens.numel()] = tokens |
| tokens = padded |
| return tokens[: self.max_seq_len + 1] |
|
|
|
|
| class SyntheticTokenDataset(Dataset): |
| def __init__(self, vocab_size: int, max_seq_len: int, num_samples: int = 128): |
| self.vocab_size = vocab_size |
| self.max_seq_len = max_seq_len |
| self.num_samples = num_samples |
|
|
| def __len__(self) -> int: |
| return self.num_samples |
|
|
| def __getitem__(self, idx: int) -> torch.Tensor: |
| return torch.randint(0, self.vocab_size, (self.max_seq_len + 1,), dtype=torch.long) |
|
|
|
|
| class LayerWiseSGD(torch.optim.Optimizer): |
| def __init__(self, params: Iterable[torch.nn.Parameter], lr: float = 1e-2, momentum: float = 0.9, weight_decay: float = 0.0): |
| defaults = dict(lr=lr, momentum=momentum, weight_decay=weight_decay) |
| super().__init__(params, defaults) |
|
|
| @torch.no_grad() |
| def step(self, closure=None): |
| loss = None |
| if closure is not None: |
| with torch.enable_grad(): |
| loss = closure() |
| for group in self.param_groups: |
| lr = group["lr"] |
| momentum = group["momentum"] |
| weight_decay = group["weight_decay"] |
| params_with_grad = [p for p in group["params"] if p.grad is not None] |
| if not params_with_grad: |
| continue |
| device = params_with_grad[0].device |
| mean_grad_sq = torch.zeros((), device=device) |
| counted = 0 |
| for p in params_with_grad: |
| grad = p.grad |
| if weight_decay != 0: |
| grad = grad.add(p, alpha=weight_decay) |
| mean_grad_sq = mean_grad_sq + grad.pow(2).mean() |
| counted += 1 |
| mean_grad_sq = mean_grad_sq / max(1, counted) |
| velocity = group.get("layer_velocity") |
| if velocity is None: |
| velocity = torch.zeros((), device=device) |
| velocity = (momentum * velocity) + mean_grad_sq.sqrt() |
| group["layer_velocity"] = velocity |
| scale = lr / velocity.clamp(min=1e-8) |
| for p in params_with_grad: |
| grad = p.grad |
| if weight_decay != 0: |
| grad = grad.add(p, alpha=weight_decay) |
| p.add_(grad, alpha=-scale) |
| return loss |
|
|
|
|
| def _build_adafactor(params: Iterable[torch.nn.Parameter], cfg: TrainStackConfig): |
| if importlib.util.find_spec("transformers") is None: |
| return torch.optim.AdamW(params, lr=cfg.learning_rate, weight_decay=cfg.weight_decay) |
| transformers = __import__("transformers") |
| return transformers.Adafactor( |
| params, |
| lr=cfg.learning_rate, |
| relative_step=False, |
| scale_parameter=False, |
| warmup_init=False, |
| weight_decay=cfg.weight_decay, |
| ) |
|
|
|
|
| def _build_adam8bit(params: Iterable[torch.nn.Parameter], cfg: TrainStackConfig): |
| if importlib.util.find_spec("bitsandbytes") is None: |
| return torch.optim.AdamW(params, lr=cfg.learning_rate, weight_decay=cfg.weight_decay) |
| bnb = __import__("bitsandbytes") |
| return bnb.optim.Adam8bit(params, lr=cfg.learning_rate, weight_decay=cfg.weight_decay) |
|
|
|
|
| def build_optimizer(model: torch.nn.Module, cfg: TrainStackConfig) -> torch.optim.Optimizer: |
| name = cfg.optimizer_name.lower() |
| if name == "auto": |
| name = choose_optimizer_backend(prefer_low_memory=True) |
| if name in {"adamw_fused", "fused_adamw"}: |
| if torch.cuda.is_available(): |
| try: |
| return torch.optim.AdamW( |
| model.parameters(), |
| lr=cfg.learning_rate, |
| weight_decay=cfg.weight_decay, |
| fused=True, |
| ) |
| except TypeError: |
| pass |
| return torch.optim.AdamW(model.parameters(), lr=cfg.learning_rate, weight_decay=cfg.weight_decay) |
| if name == "adafactor": |
| return _build_adafactor(model.parameters(), cfg) |
| if name in {"adam8bit", "adam_8bit", "8bit-adam"}: |
| return _build_adam8bit(model.parameters(), cfg) |
| if name in {"layerwisesgd", "lowmemsgd", "sgd"}: |
| return LayerWiseSGD(model.parameters(), lr=cfg.learning_rate, momentum=0.9, weight_decay=cfg.weight_decay) |
| return torch.optim.AdamW(model.parameters(), lr=cfg.learning_rate, weight_decay=cfg.weight_decay) |
|
|
|
|
| def collate_token_batch(batch: List[torch.Tensor], fixed_length: Optional[int] = None) -> Dict[str, torch.Tensor]: |
| if fixed_length is not None and all(item.numel() >= fixed_length for item in batch): |
| stacked = torch.stack([item[:fixed_length] for item in batch], dim=0) |
| return {"input_ids": stacked[:, :-1], "target_ids": stacked[:, 1:]} |
| max_len = max(item.numel() for item in batch) |
| padded = torch.zeros((len(batch), max_len), dtype=torch.long) |
| targets = torch.full((len(batch), max_len - 1), -100, dtype=torch.long) |
| inputs = torch.zeros((len(batch), max_len - 1), dtype=torch.long) |
| for i, item in enumerate(batch): |
| padded[i, : item.numel()] = item |
| inputs[i, : item.numel() - 1] = item[:-1] |
| targets[i, : item.numel() - 1] = item[1:] |
| return {"input_ids": inputs, "target_ids": targets} |
|
|
|
|
| def build_dataset(dataset_dir: str, vocab_size: int, max_seq_len: int, synthetic_samples: int = 128) -> Dataset: |
| if dataset_dir: |
| return PretokenizedShardDataset(dataset_dir, max_seq_len=max_seq_len) |
| return SyntheticTokenDataset(vocab_size=vocab_size, max_seq_len=max_seq_len, num_samples=synthetic_samples) |
|
|
|
|
| def build_dataloader(dataset: Dataset, cfg: TrainStackConfig, shuffle: bool = True) -> DataLoader: |
| kwargs = dict( |
| batch_size=cfg.batch_size, |
| shuffle=shuffle, |
| num_workers=cfg.num_workers, |
| pin_memory=cfg.pin_memory, |
| persistent_workers=cfg.persistent_workers and cfg.num_workers > 0, |
| collate_fn=functools.partial(collate_token_batch, fixed_length=cfg.max_seq_len + 1), |
| ) |
| if cfg.num_workers > 0: |
| kwargs["prefetch_factor"] = cfg.prefetch_factor |
| return DataLoader(dataset, **kwargs) |
|
|
|
|
| def move_batch_to_device(batch: Dict[str, torch.Tensor], device: torch.device, non_blocking: bool = True) -> Dict[str, torch.Tensor]: |
| return {key: value.to(device, non_blocking=non_blocking) for key, value in batch.items()} |
|
|
|
|
|
|
| def train_demo_steps( |
| model: torch.nn.Module, |
| optimizer: torch.optim.Optimizer, |
| dataloader: DataLoader, |
| device: torch.device, |
| steps: int = 2, |
| use_bf16: bool = True, |
| ) -> Tuple[float, int]: |
| model.train() |
| total_loss = 0.0 |
| total_tokens = 0 |
| autocast_enabled = use_bf16 and device.type == "cuda" |
| for step_idx, batch in enumerate(dataloader): |
| if step_idx >= steps: |
| break |
| batch = move_batch_to_device(batch, device) |
| optimizer.zero_grad(set_to_none=True) |
| with torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=autocast_enabled): |
| loss = model.training_loss(batch["input_ids"], batch["target_ids"]) |
| loss.backward() |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| optimizer.step() |
| total_loss += float(loss.detach().item()) |
| total_tokens += int((batch["target_ids"] != -100).sum().item()) |
| mean_loss = total_loss / max(1, steps) |
| return mean_loss, total_tokens |
|
|