Accelerate documentation

Checkpointing

You are viewing v0.12.0 version. A newer version v1.2.1 is available.
Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

Checkpointing

When training a PyTorch model with πŸ€— Accelerate, you may often want to save and continue a state of training. Doing so requires saving and loading the model, optimizer, RNG generators, and the GradScaler. Inside πŸ€— Accelerate are two convenience functions to achieve this quickly:

  • Use save_state() for saving everything mentioned above to a folder location
  • Use load_state() for loading everything stored from an earlier save_state

It should be noted that the expectation is that those states come from the same training script, they should not be from two separate scripts.

  • By using register_for_checkpointing(), you can register custom objects to be automatically stored or loaded from the two prior functions, so long as the object has a state_dict and a load_state_dict functionality. This could include objects such as a learning rate scheduler.

Below is a brief example using checkpointing to save and reload a state during training:

from accelerate import Accelerator
import torch

accelerator = Accelerator()

my_scheduler = torch.optim.lr_scheduler.StepLR(my_optimizer, step_size=1, gamma=0.99)
my_model, my_optimizer, my_training_dataloader = accelerate.prepare(my_model, my_optimizer, my_training_dataloader)

# Register the LR scheduler
accelerate.register_for_checkpointing(my_scheduler)

# Save the starting state
accelerate.save_state("my/save/path")

device = accelerator.device
my_model.to(device)

# Perform training
for epoch in range(num_epochs):
    for batch in my_training_dataloader:
        my_optimizer.zero_grad()
        inputs, targets = batch
        inputs = inputs.to(device)
        targets = targets.to(device)
        outputs = my_model(inputs)
        loss = my_loss_function(outputs, targets)
        accelerator.backward(loss)
        my_optimizer.step()
    my_scheduler.step()

# Restore previous state
accelerate.load_state("my/save/path")