# Copyright (c) Facebook, Inc. and its affiliates. import logging import numpy as np from typing import Dict, List, Optional, Tuple import torch from torch import nn from detectron2.config import configurable from detectron2.data.detection_utils import convert_image_to_rgb from detectron2.layers import move_device_like from detectron2.structures import ImageList, Instances from detectron2.utils.events import get_event_storage from detectron2.utils.logger import log_first_n from ..backbone import Backbone, build_backbone from ..postprocessing import detector_postprocess from ..proposal_generator import build_proposal_generator from ..roi_heads import build_roi_heads from .build import META_ARCH_REGISTRY __all__ = ["GeneralizedRCNN", "ProposalNetwork"] @META_ARCH_REGISTRY.register() class GeneralizedRCNN(nn.Module): """ Generalized R-CNN. Any models that contains the following three components: 1. Per-image feature extraction (aka backbone) 2. Region proposal generation 3. Per-region feature extraction and prediction """ @configurable def __init__( self, *, backbone: Backbone, proposal_generator: nn.Module, roi_heads: nn.Module, pixel_mean: Tuple[float], pixel_std: Tuple[float], input_format: Optional[str] = None, vis_period: int = 0, ): """ Args: backbone: a backbone module, must follow detectron2's backbone interface proposal_generator: a module that generates proposals using backbone features roi_heads: a ROI head that performs per-region computation pixel_mean, pixel_std: list or tuple with #channels element, representing the per-channel mean and std to be used to normalize the input image input_format: describe the meaning of channels of input. Needed by visualization vis_period: the period to run visualization. Set to 0 to disable. """ super().__init__() self.backbone = backbone self.proposal_generator = proposal_generator self.roi_heads = roi_heads self.input_format = input_format self.vis_period = vis_period if vis_period > 0: assert input_format is not None, "input_format is required for visualization!" self.register_buffer("pixel_mean", torch.tensor(pixel_mean).view(-1, 1, 1), False) self.register_buffer("pixel_std", torch.tensor(pixel_std).view(-1, 1, 1), False) assert ( self.pixel_mean.shape == self.pixel_std.shape ), f"{self.pixel_mean} and {self.pixel_std} have different shapes!" @classmethod def from_config(cls, cfg): backbone = build_backbone(cfg) return { "backbone": backbone, "proposal_generator": build_proposal_generator(cfg, backbone.output_shape()), "roi_heads": build_roi_heads(cfg, backbone.output_shape()), "input_format": cfg.INPUT.FORMAT, "vis_period": cfg.VIS_PERIOD, "pixel_mean": cfg.MODEL.PIXEL_MEAN, "pixel_std": cfg.MODEL.PIXEL_STD, } @property def device(self): return self.pixel_mean.device def _move_to_current_device(self, x): return move_device_like(x, self.pixel_mean) def visualize_training(self, batched_inputs, proposals): """ A function used to visualize images and proposals. It shows ground truth bounding boxes on the original image and up to 20 top-scoring predicted object proposals on the original image. Users can implement different visualization functions for different models. Args: batched_inputs (list): a list that contains input to the model. proposals (list): a list that contains predicted proposals. Both batched_inputs and proposals should have the same length. """ from detectron2.utils.visualizer import Visualizer storage = get_event_storage() max_vis_prop = 20 for input, prop in zip(batched_inputs, proposals): img = input["image"] img = convert_image_to_rgb(img.permute(1, 2, 0), self.input_format) v_gt = Visualizer(img, None) v_gt = v_gt.overlay_instances(boxes=input["instances"].gt_boxes) anno_img = v_gt.get_image() box_size = min(len(prop.proposal_boxes), max_vis_prop) v_pred = Visualizer(img, None) v_pred = v_pred.overlay_instances( boxes=prop.proposal_boxes[0:box_size].tensor.cpu().numpy() ) prop_img = v_pred.get_image() vis_img = np.concatenate((anno_img, prop_img), axis=1) vis_img = vis_img.transpose(2, 0, 1) vis_name = "Left: GT bounding boxes; Right: Predicted proposals" storage.put_image(vis_name, vis_img) break # only visualize one image in a batch def forward(self, batched_inputs: List[Dict[str, torch.Tensor]]): """ Args: batched_inputs: a list, batched outputs of :class:`DatasetMapper` . Each item in the list contains the inputs for one image. For now, each item in the list is a dict that contains: * image: Tensor, image in (C, H, W) format. * instances (optional): groundtruth :class:`Instances` * proposals (optional): :class:`Instances`, precomputed proposals. Other information that's included in the original dicts, such as: * "height", "width" (int): the output resolution of the model, used in inference. See :meth:`postprocess` for details. Returns: list[dict]: Each dict is the output for one input image. The dict contains one key "instances" whose value is a :class:`Instances`. The :class:`Instances` object has the following keys: "pred_boxes", "pred_classes", "scores", "pred_masks", "pred_keypoints" """ if not self.training: return self.inference(batched_inputs) images = self.preprocess_image(batched_inputs) if "instances" in batched_inputs[0]: gt_instances = [x["instances"].to(self.device) for x in batched_inputs] else: gt_instances = None features = self.backbone(images.tensor) if self.proposal_generator is not None: proposals, proposal_losses = self.proposal_generator(images, features, gt_instances) else: assert "proposals" in batched_inputs[0] proposals = [x["proposals"].to(self.device) for x in batched_inputs] proposal_losses = {} _, detector_losses = self.roi_heads(images, features, proposals, gt_instances) if self.vis_period > 0: storage = get_event_storage() if storage.iter % self.vis_period == 0: self.visualize_training(batched_inputs, proposals) losses = {} losses.update(detector_losses) losses.update(proposal_losses) return losses def inference( self, batched_inputs: List[Dict[str, torch.Tensor]], detected_instances: Optional[List[Instances]] = None, do_postprocess: bool = True, ): """ Run inference on the given inputs. Args: batched_inputs (list[dict]): same as in :meth:`forward` detected_instances (None or list[Instances]): if not None, it contains an `Instances` object per image. The `Instances` object contains "pred_boxes" and "pred_classes" which are known boxes in the image. The inference will then skip the detection of bounding boxes, and only predict other per-ROI outputs. do_postprocess (bool): whether to apply post-processing on the outputs. Returns: When do_postprocess=True, same as in :meth:`forward`. Otherwise, a list[Instances] containing raw network outputs. """ assert not self.training images = self.preprocess_image(batched_inputs) features = self.backbone(images.tensor) if detected_instances is None: if self.proposal_generator is not None: proposals, _ = self.proposal_generator(images, features, None) else: assert "proposals" in batched_inputs[0] proposals = [x["proposals"].to(self.device) for x in batched_inputs] results, _ = self.roi_heads(images, features, proposals, None) else: detected_instances = [x.to(self.device) for x in detected_instances] results = self.roi_heads.forward_with_given_boxes(features, detected_instances) if do_postprocess: assert not torch.jit.is_scripting(), "Scripting is not supported for postprocess." return GeneralizedRCNN._postprocess(results, batched_inputs, images.image_sizes) return results def preprocess_image(self, batched_inputs: List[Dict[str, torch.Tensor]]): """ Normalize, pad and batch the input images. """ images = [self._move_to_current_device(x["image"]) for x in batched_inputs] images = [(x - self.pixel_mean) / self.pixel_std for x in images] images = ImageList.from_tensors( images, self.backbone.size_divisibility, padding_constraints=self.backbone.padding_constraints, ) return images @staticmethod def _postprocess(instances, batched_inputs: List[Dict[str, torch.Tensor]], image_sizes): """ Rescale the output instances to the target size. """ # note: private function; subject to changes processed_results = [] for results_per_image, input_per_image, image_size in zip( instances, batched_inputs, image_sizes ): height = input_per_image.get("height", image_size[0]) width = input_per_image.get("width", image_size[1]) r = detector_postprocess(results_per_image, height, width) processed_results.append({"instances": r}) return processed_results @META_ARCH_REGISTRY.register() class ProposalNetwork(nn.Module): """ A meta architecture that only predicts object proposals. """ @configurable def __init__( self, *, backbone: Backbone, proposal_generator: nn.Module, pixel_mean: Tuple[float], pixel_std: Tuple[float], ): """ Args: backbone: a backbone module, must follow detectron2's backbone interface proposal_generator: a module that generates proposals using backbone features pixel_mean, pixel_std: list or tuple with #channels element, representing the per-channel mean and std to be used to normalize the input image """ super().__init__() self.backbone = backbone self.proposal_generator = proposal_generator self.register_buffer("pixel_mean", torch.tensor(pixel_mean).view(-1, 1, 1), False) self.register_buffer("pixel_std", torch.tensor(pixel_std).view(-1, 1, 1), False) @classmethod def from_config(cls, cfg): backbone = build_backbone(cfg) return { "backbone": backbone, "proposal_generator": build_proposal_generator(cfg, backbone.output_shape()), "pixel_mean": cfg.MODEL.PIXEL_MEAN, "pixel_std": cfg.MODEL.PIXEL_STD, } @property def device(self): return self.pixel_mean.device def _move_to_current_device(self, x): return move_device_like(x, self.pixel_mean) def forward(self, batched_inputs): """ Args: Same as in :class:`GeneralizedRCNN.forward` Returns: list[dict]: Each dict is the output for one input image. The dict contains one key "proposals" whose value is a :class:`Instances` with keys "proposal_boxes" and "objectness_logits". """ images = [self._move_to_current_device(x["image"]) for x in batched_inputs] images = [(x - self.pixel_mean) / self.pixel_std for x in images] images = ImageList.from_tensors( images, self.backbone.size_divisibility, padding_constraints=self.backbone.padding_constraints, ) features = self.backbone(images.tensor) if "instances" in batched_inputs[0]: gt_instances = [x["instances"].to(self.device) for x in batched_inputs] elif "targets" in batched_inputs[0]: log_first_n( logging.WARN, "'targets' in the model inputs is now renamed to 'instances'!", n=10 ) gt_instances = [x["targets"].to(self.device) for x in batched_inputs] else: gt_instances = None proposals, proposal_losses = self.proposal_generator(images, features, gt_instances) # In training, the proposals are not useful at all but we generate them anyway. # This makes RPN-only models about 5% slower. if self.training: return proposal_losses processed_results = [] for results_per_image, input_per_image, image_size in zip( proposals, batched_inputs, images.image_sizes ): height = input_per_image.get("height", image_size[0]) width = input_per_image.get("width", image_size[1]) r = detector_postprocess(results_per_image, height, width) processed_results.append({"proposals": r}) return processed_results