Accelerate documentation

Checkpointing

You are viewing main version, which requires installation from source. If you'd like regular pip install, checkout the latest stable version (v1.2.1).
Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

Checkpointing

When training a PyTorch model with Accelerate, you may often want to save and continue a state of training. Doing so requires saving and loading the model, optimizer, RNG generators, and the GradScaler. Inside Accelerate are two convenience functions to achieve this quickly:

  • Use save_state() for saving everything mentioned above to a folder location
  • Use load_state() for loading everything stored from an earlier save_state

To further customize where and how states are saved through save_state() the ProjectConfiguration class can be used. For example if automatic_checkpoint_naming is enabled each saved checkpoint will be located then at Accelerator.project_dir/checkpoints/checkpoint_{checkpoint_number}.

It should be noted that the expectation is that those states come from the same training script, they should not be from two separate scripts.

  • By using register_for_checkpointing(), you can register custom objects to be automatically stored or loaded from the two prior functions, so long as the object has a state_dict and a load_state_dict functionality. This could include objects such as a learning rate scheduler.

Below is a brief example using checkpointing to save and reload a state during training:

from accelerate import Accelerator
import torch

accelerator = Accelerator(project_dir="my/save/path")

my_scheduler = torch.optim.lr_scheduler.StepLR(my_optimizer, step_size=1, gamma=0.99)
my_model, my_optimizer, my_training_dataloader = accelerator.prepare(my_model, my_optimizer, my_training_dataloader)

# Register the LR scheduler
accelerator.register_for_checkpointing(my_scheduler)

# Save the starting state
accelerator.save_state()

device = accelerator.device
my_model.to(device)

# Perform training
for epoch in range(num_epochs):
    for batch in my_training_dataloader:
        my_optimizer.zero_grad()
        inputs, targets = batch
        inputs = inputs.to(device)
        targets = targets.to(device)
        outputs = my_model(inputs)
        loss = my_loss_function(outputs, targets)
        accelerator.backward(loss)
        my_optimizer.step()
    my_scheduler.step()

# Restore the previous state
accelerator.load_state("my/save/path/checkpointing/checkpoint_0")

Restoring the state of the DataLoader

After resuming from a checkpoint, it may also be desirable to resume from a particular point in the active DataLoader if the state was saved during the middle of an epoch. You can use skip_first_batches() to do so.

from accelerate import Accelerator

accelerator = Accelerator(project_dir="my/save/path")

train_dataloader = accelerator.prepare(train_dataloader)
accelerator.load_state("my_state")

# Assume the checkpoint was saved 100 steps into the epoch
skipped_dataloader = accelerator.skip_first_batches(train_dataloader, 100)

# After the first iteration, go back to `train_dataloader`

# First epoch
for batch in skipped_dataloader:
    # Do something
    pass

# Second epoch
for batch in train_dataloader:
    # Do something
    pass
< > Update on GitHub