from typing import Callable, Iterable, Optional, Union import torch import torch.distributed as dist import torch.nn as nn from torch.distributed._composable.contract import contract from torch.distributed._composable_state import _get_module_state, _insert_module_state from torch.distributed.fsdp._common_utils import _FSDPState from torch.distributed.fsdp._dynamo_utils import _annotate_modules_for_dynamo from torch.distributed.fsdp._init_utils import ( _init_buffer_state, _init_core_state, _init_device_handle, _init_ignored_module_states, _init_param_handle_from_module, _init_prefetching_state, _init_process_group_state, _init_runtime_state, _init_state_dict_state, HYBRID_SHARDING_STRATEGIES, ) from torch.distributed.fsdp._runtime_utils import ( _register_post_forward_hook, _register_pre_forward_hook, _register_root_pre_forward_hook, ) from torch.distributed.fsdp._state_dict_utils import _register_all_state_dict_hooks from torch.distributed.fsdp._wrap_utils import _auto_wrap from torch.distributed.fsdp.api import ( BackwardPrefetch, CPUOffload, MixedPrecision, ShardingStrategy, ) from torch.distributed.fsdp.wrap import _Policy @contract(state_cls=_FSDPState) def fully_shard( module: nn.Module, *, process_group: Optional[dist.ProcessGroup] = None, policy: Optional[_Policy] = None, strategy: Optional[ShardingStrategy] = None, mixed_precision: Optional[MixedPrecision] = None, cpu_offload: Optional[CPUOffload] = None, ignored_modules: Optional[Iterable[torch.nn.Module]] = None, device_id: Optional[Union[int, torch.device]] = None, param_init_fn: Optional[Callable[[nn.Module], None]] = None, sync_module_states: bool = False, forward_prefetch: bool = False, ignored_states: Union[ Optional[Iterable[torch.nn.Parameter]], Optional[Iterable[torch.nn.Module]] ] = None, ) -> nn.Module: """ Applies ``FullyShardedDataParallel` (FSDP) semantics to ``module``. """ torch._C._log_api_usage_once("torch.distributed.fully_shard") # Enforce the new auto wrap policy if policy is not None and not isinstance(policy, _Policy): raise ValueError(f"Expects a `_Policy` but got {policy}") state = fully_shard.state(module) state = _init_ignored_module_states(state, module, ignored_modules, ignored_states) state = _init_device_handle(state, module, state._ignored_params, device_id) _annotate_modules_for_dynamo(module, state._ignored_modules, True) state = _init_process_group_state(state, process_group, strategy, policy) if policy is not None: root_kwargs = { "process_group": process_group, "strategy": strategy, "mixed_precision": mixed_precision, "cpu_offload": cpu_offload, "ignored_modules": ignored_modules, "device_id": device_id, "param_init_fn": param_init_fn, "sync_module_states": sync_module_states, "forward_prefetch": forward_prefetch, "ignored_states": ignored_states, } if strategy in HYBRID_SHARDING_STRATEGIES: root_kwargs["process_group"] = (state.process_group, state._inter_node_pg) _auto_wrap( module, policy, state._ignored_modules, state._ignored_params, root_kwargs, fully_shard, ) state = _init_core_state( state, strategy or ShardingStrategy.FULL_SHARD, mixed_precision, cpu_offload, limit_all_gathers=True, use_orig_params=True, backward_prefetch_limit=1, forward_prefetch_limit=1, ) state = _init_runtime_state(state) state = _init_prefetching_state( state, BackwardPrefetch.BACKWARD_PRE, forward_prefetch=forward_prefetch ) state = _init_buffer_state(state, module) state = _init_param_handle_from_module( state, module, device_id, param_init_fn, sync_module_states ) state = _init_state_dict_state(state) _register_all_state_dict_hooks(state) _register_pre_forward_hook(state, module) _register_post_forward_hook(state, module) _register_root_pre_forward_hook(state, module) # prepend last # Always insert the state for the passed-in module even if it has no # managed parameters, in which case it has no handles and does not appear # in `_fully_sharded_module_to_handles` _insert_module_state(module, state) for submodule in module.modules(): if ( submodule in state._fully_sharded_module_to_handle and _get_module_state(submodule) is None ): _insert_module_state(submodule, state) return module