OMG_Seg / seg /models /utils /load_checkpoint.py
HarborYuan's picture
add omg code
b34d1d6
raw
history blame
1.19 kB
from mmengine.runner.checkpoint import CheckpointLoader
def load_checkpoint_with_prefix(filename, prefix=None, map_location='cpu', logger='current'):
"""Load partial pretrained model with specific prefix.
Args:
prefix (str): The prefix of sub-module.
filename (str): Accept local filepath, URL, ``torchvision://xxx``,
``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
details.
map_location (str | None): Same as :func:`torch.load`.
Defaults to None.
logger: logger
Returns:
dict or OrderedDict: The loaded checkpoint.
"""
checkpoint = CheckpointLoader.load_checkpoint(filename, map_location=map_location, logger=logger)
if 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
else:
state_dict = checkpoint
if not prefix:
return state_dict
if not prefix.endswith('.'):
prefix += '.'
prefix_len = len(prefix)
state_dict = {
k[prefix_len:]: v
for k, v in state_dict.items() if k.startswith(prefix)
}
assert state_dict, f'{prefix} is not in the pretrained model'
return state_dict