# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. import math import torch import torch.nn.functional as F from torch import nn from maskrcnn_benchmark.modeling import registry from maskrcnn_benchmark.modeling.box_coder import BoxCoder from .loss import make_focal_loss_evaluator from .anchor_generator import make_anchor_generator_complex from .inference import make_retina_postprocessor @registry.RPN_HEADS.register("RetinaNetHead") class RetinaNetHead(torch.nn.Module): """ Adds a RetinNet head with classification and regression heads """ def __init__(self, cfg): """ Arguments: in_channels (int): number of channels of the input feature num_anchors (int): number of anchors to be predicted """ super(RetinaNetHead, self).__init__() # TODO: Implement the sigmoid version first. num_classes = cfg.MODEL.RETINANET.NUM_CLASSES - 1 in_channels = cfg.MODEL.BACKBONE.OUT_CHANNELS if cfg.MODEL.RPN.USE_FPN: num_anchors = len(cfg.MODEL.RPN.ASPECT_RATIOS) * cfg.MODEL.RPN.SCALES_PER_OCTAVE else: num_anchors = len(cfg.MODEL.RPN.ASPECT_RATIOS) * len(cfg.MODEL.RPN.ANCHOR_SIZES) cls_tower = [] bbox_tower = [] for i in range(cfg.MODEL.RETINANET.NUM_CONVS): cls_tower.append(nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)) cls_tower.append(nn.ReLU()) bbox_tower.append(nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)) bbox_tower.append(nn.ReLU()) self.add_module("cls_tower", nn.Sequential(*cls_tower)) self.add_module("bbox_tower", nn.Sequential(*bbox_tower)) self.cls_logits = nn.Conv2d(in_channels, num_anchors * num_classes, kernel_size=3, stride=1, padding=1) self.bbox_pred = nn.Conv2d(in_channels, num_anchors * 4, kernel_size=3, stride=1, padding=1) # Initialization for modules in [self.cls_tower, self.bbox_tower, self.cls_logits, self.bbox_pred]: for l in modules.modules(): if isinstance(l, nn.Conv2d): torch.nn.init.normal_(l.weight, std=0.01) torch.nn.init.constant_(l.bias, 0) # retinanet_bias_init prior_prob = cfg.MODEL.RETINANET.PRIOR_PROB bias_value = -math.log((1 - prior_prob) / prior_prob) torch.nn.init.constant_(self.cls_logits.bias, bias_value) def forward(self, x): logits = [] bbox_reg = [] for feature in x: logits.append(self.cls_logits(self.cls_tower(feature))) bbox_reg.append(self.bbox_pred(self.bbox_tower(feature))) return logits, bbox_reg class RetinaNetModule(torch.nn.Module): """ Module for RetinaNet computation. Takes feature maps from the backbone and RetinaNet outputs and losses. Only Test on FPN now. """ def __init__(self, cfg): super(RetinaNetModule, self).__init__() self.cfg = cfg.clone() anchor_generator = make_anchor_generator_complex(cfg) head = RetinaNetHead(cfg) box_coder = BoxCoder(weights=(10.0, 10.0, 5.0, 5.0)) box_selector_test = make_retina_postprocessor(cfg, box_coder, is_train=False) loss_evaluator = make_focal_loss_evaluator(cfg, box_coder) self.anchor_generator = anchor_generator self.head = head self.box_selector_test = box_selector_test self.loss_evaluator = loss_evaluator def forward(self, images, features, targets=None): """ Arguments: images (ImageList): images for which we want to compute the predictions features (list[Tensor]): features computed from the images that are used for computing the predictions. Each tensor in the list correspond to different feature levels targets (list[BoxList): ground-truth boxes present in the image (optional) Returns: boxes (list[BoxList]): the predicted boxes from the RPN, one BoxList per image. losses (dict[Tensor]): the losses for the model during training. During testing, it is an empty dict. """ box_cls, box_regression = self.head(features) anchors = self.anchor_generator(images, features) if self.training: return self._forward_train(anchors, box_cls, box_regression, targets) else: return self._forward_test(anchors, box_cls, box_regression) def _forward_train(self, anchors, box_cls, box_regression, targets): loss_box_cls, loss_box_reg = self.loss_evaluator(anchors, box_cls, box_regression, targets) losses = { "loss_retina_cls": loss_box_cls, "loss_retina_reg": loss_box_reg, } return anchors, losses def _forward_test(self, anchors, box_cls, box_regression): boxes = self.box_selector_test(anchors, box_cls, box_regression) return boxes, {}