# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. """ Implements the Generalized R-CNN framework """ import torch from torch import nn from maskrcnn_benchmark.structures.image_list import to_image_list from ..backbone import build_backbone from ..rpn import build_rpn from ..roi_heads import build_roi_heads import timeit class GeneralizedRCNN(nn.Module): """ Main class for Generalized R-CNN. Currently supports boxes and masks. It consists of three main parts: - backbone - rpn - heads: takes the features + the proposals from the RPN and computes detections / masks from it. """ def __init__(self, cfg): super(GeneralizedRCNN, self).__init__() self.backbone = build_backbone(cfg) self.rpn = build_rpn(cfg) self.roi_heads = build_roi_heads(cfg) self.DEBUG = cfg.MODEL.DEBUG self.ONNX = cfg.MODEL.ONNX self.freeze_backbone = cfg.MODEL.BACKBONE.FREEZE self.freeze_fpn = cfg.MODEL.FPN.FREEZE self.freeze_rpn = cfg.MODEL.RPN.FREEZE if cfg.MODEL.LINEAR_PROB: assert cfg.MODEL.BACKBONE.FREEZE, "For linear probing, backbone should be frozen!" if hasattr(self.backbone, "fpn"): assert cfg.MODEL.FPN.FREEZE, "For linear probing, FPN should be frozen!" self.linear_prob = cfg.MODEL.LINEAR_PROB def train(self, mode=True): """Convert the model into training mode while keep layers freezed.""" super(GeneralizedRCNN, self).train(mode) if self.freeze_backbone: self.backbone.body.eval() for p in self.backbone.body.parameters(): p.requires_grad = False if self.freeze_fpn: self.backbone.fpn.eval() for p in self.backbone.fpn.parameters(): p.requires_grad = False if self.freeze_rpn: self.rpn.eval() for p in self.rpn.parameters(): p.requires_grad = False if self.linear_prob: if self.rpn is not None: for key, value in self.rpn.named_parameters(): if not ("bbox_pred" in key or "cls_logits" in key or "centerness" in key or "cosine_scale" in key): value.requires_grad = False if self.roi_heads is not None: for key, value in self.roi_heads.named_parameters(): if not ("bbox_pred" in key or "cls_logits" in key or "centerness" in key or "cosine_scale" in key): value.requires_grad = False def forward(self, images, targets=None): """ Arguments: images (list[Tensor] or ImageList): images to be processed targets (list[BoxList]): ground-truth boxes present in the image (optional) Returns: result (list[BoxList] or dict[Tensor]): the output from the model. During training, it returns a dict[Tensor] which contains the losses. During testing, it returns list[BoxList] contains additional fields like `scores`, `labels` and `mask` (for Mask R-CNN models). """ if self.training and targets is None: raise ValueError("In training mode, targets should be passed") if self.DEBUG: debug_info = {} if self.DEBUG: debug_info["input_size"] = images[0].size() if self.DEBUG: tic = timeit.time.perf_counter() if self.ONNX: features = self.backbone(images) else: images = to_image_list(images) features = self.backbone(images.tensors) if self.DEBUG: debug_info["feat_time"] = timeit.time.perf_counter() - tic if self.DEBUG: debug_info["feat_size"] = [feat.size() for feat in features] if self.DEBUG: tic = timeit.time.perf_counter() proposals, proposal_losses = self.rpn(images, features, targets) if self.DEBUG: debug_info["rpn_time"] = timeit.time.perf_counter() - tic if self.DEBUG: debug_info["#rpn"] = [prop for prop in proposals] if self.DEBUG: tic = timeit.time.perf_counter() if self.roi_heads: x, result, detector_losses = self.roi_heads(features, proposals, targets) else: # RPN-only models don't have roi_heads x = features result = proposals detector_losses = {} if self.DEBUG: debug_info["rcnn_time"] = timeit.time.perf_counter() - tic if self.DEBUG: debug_info["#rcnn"] = result if self.DEBUG: return result, debug_info if self.training: losses = {} losses.update(detector_losses) losses.update(proposal_losses) return losses return result