| """ |
| train/trainer.py — Core training loop. |
| |
| Provides: |
| TrainConfig : Dataclass of all training hyper-parameters. |
| Trainer : Orchestrates gradient accumulation, AMP, gradient clipping, |
| tensorboard logging, and checkpoint saving. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import contextlib |
| import math |
| import time |
| from dataclasses import dataclass, field |
| from pathlib import Path |
| from typing import Optional |
|
|
| import torch |
| import torch.nn as nn |
| from torch.nn.parallel import DistributedDataParallel as DDP |
| from torch.optim import Optimizer |
| from torch.optim.lr_scheduler import LambdaLR |
| from torch.utils.data import DataLoader |
| from torch.utils.data.distributed import DistributedSampler |
| try: |
| from torch.utils.tensorboard import SummaryWriter |
| HAS_TENSORBOARD = True |
| except (ImportError, AttributeError): |
| SummaryWriter = None |
| HAS_TENSORBOARD = False |
|
|
| from train.utils import get_grad_norm, is_main_process, save_checkpoint |
|
|
|
|
| |
| |
| |
| try: |
| import transformer_engine.pytorch as te |
| from transformer_engine.common.recipe import DelayedScaling, Format |
| HAS_TE = True |
| except ImportError: |
| te = None |
| HAS_TE = False |
|
|
|
|
| |
| |
| |
|
|
|
|
| @dataclass |
| class TrainConfig: |
| """Hyper-parameters that control the training loop.""" |
|
|
| |
| max_steps: int = 100_000 |
|
|
| |
| grad_accum_steps: int = 1 |
|
|
| |
| max_grad_norm: float = 1.0 |
|
|
| |
| log_interval: int = 10 |
|
|
| |
| save_interval: int = 1000 |
|
|
| |
| eval_interval: int = 500 |
|
|
| |
| checkpoint_dir: str = "checkpoints" |
|
|
| |
| use_amp: bool = True |
|
|
| |
| compile_model: bool = False |
|
|
| |
| use_fp8: bool = False |
| fp8_amax_history_len: int = 16 |
| fp8_amax_compute_algo: str = "max" |
| fp8_format: str = "MXFP8" |
|
|
| |
| log_file: Optional[str] = None |
|
|
| |
| log_grad_norm_interval: int = 100 |
| |
| log_memory_interval: int = 100 |
|
|
|
|
| |
| |
| |
|
|
|
|
| class Trainer: |
| """ |
| Manages the full pretraining loop for a decoder-only LLM. |
| |
| Supports: |
| - Gradient accumulation over ``config.grad_accum_steps`` micro-batches. |
| - bf16 mixed-precision via ``torch.autocast`` (no GradScaler required). |
| - Global gradient norm clipping. |
| - Tensorboard logging on the main process. |
| - Periodic checkpoint saving via :func:`train.utils.save_checkpoint`. |
| - Optional ``torch.compile`` acceleration. |
| |
| Args: |
| model: The LLM (plain ``nn.Module`` or DDP-wrapped). |
| train_loader: DataLoader yielding ``(input_ids, targets)`` batches. |
| optimizer: AdamW (or any ``Optimizer``) configured externally. |
| scheduler: LR scheduler produced by the caller. |
| config: ``TrainConfig`` instance controlling all loop behaviour. |
| device: Target device for data and model. |
| rank: Process rank (used to suppress logging on non-main ranks). |
| """ |
|
|
| def __init__( |
| self, |
| model: nn.Module, |
| train_loader: DataLoader, |
| optimizer: Optimizer, |
| scheduler: LambdaLR, |
| config: TrainConfig, |
| device: torch.device, |
| rank: int = 0, |
| sampler: Optional[DistributedSampler] = None, |
| val_loader: Optional[DataLoader] = None, |
| ) -> None: |
| self.model = model |
| self.train_loader = train_loader |
| self.optimizer = optimizer |
| self.scheduler = scheduler |
| self.config = config |
| self.device = device |
| self.rank = rank |
| self._is_main = is_main_process() |
| self._sampler = sampler |
| self._epoch = 0 |
| self._val_loader = val_loader |
| self._best_val_loss: float = float("inf") |
| self._val_patience_counter: int = 0 |
| self._val_patience_limit: int = 10 |
|
|
| |
| |
| self._shutdown_requested = False |
| self._shutdown_signal = "" |
|
|
| |
| self._fp8_recipe = None |
| if config.use_fp8 and HAS_TE: |
| if config.fp8_format == "MXFP8": |
| from transformer_engine.common.recipe import MXFP8BlockScaling |
| self._fp8_recipe = MXFP8BlockScaling() |
| else: |
| self._fp8_recipe = DelayedScaling( |
| fp8_format=getattr(Format, config.fp8_format), |
| amax_history_len=config.fp8_amax_history_len, |
| amax_compute_algo=config.fp8_amax_compute_algo, |
| ) |
|
|
| |
| if config.compile_model: |
| inner: nn.Module = getattr(self.model, "module", self.model) |
| compiled = torch.compile(inner) |
| if hasattr(self.model, "module"): |
| self.model.module = compiled |
| else: |
| self.model = compiled |
|
|
| |
| self._writer: Optional[SummaryWriter] = None |
| self._log_fh = None |
| if self._is_main: |
| if HAS_TENSORBOARD: |
| log_dir = Path(config.checkpoint_dir) / "tensorboard" |
| self._writer = SummaryWriter(log_dir=str(log_dir)) |
| if config.log_file is not None: |
| Path(config.log_file).parent.mkdir(parents=True, exist_ok=True) |
| self._log_fh = open(config.log_file, "a", encoding="utf-8", buffering=1) |
|
|
| |
| import datetime |
| self._train_start_time = datetime.datetime.now() |
|
|
| |
| self._loader_iter = iter(self.train_loader) |
|
|
| |
| |
| |
|
|
| def request_shutdown(self, signal_name: str = "UNKNOWN") -> None: |
| """Request graceful shutdown after the current training step. |
| |
| Called from signal handlers (SIGHUP, SIGTERM). Sets a flag |
| that the training loop checks after each optimizer step. |
| The loop will save an emergency checkpoint and exit cleanly. |
| """ |
| self._shutdown_requested = True |
| self._shutdown_signal = signal_name |
|
|
| def train(self, start_step: int = 0) -> None: |
| """ |
| Run the main training loop from ``start_step`` to ``config.max_steps``. |
| |
| Args: |
| start_step: First optimiser step index (non-zero when resuming). |
| """ |
| cfg = self.config |
| model = self.model |
|
|
| model.train() |
|
|
| |
| t0 = time.perf_counter() |
| running_loss = 0.0 |
| log_step_count = 0 |
| accum_loss = torch.tensor(0.0, device=self.device) |
|
|
| for step in range(start_step, cfg.max_steps): |
| |
| self.optimizer.zero_grad(set_to_none=True) |
| |
| accum_loss = torch.zeros(1, device=self.device) |
|
|
| for micro_step in range(cfg.grad_accum_steps): |
| batch = self._next_batch() |
| |
| is_last_micro = micro_step == cfg.grad_accum_steps - 1 |
| sync_ctx = ( |
| contextlib.nullcontext() |
| if not isinstance(model, DDP) or is_last_micro |
| else model.no_sync() |
| ) |
| try: |
| with sync_ctx: |
| micro_loss = self._step(batch) |
| except torch.cuda.OutOfMemoryError as e: |
| torch.cuda.empty_cache() |
| mem_total = torch.cuda.get_device_properties(self.device).total_memory / 1e9 |
| mem_alloc = torch.cuda.memory_allocated() / 1e9 |
| raise RuntimeError( |
| f"CUDA OOM at step {step}, micro_step {micro_step}. " |
| f"GPU mem: {mem_alloc:.1f}/{mem_total:.1f} GB. " |
| f"Try reducing batch_size or grad_accum_steps." |
| ) from e |
| except RuntimeError as e: |
| self._log(f"RuntimeError at step {step}, micro_step {micro_step}: {e}", level="ERROR") |
| raise |
| accum_loss += micro_loss |
|
|
| |
| avg_loss = accum_loss.item() / cfg.grad_accum_steps |
|
|
| |
| if not math.isfinite(avg_loss): |
| mem_gb = torch.cuda.memory_allocated() / 1e9 |
| mem_total = torch.cuda.get_device_properties(self.device).total_memory / 1e9 |
| raise RuntimeError( |
| f"Non-finite loss detected: {avg_loss}. " |
| f"GPU mem: {mem_gb:.1f}/{mem_total:.1f} GB. " |
| f"Check lr, grad clipping, FP8 amax history. " |
| f"Try: lower lr, increase fp8_amax_history_len, or switch to BF16." |
| ) |
|
|
| |
| |
| |
| if cfg.max_grad_norm > 0.0: |
| grad_norm = torch.nn.utils.clip_grad_norm_( |
| model.parameters(), cfg.max_grad_norm |
| ).item() |
| else: |
| grad_norm = get_grad_norm(model) |
|
|
| |
| self.optimizer.step() |
| self.scheduler.step() |
|
|
| |
| |
| |
| |
| if self._shutdown_requested: |
| self._log( |
| f"Graceful shutdown initiated (signal: {self._shutdown_signal}) " |
| f"at step {step + 1}, loss={avg_loss:.4f}", |
| level="WARN", |
| ) |
| if self._is_main: |
| ckpt_path = save_checkpoint( |
| model=self.model, |
| optimizer=self.optimizer, |
| scheduler=self.scheduler, |
| step=step + 1, |
| loss=avg_loss, |
| path=cfg.checkpoint_dir, |
| ) |
| self._log(f"Emergency checkpoint saved → {ckpt_path}", level="WARN") |
| |
| try: |
| if torch.distributed.is_initialized(): |
| torch.distributed.barrier() |
| except Exception: |
| pass |
| self._log("Shutdown complete. Exiting training loop.", level="WARN") |
| if self._writer is not None: |
| self._writer.close() |
| if self._log_fh is not None: |
| self._log_fh.flush() |
| return |
|
|
| running_loss += avg_loss |
| log_step_count += 1 |
|
|
| |
| if (step + 1) % cfg.log_interval == 0 and self._is_main: |
| t1 = time.perf_counter() |
| elapsed = t1 - t0 |
|
|
| avg_loss = running_loss / log_step_count |
|
|
| |
| batch_size, seq_len = self._last_batch_shape |
| tokens_per_sec = ( |
| batch_size * seq_len * cfg.grad_accum_steps * cfg.log_interval |
| ) / max(elapsed, 1e-9) |
|
|
| current_lr = self.scheduler.get_last_lr()[0] |
| global_step = step + 1 |
|
|
| mem_gb = torch.cuda.memory_allocated() / 1e9 |
| self._log( |
| f"step {global_step:>7d} | " |
| f"loss {avg_loss:.4f} | " |
| f"lr {current_lr:.2e} | " |
| f"gnorm {grad_norm:.3f} | " |
| f"tok/s {tokens_per_sec:,.0f} | " |
| f"mem {mem_gb:.1f}GB | " |
| f"epoch {self._epoch}" |
| ) |
|
|
| if self._writer is not None: |
| self._writer.add_scalar("train/loss", avg_loss, global_step) |
| self._writer.add_scalar("train/lr", current_lr, global_step) |
| self._writer.add_scalar("train/grad_norm", grad_norm, global_step) |
| self._writer.add_scalar("train/tokens_per_sec", tokens_per_sec, global_step) |
|
|
| |
| running_loss = 0.0 |
| log_step_count = 0 |
| t0 = t1 |
|
|
| |
| if (step + 1) % cfg.eval_interval == 0 and self._val_loader is not None: |
| val_loss = self._run_validation() |
| |
| |
| should_stop = False |
| if self._is_main: |
| self._log(f"step {step + 1:>7d} | val_loss {val_loss:.4f}") |
| if self._writer is not None: |
| self._writer.add_scalar("val/loss", val_loss, step + 1) |
| |
| if val_loss < self._best_val_loss: |
| self._best_val_loss = val_loss |
| self._val_patience_counter = 0 |
| best_path = save_checkpoint( |
| model=self.model, |
| optimizer=self.optimizer, |
| scheduler=self.scheduler, |
| step=step + 1, |
| loss=val_loss, |
| path=cfg.checkpoint_dir, |
| suffix="best", |
| ) |
| self._log( |
| f"New best val_loss={val_loss:.4f} → {best_path}" |
| ) |
| else: |
| self._val_patience_counter += 1 |
| self._log( |
| f"val_loss {val_loss:.4f} did not improve " |
| f"(best={self._best_val_loss:.4f}, " |
| f"patience={self._val_patience_counter}/{self._val_patience_limit})" |
| ) |
| if self._val_patience_counter >= self._val_patience_limit: |
| self._log( |
| f"Early stopping triggered at step {step + 1} " |
| f"(patience {self._val_patience_limit} exhausted)" |
| ) |
| should_stop = True |
| |
| if torch.distributed.is_initialized(): |
| stop_tensor = torch.tensor( |
| [1 if should_stop else 0], dtype=torch.int32, |
| device=self.device, |
| ) |
| torch.distributed.broadcast(stop_tensor, src=0) |
| should_stop = stop_tensor.item() == 1 |
| if should_stop: |
| return |
|
|
| |
| if (step + 1) % cfg.save_interval == 0 and self._is_main: |
| ckpt_path = save_checkpoint( |
| model=self.model, |
| optimizer=self.optimizer, |
| scheduler=self.scheduler, |
| step=step + 1, |
| loss=avg_loss, |
| path=cfg.checkpoint_dir, |
| ) |
| self._log(f"Checkpoint saved → {ckpt_path}") |
|
|
| |
| if self._is_main: |
| |
| final_path = save_checkpoint( |
| model=self.model, |
| optimizer=self.optimizer, |
| scheduler=self.scheduler, |
| step=cfg.max_steps, |
| loss=avg_loss, |
| path=cfg.checkpoint_dir, |
| ) |
| self._log(f"Training complete. Final checkpoint → {final_path}") |
|
|
| import datetime |
| elapsed = (datetime.datetime.now() - self._train_start_time).total_seconds() |
| total_steps_done = cfg.max_steps - start_step |
| self._log( |
| f"Training summary: {total_steps_done} steps, " |
| f"{elapsed/3600:.2f}h elapsed, " |
| f"avg {total_steps_done/elapsed:.1f} steps/s" |
| ) |
|
|
| if self._writer is not None: |
| self._writer.close() |
| if self._log_fh is not None: |
| self._log_fh.close() |
|
|
| |
| |
| |
|
|
| @torch.no_grad() |
| def _run_validation(self) -> float: |
| """ |
| Evaluate the model on the entire validation set and return the mean loss. |
| |
| Temporarily switches the model to eval mode and back to train mode |
| afterwards so that dropout / NEFTune hooks are inactive during eval. |
| """ |
| model = self.model |
| model.eval() |
| total_loss = 0.0 |
| total_batches = 0 |
|
|
| for batch in self._val_loader: |
| input_ids = batch[0].to(self.device, dtype=torch.long, non_blocking=True) |
| targets = batch[1].to(self.device, dtype=torch.long, non_blocking=True) |
| |
| _attn_mask = batch[2].to(self.device, non_blocking=True) if len(batch) > 2 else None |
|
|
| device_type = self.device.type |
| with contextlib.ExitStack() as stack: |
| if self.config.use_fp8 and self._fp8_recipe is not None: |
| stack.enter_context( |
| torch.autocast(device_type=device_type, dtype=torch.bfloat16) |
| ) |
| stack.enter_context( |
| te.fp8_autocast(enabled=True, fp8_recipe=self._fp8_recipe) |
| ) |
| elif self.config.use_amp: |
| stack.enter_context( |
| torch.autocast(device_type=device_type, dtype=torch.bfloat16) |
| ) |
| logits, _ = model(input_ids) |
| loss = self._compute_loss(logits, targets) |
|
|
| total_loss += loss.item() |
| total_batches += 1 |
|
|
| model.train() |
| if total_batches == 0: |
| self._log("Validation set is empty — returning inf", level="WARN") |
| return float("inf") |
| return total_loss / total_batches |
|
|
| def _log(self, msg: str, level: str = "INFO") -> None: |
| """Print to stdout and optionally write to the log file.""" |
| import datetime |
| ts = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") |
| line = f"[{ts}] [{level}] {msg}" |
| print(line) |
| if self._log_fh is not None: |
| self._log_fh.write(line + "\n") |
|
|
| def _step(self, batch: tuple) -> torch.Tensor: |
| """ |
| Execute one forward + backward pass for a single micro-batch. |
| |
| The loss is divided by ``grad_accum_steps`` so that gradients |
| accumulated over multiple micro-batches sum to the correct scale. |
| |
| Args: |
| batch: ``(input_ids, targets)`` or ``(input_ids, targets, attention_mask)`` |
| tensors on CPU; moved to device here. |
| |
| Returns: |
| Raw (un-scaled) loss as a detached GPU tensor (no CPU sync). |
| The caller is responsible for calling .item() once per optimizer step. |
| """ |
| input_ids = batch[0].to(self.device, dtype=torch.long, non_blocking=True) |
| targets = batch[1].to(self.device, dtype=torch.long, non_blocking=True) |
| |
| |
| |
| _attn_mask = batch[2].to(self.device, non_blocking=True) if len(batch) > 2 else None |
|
|
| |
| self._last_batch_shape = (input_ids.shape[0], input_ids.shape[1]) |
|
|
| device_type = self.device.type |
| |
| |
| |
| |
| with contextlib.ExitStack() as stack: |
| if self.config.use_fp8 and self._fp8_recipe is not None: |
| stack.enter_context( |
| torch.autocast(device_type=device_type, dtype=torch.bfloat16) |
| ) |
| stack.enter_context( |
| te.fp8_autocast(enabled=True, fp8_recipe=self._fp8_recipe) |
| ) |
| elif self.config.use_amp: |
| stack.enter_context( |
| torch.autocast(device_type=device_type, dtype=torch.bfloat16) |
| ) |
| logits, _ = self.model(input_ids) |
| loss = self._compute_loss(logits, targets) |
|
|
| |
| scaled_loss = loss / self.config.grad_accum_steps |
| scaled_loss.backward() |
|
|
| |
| |
| return loss.detach() |
|
|
| @staticmethod |
| def _compute_loss( |
| logits: torch.Tensor, targets: torch.Tensor |
| ) -> torch.Tensor: |
| """ |
| Compute cross-entropy loss, ignoring target positions equal to -1. |
| |
| Args: |
| logits: ``[B, T, vocab_size]`` float tensor. |
| targets: ``[B, T]`` long tensor (may contain -1 as ignore index). |
| |
| Returns: |
| Scalar loss tensor. |
| """ |
| B, T, V = logits.shape |
| return nn.functional.cross_entropy( |
| logits.view(B * T, V), |
| targets.view(B * T), |
| ignore_index=-1, |
| ) |
|
|
| def _next_batch(self) -> tuple: |
| """Return the next batch, restarting the DataLoader iterator if exhausted.""" |
| try: |
| return next(self._loader_iter) |
| except StopIteration: |
| self._epoch += 1 |
| |
| if self._sampler is not None: |
| self._sampler.set_epoch(self._epoch) |
| self._loader_iter = iter(self.train_loader) |
| return next(self._loader_iter) |
|
|