import os import torch from torch.nn.parallel.data_parallel import DataParallel from torch.nn.parallel.distributed import DistributedDataParallel from loguru import logger class CheckPoint: def __init__(self, dir=None, name="tmp"): self.name = name self.dir = dir os.makedirs(self.dir, exist_ok=True) def __call__( self, model, optimizer, lr_scheduler, n, ): assert model is not None if isinstance(model, (DataParallel, DistributedDataParallel)): model = model.module states = { "model": model.state_dict(), "n": n, "optimizer": optimizer.state_dict(), "lr_scheduler": lr_scheduler.state_dict(), } torch.save(states, self.dir + self.name + f"_latest.pth") logger.info(f"Saved states {list(states.keys())}, at step {n}")