# Copyright (c) Facebook, Inc. and its affiliates. import logging import os import pickle import torch from fvcore.common.checkpoint import Checkpointer from torch.nn.parallel import DistributedDataParallel import detectron2.utils.comm as comm from detectron2.utils.env import TORCH_VERSION from detectron2.utils.file_io import PathManager from .c2_model_loading import align_and_update_state_dicts from .clip_model_loading import align_and_update_state_dicts_for_CLIP class DetectionCheckpointer(Checkpointer): """ Same as :class:`Checkpointer`, but is able to: 1. handle models in detectron & detectron2 model zoo, and apply conversions for legacy models. 2. correctly load checkpoints that are only available on the master worker """ def __init__(self, model, save_dir="", *, save_to_disk=None, bb_rpn_weights=False, **checkpointables): is_main_process = comm.is_main_process() super().__init__( model, save_dir, save_to_disk=is_main_process if save_to_disk is None else save_to_disk, **checkpointables, ) self.path_manager = PathManager self.bb_rpn_weights = bb_rpn_weights def load(self, path, *args, **kwargs): need_sync = False if path and isinstance(self.model, DistributedDataParallel): logger = logging.getLogger(__name__) path = self.path_manager.get_local_path(path) has_file = os.path.isfile(path) all_has_file = comm.all_gather(has_file) if not all_has_file[0]: raise OSError(f"File {path} not found on main worker.") if not all(all_has_file): logger.warning( f"Not all workers can read checkpoint {path}. " "Training may fail to fully resume." ) # TODO: broadcast the checkpoint file contents from main # worker, and load from it instead. need_sync = True if not has_file: path = None # don't load if not readable ret = super().load(path, *args, **kwargs) if need_sync: logger.info("Broadcasting model states from main worker ...") if TORCH_VERSION >= (1, 7): self.model._sync_params_and_buffers() return ret def _load_file(self, filename): if filename.endswith(".pkl"): with PathManager.open(filename, "rb") as f: data = pickle.load(f, encoding="latin1") if "model" in data and "__author__" in data: # file is in Detectron2 model zoo format self.logger.info("Reading a file from '{}'".format(data["__author__"])) return data else: # assume file is from Caffe2 / Detectron1 model zoo if "blobs" in data: # Detection models have "blobs", but ImageNet models don't data = data["blobs"] data = {k: v for k, v in data.items() if not k.endswith("_momentum")} return {"model": data, "__author__": "Caffe2", "matching_heuristics": True} elif filename.endswith(".pyth"): # assume file is from pycls; no one else seems to use the ".pyth" extension with PathManager.open(filename, "rb") as f: data = torch.load(f) assert ( "model_state" in data ), f"Cannot load .pyth file {filename}; pycls checkpoints must contain 'model_state'." model_state = { k: v for k, v in data["model_state"].items() if not k.endswith("num_batches_tracked") } return {"model": model_state, "__author__": "pycls", "matching_heuristics": True} elif "OAI_CLIP" in filename: # assume file is from OpenAI CLIP pre-trained model loaded = super()._load_file(filename) # load native pth checkpoint if "model" not in loaded: loaded = {"model": loaded} return {"model": loaded["model"], "__author__": "OAI_CLIP", "matching_heuristics": True} loaded = super()._load_file(filename) # load native pth checkpoint if "model" not in loaded: loaded = {"model": loaded} return loaded def _load_model(self, checkpoint): # if checkpoint.get("matching_heuristics", False) or self.bb_rpn_weights: # self._convert_ndarray_to_tensor(checkpoint["model"]) # # convert weights by name-matching heuristics # if checkpoint.get("__author__", "NA") == "OAI_CLIP" or self.bb_rpn_weights: # for OAI_CLIP or 2nd ckpt (offline modules) # checkpoint["model"] = align_and_update_state_dicts_for_CLIP( # self.model.state_dict(), # checkpoint["model"], # bb_rpn_weights=self.bb_rpn_weights, # ) # else: # default loading # checkpoint["model"] = align_and_update_state_dicts( # self.model.state_dict(), # checkpoint["model"], # c2_conversion=checkpoint.get("__author__", None) == "Caffe2", # ) # for non-caffe2 models, use standard ways to load it # if not self.bb_rpn_weights: # checkpoint = {'model': {'backbone.' + key: val for key, val in checkpoint['model'].items()}} incompatible = super()._load_model(checkpoint) del checkpoint # try saving memory model_buffers = dict(self.model.named_buffers(recurse=False)) for k in ["pixel_mean", "pixel_std"]: # Ignore missing key message about pixel_mean/std. # Though they may be missing in old checkpoints, they will be correctly # initialized from config anyway. if k in model_buffers: try: incompatible.missing_keys.remove(k) except ValueError: pass return incompatible