Spaces:
Sleeping
Sleeping
import torch | |
import logging | |
logger = logging.getLogger('global') | |
def check_keys(model, pretrained_state_dict): | |
ckpt_keys = set(pretrained_state_dict.keys()) | |
model_keys = set(model.state_dict().keys()) | |
used_pretrained_keys = model_keys & ckpt_keys | |
unused_pretrained_keys = ckpt_keys - model_keys | |
missing_keys = model_keys - ckpt_keys | |
if len(missing_keys) > 0: | |
logger.info('[Warning] missing keys: {}'.format(missing_keys)) | |
logger.info('missing keys:{}'.format(len(missing_keys))) | |
if len(unused_pretrained_keys) > 0: | |
logger.info('[Warning] unused_pretrained_keys: {}'.format(unused_pretrained_keys)) | |
logger.info('unused checkpoint keys:{}'.format(len(unused_pretrained_keys))) | |
logger.info('used keys:{}'.format(len(used_pretrained_keys))) | |
assert len(used_pretrained_keys) > 0, 'load NONE from pretrained checkpoint' | |
return True | |
def remove_prefix(state_dict, prefix): | |
''' Old style model is stored with all names of parameters share common prefix 'module.' ''' | |
logger.info('remove prefix \'{}\''.format(prefix)) | |
f = lambda x: x.split(prefix, 1)[-1] if x.startswith(prefix) else x | |
return {f(key): value for key, value in state_dict.items()} | |
def load_pretrain(model, pretrained_path): | |
logger.info('load pretrained model from {}'.format(pretrained_path)) | |
if not torch.cuda.is_available(): | |
pretrained_dict = torch.load(pretrained_path, map_location=lambda storage, loc: storage) | |
else: | |
device = torch.cuda.current_device() | |
pretrained_dict = torch.load(pretrained_path, map_location=lambda storage, loc: storage.cuda(device)) | |
if "state_dict" in pretrained_dict.keys(): | |
pretrained_dict = remove_prefix(pretrained_dict['state_dict'], 'module.') | |
else: | |
pretrained_dict = remove_prefix(pretrained_dict, 'module.') | |
try: | |
check_keys(model, pretrained_dict) | |
except: | |
logger.info('[Warning]: using pretrain as features. Adding "features." as prefix') | |
new_dict = {} | |
for k, v in pretrained_dict.items(): | |
k = 'features.' + k | |
new_dict[k] = v | |
pretrained_dict = new_dict | |
check_keys(model, pretrained_dict) | |
model.load_state_dict(pretrained_dict, strict=False) | |
return model | |
def restore_from(model, optimizer, ckpt_path): | |
logger.info('restore from {}'.format(ckpt_path)) | |
device = torch.cuda.current_device() | |
ckpt = torch.load(ckpt_path, map_location=lambda storage, loc: storage.cuda(device)) | |
epoch = ckpt['epoch'] | |
best_acc = ckpt['best_acc'] | |
arch = ckpt['arch'] | |
ckpt_model_dict = remove_prefix(ckpt['state_dict'], 'module.') | |
check_keys(model, ckpt_model_dict) | |
model.load_state_dict(ckpt_model_dict, strict=False) | |
check_keys(optimizer, ckpt['optimizer']) | |
optimizer.load_state_dict(ckpt['optimizer']) | |
return model, optimizer, epoch, best_acc, arch | |