| |
| from dataclasses import dataclass |
| from typing import Optional |
|
|
| import torch |
|
|
|
|
| @dataclass(frozen=True) |
| class MixedPrecisionPolicy: |
| """ |
| This configures FSDP's mixed precision. Unlike autocast, this applies mixed |
| precision at the module level, not op level, which means low-precision |
| activations are saved for backward and high-to-low-precision casts are |
| incurred only at module boundaries. |
| |
| FSDP works well with module-level mixed precision since it keeps the |
| high-precision sharded parameters in memory anyway. In other words, FSDP |
| does not require any extra memory to keep a high-precision copy of the |
| parameters for the optimizer step. |
| |
| Attributes: |
| param_dtype (Optional[torch.dtype]): This specifies the dtype for |
| the unsharded parameter and hence the dtype for forward/backward |
| computation and the parameter all-gather. If this is ``None``, then |
| the unsharded parameter uses the original dtype. The optimizer step |
| uses the sharded parameter in the original dtype. (Default: |
| ``None``) |
| reduce_dtype (Optional[torch.dtype]): This specifies the dtype for |
| gradient reduction (i.e. reduce-scatter or all-reduce). If this is |
| ``None`` but ``param_dtype`` is not ``None``, then the reduction |
| uses the compute dtype. This can be used to run gradient reduction |
| in full precision while using low precision for compute. If also |
| gradient reduction is disabled via :meth:`set_requires_gradient_sync`, |
| then FSDP will accumulate gradients using ``reduce_dtype``. |
| (Default: ``None``) |
| output_dtype (Optional[torch.dtype]): This specifies the dtype for |
| casting floating-point forward outputs. This can be used to |
| help implement cases where different modules have different mixed |
| precision policies. (Default: ``None``) |
| cast_forward_inputs (bool): This specifies whether FSDP should cast the |
| forward's floating-point input tensors to ``param_dtype`` or not. |
| """ |
|
|
| param_dtype: Optional[torch.dtype] = None |
| reduce_dtype: Optional[torch.dtype] = None |
| output_dtype: Optional[torch.dtype] = None |
| cast_forward_inputs: bool = True |
|
|
|
|
| @dataclass |
| class OffloadPolicy: |
| """ |
| This base class represents the policy of no offloading and is only used as |
| the default value for the ``offload_policy`` arg. |
| """ |
|
|
|
|
| @dataclass |
| class CPUOffloadPolicy(OffloadPolicy): |
| """ |
| This offload policy offloads parameters, gradients, and optimizer states to |
| CPU. Sharded parameters are copied host-to-device before all-gather. The |
| all-gathered parameters are freed according to ``reshard_after_forward``. |
| Sharded gradients are copied device-to-host in backward, and the optimizer |
| step runs on CPU with CPU optimizer states. |
| |
| Attributes: |
| pin_memory (bool): Whether to pin sharded parameter and gradient |
| memory. Pinning memory allows both more efficient H2D/D2H copies |
| and for the copies to overlap with compute. However, the pinned |
| memory cannot be used by other processes. Set this to ``False`` if |
| you have insufficient CPU memory. (Default: ``True``) |
| """ |
|
|
| pin_memory: bool = True |
|
|