| """Checkpoint Management for Training"""
|
|
|
| import json
|
| import logging
|
| import shutil
|
| from dataclasses import dataclass, field
|
| from pathlib import Path
|
| from typing import Any, Dict, List, Optional
|
|
|
| import torch
|
|
|
| logger = logging.getLogger(__name__)
|
|
|
|
|
| @dataclass
|
| class CheckpointMetadata:
|
| """Metadata for a checkpoint."""
|
| step: int
|
| epoch: int
|
| global_step: int
|
| metrics: Dict[str, float] = field(default_factory=dict)
|
| config: Dict[str, Any] = field(default_factory=dict)
|
| model_name: str = "zenith"
|
| timestamp: str = ""
|
|
|
| def to_dict(self) -> Dict[str, Any]:
|
| return {
|
| "step": self.step,
|
| "epoch": self.epoch,
|
| "global_step": self.global_step,
|
| "metrics": self.metrics,
|
| "config": self.config,
|
| "model_name": self.model_name,
|
| "timestamp": self.timestamp,
|
| }
|
|
|
| @classmethod
|
| def from_dict(cls, data: Dict[str, Any]) -> "CheckpointMetadata":
|
| return cls(**data)
|
|
|
|
|
| class CheckpointManager:
|
| """Manages saving and loading of checkpoints."""
|
|
|
| def __init__(
|
| self,
|
| checkpoint_dir: str,
|
| save_total_limit: int = 5,
|
| save_best_only: bool = False,
|
| metric_for_best: str = "eval_loss",
|
| greater_is_better: bool = False,
|
| ):
|
| self.checkpoint_dir = Path(checkpoint_dir)
|
| self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
| self.save_total_limit = save_total_limit
|
| self.save_best_only = save_best_only
|
| self.metric_for_best = metric_for_best
|
| self.greater_is_better = greater_is_better
|
|
|
| self.best_metric = None
|
| self.checkpoints: List[Path] = []
|
|
|
|
|
| self._scan_checkpoints()
|
|
|
| def _scan_checkpoints(self):
|
| """Scan checkpoint directory for existing checkpoints."""
|
| for path in self.checkpoint_dir.glob("checkpoint-*"):
|
| if path.is_dir():
|
| self.checkpoints.append(path)
|
| self.checkpoints.sort(key=lambda p: int(p.name.split("-")[1]))
|
|
|
| def save_checkpoint(
|
| self,
|
| state: Dict[str, Any],
|
| name: str,
|
| metrics: Optional[Dict[str, float]] = None,
|
| ) -> Path:
|
| """Save checkpoint to disk."""
|
| checkpoint_path = self.checkpoint_dir / f"checkpoint-{name}"
|
| checkpoint_path.mkdir(exist_ok=True)
|
|
|
|
|
| torch.save(state["model_state_dict"], checkpoint_path / "pytorch_model.bin")
|
|
|
|
|
| if "optimizer_state_dict" in state:
|
| torch.save(state["optimizer_state_dict"], checkpoint_path / "optimizer.pt")
|
| if "scheduler_state_dict" in state and state["scheduler_state_dict"]:
|
| torch.save(state["scheduler_state_dict"], checkpoint_path / "scheduler.pt")
|
| if "scaler_state_dict" in state and state["scaler_state_dict"]:
|
| torch.save(state["scaler_state_dict"], checkpoint_path / "scaler.pt")
|
|
|
|
|
| metadata = CheckpointMetadata(
|
| step=state.get("step", 0),
|
| epoch=state.get("epoch", 0),
|
| global_step=state.get("global_step", 0),
|
| metrics=metrics or {},
|
| config=state.get("config", {}),
|
| timestamp=state.get("timestamp", ""),
|
| )
|
| with open(checkpoint_path / "metadata.json", "w") as f:
|
| json.dump(metadata.to_dict(), f, indent=2)
|
|
|
| logger.info(f"Checkpoint saved: {checkpoint_path}")
|
|
|
|
|
| if checkpoint_path not in self.checkpoints:
|
| self.checkpoints.append(checkpoint_path)
|
| self.checkpoints.sort(key=lambda p: int(p.name.split("-")[1]))
|
|
|
|
|
| if self.save_total_limit > 0 and len(self.checkpoints) > self.save_total_limit:
|
| self._remove_oldest_checkpoint()
|
|
|
| return checkpoint_path
|
|
|
| def load_checkpoint(
|
| self,
|
| checkpoint_path: Union[str, Path],
|
| model: torch.nn.Module,
|
| optimizer: Optional[torch.optim.Optimizer] = None,
|
| scheduler: Optional[Any] = None,
|
| scaler: Optional[torch.cuda.amp.GradScaler] = None,
|
| ) -> CheckpointMetadata:
|
| """Load checkpoint from disk."""
|
| checkpoint_path = Path(checkpoint_path)
|
|
|
| if not checkpoint_path.exists():
|
| raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
|
|
|
|
|
| model_path = checkpoint_path / "pytorch_model.bin"
|
| if model_path.exists():
|
| state_dict = torch.load(model_path, map_location="cpu")
|
| model.load_state_dict(state_dict)
|
| logger.info(f"Loaded model from {model_path}")
|
| else:
|
| logger.warning(f"Model weights not found at {model_path}")
|
|
|
|
|
| if optimizer is not None:
|
| opt_path = checkpoint_path / "optimizer.pt"
|
| if opt_path.exists():
|
| optimizer.load_state_dict(torch.load(opt_path, map_location="cpu"))
|
| logger.info(f"Loaded optimizer from {opt_path}")
|
|
|
|
|
| if scheduler is not None:
|
| sched_path = checkpoint_path / "scheduler.pt"
|
| if sched_path.exists():
|
| scheduler.load_state_dict(torch.load(sched_path, map_location="cpu"))
|
| logger.info(f"Loaded scheduler from {sched_path}")
|
|
|
|
|
| if scaler is not None:
|
| scaler_path = checkpoint_path / "scaler.pt"
|
| if scaler_path.exists():
|
| scaler.load_state_dict(torch.load(scaler_path, map_location="cpu"))
|
| logger.info(f"Loaded scaler from {scaler_path}")
|
|
|
|
|
| meta_path = checkpoint_path / "metadata.json"
|
| if meta_path.exists():
|
| with open(meta_path, "r") as f:
|
| metadata = CheckpointMetadata.from_dict(json.load(f))
|
| logger.info(f"Loaded metadata: epoch={metadata.epoch}, step={metadata.step}")
|
| else:
|
| metadata = CheckpointMetadata(step=0, epoch=0, global_step=0)
|
|
|
| return metadata
|
|
|
| def get_latest_checkpoint(self) -> Optional[Path]:
|
| """Get the most recent checkpoint."""
|
| if self.checkpoints:
|
| return self.checkpoints[-1]
|
| return None
|
|
|
| def get_best_checkpoint(self) -> Optional[Path]:
|
| """Get the best checkpoint based on metric."""
|
| if not self.checkpoints:
|
| return None
|
|
|
| best_path = None
|
| best_value = None
|
|
|
| for path in self.checkpoints:
|
| meta_path = path / "metadata.json"
|
| if meta_path.exists():
|
| with open(meta_path, "r") as f:
|
| meta = CheckpointMetadata.from_dict(json.load(f))
|
|
|
| if self.metric_for_best in meta.metrics:
|
| value = meta.metrics[self.metric_for_best]
|
| if best_value is None or (
|
| self.greater_is_better and value > best_value
|
| ) or (not self.greater_is_better and value < best_value):
|
| best_value = value
|
| best_path = path
|
|
|
| return best_path
|
|
|
| def _remove_oldest_checkpoint(self):
|
| """Remove the oldest checkpoint to maintain limit."""
|
| if len(self.checkpoints) > self.save_total_limit:
|
| oldest = self.checkpoints.pop(0)
|
| if oldest.exists():
|
| shutil.rmtree(oldest)
|
| logger.info(f"Removed old checkpoint: {oldest}")
|
|
|
| def cleanup(self, keep: Optional[List[Path]] = None):
|
| """Clean up checkpoints, optionally keeping specific ones."""
|
| if keep is None:
|
| keep = []
|
|
|
| for path in self.checkpoints:
|
| if path not in keep:
|
| if path.exists():
|
| shutil.rmtree(path)
|
| logger.info(f"Removed checkpoint: {path}")
|
|
|
| self._scan_checkpoints()
|
|
|
|
|
| def save_checkpoint(
|
| model: torch.nn.Module,
|
| optimizer: torch.optim.Optimizer,
|
| scheduler: Optional[Any],
|
| scaler: Optional[torch.cuda.amp.GradScaler],
|
| checkpoint_dir: str,
|
| epoch: int,
|
| global_step: int,
|
| metrics: Optional[Dict[str, float]] = None,
|
| config: Optional[Dict[str, Any]] = None,
|
| save_optimizer: bool = True,
|
| save_scheduler: bool = True,
|
| ):
|
| """Convenience function to save a checkpoint."""
|
| manager = CheckpointManager(checkpoint_dir, save_total_limit=0)
|
| state = {
|
| "model_state_dict": model.state_dict(),
|
| "global_step": global_step,
|
| "epoch": epoch,
|
| "config": config or {},
|
| "timestamp": "",
|
| }
|
|
|
| if save_optimizer:
|
| state["optimizer_state_dict"] = optimizer.state_dict()
|
| if save_scheduler and scheduler is not None:
|
| state["scheduler_state_dict"] = scheduler.state_dict()
|
|
|
| manager.save_checkpoint(state, f"step-{global_step}", metrics)
|
|
|
|
|
| def load_checkpoint(
|
| checkpoint_path: str,
|
| model: torch.nn.Module,
|
| optimizer: Optional[torch.optim.Optimizer] = None,
|
| scheduler: Optional[Any] = None,
|
| scaler: Optional[torch.cuda.amp.GradScaler] = None,
|
| ) -> int:
|
| """Convenience function to load a checkpoint."""
|
| manager = CheckpointManager(Path(checkpoint_path).parent)
|
| metadata = manager.load_checkpoint(checkpoint_path, model, optimizer, scheduler, scaler)
|
| return metadata.global_step
|
|
|