from accelerate import Accelerator
accelerator = Accelerator()
dataloader, model, optimizer scheduler = accelerator.prepare(
        dataloader, model, optimizer, scheduler
)

for batch in dataloader:
    optimizer.zero_grad()
    inputs, targets = batch
    outputs = model(inputs)
    loss = loss_function(outputs, targets)
    accelerator.backward(loss)
    optimizer.step()
    scheduler.step()
+accelerator.save_state("checkpoint_dir")