Accelerate documentation

🤗 Accelerate’s internal mechanisms

You are viewing v0.34.2 version. A newer version v1.2.1 is available.
Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

🤗 Accelerate’s internal mechanisms

Internally, 🤗 Accelerate works by first analyzing the environment in which the script is launched to determine which kind of distributed setup is used, how many different processes there are and which one the current script is in. All that information is stored in the ~AcceleratorState.

This class is initialized the first time you instantiate an ~Accelerator as well as performing any specific initialization your distributed setup needs. Its state is then uniquely shared through all instances of AcceleratorState. (The same can also be done with the PartialState, a more barebones version it inherits)

Then, when calling prepare(), the library:

While the model(s), optimizer(s), and scheduler(s) are just put in simple wrappers, the dataloader(s) are re-created. This is mostly because PyTorch does not let the user change the batch_sampler of a dataloader once it’s been created and the library handles the sharding of your data between processes by changing that batch_sampler to yield every other num_processes batches (if enabled).

The DataLoaderShard subclasses DataLoader to add the following functionality:

  • it synchronizes the appropriate random number generator of all processes at each new iteration, to ensure any randomization (like shuffling) is done the exact same way across processes.
  • it puts the batches on the proper device before yielding them (unless you have opted out of device_placement=True).

The DataLoaderDispatcher subclasses differs from the DataLoaderShard in that when iterating through the DataLoader, the data is all starting from process 0 and then split and sent off to each process rather than it happening at the dataset level.

The random number generator synchronization will by default synchronize:

  • the generator attribute of a given sampler (like the PyTorch RandomSampler) for PyTorch >= 1.6
  • the main random number generator in PyTorch <=1.5.1

You can choose which random number generator(s) to synchronize with the rng_types argument of the main Accelerator. In PyTorch >= 1.6, it is recommended to rely on a local generator to avoid setting the same seed in the main random number generator in all processes.

Synchronization of the main torch (or CUDA or XLA) random number generator will affect any other potential random artifacts you could have in your dataset (like random data augmentation) in the sense that all processes will get the same random numbers from the torch random modules (so will apply the same random data augmentation if it’s controlled by torch).

The randomization part of your custom sampler, batch sampler or iterable dataset should be done using a local torch.Generator object (in PyTorch >= 1.6), see the traditional RandomSampler, as an example.

<Note>

If you have torchdata>=0.8.0 installed, and you have passed use_stateful_dataloader=True into your DataLoaderConfiguration, these classes will directly inherit from StatefulDataLoader instead, and maintain a state_dict.

</Note>

For more details about the internals, see the Internals page.

< > Update on GitHub