# Copyright (c) Facebook, Inc. and its affiliates. import logging from detectron2.modeling import ROI_HEADS_REGISTRY, StandardROIHeads @ROI_HEADS_REGISTRY.register() class PointRendROIHeads(StandardROIHeads): """ Identical to StandardROIHeads, except for some weights conversion code to handle old models. """ _version = 2 def _load_from_state_dict( self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs ): version = local_metadata.get("version", None) if version is None or version < 2: logger = logging.getLogger(__name__) logger.warning( "Weight format of PointRend models have changed! " "Please upgrade your models. Applying automatic conversion now ..." ) for k in list(state_dict.keys()): newk = k if k.startswith(prefix + "mask_point_head"): newk = k.replace(prefix + "mask_point_head", prefix + "mask_head.point_head") if k.startswith(prefix + "mask_coarse_head"): newk = k.replace(prefix + "mask_coarse_head", prefix + "mask_head.coarse_head") if newk != k: state_dict[newk] = state_dict[k] del state_dict[k] @classmethod def _init_mask_head(cls, cfg, input_shape): if cfg.MODEL.MASK_ON and cfg.MODEL.ROI_MASK_HEAD.NAME != "PointRendMaskHead": logger = logging.getLogger(__name__) logger.warning( "Config of PointRend models have changed! " "Please upgrade your models. Applying automatic conversion now ..." ) assert cfg.MODEL.ROI_MASK_HEAD.NAME == "CoarseMaskHead" cfg.defrost() cfg.MODEL.ROI_MASK_HEAD.NAME = "PointRendMaskHead" cfg.MODEL.ROI_MASK_HEAD.POOLER_TYPE = "" cfg.freeze() return super()._init_mask_head(cfg, input_shape)