When training a PyTorch model with 🤗 Accelerate, you may often want to save and continue a state of training. Doing so requires saving and loading the model, optimizer, RNG generators, and the GradScaler. Inside 🤗 Accelerate are two convenience functions to achieve this quickly:
- Use save_state() for saving everything mentioned above to a folder location
- Use load_state() for loading everything stored from an earlier
To further customize where and how states are saved through save_state() the ProjectConfiguration class can be used. For example
automatic_checkpoint_naming is enabled each saved checkpoint will be located then at
It should be noted that the expectation is that those states come from the same training script, they should not be from two separate scripts.
- By using register_for_checkpointing(), you can register custom objects to be automatically stored or loaded from the two prior functions,
so long as the object has a
load_state_dictfunctionality. This could include objects such as a learning rate scheduler.
Below is a brief example using checkpointing to save and reload a state during training:
from accelerate import Accelerator import torch accelerator = Accelerator(project_dir="my/save/path") my_scheduler = torch.optim.lr_scheduler.StepLR(my_optimizer, step_size=1, gamma=0.99) my_model, my_optimizer, my_training_dataloader = accelerator.prepare(my_model, my_optimizer, my_training_dataloader) # Register the LR scheduler accelerator.register_for_checkpointing(my_scheduler) # Save the starting state accelerator.save_state() device = accelerator.device my_model.to(device) # Perform training for epoch in range(num_epochs): for batch in my_training_dataloader: my_optimizer.zero_grad() inputs, targets = batch inputs = inputs.to(device) targets = targets.to(device) outputs = my_model(inputs) loss = my_loss_function(outputs, targets) accelerator.backward(loss) my_optimizer.step() my_scheduler.step() # Restore the previous state accelerator.load_state("my/save/path/checkpointing/checkpoint_0")
After resuming from a checkpoint, it may also be desirable to resume from a particular point in the active
the state was saved during the middle of an epoch. You can use skip_first_batches() to do so.
from accelerate import Accelerator accelerator = Accelerator(project_dir="my/save/path") train_dataloader = accelerator.prepare(train_dataloader) accelerator.load_state("my_state") # Assume the checkpoint was saved 100 steps into the epoch skipped_dataloader = accelerator.skip_first_batches(train_dataloader, 100) # After the first iteration, go back to `train_dataloader` # First epoch for batch in skipped_dataloader: # Do something pass # Second epoch for batch in train_dataloader: # Do something pass