When training a PyTorch model with 🤗 Accelerate, you may often want to save and continue a state of training. Doing so requires saving and loading the model, optimizer, RNG generators, and the GradScaler. Inside 🤗 Accelerate are two convenience functions to achieve this quickly:
- Use save_state() for saving everything mentioned above to a folder location
- Use load_state() for loading everything stored from an earlier
It should be noted that the expectation is that those states come from the same training script, they should not be from two separate scripts.
- By using register_for_checkpointing(), you can register custom objects to be automatically stored or loaded from the two prior functions,
so long as the object has a
load_state_dictfunctionality. This could include objects such as a learning rate scheduler.
Below is a brief example using checkpointing to save and reload a state during training:
from accelerate import Accelerator import torch accelerator = Accelerator() my_scheduler = torch.optim.lr_scheduler.StepLR(my_optimizer, step_size=1, gamma=0.99) my_model, my_optimizer, my_training_dataloader = accelerate.prepare(my_model, my_optimizer, my_training_dataloader) # Register the LR scheduler accelerate.register_for_checkpointing(my_scheduler) # Save the starting state accelerate.save_state("my/save/path") device = accelerator.device my_model.to(device) # Perform training for epoch in range(num_epochs): for batch in my_training_dataloader: my_optimizer.zero_grad() inputs, targets = batch inputs = inputs.to(device) targets = targets.to(device) outputs = my_model(inputs) loss = my_loss_function(outputs, targets) accelerator.backward(loss) my_optimizer.step() my_scheduler.step() # Restore previous state accelerate.load_state("my/save/path")