Accelerate documentation

FP8

Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

FP8

Below are functions and classes relative to the underlying FP8 implementation

FP8RecipeKwargs

class accelerate.utils.FP8RecipeKwargs

< >

( backend: typing.Literal['MSAMP', 'TE'] = None use_autocast_during_eval: bool = None opt_level: typing.Literal['O1', 'O2'] = None margin: int = None interval: int = None fp8_format: typing.Literal['E4M3', 'HYBRID'] = None amax_history_len: int = None amax_compute_algo: typing.Literal['max', 'most_recent'] = None override_linear_precision: typing.Tuple[bool, bool, bool] = None )

Parameters

  • backend (str, optional) — Which FP8 engine to use. Must be one of "msamp" (MS-AMP) or "te" (TransformerEngine). If not passed, will use whichever is available in the environment, prioritizing MS-AMP.
  • use_autocast_during_eval (bool, optional, default to False) — Whether to use FP8 autocast during eval mode. Generally better metrics are found when this is False.
  • margin (int, optional, default to 0) — The margin to use for the gradient scaling.
  • interval (int, optional, default to 1) — The interval to use for how often the scaling factor is recomputed.
  • fp8_format (str, optional, default to “HYBRID”) — The format to use for the FP8 recipe. Must be one of HYBRID or E4M3. (Generally HYBRID for training, E4M3 for evaluation)
  • amax_history_len (int, optional, default to 1024) — The length of the history to use for the scaling factor computation
  • amax_compute_algo (str, optional, default to “most_recent”) — The algorithm to use for the scaling factor computation. Must be one of max or most_recent.
  • override_linear_precision (tuple of three bool, optional, default to (False, False, False)) — Whether or not to execute fprop, dgrad, and wgrad GEMMS in higher precision.
  • optimization_level (str), one of O1, O2. (default is O2) — What level of 8-bit collective communication should be used with MS-AMP. In general:
    • O1: Weight gradients and all_reduce communications are done in fp8, reducing GPU memory usage and communication bandwidth
    • O2: First-order optimizer states are in 8-bit, and second order states are in FP16. Only available when using Adam or AdamW. This maintains accuracy and can potentially save the highest memory.
    • 03: Specifically for DeepSpeed, implements capabilities so weights and master weights of models are stored in FP8. If fp8 is selected and deepspeed is enabled, will be used by default. (Not available currently).

Use this object in your Accelerator to customize the initialization of the recipe for FP8 mixed precision training with transformer-engine or ms-amp.

For more information on transformer-engine args, please refer to the API documentation.

For more information on the ms-amp args, please refer to the Optimization Level documentation.

from accelerate import Accelerator
from accelerate.utils import FP8RecipeKwargs

kwargs = FP8RecipeKwargs(backend="te", fp8_format="HYBRID")
accelerator = Accelerator(mixed_precision="fp8", kwargs_handlers=[kwargs])

To use MS-AMP as an engine, pass backend="msamp" and the optimization_level:

kwargs = FP8RecipeKwargs(backend="msamp", optimization_level="02")

convert_model

accelerate.utils.convert_model

< >

( model to_transformer_engine = True _convert_linear = True _convert_ln = True )

Recursively converts the linear and layernorm layers of a model to their transformers_engine counterpart.

has_transformer_engine_layers

accelerate.utils.has_transformer_engine_layers

< >

( model )

Returns whether a given model has some transformer_engine layer or not.

contextual_fp8_autocast

accelerate.utils.contextual_fp8_autocast

< >

( model_forward fp8_recipe use_during_eval = False )

Wrapper for a model’s forward method to apply FP8 autocast. Is context aware, meaning that by default it will disable FP8 autocast during eval mode, which is generally better for more accurate metrics.

apply_fp8_autowrap

accelerate.utils.apply_fp8_autowrap

< >

( model fp8_recipe_handler )

Applies FP8 context manager to the model’s forward method

< > Update on GitHub