# Copyright (c) Facebook, Inc. and its affiliates. # Copyright (c) Meta Platforms, Inc. All Rights Reserved from typing import Tuple import torch from torch import nn from torch.nn import functional as F from detectron2.config import configurable from detectron2.data import MetadataCatalog from detectron2.modeling import META_ARCH_REGISTRY, build_backbone, build_sem_seg_head from detectron2.modeling.backbone import Backbone from detectron2.modeling.postprocessing import sem_seg_postprocess from detectron2.structures import ImageList from .modeling.criterion import SetCriterion from .modeling.matcher import HungarianMatcher @META_ARCH_REGISTRY.register() class MaskFormer(nn.Module): """ Main class for mask classification semantic segmentation architectures. """ @configurable def __init__( self, *, backbone: Backbone, sem_seg_head: nn.Module, criterion: nn.Module, num_queries: int, panoptic_on: bool, object_mask_threshold: float, overlap_threshold: float, metadata, size_divisibility: int, sem_seg_postprocess_before_inference: bool, pixel_mean: Tuple[float], pixel_std: Tuple[float], ): """ Args: backbone: a backbone module, must follow detectron2's backbone interface sem_seg_head: a module that predicts semantic segmentation from backbone features criterion: a module that defines the loss num_queries: int, number of queries panoptic_on: bool, whether to output panoptic segmentation prediction object_mask_threshold: float, threshold to filter query based on classification score for panoptic segmentation inference overlap_threshold: overlap threshold used in general inference for panoptic segmentation metadata: dataset meta, get `thing` and `stuff` category names for panoptic segmentation inference size_divisibility: Some backbones require the input height and width to be divisible by a specific integer. We can use this to override such requirement. sem_seg_postprocess_before_inference: whether to resize the prediction back to original input size before semantic segmentation inference or after. For high-resolution dataset like Mapillary, resizing predictions before inference will cause OOM error. pixel_mean, pixel_std: list or tuple with #channels element, representing the per-channel mean and std to be used to normalize the input image """ super().__init__() self.backbone = backbone self.sem_seg_head = sem_seg_head self.criterion = criterion self.num_queries = num_queries self.overlap_threshold = overlap_threshold self.panoptic_on = panoptic_on self.object_mask_threshold = object_mask_threshold self.metadata = metadata if size_divisibility < 0: # use backbone size_divisibility if not set size_divisibility = self.backbone.size_divisibility self.size_divisibility = size_divisibility self.sem_seg_postprocess_before_inference = sem_seg_postprocess_before_inference self.register_buffer( "pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False ) self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) @classmethod def from_config(cls, cfg): backbone = build_backbone(cfg) sem_seg_head = build_sem_seg_head(cfg, backbone.output_shape()) # Loss parameters: deep_supervision = cfg.MODEL.MASK_FORMER.DEEP_SUPERVISION no_object_weight = cfg.MODEL.MASK_FORMER.NO_OBJECT_WEIGHT dice_weight = cfg.MODEL.MASK_FORMER.DICE_WEIGHT mask_weight = cfg.MODEL.MASK_FORMER.MASK_WEIGHT # building criterion matcher = HungarianMatcher( cost_class=1, cost_mask=mask_weight, cost_dice=dice_weight, ) weight_dict = {"loss_ce": 1, "loss_mask": mask_weight, "loss_dice": dice_weight} if deep_supervision: dec_layers = cfg.MODEL.MASK_FORMER.DEC_LAYERS aux_weight_dict = {} for i in range(dec_layers - 1): aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()}) weight_dict.update(aux_weight_dict) losses = ["labels", "masks"] criterion = SetCriterion( sem_seg_head.num_classes, matcher=matcher, weight_dict=weight_dict, eos_coef=no_object_weight, losses=losses, ) return { "backbone": backbone, "sem_seg_head": sem_seg_head, "criterion": criterion, "num_queries": cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES, "panoptic_on": cfg.MODEL.MASK_FORMER.TEST.PANOPTIC_ON, "object_mask_threshold": cfg.MODEL.MASK_FORMER.TEST.OBJECT_MASK_THRESHOLD, "overlap_threshold": cfg.MODEL.MASK_FORMER.TEST.OVERLAP_THRESHOLD, "metadata": MetadataCatalog.get(cfg.DATASETS.TRAIN[0]), "size_divisibility": cfg.MODEL.MASK_FORMER.SIZE_DIVISIBILITY, "sem_seg_postprocess_before_inference": ( cfg.MODEL.MASK_FORMER.TEST.SEM_SEG_POSTPROCESSING_BEFORE_INFERENCE or cfg.MODEL.MASK_FORMER.TEST.PANOPTIC_ON ), "pixel_mean": cfg.MODEL.PIXEL_MEAN, "pixel_std": cfg.MODEL.PIXEL_STD, } @property def device(self): return self.pixel_mean.device def forward(self, batched_inputs): """ Args: batched_inputs: a list, batched outputs of :class:`DatasetMapper`. Each item in the list contains the inputs for one image. For now, each item in the list is a dict that contains: * "image": Tensor, image in (C, H, W) format. * "instances": per-region ground truth * Other information that's included in the original dicts, such as: "height", "width" (int): the output resolution of the model (may be different from input resolution), used in inference. Returns: list[dict]: each dict has the results for one image. The dict contains the following keys: * "sem_seg": A Tensor that represents the per-pixel segmentation prediced by the head. The prediction has shape KxHxW that represents the logits of each class for each pixel. * "panoptic_seg": A tuple that represent panoptic output panoptic_seg (Tensor): of shape (height, width) where the values are ids for each segment. segments_info (list[dict]): Describe each segment in `panoptic_seg`. Each dict contains keys "id", "category_id", "isthing". """ images = [x["image"].to(self.device) for x in batched_inputs] images = [(x - self.pixel_mean) / self.pixel_std for x in images] images = ImageList.from_tensors(images, self.size_divisibility) features = self.backbone(images.tensor) outputs = self.sem_seg_head(features) if self.training: # mask classification target if "instances" in batched_inputs[0]: gt_instances = [x["instances"].to(self.device) for x in batched_inputs] targets = self.prepare_targets(gt_instances, images) else: targets = None # bipartite matching-based loss losses = self.criterion(outputs, targets) for k in list(losses.keys()): if k in self.criterion.weight_dict: losses[k] *= self.criterion.weight_dict[k] else: # remove this loss if not specified in `weight_dict` losses.pop(k) return losses else: mask_cls_results = outputs["pred_logits"] mask_pred_results = outputs["pred_masks"] # upsample masks mask_pred_results = F.interpolate( mask_pred_results, size=(images.tensor.shape[-2], images.tensor.shape[-1]), mode="bilinear", align_corners=False, ) processed_results = [] for mask_cls_result, mask_pred_result, input_per_image, image_size in zip( mask_cls_results, mask_pred_results, batched_inputs, images.image_sizes ): height = input_per_image.get("height", image_size[0]) width = input_per_image.get("width", image_size[1]) if self.sem_seg_postprocess_before_inference: mask_pred_result = sem_seg_postprocess( mask_pred_result, image_size, height, width ) # semantic segmentation inference r = self.semantic_inference(mask_cls_result, mask_pred_result) if not self.sem_seg_postprocess_before_inference: r = sem_seg_postprocess(r, image_size, height, width) processed_results.append({"sem_seg": r}) # panoptic segmentation inference if self.panoptic_on: panoptic_r = self.panoptic_inference( mask_cls_result, mask_pred_result ) processed_results[-1]["panoptic_seg"] = panoptic_r return processed_results def prepare_targets(self, targets, images): h, w = images.tensor.shape[-2:] new_targets = [] for targets_per_image in targets: # pad gt gt_masks = targets_per_image.gt_masks padded_masks = torch.zeros( (gt_masks.shape[0], h, w), dtype=gt_masks.dtype, device=gt_masks.device ) padded_masks[:, : gt_masks.shape[1], : gt_masks.shape[2]] = gt_masks new_targets.append( { "labels": targets_per_image.gt_classes, "masks": padded_masks, } ) return new_targets def semantic_inference(self, mask_cls, mask_pred): mask_cls = F.softmax(mask_cls, dim=-1)[..., :-1] mask_pred = mask_pred.sigmoid() semseg = torch.einsum("qc,qhw->chw", mask_cls, mask_pred) return semseg