Spaces:
Build error
Build error
| import datetime | |
| import functools | |
| import os | |
| import pathlib | |
| import shutil | |
| import time | |
| from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union | |
| import datasets.distributed | |
| import torch | |
| import torch.distributed._functional_collectives | |
| import torch.distributed.checkpoint | |
| import torch.distributed.checkpoint.stateful | |
| from diffusers.hooks import HookRegistry, ModelHook | |
| from torch.distributed._composable.fsdp import CPUOffloadPolicy, MixedPrecisionPolicy, fully_shard | |
| from torch.distributed._composable.replicate import replicate | |
| from torch.distributed.checkpoint.state_dict import ( | |
| StateDictOptions, | |
| get_model_state_dict, | |
| set_model_state_dict, | |
| ) | |
| from torch.distributed.tensor import DTensor, Shard | |
| from finetrainers._metadata import ContextParallelModelPlan, CPInput, CPOutput, TransformerRegistry | |
| from finetrainers.data import DPDataLoader | |
| from finetrainers.logging import get_logger | |
| from finetrainers.utils import enable_determinism, get_device_info, get_submodule_by_name, unwrap_module | |
| from finetrainers.utils._common import DIFFUSERS_TRANSFORMER_BLOCK_NAMES | |
| from .base import BaseCheckpointer, BaseParallelBackend | |
| if TYPE_CHECKING: | |
| from finetrainers import optimizer | |
| _device_type, _device_module = get_device_info() | |
| logger = get_logger() | |
| class PytorchDTensorParallelBackend(BaseParallelBackend): | |
| def __init__( | |
| self, | |
| world_size: int, | |
| pp_degree: int = 1, | |
| dp_degree: int = 1, | |
| dp_shards: int = -1, | |
| cp_degree: int = 1, | |
| tp_degree: int = 1, | |
| backend: str = "nccl", | |
| timeout: int = 180, | |
| logging_dir: Optional[str] = None, | |
| output_dir: Optional[str] = None, | |
| gradient_accumulation_steps: Optional[int] = None, | |
| ) -> None: | |
| super().__init__() | |
| self._world_size = world_size | |
| self._pp_degree = pp_degree | |
| self._dp_degree = dp_degree | |
| self._dp_shards = dp_shards | |
| self._cp_degree = cp_degree | |
| self._tp_degree = tp_degree | |
| self._output_dir = pathlib.Path(output_dir) if output_dir is not None else None | |
| self._logging_dir = ( | |
| self._output_dir / logging_dir if output_dir is not None and logging_dir is not None else None | |
| ) | |
| self._backend = backend | |
| self._timeout = timeout | |
| for degree in [pp_degree, dp_degree, dp_shards, cp_degree, tp_degree]: | |
| if degree < 1: | |
| raise ValueError(f"Parallel degree must be at least 1, got {degree}.") | |
| if dp_shards * pp_degree * dp_degree * cp_degree * tp_degree != world_size: | |
| raise ValueError( | |
| f"World size {world_size} must be divisible by the product of all parallel degrees and data parallel shards." | |
| ) | |
| torch.distributed.init_process_group(backend=self._backend, timeout=datetime.timedelta(seconds=self._timeout)) | |
| _device_module.set_device(self.local_rank) | |
| logger.info( | |
| f"Initialized parallel state with:\n" | |
| f" - World size: {world_size}\n" | |
| f" - Pipeline parallel degree: {pp_degree}\n" | |
| f" - Data parallel degree: {dp_degree}\n" | |
| f" - Context parallel degree: {cp_degree}\n" | |
| f" - Tensor parallel degree: {tp_degree}\n" | |
| f" - Data parallel shards: {dp_shards}\n" | |
| ) | |
| self._mesh: torch.distributed.DeviceMesh = None | |
| def enable_determinism(self, seed): | |
| world_mesh = self.get_mesh() | |
| enable_determinism(seed, world_mesh) | |
| def apply_ddp( | |
| self, model: torch.nn.Module, device_mesh: Optional[torch.distributed.DeviceMesh] = None | |
| ) -> torch.nn.Module: | |
| if device_mesh is None: | |
| device_mesh = self.get_mesh() | |
| apply_ddp(model, device_mesh) | |
| logger.debug("Applied PytorchDTensorParallel::apply_ddp to model.") | |
| return model | |
| def apply_fsdp2( | |
| self, | |
| model: torch.nn.Module, | |
| param_dtype: torch.dtype, | |
| reduce_dtype: torch.dtype, | |
| output_dtype: torch.dtype, | |
| pp_enabled: bool = False, | |
| cpu_offload: bool = False, | |
| device_mesh: Optional[torch.distributed.DeviceMesh] = None, | |
| ) -> torch.nn.Module: | |
| if device_mesh is None: | |
| device_mesh = self.get_mesh() | |
| apply_fsdp2(model, device_mesh, param_dtype, reduce_dtype, output_dtype, pp_enabled, cpu_offload) | |
| logger.debug("Applied PytorchDTensorParallel::apply_fsdp2 to model.") | |
| return model | |
| def apply_context_parallel( | |
| self, model: torch.nn.Module, device_mesh: Optional[torch.distributed.DeviceMesh] = None | |
| ) -> torch.nn.Module: | |
| if device_mesh is None: | |
| device_mesh = self.get_mesh() | |
| apply_context_parallel(model, device_mesh) | |
| logger.debug("Applied PytorchDTensorParallel::apply_context_parallel to model.") | |
| return model | |
| def prepare_model(self, model: torch.nn.Module) -> torch.nn.Module: | |
| return model | |
| def prepare_dataset(self, dataset: torch.utils.data.IterableDataset) -> torch.utils.data.IterableDataset: | |
| if self._dp_degree == 1: | |
| return dataset | |
| dp_mesh = self.get_mesh()["dp_replicate"] | |
| dp_local_rank, dp_world_size = dp_mesh.get_local_rank(), dp_mesh.size() | |
| dataset._data = datasets.distributed.split_dataset_by_node(dataset._data, dp_local_rank, dp_world_size) | |
| logger.debug("PytorchDTensorParallelBackend::prepare_dataset completed!") | |
| return dataset | |
| def prepare_dataloader( | |
| self, dataset: torch.utils.data.IterableDataset, batch_size: int, num_workers: int, pin_memory: bool | |
| ) -> DPDataLoader: | |
| if self._dp_degree == 1: | |
| dp_local_rank = 0 | |
| else: | |
| dp_mesh = self.get_mesh()["dp_replicate"] | |
| dp_local_rank = dp_mesh.get_local_rank() | |
| dataloader = DPDataLoader(dp_local_rank, dataset, batch_size=batch_size, num_workers=num_workers) | |
| logger.debug("PytorchDTensorParallelBackend::prepare_dataloader completed!") | |
| return dataloader | |
| def prepare_optimizer(self, optimizer, lr_scheduler): | |
| logger.debug("PytorchDTensorParallelBackend::prepare_optimizer completed!") | |
| return optimizer, lr_scheduler | |
| def get_mesh(self, name: Optional[str] = None) -> torch.distributed.DeviceMesh: | |
| def _get_mesh(): | |
| if name is None: | |
| return self._mesh | |
| try: | |
| return self._mesh[name] | |
| except (KeyError, RuntimeError): | |
| if self._mesh.ndim == 0: | |
| return None | |
| return self._mesh | |
| if self._mesh is not None: | |
| return _get_mesh() | |
| mesh_list = [ | |
| ("pp", self._pp_degree), | |
| ("dp_replicate", self._dp_degree), | |
| ("dp_shard", self._dp_shards), | |
| ("cp", self._cp_degree), | |
| ("tp", self._tp_degree), | |
| ] | |
| mesh_list = [(name, degree) for name, degree in mesh_list if degree > 1] | |
| names = [x[0] for x in mesh_list] | |
| degrees = [x[1] for x in mesh_list] | |
| mesh = torch.distributed.device_mesh.init_device_mesh(_device_type, mesh_shape=degrees, mesh_dim_names=names) | |
| dp_mesh_names, dp_cp_mesh_names, dp_shard_cp_mesh_names = [], [], [] | |
| if self.data_replication_enabled: | |
| dp_mesh_names.append("dp_replicate") | |
| dp_cp_mesh_names.append("dp_replicate") | |
| if self.data_sharding_enabled: | |
| dp_mesh_names.append("dp_shard") | |
| dp_cp_mesh_names.append("dp_shard") | |
| dp_shard_cp_mesh_names.append("dp_shard") | |
| if self.context_parallel_enabled: | |
| dp_cp_mesh_names.append("cp") | |
| dp_shard_cp_mesh_names.append("cp") | |
| if len(dp_mesh_names) > 0: | |
| mesh[tuple(dp_mesh_names)]._flatten(mesh_dim_name="dp") | |
| if len(dp_cp_mesh_names) > 0: | |
| mesh[tuple(dp_cp_mesh_names)]._flatten(mesh_dim_name="dp_cp") | |
| if len(dp_shard_cp_mesh_names) > 0: | |
| mesh[tuple(dp_shard_cp_mesh_names)]._flatten(mesh_dim_name="dp_shard_cp") | |
| logger.debug(f"Device mesh: {mesh}") | |
| self._mesh = mesh | |
| return _get_mesh() | |
| def get_checkpointer(self, *args, **kwargs): | |
| return PTDCheckpointer(*args, **kwargs) | |
| def world_size(self): | |
| return torch.distributed.get_world_size() | |
| def rank(self): | |
| return torch.distributed.get_rank() | |
| def local_rank(self): | |
| return int(os.environ.get("LOCAL_RANK", 0)) | |
| def is_main_process(self): | |
| r"""Returns `True` if the current process is the main process on the master node.""" | |
| return self.rank == 0 | |
| def is_local_main_process(self): | |
| r"""Returns `True` if the current process is the main process on local node.""" | |
| return self.local_rank == 0 | |
| def device(self): | |
| return torch.device(_device_type, self.local_rank) | |
| def wait_for_everyone(self): | |
| return torch.distributed.barrier() | |
| # @contextmanager | |
| # def main_process_first(self): | |
| # if self.is_main_process: | |
| # yield | |
| # self.wait_for_everyone() | |
| # else: | |
| # self.wait_for_everyone() | |
| # yield | |
| def destroy(self): | |
| if self.is_main_process and self.tracker is not None: | |
| self.tracker.finish() | |
| return torch.distributed.destroy_process_group() | |
| def pipeline_parallel_enabled(self): | |
| return self._pp_degree > 1 | |
| def data_parallel_enabled(self): | |
| return self._dp_degree > 1 or self._dp_shards > 1 | |
| def data_replication_enabled(self): | |
| return self._dp_degree > 1 | |
| def data_sharding_enabled(self): | |
| return self._dp_shards > 1 | |
| def context_parallel_enabled(self): | |
| return self._cp_degree > 1 | |
| def tensor_parallel_enabled(self): | |
| return self._tp_degree > 1 | |
| class ModelWrapper(torch.distributed.checkpoint.stateful.Stateful): | |
| def __init__(self, model: Union[torch.nn.Module, List[torch.nn.Module]]) -> None: | |
| self.model = [model] if isinstance(model, torch.nn.Module) else model | |
| def state_dict(self) -> Dict[str, Any]: | |
| return {k: v for sd in map(get_model_state_dict, self.model) for k, v in sd.items()} | |
| def load_state_dict(self, state_dict: Dict[str, Any]) -> None: | |
| func = functools.partial( | |
| set_model_state_dict, | |
| model_state_dict=state_dict, | |
| options=StateDictOptions(strict=False), | |
| ) | |
| list(map(func, self.model)) | |
| class PTDCheckpointer(BaseCheckpointer): | |
| def __init__( | |
| self, | |
| dataloader: torch.utils.data.DataLoader, | |
| model_parts: List[torch.nn.Module], | |
| optimizers: "optimizer.OptimizerWrapper", | |
| schedulers: "optimizer.SchedulerWrapper", | |
| states: Dict[str, Any], | |
| checkpointing_steps: int, | |
| checkpointing_limit: int, | |
| output_dir: str, | |
| enable: bool = True, | |
| _callback_fn: Callable[[Dict[str, Any]], Dict[str, Any]] = None, | |
| _prefix: str = "finetrainers_step", | |
| ) -> None: | |
| self.states = states | |
| self.states.update( | |
| { | |
| "model": ModelWrapper(model_parts), | |
| "optimizer": optimizers, | |
| "dataloader": dataloader, | |
| } | |
| ) | |
| self.states.update(schedulers.get_lr_scheduler_state()) | |
| self.checkpointing_steps = checkpointing_steps | |
| self.checkpointing_limit = checkpointing_limit | |
| self.output_dir = pathlib.Path(output_dir) | |
| self.enable = enable | |
| self._callback_fn = _callback_fn | |
| self._prefix = _prefix | |
| logger.info(f"Checkpointing enabled. Checkpoints will be stored in '{self.output_dir}'") | |
| def save(self, step: int = -1, force: bool = False, *, _device: torch.device, _is_main_process: bool) -> str: | |
| if not self._should_checkpoint(step, force): | |
| return None | |
| checkpoint_dir = self._get_checkpoint_dir(step) | |
| begin_time = time.monotonic() | |
| torch.distributed.checkpoint.save(self.states, checkpoint_id=checkpoint_dir.as_posix()) | |
| end_time = time.monotonic() | |
| logger.info( | |
| f"Saved checkpoint in {end_time - begin_time:.2f} seconds at step {step}. Directory: {checkpoint_dir}" | |
| ) | |
| self._purge_stale_checkpoints() | |
| state_dicts = [ | |
| gather_state_dict_on_cpu_rank0(model, _device, is_main_process=_is_main_process) | |
| for model in self.states["model"].model | |
| ] | |
| if self._callback_fn is not None: | |
| list(map(self._callback_fn, state_dicts)) | |
| return checkpoint_dir.as_posix() | |
| def load(self, step: int = -1) -> bool: | |
| if not self.enable: | |
| return False | |
| if not self.output_dir.exists(): | |
| return False | |
| if step != -1 and not self._get_checkpoint_dir(step).exists(): | |
| return False | |
| if step == -1: | |
| latest_checkpoint_dir = self._find_latest_checkpoint_dir() | |
| if latest_checkpoint_dir is None: | |
| return False | |
| step = int(latest_checkpoint_dir.name.split("_")[-1]) | |
| checkpoint_dir = self._get_checkpoint_dir(step) | |
| logger.info(f"Loading checkpoint from '{checkpoint_dir}' at step {step}") | |
| # For step 0, optimizers/schedulers are not available as they are created during training after first step | |
| states = {"model": self.states["model"]} if step == 0 else self.states | |
| # See bug: https://github.com/pytorch/pytorch/pull/138575 | |
| original_stateful_states = { | |
| k: v for k, v in states.items() if isinstance(v, torch.distributed.checkpoint.stateful.Stateful) | |
| } | |
| begin_time = time.monotonic() | |
| torch.distributed.checkpoint.load(states, checkpoint_id=checkpoint_dir.as_posix()) | |
| end_time = time.monotonic() | |
| logger.info(f"Loaded checkpoint in {end_time - begin_time:.2f} seconds.") | |
| # bugfix from above: restore the original stateful objects, whose states were already updated in-place by dcp.load() | |
| states.update(original_stateful_states) | |
| return True | |
| def _should_checkpoint(self, step: int, force: bool) -> bool: | |
| if not self.enable: | |
| return False | |
| if not force: | |
| if step % self.checkpointing_steps != 0: | |
| return False | |
| return True | |
| def _get_checkpoint_dir(self, step: int) -> pathlib.Path: | |
| return self.output_dir / f"{self._prefix}_{step}" | |
| def _find_latest_checkpoint_dir(self) -> Optional[pathlib.Path]: | |
| checkpoints = sorted(self.output_dir.glob(f"{self._prefix}_*"), key=lambda x: int(x.name.split("_")[-1])) | |
| return checkpoints[-1] if len(checkpoints) > 0 else None | |
| def _purge_stale_checkpoints(self) -> None: | |
| if self.checkpointing_limit is None or self.checkpointing_limit <= 0: | |
| return | |
| checkpoints = sorted( | |
| self.output_dir.glob(f"{self._prefix}_*"), key=lambda x: int(x.name.split("_")[-1]), reverse=True | |
| ) | |
| for checkpoint in checkpoints[self.checkpointing_limit :]: | |
| logger.info(f"Deleting stale checkpoint: {checkpoint}") | |
| shutil.rmtree(checkpoint, ignore_errors=True) | |
| def gather_state_dict_on_cpu_rank0( | |
| model, device: Optional[torch.device] = None, *, is_main_process: bool | |
| ) -> Dict[str, Any]: | |
| cpu_state_dict = {} | |
| sharded_sd = model.state_dict() | |
| for param_name, param in sharded_sd.items(): | |
| if param.is_cpu: | |
| # Move back to device if offloaded to CPU | |
| param = param.to(device) | |
| if hasattr(param, "_local_tensor"): | |
| # Gather DTensor | |
| param = param.full_tensor() | |
| if is_main_process: | |
| cpu_state_dict[param_name] = param.cpu() | |
| torch.distributed.barrier() | |
| return cpu_state_dict | |
| # # Copied from pytorch (torch/distributed/checkpoint/format_utils.py) to support callbacks to modify state_dict | |
| # def dcp_to_torch_save( | |
| # dcp_checkpoint_dir: Union[str, os.PathLike], | |
| # torch_save_path: Union[str, os.PathLike], | |
| # callback_fn: Callable[[Dict[str, Any]], Dict[str, Any]] = None, | |
| # ): | |
| # """ | |
| # Given a directory containing a DCP checkpoint, this function will convert it into a | |
| # Torch save file. | |
| # Args: | |
| # dcp_checkpoint_dir: Directory containing the DCP checkpoint. | |
| # torch_save_path: Filename to store the converted Torch save file. | |
| # callback_fn: Optional callback function that takes the state_dict as input and returns a modified state_dict. | |
| # .. warning:: | |
| # To avoid OOM, it's recommended to only run this function on a single rank. | |
| # """ | |
| # state_dict = {} | |
| # _load_state_dict( | |
| # state_dict, | |
| # storage_reader=FileSystemReader(dcp_checkpoint_dir), | |
| # planner=_EmptyStateDictLoadPlanner(), | |
| # no_dist=True, | |
| # ) | |
| # if callback_fn is not None: | |
| # state_dict = callback_fn(state_dict) | |
| # torch.save(state_dict, torch_save_path) | |
| def apply_ddp(model: torch.nn.Module, dp_mesh: torch.distributed.device_mesh.DeviceMesh) -> None: | |
| replicate(model, device_mesh=dp_mesh, bucket_cap_mb=100) | |
| def apply_fsdp2( | |
| model: torch.nn.Module, | |
| dp_mesh: torch.distributed.device_mesh.DeviceMesh, | |
| param_dtype: torch.dtype, | |
| reduce_dtype: torch.dtype, | |
| output_dtype: torch.dtype, | |
| pp_enabled: bool = False, | |
| cpu_offload: bool = False, | |
| ) -> None: | |
| """Apply FSDP2 on a model.""" | |
| mp_policy = MixedPrecisionPolicy(param_dtype, reduce_dtype, output_dtype, cast_forward_inputs=True) | |
| fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy} | |
| if cpu_offload: | |
| fsdp_config["offload_policy"] = CPUOffloadPolicy(pin_memory=True) | |
| def apply_fully_shard(blocks): | |
| for layer_index, block in enumerate(blocks): | |
| if pp_enabled: | |
| # For PP, do not reshard after forward to avoid per-microbatch | |
| # all-gathers, which can be expensive and non-overlapped | |
| reshard_after_forward = False | |
| else: | |
| # As an optimization, do not reshard after forward for the last | |
| # transformer block since FSDP would prefetch it immediately | |
| reshard_after_forward = layer_index < len(blocks) - 1 | |
| fully_shard(block, **fsdp_config, reshard_after_forward=reshard_after_forward) | |
| for transformer_block_name in DIFFUSERS_TRANSFORMER_BLOCK_NAMES: | |
| blocks = getattr(model, transformer_block_name, None) | |
| if blocks is not None: | |
| apply_fully_shard(blocks) | |
| fully_shard(model, **fsdp_config, reshard_after_forward=not pp_enabled) | |
| def apply_context_parallel( | |
| model: torch.nn.Module, | |
| mesh: torch.distributed.device_mesh.DeviceMesh, | |
| plan: Optional[Dict[str, ContextParallelModelPlan]] = None, | |
| ) -> None: | |
| """Apply context parallel on a model.""" | |
| logger.debug(f"Applying context parallel with CP mesh: {mesh}") | |
| model_cls = unwrap_module(model).__class__ | |
| if plan is None: | |
| plan = TransformerRegistry.get(model_cls).cp_plan | |
| for module_id, cp_model_plan in plan.items(): | |
| module = get_submodule_by_name(model, module_id) | |
| if not isinstance(module, list): | |
| module = [module] | |
| logger.debug(f"Applying ContextParallelHook to {module_id=} identifying a total of {len(module)} modules") | |
| for m in module: | |
| registry = HookRegistry.check_if_exists_or_initialize(m) | |
| if isinstance(cp_model_plan, list): | |
| # Metadata can only be a list when it is a list of CPOutput | |
| assert all(isinstance(x, CPOutput) for x in cp_model_plan) | |
| hook = ContextParallelGatherHook(cp_model_plan, mesh) | |
| hook_name = f"cp_output---{module_id}" | |
| else: | |
| hook = ContextParallelSplitHook(cp_model_plan, mesh) | |
| hook_name = f"cp_input---{module_id}" | |
| registry.register_hook(hook, hook_name) | |
| class ContextParallelSplitHook(ModelHook): | |
| def __init__(self, metadata: ContextParallelModelPlan, mesh: torch.distributed.device_mesh.DeviceMesh) -> None: | |
| super().__init__() | |
| self.metadata = metadata | |
| self.mesh = mesh | |
| def pre_forward(self, module, *args, **kwargs): | |
| args_list = list(args) | |
| for param_identifier, cpm in self.metadata.items(): | |
| name = param_identifier.name | |
| index = param_identifier.index | |
| if isinstance(cpm, CPInput) and cpm.split_output: | |
| continue | |
| # Maybe the parameter was passed as a keyword argument | |
| is_kwarg = True | |
| input_val = kwargs.get(name, None) | |
| # If not, maybe it was passed as a positional argument | |
| if input_val is None and index is not None: | |
| if index < len(args_list): # Ensure index is within bounds | |
| input_val = args_list[index] | |
| is_kwarg = False | |
| else: | |
| logger.warning(f"Index {index} out of bounds for args of length {len(args_list)}.") | |
| continue # Skip if index is invalid | |
| # Either the input_val is truly None, or argument is passed as normal argument | |
| # but user forgot to specify the index when registering metadata | |
| if input_val is None: | |
| continue | |
| # The input_val may be a tensor or list/tuple of tensors. In certain cases, user may specify to shard | |
| # the output instead of input for a particular layer by setting split_output=True | |
| if torch.is_tensor(input_val): | |
| input_val = self._prepare_cp_input(input_val, cpm) | |
| elif isinstance(input_val, (list, tuple)): | |
| if len(input_val) != len(cpm): | |
| raise ValueError( | |
| f"Expected input model plan to have {len(input_val)} elements, but got {len(cpm)}." | |
| ) | |
| sharded_input_val = [] | |
| for i, x in enumerate(input_val): | |
| if torch.is_tensor(x) and not cpm[i].split_output: | |
| x = self._prepare_cp_input(x, cpm[i]) | |
| sharded_input_val.append(x) | |
| input_val = sharded_input_val | |
| else: | |
| raise ValueError(f"Unsupported input type: {type(input_val)}") | |
| if is_kwarg: | |
| kwargs[name] = input_val | |
| elif index is not None and index < len(args_list): | |
| args_list[index] = input_val | |
| return tuple(args_list), kwargs | |
| def post_forward(self, module, output): | |
| is_tensor = torch.is_tensor(output) | |
| is_tensor_list = isinstance(output, (list, tuple)) and all(torch.is_tensor(x) for x in output) | |
| if not is_tensor and not is_tensor_list: | |
| raise ValueError(f"Expected output to be a tensor or a list/tuple of tensors, but got {type(output)}.") | |
| output = [output] if is_tensor else list(output) | |
| for param_identifier, cpm in self.metadata.items(): | |
| if not isinstance(cpm, CPInput) or not cpm.split_output: | |
| continue | |
| index = param_identifier.index | |
| if index >= len(output): | |
| raise ValueError(f"Index {index} out of bounds for output of length {len(output)}.") | |
| current_output = output[index] | |
| current_output = self._prepare_cp_input(current_output, cpm) | |
| output[index] = current_output | |
| return output[0] if is_tensor else tuple(output) | |
| def _prepare_cp_input(self, x: torch.Tensor, cp_input: CPInput) -> torch.Tensor: | |
| if cp_input.expected_dims is not None and x.dim() != cp_input.expected_dims: | |
| raise ValueError( | |
| f"Expected input tensor to have {cp_input.expected_dims} dimensions, but got {x.dim()} dimensions." | |
| ) | |
| return _EquipartitionSharder.shard(x, cp_input.split_dim, self.mesh) | |
| class ContextParallelGatherHook(ModelHook): | |
| def __init__(self, metadata: ContextParallelModelPlan, mesh: torch.distributed.device_mesh.DeviceMesh) -> None: | |
| super().__init__() | |
| self.metadata = metadata | |
| self.mesh = mesh | |
| def post_forward(self, module, output): | |
| is_tensor = torch.is_tensor(output) | |
| if is_tensor: | |
| output = [output] | |
| output = list(output) | |
| assert len(output) == len(self.metadata), f"Expected {len(self.metadata)} outputs, but got {len(output)}." | |
| for i, cpm in enumerate(self.metadata): | |
| if cpm is None: | |
| continue | |
| output[i] = _EquipartitionSharder.unshard(output[i], cpm.gather_dim, self.mesh) | |
| return output[0] if is_tensor else tuple(output) | |
| class _ContextParallelSharder: | |
| def shard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor: | |
| raise NotImplementedError("_ContextParallelSharder::shard should be implemented in subclasses") | |
| def unshard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor: | |
| raise NotImplementedError("_ContextParallelSharder::unshard should be implemented in subclasses") | |
| class _EquipartitionSharder(_ContextParallelSharder): | |
| """ | |
| Shards the input tensor along the specified dimension into cp_mesh's world size chunks. | |
| Essentially, rank_i gets the i-th chunk. | |
| This sharding strategy should only be used when performing full attention. Otherwise, it will | |
| have performance penalty. If using causal attention, please use _CausalSharder instead. | |
| """ | |
| def shard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor: | |
| assert tensor.size()[dim] % mesh.size() == 0 | |
| return tensor.chunk(mesh.size(), dim=dim)[mesh.get_local_rank()] | |
| def unshard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor: | |
| tensor = tensor.contiguous() | |
| # TODO(aryan): pass a shape here so that we can allow uneven sharding across seq dim | |
| result = DTensor.from_local(tensor, mesh, placements=[Shard(dim)]).full_tensor() | |
| return result | |
| # TODO(aryan): this class is untested | |
| class _CausalSharder(_ContextParallelSharder): | |
| """ | |
| Shards the input tensor along the specified dimension into 2x cp_mesh's world size chunks. | |
| Essentially, rank_i gets the i-th chunk and (2 * cp_world_size - 1 - i)-th chunk. | |
| This sharding strategy improves the performance for causal attention, as it allows | |
| equal distribution of computation across all ranks. | |
| Causal attention mask: | |
| ``` | |
| 1 0 0 0 <--- Group 0 | |
| 1 1 0 0 <--- Group 1 | |
| 1 1 1 0 <--- Group 1 | |
| 1 1 1 1 <--- Group 0 | |
| ``` | |
| """ | |
| def shard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor: | |
| world_size = mesh.size() | |
| rank = mesh.get_local_rank() | |
| assert tensor.size()[dim] % (2 * world_size) == 0 | |
| chunks = tensor.chunk(2 * world_size, dim=dim) | |
| i, j = rank, 2 * world_size - 1 - rank | |
| return torch.cat((chunks[i], chunks[j]), dim=dim) | |
| def unshard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor: | |
| tensor = tensor.contiguous() | |
| world_size = mesh.size() | |
| # TODO(aryan): pass a shape here so that we can allow uneven sharding across seq dim | |
| all_tensors = DTensor.from_local(tensor, mesh, placements=[Shard(dim)]).full_tensor() | |
| sliced_tensors = [st for t in all_tensors for st in t.chunk(2, dim=dim)] | |
| ordered_tensors = list(sliced_tensors) | |
| for i, t in enumerate(sliced_tensors): | |
| if i % 2 == 0: | |
| ordered_tensors[i // 2] = t | |
| else: | |
| ordered_tensors[world_size * 2 - (i // 2) - 1] = t | |
| return torch.cat(ordered_tensors, dim=dim) | |