Fully Sharded Data Parallel utilities
enable_fsdp_ram_efficient_loading
Enables RAM efficient loading of Hugging Face models for FSDP in the environment.
disable_fsdp_ram_efficient_loading
Disables RAM efficient loading of Hugging Face models for FSDP in the environment.
merge_fsdp_weights
accelerate.utils.merge_fsdp_weights
< source >( checkpoint_dir: str output_path: str safe_serialization: bool = True remove_checkpoint_dir: bool = False )
Parameters
- checkpoint_dir (
str
) — The directory containing the FSDP checkpoints (can be either the model or optimizer). - output_path (
str
) — The path to save the merged checkpoint. - safe_serialization (
bool
, optional, defaults toTrue
) — Whether to save the merged weights with safetensors (recommended). - remove_checkpoint_dir (
bool
, optional, defaults toFalse
) — Whether to remove the checkpoint directory after merging.
Merge the weights from sharded FSDP model checkpoints into a single combined checkpoint. Should be used if
SHARDED_STATE_DICT
was used for the model. Weights will be saved to {output_path}/model.safetensors
if
safe_serialization
else pytorch_model.bin
.
Note: this is a CPU-bound process.
FullyShardedDataParallelPlugin
class accelerate.FullyShardedDataParallelPlugin
< source >( sharding_strategy: typing.Union[str, ForwardRef('torch.distributed.fsdp.ShardingStrategy')] = None backward_prefetch: typing.Union[str, ForwardRef('torch.distributed.fsdp.BackwardPrefetch')] = None mixed_precision_policy: typing.Union[dict, ForwardRef('torch.distributed.fsdp.MixedPrecision'), NoneType] = None auto_wrap_policy: typing.Union[typing.Callable, typing.Literal['transformer_based_wrap', 'size_based_wrap', 'no_wrap'], NoneType] = None cpu_offload: typing.Union[bool, ForwardRef('torch.distributed.fsdp.CPUOffload')] = None ignored_modules: typing.Optional[typing.Iterable[torch.nn.modules.module.Module]] = None state_dict_type: typing.Union[str, ForwardRef('torch.distributed.fsdp.StateDictType')] = None state_dict_config: typing.Union[ForwardRef('torch.distributed.fsdp.FullStateDictConfig'), ForwardRef('torch.distributed.fsdp.ShardedStateDictConfig'), NoneType] = None optim_state_dict_config: typing.Union[ForwardRef('torch.distributed.fsdp.FullOptimStateDictConfig'), ForwardRef('torch.distributed.fsdp.ShardedOptimStateDictConfig'), NoneType] = None limit_all_gathers: bool = True use_orig_params: bool = None param_init_fn: typing.Optional[typing.Callable[[torch.nn.modules.module.Module], NoneType]] = None sync_module_states: bool = None forward_prefetch: bool = None activation_checkpointing: bool = None cpu_ram_efficient_loading: bool = None transformer_cls_names_to_wrap: typing.Optional[typing.List[str]] = None min_num_params: typing.Optional[int] = None )
Parameters
- sharding_strategy (
Union[str, torch.distributed.fsdp.ShardingStrategy]
, defaults to'FULL_SHARD'
) — Sharding strategy to use. Should be either astr
or an instance oftorch.distributed.fsdp.fully_sharded_data_parallel.ShardingStrategy
. - backward_prefetch (
Union[str, torch.distributed.fsdp.BackwardPrefetch]
, defaults to'NO_PREFETCH'
) — Backward prefetch strategy to use. Should be either astr
or an instance oftorch.distributed.fsdp.fully_sharded_data_parallel.BackwardPrefetch
. - mixed_precision_policy (
Optional[Union[dict, torch.distributed.fsdp.MixedPrecision]]
, defaults toNone
) — A config to enable mixed precision training with FullyShardedDataParallel. If passing in adict
, it should have the following keys:param_dtype
,reduce_dtype
, andbuffer_dtype
. - auto_wrap_policy (
Optional(Union[Callable, Literal["transformer_based_wrap", "size_based_wrap", "no_wrap"]]), defaults to
NO_WRAP) -- A callable or string specifying a policy to recursively wrap layers with FSDP. If a string, it must be one of
transformer_based_wrap,
size_based_wrap, or
no_wrap. See
torch.distributed.fsdp.wrap.size_based_wrap_policy` for a direction on what it should look like. - cpu_offload (
Union[bool, torch.distributed.fsdp.CPUOffload]
, defaults toFalse
) — Whether to offload parameters to CPU. Should be either abool
or an instance oftorch.distributed.fsdp.fully_sharded_data_parallel.CPUOffload
. - ignored_modules (
Optional[Iterable[torch.nn.Module]]
, defaults toNone
) — A list of modules to ignore when wrapping with FSDP. - state_dict_type (
Union[str, torch.distributed.fsdp.StateDictType]
, defaults to'FULL_STATE_DICT'
) — State dict type to use. If a string, it must be one offull_state_dict
,local_state_dict
, orsharded_state_dict
. - state_dict_config (
Optional[Union[torch.distributed.fsdp.FullStateDictConfig, torch.distributed.fsdp.ShardedStateDictConfig]
, defaults toNone
) — State dict config to use. Is determined based on thestate_dict_type
if not passed in. - optim_state_dict_config (
Optional[Union[torch.distributed.fsdp.FullOptimStateDictConfig, torch.distributed.fsdp.ShardedOptimStateDictConfig]
, defaults toNone
) — Optim state dict config to use. Is determined based on thestate_dict_type
if not passed in. - limit_all_gathers (
bool
, defaults toTrue
) — Whether to have FSDP explicitly synchronizes the CPU thread to prevent too many in-flight all-gathers. This bool only affects the sharded strategies that schedule all-gathers. Enabling this can help lower the number of CUDA malloc retries. - use_orig_params (
bool
, defaults toFalse
) — Whether to use the original parameters for the optimizer. - param_init_fn (
Optional[Callable[[torch.nn.Module], None]
, defaults toNone
) — ACallable[torch.nn.Module] -> None
that specifies how modules that are currently on the meta device should be initialized onto an actual device. Only applicable whensync_module_states
isTrue
. By default is alambda
which callsto_empty
on the module. - sync_module_states (
bool
, defaults toFalse
) — Whether each individually wrapped FSDP unit should broadcast module parameters from rank 0 to ensure they are the same across all ranks after initialization. Defaults toFalse
unlesscpu_ram_efficient_loading
isTrue
, then will be forcibly enabled. - forward_prefetch (
bool
, defaults toFalse
) — Whether to have FSDP explicitly prefetches the next upcoming all-gather while executing in the forward pass. only use with Static graphs. - activation_checkpointing (
bool
, defaults toFalse
) — A technique to reduce memory usage by clearing activations of certain layers and recomputing them during a backward pass. Effectively, this trades extra computation time for reduced memory usage. - cpu_ram_efficient_loading (
bool
, defaults toNone
) — If True, only the first process loads the pretrained model checkoint while all other processes have empty weights. Only applicable for Transformers. When using this,sync_module_states
needs to beTrue
. - transformer_cls_names_to_wrap (
Optional[List[str]]
, defaults toNone
) — A list of transformer layer class names to wrap. Only applicable whenauto_wrap_policy
istransformer_based_wrap
. - min_num_params (
Optional[int]
, defaults toNone
) — The minimum number of parameters a module must have to be wrapped. Only applicable whenauto_wrap_policy
issize_based_wrap
.
This plugin is used to enable fully sharded data parallelism.
Given model
, creates an auto_wrap_policy
baesd on the passed in policy and if we can use the
transformer_cls_to_wrap
Sets the mixed precision policy for FSDP
Set the state dict config based on the StateDictType
.