Vincentqyw
add: roma
c608946
raw
history blame
907 Bytes
import os
import torch
from torch.nn.parallel.data_parallel import DataParallel
from torch.nn.parallel.distributed import DistributedDataParallel
from loguru import logger
class CheckPoint:
def __init__(self, dir=None, name="tmp"):
self.name = name
self.dir = dir
os.makedirs(self.dir, exist_ok=True)
def __call__(
self,
model,
optimizer,
lr_scheduler,
n,
):
assert model is not None
if isinstance(model, (DataParallel, DistributedDataParallel)):
model = model.module
states = {
"model": model.state_dict(),
"n": n,
"optimizer": optimizer.state_dict(),
"lr_scheduler": lr_scheduler.state_dict(),
}
torch.save(states, self.dir + self.name + f"_latest.pth")
logger.info(f"Saved states {list(states.keys())}, at step {n}")