|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
import numpy as np |
|
|
|
|
|
def load_model(model, model_path, optimizer=None, resume=False, |
|
lr=None, lr_step=None): |
|
start_epoch = 0 |
|
checkpoint = torch.load(model_path, map_location=lambda storage, loc: storage) |
|
print(f'loaded {model_path}') |
|
state_dict = checkpoint['model'] |
|
model_state_dict = model.state_dict() |
|
|
|
|
|
msg = 'If you see this, your model does not fully load the ' + \ |
|
'pre-trained weight. Please make sure ' + \ |
|
'you set the correct --num_classes for your own dataset.' |
|
for k in state_dict: |
|
if k in model_state_dict: |
|
if state_dict[k].shape != model_state_dict[k].shape: |
|
print('Skip loading parameter {}, required shape{}, ' \ |
|
'loaded shape{}. {}'.format( |
|
k, model_state_dict[k].shape, state_dict[k].shape, msg)) |
|
if 'class_embed' in k: |
|
print("load class_embed: {} shape={}".format(k, state_dict[k].shape)) |
|
if model_state_dict[k].shape[0] == 1: |
|
state_dict[k] = state_dict[k][1:2] |
|
elif model_state_dict[k].shape[0] == 2: |
|
state_dict[k] = state_dict[k][1:3] |
|
elif model_state_dict[k].shape[0] == 3: |
|
state_dict[k] = state_dict[k][1:4] |
|
else: |
|
raise NotImplementedError('invalid shape: {}'.format(model_state_dict[k].shape)) |
|
continue |
|
state_dict[k] = model_state_dict[k] |
|
else: |
|
print('Drop parameter {}.'.format(k) + msg) |
|
for k in model_state_dict: |
|
if not (k in state_dict): |
|
print('No param {}.'.format(k) + msg) |
|
state_dict[k] = model_state_dict[k] |
|
model.load_state_dict(state_dict, strict=False) |
|
|
|
|
|
if optimizer is not None and resume: |
|
if 'optimizer' in checkpoint: |
|
optimizer.load_state_dict(checkpoint['optimizer']) |
|
start_epoch = checkpoint['epoch'] |
|
start_lr = lr |
|
for step in lr_step: |
|
if start_epoch >= step: |
|
start_lr *= 0.1 |
|
for param_group in optimizer.param_groups: |
|
param_group['lr'] = start_lr |
|
print('Resumed optimizer with start lr', start_lr) |
|
else: |
|
print('No optimizer parameters in checkpoint.') |
|
if optimizer is not None: |
|
return model, optimizer, start_epoch |
|
else: |
|
return model |
|
|
|
|
|
|
|
|