Spaces:
Running
on
A10G
Running
on
A10G
# ------------------------------------------------------------------------------ | |
# Copyright (c) Microsoft | |
# Licensed under the MIT License. | |
# The code is from Swin Transformer. | |
# (https://github.com/SwinTransformer/Swin-Transformer-Object-Detection/blob/master/mmcv_custom/checkpoint.py) | |
# ------------------------------------------------------------------------------ | |
import io | |
import os | |
import os.path as osp | |
import pkgutil | |
import time | |
import warnings | |
import numpy as np | |
from collections import OrderedDict | |
from importlib import import_module | |
from tempfile import TemporaryDirectory | |
from scipy import interpolate | |
import torch | |
import torchvision | |
import torch.distributed as dist | |
from torch.optim import Optimizer | |
from torch.utils import model_zoo | |
from torch.nn import functional as F | |
import mmcv | |
from mmcv.fileio import FileClient | |
from mmcv.fileio import load as load_file | |
from mmcv.parallel import is_module_wrapper | |
from mmcv.utils import mkdir_or_exist | |
from mmcv.runner import get_dist_info | |
from mmcv.utils import get_logger | |
import logging | |
def get_root_logger(log_file=None, log_level=logging.INFO): | |
"""Get the root logger. | |
The logger will be initialized if it has not been initialized. By default a | |
StreamHandler will be added. If `log_file` is specified, a FileHandler will | |
also be added. The name of the root logger is the top-level package name, | |
e.g., "mmseg". | |
Args: | |
log_file (str | None): The log filename. If specified, a FileHandler | |
will be added to the root logger. | |
log_level (int): The root logger level. Note that only the process of | |
rank 0 is affected, while other processes will set the level to | |
"Error" and be silent most of the time. | |
Returns: | |
logging.Logger: The root logger. | |
""" | |
logger = get_logger(name='mmpose', log_file=log_file, log_level=log_level) | |
return logger | |
def _get_mmcv_home(): | |
mmcv_home = os.path.expanduser( | |
os.getenv( | |
ENV_MMCV_HOME, | |
os.path.join( | |
os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'mmcv'))) | |
mkdir_or_exist(mmcv_home) | |
return mmcv_home | |
def load_state_dict(module, state_dict, strict=False, logger=None): | |
"""Load state_dict to a module. | |
This method is modified from :meth:`torch.nn.Module.load_state_dict`. | |
Default value for ``strict`` is set to ``False`` and the message for | |
param mismatch will be shown even if strict is False. | |
Args: | |
module (Module): Module that receives the state_dict. | |
state_dict (OrderedDict): Weights. | |
strict (bool): whether to strictly enforce that the keys | |
in :attr:`state_dict` match the keys returned by this module's | |
:meth:`~torch.nn.Module.state_dict` function. Default: ``False``. | |
logger (:obj:`logging.Logger`, optional): Logger to log the error | |
message. If not specified, print function will be used. | |
""" | |
unexpected_keys = [] | |
all_missing_keys = [] | |
err_msg = [] | |
metadata = getattr(state_dict, '_metadata', None) | |
state_dict = state_dict.copy() | |
if metadata is not None: | |
state_dict._metadata = metadata | |
# use _load_from_state_dict to enable checkpoint version control | |
def load(module, prefix=''): | |
# recursively check parallel module in case that the model has a | |
# complicated structure, e.g., nn.Module(nn.Module(DDP)) | |
if is_module_wrapper(module): | |
module = module.module | |
local_metadata = {} if metadata is None else metadata.get( | |
prefix[:-1], {}) | |
module._load_from_state_dict(state_dict, prefix, local_metadata, True, | |
all_missing_keys, unexpected_keys, | |
err_msg) | |
for name, child in module._modules.items(): | |
if child is not None: | |
load(child, prefix + name + '.') | |
load(module) | |
load = None # break load->load reference cycle | |
# ignore "num_batches_tracked" of BN layers | |
missing_keys = [ | |
key for key in all_missing_keys if 'num_batches_tracked' not in key | |
] | |
if unexpected_keys: | |
err_msg.append('unexpected key in source ' | |
f'state_dict: {", ".join(unexpected_keys)}\n') | |
if missing_keys: | |
err_msg.append( | |
f'missing keys in source state_dict: {", ".join(missing_keys)}\n') | |
rank, _ = get_dist_info() | |
if len(err_msg) > 0 and rank == 0: | |
err_msg.insert( | |
0, 'The model and loaded state dict do not match exactly\n') | |
err_msg = '\n'.join(err_msg) | |
if strict: | |
raise RuntimeError(err_msg) | |
elif logger is not None: | |
logger.warning(err_msg) | |
else: | |
print(err_msg) | |
def load_url_dist(url, model_dir=None): | |
"""In distributed setting, this function only download checkpoint at local | |
rank 0.""" | |
rank, world_size = get_dist_info() | |
rank = int(os.environ.get('LOCAL_RANK', rank)) | |
if rank == 0: | |
checkpoint = model_zoo.load_url(url, model_dir=model_dir) | |
if world_size > 1: | |
torch.distributed.barrier() | |
if rank > 0: | |
checkpoint = model_zoo.load_url(url, model_dir=model_dir) | |
return checkpoint | |
def load_pavimodel_dist(model_path, map_location=None): | |
"""In distributed setting, this function only download checkpoint at local | |
rank 0.""" | |
try: | |
from pavi import modelcloud | |
except ImportError: | |
raise ImportError( | |
'Please install pavi to load checkpoint from modelcloud.') | |
rank, world_size = get_dist_info() | |
rank = int(os.environ.get('LOCAL_RANK', rank)) | |
if rank == 0: | |
model = modelcloud.get(model_path) | |
with TemporaryDirectory() as tmp_dir: | |
downloaded_file = osp.join(tmp_dir, model.name) | |
model.download(downloaded_file) | |
checkpoint = torch.load(downloaded_file, map_location=map_location) | |
if world_size > 1: | |
torch.distributed.barrier() | |
if rank > 0: | |
model = modelcloud.get(model_path) | |
with TemporaryDirectory() as tmp_dir: | |
downloaded_file = osp.join(tmp_dir, model.name) | |
model.download(downloaded_file) | |
checkpoint = torch.load( | |
downloaded_file, map_location=map_location) | |
return checkpoint | |
def load_fileclient_dist(filename, backend, map_location): | |
"""In distributed setting, this function only download checkpoint at local | |
rank 0.""" | |
rank, world_size = get_dist_info() | |
rank = int(os.environ.get('LOCAL_RANK', rank)) | |
allowed_backends = ['ceph'] | |
if backend not in allowed_backends: | |
raise ValueError(f'Load from Backend {backend} is not supported.') | |
if rank == 0: | |
fileclient = FileClient(backend=backend) | |
buffer = io.BytesIO(fileclient.get(filename)) | |
checkpoint = torch.load(buffer, map_location=map_location) | |
if world_size > 1: | |
torch.distributed.barrier() | |
if rank > 0: | |
fileclient = FileClient(backend=backend) | |
buffer = io.BytesIO(fileclient.get(filename)) | |
checkpoint = torch.load(buffer, map_location=map_location) | |
return checkpoint | |
def get_torchvision_models(): | |
model_urls = dict() | |
for _, name, ispkg in pkgutil.walk_packages(torchvision.models.__path__): | |
if ispkg: | |
continue | |
_zoo = import_module(f'torchvision.models.{name}') | |
if hasattr(_zoo, 'model_urls'): | |
_urls = getattr(_zoo, 'model_urls') | |
model_urls.update(_urls) | |
return model_urls | |
def get_external_models(): | |
mmcv_home = _get_mmcv_home() | |
default_json_path = osp.join(mmcv.__path__[0], 'model_zoo/open_mmlab.json') | |
default_urls = load_file(default_json_path) | |
assert isinstance(default_urls, dict) | |
external_json_path = osp.join(mmcv_home, 'open_mmlab.json') | |
if osp.exists(external_json_path): | |
external_urls = load_file(external_json_path) | |
assert isinstance(external_urls, dict) | |
default_urls.update(external_urls) | |
return default_urls | |
def get_mmcls_models(): | |
mmcls_json_path = osp.join(mmcv.__path__[0], 'model_zoo/mmcls.json') | |
mmcls_urls = load_file(mmcls_json_path) | |
return mmcls_urls | |
def get_deprecated_model_names(): | |
deprecate_json_path = osp.join(mmcv.__path__[0], | |
'model_zoo/deprecated.json') | |
deprecate_urls = load_file(deprecate_json_path) | |
assert isinstance(deprecate_urls, dict) | |
return deprecate_urls | |
def _process_mmcls_checkpoint(checkpoint): | |
state_dict = checkpoint['state_dict'] | |
new_state_dict = OrderedDict() | |
for k, v in state_dict.items(): | |
if k.startswith('backbone.'): | |
new_state_dict[k[9:]] = v | |
new_checkpoint = dict(state_dict=new_state_dict) | |
return new_checkpoint | |
def _load_checkpoint(filename, map_location=None): | |
"""Load checkpoint from somewhere (modelzoo, file, url). | |
Args: | |
filename (str): Accept local filepath, URL, ``torchvision://xxx``, | |
``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for | |
details. | |
map_location (str | None): Same as :func:`torch.load`. Default: None. | |
Returns: | |
dict | OrderedDict: The loaded checkpoint. It can be either an | |
OrderedDict storing model weights or a dict containing other | |
information, which depends on the checkpoint. | |
""" | |
if filename.startswith('modelzoo://'): | |
warnings.warn('The URL scheme of "modelzoo://" is deprecated, please ' | |
'use "torchvision://" instead') | |
model_urls = get_torchvision_models() | |
model_name = filename[11:] | |
checkpoint = load_url_dist(model_urls[model_name]) | |
elif filename.startswith('torchvision://'): | |
model_urls = get_torchvision_models() | |
model_name = filename[14:] | |
checkpoint = load_url_dist(model_urls[model_name]) | |
elif filename.startswith('open-mmlab://'): | |
model_urls = get_external_models() | |
model_name = filename[13:] | |
deprecated_urls = get_deprecated_model_names() | |
if model_name in deprecated_urls: | |
warnings.warn(f'open-mmlab://{model_name} is deprecated in favor ' | |
f'of open-mmlab://{deprecated_urls[model_name]}') | |
model_name = deprecated_urls[model_name] | |
model_url = model_urls[model_name] | |
# check if is url | |
if model_url.startswith(('http://', 'https://')): | |
checkpoint = load_url_dist(model_url) | |
else: | |
filename = osp.join(_get_mmcv_home(), model_url) | |
if not osp.isfile(filename): | |
raise IOError(f'{filename} is not a checkpoint file') | |
checkpoint = torch.load(filename, map_location=map_location) | |
elif filename.startswith('mmcls://'): | |
model_urls = get_mmcls_models() | |
model_name = filename[8:] | |
checkpoint = load_url_dist(model_urls[model_name]) | |
checkpoint = _process_mmcls_checkpoint(checkpoint) | |
elif filename.startswith(('http://', 'https://')): | |
checkpoint = load_url_dist(filename) | |
elif filename.startswith('pavi://'): | |
model_path = filename[7:] | |
checkpoint = load_pavimodel_dist(model_path, map_location=map_location) | |
elif filename.startswith('s3://'): | |
checkpoint = load_fileclient_dist( | |
filename, backend='ceph', map_location=map_location) | |
else: | |
if not osp.isfile(filename): | |
raise IOError(f'{filename} is not a checkpoint file') | |
checkpoint = torch.load(filename, map_location=map_location) | |
return checkpoint | |
def load_checkpoint_swin(model, | |
filename, | |
map_location='cpu', | |
strict=False, | |
rpe_interpolation='outer_mask', | |
logger=None): | |
"""Load checkpoint from a file or URI. | |
Args: | |
model (Module): Module to load checkpoint. | |
filename (str): Accept local filepath, URL, ``torchvision://xxx``, | |
``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for | |
details. | |
map_location (str): Same as :func:`torch.load`. | |
strict (bool): Whether to allow different params for the model and | |
checkpoint. | |
logger (:mod:`logging.Logger` or None): The logger for error message. | |
Returns: | |
dict or OrderedDict: The loaded checkpoint. | |
""" | |
checkpoint = _load_checkpoint(filename, map_location) | |
# OrderedDict is a subclass of dict | |
if not isinstance(checkpoint, dict): | |
raise RuntimeError( | |
f'No state_dict found in checkpoint file {filename}') | |
# get state_dict from checkpoint | |
if 'state_dict' in checkpoint: | |
state_dict = checkpoint['state_dict'] | |
elif 'model' in checkpoint: | |
state_dict = checkpoint['model'] | |
elif 'module' in checkpoint: | |
state_dict = checkpoint['module'] | |
else: | |
state_dict = checkpoint | |
# strip prefix of state_dict | |
if list(state_dict.keys())[0].startswith('module.'): | |
state_dict = {k[7:]: v for k, v in state_dict.items()} | |
# for MoBY, load model of online branch | |
if sorted(list(state_dict.keys()))[2].startswith('encoder'): | |
state_dict = {k.replace('encoder.', ''): v for k, v in state_dict.items() if k.startswith('encoder.')} | |
# reshape absolute position embedding for Swin | |
if state_dict.get('absolute_pos_embed') is not None: | |
absolute_pos_embed = state_dict['absolute_pos_embed'] | |
N1, L, C1 = absolute_pos_embed.size() | |
N2, C2, H, W = model.absolute_pos_embed.size() | |
if N1 != N2 or C1 != C2 or L != H * W: | |
logger.warning("Error in loading absolute_pos_embed, pass") | |
else: | |
state_dict['absolute_pos_embed'] = absolute_pos_embed.view(N2, H, W, C2).permute(0, 3, 1, 2) | |
# interpolate position bias table if needed | |
relative_position_bias_table_keys = [k for k in state_dict.keys() if "relative_position_bias_table" in k] | |
for k in relative_position_bias_table_keys: | |
table_pretrained = state_dict[k] | |
table_current = model.state_dict()[k] | |
L1, nH1 = table_pretrained.size() | |
L2, nH2 = table_current.size() | |
if nH1 != nH2: | |
logger.warning(f"Error in loading {k}, pass") | |
else: | |
if L1 != L2: | |
if rpe_interpolation in ['bicubic', 'bilinear', 'nearest']: | |
logger.info(f"Interpolate relative_position_bias_table using {rpe_interpolation}") | |
S1 = int(L1 ** 0.5) | |
S2 = int(L2 ** 0.5) | |
table_pretrained_resized = F.interpolate( | |
table_pretrained.permute(1, 0).view(1, nH1, S1, S1), | |
size=(S2, S2), mode=rpe_interpolation) | |
state_dict[k] = table_pretrained_resized.view(nH2, L2).permute(1, 0) | |
elif rpe_interpolation == 'geo': | |
logger.info("Interpolate relative_position_bias_table using geo.") | |
src_size = int(L1 ** 0.5) | |
dst_size = int(L2 ** 0.5) | |
def geometric_progression(a, r, n): | |
return a * (1.0 - r ** n) / (1.0 - r) | |
left, right = 1.01, 1.5 | |
while right - left > 1e-6: | |
q = (left + right) / 2.0 | |
gp = geometric_progression(1, q, src_size // 2) | |
if gp > dst_size // 2: | |
right = q | |
else: | |
left = q | |
# if q > 1.13492: | |
# q = 1.13492 | |
dis = [] | |
cur = 1 | |
for i in range(src_size // 2): | |
dis.append(cur) | |
cur += q ** (i + 1) | |
r_ids = [-_ for _ in reversed(dis)] | |
x = r_ids + [0] + dis | |
y = r_ids + [0] + dis | |
t = dst_size // 2.0 | |
dx = np.arange(-t, t + 0.1, 1.0) | |
dy = np.arange(-t, t + 0.1, 1.0) | |
logger.info("Original positions = %s" % str(x)) | |
logger.info("Target positions = %s" % str(dx)) | |
all_rel_pos_bias = [] | |
for i in range(nH1): | |
z = table_pretrained[:, i].view(src_size, src_size).float().numpy() | |
f_cubic = interpolate.interp2d(x, y, z, kind='cubic') | |
all_rel_pos_bias.append(torch.Tensor(f_cubic(dx, dy)).contiguous().view(-1, 1).to( | |
table_pretrained.device)) | |
new_rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1) | |
state_dict[k] = new_rel_pos_bias | |
if 'pos_embed' in state_dict: | |
pos_embed_checkpoint = state_dict['pos_embed'] | |
embedding_size = pos_embed_checkpoint.shape[-1] | |
num_patches = model.patch_embed.num_patches | |
num_extra_tokens = model.pos_embed.shape[-2] - num_patches | |
# height (== width) for the checkpoint position embedding | |
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) | |
# height (== width) for the new position embedding | |
new_size = int(num_patches ** 0.5) | |
# class_token and dist_token are kept unchanged | |
if orig_size != new_size: | |
if dist.get_rank() == 0: | |
print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) | |
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] | |
# only the position tokens are interpolated | |
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] | |
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) | |
pos_tokens = torch.nn.functional.interpolate( | |
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) | |
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) | |
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) | |
state_dict['pos_embed'] = new_pos_embed | |
# load state_dict | |
load_state_dict(model, state_dict, strict, logger) | |
return checkpoint | |
def weights_to_cpu(state_dict): | |
"""Copy a model state_dict to cpu. | |
Args: | |
state_dict (OrderedDict): Model weights on GPU. | |
Returns: | |
OrderedDict: Model weights on GPU. | |
""" | |
state_dict_cpu = OrderedDict() | |
for key, val in state_dict.items(): | |
state_dict_cpu[key] = val.cpu() | |
return state_dict_cpu | |
def _save_to_state_dict(module, destination, prefix, keep_vars): | |
"""Saves module state to `destination` dictionary. | |
This method is modified from :meth:`torch.nn.Module._save_to_state_dict`. | |
Args: | |
module (nn.Module): The module to generate state_dict. | |
destination (dict): A dict where state will be stored. | |
prefix (str): The prefix for parameters and buffers used in this | |
module. | |
""" | |
for name, param in module._parameters.items(): | |
if param is not None: | |
destination[prefix + name] = param if keep_vars else param.detach() | |
for name, buf in module._buffers.items(): | |
# remove check of _non_persistent_buffers_set to allow nn.BatchNorm2d | |
if buf is not None: | |
destination[prefix + name] = buf if keep_vars else buf.detach() | |
def get_state_dict(module, destination=None, prefix='', keep_vars=False): | |
"""Returns a dictionary containing a whole state of the module. | |
Both parameters and persistent buffers (e.g. running averages) are | |
included. Keys are corresponding parameter and buffer names. | |
This method is modified from :meth:`torch.nn.Module.state_dict` to | |
recursively check parallel module in case that the model has a complicated | |
structure, e.g., nn.Module(nn.Module(DDP)). | |
Args: | |
module (nn.Module): The module to generate state_dict. | |
destination (OrderedDict): Returned dict for the state of the | |
module. | |
prefix (str): Prefix of the key. | |
keep_vars (bool): Whether to keep the variable property of the | |
parameters. Default: False. | |
Returns: | |
dict: A dictionary containing a whole state of the module. | |
""" | |
# recursively check parallel module in case that the model has a | |
# complicated structure, e.g., nn.Module(nn.Module(DDP)) | |
if is_module_wrapper(module): | |
module = module.module | |
# below is the same as torch.nn.Module.state_dict() | |
if destination is None: | |
destination = OrderedDict() | |
destination._metadata = OrderedDict() | |
destination._metadata[prefix[:-1]] = local_metadata = dict( | |
version=module._version) | |
_save_to_state_dict(module, destination, prefix, keep_vars) | |
for name, child in module._modules.items(): | |
if child is not None: | |
get_state_dict( | |
child, destination, prefix + name + '.', keep_vars=keep_vars) | |
for hook in module._state_dict_hooks.values(): | |
hook_result = hook(module, destination, prefix, local_metadata) | |
if hook_result is not None: | |
destination = hook_result | |
return destination | |
def save_checkpoint(model, filename, optimizer=None, meta=None): | |
"""Save checkpoint to file. | |
The checkpoint will have 3 fields: ``meta``, ``state_dict`` and | |
``optimizer``. By default ``meta`` will contain version and time info. | |
Args: | |
model (Module): Module whose params are to be saved. | |
filename (str): Checkpoint filename. | |
optimizer (:obj:`Optimizer`, optional): Optimizer to be saved. | |
meta (dict, optional): Metadata to be saved in checkpoint. | |
""" | |
if meta is None: | |
meta = {} | |
elif not isinstance(meta, dict): | |
raise TypeError(f'meta must be a dict or None, but got {type(meta)}') | |
meta.update(mmcv_version=mmcv.__version__, time=time.asctime()) | |
if is_module_wrapper(model): | |
model = model.module | |
if hasattr(model, 'CLASSES') and model.CLASSES is not None: | |
# save class name to the meta | |
meta.update(CLASSES=model.CLASSES) | |
checkpoint = { | |
'meta': meta, | |
'state_dict': weights_to_cpu(get_state_dict(model)) | |
} | |
# save optimizer state dict in the checkpoint | |
if isinstance(optimizer, Optimizer): | |
checkpoint['optimizer'] = optimizer.state_dict() | |
elif isinstance(optimizer, dict): | |
checkpoint['optimizer'] = {} | |
for name, optim in optimizer.items(): | |
checkpoint['optimizer'][name] = optim.state_dict() | |
if filename.startswith('pavi://'): | |
try: | |
from pavi import modelcloud | |
from pavi.exception import NodeNotFoundError | |
except ImportError: | |
raise ImportError( | |
'Please install pavi to load checkpoint from modelcloud.') | |
model_path = filename[7:] | |
root = modelcloud.Folder() | |
model_dir, model_name = osp.split(model_path) | |
try: | |
model = modelcloud.get(model_dir) | |
except NodeNotFoundError: | |
model = root.create_training_model(model_dir) | |
with TemporaryDirectory() as tmp_dir: | |
checkpoint_file = osp.join(tmp_dir, model_name) | |
with open(checkpoint_file, 'wb') as f: | |
torch.save(checkpoint, f) | |
f.flush() | |
model.create_file(checkpoint_file, name=model_name) | |
else: | |
mmcv.mkdir_or_exist(osp.dirname(filename)) | |
# immediately flush buffer | |
with open(filename, 'wb') as f: | |
torch.save(checkpoint, f) | |
f.flush() | |