|
|
import datetime |
|
|
import os |
|
|
import pathlib |
|
|
import shutil |
|
|
import time |
|
|
from typing import Any, Callable, Dict, Optional |
|
|
|
|
|
import torch |
|
|
from diffusers.utils import is_accelerate_available |
|
|
|
|
|
from finetrainers.logging import get_logger |
|
|
from finetrainers.utils import get_device_info |
|
|
|
|
|
from .base import BaseCheckpointer, BaseParallelBackend |
|
|
|
|
|
|
|
|
if not is_accelerate_available(): |
|
|
raise ImportError( |
|
|
"Please install the accelerate package using `pip install accelerate` to use the AccelerateParallelBackend." |
|
|
) |
|
|
|
|
|
from accelerate import Accelerator |
|
|
from accelerate.data_loader import DataLoader |
|
|
from accelerate.utils import ( |
|
|
DataLoaderConfiguration, |
|
|
DistributedDataParallelKwargs, |
|
|
InitProcessGroupKwargs, |
|
|
ProjectConfiguration, |
|
|
set_seed, |
|
|
) |
|
|
|
|
|
|
|
|
logger = get_logger() |
|
|
_device_type, _device_module = get_device_info() |
|
|
|
|
|
|
|
|
class AccelerateParallelBackend(BaseParallelBackend): |
|
|
def __init__( |
|
|
self, |
|
|
world_size: int, |
|
|
pp_degree: int = 1, |
|
|
dp_degree: int = 1, |
|
|
dp_shards: int = -1, |
|
|
cp_degree: int = 1, |
|
|
tp_degree: int = 1, |
|
|
backend: str = "nccl", |
|
|
timeout: int = 180, |
|
|
logging_dir: Optional[str] = None, |
|
|
output_dir: Optional[str] = None, |
|
|
gradient_accumulation_steps: Optional[int] = None, |
|
|
) -> None: |
|
|
super().__init__() |
|
|
|
|
|
self._world_size = world_size |
|
|
self._pp_degree = pp_degree |
|
|
self._dp_degree = dp_degree |
|
|
self._dp_shards = dp_shards |
|
|
self._cp_degree = cp_degree |
|
|
self._tp_degree = tp_degree |
|
|
self._output_dir = pathlib.Path(output_dir) if output_dir is not None else None |
|
|
self._logging_dir = ( |
|
|
self._output_dir / logging_dir if output_dir is not None and logging_dir is not None else None |
|
|
) |
|
|
self._backend = backend |
|
|
self._timeout = timeout |
|
|
self._gradient_accumulation_steps = gradient_accumulation_steps |
|
|
|
|
|
if pp_degree > 1 or dp_shards > 1 or cp_degree > 1 or tp_degree > 1: |
|
|
raise ValueError( |
|
|
"AccelerateParallelBackend does not support anything but Distributed Data Parallelism at the moment." |
|
|
) |
|
|
if dp_degree != world_size: |
|
|
raise ValueError("Data parallel degree must be equal to world size.") |
|
|
|
|
|
self._accelerator = None |
|
|
if world_size == 1: |
|
|
|
|
|
project_config = ProjectConfiguration(project_dir=self._output_dir, logging_dir=self._logging_dir) |
|
|
dataloader_config = DataLoaderConfiguration( |
|
|
split_batches=False, dispatch_batches=False, use_stateful_dataloader=True |
|
|
) |
|
|
init_process_group_kwargs = InitProcessGroupKwargs( |
|
|
backend=self._backend, timeout=datetime.timedelta(seconds=self._timeout) |
|
|
) |
|
|
self._accelerator = Accelerator( |
|
|
project_config=project_config, |
|
|
dataloader_config=dataloader_config, |
|
|
gradient_accumulation_steps=gradient_accumulation_steps, |
|
|
log_with=None, |
|
|
kwargs_handlers=[init_process_group_kwargs], |
|
|
) |
|
|
if torch.backends.mps.is_available(): |
|
|
self._accelerator.native_amp = False |
|
|
|
|
|
self._mesh: torch.distributed.DeviceMesh = None |
|
|
|
|
|
def enable_determinism(self, seed: int) -> None: |
|
|
set_seed(seed) |
|
|
|
|
|
def apply_ddp(self, model: torch.nn.Module, *args, **kwargs) -> torch.nn.Module: |
|
|
project_config = None |
|
|
ddp_kwargs = None |
|
|
init_process_group_kwargs = None |
|
|
if self._accelerator is None: |
|
|
project_config = ProjectConfiguration(project_dir=self._output_dir, logging_dir=self._logging_dir) |
|
|
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=False) |
|
|
dataloader_config = DataLoaderConfiguration( |
|
|
split_batches=False, dispatch_batches=False, use_stateful_dataloader=True |
|
|
) |
|
|
init_process_group_kwargs = InitProcessGroupKwargs( |
|
|
backend=self._backend, timeout=datetime.timedelta(seconds=self._timeout) |
|
|
) |
|
|
self._accelerator, model = apply_ddp( |
|
|
model, |
|
|
project_config, |
|
|
ddp_kwargs, |
|
|
init_process_group_kwargs, |
|
|
dataloader_config, |
|
|
self._gradient_accumulation_steps, |
|
|
accelerator=self._accelerator, |
|
|
) |
|
|
logger.debug("Applied AccelerateParallel::apply_ddp to model.") |
|
|
return model |
|
|
|
|
|
def prepare_model(self, model: torch.nn.Module) -> torch.nn.Module: |
|
|
return self._accelerator.prepare_model(model) |
|
|
|
|
|
def prepare_dataset(self, dataset: torch.utils.data.IterableDataset) -> torch.utils.data.IterableDataset: |
|
|
logger.debug("AccelerateParallelBackend::prepare_dataset completed!") |
|
|
return dataset |
|
|
|
|
|
def prepare_dataloader( |
|
|
self, |
|
|
dataset: torch.utils.data.IterableDataset, |
|
|
batch_size: int = 1, |
|
|
num_workers: int = 0, |
|
|
pin_memory: bool = False, |
|
|
) -> DataLoader: |
|
|
dataloader = torch.utils.data.DataLoader( |
|
|
dataset, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory |
|
|
) |
|
|
dataloader = self._accelerator.prepare_data_loader(dataloader) |
|
|
logger.debug("AccelerateParallelBackend::prepare_dataloader completed!") |
|
|
return dataloader |
|
|
|
|
|
def prepare_optimizer(self, optimizer, lr_scheduler): |
|
|
optimizer = self._accelerator.prepare_optimizer(optimizer) |
|
|
lr_scheduler = self._accelerator.prepare_scheduler(lr_scheduler) |
|
|
return optimizer, lr_scheduler |
|
|
|
|
|
def get_mesh(self, name: Optional[str] = None) -> torch.distributed.DeviceMesh: |
|
|
def _get_mesh(): |
|
|
if name is None: |
|
|
return self._mesh |
|
|
try: |
|
|
return self._mesh[name] |
|
|
except (KeyError, RuntimeError): |
|
|
return self._mesh |
|
|
|
|
|
if self._mesh is not None: |
|
|
return _get_mesh() |
|
|
|
|
|
mesh_list = [("dp_replicate", self._dp_degree), ("dp_shard", self._dp_shards)] |
|
|
mesh_list = [(name, degree) for name, degree in mesh_list if degree > 1] |
|
|
names = [x[0] for x in mesh_list] |
|
|
degrees = [x[1] for x in mesh_list] |
|
|
mesh = torch.distributed.device_mesh.init_device_mesh(_device_type, mesh_shape=degrees, mesh_dim_names=names) |
|
|
|
|
|
dp_mesh_names, dp_cp_mesh_names, dp_shard_cp_mesh_names = [], [], [] |
|
|
|
|
|
if self.data_replication_enabled: |
|
|
dp_mesh_names.append("dp_replicate") |
|
|
dp_cp_mesh_names.append("dp_replicate") |
|
|
if self.data_sharding_enabled: |
|
|
dp_mesh_names.append("dp_shard") |
|
|
dp_cp_mesh_names.append("dp_shard") |
|
|
dp_shard_cp_mesh_names.append("dp_shard") |
|
|
if self.context_parallel_enabled: |
|
|
dp_cp_mesh_names.append("cp") |
|
|
dp_shard_cp_mesh_names.append("cp") |
|
|
|
|
|
if len(dp_mesh_names) > 0: |
|
|
mesh[tuple(dp_mesh_names)]._flatten(mesh_dim_name="dp") |
|
|
if len(dp_cp_mesh_names) > 0: |
|
|
mesh[tuple(dp_cp_mesh_names)]._flatten(mesh_dim_name="dp_cp") |
|
|
if len(dp_shard_cp_mesh_names) > 0: |
|
|
mesh[tuple(dp_shard_cp_mesh_names)]._flatten(mesh_dim_name="dp_shard_cp") |
|
|
|
|
|
logger.debug(f"Device mesh: {mesh}") |
|
|
self._mesh = mesh |
|
|
return _get_mesh() |
|
|
|
|
|
def get_checkpointer(self, *args, **kwargs): |
|
|
return AccelerateCheckpointer(self._accelerator, *args, **kwargs) |
|
|
|
|
|
@property |
|
|
def world_size(self): |
|
|
return self._accelerator.num_processes |
|
|
|
|
|
@property |
|
|
def rank(self): |
|
|
return self._accelerator.process_index |
|
|
|
|
|
@property |
|
|
def local_rank(self): |
|
|
return self._accelerator.local_process_index |
|
|
|
|
|
@property |
|
|
def is_main_process(self): |
|
|
r"""Returns `True` if the current process is the main process on the master node.""" |
|
|
return self._accelerator.is_main_process |
|
|
|
|
|
@property |
|
|
def is_local_main_process(self): |
|
|
r"""Returns `True` if the current process is the main process on local node.""" |
|
|
return self._accelerator.is_local_main_process |
|
|
|
|
|
@property |
|
|
def device(self): |
|
|
return self._accelerator.device |
|
|
|
|
|
def wait_for_everyone(self): |
|
|
self._accelerator.wait_for_everyone() |
|
|
|
|
|
def destroy(self): |
|
|
if self.is_main_process and self.tracker is not None: |
|
|
self.tracker.finish() |
|
|
self._accelerator.end_training() |
|
|
|
|
|
@property |
|
|
def pipeline_parallel_enabled(self): |
|
|
return self._pp_degree > 1 |
|
|
|
|
|
@property |
|
|
def data_parallel_enabled(self): |
|
|
return self._dp_degree > 1 or self._dp_shards > 1 |
|
|
|
|
|
@property |
|
|
def data_replication_enabled(self): |
|
|
return self._dp_degree > 1 |
|
|
|
|
|
@property |
|
|
def data_sharding_enabled(self): |
|
|
return self._dp_shards > 1 |
|
|
|
|
|
@property |
|
|
def context_parallel_enabled(self): |
|
|
return self._cp_degree > 1 |
|
|
|
|
|
@property |
|
|
def tensor_parallel_enabled(self): |
|
|
return self._tp_degree > 1 |
|
|
|
|
|
|
|
|
class AccelerateCheckpointer(BaseCheckpointer): |
|
|
def __init__( |
|
|
self, |
|
|
accelerator: Accelerator, |
|
|
states: Dict[str, Any], |
|
|
checkpointing_steps: int, |
|
|
checkpointing_limit: int, |
|
|
output_dir: str, |
|
|
enable: bool = True, |
|
|
_callback_fn: Callable[[Dict[str, Any]], Dict[str, Any]] = None, |
|
|
_prefix: str = "finetrainers_step", |
|
|
*args, |
|
|
**kwargs, |
|
|
) -> None: |
|
|
self.accelerator = accelerator |
|
|
self.states = states |
|
|
|
|
|
self.checkpointing_steps = checkpointing_steps |
|
|
self.checkpointing_limit = checkpointing_limit |
|
|
self.output_dir = pathlib.Path(output_dir) |
|
|
self.enable = enable |
|
|
self._callback_fn = _callback_fn |
|
|
self._prefix = _prefix |
|
|
|
|
|
def save_model_hook(models, weights, output_dir: str) -> None: |
|
|
if not self.accelerator.is_main_process: |
|
|
return |
|
|
|
|
|
|
|
|
|
|
|
assert len(models) == 1 |
|
|
|
|
|
_callback_fn(weights[0]) |
|
|
torch.save(self.states, os.path.join(output_dir, "states.pt")) |
|
|
|
|
|
def load_model_hook(models, input_dir) -> None: |
|
|
self.states = torch.load(os.path.join(input_dir, "states.pt")) |
|
|
|
|
|
self.accelerator.register_save_state_pre_hook(save_model_hook) |
|
|
self.accelerator.register_load_state_pre_hook(load_model_hook) |
|
|
|
|
|
logger.info(f"Checkpointing enabled. Checkpoints will be stored in '{self.output_dir}'") |
|
|
|
|
|
def save(self, step: int = -1, force: bool = False, *, _device: torch.device, _is_main_process: bool) -> str: |
|
|
if not self._should_checkpoint(step, force): |
|
|
return None |
|
|
|
|
|
checkpoint_dir = self._get_checkpoint_dir(step) |
|
|
begin_time = time.monotonic() |
|
|
self.accelerator.save_state(checkpoint_dir.as_posix(), safe_serialization=True) |
|
|
end_time = time.monotonic() |
|
|
logger.info( |
|
|
f"Saved checkpoint in {end_time - begin_time:.2f} seconds at step {step}. Directory: {checkpoint_dir}" |
|
|
) |
|
|
self._purge_stale_checkpoints() |
|
|
|
|
|
return checkpoint_dir.as_posix() |
|
|
|
|
|
def load(self, step: int = -1) -> bool: |
|
|
if not self.enable: |
|
|
return False |
|
|
if not self.output_dir.exists(): |
|
|
return False |
|
|
if step != -1 and not self._get_checkpoint_dir(step).exists(): |
|
|
return False |
|
|
|
|
|
if step == -1: |
|
|
latest_checkpoint_dir = self._find_latest_checkpoint_dir() |
|
|
if latest_checkpoint_dir is None: |
|
|
return False |
|
|
step = int(latest_checkpoint_dir.name.split("_")[-1]) |
|
|
|
|
|
checkpoint_dir = self._get_checkpoint_dir(step) |
|
|
logger.info(f"Loading checkpoint from '{checkpoint_dir}' at step {step}") |
|
|
|
|
|
begin_time = time.monotonic() |
|
|
self.accelerator.load_state(checkpoint_dir.as_posix()) |
|
|
end_time = time.monotonic() |
|
|
logger.info(f"Loaded checkpoint in {end_time - begin_time:.2f} seconds.") |
|
|
|
|
|
return True |
|
|
|
|
|
def _should_checkpoint(self, step: int, force: bool) -> bool: |
|
|
if not self.enable: |
|
|
return False |
|
|
if not force: |
|
|
if step % self.checkpointing_steps != 0: |
|
|
return False |
|
|
return True |
|
|
|
|
|
def _get_checkpoint_dir(self, step: int) -> pathlib.Path: |
|
|
return self.output_dir / f"{self._prefix}_{step}" |
|
|
|
|
|
def _find_latest_checkpoint_dir(self) -> Optional[pathlib.Path]: |
|
|
checkpoints = sorted(self.output_dir.glob(f"{self._prefix}_*"), key=lambda x: int(x.name.split("_")[-1])) |
|
|
return checkpoints[-1] if len(checkpoints) > 0 else None |
|
|
|
|
|
def _purge_stale_checkpoints(self) -> None: |
|
|
if self.checkpointing_limit is None or self.checkpointing_limit <= 0: |
|
|
return |
|
|
checkpoints = sorted( |
|
|
self.output_dir.glob(f"{self._prefix}_*"), key=lambda x: int(x.name.split("_")[-1]), reverse=True |
|
|
) |
|
|
for checkpoint in checkpoints[self.checkpointing_limit :]: |
|
|
logger.info(f"Deleting stale checkpoint: {checkpoint}") |
|
|
shutil.rmtree(checkpoint, ignore_errors=True) |
|
|
|
|
|
|
|
|
def apply_ddp( |
|
|
model: torch.nn.Module, |
|
|
project_config: Optional[ProjectConfiguration] = None, |
|
|
ddp_kwargs: Optional[DistributedDataParallelKwargs] = None, |
|
|
init_process_group_kwargs: Optional[InitProcessGroupKwargs] = None, |
|
|
dataloader_config: Optional[DataLoaderConfiguration] = None, |
|
|
gradient_accumulation_steps: Optional[int] = None, |
|
|
accelerator: Optional[Accelerator] = None, |
|
|
) -> torch.nn.Module: |
|
|
if accelerator is None: |
|
|
accelerator = Accelerator( |
|
|
project_config=project_config, |
|
|
dataloader_config=dataloader_config, |
|
|
gradient_accumulation_steps=gradient_accumulation_steps, |
|
|
log_with=None, |
|
|
kwargs_handlers=[ddp_kwargs, init_process_group_kwargs], |
|
|
) |
|
|
if torch.backends.mps.is_available(): |
|
|
accelerator.native_amp = False |
|
|
accelerator.prepare_model(model) |
|
|
return accelerator, model |
|
|
|