# Copyright (c) Facebook, Inc. and its affiliates. import logging from typing import Callable, Dict, List, Optional, Tuple, Union import fvcore.nn.weight_init as weight_init from torch import nn from torch.nn import functional as F from detectron2.config import configurable from detectron2.layers import Conv2d, ShapeSpec, get_norm from detectron2.modeling import SEM_SEG_HEADS_REGISTRY from ..transformer_decoder.maskformer_transformer_decoder import StandardTransformerDecoder from ..pixel_decoder.fpn import build_pixel_decoder @SEM_SEG_HEADS_REGISTRY.register() class PerPixelBaselineHead(nn.Module): _version = 2 def _load_from_state_dict( self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs ): version = local_metadata.get("version", None) if version is None or version < 2: logger = logging.getLogger(__name__) # Do not warn if train from scratch scratch = True logger = logging.getLogger(__name__) for k in list(state_dict.keys()): newk = k if "sem_seg_head" in k and not k.startswith(prefix + "predictor"): newk = k.replace(prefix, prefix + "pixel_decoder.") # logger.warning(f"{k} ==> {newk}") if newk != k: state_dict[newk] = state_dict[k] del state_dict[k] scratch = False if not scratch: logger.warning( f"Weight format of {self.__class__.__name__} have changed! " "Please upgrade your models. Applying automatic conversion now ..." ) @configurable def __init__( self, input_shape: Dict[str, ShapeSpec], *, num_classes: int, pixel_decoder: nn.Module, loss_weight: float = 1.0, ignore_value: int = -1, ): """ NOTE: this interface is experimental. Args: input_shape: shapes (channels and stride) of the input features num_classes: number of classes to predict pixel_decoder: the pixel decoder module loss_weight: loss weight ignore_value: category id to be ignored during training. """ super().__init__() input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride) self.in_features = [k for k, v in input_shape] feature_strides = [v.stride for k, v in input_shape] feature_channels = [v.channels for k, v in input_shape] self.ignore_value = ignore_value self.common_stride = 4 self.loss_weight = loss_weight self.pixel_decoder = pixel_decoder self.predictor = Conv2d( self.pixel_decoder.mask_dim, num_classes, kernel_size=1, stride=1, padding=0 ) weight_init.c2_msra_fill(self.predictor) @classmethod def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]): return { "input_shape": { k: v for k, v in input_shape.items() if k in cfg.MODEL.SEM_SEG_HEAD.IN_FEATURES }, "ignore_value": cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE, "num_classes": cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES, "pixel_decoder": build_pixel_decoder(cfg, input_shape), "loss_weight": cfg.MODEL.SEM_SEG_HEAD.LOSS_WEIGHT, } def forward(self, features, targets=None): """ Returns: In training, returns (None, dict of losses) In inference, returns (CxHxW logits, {}) """ x = self.layers(features) if self.training: return None, self.losses(x, targets) else: x = F.interpolate( x, scale_factor=self.common_stride, mode="bilinear", align_corners=False ) return x, {} def layers(self, features): x, _, _ = self.pixel_decoder.forward_features(features) x = self.predictor(x) return x def losses(self, predictions, targets): predictions = predictions.float() # https://github.com/pytorch/pytorch/issues/48163 predictions = F.interpolate( predictions, scale_factor=self.common_stride, mode="bilinear", align_corners=False ) loss = F.cross_entropy( predictions, targets, reduction="mean", ignore_index=self.ignore_value ) losses = {"loss_sem_seg": loss * self.loss_weight} return losses @SEM_SEG_HEADS_REGISTRY.register() class PerPixelBaselinePlusHead(PerPixelBaselineHead): def _load_from_state_dict( self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs ): version = local_metadata.get("version", None) if version is None or version < 2: # Do not warn if train from scratch scratch = True logger = logging.getLogger(__name__) for k in list(state_dict.keys()): newk = k if "sem_seg_head" in k and not k.startswith(prefix + "predictor"): newk = k.replace(prefix, prefix + "pixel_decoder.") logger.debug(f"{k} ==> {newk}") if newk != k: state_dict[newk] = state_dict[k] del state_dict[k] scratch = False if not scratch: logger.warning( f"Weight format of {self.__class__.__name__} have changed! " "Please upgrade your models. Applying automatic conversion now ..." ) @configurable def __init__( self, input_shape: Dict[str, ShapeSpec], *, # extra parameters transformer_predictor: nn.Module, transformer_in_feature: str, deep_supervision: bool, # inherit parameters num_classes: int, pixel_decoder: nn.Module, loss_weight: float = 1.0, ignore_value: int = -1, ): """ NOTE: this interface is experimental. Args: input_shape: shapes (channels and stride) of the input features transformer_predictor: the transformer decoder that makes prediction transformer_in_feature: input feature name to the transformer_predictor deep_supervision: whether or not to add supervision to the output of every transformer decoder layer num_classes: number of classes to predict pixel_decoder: the pixel decoder module loss_weight: loss weight ignore_value: category id to be ignored during training. """ super().__init__( input_shape, num_classes=num_classes, pixel_decoder=pixel_decoder, loss_weight=loss_weight, ignore_value=ignore_value, ) del self.predictor self.predictor = transformer_predictor self.transformer_in_feature = transformer_in_feature self.deep_supervision = deep_supervision @classmethod def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]): ret = super().from_config(cfg, input_shape) ret["transformer_in_feature"] = cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE if cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE == "transformer_encoder": in_channels = cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM else: in_channels = input_shape[ret["transformer_in_feature"]].channels ret["transformer_predictor"] = StandardTransformerDecoder( cfg, in_channels, mask_classification=False ) ret["deep_supervision"] = cfg.MODEL.MASK_FORMER.DEEP_SUPERVISION return ret def forward(self, features, targets=None): """ Returns: In training, returns (None, dict of losses) In inference, returns (CxHxW logits, {}) """ x, aux_outputs = self.layers(features) if self.training: if self.deep_supervision: losses = self.losses(x, targets) for i, aux_output in enumerate(aux_outputs): losses["loss_sem_seg" + f"_{i}"] = self.losses( aux_output["pred_masks"], targets )["loss_sem_seg"] return None, losses else: return None, self.losses(x, targets) else: x = F.interpolate( x, scale_factor=self.common_stride, mode="bilinear", align_corners=False ) return x, {} def layers(self, features): mask_features, transformer_encoder_features, _ = self.pixel_decoder.forward_features(features) if self.transformer_in_feature == "transformer_encoder": assert ( transformer_encoder_features is not None ), "Please use the TransformerEncoderPixelDecoder." predictions = self.predictor(transformer_encoder_features, mask_features) else: predictions = self.predictor(features[self.transformer_in_feature], mask_features) if self.deep_supervision: return predictions["pred_masks"], predictions["aux_outputs"] else: return predictions["pred_masks"], None