zhigangjiang's picture
no message
88b0dcb
raw history blame
No virus
5.88 kB
"""
@Date: 2021/07/17
@description:
"""
import os
import torch
import torch.nn as nn
import datetime
class BaseModule(nn.Module):
def __init__(self, ckpt_dir=None):
super().__init__()
self.ckpt_dir = ckpt_dir
if ckpt_dir:
if not os.path.exists(ckpt_dir):
os.makedirs(ckpt_dir)
else:
self.model_lst = [x for x in sorted(os.listdir(self.ckpt_dir)) if x.endswith('.pkl')]
self.last_model_path = None
self.best_model_path = None
self.best_accuracy = -float('inf')
self.acc_d = {}
def show_parameter_number(self, logger):
total = sum(p.numel() for p in self.parameters())
trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
logger.info('{} parameter total:{:,}, trainable:{:,}'.format(self._get_name(), total, trainable))
def load(self, device, logger, optimizer=None, best=False):
if len(self.model_lst) == 0:
logger.info('*'*50)
logger.info("Empty model folder! Using initial weights")
logger.info('*'*50)
return 0
last_model_lst = list(filter(lambda n: '_last_' in n, self.model_lst))
best_model_lst = list(filter(lambda n: '_best_' in n, self.model_lst))
if len(last_model_lst) == 0 and len(best_model_lst) == 0:
logger.info('*'*50)
ckpt_path = os.path.join(self.ckpt_dir, self.model_lst[0])
logger.info(f"Load: {ckpt_path}")
checkpoint = torch.load(ckpt_path, map_location=torch.device(device))
self.load_state_dict(checkpoint, strict=False)
logger.info('*'*50)
return 0
checkpoint = None
if len(last_model_lst) > 0:
self.last_model_path = os.path.join(self.ckpt_dir, last_model_lst[-1])
checkpoint = torch.load(self.last_model_path, map_location=torch.device(device))
self.best_accuracy = checkpoint['accuracy']
self.acc_d = checkpoint['acc_d']
if len(best_model_lst) > 0:
self.best_model_path = os.path.join(self.ckpt_dir, best_model_lst[-1])
best_checkpoint = torch.load(self.best_model_path, map_location=torch.device(device))
self.best_accuracy = best_checkpoint['accuracy']
self.acc_d = best_checkpoint['acc_d']
if best:
checkpoint = best_checkpoint
for k in self.acc_d:
if isinstance(self.acc_d[k], float):
self.acc_d[k] = {
'acc': self.acc_d[k],
'epoch': checkpoint['epoch']
}
if checkpoint is None:
logger.error("Invalid checkpoint")
return
self.load_state_dict(checkpoint['net'], strict=False)
if optimizer and not best: # best的时候使用新的优化器比如从adam->sgd
logger.info('Load optimizer')
optimizer.load_state_dict(checkpoint['optimizer'])
for state in optimizer.state.values():
for k, v in state.items():
if torch.is_tensor(v):
state[k] = v.to(device)
logger.info('*'*50)
if best:
logger.info(f"Lode best: {self.best_model_path}")
else:
logger.info(f"Lode last: {self.last_model_path}")
logger.info(f"Best accuracy: {self.best_accuracy}")
logger.info(f"Last epoch: {checkpoint['epoch'] + 1}")
logger.info('*'*50)
return checkpoint['epoch'] + 1
def update_acc(self, acc_d, epoch, logger):
logger.info("-" * 100)
for k in acc_d:
if k not in self.acc_d.keys() or acc_d[k] > self.acc_d[k]['acc']:
self.acc_d[k] = {
'acc': acc_d[k],
'epoch': epoch
}
logger.info(f"Update ACC: {k} {self.acc_d[k]['acc']:.4f}({self.acc_d[k]['epoch']}-{epoch})")
logger.info("-" * 100)
def save(self, optim, epoch, accuracy, logger, replace=True, acc_d=None, config=None):
"""
:param config:
:param optim:
:param epoch:
:param accuracy:
:param logger:
:param replace:
:param acc_d: 其他评估数据,visible_2/3d, full_2/3d, rmse...
:return:
"""
if acc_d:
self.update_acc(acc_d, epoch, logger)
name = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S_last_{:.4f}_{}'.format(accuracy, epoch))
name = f"model_{name}.pkl"
checkpoint = {
'net': self.state_dict(),
'optimizer': optim.state_dict(),
'epoch': epoch,
'accuracy': accuracy,
'acc_d': acc_d
}
# FIXME:: delete always true
if (True or config.MODEL.SAVE_LAST) and epoch % config.TRAIN.SAVE_FREQ == 0:
if replace and self.last_model_path and os.path.exists(self.last_model_path):
os.remove(self.last_model_path)
self.last_model_path = os.path.join(self.ckpt_dir, name)
torch.save(checkpoint, self.last_model_path)
logger.info(f"Saved last model: {self.last_model_path}")
if accuracy > self.best_accuracy:
self.best_accuracy = accuracy
# FIXME:: delete always true
if True or config.MODEL.SAVE_BEST:
if replace and self.best_model_path and os.path.exists(self.best_model_path):
os.remove(self.best_model_path)
self.best_model_path = os.path.join(self.ckpt_dir, name.replace('last', 'best'))
torch.save(checkpoint, self.best_model_path)
logger.info("#" * 100)
logger.info(f"Saved best model: {self.best_model_path}")
logger.info("#" * 100)