Accelerator
The Accelerator is the main class provided by π€ Accelerate. It serves at the main entrypoint for the API.
Quick adaptation of your code
To quickly adapt your script to work on any kind of setup with π€ Accelerate just:
- Initialize an Accelerator object (that we will call
accelerator
throughout this page) as early as possible in your script. - Pass your dataloader(s), model(s), optimizer(s), and scheduler(s) to the prepare() method.
- Remove all the
.cuda()
or.to(device)
from your code and let theaccelerator
handle the device placement for you.
Step three is optional, but considered a best practice.
- Replace
loss.backward()
in your code withaccelerator.backward(loss)
- Gather your predictions and labels before storing them or using them for metric computation using gather()
Step five is mandatory when using distributed evaluation
In most cases this is all that is needed. The next section lists a few more advanced use cases and nice features
you should search for and replace by the corresponding methods of your accelerator
:
Advanced recommendations
Printing
print
statements should be replaced by print() to be printed once per process
- print("My thing I want to print!")
+ accelerator.print("My thing I want to print!")
Executing processes
Once on a single server
For statements that should be executed once per server, use is_local_main_process
:
if accelerator.is_local_main_process:
do_thing_once_per_server()
A function can be wrapped using the on_local_main_process() function to achieve the same behavior on a functionβs execution:
@accelerator.on_local_main_process
def do_my_thing():
"Something done once per server"
do_thing_once_per_server()
Only ever once across all servers
For statements that should only ever be executed once, use is_main_process
:
if accelerator.is_main_process:
do_thing_once()
A function can be wrapped using the on_main_process() function to achieve the same behavior on a functionβs execution:
@accelerator.on_main_process
def do_my_thing():
"Something done once per server"
do_thing_once()
On specific processes
If a function should be ran on a specific overall or local process index, there are similar decorators to achieve this:
@accelerator.on_local_process(local_process_idx=0)
def do_my_thing():
"Something done on process index 0 on each server"
do_thing_on_index_zero_on_each_server()
@accelerator.on_process(process_index=0)
def do_my_thing():
"Something done on process index 0"
do_thing_on_index_zero()
Synchronicity control
Use wait_for_everyone() to make sure all processes join that point before continuing. (Useful before a model save for instance)
Saving and loading
Use unwrap_model() before saving to remove all special model wrappers added during the distributed process.
model = MyModel()
model = accelerator.prepare(model)
# Unwrap
model = accelerator.unwrap_model(model)
Use save() instead of torch.save
:
state_dict = model.state_dict()
- torch.save(state_dict, "my_state.pkl")
+ accelerator.save(state_dict, "my_state.pkl")
Operations
Use clipgrad_norm() instead of torch.nn.utils.clip_grad_norm_
and clipgrad_value() instead of torch.nn.utils.clip_grad_value
Gradient Accumulation
To perform gradient accumulation use accumulate() and specify a gradient_accumulation_steps. This will also automatically ensure the gradients are synced or unsynced when on multi-device training, check if the step should actually be performed, and auto-scale the loss:
- accelerator = Accelerator()
+ accelerator = Accelerator(gradient_accumulation_steps=2)
for (input, label) in training_dataloader:
+ with accelerator.accumulate(model):
predictions = model(input)
loss = loss_function(predictions, labels)
accelerator.backward(loss)
optimizer.step()
scheduler.step()
optimizer.zero_grad()
Overall API documentation:
class accelerate.Accelerator
< source >( device_placement: bool = True split_batches: bool = False mixed_precision: typing.Union[accelerate.utils.dataclasses.PrecisionType, str] = None gradient_accumulation_steps: int = 1 cpu: bool = False deepspeed_plugin: DeepSpeedPlugin = None fsdp_plugin: FullyShardedDataParallelPlugin = None megatron_lm_plugin: MegatronLMPlugin = None rng_types: typing.Union[typing.List[typing.Union[str, accelerate.utils.dataclasses.RNGType]], NoneType] = None log_with: typing.Union[typing.List[typing.Union[str, accelerate.utils.dataclasses.LoggerType, accelerate.tracking.GeneralTracker]], NoneType] = None project_dir: typing.Union[str, os.PathLike, NoneType] = None project_config: typing.Optional[accelerate.utils.dataclasses.ProjectConfiguration] = None logging_dir: typing.Union[str, os.PathLike, NoneType] = None dispatch_batches: typing.Optional[bool] = None even_batches: bool = True step_scheduler_with_optimizer: bool = True kwargs_handlers: typing.Optional[typing.List[accelerate.utils.dataclasses.KwargsHandler]] = None dynamo_backend: typing.Union[accelerate.utils.dataclasses.DynamoBackend, str] = None )
Parameters
-
device_placement (
bool
, optional, defaults toTrue
) — Whether or not the accelerator should put objects on device (tensors yielded by the dataloader, model, etc…). -
split_batches (
bool
, optional, defaults toFalse
) — Whether or not the accelerator should split the batches yielded by the dataloaders across the devices. IfTrue
the actual batch size used will be the same on any kind of distributed processes, but it must be a round multiple of thenum_processes
you are using. IfFalse
, actual batch size used will be the one set in your script multiplied by the number of processes. -
mixed_precision (
str
, optional) — Whether or not to use mixed precision training (fp16 or bfloat16). Choose from ‘no’,‘fp16’,‘bf16’. Will default to the value in the environment variableACCELERATE_MIXED_PRECISION
, which will use the default value in the accelerate config of the current system or the flag passed with theaccelerate.launch
command. ‘fp16’ requires pytorch 1.6 or higher. ‘bf16’ requires pytorch 1.10 or higher. -
gradient_accumulation_steps (
int
, optional, default to 1) — The number of steps that should pass before gradients are accumulated. A number > 1 should be combined withAccelerator.accumulate
. If not passed, will default to the value in the environment variableACCELERATE_GRADIENT_ACCUMULATION_STEPS
. -
cpu (
bool
, optional) — Whether or not to force the script to execute on CPU. Will ignore GPU available if set toTrue
and force the execution on one process only. -
deepspeed_plugin (
DeepSpeedPlugin
, optional) — Tweak your DeepSpeed related args using this argument. This argument is optional and can be configured directly using accelerate config -
fsdp_plugin (
FullyShardedDataParallelPlugin
, optional) — Tweak your FSDP related args using this argument. This argument is optional and can be configured directly using accelerate config -
megatron_lm_plugin (
MegatronLMPlugin
, optional) — Tweak your MegatronLM related args using this argument. This argument is optional and can be configured directly using accelerate config -
rng_types (list of
str
orRNGType
) — The list of random number generators to synchronize at the beginning of each iteration in your prepared dataloaders. Should be one or several of:"torch"
: the base torch random number generator"cuda"
: the CUDA random number generator (GPU only)"xla"
: the XLA random number generator (TPU only)"generator"
: thetorch.Generator
of the sampler (or batch sampler if there is no sampler in your dataloader) or of the iterable dataset (if it exists) if the underlying dataset is of that type.
Will default to
["torch"]
for PyTorch versions <=1.5.1 and["generator"]
for PyTorch versions >= 1.6. -
log_with (list of
str
, LoggerType or GeneralTracker, optional) — A list of loggers to be setup for experiment tracking. Should be one or several of:"all"
"tensorboard"
"wandb"
"comet_ml"
If"all"
is selected, will pick up all available trackers in the environment and initialize them. Can also accept implementations ofGeneralTracker
for custom trackers, and can be combined with"all"
.
-
project_config (
ProjectConfiguration
, optional) — A configuration for how saving the state can be handled. -
project_dir (
str
,os.PathLike
, optional) — A path to a directory for storing data such as logs of locally-compatible loggers and potentially saved checkpoints. -
dispatch_batches (
bool
, optional) — If set toTrue
, the dataloader prepared by the Accelerator is only iterated through on the main process and then the batches are split and broadcast to each process. Will default toTrue
forDataLoader
whose underlying dataset is anIterableDataset
,False
otherwise. -
even_batches (
bool
, optional, defaults toTrue
) — If set toTrue
, in cases where the total batch size across all processes does not exactly divide the dataset, samples at the start of the dataset will be duplicated so the batch can be divided equally among all workers. -
step_scheduler_with_optimizer (
bool
, *optional, defaults to
True) -- Set
Trueif the learning rate scheduler is stepped at the same time as the optimizer,
False` if only done under certain circumstances (at the end of each epoch, for instance). -
kwargs_handlers (
List[KwargHandler]
, optional) — A list ofKwargHandler
to customize how the objects related to distributed training or mixed precision are created. See kwargs for more information. -
dynamo_backend (
str
orDynamoBackend
, optional, defaults to"no"
) — Set to one of the possible dynamo backends to optimize your training with torch dynamo.
Creates an instance of an accelerator for distributed training (on multi-GPU, TPU) or mixed precision training.
Available attributes:
- device (
torch.device
) β The device to use. - distributed_type (DistributedType) β The distributed training configuration.
- local_process_index (
int
) β The process index on the current machine. - mixed_precision (
str
) β The configured mixed precision mode. - num_processes (
int
) β The total number of processes used for training. - optimizer_step_was_skipped (
bool
) β Whether or not the optimizer update was skipped (because of gradient overflow in mixed precision), in which case the learning rate should not be changed. - process_index (
int
) β The overall index of the current process among all processes. - state (AcceleratorState) β The distributed setup state.
- sync_gradients (
bool
) β Whether the gradients are currently being synced across all processes. - use_distributed (
bool
) β Whether the current configuration is for distributed training.
accumulate
< source >( model )
A context manager that will lightly wrap around and perform gradient accumulation automatically
Example:
>>> from accelerate import Accelerator
>>> accelerator = Accelerator(gradient_accumulation_steps=2)
>>> dataloader, model, optimizer, scheduler = accelerator.prepare(dataloader, model, optimizer, scheduler)
>>> with accelerator.accumulate():
... for input, output in dataloader:
... outputs = model(input)
... loss = loss_func(outputs)
... loss.backward()
... optimizer.step()
... scheduler.step()
... optimizer.zero_grad()
Will apply automatic mixed-precision inside the block inside this context manager, if it is enabled. Nothing different will happen otherwise.
Scales the gradients in accordance to Accelerator.gradient_accumulation_steps
and calls the correct
backward()
based on the configuration.
Should be used in lieu of loss.backward()
.
Alias for Accelerate.free_memory
, releases all references to the internal objects stored and call the
garbage collector. You should call this method between two trainings with different models/optimizers.
clip_grad_norm_
< source >(
parameters
max_norm
norm_type = 2
)
β
torch.Tensor
Returns
torch.Tensor
Total norm of the parameter gradients (viewed as a single vector).
Should be used in place of torch.nn.utils.clip_grad_norm_
.
Example:
>>> from accelerate import Accelerator
>>> accelerator = Accelerator(gradient_accumulation_steps=2)
>>> dataloader, model, optimizer, scheduler = accelerator.prepare(dataloader, model, optimizer, scheduler)
>>> for (input, target) in dataloader:
... optimizer.zero_grad()
... output = model(input)
... loss = loss_func(output, target)
... accelerator.backward(loss)
... if accelerator.sync_gradients:
... accelerator.clip_grad_norm_(model.parameters(), max_grad_norm)
... optimizer.step()
Should be used in place of torch.nn.utils.clip_grad_value_
.
Example:
>>> from accelerate import Accelerator
>>> accelerator = Accelerator(gradient_accumulation_steps=2)
>>> dataloader, model, optimizer, scheduler = accelerator.prepare(dataloader, model, optimizer, scheduler)
>>> for (input, target) in dataloader:
... optimizer.zero_grad()
... output = model(input)
... loss = loss_func(output, target)
... accelerator.backward(loss)
... if accelerator.sync_gradients:
... accelerator.clip_grad_value_(model.parameters(), clip_value)
... optimizer.step()
Runs any special end training behaviors, such as stopping trackers on the main process only. Should always be called at the end of your script if using experiment tracking.
Will release all references to the internal objects stored and call the garbage collector. You should call this method between two trainings with different models/optimizers.
gather
< source >(
tensor
)
β
torch.Tensor
, or a nested tuple/list/dictionary of torch.Tensor
Parameters
-
tensor (
torch.Tensor
, or a nested tuple/list/dictionary oftorch.Tensor
) — The tensors to gather across all processes.
Returns
torch.Tensor
, or a nested tuple/list/dictionary of torch.Tensor
The gathered tensor(s). Note that the first dimension of the result is num_processes multiplied by the first dimension of the input tensors.
Gather the values in tensor across all processes and concatenate them on the first dimension. Useful to regroup the predictions from all processes when doing evaluation.
Note: This gather happens in all processes.
gather_for_metrics
< source >( tensor )
Gathers tensor
and potentially drops duplicates in the last batch if on a distributed system. Should be used
for gathering the inputs and targets for metric calculation.
Example:
>>> # Assuming two processes, with a batch size of 5 on a dataset with 9 samples
>>> import torch
>>> from accelerate import Accelerator
>>> accelerator = Accelerator()
>>> dataloader = torch.utils.data.DataLoader(range(9), batch_size=5)
>>> dataloader = accelerator.prepare(dataloader)
>>> batch = next(iter(dataloader))
>>> gathered_items = accelerator.gather_for_metrics(batch)
>>> len(gathered_items)
9
get_state_dict
< source >(
model
unwrap = True
)
β
dict
Parameters
-
model (
torch.nn.Module
) — A PyTorch model sent through Accelerator.prepare() -
unwrap (
bool
, optional, defaults toTrue
) — Whether to return the original underlying state_dict ofmodel
or to return the wrapped state_dict
Returns
dict
The state dictionary of the model potentially without full precision.
Returns the state dictionary of a model sent through Accelerator.prepare() potentially without full precision.
get_tracker
< source >(
name: str
)
β
GeneralTracker
Returns a tracker
from self.trackers
based on name
on the main process only.
init_trackers
< source >( project_name: str config: typing.Optional[dict] = None init_kwargs: typing.Optional[dict] = {} )
Parameters
-
project_name (
str
) — The name of the project. All trackers will save their data based on this -
config (
dict
, optional) — Optional starting configuration to be logged. -
init_kwargs (
dict
, optional) — A nested dictionary of kwargs to be passed to a specific tracker’s__init__
function. Should be formatted like so:
Initializes a run for all trackers stored in self.log_with
, potentially with starting configurations
join_uneven_inputs
< source >( joinables even_batches = None )
Parameters
-
joinables (
List[torch.distributed.algorithms.Joinable]
) — A list of models or optimizers that subclasstorch.distributed.algorithms.Joinable
. Most commonly, a PyTorch Module that was prepared withAccelerator.prepare
for DistributedDataParallel training. -
even_batches (
bool
, optional) — If set, this will override the value ofeven_batches
set in theAccelerator
. If it is not provided, the defaultAccelerator
value wil be used.
A context manager that facilitates distributed training or evaluation on uneven inputs, which acts as a wrapper
around torch.distributed.algorithms.join
. This is useful when the total batch size does not evenly divide the
length of the dataset.
join_uneven_inputs
is only supported for Distributed Data Parallel training on multiple GPUs. For any other
configuration, this method will have no effect.
Overidding even_batches
will not affect iterable-style data loaders.
Example:
>>> from accelerate import Accelerator
>>> accelerator = Accelerator(even_batches=True)
>>> ddp_model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
>>> with accelerator.join_uneven_inputs([ddp_model], even_batches=False):
... for input, output in dataloader:
... outputs = model(input)
... loss = loss_func(outputs)
... loss.backward()
... optimizer.step()
... optimizer.zero_grad()
load_state
< source >( input_dir: str **load_model_func_kwargs )
Parameters
-
input_dir (
str
oros.PathLike
) — The name of the folder all relevant weights and states were saved in. -
load_model_func_kwargs (
dict
, optional) — Additional keyword arguments for loading model which can be passed to the underlying load function, such as optional arguments for DeepSpeed’sload_checkpoint
function.
Loads the current states of the model, optimizer, scaler, RNG generators, and registered objects.
Should only be used in conjunction with Accelerator.save_state().
Lets the local main process go inside a with block.
The other processes will enter the with block after the main process exits.
Example:
>>> from accelerate import Accelerator
>>> accelerator = Accelerator()
>>> with accelerator.local_main_process_first():
... # This will be printed first by local process 0 then in a seemingly
... # random order by the other processes.
... print(f"This will be printed by process {accelerator.local_process_index}")
log
< source >( values: dict step: typing.Optional[int] = None log_kwargs: typing.Optional[dict] = {} )
Parameters
-
values (
dict
) — Values should be a dictionary-like object containing only typesint
,float
, orstr
. -
step (
int
, optional) — The run step. If included, the log will be affiliated with this step. -
log_kwargs (
dict
, optional) — A nested dictionary of kwargs to be passed to a specific tracker’slog
function. Should be formatted like so:
Logs values
to all stored trackers in self.trackers
on the main process only.
Lets the main process go first inside a with block.
The other processes will enter the with block after the main process exits.
Example:
>>> from accelerate import Accelerator
>>> accelerator = Accelerator()
>>> with accelerator.main_process_first():
... # This will be printed first by process 0 then in a seemingly
... # random order by the other processes.
... print(f"This will be printed by process {accelerator.process_index}")
no_sync
< source >( model )
A context manager to disable gradient synchronizations across DDP processes by calling
torch.nn.parallel.DistributedDataParallel.no_sync
.
If model
is not in DDP, this context manager does nothing
Example:
>>> from accelerate import Accelerator
>>> accelerator = Accelerator()
>>> dataloader, model, optimizer = accelerator.prepare(dataloader, model, optimizer)
>>> input_a = next(iter(dataloader))
>>> input_b = next(iter(dataloader))
>>> with accelerator.no_sync():
... outputs = model(input_a)
... loss = loss_func(outputs)
... accelerator.backward(loss)
... # No synchronization across processes, only accumulate gradients
>>> outputs = model(input_b)
>>> accelerator.backward(loss)
>>> # Synchronization across all processes
>>> optimizer.step()
>>> optimizer.zero_grad()
A decorator that will run the decorated function on the last process only.
A decorator that will run the decorated function on the local main process only.
A decorator that will run the decorated function on a given local process index only.
A decorator that will run the decorated function on the main process only.
A decorator that will run the decorated function on a given process index only.
pad_across_processes
< source >(
tensor
dim = 0
pad_index = 0
pad_first = False
)
β
torch.Tensor
, or a nested tuple/list/dictionary of torch.Tensor
Parameters
-
tensor (nested list/tuple/dictionary of
torch.Tensor
) — The data to gather. -
dim (
int
, optional, defaults to 0) — The dimension on which to pad. -
pad_index (
int
, optional, defaults to 0) — The value with which to pad. -
pad_first (
bool
, optional, defaults toFalse
) — Whether to pad at the beginning or the end.
Returns
torch.Tensor
, or a nested tuple/list/dictionary of torch.Tensor
The padded tensor(s).
Recursively pad the tensors in a nested list/tuple/dictionary of tensors from all devices to the same size so they can safely be gathered.
Example:
>>> # Assuming two processes, with the first processes having a tensor of size 1 and the second of size 2
>>> import torch
>>> from accelerate import Accelerator
>>> accelerator = Accelerator()
>>> process_tensor = torch.arange(accelerator.process_index + 1).to(accelerator.device)
>>> padded_tensor = accelerator.pad_across_processes(process_tensor)
>>> padded_tensor.shape
torch.Size([2])
prepare
< source >( *args device_placement = None )
Parameters
-
*args (list of objects) —
Any of the following type of objects:
torch.utils.data.DataLoader
: PyTorch Dataloadertorch.nn.Module
: PyTorch Moduletorch.optim.Optimizer
: PyTorch Optimizertorch.optim.lr_scheduler.LRScheduler
: PyTorch LR Scheduler
-
device_placement (
List[bool]
, optional) — Used to customize whether automatic device placement should be performed for each object passed. Needs to be a list of the same length asargs
.
Prepare all objects passed in args
for distributed training and mixed precision, then return them in the same
order.
You donβt need to prepare a model if you only use it for inference without any kind of mixed precision
prepare_data_loader
< source >( data_loader: DataLoader device_placement = None )
Prepares a PyTorch DataLoader for training in any distributed setup. It is recommended to use Accelerator.prepare() instead.
prepare_model
< source >( model: Module device_placement = None )
Prepares a PyTorch model for training in any distributed setup. It is recommended to use Accelerator.prepare() instead.
prepare_optimizer
< source >( optimizer: Optimizer device_placement = None )
Prepares a PyTorch Optimizer for training in any distributed setup. It is recommended to use Accelerator.prepare() instead.
prepare_scheduler
< source >( scheduler: _LRScheduler )
Prepares a PyTorch Scheduler for training in any distributed setup. It is recommended to use Accelerator.prepare() instead.
Drop in replacement of print()
to only print once per server.
reduce
< source >(
tensor
reduction = 'sum'
)
β
torch.Tensor
, or a nested tuple/list/dictionary of torch.Tensor
Parameters
-
tensor (
torch.Tensor
, or a nested tuple/list/dictionary oftorch.Tensor
) — The tensors to reduce across all processes. -
reduction (
str
, optional, defaults to “sum”) — A reduction type, can be one of ‘sum’, ‘mean’, or ‘none’. If ‘none’, will not perform any operation.
Returns
torch.Tensor
, or a nested tuple/list/dictionary of torch.Tensor
The reduced tensor(s).
Reduce the values in tensor across all processes based on reduction.
Note: All processes get the reduced value.
Example:
>>> # Assuming two processes
>>> import torch
>>> from accelerate import Accelerator
>>> accelerator = Accelerator()
>>> process_tensor = torch.arange(accelerator.num_processes) + 1 + (2 * accelerator.process_index)
>>> process_tensor = process_tensor.to(accelerator.device)
>>> reduced_tensor = accelerator.reduce(process_tensor, reduction="sum")
>>> reduced_tensor
tensor([4, 6])
Makes note of objects
and will save or load them in during save_state
or load_state
.
These should be utilized when the state is being loaded or saved in the same script. It is not designed to be used in different scripts.
Every object
must have a load_state_dict
and state_dict
function to be stored.
register_load_state_pre_hook
< source >(
hook: typing.Callable[..., NoneType]
)
β
torch.utils.hooks.RemovableHandle
Parameters
-
hook (
Callable
) — A function to be called in Accelerator.load_state() beforeload_checkpoint
.
Returns
torch.utils.hooks.RemovableHandle
a handle that can be used to remove the added hook by calling
handle.remove()
Registers a pre hook to be run before load_checkpoint
is called in Accelerator.load_state().
The hook should have the following signature:
hook(models: List[torch.nn.Module], input_dir: str) -> None
The models
argument are the models as saved in the accelerator state under accelerator._models
, and the
input_dir
argument is the input_dir
argument passed to Accelerator.load_state().
Should only be used in conjunction with Accelerator.register_save_state_pre_hook(). Can be useful to load configurations in addition to model weights. Can also be used to overwrite model loading with a customized method. In this case, make sure to remove already loaded models from the models list.
register_save_state_pre_hook
< source >(
hook: typing.Callable[..., NoneType]
)
β
torch.utils.hooks.RemovableHandle
Parameters
-
hook (
Callable
) — A function to be called in Accelerator.save_state() beforesave_checkpoint
.
Returns
torch.utils.hooks.RemovableHandle
a handle that can be used to remove the added hook by calling
handle.remove()
Registers a pre hook to be run before save_checkpoint
is called in Accelerator.save_state().
The hook should have the following signature:
hook(models: List[torch.nn.Module], weights: List[Dict[str, torch.Tensor]], input_dir: str) -> None
The models
argument are the models as saved in the accelerator state under accelerator._models
, weigths
argument are the state dicts of the models
, and the input_dir
argument is the input_dir
argument passed
to Accelerator.load_state().
Should only be used in conjunction with Accelerator.register_load_state_pre_hook(). Can be useful to save configurations in addition to model weights. Can also be used to overwrite model saving with a customized method. In this case, make sure to remove already loaded weights from the weights list.
save
< source >( obj f )
Save the object passed to disk once per machine. Use in place of torch.save
.
save_state
< source >( output_dir: str = None **save_model_func_kwargs )
Parameters
-
output_dir (
str
oros.PathLike
) — The name of the folder to save all relevant weights and states. -
save_model_func_kwargs (
dict
, optional) — Additional keyword arguments for saving model which can be passed to the underlying save function, such as optional arguments for DeepSpeed’ssave_checkpoint
function.
Saves the current states of the model, optimizer, scaler, RNG generators, and registered objects to a folder.
If a ProjectConfiguration
was passed to the Accelerator
object with automatic_checkpoint_naming
enabled
then checkpoints will be saved to self.project_dir/checkpoints
. If the number of current saves is greater
than total_limit
then the oldest save is deleted. Each checkpoint is saved in seperate folders named
checkpoint_<iteration>
.
Otherwise they are just saved to output_dir
.
Should only be used when wanting to save a checkpoint during training and restoring the state in the same environment.
skip_first_batches
< source >( dataloader num_batches: int = 0 )
Creates a new torch.utils.data.DataLoader
that will efficiently skip the first num_batches
.
Example:
>>> from accelerate import Accelerator
>>> accelerator = Accelerator()
>>> dataloader, model, optimizer, scheduler = accelerator.prepare(dataloader, model, optimizer, scheduler)
>>> for (input, target) in accelerator.skip_first_batches(dataloader, num_batches=2):
... optimizer.zero_grad()
... output = model(input)
... loss = loss_func(output, target)
... accelerator.backward(loss)
... optimizer.step()
unscale_gradients
< source >( optimizer = None )
Parameters
-
optimizer (
torch.optim.Optimizer
orList[torch.optim.Optimizer]
, optional) — The optimizer(s) for which to unscale gradients. If not set, will unscale gradients on all optimizers that were passed to prepare().
Unscale the gradients in mixed precision training with AMP. This is a noop in all other settings.
Likely should be called through Accelerator.clipgrad_norm() or Accelerator.clipgrad_value()
unwrap_model
< source >(
model
keep_fp32_wrapper: bool = True
)
β
torch.nn.Module
Unwraps the model
from the additional layer possible added by prepare(). Useful before saving
the model.
Example:
>>> # Assuming two GPU processes
>>> from torch.nn.parallel import DistributedDataParallel
>>> from accelerate import Accelerator
>>> accelerator = Accelerator()
>>> model = accelerator.prepare(MyModel())
>>> print(model.__class__.__name__)
DistributedDataParallel
>>> model = accelerator.unwrap_model(model)
>>> print(model.__class__.__name__)
MyModel
Will stop the execution of the current process until every other process has reached that point (so this does nothing when the script is only run in one process). Useful to do before saving a model.
Example:
>>> # Assuming two GPU processes
>>> import time
>>> from accelerate import Accelerator
>>> accelerator = Accelerator()
>>> if accelerator.is_main_process:
... time.sleep(2)
>>> else:
... print("I'm waiting for the main process to finish its sleep...")
>>> accelerator.wait_for_everyone()
>>> # Should print on every process at the same time
>>> print("Everyone is here")