Internally, 🤗 Accelerate works by first analyzing the environment in which the script is launched to determine which
kind of distributed setup is used, how many different processes there are and which one the current script is in. All
that information is stored in the
This class is initialized the first time you instantiate an ~Accelerator as well as performing any specific initialization your distributed setup needs. Its state is then uniquely shared through all instances of AcceleratorState. (The same can also be done with the PartialState, a more barebones version it inherits)
Then, when calling prepare(), the library:
- wraps your model(s) in the container adapted for the distributed setup,
- wraps your optimizer(s) in an AcceleratedOptimizer,
- wraps your scheduler(s) in an AcceleratedScheduler
- creates a new version of your dataloader(s) in a DataLoaderShard or DataLoaderDispatcher
While the model(s), optimizer(s), and scheduler(s) are just put in simple wrappers, the dataloader(s) are re-created. This is mostly
because PyTorch does not let the user change the
batch_sampler of a dataloader once it’s been created and the
library handles the sharding of your data between processes by changing that
batch_sampler to yield every other
num_processes batches (if enabled).
The DataLoaderShard subclasses
DataLoader to add the following functionality:
- it synchronizes the appropriate random number generator of all processes at each new iteration, to ensure any randomization (like shuffling) is done the exact same way across processes.
- it puts the batches on the proper device before yielding them (unless you have opted out of
The DataLoaderDispatcher subclasses differs from the DataLoaderShard in that when iterating through the
DataLoader, the data is all starting from process 0 and then split and sent off to each process rather than it happening at the dataset level.
The random number generator synchronization will by default synchronize:
generatorattribute of a given sampler (like the PyTorch
RandomSampler) for PyTorch >= 1.6
- the main random number generator in PyTorch <=1.5.1
You can choose which random number generator(s) to synchronize with the
rng_types argument of the main
Accelerator. In PyTorch >= 1.6, it is recommended to rely on a local
generator to avoid
setting the same seed in the main random number generator in all processes.
Synchronization of the main torch (or CUDA or XLA) random number generator will affect any other potential random artifacts you could have in your dataset (like random data augmentation) in the sense that all processes will get the same random numbers from the torch random modules (so will apply the same random data augmentation if it’s controlled by torch).
The randomization part of your custom sampler, batch sampler or iterable dataset should be done using a local
torch.Generator object (in PyTorch >= 1.6), see the traditional
RandomSampler, as an example.
For more details about the internals, see the Internals page.