Metric3D / mono /utils /running.py
JUGGHM's picture
Update mono/utils/running.py
4622be5
raw
history blame
2.75 kB
import os
import torch
import torch.nn as nn
from mono.utils.comm import main_process
import copy
import inspect
import logging
import glob
def load_ckpt(load_path, model, optimizer=None, scheduler=None, strict_match=True, loss_scaler=None):
"""
Load the check point for resuming training or finetuning.
"""
logger = logging.getLogger()
if os.path.isfile(load_path):
if main_process():
logger.info(f"Loading weight '{load_path}'")
checkpoint = torch.load(load_path, map_location="cpu")
ckpt_state_dict = checkpoint['model_state_dict']
try:
model.module.load_state_dict(ckpt_state_dict, strict=strict_match)
except:
model.load_state_dict(ckpt_state_dict, strict=strict_match)
if optimizer is not None:
optimizer.load_state_dict(checkpoint['optimizer'])
if scheduler is not None:
scheduler.load_state_dict(checkpoint['scheduler'])
if loss_scaler is not None and 'scaler' in checkpoint:
scheduler.load_state_dict(checkpoint['scaler'])
del ckpt_state_dict
del checkpoint
if main_process():
logger.info(f"Successfully loaded weight: '{load_path}'")
if scheduler is not None and optimizer is not None:
logger.info(f"Resume training from: '{load_path}'")
else:
if main_process():
raise RuntimeError(f"No weight found at '{load_path}'")
return model, optimizer, scheduler, loss_scaler
def save_ckpt(cfg, model, optimizer, scheduler, curr_iter=0, curr_epoch=None, loss_scaler=None):
"""
Save the model, optimizer, lr scheduler.
"""
logger = logging.getLogger()
if 'IterBasedRunner' in cfg.runner.type:
max_iters = cfg.runner.max_iters
elif 'EpochBasedRunner' in cfg.runner.type:
max_iters = cfg.runner.max_epochs
else:
raise TypeError(f'{cfg.runner.type} is not supported')
ckpt = dict(
model_state_dict=model.module.state_dict(),
optimizer=optimizer.state_dict(),
max_iter=cfg.runner.max_iters if 'max_iters' in cfg.runner \
else cfg.runner.max_epochs,
scheduler=scheduler.state_dict(),
)
if loss_scaler is not None:
ckpt.update(dict(scaler=loss_scaler.state_dict()))
ckpt_dir = os.path.join(cfg.work_dir, 'ckpt')
os.makedirs(ckpt_dir, exist_ok=True)
save_name = os.path.join(ckpt_dir, 'step%08d.pth' %curr_iter)
saved_ckpts = glob.glob(ckpt_dir + '/step*.pth')
torch.save(ckpt, save_name)
# keep the last 8 ckpts
if len(saved_ckpts) > 20:
saved_ckpts.sort()
os.remove(saved_ckpts.pop(0))
logger.info(f'Save model: {save_name}')