Spaces:
Running
Running
import os | |
import torch | |
from torch.nn.parallel.data_parallel import DataParallel | |
from torch.nn.parallel.distributed import DistributedDataParallel | |
from loguru import logger | |
import gc | |
import roma | |
class CheckPoint: | |
def __init__(self, dir=None, name="tmp"): | |
self.name = name | |
self.dir = dir | |
os.makedirs(self.dir, exist_ok=True) | |
def save( | |
self, | |
model, | |
optimizer, | |
lr_scheduler, | |
n, | |
): | |
if roma.RANK == 0: | |
assert model is not None | |
if isinstance(model, (DataParallel, DistributedDataParallel)): | |
model = model.module | |
states = { | |
"model": model.state_dict(), | |
"n": n, | |
"optimizer": optimizer.state_dict(), | |
"lr_scheduler": lr_scheduler.state_dict(), | |
} | |
torch.save(states, self.dir + self.name + f"_latest.pth") | |
logger.info(f"Saved states {list(states.keys())}, at step {n}") | |
def load( | |
self, | |
model, | |
optimizer, | |
lr_scheduler, | |
n, | |
): | |
if os.path.exists(self.dir + self.name + f"_latest.pth") and roma.RANK == 0: | |
states = torch.load(self.dir + self.name + f"_latest.pth") | |
if "model" in states: | |
model.load_state_dict(states["model"]) | |
if "n" in states: | |
n = states["n"] if states["n"] else n | |
if "optimizer" in states: | |
try: | |
optimizer.load_state_dict(states["optimizer"]) | |
except Exception as e: | |
print(f"Failed to load states for optimizer, with error {e}") | |
if "lr_scheduler" in states: | |
lr_scheduler.load_state_dict(states["lr_scheduler"]) | |
print(f"Loaded states {list(states.keys())}, at step {n}") | |
del states | |
gc.collect() | |
torch.cuda.empty_cache() | |
return model, optimizer, lr_scheduler, n |