Checkpointing
When training a PyTorch model with π€ Accelerate, you may often want to save and continue a state of training. Doing so requires saving and loading the model, optimizer, RNG generators, and the GradScaler. Inside π€ Accelerate are two convenience functions to achieve this quickly:
- Use save_state() for saving everything mentioned above to a folder location
- Use load_state() for loading everything stored from an earlier
save_state
To further customize where and how states saved through save_state() the ProjectConfiguration class can be used. For example
if automatic_checkpoint_naming
is enabled each saved checkpoint will be located then at Accelerator.project_dir/checkpoints/checkpoint_{checkpoint_number}
.
It should be noted that the expectation is that those states come from the same training script, they should not be from two separate scripts.
- By using register_for_checkpointing(), you can register custom objects to be automatically stored or loaded from the two prior functions,
so long as the object has a
state_dict
and aload_state_dict
functionality. This could include objects such as a learning rate scheduler.
Below is a brief example using checkpointing to save and reload a state during training:
from accelerate import Accelerator
import torch
accelerator = Accelerator(project_dir="my/save/path")
my_scheduler = torch.optim.lr_scheduler.StepLR(my_optimizer, step_size=1, gamma=0.99)
my_model, my_optimizer, my_training_dataloader = accelerator.prepare(my_model, my_optimizer, my_training_dataloader)
# Register the LR scheduler
accelerator.register_for_checkpointing(my_scheduler)
# Save the starting state
accelerator.save_state()
device = accelerator.device
my_model.to(device)
# Perform training
for epoch in range(num_epochs):
for batch in my_training_dataloader:
my_optimizer.zero_grad()
inputs, targets = batch
inputs = inputs.to(device)
targets = targets.to(device)
outputs = my_model(inputs)
loss = my_loss_function(outputs, targets)
accelerator.backward(loss)
my_optimizer.step()
my_scheduler.step()
# Restore previous state
accelerator.load_state("my/save/path/checkpointing/checkpoint_0")