Wrapper classes for torch Dataloaders, Optimizers, and Schedulers
The internal classes Accelerate uses to prepare objects for distributed training when calling prepare().
Datasets and DataLoaders
accelerate.data_loader.prepare_data_loader
< source >( dataloader: DataLoader device: Optional = None num_processes: Optional = None process_index: Optional = None split_batches: bool = False put_on_device: bool = False rng_types: Optional = None dispatch_batches: Optional = None even_batches: bool = True slice_fn_for_dispatch: Optional = None use_seedable_sampler: bool = False non_blocking: bool = False ) → torch.utils.data.dataloader.DataLoader
Parameters
- dataloader (
torch.utils.data.dataloader.DataLoader
) — The data loader to split across several devices. - device (
torch.device
) — The target device for the returnedDataLoader
. - num_processes (
int
, optional) — The number of processes running concurrently. Will default to the value given by AcceleratorState. - process_index (
int
, optional) — The index of the current process. Will default to the value given by AcceleratorState. - split_batches (
bool
, optional, defaults toFalse
) — Whether the resultingDataLoader
should split the batches of the original data loader across devices or yield full batches (in which case it will yield batches starting at theprocess_index
-th and advancing ofnum_processes
batches at each iteration).Another way to see this is that the observed batch size will be the same as the initial
dataloader
if this option is set toTrue
, the batch size of the initialdataloader
multiplied bynum_processes
otherwise.Setting this option to
True
requires that the batch size of thedataloader
is a round multiple ofbatch_size
. - put_on_device (
bool
, optional, defaults toFalse
) — Whether or not to put the batches ondevice
(only works if the batches are nested list, tuples or dictionaries of tensors). - rng_types (list of
str
or RNGType) — The list of random number generators to synchronize at the beginning of each iteration. Should be one or several of:"torch"
: the base torch random number generator"cuda"
: the CUDA random number generator (GPU only)"xla"
: the XLA random number generator (TPU only)"generator"
: thetorch.Generator
of the sampler (or batch sampler if there is no sampler in your dataloader) or of the iterable dataset (if it exists) if the underlying dataset is of that type.
- dispatch_batches (
bool
, optional) — If set toTrue
, the datalaoder prepared is only iterated through on the main process and then the batches are split and broadcast to each process. Will default toTrue
when the underlying dataset is anIterableDataset
,False
otherwise. - even_batches (
bool
, optional, defaults toTrue
) — If set toTrue
, in cases where the total batch size across all processes does not exactly divide the dataset, samples at the start of the dataset will be duplicated so the batch can be divided equally among all workers. - slice_fn_for_dispatch (
Callable
, optional) -- If passed, this function will be used to slice tensors across
num_processes. Will default to [slice_tensors()](/docs/accelerate/v0.30.0/en/package_reference/utilities#accelerate.utils.slice_tensors). This argument is used only when
dispatch_batchesis set to
True` and will be ignored otherwise. - use_seedable_sampler (
bool
, optional, defaults toFalse
) — Whether to use theSeedableRandomSampler
instead of aRandomSampler
for better reproducability. Comes at a cost of potentially different performances due to different shuffling algorithms but ensures results will be the exact same. Should be paired withset_seed()
at everyself.set_epoch
- non_blocking (
bool
, optional, defaults toFalse
) — If set toTrue
, dataloader will utilize non-blocking host-to-device transfers. If the dataloader haspin_memory
set toTrue
, this will help to increase overlap between data transfer and computations.
Returns
torch.utils.data.dataloader.DataLoader
A new data loader that will yield the portion of the batches
Wraps a PyTorch DataLoader
to generate batches for one of the processes only.
Depending on the value of the drop_last
attribute of the dataloader
passed, it will either stop the iteration
at the first batch that would be too small / not present on all processes or loop with indices from the beginning.
BatchSampler
s with varying batch sizes are not enabled by default. To enable this behaviour, set even_batches
equal to False
Creates a torch.utils.data.DataLoader
that will efficiently skip the first num_batches
.
class accelerate.data_loader.BatchSamplerShard
< source >( batch_sampler: BatchSampler num_processes: int = 1 process_index: int = 0 split_batches: bool = False even_batches: bool = True )
Parameters
- batch_sampler (
torch.utils.data.sampler.BatchSampler
) — The batch sampler to split in several shards. - num_processes (
int
, optional, defaults to 1) — The number of processes running concurrently. - process_index (
int
, optional, defaults to 0) — The index of the current process. - split_batches (
bool
, optional, defaults toFalse
) — Whether the shards should be created by splitting a batch to give a piece of it on each process, or by yielding different full batches on each process.On two processes with a sampler of
[[0, 1, 2, 3], [4, 5, 6, 7]]
, this will result in:- the sampler on process 0 to yield
[0, 1, 2, 3]
and the sampler on process 1 to yield[4, 5, 6, 7]
if this argument is set toFalse
. - the sampler on process 0 to yield
[0, 1]
then[4, 5]
and the sampler on process 1 to yield[2, 3]
then[6, 7]
if this argument is set toTrue
.
- the sampler on process 0 to yield
- even_batches (
bool
, optional, defaults toTrue
) — Whether or not to loop back at the beginning of the sampler when the number of samples is not a round multiple of (original batch size / number of processes).
Wraps a PyTorch BatchSampler
to generate batches for one of the processes only. Instances of this class will
always yield a number of batches that is a round multiple of num_processes
and that all have the same size.
Depending on the value of the drop_last
attribute of the batch sampler passed, it will either stop the iteration
at the first batch that would be too small / not present on all processes or loop with indices from the beginning.
BatchSampler
s with varying batch sizes are not enabled by default. To enable this behaviour, set even_batches
equal to False
class accelerate.data_loader.IterableDatasetShard
< source >( dataset: IterableDataset batch_size: int = 1 drop_last: bool = False num_processes: int = 1 process_index: int = 0 split_batches: bool = False )
Parameters
- dataset (
torch.utils.data.dataset.IterableDataset
) — The batch sampler to split in several shards. - batch_size (
int
, optional, defaults to 1) — The size of the batches per shard (ifsplit_batches=False
) or the size of the batches (ifsplit_batches=True
). - drop_last (
bool
, optional, defaults toFalse
) — Whether or not to drop the last incomplete batch or complete the last batches by using the samples from the beginning. - num_processes (
int
, optional, defaults to 1) — The number of processes running concurrently. - process_index (
int
, optional, defaults to 0) — The index of the current process. - split_batches (
bool
, optional, defaults toFalse
) — Whether the shards should be created by splitting a batch to give a piece of it on each process, or by yielding different full batches on each process.On two processes with an iterable dataset yielding of
[0, 1, 2, 3, 4, 5, 6, 7]
, this will result in:- the shard on process 0 to yield
[0, 1, 2, 3]
and the shard on process 1 to yield[4, 5, 6, 7]
if this argument is set toFalse
. - the shard on process 0 to yield
[0, 1, 4, 5]
and the sampler on process 1 to yield[2, 3, 6, 7]
if this argument is set toTrue
.
- the shard on process 0 to yield
Wraps a PyTorch IterableDataset
to generate samples for one of the processes only. Instances of this class will
always yield a number of samples that is a round multiple of the actual batch size (depending of the value of
split_batches
, this is either batch_size
or batch_size x num_processes
). Depending on the value of the
drop_last
attribute of the batch sampler passed, it will either stop the iteration at the first batch that would
be too small or loop with indices from the beginning.
class accelerate.data_loader.DataLoaderShard
< source >( dataset device = None rng_types = None synchronized_generator = None skip_batches = 0 _drop_last: bool = False _non_blocking: bool = False **kwargs )
Parameters
- dataset (
torch.utils.data.dataset.Dataset
) — The dataset to use to build this datalaoder. - device (
torch.device
, optional) — If passed, the device to put all batches on. - rng_types (list of
str
or RNGType) — The list of random number generators to synchronize at the beginning of each iteration. Should be one or several of:"torch"
: the base torch random number generator"cuda"
: the CUDA random number generator (GPU only)"xla"
: the XLA random number generator (TPU only)"generator"
: an optionaltorch.Generator
- synchronized_generator (
torch.Generator
, optional) — A random number generator to keep synchronized across processes. - skip_batches (
int
, optional, defaults to 0) — The number of batches to skip at the beginning. - **kwargs (additional keyword arguments, optional) —
All other keyword arguments to pass to the regular
DataLoader
initialization.
Subclass of a PyTorch DataLoader
that will deal with device placement and current distributed setup.
Available attributes:
total_batch_size (
int
) — Total batch size of the dataloader across all processes. Equal to the original batch size whensplit_batches=True
; otherwise the original batch size * the total number of processestotal_dataset_length (
int
) — Total length of the inner dataset across all processes.
class accelerate.data_loader.DataLoaderDispatcher
< source >( dataset split_batches: bool = False skip_batches = 0 _drop_last: bool = False _non_blocking: bool = False slice_fn = None **kwargs )
Parameters
- split_batches (
bool
, optional, defaults toFalse
) — Whether the resultingDataLoader
should split the batches of the original data loader across devices or yield full batches (in which case it will yield batches starting at theprocess_index
-th and advancing ofnum_processes
batches at each iteration). Another way to see this is that the observed batch size will be the same as the initialdataloader
if this option is set toTrue
, the batch size of the initialdataloader
multiplied bynum_processes
otherwise. Setting this option toTrue
requires that the batch size of thedataloader
is a round multiple ofbatch_size
. - skip_batches (
int
, optional, defaults to 0) — The number of batches to skip at the beginning of an iteration.
Subclass of a PyTorch DataLoader
that will iterate and preprocess on process 0 only, then dispatch on each
process their part of the batch.
Available attributes:
total_batch_size (
int
) — Total batch size of the dataloader across all processes. Equal to the original batch size whensplit_batches=True
; otherwise the original batch size * the total number of processestotal_dataset_length (
int
) — Total length of the inner dataset across all processes.
Optimizers
class accelerate.optimizer.AcceleratedOptimizer
< source >( optimizer device_placement = True scaler = None )
Parameters
- optimizer (
torch.optim.optimizer.Optimizer
) — The optimizer to wrap. - device_placement (
bool
, optional, defaults toTrue
) — Whether or not the optimizer should handle device placement. If so, it will place the state dictionary ofoptimizer
on the right device. - scaler (
torch.cuda.amp.grad_scaler.GradScaler
, optional) — The scaler to use in the step function if training with mixed precision.
Internal wrapper around a torch optimizer.
Conditionally will perform step
and zero_grad
if gradients should be synchronized when performing gradient
accumulation.
Sets the optimizer to “eval” mode. Useful for optimizers like schedule_free
Sets the optimizer to “train” mode. Useful for optimizers like schedule_free
Schedulers
class accelerate.scheduler.AcceleratedScheduler
< source >( scheduler optimizers step_with_optimizer: bool = True split_batches: bool = False )
Parameters
- scheduler (
torch.optim.lr_scheduler._LRScheduler
) — The scheduler to wrap. - optimizers (one or a list of
torch.optim.Optimizer
) — The optimizers used. - step_with_optimizer (
bool
, optional, defaults toTrue
) — Whether or not the scheduler should be stepped at each optimizer step. - split_batches (
bool
, optional, defaults toFalse
) — Whether or not the dataloaders split one batch across the different processes (so batch size is the same regardless of the number of processes) or create batches on each process (so batch size is the original batch size multiplied by the number of processes).
A wrapper around a learning rate scheduler that will only step when the optimizer(s) have a training step. Useful to avoid making a scheduler step too fast when gradients went overflow and there was no training step (in mixed precision training)
When performing gradient accumulation scheduler lengths should not be changed accordingly, Accelerate will always step the scheduler to account for it.