diff --git "a/llm/Lib/site-packages/accelerate/accelerator.py" "b/llm/Lib/site-packages/accelerate/accelerator.py"
new file mode 100644--- /dev/null
+++ "b/llm/Lib/site-packages/accelerate/accelerator.py"
@@ -0,0 +1,3259 @@
+# Copyright 2021 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+import contextlib
+import functools
+import json
+import math
+import os
+import re
+import shutil
+import sys
+import warnings
+from collections import OrderedDict
+from contextlib import contextmanager
+from functools import partial
+from types import MethodType
+from typing import Any, Callable, Union
+
+import torch
+import torch.utils.hooks as hooks
+
+from .checkpointing import load_accelerator_state, load_custom_state, save_accelerator_state, save_custom_state
+from .data_loader import DataLoaderDispatcher, prepare_data_loader, skip_first_batches
+from .hooks import AlignDevicesHook
+from .logging import get_logger
+from .optimizer import AcceleratedOptimizer
+from .scheduler import AcceleratedScheduler
+from .state import AcceleratorState, GradientState, PartialState
+from .tracking import LOGGER_TYPE_TO_CLASS, GeneralTracker, filter_trackers
+from .utils import (
+ MODEL_NAME,
+ SAFE_WEIGHTS_INDEX_NAME,
+ SAFE_WEIGHTS_NAME,
+ WEIGHTS_INDEX_NAME,
+ WEIGHTS_NAME,
+ AutocastKwargs,
+ DataLoaderConfiguration,
+ DeepSpeedPlugin,
+ DistributedDataParallelKwargs,
+ DistributedType,
+ DynamoBackend,
+ FP8RecipeKwargs,
+ FullyShardedDataParallelPlugin,
+ GradientAccumulationPlugin,
+ GradScalerKwargs,
+ InitProcessGroupKwargs,
+ KwargsHandler,
+ LoggerType,
+ MegatronLMPlugin,
+ PrecisionType,
+ ProjectConfiguration,
+ RNGType,
+ TorchDynamoPlugin,
+ check_os_kernel,
+ clean_state_dict_for_safetensors,
+ compare_versions,
+ convert_model,
+ convert_outputs_to_fp32,
+ extract_model_from_parallel,
+ gather,
+ gather_object,
+ get_mixed_precision_context_manager,
+ get_pretty_name,
+ has_transformer_engine_layers,
+ is_bf16_available,
+ is_deepspeed_available,
+ is_fp8_available,
+ is_ipex_available,
+ is_megatron_lm_available,
+ is_mlu_available,
+ is_msamp_available,
+ is_npu_available,
+ is_torch_version,
+ is_torch_xla_available,
+ is_xpu_available,
+ load_fsdp_model,
+ load_fsdp_optimizer,
+ pad_across_processes,
+ parse_choice_from_env,
+ recursively_apply,
+ reduce,
+ release_memory,
+ save,
+ save_fsdp_model,
+ save_fsdp_optimizer,
+ shard_checkpoint,
+ wait_for_everyone,
+)
+from .utils.constants import FSDP_PYTORCH_VERSION
+from .utils.modeling import get_state_dict_offloaded_model
+from .utils.other import is_compiled_module
+
+
+if is_deepspeed_available():
+ from .utils import (
+ DeepSpeedEngineWrapper,
+ DeepSpeedOptimizerWrapper,
+ DeepSpeedSchedulerWrapper,
+ DummyOptim,
+ DummyScheduler,
+ )
+
+if is_fp8_available():
+ import transformer_engine.common.recipe as te_recipe
+ from transformer_engine.pytorch import fp8_autocast
+
+
+if is_megatron_lm_available():
+ from .utils import (
+ MegatronEngine,
+ MegatronLMDummyDataLoader,
+ MegatronLMDummyScheduler,
+ MegatronLMOptimizerWrapper,
+ MegatronLMSchedulerWrapper,
+ megatron_lm_initialize,
+ megatron_lm_prepare_data_loader,
+ megatron_lm_prepare_model,
+ megatron_lm_prepare_optimizer,
+ megatron_lm_prepare_scheduler,
+ )
+
+from torch.distributed.algorithms.join import Join
+
+
+if is_torch_xla_available():
+ import torch_xla.amp as xamp
+ import torch_xla.core.xla_model as xm
+ import torch_xla.distributed.xla_multiprocessing as xmp
+
+
+if is_npu_available(check_device=False):
+ import torch_npu # noqa: F401
+
+
+try:
+ from torch.optim.lr_scheduler import LRScheduler
+except ImportError:
+ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
+
+logger = get_logger(__name__)
+
+# Sentinel values for defaults
+_split_batches = object()
+_dispatch_batches = object()
+_even_batches = object()
+_use_seedable_sampler = object()
+
+
+class Accelerator:
+ """
+ Creates an instance of an accelerator for distributed training (on multi-GPU, TPU) or mixed precision training.
+
+ Args:
+ device_placement (`bool`, *optional*, defaults to `True`):
+ Whether or not the accelerator should put objects on device (tensors yielded by the dataloader, model,
+ etc...).
+ mixed_precision (`str`, *optional*):
+ Whether or not to use mixed precision training. Choose from 'no','fp16','bf16 or 'fp8'. Will default to the
+ value in the environment variable `ACCELERATE_MIXED_PRECISION`, which will use the default value in the
+ accelerate config of the current system or the flag passed with the `accelerate.launch` command. 'fp8'
+ requires the installation of transformers-engine.
+ gradient_accumulation_steps (`int`, *optional*, default to 1):
+ The number of steps that should pass before gradients are accumulated. A number > 1 should be combined with
+ `Accelerator.accumulate`. If not passed, will default to the value in the environment variable
+ `ACCELERATE_GRADIENT_ACCUMULATION_STEPS`. Can also be configured through a `GradientAccumulationPlugin`.
+ cpu (`bool`, *optional*):
+ Whether or not to force the script to execute on CPU. Will ignore GPU available if set to `True` and force
+ the execution on one process only.
+ dataloader_config (`DataLoaderConfiguration`, *optional*):
+ A configuration for how the dataloaders should be handled in distributed scenarios.
+ deepspeed_plugin ([`~utils.DeepSpeedPlugin`], *optional*):
+ Tweak your DeepSpeed related args using this argument. This argument is optional and can be configured
+ directly using *accelerate config*
+ fsdp_plugin ([`~utils.FullyShardedDataParallelPlugin`], *optional*):
+ Tweak your FSDP related args using this argument. This argument is optional and can be configured directly
+ using *accelerate config*
+ megatron_lm_plugin ([`~utils.MegatronLMPlugin`], *optional*):
+ Tweak your MegatronLM related args using this argument. This argument is optional and can be configured
+ directly using *accelerate config*
+ rng_types (list of `str` or [`~utils.RNGType`]):
+ The list of random number generators to synchronize at the beginning of each iteration in your prepared
+ dataloaders. Should be one or several of:
+
+ - `"torch"`: the base torch random number generator
+ - `"cuda"`: the CUDA random number generator (GPU only)
+ - `"xla"`: the XLA random number generator (TPU only)
+ - `"generator"`: the `torch.Generator` of the sampler (or batch sampler if there is no sampler in your
+ dataloader) or of the iterable dataset (if it exists) if the underlying dataset is of that type.
+
+ Will default to `["torch"]` for PyTorch versions <=1.5.1 and `["generator"]` for PyTorch versions >= 1.6.
+ log_with (list of `str`, [`~utils.LoggerType`] or [`~tracking.GeneralTracker`], *optional*):
+ A list of loggers to be setup for experiment tracking. Should be one or several of:
+
+ - `"all"`
+ - `"tensorboard"`
+ - `"wandb"`
+ - `"comet_ml"`
+ If `"all"` is selected, will pick up all available trackers in the environment and initialize them. Can
+ also accept implementations of `GeneralTracker` for custom trackers, and can be combined with `"all"`.
+ project_config ([`~utils.ProjectConfiguration`], *optional*):
+ A configuration for how saving the state can be handled.
+ project_dir (`str`, `os.PathLike`, *optional*):
+ A path to a directory for storing data such as logs of locally-compatible loggers and potentially saved
+ checkpoints.
+ step_scheduler_with_optimizer (`bool`, *optional`, defaults to `True`):
+ Set `True` if the learning rate scheduler is stepped at the same time as the optimizer, `False` if only
+ done under certain circumstances (at the end of each epoch, for instance).
+ kwargs_handlers (list of [`~utils.KwargsHandler`], *optional*)
+ A list of [`~utils.KwargsHandler`] to customize how the objects related to distributed training or mixed
+ precision are created. See [kwargs](kwargs) for more information.
+ dynamo_backend (`str` or [`~utils.DynamoBackend`], *optional*, defaults to `"no"`):
+ Set to one of the possible dynamo backends to optimize your training with torch dynamo.
+ gradient_accumulation_plugin ([`~utils.GradientAccumulationPlugin`], *optional*):
+ A configuration for how gradient accumulation should be handled, if more tweaking than just the
+ `gradient_accumulation_steps` is needed.
+
+ **Available attributes:**
+
+ - **device** (`torch.device`) -- The device to use.
+ - **distributed_type** ([`~utils.DistributedType`]) -- The distributed training configuration.
+ - **local_process_index** (`int`) -- The process index on the current machine.
+ - **mixed_precision** (`str`) -- The configured mixed precision mode.
+ - **num_processes** (`int`) -- The total number of processes used for training.
+ - **optimizer_step_was_skipped** (`bool`) -- Whether or not the optimizer update was skipped (because of
+ gradient overflow in mixed precision), in which
+ case the learning rate should not be changed.
+ - **process_index** (`int`) -- The overall index of the current process among all processes.
+ - **state** ([`~state.AcceleratorState`]) -- The distributed setup state.
+ - **sync_gradients** (`bool`) -- Whether the gradients are currently being synced across all processes.
+ - **use_distributed** (`bool`) -- Whether the current configuration is for distributed training.
+ """
+
+ def __init__(
+ self,
+ device_placement: bool = True,
+ split_batches: bool = _split_batches,
+ mixed_precision: PrecisionType | str | None = None,
+ gradient_accumulation_steps: int = 1,
+ cpu: bool = False,
+ dataloader_config: DataLoaderConfiguration | None = None,
+ deepspeed_plugin: DeepSpeedPlugin | None = None,
+ fsdp_plugin: FullyShardedDataParallelPlugin | None = None,
+ megatron_lm_plugin: MegatronLMPlugin | None = None,
+ rng_types: list[str | RNGType] | None = None,
+ log_with: str | LoggerType | GeneralTracker | list[str | LoggerType | GeneralTracker] | None = None,
+ project_dir: str | os.PathLike | None = None,
+ project_config: ProjectConfiguration | None = None,
+ gradient_accumulation_plugin: GradientAccumulationPlugin | None = None,
+ dispatch_batches: bool | None = _dispatch_batches,
+ even_batches: bool = _even_batches,
+ use_seedable_sampler: bool = _use_seedable_sampler,
+ step_scheduler_with_optimizer: bool = True,
+ kwargs_handlers: list[KwargsHandler] | None = None,
+ dynamo_backend: DynamoBackend | str | None = None,
+ ):
+ self.trackers = []
+ if project_config is not None:
+ self.project_configuration = project_config
+ else:
+ self.project_configuration = ProjectConfiguration(project_dir=project_dir)
+ if project_dir is not None and self.project_dir is None:
+ self.project_configuration.set_directories(project_dir)
+ if mixed_precision is not None:
+ mixed_precision = str(mixed_precision)
+ if mixed_precision not in PrecisionType:
+ raise ValueError(
+ f"Unknown mixed_precision mode: {mixed_precision}. Choose between {PrecisionType.list()}"
+ )
+
+ dynamo_plugin = TorchDynamoPlugin() if dynamo_backend is None else TorchDynamoPlugin(backend=dynamo_backend)
+
+ if deepspeed_plugin is None: # init from env variables
+ deepspeed_plugin = (
+ DeepSpeedPlugin() if os.environ.get("ACCELERATE_USE_DEEPSPEED", "false") == "true" else None
+ )
+ else:
+ assert isinstance(
+ deepspeed_plugin, DeepSpeedPlugin
+ ), "`deepspeed_plugin` must be an `accelerate.utils.DeepSpeedPlugin` object."
+ os.environ["ACCELERATE_USE_DEEPSPEED"] = "true" # use DeepSpeed if plugin is provided
+ if deepspeed_plugin:
+ if not is_deepspeed_available():
+ raise ImportError("DeepSpeed is not installed => run `pip install deepspeed` or build it from source.")
+ if is_mlu_available():
+ if compare_versions("deepspeed-mlu", "<", "0.10.1"):
+ raise ImportError("DeepSpeed MLU version must be >= 0.10.1. Please update DeepSpeed MLU.")
+ elif compare_versions("deepspeed", "<", "0.9.3"):
+ raise ImportError("DeepSpeed version must be >= 0.9.3. Please update DeepSpeed.")
+
+ mixed_precision = (
+ os.environ.get("ACCELERATE_MIXED_PRECISION", "no") if mixed_precision is None else mixed_precision
+ )
+ deepspeed_plugin.set_mixed_precision(mixed_precision)
+ deepspeed_plugin.set_deepspeed_weakref()
+
+ if os.environ.get("ACCELERATE_USE_FSDP", "false") == "true" or isinstance(
+ fsdp_plugin, FullyShardedDataParallelPlugin
+ ):
+ if is_torch_version("<", FSDP_PYTORCH_VERSION):
+ raise ValueError(f"FSDP requires PyTorch >= {FSDP_PYTORCH_VERSION}")
+
+ if fsdp_plugin is None: # init from env variables
+ fsdp_plugin = (
+ FullyShardedDataParallelPlugin() if os.environ.get("ACCELERATE_USE_FSDP", "false") == "true" else None
+ )
+ else:
+ if not isinstance(fsdp_plugin, FullyShardedDataParallelPlugin):
+ raise TypeError("`fsdp_plugin` must be a FullyShardedDataParallelPlugin object.")
+ os.environ["ACCELERATE_USE_FSDP"] = "true" # use FSDP if plugin is provided
+
+ if megatron_lm_plugin is None: # init from env variables
+ megatron_lm_plugin = (
+ MegatronLMPlugin() if os.environ.get("ACCELERATE_USE_MEGATRON_LM", "false") == "true" else None
+ )
+ else:
+ if not isinstance(megatron_lm_plugin, MegatronLMPlugin):
+ raise TypeError("`megatron_lm_plugin` must be a MegatronLMPlugin object.")
+ os.environ["ACCELERATE_USE_MEGATRON_LM"] = "true" # use MegatronLM if plugin is provided
+
+ if megatron_lm_plugin:
+ if not is_megatron_lm_available():
+ raise ImportError("Megatron is not installed. please build it from source.")
+
+ # Kwargs handlers
+ self.ddp_handler = None
+ self.scaler_handler = None
+ self.init_handler = None
+ self.fp8_recipe_handler = None
+ self.autocast_handler = None
+ if kwargs_handlers is not None:
+ for handler in kwargs_handlers:
+ assert isinstance(
+ handler, KwargsHandler
+ ), f"Unsupported kwargs handler passed: {handler}, must be one that inherits `accelerate.utils.KwargsHandler`."
+ if isinstance(handler, DistributedDataParallelKwargs):
+ if self.ddp_handler is not None:
+ raise ValueError("You can only pass one `DistributedDataParallelKwargs` in `kwargs_handler`.")
+ else:
+ self.ddp_handler = handler
+ elif isinstance(handler, GradScalerKwargs):
+ if self.scaler_handler is not None:
+ raise ValueError("You can only pass one `GradScalerKwargs` in `kwargs_handler`.")
+ else:
+ self.scaler_handler = handler
+ elif isinstance(handler, InitProcessGroupKwargs):
+ if self.init_handler is not None:
+ raise ValueError("You can only pass one `InitProcessGroupKwargs` in `kwargs_handler`.")
+ else:
+ self.init_handler = handler
+ elif isinstance(handler, FP8RecipeKwargs):
+ if self.fp8_recipe_handler is not None:
+ raise ValueError("You can only pass one `FP8RecipeKwargs` in `kwargs_handler`.")
+ else:
+ self.fp8_recipe_handler = handler
+ elif isinstance(handler, AutocastKwargs):
+ if self.autocast_handler is not None:
+ raise ValueError("You can only pass one `AutocastKwargs` in `kwargs_handler`.")
+ else:
+ self.autocast_handler = handler
+
+ kwargs = self.init_handler.to_kwargs() if self.init_handler is not None else {}
+ self.state = AcceleratorState(
+ mixed_precision=mixed_precision,
+ cpu=cpu,
+ dynamo_plugin=dynamo_plugin,
+ deepspeed_plugin=deepspeed_plugin,
+ fsdp_plugin=fsdp_plugin,
+ megatron_lm_plugin=megatron_lm_plugin,
+ _from_accelerator=True,
+ **kwargs,
+ )
+
+ if self.fp8_recipe_handler is None and self.state.mixed_precision == "fp8":
+ self.fp8_recipe_handler = FP8RecipeKwargs(backend="MSAMP" if is_msamp_available() else "TE")
+
+ trackers = filter_trackers(log_with, self.logging_dir)
+ if len(trackers) < 1 and log_with is not None:
+ warnings.warn(f"`log_with={log_with}` was passed but no supported trackers are currently installed.")
+ self.log_with = trackers
+
+ if (
+ (mixed_precision != "bf16")
+ and getattr(self.state, "downcast_bfloat", False)
+ and (self.state.distributedType != DistributedType.XLA)
+ ):
+ raise ValueError("Can only use `downcast_bf16` when using `mixed_precision='bf16'` and on a TPU")
+
+ if gradient_accumulation_plugin is not None:
+ if gradient_accumulation_steps != 1:
+ raise ValueError(
+ "You can only pass one of `gradient_accumulation_steps` and `gradient_accumulation_plugin`. Please only pass in the created `GradientAccumulationPlugin` object."
+ )
+ else:
+ gradient_accumulation_steps = int(
+ parse_choice_from_env("ACCELERATE_GRADIENT_ACCUMULATION_STEPS", gradient_accumulation_steps)
+ )
+ gradient_accumulation_plugin = GradientAccumulationPlugin(num_steps=gradient_accumulation_steps)
+ self.gradient_state = GradientState(
+ gradient_accumulation_plugin=gradient_accumulation_plugin,
+ )
+
+ self.device_placement = device_placement
+ if dataloader_config is None:
+ dataloader_config = DataLoaderConfiguration()
+ self.dataloader_config = dataloader_config
+ # Deal with deprecated args
+ # TODO: Remove in v1.0.0
+ deprecated_dl_args = {}
+ if dispatch_batches is not _dispatch_batches:
+ deprecated_dl_args["dispatch_batches"] = dispatch_batches
+ self.dataloader_config.dispatch_batches = dispatch_batches
+ if split_batches is not _split_batches:
+ deprecated_dl_args["split_batches"] = split_batches
+ self.dataloader_config.split_batches = split_batches
+ if even_batches is not _even_batches:
+ deprecated_dl_args["even_batches"] = even_batches
+ self.dataloader_config.even_batches = even_batches
+ if use_seedable_sampler is not _use_seedable_sampler:
+ deprecated_dl_args["use_seedable_sampler"] = use_seedable_sampler
+ self.dataloader_config.use_seedable_sampler = use_seedable_sampler
+ if len(deprecated_dl_args) > 0:
+ values = ", ".join([f"{k}={v}" for k, v in deprecated_dl_args.items()])
+ warnings.warn(
+ f"Passing the following arguments to `Accelerator` is deprecated and will be removed in version 1.0 of Accelerate: {deprecated_dl_args.keys()}. "
+ "Please pass an `accelerate.DataLoaderConfiguration` instead: \n"
+ f"dataloader_config = DataLoaderConfiguration({values})",
+ FutureWarning,
+ )
+ self.step_scheduler_with_optimizer = step_scheduler_with_optimizer
+
+ # Mixed precision attributes
+ self.scaler = None
+ self.native_amp = False
+ if (
+ self.state.mixed_precision == "fp16"
+ and self.device.type != "cpu"
+ and self.distributed_type not in (DistributedType.DEEPSPEED, DistributedType.MEGATRON_LM)
+ ):
+ self.native_amp = True
+ if self.device.type not in ("xpu", "cuda", "mps", "npu", "xla", "mlu") or is_torch_xla_available(
+ check_is_tpu=True
+ ):
+ raise ValueError(f"fp16 mixed precision requires a GPU (not {self.device.type!r}).")
+ kwargs = self.scaler_handler.to_kwargs() if self.scaler_handler is not None else {}
+ if self.distributed_type == DistributedType.FSDP:
+ from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
+
+ self.scaler = ShardedGradScaler(**kwargs)
+ elif is_torch_xla_available(check_is_gpu=True):
+ self.scaler = xamp.GradScaler(**kwargs)
+ elif is_mlu_available():
+ self.scaler = torch.mlu.amp.GradScaler(**kwargs)
+ elif is_npu_available():
+ self.scaler = torch.npu.amp.GradScaler(**kwargs)
+ else:
+ self.scaler = torch.cuda.amp.GradScaler(**kwargs)
+
+ elif self.state.mixed_precision == "bf16" and self.distributed_type not in (
+ DistributedType.DEEPSPEED,
+ DistributedType.MEGATRON_LM,
+ ):
+ if self.device.type in ["cpu", "xpu"]:
+ self.native_amp = True
+ else:
+ self.native_amp = is_bf16_available(True)
+ if mixed_precision == "bf16" and not self.native_amp and not is_torch_xla_available():
+ raise ValueError("bf16 mixed precision requires PyTorch >= 1.10 and a supported device.")
+
+ # Start of internal step tracking
+ self.step = 0
+
+ # Internal references to the training objects
+ self._optimizers = []
+ self._models = []
+ self._schedulers = []
+ self._dataloaders = []
+ self._custom_objects = []
+
+ # Hooks
+ self._load_model_state_pre_hook = OrderedDict()
+ self._save_model_state_pre_hook = OrderedDict()
+
+ # RNG Types
+ self.rng_types = rng_types
+ if self.rng_types is None:
+ self.rng_types = ["generator"]
+
+ # Set a flag tensor for early stopping and other breakpoints
+ self.flag_tensor = None
+
+ check_os_kernel()
+
+ @property
+ def use_distributed(self):
+ """
+ Whether the Accelerator is configured for distributed training
+ """
+ return self.state.use_distributed
+
+ @property
+ def distributed_type(self):
+ return self.state.distributed_type
+
+ @property
+ def num_processes(self):
+ return self.state.num_processes
+
+ @property
+ def process_index(self):
+ return self.state.process_index
+
+ @property
+ def local_process_index(self):
+ return self.state.local_process_index
+
+ @property
+ def device(self):
+ return self.state.device
+
+ @property
+ def split_batches(self):
+ return self.dataloader_config.split_batches
+
+ @property
+ def dispatch_batches(self):
+ return self.dataloader_config.dispatch_batches
+
+ @property
+ def even_batches(self):
+ return self.dataloader_config.even_batches
+
+ @even_batches.setter
+ def even_batches(self, value: bool):
+ self.dataloader_config.even_batches = value
+
+ @property
+ def use_seedable_sampler(self):
+ return self.dataloader_config.use_seedable_sampler
+
+ @property
+ def project_dir(self):
+ return self.project_configuration.project_dir
+
+ @property
+ def logging_dir(self):
+ return self.project_configuration.logging_dir
+
+ @property
+ def save_iteration(self):
+ return self.project_configuration.iteration
+
+ @property
+ def is_main_process(self):
+ """True for one process only."""
+ return self.state.is_main_process
+
+ @property
+ def is_local_main_process(self):
+ """True for one process per server."""
+ return self.state.is_local_main_process
+
+ @property
+ def use_fp16(self):
+ warnings.warn(
+ "The `use_fp16` property is deprecated and will be removed in version 1.0 of Accelerate use "
+ "`Accelerator.mixed_precision == 'fp16'` instead.",
+ FutureWarning,
+ )
+ return self.mixed_precision != "no"
+
+ @property
+ def is_last_process(self):
+ return self.process_index == self.num_processes - 1
+
+ @property
+ def mixed_precision(self):
+ return self.state.mixed_precision
+
+ @contextmanager
+ def split_between_processes(self, inputs: list | tuple | dict | torch.Tensor, apply_padding: bool = False):
+ """
+ Splits `input` between `self.num_processes` quickly and can be then used on that process. Useful when doing
+ distributed inference, such as with different prompts.
+
+ Note that when using a `dict`, all keys need to have the same number of elements.
+
+ Args:
+ inputs (`list`, `tuple`, `torch.Tensor`, or `dict` of `list`/`tuple`/`torch.Tensor`):
+ The input to split between processes.
+ apply_padding (`bool`, `optional`, defaults to `False`):
+ Whether to apply padding by repeating the last element of the input so that all processes have the same
+ number of elements. Useful when trying to perform actions such as `Accelerator.gather()` on the outputs
+ or passing in less inputs than there are processes. If so, just remember to drop the padded elements
+ afterwards.
+
+ Example:
+
+ ```python
+ # Assume there are two processes
+ from accelerate import Accelerator
+
+ accelerator = Accelerator()
+ with accelerator.split_between_processes(["A", "B", "C"]) as inputs:
+ print(inputs)
+ # Process 0
+ ["A", "B"]
+ # Process 1
+ ["C"]
+
+ with accelerator.split_between_processes(["A", "B", "C"], apply_padding=True) as inputs:
+ print(inputs)
+ # Process 0
+ ["A", "B"]
+ # Process 1
+ ["C", "C"]
+ ```
+ """
+ with PartialState().split_between_processes(inputs, apply_padding=apply_padding) as inputs:
+ yield inputs
+
+ def on_main_process(self, function: Callable[..., Any] = None):
+ """
+ A decorator that will run the decorated function on the main process only. Can also be called using the
+ `PartialState` class.
+
+ Args:
+ function (`Callable`): The function to decorate.
+
+ Example:
+
+ ```python
+ >>> from accelerate import Accelerator
+
+ >>> accelerator = Accelerator()
+
+
+ >>> @accelerator.on_main_process
+ ... def print_something():
+ ... print("This will be printed by process 0 only.")
+
+
+ >>> print_something()
+ "This will be printed by process 0 only"
+ ```
+ """
+ # For times when the `Accelerator` object itself utilizes this decorator.
+ if function is None:
+ if "Accelerator." in self.__qualname__:
+ function = self
+ else:
+ raise ValueError(
+ "The `on_main_process` decorator must be called with a function on an instantiated `Accelerator` object."
+ )
+
+ def _inner(*args, **kwargs):
+ return PartialState().on_main_process(function)(*args, **kwargs)
+
+ return _inner
+
+ def on_local_main_process(self, function: Callable[..., Any] = None):
+ """
+ A decorator that will run the decorated function on the local main process only. Can also be called using the
+ `PartialState` class.
+
+ Args:
+ function (`Callable`): The function to decorate.
+
+ Example:
+ ```python
+ # Assume we have 2 servers with 4 processes each.
+ from accelerate import Accelerator
+
+ accelerator = Accelerator()
+
+
+ @accelerator.on_local_main_process
+ def print_something():
+ print("This will be printed by process 0 only on each server.")
+
+
+ print_something()
+ # On server 1:
+ "This will be printed by process 0 only"
+ # On server 2:
+ "This will be printed by process 0 only"
+ ```
+ """
+ # For times when the `Accelerator` object itself utilizes this decorator.
+ if function is None:
+ if "Accelerator." in self.__qualname__:
+ function = self
+ else:
+ raise ValueError(
+ "The `on_local_main_process` decorator must be called with a function on an instantiated `Accelerator` object."
+ )
+
+ def _inner(*args, **kwargs):
+ return PartialState().on_local_main_process(function)(*args, **kwargs)
+
+ return _inner
+
+ def on_last_process(self, function: Callable[..., Any]):
+ """
+ A decorator that will run the decorated function on the last process only. Can also be called using the
+ `PartialState` class.
+
+ Args:
+ function (`Callable`): The function to decorate.
+
+ Example:
+ ```python
+ # Assume we have 4 processes.
+ from accelerate import Accelerator
+
+ accelerator = Accelerator()
+
+
+ @accelerator.on_last_process
+ def print_something():
+ print(f"Printed on process {accelerator.process_index}")
+
+
+ print_something()
+ "Printed on process 3"
+ ```
+ """
+ # For times when the `Accelerator` object itself utilizes this decorator.
+ if function is None:
+ if "Accelerator." in self.__qualname__:
+ function = self
+ else:
+ raise ValueError(
+ "The `on_last_process` decorator must be called with a function on an instantiated `Accelerator` object."
+ )
+
+ def _inner(*args, **kwargs):
+ return PartialState().on_last_process(function)(*args, **kwargs)
+
+ return _inner
+
+ def on_process(self, function: Callable[..., Any] = None, process_index: int = None):
+ """
+ A decorator that will run the decorated function on a given process index only. Can also be called using the
+ `PartialState` class.
+
+ Args:
+ function (`Callable`, `optional`):
+ The function to decorate.
+ process_index (`int`, `optional`):
+ The index of the process on which to run the function.
+
+ Example:
+ ```python
+ # Assume we have 4 processes.
+ from accelerate import Accelerator
+
+ accelerator = Accelerator()
+
+
+ @accelerator.on_process(process_index=2)
+ def print_something():
+ print(f"Printed on process {accelerator.process_index}")
+
+
+ print_something()
+ "Printed on process 2"
+ ```
+ """
+ # Initial construction of the decorator.
+ if (self is not None) and (process_index is not None) and (function is None):
+ return partial(self.on_process, process_index=process_index)
+ # For times when the `Accelerator` object itself utilizes this decorator.
+ if function is None:
+ if "Accelerator." in self.__qualname__:
+ function = self
+ else:
+ raise ValueError(
+ "The `on_main_process` decorator must be called with a function on an instantiated `Accelerator` object."
+ )
+
+ def _inner(*args, **kwargs):
+ return PartialState().on_process(function, process_index)(*args, **kwargs)
+
+ return _inner
+
+ def on_local_process(self, function: Callable[..., Any] = None, local_process_index: int = None):
+ """
+ A decorator that will run the decorated function on a given local process index only. Can also be called using
+ the `PartialState` class.
+
+ Args:
+ function (`Callable`, *optional*):
+ The function to decorate.
+ local_process_index (`int`, *optional*):
+ The index of the local process on which to run the function.
+
+ Example:
+ ```python
+ # Assume we have 2 servers with 4 processes each.
+ from accelerate import Accelerator
+
+ accelerator = Accelerator()
+
+
+ @accelerator.on_local_process(local_process_index=2)
+ def print_something():
+ print(f"Printed on process {accelerator.local_process_index}")
+
+
+ print_something()
+ # On server 1:
+ "Printed on process 2"
+ # On server 2:
+ "Printed on process 2"
+ ```
+ """
+ # Initial construction of the decorator.
+ if (self is not None) and (local_process_index is not None) and (function is None):
+ return partial(self.on_local_process, local_process_index=local_process_index)
+ # For times when the `Accelerator` object itself utilizes this decorator.
+ if function is None:
+ if "Accelerator." in self.__qualname__:
+ function = self
+ else:
+ raise ValueError(
+ "The `on_main_process` decorator must be called with a function on an instantiated `Accelerator` object."
+ )
+
+ def _inner(*args, **kwargs):
+ return PartialState().on_local_process(function, local_process_index)(*args, **kwargs)
+
+ return _inner
+
+ @contextmanager
+ def main_process_first(self):
+ """
+ Lets the main process go first inside a with block.
+
+ The other processes will enter the with block after the main process exits.
+
+ Example:
+
+ ```python
+ >>> from accelerate import Accelerator
+
+ >>> accelerator = Accelerator()
+ >>> with accelerator.main_process_first():
+ ... # This will be printed first by process 0 then in a seemingly
+ ... # random order by the other processes.
+ ... print(f"This will be printed by process {accelerator.process_index}")
+ ```
+ """
+ with self.state.main_process_first():
+ yield
+
+ @contextmanager
+ def local_main_process_first(self):
+ """
+ Lets the local main process go inside a with block.
+
+ The other processes will enter the with block after the main process exits.
+
+ Example:
+
+ ```python
+ >>> from accelerate import Accelerator
+
+ >>> accelerator = Accelerator()
+ >>> with accelerator.local_main_process_first():
+ ... # This will be printed first by local process 0 then in a seemingly
+ ... # random order by the other processes.
+ ... print(f"This will be printed by process {accelerator.local_process_index}")
+ ```
+ """
+ with self.state.local_main_process_first():
+ yield
+
+ @contextmanager
+ def no_sync(self, model):
+ """
+ A context manager to disable gradient synchronizations across DDP processes by calling
+ `torch.nn.parallel.DistributedDataParallel.no_sync`.
+
+ If `model` is not in DDP, this context manager does nothing
+
+ Args:
+ model (`torch.nn.Module`):
+ PyTorch Module that was prepared with `Accelerator.prepare`
+
+ Example:
+
+ ```python
+ >>> from accelerate import Accelerator
+
+ >>> accelerator = Accelerator()
+ >>> dataloader, model, optimizer = accelerator.prepare(dataloader, model, optimizer)
+ >>> input_a = next(iter(dataloader))
+ >>> input_b = next(iter(dataloader))
+
+ >>> with accelerator.no_sync():
+ ... outputs = model(input_a)
+ ... loss = loss_func(outputs)
+ ... accelerator.backward(loss)
+ ... # No synchronization across processes, only accumulate gradients
+ >>> outputs = model(input_b)
+ >>> accelerator.backward(loss)
+ >>> # Synchronization across all processes
+ >>> optimizer.step()
+ >>> optimizer.zero_grad()
+ ```
+ """
+ context = contextlib.nullcontext
+ if self.use_distributed:
+ context = getattr(model, "no_sync", context)
+
+ with context():
+ yield
+
+ @staticmethod
+ @contextmanager
+ def trigger_sync_in_backward(model):
+ """Trigger the sync of the gradients in the next backward pass of the model after multiple forward passes under
+ `Accelerator.no_sync` (only applicable in multi-GPU scenarios).
+
+ If the script is not launched in distributed mode, this context manager does nothing.
+
+ Args:
+ model (`torch.nn.Module`):
+ The model for which to trigger the gradient synchronization.
+
+ Example:
+
+ ```python
+ >>> from accelerate import Accelerator
+
+ >>> accelerator = Accelerator()
+ >>> dataloader, model, optimizer = accelerator.prepare(dataloader, model, optimizer)
+
+ >>> with accelerator.no_sync():
+ ... loss_a = loss_func(model(input_a)) # first forward pass
+ ... loss_b = loss_func(model(input_b)) # second forward pass
+ >>> accelerator.backward(loss_a) # No synchronization across processes, only accumulate gradients
+ >>> with accelerator.trigger_sync_in_backward(model):
+ ... accelerator.backward(loss_b) # Synchronization across all processes
+ >>> optimizer.step()
+ >>> optimizer.zero_grad()
+ ```
+ """
+ if not isinstance(model, torch.nn.parallel.DistributedDataParallel):
+ yield
+ return
+
+ old_require_backward_grad_sync = model.require_backward_grad_sync
+ old_require_forward_param_sync = model.require_forward_param_sync
+
+ # EXPERIMENTAL: This will force grad sync during `backward()`, but it is unknown if it breaks other DDP features.
+ # https://github.com/pytorch/pytorch/blob/e1502c0cdbfd17548c612f25d5a65b1e4b86224d/torch/nn/parallel/distributed.py#L1453-L1466
+ model.require_backward_grad_sync = True
+ model.require_forward_param_sync = True
+ # https://github.com/pytorch/pytorch/blob/e1502c0cdbfd17548c612f25d5a65b1e4b86224d/torch/csrc/distributed/c10d/reducer.cpp#L1371-L1402
+ model.reducer.prepare_for_backward([])
+ try:
+ yield
+ finally:
+ model.require_backward_grad_sync = old_require_backward_grad_sync
+ model.require_forward_param_sync = old_require_forward_param_sync
+
+ def _do_sync(self, force: bool = False):
+ "Sets the right `sync_gradients` context and either resets or increases `self.step`"
+ if self.gradient_state.sync_with_dataloader and self.gradient_state.end_of_dataloader:
+ self.step = 0
+ self.gradient_state._set_sync_gradients(True)
+ else:
+ self.step += 1
+ self.gradient_state._set_sync_gradients(force or ((self.step % self.gradient_state.num_steps) == 0))
+
+ @property
+ def sync_gradients(self):
+ return self.gradient_state.sync_gradients
+
+ @sync_gradients.setter
+ def sync_gradients(self, sync_gradients):
+ self.gradient_state.sync_gradients = sync_gradients
+
+ @property
+ def gradient_accumulation_steps(self):
+ return self.gradient_state.num_steps
+
+ @gradient_accumulation_steps.setter
+ def gradient_accumulation_steps(self, gradient_accumulation_steps):
+ self.gradient_state.plugin_kwargs.update({"num_steps": gradient_accumulation_steps})
+
+ @contextmanager
+ def accumulate(self, *models):
+ """
+ A context manager that will lightly wrap around and perform gradient accumulation automatically
+
+ Args:
+ *models (list of `torch.nn.Module`):
+ PyTorch Modules that were prepared with `Accelerator.prepare`. Models passed to `accumulate()` will
+ skip gradient syncing during backward pass in distributed training
+
+ Example:
+
+ ```python
+ >>> from accelerate import Accelerator
+
+ >>> accelerator = Accelerator(gradient_accumulation_steps=1)
+ >>> dataloader, model, optimizer, scheduler = accelerator.prepare(dataloader, model, optimizer, scheduler)
+
+ >>> for input, output in dataloader:
+ ... with accelerator.accumulate(model):
+ ... outputs = model(input)
+ ... loss = loss_func(outputs)
+ ... loss.backward()
+ ... optimizer.step()
+ ... scheduler.step()
+ ... optimizer.zero_grad()
+ ```
+ """
+ # sync_each_batch=True will guarantee below that self.sync_gradients=True, therefore
+ # resulting in the nullcontext always being selected.
+ self._do_sync(force=self.gradient_state.plugin_kwargs.get("sync_each_batch", False))
+ with contextlib.ExitStack() as cm_stack:
+ for m in models:
+ cm_stack.enter_context(contextlib.nullcontext() if self.sync_gradients else self.no_sync(m))
+ yield
+
+ @contextmanager
+ def join_uneven_inputs(self, joinables, even_batches=None):
+ """
+ A context manager that facilitates distributed training or evaluation on uneven inputs, which acts as a wrapper
+ around `torch.distributed.algorithms.join`. This is useful when the total batch size does not evenly divide the
+ length of the dataset.
+
+ Args:
+ joinables (`list[torch.distributed.algorithms.Joinable]`):
+ A list of models or optimizers that subclass `torch.distributed.algorithms.Joinable`. Most commonly, a
+ PyTorch Module that was prepared with `Accelerator.prepare` for DistributedDataParallel training.
+ even_batches (`bool`, *optional*)
+ If set, this will override the value of `even_batches` set in the `Accelerator`. If it is not provided,
+ the default `Accelerator` value wil be used.
+
+
+
+ `join_uneven_inputs` is only supported for Distributed Data Parallel training on multiple GPUs. For any other
+ configuration, this method will have no effect.
+
+
+
+
+
+ Overidding `even_batches` will not affect iterable-style data loaders.
+
+
+
+ Example:
+
+ ```python
+ >>> from accelerate import Accelerator
+
+ >>> accelerator = Accelerator(even_batches=True)
+ >>> ddp_model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
+
+ >>> with accelerator.join_uneven_inputs([ddp_model], even_batches=False):
+ ... for input, output in dataloader:
+ ... outputs = model(input)
+ ... loss = loss_func(outputs)
+ ... loss.backward()
+ ... optimizer.step()
+ ... optimizer.zero_grad()
+ ```
+ """
+ if self.distributed_type in (
+ DistributedType.MULTI_GPU,
+ DistributedType.MULTI_NPU,
+ DistributedType.MULTI_MLU,
+ DistributedType.MULTI_XPU,
+ ):
+ dl_even_batches_values = []
+
+ if even_batches is not None:
+ iterable_dl_seen = False
+ # override value in batch sampler for map-style datasets
+ for dl_idx, dl in enumerate(self._dataloaders):
+ if isinstance(dl, DataLoaderDispatcher):
+ iterable_dl_seen = True
+ continue
+ dl_even_batches_values.append((dl_idx, dl.batch_sampler.even_batches))
+ dl.batch_sampler.even_batches = even_batches
+
+ if iterable_dl_seen:
+ warnings.warn(
+ "Overridding even_batches is only supported for map-style datasets, yet some dataloaders given were iterable"
+ )
+ else:
+ even_batches = self.even_batches
+
+ enable_join = False if even_batches else True
+ try:
+ with Join(joinables, enable=enable_join, throw_on_early_termination=False):
+ yield
+ finally:
+ # reset any batch samplers that have been modified
+ for dl_idx, even_batches_value in dl_even_batches_values:
+ self._dataloaders[dl_idx].batch_sampler.even_batches = even_batches_value
+ else:
+ # Even when disabled, Join expects models to subclass Joinable, so skip entirely for single process runs
+ if self.distributed_type != DistributedType.NO:
+ warnings.warn(
+ "Joining uneven inputs is only supported for multi-GPU training, as a result `join_uneven_inputs` will have no effect."
+ )
+
+ with contextlib.nullcontext(joinables):
+ yield
+
+ def print(self, *args, **kwargs):
+ """
+ Drop in replacement of `print()` to only print once per server.
+
+ Example:
+
+ ```python
+ >>> from accelerate import Accelerator
+
+ >>> accelerator = Accelerator()
+ >>> accelerator.print("Hello world!")
+ ```
+ """
+ self.state.print(*args, **kwargs)
+
+ def _prepare_one(self, obj, first_pass=False, device_placement=None):
+ # First pass of preparation: DataLoader, model, optimizer
+ if first_pass:
+ if isinstance(obj, torch.utils.data.DataLoader):
+ return self.prepare_data_loader(obj, device_placement=device_placement)
+ elif isinstance(obj, torch.nn.Module):
+ return self.prepare_model(obj, device_placement=device_placement)
+ elif isinstance(obj, torch.optim.Optimizer):
+ optimizer = self.prepare_optimizer(obj, device_placement=device_placement)
+ return optimizer
+ # Second pass of preparation: LR scheduler (which need the full list of optimizers)
+ elif isinstance(obj, LRScheduler):
+ scheduler = self.prepare_scheduler(obj)
+ return scheduler
+ # Return the unprocessed object if previous criteria was not met
+ return obj
+
+ def prepare(self, *args, device_placement=None):
+ """
+ Prepare all objects passed in `args` for distributed training and mixed precision, then return them in the same
+ order.
+
+ Args:
+ *args (list of objects):
+ Any of the following type of objects:
+
+ - `torch.utils.data.DataLoader`: PyTorch Dataloader
+ - `torch.nn.Module`: PyTorch Module
+ - `torch.optim.Optimizer`: PyTorch Optimizer
+ - `torch.optim.lr_scheduler.LRScheduler`: PyTorch LR Scheduler
+
+ device_placement (`list[bool]`, *optional*):
+ Used to customize whether automatic device placement should be performed for each object passed. Needs
+ to be a list of the same length as `args`. Not compatible with DeepSpeed or FSDP.
+
+
+
+ You don't need to prepare a model if you only use it for inference without any kind of mixed precision
+
+
+
+ Examples:
+
+ ```python
+ >>> from accelerate import Accelerator
+
+ >>> accelerator = Accelerator()
+ >>> # Assume a model, optimizer, data_loader and scheduler are defined
+ >>> model, optimizer, data_loader, scheduler = accelerator.prepare(model, optimizer, data_loader, scheduler)
+ ```
+
+ ```python
+ >>> from accelerate import Accelerator
+
+ >>> accelerator = Accelerator()
+ >>> # Assume a model, optimizer, data_loader and scheduler are defined
+ >>> device_placement = [True, True, False, False]
+ >>> # Will place the first to items passed in automatically to the right device but not the last two.
+ >>> model, optimizer, data_loader, scheduler = accelerator.prepare(
+ ... model, optimizer, data_loader, scheduler, device_placement=device_placement
+ ... )
+ ```
+ """
+ if device_placement is None:
+ device_placement = [None for _ in args]
+ elif self.distributed_type in (DistributedType.DEEPSPEED, DistributedType.MEGATRON_LM):
+ raise ValueError("You can't customize device placements with DeepSpeed or Megatron-LM.")
+ elif len(device_placement) != len(args):
+ raise ValueError(
+ f"`device_placement` should be a list with {len(args)} elements (the number of objects passed)."
+ )
+
+ for obj in args:
+ # TODO: Look at enabling native TP training directly with a proper config
+ if (
+ isinstance(obj, torch.nn.Module)
+ and self.verify_device_map(obj)
+ and self.distributed_type != DistributedType.NO
+ and os.environ.get("ACCELERATE_BYPASS_DEVICE_MAP", "false") != "true"
+ ):
+ raise ValueError(
+ "You can't train a model that has been loaded with `device_map='auto'` in any distributed mode."
+ " Please rerun your script specifying `--num_processes=1` or by launching with `python {{myscript.py}}`."
+ )
+
+ if self.distributed_type == DistributedType.DEEPSPEED:
+ model_count = 0
+ for obj in args:
+ if isinstance(obj, torch.nn.Module):
+ model_count += 1
+ if model_count > 1:
+ raise AssertionError(
+ "You can't use same `Accelerator()` instance with multiple models when using DeepSpeed"
+ )
+
+ # On TPUs, putting the model on the XLA device will create new parameters, so the corresponding optimizer will
+ # have parameters disconnected from the model (so no training :-( ).
+ # If the model and optimizer have parameters on different devices we raise an error.
+ if self.distributed_type == DistributedType.XLA:
+ model_device, optimizer_device = self._get_devices()
+ if model_device is not None and optimizer_device is not None and model_device != optimizer_device:
+ raise ValueError(
+ "The model and the optimizer parameters are not on the same device, which probably means you "
+ "created an optimizer around your model **before** putting on the device. Make sure the line "
+ "model.to(device) is before the optimizer creation in your script or remove it entirely and use "
+ "the flag default value for `device_placement` in your `Accelerator` to let it handle that "
+ "part for you."
+ )
+
+ # If we're dealing with device placement, this deals with that by...
+ tpu_should_fix_optimizer = self.device_placement and self.distributed_type == DistributedType.XLA
+ if tpu_should_fix_optimizer or (self.mixed_precision == "fp8" and self.fp8_recipe_handler.backend == "TE"):
+ # 1. grabbing old model parameters
+ old_named_params = self._get_named_parameters(*args)
+
+ if self.distributed_type in [DistributedType.MULTI_CPU, DistributedType.MULTI_XPU, DistributedType.NO]:
+ if self.device.type == "cpu" and self.state.use_ipex:
+ args = self._prepare_ipex(*args)
+ elif self.device.type == "xpu" and is_xpu_available():
+ args = self._prepare_ipex(*args)
+ if self.distributed_type == DistributedType.DEEPSPEED:
+ result = self._prepare_deepspeed(*args)
+ elif self.distributed_type == DistributedType.MEGATRON_LM:
+ result = self._prepare_megatron_lm(*args)
+ else:
+ if self.mixed_precision == "fp8" and self.fp8_recipe_handler.backend == "MSAMP":
+ args = self._prepare_msamp(*args)
+ # MS-AMP will handle the device placement
+ device_placement = [False for _ in args]
+ result = tuple(
+ self._prepare_one(obj, first_pass=True, device_placement=d) for obj, d in zip(args, device_placement)
+ )
+ result = tuple(self._prepare_one(obj, device_placement=d) for obj, d in zip(result, device_placement))
+
+ if tpu_should_fix_optimizer or (self.mixed_precision == "fp8" and self.fp8_recipe_handler.backend == "TE"):
+ # 2. grabbing new model parameters
+ new_named_params = self._get_named_parameters(*result)
+ # 3. building a map from the first to the second
+ mapping = {p: new_named_params[n] for n, p in old_named_params.items()}
+ # 4. using that map to update the parameters of the optimizer
+ for obj in result:
+ if isinstance(obj, torch.optim.Optimizer):
+ obj._switch_parameters(mapping)
+
+ for item in result:
+ if any(
+ item in container
+ for container in (self._dataloaders, self._models, self._optimizers, self._schedulers)
+ ):
+ item._is_accelerate_prepared = True
+
+ return result if len(result) > 1 else result[0]
+
+ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, evaluation_mode: bool = False):
+ """
+ Prepares a PyTorch model for training in any distributed setup. It is recommended to use
+ [`Accelerator.prepare`] instead.
+
+ Args:
+ model (`torch.nn.Module`):
+ A PyTorch model to prepare. You don't need to prepare a model if it is used only for inference without
+ any kind of mixed precision
+ device_placement (`bool`, *optional*):
+ Whether or not to place the model on the proper device. Will default to `self.device_placement`.
+ evaluation_mode (`bool`, *optional*, defaults to `False`):
+ Whether or not to set the model for evaluation only, by just applying mixed precision and
+ `torch.compile` (if configured in the `Accelerator` object).
+
+ Example:
+
+ ```python
+ >>> from accelerate import Accelerator
+
+ >>> accelerator = Accelerator()
+ >>> # Assume a model is defined
+ >>> model = accelerator.prepare_model(model)
+ ```
+ """
+ if device_placement is None:
+ device_placement = self.device_placement and self.distributed_type != DistributedType.FSDP
+ self._models.append(model)
+
+ # TODO: Look at enabling native TP training directly with a proper config
+ if (
+ self.verify_device_map(model)
+ and self.distributed_type != DistributedType.NO
+ and os.environ.get("ACCELERATE_BYPASS_DEVICE_MAP", "false") != "true"
+ ):
+ raise ValueError(
+ "You can't train a model that has been loaded with `device_map='auto'` in any distributed mode."
+ " Please rerun your script specifying `--num_processes=1` or by launching with `python {{myscript.py}}`."
+ )
+
+ if self.native_amp:
+ model._original_forward = model.forward
+ model_forward_func = model.forward.__func__ if hasattr(model.forward, "__func__") else model.forward
+ autocast_context = get_mixed_precision_context_manager(self.native_amp, self.autocast_handler)
+ new_forward = autocast_context(model_forward_func)
+ if hasattr(model.forward, "__func__"):
+ model.forward = MethodType(new_forward, model)
+ model.forward = MethodType(convert_outputs_to_fp32(model.forward.__func__), model)
+ else:
+ model.forward = convert_outputs_to_fp32(new_forward)
+ elif self.mixed_precision == "fp8" and self.fp8_recipe_handler.backend == "TE":
+ if not has_transformer_engine_layers(model):
+ with torch.no_grad():
+ convert_model(model)
+ model._converted_to_transformer_engine = True
+ model._original_forward = model.forward
+
+ kwargs = self.fp8_recipe_handler.to_kwargs() if self.fp8_recipe_handler is not None else {}
+ if "fp8_format" in kwargs:
+ kwargs["fp8_format"] = getattr(te_recipe.Format, kwargs["fp8_format"])
+ fp8_recipe = te_recipe.DelayedScaling(**kwargs)
+ model.forward = fp8_autocast(enabled=True, fp8_recipe=fp8_recipe)(model.forward)
+
+ if (getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False)) and getattr(
+ model, "hf_device_map", False
+ ):
+ model_devices = set(model.hf_device_map.values())
+ if len(model_devices) > 1 and self.distributed_type != DistributedType.NO:
+ raise ValueError(
+ "You can't train a model that has been loaded in 8-bit precision on multiple devices in any distributed mode."
+ " In order to use 8-bit models that have been loaded across multiple GPUs the solution is to use Naive Pipeline Parallelism."
+ " Therefore you should not specify that you are under any distributed regime in your accelerate config."
+ )
+ current_device = list(model_devices)[0]
+ current_device_index = current_device.index if isinstance(current_device, torch.device) else current_device
+
+ if torch.device(current_device_index) != self.device:
+ # if on the first device (GPU 0) we don't care
+ if (self.device.index is not None) or (current_device_index != 0):
+ raise ValueError(
+ "You can't train a model that has been loaded in 8-bit precision on a different device than the one "
+ "you're training on. Make sure you loaded the model on the correct device using for example `device_map={'':torch.cuda.current_device() or device_map={'':torch.xpu.current_device()}"
+ )
+
+ if "cpu" in model_devices or "disk" in model_devices:
+ raise ValueError(
+ "You can't train a model that has been loaded in 8-bit precision with CPU or disk offload."
+ )
+ elif device_placement and not self.verify_device_map(model):
+ model = model.to(self.device)
+ if not evaluation_mode:
+ if self.distributed_type in (
+ DistributedType.MULTI_GPU,
+ DistributedType.MULTI_MLU,
+ DistributedType.MULTI_NPU,
+ DistributedType.MULTI_XPU,
+ ):
+ if any(p.requires_grad for p in model.parameters()):
+ kwargs = self.ddp_handler.to_kwargs() if self.ddp_handler is not None else {}
+ # TODO: Look at enabling native TP training directly with a proper config
+ if os.environ.get("ACCELERATE_BYPASS_DEVICE_MAP", "false") != "true":
+ device_ids, output_device = [self.local_process_index], self.local_process_index
+ else:
+ device_ids, output_device = None, None
+
+ model = torch.nn.parallel.DistributedDataParallel(
+ model, device_ids=device_ids, output_device=output_device, **kwargs
+ )
+ elif self.distributed_type == DistributedType.FSDP:
+ from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
+
+ # Check if the model is already a FSDP model due to `Manual Wrapping` and if so,
+ # don't wrap it again
+ # In case the model is already compiled using PyTorch 2.0 and the wrapped model in it
+ # is a FSDP model, don't wrap it again
+ is_type_fsdp = isinstance(model, FSDP) or (
+ is_compiled_module(model) and isinstance(model._orig_mod, FSDP)
+ )
+
+ if not is_type_fsdp:
+ self.state.fsdp_plugin.set_auto_wrap_policy(model)
+ fsdp_plugin = self.state.fsdp_plugin
+ kwargs = {
+ "sharding_strategy": fsdp_plugin.sharding_strategy,
+ "cpu_offload": fsdp_plugin.cpu_offload,
+ "auto_wrap_policy": fsdp_plugin.auto_wrap_policy,
+ "mixed_precision": fsdp_plugin.mixed_precision_policy,
+ "sync_module_states": fsdp_plugin.sync_module_states,
+ "backward_prefetch": fsdp_plugin.backward_prefetch,
+ "forward_prefetch": fsdp_plugin.forward_prefetch,
+ "use_orig_params": fsdp_plugin.use_orig_params,
+ "param_init_fn": fsdp_plugin.param_init_fn,
+ "ignored_modules": fsdp_plugin.ignored_modules,
+ "limit_all_gathers": fsdp_plugin.limit_all_gathers,
+ "device_id": self.device,
+ }
+ model = FSDP(model, **kwargs)
+ if fsdp_plugin.activation_checkpointing:
+ from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
+ CheckpointImpl,
+ apply_activation_checkpointing,
+ checkpoint_wrapper,
+ )
+
+ apply_activation_checkpointing(
+ model,
+ checkpoint_wrapper_fn=functools.partial(
+ checkpoint_wrapper,
+ checkpoint_impl=CheckpointImpl.NO_REENTRANT,
+ ),
+ auto_wrap_policy=fsdp_plugin.auto_wrap_policy,
+ )
+ # if the previous and current models are same, delete the previous one
+ if len(self._models) > 1 and (self._models[-2] is self._models[-1]):
+ del self._models[-2]
+ self._models[-1] = model
+ elif self.distributed_type == DistributedType.MULTI_CPU:
+ kwargs = self.ddp_handler.to_kwargs() if self.ddp_handler is not None else {}
+ model = torch.nn.parallel.DistributedDataParallel(model, **kwargs)
+ elif self.distributed_type == DistributedType.XLA and self.state.fork_launched:
+ model = xmp.MpModelWrapper(model).to(self.device)
+ # torch.compile should be called last and only if the model isn't already compiled.
+ if self.state.dynamo_plugin.backend != DynamoBackend.NO and not is_compiled_module(model):
+ if not is_torch_version(">=", "2.0"):
+ raise ValueError("Using `torch.compile` requires PyTorch 2.0 or higher.")
+ model = torch.compile(model, **self.state.dynamo_plugin.to_kwargs())
+ return model
+
+ def _prepare_deepspeed(self, *args):
+ import deepspeed
+
+ deepspeed_plugin = self.state.deepspeed_plugin
+
+ is_dataloader_present = any(isinstance(obj, torch.utils.data.DataLoader) for obj in args)
+ result = [
+ self._prepare_one(obj, first_pass=True) if isinstance(obj, torch.utils.data.DataLoader) else obj
+ for obj in args
+ ]
+
+ if deepspeed_plugin.is_auto("train_micro_batch_size_per_gpu"):
+ if is_dataloader_present:
+ batch_sizes = [obj.batch_size for obj in args if hasattr(obj, "batch_size")]
+ if any(bs is None for bs in batch_sizes):
+ raise ValueError(
+ "At least one of the dataloaders passed to `accelerate.prepare()` has `None` as batch size. "
+ "Please set an integer value in `train_micro_batch_size_per_gpu` in the deepspeed config file "
+ "or assign integer value to `AcceleratorState().deepspeed_plugin.deepspeed_config['train_micro_batch_size_per_gpu']`."
+ )
+ if self.split_batches:
+ batch_sizes = [batch_size // self.num_processes for batch_size in batch_sizes]
+
+ batch_size_per_device = min(batch_sizes) if deepspeed_plugin.is_train_batch_min else max(batch_sizes)
+ if len(batch_sizes) > 1:
+ logger.info(
+ "Since you passed both train and evaluation dataloader, `is_train_batch_min` (here "
+ f"{deepspeed_plugin.is_train_batch_min} will decide the `train_batch_size` ({batch_size_per_device})."
+ )
+ else:
+ raise ValueError(
+ "When using DeepSpeed, `accelerate.prepare()` requires you to pass at least one of training or evaluation dataloaders "
+ "with `batch_size` attribute returning an integer value "
+ "or alternatively set an integer value in `train_micro_batch_size_per_gpu` in the deepspeed config file "
+ "or assign integer value to `AcceleratorState().deepspeed_plugin.deepspeed_config['train_micro_batch_size_per_gpu']`."
+ )
+ else:
+ batch_size_per_device = deepspeed_plugin.get_value("train_micro_batch_size_per_gpu")
+
+ # handle `gradient_accumulation_steps` when the value is `auto`
+ deepspeed_plugin.fill_match(
+ "gradient_accumulation_steps",
+ must_match=False,
+ gradient_accumulation_steps=self.gradient_accumulation_steps,
+ )
+
+ config_kwargs = {
+ "train_micro_batch_size_per_gpu": batch_size_per_device,
+ "train_batch_size": batch_size_per_device
+ * deepspeed_plugin.get_value("gradient_accumulation_steps")
+ * self.num_processes,
+ "gradient_clipping": 1.0,
+ "zero_optimization.stage3_gather_16bit_weights_on_model_save": False,
+ }
+
+ model = None
+ optimizer = None
+ scheduler = None
+ for obj in result:
+ if isinstance(obj, torch.nn.Module):
+ model = obj
+ elif isinstance(obj, (torch.optim.Optimizer, DummyOptim)):
+ optimizer = obj
+ elif (isinstance(obj, (LRScheduler, DummyScheduler))) or (
+ type(obj).__name__ in deepspeed.runtime.lr_schedules.VALID_LR_SCHEDULES
+ ):
+ scheduler = obj
+
+ if optimizer is not None:
+ if "optimizer" in deepspeed_plugin.deepspeed_config and not isinstance(optimizer, (DummyOptim)):
+ raise ValueError(
+ "You cannot specify an optimizer in the config file and in the code at the same time. "
+ "Please remove the optimizer from the config file or "
+ "create `accelerate.utils.DummyOptim` in the code."
+ )
+ elif "optimizer" not in deepspeed_plugin.deepspeed_config and isinstance(optimizer, (DummyOptim)):
+ raise ValueError(
+ "You cannot create a `DummyOptim` without specifying an optimizer in the config file."
+ )
+
+ if isinstance(optimizer, (torch.optim.Optimizer)):
+ deepspeed_plugin.deepspeed_config["zero_allow_untested_optimizer"] = True
+
+ if scheduler is not None:
+ if "scheduler" in deepspeed_plugin.deepspeed_config and not isinstance(scheduler, (DummyScheduler)):
+ raise ValueError(
+ "You cannot specify a scheduler in the config file and in the code at the same time. "
+ "Please remove the scheduler from the config file or "
+ "create `accelerate.utils.DummyScheduler` in the code."
+ )
+ elif (
+ "scheduler" not in deepspeed_plugin.deepspeed_config
+ and isinstance(scheduler, (DummyScheduler))
+ and scheduler.lr_scheduler_callable is None
+ ):
+ raise ValueError(
+ "Either specify a scheduler in the config file or "
+ "pass in the `lr_scheduler_callable` parameter when using `accelerate.utils.DummyScheduler`."
+ )
+
+ if optimizer is not None and scheduler is not None:
+ if isinstance(optimizer, (DummyOptim)) and not isinstance(scheduler, (DummyScheduler)):
+ raise ValueError(
+ "You can only specify `accelerate.utils.DummyScheduler` in the code when using "
+ "`accelerate.utils.DummyOptim`."
+ )
+
+ if model is not None:
+ # deal with config keys that use `auto` value and rely on model's hidden_size
+ hidden_size_based_keys = [
+ "zero_optimization.reduce_bucket_size",
+ "zero_optimization.stage3_prefetch_bucket_size",
+ "zero_optimization.stage3_param_persistence_threshold",
+ ]
+ hidden_size_auto_keys = [x for x in hidden_size_based_keys if deepspeed_plugin.is_auto(x)]
+ if len(hidden_size_auto_keys) > 0:
+ reasoning = (
+ "therefore it's not possible to automatically fill out the following `auto` entries "
+ + f"in the DeepSpeed config file: {hidden_size_auto_keys}. You can fix that by replacing "
+ + "`auto` values for these keys with an integer value of your choice."
+ )
+ if not hasattr(model, "config"):
+ raise ValueError("Can't find `model.config` entry, " + reasoning)
+
+ if hasattr(model.config, "hidden_size"):
+ hidden_size = model.config.hidden_size
+ elif hasattr(model.config, "hidden_sizes"):
+ # if there are many hidden sizes pick the largest one
+ hidden_size = max(model.config.hidden_sizes)
+ else:
+ raise ValueError(
+ "Can find neither `model.config.hidden_size` nor `model.config.hidden_sizes`, " + reasoning
+ )
+
+ config_kwargs.update(
+ {
+ "zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
+ "zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size,
+ "zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
+ }
+ )
+
+ if isinstance(optimizer, (DummyOptim)):
+ config_kwargs.update(
+ {"optimizer.params.lr": optimizer.lr, "optimizer.params.weight_decay": optimizer.weight_decay}
+ )
+ if isinstance(scheduler, (DummyScheduler)) and scheduler.lr_scheduler_callable is None:
+ max_lr = (
+ getattr(scheduler.optimizer, "lr", None)
+ if getattr(scheduler.optimizer, "defaults", None) is None
+ else scheduler.optimizer.defaults["lr"]
+ )
+ config_kwargs.update(
+ {
+ "scheduler.params.warmup_min_lr": 0,
+ "scheduler.params.warmup_max_lr": max_lr,
+ "scheduler.params.warmup_num_steps": scheduler.warmup_num_steps,
+ }
+ )
+ if scheduler.total_num_steps is not None:
+ config_kwargs["scheduler.params.total_num_steps"] = (
+ math.ceil(scheduler.total_num_steps / self.num_processes)
+ if not self.split_batches
+ else scheduler.total_num_steps
+ )
+ deepspeed_plugin.deepspeed_config_process(must_match=False, **config_kwargs)
+ self.deepspeed_config = deepspeed_plugin.deepspeed_config
+ kwargs = dict(model=model, config_params=self.deepspeed_config)
+ if optimizer is not None:
+ if isinstance(optimizer, (DummyOptim)):
+ kwargs["model_parameters"] = optimizer.params
+ if isinstance(scheduler, (DummyScheduler)) and scheduler.lr_scheduler_callable is not None:
+ kwargs["lr_scheduler"] = scheduler.lr_scheduler_callable
+ else:
+ if self.deepspeed_config["zero_optimization"].get("offload_optimizer", {}).get(
+ "device", "none"
+ ) != "none" and self.deepspeed_config.get("zero_force_ds_cpu_optimizer", True):
+ from deepspeed.ops.adam import DeepSpeedCPUAdam
+
+ defaults = {k: v for k, v in optimizer.defaults.items() if k in ["lr", "weight_decay"]}
+ optimizer = DeepSpeedCPUAdam(optimizer.param_groups, **defaults)
+ kwargs["optimizer"] = optimizer
+ if scheduler is not None:
+ if type(scheduler).__name__ in deepspeed.runtime.lr_schedules.VALID_LR_SCHEDULES:
+ kwargs["lr_scheduler"] = scheduler
+
+ engine, optimizer, _, lr_scheduler = deepspeed.initialize(**kwargs)
+ if optimizer is not None:
+ optimizer = DeepSpeedOptimizerWrapper(optimizer)
+ if scheduler is not None:
+ if lr_scheduler is None:
+ scheduler = AcceleratedScheduler(
+ scheduler,
+ optimizer,
+ step_with_optimizer=self.step_scheduler_with_optimizer,
+ split_batches=self.split_batches,
+ )
+ else:
+ scheduler = DeepSpeedSchedulerWrapper(lr_scheduler, optimizer)
+
+ for i in range(len(result)):
+ if isinstance(result[i], torch.nn.Module):
+ result[i] = engine
+ elif isinstance(result[i], (torch.optim.Optimizer, DummyOptim)):
+ result[i] = optimizer
+ elif (isinstance(result[i], (LRScheduler, DummyScheduler))) or (
+ type(result[i]).__name__ in deepspeed.runtime.lr_schedules.VALID_LR_SCHEDULES
+ ):
+ result[i] = scheduler
+ # pointing for deepspeed_engine_wrapped.backward()
+ self.deepspeed_engine_wrapped = DeepSpeedEngineWrapper(engine)
+ self._models.append(engine)
+ if optimizer is not None:
+ self._optimizers.append(optimizer)
+ if scheduler is not None:
+ self._schedulers.append(scheduler)
+ if len(self._models) > 1:
+ raise AssertionError(
+ "You can't use same `Accelerator()` instance with multiple models when using DeepSpeed"
+ )
+ return tuple(result)
+
+ def _prepare_megatron_lm(self, *args):
+ megatron_lm_plugin = self.state.megatron_lm_plugin
+ if not megatron_lm_plugin.megatron_dataset_flag:
+ batch_sizes = [obj.batch_size for obj in args if hasattr(obj, "batch_size")]
+ if len(batch_sizes) == 0:
+ raise ValueError(
+ "You must specify a training or evaluation dataloader in `accelerate.prepare()` when using Megatron-LM."
+ )
+
+ micro_batch_size = min(batch_sizes) if megatron_lm_plugin.is_train_batch_min else max(batch_sizes)
+ if len(batch_sizes) > 1:
+ logger.info(
+ "Since you passed both train and evaluation dataloader, `is_train_batch_min` (here "
+ f"{megatron_lm_plugin.is_train_batch_min} will decide the `train_batch_size` ({micro_batch_size})."
+ )
+ else:
+ for obj in args:
+ if isinstance(obj, MegatronLMDummyDataLoader):
+ micro_batch_size = obj.dataset_args["micro_batch_size"]
+ break
+
+ dp_degree = self.num_processes // (megatron_lm_plugin.tp_degree * megatron_lm_plugin.pp_degree)
+ megatron_lm_plugin.set_training_args(micro_batch_size, dp_degree)
+
+ model = None
+ optimizer = None
+ scheduler = None
+ is_dummy_scheduler = False
+ batch_data = None
+ for obj in args:
+ if isinstance(obj, torch.utils.data.DataLoader) and batch_data is None:
+ batch_data = next(iter(obj))
+ if isinstance(obj, torch.nn.Module):
+ model = obj
+ elif isinstance(obj, (torch.optim.Optimizer)):
+ optimizer = obj
+ elif isinstance(obj, (LRScheduler, MegatronLMDummyScheduler)):
+ scheduler = obj
+
+ if model is not None:
+ megatron_lm_plugin.set_network_size_args(model, batch_data)
+ if optimizer is not None:
+ megatron_lm_plugin.set_optimizer_type(optimizer)
+ if scheduler is not None:
+ is_dummy_scheduler = isinstance(scheduler, MegatronLMDummyScheduler)
+ if not is_dummy_scheduler:
+ raise ValueError(
+ "You can't use a custom scheduler with Megatron-LM. Please use the `accelerate.utils.MegatronLMDummyScheduler` instead."
+ )
+ megatron_lm_plugin.set_scheduler_args(scheduler)
+
+ # initialize megatron-lm
+ megatron_lm_initialize(self, args_defaults=megatron_lm_plugin.megatron_lm_default_args)
+ counter = 0
+ result = []
+ for obj in args:
+ if isinstance(obj, torch.utils.data.DataLoader):
+ result.append(megatron_lm_prepare_data_loader(self, obj))
+ counter += 1
+ elif isinstance(obj, MegatronLMDummyDataLoader):
+ if counter == 0:
+ obj.set_megatron_data_args()
+ dataloaders = megatron_lm_prepare_data_loader(self, obj)
+ result.append(dataloaders[counter])
+ counter += 1
+ else:
+ result.append(obj)
+
+ if model is not None:
+ model = megatron_lm_prepare_model(self)
+ if optimizer is not None:
+ optimizer = megatron_lm_prepare_optimizer(self, model)
+ if scheduler is not None:
+ scheduler = megatron_lm_prepare_scheduler(self, optimizer, scheduler)
+
+ if model is not None:
+ model = MegatronEngine(self, model, optimizer, scheduler)
+ if optimizer is not None:
+ optimizer = MegatronLMOptimizerWrapper(optimizer)
+ if scheduler is not None:
+ scheduler = MegatronLMSchedulerWrapper(scheduler, optimizer)
+
+ for i in range(len(result)):
+ if isinstance(result[i], torch.nn.Module):
+ result[i] = model
+ elif isinstance(result[i], torch.optim.Optimizer):
+ result[i] = optimizer
+ elif isinstance(result[i], MegatronLMDummyScheduler):
+ result[i] = scheduler
+ if model is not None:
+ self._models.append(model)
+ if optimizer is not None:
+ self._optimizers.append(optimizer)
+ if scheduler is not None:
+ self._schedulers.append(scheduler)
+ if len(self._models) > 1:
+ raise AssertionError(
+ "You can't use same `Accelerator()` instance with multiple models when using Megatron-LM"
+ )
+ return tuple(result)
+
+ def _prepare_ipex(self, *args):
+ if not is_ipex_available():
+ raise ImportError(
+ "IPEX is not installed or IPEX's version does not match current PyTorch version. Please refer"
+ " to https://github.com/intel/intel-extension-for-pytorch."
+ )
+ else:
+ import intel_extension_for_pytorch as ipex
+
+ model = None
+ optimizer = None
+ result = [obj for obj in args]
+ for obj in result:
+ if isinstance(obj, torch.nn.Module):
+ model = obj
+ model.train()
+ elif isinstance(obj, (torch.optim.Optimizer)):
+ optimizer = obj
+ if optimizer is not None and model is not None:
+ dtype = torch.bfloat16 if self.state.mixed_precision == "bf16" else None
+ if self.device.type == "xpu" and is_xpu_available():
+ model = model.to(self.device)
+ model, optimizer = torch.xpu.optimize(
+ model, optimizer=optimizer, dtype=dtype, inplace=True, level="O1"
+ )
+ else:
+ model, optimizer = ipex.optimize(model, optimizer=optimizer, dtype=dtype, inplace=True, level="O1")
+ for i in range(len(result)):
+ if isinstance(result[i], torch.nn.Module):
+ result[i] = model
+ elif isinstance(result[i], (torch.optim.Optimizer)):
+ result[i] = optimizer
+ return tuple(result)
+
+ def _prepare_msamp(self, *args):
+ if not is_msamp_available():
+ raise ImportError(
+ "MS-AMP was not found on your system. Please ensure that MS-AMP is available "
+ " or choose `'te'` as the backend for FP8 mixed precision training."
+ )
+ else:
+ import msamp
+
+ model, optimizer = None, None
+ num_models, num_optimizers = 0, 0
+ result = [obj for obj in args]
+ for obj in result:
+ if isinstance(obj, torch.nn.Module):
+ model = obj
+ num_models += 1
+ elif isinstance(obj, (torch.optim.Optimizer)):
+ optimizer = obj
+ num_optimizers += 1
+ if optimizer is None or model is None:
+ raise ValueError(
+ "You must pass a model and an optimizer together to `accelerate.prepare()` when using MS-AMP."
+ )
+ elif num_models > 1 or num_optimizers > 1:
+ raise ValueError(
+ f"You can't use multiple models ({num_models}) or optimizers {num_optimizers} with MS-AMP."
+ )
+ else:
+ model, optimizer = msamp.initialize(model, optimizer, opt_level=self.fp8_recipe_handler.opt_level)
+ for i in range(len(result)):
+ if isinstance(result[i], torch.nn.Module):
+ result[i] = model
+ elif isinstance(result[i], (torch.optim.Optimizer)):
+ result[i] = optimizer
+ return tuple(result)
+
+ def prepare_data_loader(
+ self, data_loader: torch.utils.data.DataLoader, device_placement=None, slice_fn_for_dispatch=None
+ ):
+ """
+ Prepares a PyTorch DataLoader for training in any distributed setup. It is recommended to use
+ [`Accelerator.prepare`] instead.
+
+ Args:
+ data_loader (`torch.utils.data.DataLoader`):
+ A vanilla PyTorch DataLoader to prepare
+ device_placement (`bool`, *optional*):
+ Whether or not to place the batches on the proper device in the prepared dataloader. Will default to
+ `self.device_placement`.
+ slice_fn_for_dispatch (`Callable`, *optional*`):
+ If passed, this function will be used to slice tensors across `num_processes`. Will default to
+ [`~utils.slice_tensors`]. This argument is used only when `dispatch_batches` is set to `True` and will
+ be ignored otherwise.
+
+ Example:
+
+ ```python
+ >>> import torch
+ >>> from accelerate import Accelerator
+
+ >>> accelerator = Accelerator()
+ >>> data_loader = torch.utils.data.DataLoader(...)
+ >>> data_loader = accelerator.prepare_data_loader(data_loader, device_placement=True)
+ ```
+ """
+ # Ensure we can't double wrap a DataLoader due to `find_batch_size`
+ if getattr(data_loader, "_is_accelerate_prepared", False):
+ if data_loader not in self._dataloaders:
+ self._dataloaders.append(data_loader)
+ return data_loader
+ if device_placement is None:
+ device_placement = self.device_placement if self.distributed_type != DistributedType.XLA else False
+ prepared_data_loader = prepare_data_loader(
+ data_loader,
+ self.device,
+ num_processes=self.num_processes,
+ process_index=self.process_index,
+ split_batches=self.split_batches,
+ put_on_device=device_placement,
+ rng_types=self.rng_types.copy(),
+ dispatch_batches=self.dispatch_batches,
+ even_batches=self.even_batches,
+ slice_fn_for_dispatch=slice_fn_for_dispatch,
+ use_seedable_sampler=self.use_seedable_sampler,
+ )
+ self._dataloaders.append(prepared_data_loader)
+ return prepared_data_loader
+
+ def prepare_optimizer(self, optimizer: torch.optim.Optimizer, device_placement=None):
+ """
+ Prepares a PyTorch Optimizer for training in any distributed setup. It is recommended to use
+ [`Accelerator.prepare`] instead.
+
+ Args:
+ optimizer (`torch.optim.Optimizer`):
+ A vanilla PyTorch optimizer to prepare
+ device_placement (`bool`, *optional*):
+ Whether or not to place the optimizer on the proper device. Will default to `self.device_placement`.
+
+ Example:
+
+ ```python
+ >>> import torch
+ >>> from accelerate import Accelerator
+
+ >>> accelerator = Accelerator()
+ >>> optimizer = torch.optim.Adam(...)
+ >>> optimizer = accelerator.prepare_optimizer(optimizer, device_placement=True)
+ ```
+ """
+ # Ensure we can't double wrap an optimizer due to `find_batch_size`
+ if getattr(optimizer, "_is_accelerate_prepared", False):
+ if optimizer not in self._optimizers:
+ self._optimizers.append(optimizer)
+ return optimizer
+ if device_placement is None:
+ device_placement = self.device_placement
+ optimizer = AcceleratedOptimizer(optimizer, device_placement=device_placement, scaler=self.scaler)
+ self._optimizers.append(optimizer)
+ return optimizer
+
+ def prepare_scheduler(self, scheduler: LRScheduler):
+ """
+ Prepares a PyTorch Scheduler for training in any distributed setup. It is recommended to use
+ [`Accelerator.prepare`] instead.
+
+ Args:
+ scheduler (`torch.optim.lr_scheduler.LRScheduler`):
+ A vanilla PyTorch scheduler to prepare
+
+ Example:
+
+ ```python
+ >>> import torch
+ >>> from accelerate import Accelerator
+
+ >>> accelerator = Accelerator()
+ >>> optimizer = torch.optim.Adam(...)
+ >>> scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, ...)
+ >>> scheduler = accelerator.prepare_scheduler(scheduler)
+ ```
+ """
+ # Ensure we can't double wrap a scheduler due to `find_batch_size`
+ if getattr(scheduler, "_is_accelerate_prepared", False):
+ if scheduler not in self._schedulers:
+ self._schedulers.append(scheduler)
+ return scheduler
+ # We try to find the optimizer associated with `scheduler`, the default is the full list.
+ optimizer = self._optimizers
+ for opt in self._optimizers:
+ if getattr(scheduler, "optimizer", None) == opt.optimizer:
+ optimizer = opt
+ break
+ scheduler = AcceleratedScheduler(
+ scheduler,
+ optimizer,
+ step_with_optimizer=self.step_scheduler_with_optimizer,
+ split_batches=self.split_batches,
+ )
+ self._schedulers.append(scheduler)
+ return scheduler
+
+ def backward(self, loss, **kwargs):
+ """
+ Scales the gradients in accordance to the `GradientAccumulationPlugin` and calls the correct `backward()` based
+ on the configuration.
+
+ Should be used in lieu of `loss.backward()`.
+
+ Example:
+
+ ```python
+ >>> from accelerate import Accelerator
+
+ >>> accelerator = Accelerator(gradient_accumulation_steps=2)
+ >>> outputs = model(inputs)
+ >>> loss = loss_fn(outputs, labels)
+ >>> accelerator.backward(loss)
+ ```
+ """
+ if self.distributed_type != DistributedType.DEEPSPEED:
+ # deepspeed handles loss scaling by gradient_accumulation_steps in its `backward`
+ loss = loss / self.gradient_accumulation_steps
+ if self.distributed_type == DistributedType.DEEPSPEED:
+ self.deepspeed_engine_wrapped.backward(loss, **kwargs)
+ elif self.distributed_type == DistributedType.MEGATRON_LM:
+ return
+ elif self.scaler is not None:
+ self.scaler.scale(loss).backward(**kwargs)
+ else:
+ loss.backward(**kwargs)
+
+ def set_trigger(self):
+ """
+ Sets the internal trigger tensor to 1 on the current process. A latter check should follow using this which
+ will check across all processes.
+
+ Note:
+ Does not require `wait_for_everyone()`
+
+ Example:
+
+ ```python
+ >>> from accelerate import Accelerator
+
+ >>> accelerator = Accelerator()
+ >>> # Assume later in the training script
+ >>> # `should_do_breakpoint` is a custom function to monitor when to break,
+ >>> # e.g. when the loss is NaN
+ >>> if should_do_breakpoint(loss):
+ ... accelerator.set_trigger()
+ >>> # Assume later in the training script
+ >>> if accelerator.check_breakpoint():
+ ... break
+ ```
+ """
+ self.flag_tensor = torch.tensor(1, device=self.device)
+
+ def check_trigger(self):
+ """
+ Checks if the internal trigger tensor has been set to 1 in any of the processes. If so, will return `True` and
+ reset the trigger tensor to 0.
+
+ Note:
+ Does not require `wait_for_everyone()`
+
+ Example:
+
+ ```python
+ >>> from accelerate import Accelerator
+
+ >>> accelerator = Accelerator()
+ >>> # Assume later in the training script
+ >>> # `should_do_breakpoint` is a custom function to monitor when to break,
+ >>> # e.g. when the loss is NaN
+ >>> if should_do_breakpoint(loss):
+ ... accelerator.set_trigger()
+ >>> # Assume later in the training script
+ >>> if accelerator.check_trigger():
+ ... break
+ ```
+ """
+ # Now that we are outside `__init__`, we can initialize it if it is `None` on device
+ if self.flag_tensor is None:
+ self.flag_tensor = torch.tensor(0, device=self.device)
+ flag_tensor = self.reduce(self.flag_tensor)
+ if flag_tensor.item() >= 1:
+ self.flag_tensor = torch.tensor(0, device=self.device)
+ return True
+ return False
+
+ def unscale_gradients(self, optimizer=None):
+ """
+ Unscale the gradients in mixed precision training with AMP. This is a noop in all other settings.
+
+ Likely should be called through [`Accelerator.clip_grad_norm_`] or [`Accelerator.clip_grad_value_`]
+
+ Args:
+ optimizer (`torch.optim.Optimizer` or `list[torch.optim.Optimizer]`, *optional*):
+ The optimizer(s) for which to unscale gradients. If not set, will unscale gradients on all optimizers
+ that were passed to [`~Accelerator.prepare`].
+
+ Example:
+
+ ```python
+ >>> from accelerate import Accelerator
+
+ >>> accelerator = Accelerator()
+ >>> model, optimizer = accelerator.prepare(model, optimizer)
+ >>> outputs = model(inputs)
+ >>> loss = loss_fn(outputs, labels)
+ >>> accelerator.backward(loss)
+ >>> accelerator.unscale_gradients(optimizer=optimizer)
+ ```
+ """
+ if self.native_amp and self.mixed_precision == "fp16":
+ if optimizer is None:
+ # TODO: this unscales all optimizers where we should only unscale the one where parameters are.
+ optimizer = self._optimizers
+ elif not isinstance(optimizer, (tuple, list)):
+ optimizer = [optimizer]
+ for opt in optimizer:
+ while isinstance(opt, AcceleratedOptimizer):
+ opt = opt.optimizer
+ self.scaler.unscale_(opt)
+
+ def clip_grad_norm_(self, parameters, max_norm, norm_type=2):
+ """
+ Should be used in place of `torch.nn.utils.clip_grad_norm_`.
+
+ Returns:
+ `torch.Tensor`: Total norm of the parameter gradients (viewed as a single vector).
+
+ Example:
+
+ ```python
+ >>> from accelerate import Accelerator
+
+ >>> accelerator = Accelerator(gradient_accumulation_steps=2)
+ >>> dataloader, model, optimizer, scheduler = accelerator.prepare(dataloader, model, optimizer, scheduler)
+
+ >>> for input, target in dataloader:
+ ... optimizer.zero_grad()
+ ... output = model(input)
+ ... loss = loss_func(output, target)
+ ... accelerator.backward(loss)
+ ... if accelerator.sync_gradients:
+ ... accelerator.clip_grad_norm_(model.parameters(), max_grad_norm)
+ ... optimizer.step()
+ ```
+ """
+ if self.distributed_type == DistributedType.FSDP:
+ self.unscale_gradients()
+ parameters = [p for p in parameters]
+ for model in self._models:
+ if parameters == [p for p in model.parameters()]:
+ return model.clip_grad_norm_(max_norm, norm_type)
+ elif self.distributed_type == DistributedType.DEEPSPEED:
+ # `accelerator.backward(loss)` is doing that automatically. Therefore, its implementation is not needed
+ # We cannot return the gradient norm because DeepSpeed does it.
+ return None
+ elif self.distributed_type == DistributedType.XLA:
+ # Reduce gradients first for XLA
+ for acc_opt in self._optimizers:
+ if not acc_opt.gradient_state.is_xla_gradients_synced:
+ opt = acc_opt
+ while isinstance(opt, AcceleratedOptimizer):
+ opt = opt.optimizer
+ gradients = xm._fetch_gradients(opt)
+ # Use xm.all_reduce to perform an in-place all-reduce. Recusrsive all-reduce each tensor
+ # one by one in self.reduce is non-inplace.
+ xm.all_reduce("sum", gradients, scale=1.0 / self.num_processes)
+ # Set is_xla_gradients_synced to True to avoid all-reduce twice in the AcceleratedOptimizer step.
+ acc_opt.gradient_state.is_xla_gradients_synced = True
+ self.unscale_gradients()
+ return torch.nn.utils.clip_grad_norm_(parameters, max_norm, norm_type=norm_type)
+
+ def clip_grad_value_(self, parameters, clip_value):
+ """
+ Should be used in place of `torch.nn.utils.clip_grad_value_`.
+
+ Example:
+
+ ```python
+ >>> from accelerate import Accelerator
+
+ >>> accelerator = Accelerator(gradient_accumulation_steps=2)
+ >>> dataloader, model, optimizer, scheduler = accelerator.prepare(dataloader, model, optimizer, scheduler)
+
+ >>> for input, target in dataloader:
+ ... optimizer.zero_grad()
+ ... output = model(input)
+ ... loss = loss_func(output, target)
+ ... accelerator.backward(loss)
+ ... if accelerator.sync_gradients:
+ ... accelerator.clip_grad_value_(model.parameters(), clip_value)
+ ... optimizer.step()
+ ```
+ """
+ if self.distributed_type in [DistributedType.DEEPSPEED, DistributedType.FSDP]:
+ raise Exception("DeepSpeed and FSDP do not support `clip_grad_value_`. Use `clip_grad_norm_` instead.")
+ self.unscale_gradients()
+ torch.nn.utils.clip_grad_value_(parameters, clip_value)
+
+ def gather(self, tensor):
+ """
+ Gather the values in *tensor* across all processes and concatenate them on the first dimension. Useful to
+ regroup the predictions from all processes when doing evaluation.
+
+ Note:
+ This gather happens in all processes.
+
+ Args:
+ tensor (`torch.Tensor`, or a nested tuple/list/dictionary of `torch.Tensor`):
+ The tensors to gather across all processes.
+
+ Returns:
+ `torch.Tensor`, or a nested tuple/list/dictionary of `torch.Tensor`: The gathered tensor(s). Note that the
+ first dimension of the result is *num_processes* multiplied by the first dimension of the input tensors.
+
+ Example:
+
+ ```python
+ >>> # Assuming four processes
+ >>> import torch
+ >>> from accelerate import Accelerator
+
+ >>> accelerator = Accelerator()
+ >>> process_tensor = torch.tensor([accelerator.process_index])
+ >>> gathered_tensor = accelerator.gather(process_tensor)
+ >>> gathered_tensor
+ tensor([0, 1, 2, 3])
+ ```
+ """
+ return gather(tensor)
+
+ def gather_for_metrics(self, input_data):
+ """
+ Gathers `input_data` and potentially drops duplicates in the last batch if on a distributed system. Should be
+ used for gathering the inputs and targets for metric calculation.
+
+ Args:
+ input (`torch.Tensor`, `object`, a nested tuple/list/dictionary of `torch.Tensor`, or a nested tuple/list/dictionary of `object`):
+ The tensors or objects for calculating metrics across all processes
+
+ Example:
+
+ ```python
+ >>> # Assuming two processes, with a batch size of 5 on a dataset with 9 samples
+ >>> import torch
+ >>> from accelerate import Accelerator
+
+ >>> accelerator = Accelerator()
+ >>> dataloader = torch.utils.data.DataLoader(range(9), batch_size=5)
+ >>> dataloader = accelerator.prepare(dataloader)
+ >>> batch = next(iter(dataloader))
+ >>> gathered_items = accelerator.gather_for_metrics(batch)
+ >>> len(gathered_items)
+ 9
+ ```
+ """
+
+ try:
+ recursively_apply(lambda x: x, input_data, error_on_other_type=True)
+ all_tensors = True
+ except TypeError:
+ all_tensors = False
+
+ if not all_tensors:
+ data = gather_object(input_data)
+ else:
+ data = self.gather(input_data)
+
+ try:
+ if self.gradient_state.end_of_dataloader:
+ # at the end of a dataloader, `gather_for_metrics` regresses to
+ # `gather` unless the dataset has a remainder so log.
+ if self.gradient_state.remainder == -1:
+ logger.info(
+ "The used dataset had no length, returning gathered tensors. You should drop the remainder yourself."
+ )
+ return data
+ elif self.gradient_state.remainder > 0:
+ # Last batch needs to be truncated on distributed systems as it contains additional samples
+ def _adjust_samples(tensor):
+ return tensor[: self.gradient_state.remainder]
+
+ return recursively_apply(_adjust_samples, data)
+ else: # remainder is 0
+ # no remainder even though at end of dataloader, so nothing to do.
+ return data
+ else:
+ # Not at the end of the dataloader, no need to adjust the tensors
+ return data
+ except Exception:
+ # Dataset had no length or raised an error
+ return data
+
+ def reduce(self, tensor, reduction="sum", scale=1.0):
+ """
+ Reduce the values in *tensor* across all processes based on *reduction*.
+
+ Note:
+ All processes get the reduced value.
+
+ Args:
+ tensor (`torch.Tensor`, or a nested tuple/list/dictionary of `torch.Tensor`):
+ The tensors to reduce across all processes.
+ reduction (`str`, *optional*, defaults to "sum"):
+ A reduction type, can be one of 'sum', 'mean', or 'none'. If 'none', will not perform any operation.
+ scale (`float`, *optional*, defaults to 1.0):
+ A default scaling value to be applied after the reduce, only valied on XLA.
+
+ Returns:
+ `torch.Tensor`, or a nested tuple/list/dictionary of `torch.Tensor`:
+ The reduced tensor(s).
+
+ Example:
+
+ ```python
+ >>> # Assuming two processes
+ >>> import torch
+ >>> from accelerate import Accelerator
+
+ >>> accelerator = Accelerator()
+ >>> process_tensor = torch.arange(accelerator.num_processes) + 1 + (2 * accelerator.process_index)
+ >>> process_tensor = process_tensor.to(accelerator.device)
+ >>> reduced_tensor = accelerator.reduce(process_tensor, reduction="sum")
+ >>> reduced_tensor
+ tensor([4, 6])
+ ```
+ """
+ return reduce(tensor, reduction, scale)
+
+ def pad_across_processes(self, tensor, dim=0, pad_index=0, pad_first=False):
+ """
+ Recursively pad the tensors in a nested list/tuple/dictionary of tensors from all devices to the same size so
+ they can safely be gathered.
+
+ Args:
+ tensor (nested list/tuple/dictionary of `torch.Tensor`):
+ The data to gather.
+ dim (`int`, *optional*, defaults to 0):
+ The dimension on which to pad.
+ pad_index (`int`, *optional*, defaults to 0):
+ The value with which to pad.
+ pad_first (`bool`, *optional*, defaults to `False`):
+ Whether to pad at the beginning or the end.
+
+ Returns:
+ `torch.Tensor`, or a nested tuple/list/dictionary of `torch.Tensor`:
+ The padded tensor(s).
+
+ Example:
+
+ ```python
+ >>> # Assuming two processes, with the first processes having a tensor of size 1 and the second of size 2
+ >>> import torch
+ >>> from accelerate import Accelerator
+
+ >>> accelerator = Accelerator()
+ >>> process_tensor = torch.arange(accelerator.process_index + 1).to(accelerator.device)
+ >>> padded_tensor = accelerator.pad_across_processes(process_tensor)
+ >>> padded_tensor.shape
+ torch.Size([2])
+ ```
+ """
+ return pad_across_processes(tensor, dim=dim, pad_index=pad_index, pad_first=pad_first)
+
+ def unwrap_model(self, model, keep_fp32_wrapper: bool = True):
+ """
+ Unwraps the `model` from the additional layer possible added by [`~Accelerator.prepare`]. Useful before saving
+ the model.
+
+ Args:
+ model (`torch.nn.Module`):
+ The model to unwrap.
+ keep_fp32_wrapper (`bool`, *optional*, defaults to `True`):
+ Whether to not remove the mixed precision hook if it was added.
+
+ Returns:
+ `torch.nn.Module`: The unwrapped model.
+
+ Example:
+
+ ```python
+ >>> # Assuming two GPU processes
+ >>> from torch.nn.parallel import DistributedDataParallel
+ >>> from accelerate import Accelerator
+
+ >>> accelerator = Accelerator()
+ >>> model = accelerator.prepare(MyModel())
+ >>> print(model.__class__.__name__)
+ DistributedDataParallel
+
+ >>> model = accelerator.unwrap_model(model)
+ >>> print(model.__class__.__name__)
+ MyModel
+ ```
+ """
+ return extract_model_from_parallel(model, keep_fp32_wrapper)
+
+ def wait_for_everyone(self):
+ """
+ Will stop the execution of the current process until every other process has reached that point (so this does
+ nothing when the script is only run in one process). Useful to do before saving a model.
+
+ Example:
+
+ ```python
+ >>> # Assuming two GPU processes
+ >>> import time
+ >>> from accelerate import Accelerator
+
+ >>> accelerator = Accelerator()
+ >>> if accelerator.is_main_process:
+ ... time.sleep(2)
+ >>> else:
+ ... print("I'm waiting for the main process to finish its sleep...")
+ >>> accelerator.wait_for_everyone()
+ >>> # Should print on every process at the same time
+ >>> print("Everyone is here")
+ ```
+ """
+ wait_for_everyone()
+
+ @on_main_process
+ def init_trackers(self, project_name: str, config: dict | None = None, init_kwargs: dict | None = {}):
+ """
+ Initializes a run for all trackers stored in `self.log_with`, potentially with starting configurations
+
+ Args:
+ project_name (`str`):
+ The name of the project. All trackers will save their data based on this
+ config (`dict`, *optional*):
+ Optional starting configuration to be logged.
+ init_kwargs (`dict`, *optional*):
+ A nested dictionary of kwargs to be passed to a specific tracker's `__init__` function. Should be
+ formatted like so:
+ ```python
+ {"wandb": {"tags": ["tag_a", "tag_b"]}}
+ ```
+
+ Example:
+
+ ```python
+ >>> from accelerate import Accelerator
+
+ >>> accelerator = Accelerator(log_with="tensorboard")
+ >>> accelerator.init_trackers(
+ ... project_name="my_project",
+ ... config={"learning_rate": 0.001, "batch_size": 32},
+ ... init_kwargs={"tensorboard": {"flush_secs": 60}},
+ ... )
+ ```
+ """
+ for tracker in self.log_with:
+ if issubclass(type(tracker), GeneralTracker):
+ # Custom trackers are already initialized
+ self.trackers.append(tracker)
+ else:
+ tracker_init = LOGGER_TYPE_TO_CLASS[str(tracker)]
+ if tracker_init.requires_logging_directory:
+ # We can skip this check since it was done in `__init__`
+ self.trackers.append(
+ tracker_init(project_name, self.logging_dir, **init_kwargs.get(str(tracker), {}))
+ )
+ else:
+ self.trackers.append(tracker_init(project_name, **init_kwargs.get(str(tracker), {})))
+ if config is not None:
+ for tracker in self.trackers:
+ tracker.store_init_configuration(config)
+
+ def get_tracker(self, name: str, unwrap: bool = False):
+ """
+ Returns a `tracker` from `self.trackers` based on `name` on the main process only.
+
+ Args:
+ name (`str`):
+ The name of a tracker, corresponding to the `.name` property.
+ unwrap (`bool`):
+ Whether to return the internal tracking mechanism or to return the wrapped tracker instead
+ (recommended).
+
+ Returns:
+ `GeneralTracker`: The tracker corresponding to `name` if it exists.
+
+ Example:
+
+ ```python
+ >>> from accelerate import Accelerator
+
+ >>> accelerator = Accelerator(log_with="tensorboard")
+ >>> accelerator.init_trackers("my_project")
+ >>> tensorboard_tracker = accelerator.get_tracker("tensorboard")
+ ```
+ """
+ if len(self.trackers) > 0:
+ for tracker in self.trackers:
+ if tracker.name == name:
+ return tracker.tracker if unwrap else tracker
+ raise ValueError(f"{name} is not an available tracker stored inside the `Accelerator`.")
+ # Handle tracker only made on main process
+ return GeneralTracker(_blank=True)
+
+ @on_main_process
+ def log(self, values: dict, step: int | None = None, log_kwargs: dict | None = {}):
+ """
+ Logs `values` to all stored trackers in `self.trackers` on the main process only.
+
+ Args:
+ values (`dict`):
+ Values should be a dictionary-like object containing only types `int`, `float`, or `str`.
+ step (`int`, *optional*):
+ The run step. If included, the log will be affiliated with this step.
+ log_kwargs (`dict`, *optional*):
+ A nested dictionary of kwargs to be passed to a specific tracker's `log` function. Should be formatted
+ like so:
+ ```python
+ {"wandb": {"tags": ["tag_a", "tag_b"]}}
+ ```
+
+ Example:
+
+ ```python
+ >>> from accelerate import Accelerator
+
+ >>> accelerator = Accelerator(log_with="tensorboard")
+ >>> accelerator.init_trackers("my_project")
+ >>> accelerator.log({"loss": 0.5, "accuracy": 0.9})
+ ```
+ """
+ for tracker in self.trackers:
+ tracker.log(values, step=step, **log_kwargs.get(tracker.name, {}))
+
+ @on_main_process
+ def end_training(self):
+ """
+ Runs any special end training behaviors, such as stopping trackers on the main process only. Should always be
+ called at the end of your script if using experiment tracking.
+
+ Example:
+
+ ```python
+ >>> from accelerate import Accelerator
+
+ >>> accelerator = Accelerator(log_with="tensorboard")
+ >>> accelerator.init_trackers("my_project")
+ >>> # Do training
+ >>> accelerator.end_training()
+ ```
+ """
+ for tracker in self.trackers:
+ tracker.finish()
+
+ def save(self, obj, f, safe_serialization=False):
+ """
+ Save the object passed to disk once per machine. Use in place of `torch.save`.
+
+ Args:
+ obj (`object`): The object to save.
+ f (`str` or `os.PathLike`): Where to save the content of `obj`.
+ safe_serialization (`bool`, *optional*, defaults to `False`): Whether to save `obj` using `safetensors`
+
+ Note:
+ If `save_on_each_node` was passed in as a `ProjectConfiguration`, will save the object once per node,
+ rather than only once on the main node.
+
+ Example:
+
+ ```python
+ >>> from accelerate import Accelerator
+
+ >>> accelerator = Accelerator()
+ >>> arr = [0, 1, 2, 3]
+ >>> accelerator.save(arr, "array.pkl")
+ ```
+ """
+ save(
+ obj,
+ f,
+ save_on_each_node=self.project_configuration.save_on_each_node,
+ safe_serialization=safe_serialization,
+ )
+
+ def save_model(
+ self,
+ model: torch.nn.Module,
+ save_directory: Union[str, os.PathLike],
+ max_shard_size: Union[int, str] = "10GB",
+ safe_serialization: bool = True,
+ ):
+ """
+ Save a model so that it can be re-loaded using load_checkpoint_in_model
+
+ Arguments:
+ model: (`torch.nn.Module`):
+ Model to be saved. The model can be wrapped or unwraped.
+ save_directory (`str` or `os.PathLike`):
+ Directory to which to save. Will be created if it doesn't exist.
+ max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`):
+ The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size
+ lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5MB"`).
+
+
+
+ If a single weight of the model is bigger than `max_shard_size`, it will be in its own checkpoint shard
+ which will be bigger than `max_shard_size`.
+
+
+
+ safe_serialization (`bool`, *optional*, defaults to `True`):
+ Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
+
+ Example:
+
+ ```python
+ >>> from accelerate import Accelerator
+
+ >>> accelerator = Accelerator()
+ >>> model = ...
+ >>> accelerator.save_model(model, save_directory)
+ ```
+ """
+
+ if os.path.isfile(save_directory):
+ logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
+ return
+
+ os.makedirs(save_directory, exist_ok=True)
+
+ # get the state_dict of the model
+ if any(
+ [
+ module._hf_hook.offload
+ for module in model.modules()
+ if hasattr(module, "_hf_hook") and isinstance(module._hf_hook, AlignDevicesHook)
+ ]
+ ):
+ state_dict = get_state_dict_offloaded_model(model)
+ else:
+ if any(param.device == torch.device("meta") for param in model.parameters()):
+ raise RuntimeError("You can't save the model since some parameters are on the meta device.")
+ state_dict = self.get_state_dict(model)
+
+ if safe_serialization:
+ state_dict = clean_state_dict_for_safetensors(state_dict)
+ weights_name = SAFE_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
+
+ # Shard the model if it is too big.
+ shards, index = shard_checkpoint(state_dict, max_shard_size=max_shard_size, weights_name=weights_name)
+
+ # Clean the folder from a previous save
+ for filename in os.listdir(save_directory):
+ full_filename = os.path.join(save_directory, filename)
+ # If we have a shard file that is not going to be replaced, we delete it, but only from the main process
+ # in distributed settings to avoid race conditions.
+ weights_no_suffix = weights_name.replace(".bin", "")
+
+ # make sure that file to be deleted matches format of sharded file, e.g. pytorch_model-00001-of-00005
+ filename_no_suffix = filename.replace(".bin", "")
+ reg = re.compile(r"(.*?)-\d{5}-of-\d{5}")
+
+ if (
+ filename.startswith(weights_no_suffix)
+ and os.path.isfile(full_filename)
+ and filename not in shards.keys()
+ and reg.fullmatch(filename_no_suffix) is not None
+ and PartialState().is_main_process
+ ):
+ os.remove(full_filename)
+
+ # Save the model
+ for shard_file, shard in shards.items():
+ self.save(shard, os.path.join(save_directory, shard_file), safe_serialization=safe_serialization)
+
+ if index is None:
+ path_to_weights = os.path.join(save_directory, WEIGHTS_NAME)
+ logger.info(f"Model weights saved in {path_to_weights}")
+ else:
+ save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME
+ save_index_file = os.path.join(save_directory, save_index_file)
+ # Save the index as well
+ with open(save_index_file, "w", encoding="utf-8") as f:
+ content = json.dumps(index, indent=2, sort_keys=True) + "\n"
+ f.write(content)
+ logger.info(
+ f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be "
+ f"split in {len(shards)} checkpoint shards. You can find where each parameters has been saved in the "
+ f"index located at {save_index_file}."
+ )
+
+ def register_save_state_pre_hook(self, hook: Callable[..., None]) -> hooks.RemovableHandle:
+ """
+ Registers a pre hook to be run before `save_checkpoint` is called in [`Accelerator.save_state`].
+
+ Args:
+ hook (`Callable`):
+ A function to be called in [`Accelerator.save_state`] before `save_checkpoint`.
+
+ The hook should have the following signature:
+
+ `hook(models: list[torch.nn.Module], weights: list[dict[str, torch.Tensor]], input_dir: str) -> None`
+
+ The `models` argument are the models as saved in the accelerator state under `accelerator._models`, `weigths`
+ argument are the state dicts of the `models`, and the `input_dir` argument is the `input_dir` argument passed
+ to [`Accelerator.load_state`].
+
+
+
+ Should only be used in conjunction with [`Accelerator.register_load_state_pre_hook`]. Can be useful to save
+ configurations in addition to model weights. Can also be used to overwrite model saving with a customized
+ method. In this case, make sure to remove already loaded weights from the weights list.
+
+
+
+ Returns:
+ `torch.utils.hooks.RemovableHandle`: a handle that can be used to remove the added hook by calling
+ `handle.remove()`
+ """
+ handle = hooks.RemovableHandle(self._save_model_state_pre_hook)
+ self._save_model_state_pre_hook[handle.id] = hook
+ return handle
+
+ def save_state(self, output_dir: str = None, safe_serialization: bool = True, **save_model_func_kwargs):
+ """
+ Saves the current states of the model, optimizer, scaler, RNG generators, and registered objects to a folder.
+
+ If a `ProjectConfiguration` was passed to the `Accelerator` object with `automatic_checkpoint_naming` enabled
+ then checkpoints will be saved to `self.project_dir/checkpoints`. If the number of current saves is greater
+ than `total_limit` then the oldest save is deleted. Each checkpoint is saved in seperate folders named
+ `checkpoint_`.
+
+ Otherwise they are just saved to `output_dir`.
+
+
+
+ Should only be used when wanting to save a checkpoint during training and restoring the state in the same
+ environment.
+
+
+
+ Args:
+ output_dir (`str` or `os.PathLike`):
+ The name of the folder to save all relevant weights and states.
+ safe_serialization (`bool`, *optional*, defaults to `True`):
+ Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
+ save_model_func_kwargs (`dict`, *optional*):
+ Additional keyword arguments for saving model which can be passed to the underlying save function, such
+ as optional arguments for DeepSpeed's `save_checkpoint` function.
+
+ Example:
+
+ ```python
+ >>> from accelerate import Accelerator
+
+ >>> accelerator = Accelerator()
+ >>> model, optimizer, lr_scheduler = ...
+ >>> model, optimizer, lr_scheduler = accelerator.prepare(model, optimizer, lr_scheduler)
+ >>> accelerator.save_state(output_dir="my_checkpoint")
+ ```
+ """
+ if self.project_configuration.automatic_checkpoint_naming:
+ output_dir = os.path.join(self.project_dir, "checkpoints")
+ os.makedirs(output_dir, exist_ok=True)
+ if self.project_configuration.automatic_checkpoint_naming:
+ folders = [os.path.join(output_dir, folder) for folder in os.listdir(output_dir)]
+ if (
+ self.project_configuration.total_limit is not None
+ and (len(folders) + 1 > self.project_configuration.total_limit)
+ and self.is_main_process
+ ):
+
+ def _inner(folder):
+ return list(map(int, re.findall(r"[\/]?([0-9]+)(?=[^\/]*$)", folder)))[0]
+
+ folders.sort(key=_inner)
+ logger.warning(
+ f"Deleting {len(folders) + 1 - self.project_configuration.total_limit} checkpoints to make room for new checkpoint."
+ )
+ for folder in folders[: len(folders) + 1 - self.project_configuration.total_limit]:
+ shutil.rmtree(folder)
+ output_dir = os.path.join(output_dir, f"checkpoint_{self.save_iteration}")
+ if os.path.exists(output_dir):
+ raise ValueError(
+ f"Checkpoint directory {output_dir} ({self.save_iteration}) already exists. Please manually override `self.save_iteration` with what iteration to start with."
+ )
+ self.wait_for_everyone()
+ os.makedirs(output_dir, exist_ok=True)
+ logger.info(f"Saving current state to {output_dir}")
+
+ if self.distributed_type == DistributedType.XLA:
+ # Finish running the previous step before checkpointing
+ xm.mark_step()
+
+ # Save the models taking care of FSDP and DeepSpeed nuances
+ weights = []
+ for i, model in enumerate(self._models):
+ if self.distributed_type == DistributedType.FSDP:
+ logger.info("Saving FSDP model")
+ save_fsdp_model(self.state.fsdp_plugin, self, model, output_dir, i)
+ logger.info(f"FSDP Model saved to output dir {output_dir}")
+ elif self.distributed_type == DistributedType.DEEPSPEED:
+ logger.info("Saving DeepSpeed Model and Optimizer")
+ ckpt_id = f"{MODEL_NAME}" if i == 0 else f"{MODEL_NAME}_{i}"
+ model.save_checkpoint(output_dir, ckpt_id, **save_model_func_kwargs)
+ logger.info(f"DeepSpeed Model and Optimizer saved to output dir {os.path.join(output_dir, ckpt_id)}")
+ elif self.distributed_type == DistributedType.MEGATRON_LM:
+ logger.info("Saving Megatron-LM Model, Optimizer and Scheduler")
+ model.save_checkpoint(output_dir)
+ logger.info(f"Megatron-LM Model , Optimizer and Scheduler saved to output dir {output_dir}")
+ else:
+ weights.append(self.get_state_dict(model, unwrap=False))
+
+ # Save the optimizers taking care of FSDP and DeepSpeed nuances
+ optimizers = []
+ if self.distributed_type == DistributedType.FSDP:
+ for i, opt in enumerate(self._optimizers):
+ logger.info("Saving FSDP Optimizer")
+ save_fsdp_optimizer(self.state.fsdp_plugin, self, opt, self._models[i], output_dir, i)
+ logger.info(f"FSDP Optimizer saved to output dir {output_dir}")
+ elif self.distributed_type not in [DistributedType.DEEPSPEED, DistributedType.MEGATRON_LM]:
+ optimizers = self._optimizers
+
+ # Save the lr schedulers taking care of DeepSpeed nuances
+ schedulers = []
+ if self.distributed_type == DistributedType.DEEPSPEED:
+ for i, scheduler in enumerate(self._schedulers):
+ if isinstance(scheduler, DeepSpeedSchedulerWrapper):
+ continue
+ schedulers.append(scheduler)
+ elif self.distributed_type not in [DistributedType.MEGATRON_LM]:
+ schedulers = self._schedulers
+
+ # Save the samplers of the dataloaders
+ dataloaders = self._dataloaders
+
+ # Call model loading hooks that might have been registered with
+ # accelerator.register_model_state_hook
+ for hook in self._save_model_state_pre_hook.values():
+ hook(self._models, weights, output_dir)
+
+ save_location = save_accelerator_state(
+ output_dir,
+ weights,
+ optimizers,
+ schedulers,
+ dataloaders,
+ self.state.process_index,
+ self.scaler,
+ save_on_each_node=self.project_configuration.save_on_each_node,
+ safe_serialization=safe_serialization,
+ )
+ for i, obj in enumerate(self._custom_objects):
+ save_custom_state(obj, output_dir, i, save_on_each_node=self.project_configuration.save_on_each_node)
+ self.project_configuration.iteration += 1
+ return save_location
+
+ def register_load_state_pre_hook(self, hook: Callable[..., None]) -> hooks.RemovableHandle:
+ """
+ Registers a pre hook to be run before [`load_checkpoint`] is called in [`Accelerator.load_state`].
+
+ Args:
+ hook (`Callable`):
+ A function to be called in [`Accelerator.load_state`] before `load_checkpoint`.
+
+ The hook should have the following signature:
+
+ `hook(models: list[torch.nn.Module], input_dir: str) -> None`
+
+ The `models` argument are the models as saved in the accelerator state under `accelerator._models`, and the
+ `input_dir` argument is the `input_dir` argument passed to [`Accelerator.load_state`].
+
+
+
+ Should only be used in conjunction with [`Accelerator.register_save_state_pre_hook`]. Can be useful to load
+ configurations in addition to model weights. Can also be used to overwrite model loading with a customized
+ method. In this case, make sure to remove already loaded models from the models list.
+
+
+
+ Returns:
+ `torch.utils.hooks.RemovableHandle`: a handle that can be used to remove the added hook by calling
+ `handle.remove()`
+ """
+ handle = hooks.RemovableHandle(self._load_model_state_pre_hook)
+ self._load_model_state_pre_hook[handle.id] = hook
+ return handle
+
+ def load_state(self, input_dir: str = None, **load_model_func_kwargs):
+ """
+ Loads the current states of the model, optimizer, scaler, RNG generators, and registered objects.
+
+
+
+ Should only be used in conjunction with [`Accelerator.save_state`]. If a file is not registered for
+ checkpointing, it will not be loaded if stored in the directory.
+
+
+
+ Args:
+ input_dir (`str` or `os.PathLike`):
+ The name of the folder all relevant weights and states were saved in. Can be `None` if
+ `automatic_checkpoint_naming` is used, and will pick up from the latest checkpoint.
+ load_model_func_kwargs (`dict`, *optional*):
+ Additional keyword arguments for loading model which can be passed to the underlying load function,
+ such as optional arguments for DeepSpeed's `load_checkpoint` function or a `map_location` to load the
+ model and optimizer on.
+
+ Example:
+
+ ```python
+ >>> from accelerate import Accelerator
+
+ >>> accelerator = Accelerator()
+ >>> model, optimizer, lr_scheduler = ...
+ >>> model, optimizer, lr_scheduler = accelerator.prepare(model, optimizer, lr_scheduler)
+ >>> accelerator.load_state("my_checkpoint")
+ ```
+ """
+ if input_dir is not None:
+ # Check if folder exists
+ input_dir = os.path.expanduser(input_dir)
+ if not os.path.isdir(input_dir):
+ raise ValueError(f"Tried to find {input_dir} but folder does not exist")
+ elif self.project_configuration.automatic_checkpoint_naming:
+ # Pick up from automatic checkpoint naming
+ input_dir = os.path.join(self.project_dir, "checkpoints")
+ folders = [os.path.join(input_dir, folder) for folder in os.listdir(input_dir)]
+
+ def _inner(folder):
+ return list(map(int, re.findall(r"[\/]?([0-9]+)(?=[^\/]*$)", folder)))[0]
+
+ folders.sort(key=_inner)
+ input_dir = folders[-1]
+ else:
+ raise ValueError("No input_dir provided and automatic checkpoint naming is disabled.")
+ logger.info(f"Loading states from {input_dir}")
+
+ # Load the models taking care of FSDP and DeepSpeed nuances
+ models = []
+ for i, model in enumerate(self._models):
+ if self.distributed_type == DistributedType.FSDP:
+ logger.info("Loading FSDP model")
+ load_fsdp_model(self.state.fsdp_plugin, self, model, input_dir, i)
+ logger.info(f"FSDP Model loaded from input dir {input_dir}")
+ elif self.distributed_type == DistributedType.DEEPSPEED:
+ logger.info("Loading DeepSpeed Model and Optimizer")
+ ckpt_id = f"{MODEL_NAME}" if i == 0 else f"{MODEL_NAME}_{i}"
+ model.load_checkpoint(input_dir, ckpt_id, **load_model_func_kwargs)
+ logger.info(f"DeepSpeed Model and Optimizer loaded from input dir {os.path.join(input_dir, ckpt_id)}")
+ elif self.distributed_type == DistributedType.MEGATRON_LM:
+ logger.info("Loading Megatron-LM Model, Optimizer and Scheduler")
+ model.load_checkpoint(input_dir)
+ logger.info(f"Megatron-LM Model , Optimizer and Scheduler loaded from input dir {input_dir}")
+ else:
+ models.append(model)
+
+ # Load the optimizers taking care of FSDP and DeepSpeed nuances
+ optimizers = []
+ if self.distributed_type == DistributedType.FSDP:
+ for i, opt in enumerate(self._optimizers):
+ logger.info("Loading FSDP Optimizer")
+ load_fsdp_optimizer(self.state.fsdp_plugin, self, opt, self._models[i], input_dir, i)
+ logger.info(f"FSDP Optimizer loaded from input dir {input_dir}")
+ elif self.distributed_type not in [DistributedType.DEEPSPEED, DistributedType.MEGATRON_LM]:
+ optimizers = self._optimizers
+
+ # Load the lr schedulers taking care of DeepSpeed nuances
+ schedulers = []
+ if self.distributed_type == DistributedType.DEEPSPEED:
+ for i, scheduler in enumerate(self._schedulers):
+ if isinstance(scheduler, DeepSpeedSchedulerWrapper):
+ continue
+ schedulers.append(scheduler)
+ elif self.distributed_type not in [DistributedType.MEGATRON_LM]:
+ schedulers = self._schedulers
+
+ dataloaders = self._dataloaders
+
+ # Call model loading hooks that might have been registered with
+ # accelerator.register_model_state_hook
+ for hook in self._load_model_state_pre_hook.values():
+ hook(models, input_dir)
+
+ map_location = load_model_func_kwargs.pop("map_location", None)
+ if map_location is None:
+ if self.num_processes > 1 and self.distributed_type in (
+ DistributedType.MULTI_GPU,
+ DistributedType.MULTI_MLU,
+ DistributedType.MULTI_NPU,
+ ):
+ map_location = "on_device"
+ else:
+ map_location = "cpu"
+
+ load_accelerator_state(
+ input_dir,
+ models,
+ optimizers,
+ schedulers,
+ dataloaders,
+ self.state.process_index,
+ self.scaler,
+ map_location,
+ **load_model_func_kwargs,
+ )
+ custom_checkpoints = [
+ f for f in os.listdir(input_dir) if re.search(r"^custom_checkpoint_\d+\.pkl$", f) is not None
+ ]
+ if len(custom_checkpoints) != len(self._custom_objects):
+ err = "Number of custom checkpoints in folder {input_dir} does not match the number of registered objects:"
+ err += f"\n\tFound checkpoints: {len(custom_checkpoints)}"
+ err += f"\n\tRegistered objects: {len(self._custom_objects)}\n"
+ err += "Please make sure to only load checkpoints from folders that were created with the same set of registered objects,"
+ err += "or avoid using `custom_checkpoint` in the filename for files in that same directory and load them in manually."
+ raise RuntimeError(err)
+ else:
+ logger.info(f"Loading in {len(custom_checkpoints)} custom states")
+ for index, obj in enumerate(self._custom_objects):
+ load_custom_state(obj, input_dir, index)
+
+ def free_memory(self):
+ """
+ Will release all references to the internal objects stored and call the garbage collector. You should call this
+ method between two trainings with different models/optimizers. Also will reset `Accelerator.step` to 0.
+
+ Example:
+
+ ```python
+ >>> from accelerate import Accelerator
+
+ >>> accelerator = Accelerator()
+ >>> model, optimizer, scheduler = ...
+ >>> model, optimizer, scheduler = accelerator.prepare(model, optimizer, scheduler)
+ >>> accelerator.free_memory()
+ >>> del model, optimizer, scheduler
+ ```
+ """
+ self._schedulers = []
+ self._optimizers = []
+ self._models = []
+ self._dataloaders = []
+ self.deepspeed_engine_wrapped = None
+ self.step = 0
+ release_memory()
+
+ def clear(self):
+ """
+ Alias for [`Accelerate.free_memory`], releases all references to the internal objects stored and call the
+ garbage collector. You should call this method between two trainings with different models/optimizers.
+
+ Example:
+
+ ```python
+ >>> from accelerate import Accelerator
+
+ >>> accelerator = Accelerator()
+ >>> model, optimizer, scheduler = ...
+ >>> model, optimizer, scheduler = accelerator.prepare(model, optimizer, scheduler)
+ >>> accelerator.free_memory()
+ >>> del model, optimizer, scheduler
+ ```
+ """
+ self.free_memory()
+
+ def _get_named_parameters(self, *args):
+ named_parameters = {}
+ for obj in args:
+ if isinstance(obj, torch.nn.Module):
+ obj = extract_model_from_parallel(obj)
+ named_parameters.update({n: p for n, p in obj.named_parameters()})
+ return named_parameters
+
+ def _get_devices(self, *args):
+ model_device = None
+ optimizer_device = None
+ for obj in args:
+ # Loop through model parameters and stop at the first once we have its device.
+ if isinstance(obj, torch.nn.Module):
+ for param in obj.parameters():
+ model_device = param.device
+ break
+ # Loop through optimizer parameters groups and stop at the first once we have its device.
+ if isinstance(obj, torch.optim.Optimizer):
+ for param_group in obj.param_groups:
+ if len(param_group["params"]) > 0:
+ optimizer_device = param_group["params"][0].device
+ break
+ return (model_device, optimizer_device)
+
+ def get_state_dict(self, model, unwrap=True):
+ """
+ Returns the state dictionary of a model sent through [`Accelerator.prepare`] potentially without full
+ precision.
+
+ Args:
+ model (`torch.nn.Module`):
+ A PyTorch model sent through [`Accelerator.prepare`]
+ unwrap (`bool`, *optional*, defaults to `True`):
+ Whether to return the original underlying state_dict of `model` or to return the wrapped state_dict
+
+ Returns:
+ `dict`: The state dictionary of the model potentially without full precision.
+
+ Example:
+
+ ```python
+ >>> import torch
+ >>> from accelerate import Accelerator
+
+ >>> accelerator = Accelerator()
+ >>> net = torch.nn.Linear(2, 2)
+ >>> net = accelerator.prepare(net)
+ >>> state_dict = accelerator.get_state_dict(net)
+ ```
+ """
+
+ if self.distributed_type == DistributedType.DEEPSPEED:
+ if self.deepspeed_config["zero_optimization"]["stage"] == 3:
+ if model.zero_gather_16bit_weights_on_model_save():
+ state_dict = model._zero3_consolidated_16bit_state_dict()
+ else:
+ raise ValueError(
+ "Cannot get 16bit model weights because `stage3_gather_16bit_weights_on_model_save` in DeepSpeed config is False. "
+ "To save the model weights in 16bit, set `stage3_gather_16bit_weights_on_model_save` to True in DeepSpeed config file or "
+ "set `zero3_save_16bit_model` to True when using `accelerate config`. "
+ "To save the full checkpoint, run `model.save_checkpoint(save_dir)` and use `zero_to_fp32.py` to recover weights."
+ )
+ else:
+ from deepspeed.checkpoint.utils import clone_tensors_for_torch_save
+
+ state_dict = clone_tensors_for_torch_save(self.unwrap_model(model).state_dict())
+ elif self.distributed_type == DistributedType.FSDP:
+ from torch.distributed.fsdp import FullStateDictConfig, StateDictType
+ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
+
+ full_state_dict_config = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
+ with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, full_state_dict_config):
+ state_dict = model.state_dict()
+ else:
+ if unwrap:
+ model = self.unwrap_model(model)
+ state_dict = model.state_dict()
+
+ return state_dict
+
+ def register_for_checkpointing(self, *objects):
+ """
+ Makes note of `objects` and will save or load them in during `save_state` or `load_state`.
+
+ These should be utilized when the state is being loaded or saved in the same script. It is not designed to be
+ used in different scripts.
+
+
+
+ Every `object` must have a `load_state_dict` and `state_dict` function to be stored.
+
+
+
+ Example:
+
+ ```python
+ >>> from accelerate import Accelerator
+
+ >>> accelerator = Accelerator()
+ >>> # Assume `CustomObject` has a `state_dict` and `load_state_dict` function.
+ >>> obj = CustomObject()
+ >>> accelerator.register_for_checkpointing(obj)
+ >>> accelerator.save_state("checkpoint.pt")
+ ```
+ """
+ invalid_objects = []
+ for obj in objects:
+ if not hasattr(obj, "state_dict") or not hasattr(obj, "load_state_dict"):
+ invalid_objects.append(obj)
+ if len(invalid_objects) > 0:
+ err = "All `objects` must include a `state_dict` and `load_state_dict` function to be stored. The following inputs are invalid:"
+ for index, obj in enumerate(invalid_objects):
+ err += f"\n\t- Item at index {index}, `{get_pretty_name(obj)}`"
+ raise ValueError(err)
+ self._custom_objects.extend(objects)
+
+ @contextmanager
+ def autocast(self, cache_enabled: bool = False, autocast_handler: AutocastKwargs = None):
+ """
+ Will apply automatic mixed-precision inside the block inside this context manager, if it is enabled. Nothing
+ different will happen otherwise.
+
+ A different `autocast_handler` can be passed in to override the one set in the `Accelerator` object. This is
+ useful in blocks under `autocast` where you want to revert to fp32.
+
+ Example:
+
+ ```python
+ >>> from accelerate import Accelerator
+
+ >>> accelerator = Accelerator(mixed_precision="fp16")
+ >>> with accelerator.autocast():
+ ... train()
+ ```
+ """
+ if cache_enabled:
+ warnings.warn(
+ "Passing `cache_enabled=True` to `accelerator.autocast` is deprecated and will be removed in v0.23.0. "
+ "Please use the `AutocastKwargs` class instead and pass it to the `Accelerator` as a `kwarg_handler`.",
+ FutureWarning,
+ )
+ if self.autocast_handler is not None:
+ self.autocast_handler.cache_enabled = True
+ else:
+ self.autocast_handler = AutocastKwargs(cache_enabled=True)
+ if autocast_handler is None:
+ autocast_handler = self.autocast_handler
+ autocast_context = get_mixed_precision_context_manager(self.native_amp, autocast_handler)
+ autocast_context.__enter__()
+ # TODO: should the `yield` be in a try/finally block?
+ yield
+ autocast_context.__exit__(*sys.exc_info())
+
+ @property
+ def optimizer_step_was_skipped(self):
+ """
+ Whether or not the optimizer update was skipped (because of gradient overflow in mixed precision), in which
+ case the learning rate should not be changed.
+ """
+ for optimizer in self._optimizers:
+ if optimizer.step_was_skipped:
+ return True
+ return False
+
+ def skip_first_batches(self, dataloader, num_batches: int = 0):
+ """
+ Creates a new `torch.utils.data.DataLoader` that will efficiently skip the first `num_batches`.
+
+ Args:
+ dataloader (`torch.utils.data.DataLoader`): The data loader in which to skip batches.
+ num_batches (`int`, *optional*, defaults to 0): The number of batches to skip
+
+ Example:
+
+ ```python
+ >>> from accelerate import Accelerator
+
+ >>> accelerator = Accelerator()
+ >>> dataloader, model, optimizer, scheduler = accelerator.prepare(dataloader, model, optimizer, scheduler)
+ >>> skipped_dataloader = accelerator.skip_first_batches(dataloader, num_batches=2)
+ >>> # for the first epoch only
+ >>> for input, target in skipped_dataloader:
+ ... optimizer.zero_grad()
+ ... output = model(input)
+ ... loss = loss_func(output, target)
+ ... accelerator.backward(loss)
+ ... optimizer.step()
+
+ >>> # subsequent epochs
+ >>> for input, target in dataloader:
+ ... optimizer.zero_grad()
+ ... ...
+ ```
+ """
+ return skip_first_batches(dataloader, num_batches=num_batches)
+
+ def __deepcopy__(self, memo):
+ logger.info("Deep copying the `Accelerator` object, note that this will point to the same original object.")
+ return self
+
+ def verify_device_map(self, model: torch.nn.Module) -> bool:
+ """
+ Verifies that `model` has not been prepared with big model inference with a device-map resembling `auto`.
+ """
+ # Checks if any of the child modules has the attribute `hf_device_map` and this map has more than one entry.
+ for m in model.modules():
+ if hasattr(m, "hf_device_map") and len(m.hf_device_map) > 1:
+ return True
+
+ return False