Spaces:
Running
on
T4
Running
on
T4
import os | |
import torch | |
import torch.nn as nn | |
from mono.utils.comm import main_process | |
import copy | |
import inspect | |
import logging | |
import glob | |
def load_ckpt(load_path, model, optimizer=None, scheduler=None, strict_match=True, loss_scaler=None): | |
""" | |
Load the check point for resuming training or finetuning. | |
""" | |
logger = logging.getLogger() | |
if os.path.isfile(load_path): | |
if main_process(): | |
logger.info(f"Loading weight '{load_path}'") | |
checkpoint = torch.load(load_path, map_location="cpu") | |
ckpt_state_dict = checkpoint['model_state_dict'] | |
try: | |
model.module.load_state_dict(ckpt_state_dict, strict=strict_match) | |
except: | |
model.load_state_dict(ckpt_state_dict, strict=strict_match) | |
if optimizer is not None: | |
optimizer.load_state_dict(checkpoint['optimizer']) | |
if scheduler is not None: | |
scheduler.load_state_dict(checkpoint['scheduler']) | |
if loss_scaler is not None and 'scaler' in checkpoint: | |
scheduler.load_state_dict(checkpoint['scaler']) | |
del ckpt_state_dict | |
del checkpoint | |
if main_process(): | |
logger.info(f"Successfully loaded weight: '{load_path}'") | |
if scheduler is not None and optimizer is not None: | |
logger.info(f"Resume training from: '{load_path}'") | |
else: | |
if main_process(): | |
raise RuntimeError(f"No weight found at '{load_path}'") | |
return model, optimizer, scheduler, loss_scaler | |
def save_ckpt(cfg, model, optimizer, scheduler, curr_iter=0, curr_epoch=None, loss_scaler=None): | |
""" | |
Save the model, optimizer, lr scheduler. | |
""" | |
logger = logging.getLogger() | |
if 'IterBasedRunner' in cfg.runner.type: | |
max_iters = cfg.runner.max_iters | |
elif 'EpochBasedRunner' in cfg.runner.type: | |
max_iters = cfg.runner.max_epochs | |
else: | |
raise TypeError(f'{cfg.runner.type} is not supported') | |
ckpt = dict( | |
model_state_dict=model.module.state_dict(), | |
optimizer=optimizer.state_dict(), | |
max_iter=cfg.runner.max_iters if 'max_iters' in cfg.runner \ | |
else cfg.runner.max_epochs, | |
scheduler=scheduler.state_dict(), | |
) | |
if loss_scaler is not None: | |
ckpt.update(dict(scaler=loss_scaler.state_dict())) | |
ckpt_dir = os.path.join(cfg.work_dir, 'ckpt') | |
os.makedirs(ckpt_dir, exist_ok=True) | |
save_name = os.path.join(ckpt_dir, 'step%08d.pth' %curr_iter) | |
saved_ckpts = glob.glob(ckpt_dir + '/step*.pth') | |
torch.save(ckpt, save_name) | |
# keep the last 8 ckpts | |
if len(saved_ckpts) > 20: | |
saved_ckpts.sort() | |
os.remove(saved_ckpts.pop(0)) | |
logger.info(f'Save model: {save_name}') | |