Utilities for Fully Sharded Data Parallelism
class accelerate.FullyShardedDataParallelPlugin
< source >( sharding_strategy: typing.Any = None backward_prefetch: typing.Any = None mixed_precision_policy: typing.Any = None auto_wrap_policy: typing.Optional[typing.Callable] = None cpu_offload: typing.Any = None ignored_modules: typing.Optional[typing.Iterable[torch.nn.modules.module.Module]] = None state_dict_type: typing.Any = None state_dict_config: typing.Any = None optim_state_dict_config: typing.Any = None limit_all_gathers: bool = False use_orig_params: bool = True param_init_fn: typing.Optional[typing.Callable[[torch.nn.modules.module.Module]], NoneType] = None sync_module_states: bool = True forward_prefetch: bool = False activation_checkpointing: bool = False )
This plugin is used to enable fully sharded data parallelism.
get_module_class_from_name
< source >( module name )
Gets a class from a module by its name.