Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. | |
import logging | |
import os | |
import pickle | |
import torch | |
import torch.nn as nn | |
from termcolor import colored | |
from collections import defaultdict | |
from typing import Any, Dict, Iterable, List, NamedTuple, Optional, Tuple | |
from fvcore.common.checkpoint import Checkpointer, _IncompatibleKeys | |
from torch.nn.parallel import DistributedDataParallel | |
import detectron2.utils.comm as comm | |
from detectron2.utils.file_io import PathManager | |
from .c2_model_loading import align_and_update_state_dicts | |
class DetectionCheckpointer(Checkpointer): | |
""" | |
Same as :class:`Checkpointer`, but is able to: | |
1. handle models in detectron & detectron2 model zoo, and apply conversions for legacy models. | |
2. correctly load checkpoints that are only available on the master worker | |
""" | |
def __init__(self, model, save_dir="", *, save_to_disk=None, **checkpointables): | |
is_main_process = comm.is_main_process() | |
super().__init__( | |
model, | |
save_dir, | |
save_to_disk=is_main_process if save_to_disk is None else save_to_disk, | |
**checkpointables, | |
) | |
self.path_manager = PathManager | |
def load(self, path, *args, **kwargs): | |
need_sync = False | |
if path and isinstance(self.model, DistributedDataParallel): | |
logger = logging.getLogger(__name__) | |
path = self.path_manager.get_local_path(path) | |
has_file = os.path.isfile(path) | |
all_has_file = comm.all_gather(has_file) | |
if not all_has_file[0]: | |
raise OSError(f"File {path} not found on main worker.") | |
if not all(all_has_file): | |
logger.warning( | |
f"Not all workers can read checkpoint {path}. " | |
"Training may fail to fully resume." | |
) | |
# TODO: broadcast the checkpoint file contents from main | |
# worker, and load from it instead. | |
need_sync = True | |
if not has_file: | |
path = None # don't load if not readable | |
ret = super().load(path, *args, **kwargs) | |
if need_sync: | |
logger.info("Broadcasting model states from main worker ...") | |
self.model._sync_params_and_buffers() | |
return ret | |
def _load_file(self, filename): | |
if filename.endswith(".pkl"): | |
with PathManager.open(filename, "rb") as f: | |
data = pickle.load(f, encoding="latin1") | |
if "model" in data and "__author__" in data: | |
# file is in Detectron2 model zoo format | |
self.logger.info("Reading a file from '{}'".format(data["__author__"])) | |
return data | |
else: | |
# assume file is from Caffe2 / Detectron1 model zoo | |
if "blobs" in data: | |
# Detection models have "blobs", but ImageNet models don't | |
data = data["blobs"] | |
data = {k: v for k, v in data.items() if not k.endswith("_momentum")} | |
return {"model": data, "__author__": "Caffe2", "matching_heuristics": True} | |
elif filename.endswith(".pyth"): | |
# assume file is from pycls; no one else seems to use the ".pyth" extension | |
with PathManager.open(filename, "rb") as f: | |
data = torch.load(f) | |
assert ( | |
"model_state" in data | |
), f"Cannot load .pyth file {filename}; pycls checkpoints must contain 'model_state'." | |
model_state = { | |
k: v | |
for k, v in data["model_state"].items() | |
if not k.endswith("num_batches_tracked") | |
} | |
return {"model": model_state, "__author__": "pycls", "matching_heuristics": True} | |
loaded = super()._load_file(filename) # load native pth checkpoint | |
if "model" not in loaded: | |
loaded = {"model": loaded} | |
loaded["matching_heuristics"] = True | |
return loaded | |
def _load_model(self, checkpoint): | |
if checkpoint.get("matching_heuristics", False): | |
self._convert_ndarray_to_tensor(checkpoint["model"]) | |
# convert weights by name-matching heuristics | |
checkpoint["model"] = align_and_update_state_dicts( | |
self.model.state_dict(), | |
checkpoint["model"], | |
c2_conversion=checkpoint.get("__author__", None) == "Caffe2", | |
) | |
# for non-caffe2 models, use standard ways to load it | |
incompatible = super()._load_model(checkpoint) | |
model_buffers = dict(self.model.named_buffers(recurse=False)) | |
for k in ["pixel_mean", "pixel_std"]: | |
# Ignore missing key message about pixel_mean/std. | |
# Though they may be missing in old checkpoints, they will be correctly | |
# initialized from config anyway. | |
if k in model_buffers: | |
try: | |
incompatible.missing_keys.remove(k) | |
except ValueError: | |
pass | |
for k in incompatible.unexpected_keys[:]: | |
# Ignore unexpected keys about cell anchors. They exist in old checkpoints | |
# but now they are non-persistent buffers and will not be in new checkpoints. | |
if "anchor_generator.cell_anchors" in k: | |
incompatible.unexpected_keys.remove(k) | |
return incompatible | |
def _log_incompatible_keys(self, incompatible: _IncompatibleKeys) -> None: | |
""" | |
Log information about the incompatible keys returned by ``_load_model``. | |
""" | |
for k, shape_checkpoint, shape_model in incompatible.incorrect_shapes: | |
self.logger.warning( | |
"Skip loading parameter '{}' to the model due to incompatible " | |
"shapes: {} in the checkpoint but {} in the " | |
"model! You might want to double check if this is expected.".format( | |
k, shape_checkpoint, shape_model | |
) | |
) | |
if incompatible.missing_keys: | |
missing_keys = _filter_reused_missing_keys( | |
self.model, incompatible.missing_keys | |
) | |
if missing_keys: | |
self.logger.warning(get_missing_parameters_message(missing_keys)) | |
if incompatible.unexpected_keys: | |
self.logger.warning( | |
get_unexpected_parameters_message(incompatible.unexpected_keys) | |
) | |
def _filter_reused_missing_keys(model: nn.Module, keys: List[str]) -> List[str]: | |
""" | |
Filter "missing keys" to not include keys that have been loaded with another name. | |
""" | |
keyset = set(keys) | |
param_to_names = defaultdict(set) # param -> names that points to it | |
for module_prefix, module in _named_modules_with_dup(model): | |
for name, param in list(module.named_parameters(recurse=False)) + list( | |
module.named_buffers(recurse=False) | |
): | |
full_name = (module_prefix + "." if module_prefix else "") + name | |
param_to_names[param].add(full_name) | |
for names in param_to_names.values(): | |
# if one name appears missing but its alias exists, then this | |
# name is not considered missing | |
if any(n in keyset for n in names) and not all(n in keyset for n in names): | |
[keyset.remove(n) for n in names if n in keyset] | |
return list(keyset) | |
def get_missing_parameters_message(keys: List[str]) -> str: | |
""" | |
Get a logging-friendly message to report parameter names (keys) that are in | |
the model but not found in a checkpoint. | |
Args: | |
keys (list[str]): List of keys that were not found in the checkpoint. | |
Returns: | |
str: message. | |
""" | |
groups = _group_checkpoint_keys(keys) | |
msg_per_group = sorted(k + _group_to_str(v) for k, v in groups.items()) | |
msg = "Some model parameters or buffers are not found in the checkpoint:\n" | |
msg += "\n".join([colored(x, "blue") for x in msg_per_group]) | |
return msg | |
def _group_checkpoint_keys(keys: List[str]) -> Dict[str, List[str]]: | |
""" | |
Group keys based on common prefixes. A prefix is the string up to the final | |
"." in each key. | |
Args: | |
keys (list[str]): list of parameter names, i.e. keys in the model | |
checkpoint dict. | |
Returns: | |
dict[list]: keys with common prefixes are grouped into lists. | |
""" | |
groups = defaultdict(list) | |
for key in keys: | |
pos = key.rfind(".") | |
if pos >= 0: | |
head, tail = key[:pos], [key[pos + 1 :]] | |
else: | |
head, tail = key, [] | |
groups[head].extend(tail) | |
return groups | |
def _group_to_str(group: List[str]) -> str: | |
""" | |
Format a group of parameter name suffixes into a loggable string. | |
Args: | |
group (list[str]): list of parameter name suffixes. | |
Returns: | |
str: formated string. | |
""" | |
if len(group) == 0: | |
return "" | |
if len(group) == 1: | |
return "." + group[0] | |
return ".{" + ", ".join(sorted(group)) + "}" | |
def get_unexpected_parameters_message(keys: List[str]) -> str: | |
""" | |
Get a logging-friendly message to report parameter names (keys) that are in | |
the checkpoint but not found in the model. | |
Args: | |
keys (list[str]): List of keys that were not found in the model. | |
Returns: | |
str: message. | |
""" | |
groups = _group_checkpoint_keys(keys) | |
msg = "The checkpoint state_dict contains keys that are not used by the model:\n" | |
msg += "\n".join( | |
" " + colored(k + _group_to_str(v), "magenta") for k, v in groups.items() | |
) | |
return msg | |
def _named_modules_with_dup( | |
model: nn.Module, prefix: str = "" | |
) -> Iterable[Tuple[str, nn.Module]]: | |
""" | |
The same as `model.named_modules()`, except that it includes | |
duplicated modules that have more than one name. | |
""" | |
yield prefix, model | |
for name, module in model._modules.items(): | |
if module is None: | |
continue | |
submodule_prefix = prefix + ("." if prefix else "") + name | |
yield from _named_modules_with_dup(module, submodule_prefix) |