import os import torch from torch.nn.parallel.data_parallel import DataParallel from torch.nn.parallel.distributed import DistributedDataParallel from loguru import logger import gc import roma class CheckPoint: def __init__(self, dir=None, name="tmp"): self.name = name self.dir = dir os.makedirs(self.dir, exist_ok=True) def save( self, model, optimizer, lr_scheduler, n, ): if roma.RANK == 0: assert model is not None if isinstance(model, (DataParallel, DistributedDataParallel)): model = model.module states = { "model": model.state_dict(), "n": n, "optimizer": optimizer.state_dict(), "lr_scheduler": lr_scheduler.state_dict(), } torch.save(states, self.dir + self.name + f"_latest.pth") logger.info(f"Saved states {list(states.keys())}, at step {n}") def load( self, model, optimizer, lr_scheduler, n, ): if os.path.exists(self.dir + self.name + f"_latest.pth") and roma.RANK == 0: states = torch.load(self.dir + self.name + f"_latest.pth") if "model" in states: model.load_state_dict(states["model"]) if "n" in states: n = states["n"] if states["n"] else n if "optimizer" in states: try: optimizer.load_state_dict(states["optimizer"]) except Exception as e: print(f"Failed to load states for optimizer, with error {e}") if "lr_scheduler" in states: lr_scheduler.load_state_dict(states["lr_scheduler"]) print(f"Loaded states {list(states.keys())}, at step {n}") del states gc.collect() torch.cuda.empty_cache() return model, optimizer, lr_scheduler, n