zejunyang
update
9667e74
raw
history blame
No virus
1.34 kB
import collections
def load_state(net, checkpoint):
source_state = checkpoint['state_dict']
target_state = net.state_dict()
new_target_state = collections.OrderedDict()
for target_key, target_value in target_state.items():
if target_key in source_state and source_state[target_key].size() == target_state[target_key].size():
new_target_state[target_key] = source_state[target_key]
else:
new_target_state[target_key] = target_state[target_key]
print('[WARNING] Not found pre-trained parameters for {}'.format(target_key))
net.load_state_dict(new_target_state)
def load_from_mobilenet(net, checkpoint):
source_state = checkpoint['state_dict']
target_state = net.state_dict()
new_target_state = collections.OrderedDict()
for target_key, target_value in target_state.items():
k = target_key
if k.find('model') != -1:
k = k.replace('model', 'module.model')
if k in source_state and source_state[k].size() == target_state[target_key].size():
new_target_state[target_key] = source_state[k]
else:
new_target_state[target_key] = target_state[target_key]
print('[WARNING] Not found pre-trained parameters for {}'.format(target_key))
net.load_state_dict(new_target_state)