Spaces:
Build error
Build error
| import datetime | |
| import os | |
| import pathlib | |
| import shutil | |
| import time | |
| from typing import Any, Callable, Dict, Optional | |
| import torch | |
| from diffusers.utils import is_accelerate_available | |
| from finetrainers.logging import get_logger | |
| from finetrainers.utils import get_device_info | |
| from .base import BaseCheckpointer, BaseParallelBackend | |
| if not is_accelerate_available(): | |
| raise ImportError( | |
| "Please install the accelerate package using `pip install accelerate` to use the AccelerateParallelBackend." | |
| ) | |
| from accelerate import Accelerator | |
| from accelerate.data_loader import DataLoader | |
| from accelerate.utils import ( | |
| DataLoaderConfiguration, | |
| DistributedDataParallelKwargs, | |
| InitProcessGroupKwargs, | |
| ProjectConfiguration, | |
| set_seed, | |
| ) | |
| logger = get_logger() | |
| _device_type, _device_module = get_device_info() | |
| class AccelerateParallelBackend(BaseParallelBackend): | |
| def __init__( | |
| self, | |
| world_size: int, | |
| pp_degree: int = 1, | |
| dp_degree: int = 1, | |
| dp_shards: int = -1, | |
| cp_degree: int = 1, | |
| tp_degree: int = 1, | |
| backend: str = "nccl", | |
| timeout: int = 180, | |
| logging_dir: Optional[str] = None, | |
| output_dir: Optional[str] = None, | |
| gradient_accumulation_steps: Optional[int] = None, | |
| ) -> None: | |
| super().__init__() | |
| self._world_size = world_size | |
| self._pp_degree = pp_degree | |
| self._dp_degree = dp_degree | |
| self._dp_shards = dp_shards | |
| self._cp_degree = cp_degree | |
| self._tp_degree = tp_degree | |
| self._output_dir = pathlib.Path(output_dir) if output_dir is not None else None | |
| self._logging_dir = ( | |
| self._output_dir / logging_dir if output_dir is not None and logging_dir is not None else None | |
| ) | |
| self._backend = backend | |
| self._timeout = timeout | |
| self._gradient_accumulation_steps = gradient_accumulation_steps | |
| if pp_degree > 1 or dp_shards > 1 or cp_degree > 1 or tp_degree > 1: | |
| raise ValueError( | |
| "AccelerateParallelBackend does not support anything but Distributed Data Parallelism at the moment." | |
| ) | |
| if dp_degree != world_size: | |
| raise ValueError("Data parallel degree must be equal to world size.") | |
| self._accelerator = None | |
| if world_size == 1: | |
| # Needs special handling for single GPU training | |
| project_config = ProjectConfiguration(project_dir=self._output_dir, logging_dir=self._logging_dir) | |
| dataloader_config = DataLoaderConfiguration( | |
| split_batches=False, dispatch_batches=False, use_stateful_dataloader=True | |
| ) | |
| init_process_group_kwargs = InitProcessGroupKwargs( | |
| backend=self._backend, timeout=datetime.timedelta(seconds=self._timeout) | |
| ) | |
| self._accelerator = Accelerator( | |
| project_config=project_config, | |
| dataloader_config=dataloader_config, | |
| gradient_accumulation_steps=gradient_accumulation_steps, | |
| log_with=None, | |
| kwargs_handlers=[init_process_group_kwargs], | |
| ) | |
| if torch.backends.mps.is_available(): | |
| self._accelerator.native_amp = False | |
| self._mesh: torch.distributed.DeviceMesh = None | |
| def enable_determinism(self, seed: int) -> None: | |
| set_seed(seed) | |
| def apply_ddp(self, model: torch.nn.Module, *args, **kwargs) -> torch.nn.Module: | |
| project_config = None | |
| ddp_kwargs = None | |
| init_process_group_kwargs = None | |
| if self._accelerator is None: | |
| project_config = ProjectConfiguration(project_dir=self._output_dir, logging_dir=self._logging_dir) | |
| ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=False) | |
| dataloader_config = DataLoaderConfiguration( | |
| split_batches=False, dispatch_batches=False, use_stateful_dataloader=True | |
| ) | |
| init_process_group_kwargs = InitProcessGroupKwargs( | |
| backend=self._backend, timeout=datetime.timedelta(seconds=self._timeout) | |
| ) | |
| self._accelerator, model = apply_ddp( | |
| model, | |
| project_config, | |
| ddp_kwargs, | |
| init_process_group_kwargs, | |
| dataloader_config, | |
| self._gradient_accumulation_steps, | |
| accelerator=self._accelerator, | |
| ) | |
| logger.debug("Applied AccelerateParallel::apply_ddp to model.") | |
| return model | |
| def prepare_model(self, model: torch.nn.Module) -> torch.nn.Module: | |
| return self._accelerator.prepare_model(model) | |
| def prepare_dataset(self, dataset: torch.utils.data.IterableDataset) -> torch.utils.data.IterableDataset: | |
| logger.debug("AccelerateParallelBackend::prepare_dataset completed!") | |
| return dataset | |
| def prepare_dataloader( | |
| self, | |
| dataset: torch.utils.data.IterableDataset, | |
| batch_size: int = 1, | |
| num_workers: int = 0, | |
| pin_memory: bool = False, | |
| ) -> DataLoader: | |
| dataloader = torch.utils.data.DataLoader( | |
| dataset, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory | |
| ) | |
| dataloader = self._accelerator.prepare_data_loader(dataloader) | |
| logger.debug("AccelerateParallelBackend::prepare_dataloader completed!") | |
| return dataloader | |
| def prepare_optimizer(self, optimizer, lr_scheduler): | |
| optimizer = self._accelerator.prepare_optimizer(optimizer) | |
| lr_scheduler = self._accelerator.prepare_scheduler(lr_scheduler) | |
| return optimizer, lr_scheduler | |
| def get_mesh(self, name: Optional[str] = None) -> torch.distributed.DeviceMesh: | |
| def _get_mesh(): | |
| if name is None: | |
| return self._mesh | |
| try: | |
| return self._mesh[name] | |
| except (KeyError, RuntimeError): | |
| return self._mesh | |
| if self._mesh is not None: | |
| return _get_mesh() | |
| mesh_list = [("dp_replicate", self._dp_degree), ("dp_shard", self._dp_shards)] | |
| mesh_list = [(name, degree) for name, degree in mesh_list if degree > 1] | |
| names = [x[0] for x in mesh_list] | |
| degrees = [x[1] for x in mesh_list] | |
| mesh = torch.distributed.device_mesh.init_device_mesh(_device_type, mesh_shape=degrees, mesh_dim_names=names) | |
| dp_mesh_names, dp_cp_mesh_names, dp_shard_cp_mesh_names = [], [], [] | |
| if self.data_replication_enabled: | |
| dp_mesh_names.append("dp_replicate") | |
| dp_cp_mesh_names.append("dp_replicate") | |
| if self.data_sharding_enabled: | |
| dp_mesh_names.append("dp_shard") | |
| dp_cp_mesh_names.append("dp_shard") | |
| dp_shard_cp_mesh_names.append("dp_shard") | |
| if self.context_parallel_enabled: | |
| dp_cp_mesh_names.append("cp") | |
| dp_shard_cp_mesh_names.append("cp") | |
| if len(dp_mesh_names) > 0: | |
| mesh[tuple(dp_mesh_names)]._flatten(mesh_dim_name="dp") | |
| if len(dp_cp_mesh_names) > 0: | |
| mesh[tuple(dp_cp_mesh_names)]._flatten(mesh_dim_name="dp_cp") | |
| if len(dp_shard_cp_mesh_names) > 0: | |
| mesh[tuple(dp_shard_cp_mesh_names)]._flatten(mesh_dim_name="dp_shard_cp") | |
| logger.debug(f"Device mesh: {mesh}") | |
| self._mesh = mesh | |
| return _get_mesh() | |
| def get_checkpointer(self, *args, **kwargs): | |
| return AccelerateCheckpointer(self._accelerator, *args, **kwargs) | |
| def world_size(self): | |
| return self._accelerator.num_processes | |
| def rank(self): | |
| return self._accelerator.process_index | |
| def local_rank(self): | |
| return self._accelerator.local_process_index | |
| def is_main_process(self): | |
| r"""Returns `True` if the current process is the main process on the master node.""" | |
| return self._accelerator.is_main_process | |
| def is_local_main_process(self): | |
| r"""Returns `True` if the current process is the main process on local node.""" | |
| return self._accelerator.is_local_main_process | |
| def device(self): | |
| return self._accelerator.device | |
| def wait_for_everyone(self): | |
| self._accelerator.wait_for_everyone() | |
| def destroy(self): | |
| if self.is_main_process and self.tracker is not None: | |
| self.tracker.finish() | |
| self._accelerator.end_training() | |
| def pipeline_parallel_enabled(self): | |
| return self._pp_degree > 1 | |
| def data_parallel_enabled(self): | |
| return self._dp_degree > 1 or self._dp_shards > 1 | |
| def data_replication_enabled(self): | |
| return self._dp_degree > 1 | |
| def data_sharding_enabled(self): | |
| return self._dp_shards > 1 | |
| def context_parallel_enabled(self): | |
| return self._cp_degree > 1 | |
| def tensor_parallel_enabled(self): | |
| return self._tp_degree > 1 | |
| class AccelerateCheckpointer(BaseCheckpointer): | |
| def __init__( | |
| self, | |
| accelerator: Accelerator, | |
| states: Dict[str, Any], | |
| checkpointing_steps: int, | |
| checkpointing_limit: int, | |
| output_dir: str, | |
| enable: bool = True, | |
| _callback_fn: Callable[[Dict[str, Any]], Dict[str, Any]] = None, | |
| _prefix: str = "finetrainers_step", | |
| *args, | |
| **kwargs, | |
| ) -> None: | |
| self.accelerator = accelerator | |
| self.states = states | |
| self.checkpointing_steps = checkpointing_steps | |
| self.checkpointing_limit = checkpointing_limit | |
| self.output_dir = pathlib.Path(output_dir) | |
| self.enable = enable | |
| self._callback_fn = _callback_fn | |
| self._prefix = _prefix | |
| def save_model_hook(models, weights, output_dir: str) -> None: | |
| if not self.accelerator.is_main_process: | |
| return | |
| # TODO(aryan): this is a temporary assertion since we only support training transformer at the moment. | |
| # Remove it when adding support for training text encoders/vae and more. | |
| assert len(models) == 1 | |
| _callback_fn(weights[0]) | |
| torch.save(self.states, os.path.join(output_dir, "states.pt")) | |
| def load_model_hook(models, input_dir) -> None: | |
| self.states = torch.load(os.path.join(input_dir, "states.pt")) | |
| self.accelerator.register_save_state_pre_hook(save_model_hook) | |
| self.accelerator.register_load_state_pre_hook(load_model_hook) | |
| logger.info(f"Checkpointing enabled. Checkpoints will be stored in '{self.output_dir}'") | |
| def save(self, step: int = -1, force: bool = False, *, _device: torch.device, _is_main_process: bool) -> str: | |
| if not self._should_checkpoint(step, force): | |
| return None | |
| checkpoint_dir = self._get_checkpoint_dir(step) | |
| begin_time = time.monotonic() | |
| self.accelerator.save_state(checkpoint_dir.as_posix(), safe_serialization=True) | |
| end_time = time.monotonic() | |
| logger.info( | |
| f"Saved checkpoint in {end_time - begin_time:.2f} seconds at step {step}. Directory: {checkpoint_dir}" | |
| ) | |
| self._purge_stale_checkpoints() | |
| return checkpoint_dir.as_posix() | |
| def load(self, step: int = -1) -> bool: | |
| if not self.enable: | |
| return False | |
| if not self.output_dir.exists(): | |
| return False | |
| if step != -1 and not self._get_checkpoint_dir(step).exists(): | |
| return False | |
| if step == -1: | |
| latest_checkpoint_dir = self._find_latest_checkpoint_dir() | |
| if latest_checkpoint_dir is None: | |
| return False | |
| step = int(latest_checkpoint_dir.name.split("_")[-1]) | |
| checkpoint_dir = self._get_checkpoint_dir(step) | |
| logger.info(f"Loading checkpoint from '{checkpoint_dir}' at step {step}") | |
| begin_time = time.monotonic() | |
| self.accelerator.load_state(checkpoint_dir.as_posix()) | |
| end_time = time.monotonic() | |
| logger.info(f"Loaded checkpoint in {end_time - begin_time:.2f} seconds.") | |
| return True | |
| def _should_checkpoint(self, step: int, force: bool) -> bool: | |
| if not self.enable: | |
| return False | |
| if not force: | |
| if step % self.checkpointing_steps != 0: | |
| return False | |
| return True | |
| def _get_checkpoint_dir(self, step: int) -> pathlib.Path: | |
| return self.output_dir / f"{self._prefix}_{step}" | |
| def _find_latest_checkpoint_dir(self) -> Optional[pathlib.Path]: | |
| checkpoints = sorted(self.output_dir.glob(f"{self._prefix}_*"), key=lambda x: int(x.name.split("_")[-1])) | |
| return checkpoints[-1] if len(checkpoints) > 0 else None | |
| def _purge_stale_checkpoints(self) -> None: | |
| if self.checkpointing_limit is None or self.checkpointing_limit <= 0: | |
| return | |
| checkpoints = sorted( | |
| self.output_dir.glob(f"{self._prefix}_*"), key=lambda x: int(x.name.split("_")[-1]), reverse=True | |
| ) | |
| for checkpoint in checkpoints[self.checkpointing_limit :]: | |
| logger.info(f"Deleting stale checkpoint: {checkpoint}") | |
| shutil.rmtree(checkpoint, ignore_errors=True) | |
| def apply_ddp( | |
| model: torch.nn.Module, | |
| project_config: Optional[ProjectConfiguration] = None, | |
| ddp_kwargs: Optional[DistributedDataParallelKwargs] = None, | |
| init_process_group_kwargs: Optional[InitProcessGroupKwargs] = None, | |
| dataloader_config: Optional[DataLoaderConfiguration] = None, | |
| gradient_accumulation_steps: Optional[int] = None, | |
| accelerator: Optional[Accelerator] = None, | |
| ) -> torch.nn.Module: | |
| if accelerator is None: | |
| accelerator = Accelerator( | |
| project_config=project_config, | |
| dataloader_config=dataloader_config, | |
| gradient_accumulation_steps=gradient_accumulation_steps, | |
| log_with=None, | |
| kwargs_handlers=[ddp_kwargs, init_process_group_kwargs], | |
| ) | |
| if torch.backends.mps.is_available(): | |
| accelerator.native_amp = False | |
| accelerator.prepare_model(model) | |
| return accelerator, model | |