Spaces:
Runtime error
Runtime error
| import functools | |
| import pathlib | |
| import shutil | |
| import time | |
| from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union | |
| import torch | |
| import torch.distributed.checkpoint | |
| from torch.distributed.checkpoint.state_dict import ( | |
| StateDictOptions, | |
| get_model_state_dict, | |
| set_model_state_dict, | |
| ) | |
| from torch.distributed.checkpoint.stateful import Stateful | |
| from ..logging import get_logger | |
| if TYPE_CHECKING: | |
| from .. import optimizer | |
| logger = get_logger() | |
| class ModelWrapper(Stateful): | |
| def __init__(self, model: Union[torch.nn.Module, List[torch.nn.Module]]) -> None: | |
| self.model = [model] if isinstance(model, torch.nn.Module) else model | |
| def state_dict(self) -> Dict[str, Any]: | |
| return {k: v for sd in map(get_model_state_dict, self.model) for k, v in sd.items()} | |
| def load_state_dict(self, state_dict: Dict[str, Any]) -> None: | |
| func = functools.partial( | |
| set_model_state_dict, | |
| model_state_dict=state_dict, | |
| options=StateDictOptions(strict=False), | |
| ) | |
| list(map(func, self.model)) | |
| class PTDCheckpointManager: | |
| def __init__( | |
| self, | |
| dataloader: torch.utils.data.DataLoader, | |
| model_parts: List[torch.nn.Module], | |
| optimizers: "optimizer.OptimizerWrapper", | |
| schedulers: "optimizer.SchedulerWrapper", | |
| states: Dict[str, Any], | |
| checkpointing_steps: int, | |
| checkpointing_limit: int, | |
| output_dir: str, | |
| enable: bool = True, | |
| _callback_fn: Callable[[Dict[str, Any]], Dict[str, Any]] = None, | |
| _prefix: str = "finetrainers_step", | |
| ) -> None: | |
| self.states = states | |
| self.states.update( | |
| { | |
| "model": ModelWrapper(model_parts), | |
| "optimizer": optimizers, | |
| "dataloader": dataloader, | |
| } | |
| ) | |
| self.states.update(schedulers.get_lr_scheduler_state()) | |
| self.checkpointing_steps = checkpointing_steps | |
| self.checkpointing_limit = checkpointing_limit | |
| self.output_dir = pathlib.Path(output_dir) | |
| self.enable = enable | |
| self._callback_fn = _callback_fn | |
| self._prefix = _prefix | |
| logger.info(f"Checkpointing enabled. Checkpoints will be stored in '{self.output_dir}'") | |
| def save(self, step: int = -1, force: bool = False, *, _device: torch.device, _is_main_process: bool) -> str: | |
| if not self._should_checkpoint(step, force): | |
| return None | |
| checkpoint_dir = self._get_checkpoint_dir(step) | |
| begin_time = time.monotonic() | |
| torch.distributed.checkpoint.save(self.states, checkpoint_id=checkpoint_dir.as_posix()) | |
| end_time = time.monotonic() | |
| logger.info( | |
| f"Saved checkpoint in {end_time - begin_time:.2f} seconds at step {step}. Directory: {checkpoint_dir}" | |
| ) | |
| self._purge_stale_checkpoints() | |
| state_dicts = [ | |
| gather_state_dict_on_cpu_rank0(model, _device, is_main_process=_is_main_process) | |
| for model in self.states["model"].model | |
| ] | |
| if self._callback_fn is not None: | |
| list(map(self._callback_fn, state_dicts)) | |
| return checkpoint_dir.as_posix() | |
| def load(self, step: int = -1) -> bool: | |
| if not self.enable: | |
| return False | |
| if not self.output_dir.exists(): | |
| return False | |
| if step != -1 and not self._get_checkpoint_dir(step).exists(): | |
| return False | |
| if step == -1: | |
| latest_checkpoint_dir = self._find_latest_checkpoint_dir() | |
| if latest_checkpoint_dir is None: | |
| return False | |
| step = int(latest_checkpoint_dir.name.split("_")[-1]) | |
| checkpoint_dir = self._get_checkpoint_dir(step) | |
| logger.info(f"Loading checkpoint from '{checkpoint_dir}' at step {step}") | |
| # For step 0, optimizers/schedulers are not available as they are created during training after first step | |
| states = {"model": self.states["model"]} if step == 0 else self.states | |
| # See bug: https://github.com/pytorch/pytorch/pull/138575 | |
| original_stateful_states = {k: v for k, v in states.items() if isinstance(v, Stateful)} | |
| begin_time = time.monotonic() | |
| torch.distributed.checkpoint.load(states, checkpoint_id=checkpoint_dir.as_posix()) | |
| end_time = time.monotonic() | |
| logger.info(f"Loaded checkpoint in {end_time - begin_time:.2f} seconds.") | |
| # bugfix from above: restore the original stateful objects, whose states were already updated in-place by dcp.load() | |
| states.update(original_stateful_states) | |
| return True | |
| def _should_checkpoint(self, step: int, force: bool) -> bool: | |
| if not self.enable: | |
| return False | |
| if not force: | |
| if step % self.checkpointing_steps != 0: | |
| return False | |
| return True | |
| def _get_checkpoint_dir(self, step: int) -> pathlib.Path: | |
| return self.output_dir / f"{self._prefix}_{step}" | |
| def _find_latest_checkpoint_dir(self) -> Union[pathlib.Path, None]: | |
| checkpoints = sorted(self.output_dir.glob(f"{self._prefix}_*"), key=lambda x: int(x.name.split("_")[-1])) | |
| return checkpoints[-1] if len(checkpoints) > 0 else None | |
| def _purge_stale_checkpoints(self) -> None: | |
| if self.checkpointing_limit is None or self.checkpointing_limit <= 0: | |
| return | |
| checkpoints = sorted( | |
| self.output_dir.glob(f"{self._prefix}_*"), key=lambda x: int(x.name.split("_")[-1]), reverse=True | |
| ) | |
| for checkpoint in checkpoints[self.checkpointing_limit :]: | |
| logger.info(f"Deleting stale checkpoint: {checkpoint}") | |
| shutil.rmtree(checkpoint, ignore_errors=True) | |
| def gather_state_dict_on_cpu_rank0( | |
| model, device: Optional[torch.device] = None, *, is_main_process: bool | |
| ) -> Dict[str, Any]: | |
| cpu_state_dict = {} | |
| sharded_sd = model.state_dict() | |
| for param_name, param in sharded_sd.items(): | |
| if param.is_cpu: | |
| # Move back to device if offloaded to CPU | |
| param = param.to(device) | |
| if hasattr(param, "_local_tensor"): | |
| # Gather DTensor | |
| param = param.full_tensor() | |
| if is_main_process: | |
| cpu_state_dict[param_name] = param.cpu() | |
| torch.distributed.barrier() | |
| return cpu_state_dict | |
| # # Copied from pytorch (torch/distributed/checkpoint/format_utils.py) to support callbacks to modify state_dict | |
| # def dcp_to_torch_save( | |
| # dcp_checkpoint_dir: Union[str, os.PathLike], | |
| # torch_save_path: Union[str, os.PathLike], | |
| # callback_fn: Callable[[Dict[str, Any]], Dict[str, Any]] = None, | |
| # ): | |
| # """ | |
| # Given a directory containing a DCP checkpoint, this function will convert it into a | |
| # Torch save file. | |
| # Args: | |
| # dcp_checkpoint_dir: Directory containing the DCP checkpoint. | |
| # torch_save_path: Filename to store the converted Torch save file. | |
| # callback_fn: Optional callback function that takes the state_dict as input and returns a modified state_dict. | |
| # .. warning:: | |
| # To avoid OOM, it's recommended to only run this function on a single rank. | |
| # """ | |
| # state_dict = {} | |
| # _load_state_dict( | |
| # state_dict, | |
| # storage_reader=FileSystemReader(dcp_checkpoint_dir), | |
| # planner=_EmptyStateDictLoadPlanner(), | |
| # no_dist=True, | |
| # ) | |
| # if callback_fn is not None: | |
| # state_dict = callback_fn(state_dict) | |
| # torch.save(state_dict, torch_save_path) | |