import math import torch import torch.nn.functional as F from torch import nn from .inference import make_atss_postprocessor from .loss import make_atss_loss_evaluator from maskrcnn_benchmark.structures.boxlist_ops import cat_boxlist from maskrcnn_benchmark.layers import Scale, DFConv2d, DYReLU, SELayer from .anchor_generator import make_anchor_generator_complex class BoxCoder(object): def __init__(self, cfg): self.cfg = cfg def encode(self, gt_boxes, anchors): TO_REMOVE = 1 # TODO remove ex_widths = anchors[:, 2] - anchors[:, 0] + TO_REMOVE ex_heights = anchors[:, 3] - anchors[:, 1] + TO_REMOVE ex_ctr_x = (anchors[:, 2] + anchors[:, 0]) / 2 ex_ctr_y = (anchors[:, 3] + anchors[:, 1]) / 2 gt_widths = gt_boxes[:, 2] - gt_boxes[:, 0] + TO_REMOVE gt_heights = gt_boxes[:, 3] - gt_boxes[:, 1] + TO_REMOVE gt_ctr_x = (gt_boxes[:, 2] + gt_boxes[:, 0]) / 2 gt_ctr_y = (gt_boxes[:, 3] + gt_boxes[:, 1]) / 2 wx, wy, ww, wh = (10.0, 10.0, 5.0, 5.0) targets_dx = wx * (gt_ctr_x - ex_ctr_x) / ex_widths targets_dy = wy * (gt_ctr_y - ex_ctr_y) / ex_heights targets_dw = ww * torch.log(gt_widths / ex_widths) targets_dh = wh * torch.log(gt_heights / ex_heights) targets = torch.stack((targets_dx, targets_dy, targets_dw, targets_dh), dim=1) return targets def decode(self, preds, anchors): anchors = anchors.to(preds.dtype) TO_REMOVE = 1 # TODO remove widths = anchors[:, 2] - anchors[:, 0] + TO_REMOVE heights = anchors[:, 3] - anchors[:, 1] + TO_REMOVE ctr_x = (anchors[:, 2] + anchors[:, 0]) / 2 ctr_y = (anchors[:, 3] + anchors[:, 1]) / 2 wx, wy, ww, wh = (10.0, 10.0, 5.0, 5.0) dx = preds[:, 0::4] / wx dy = preds[:, 1::4] / wy dw = preds[:, 2::4] / ww dh = preds[:, 3::4] / wh # Prevent sending too large values into torch.exp() dw = torch.clamp(dw, max=math.log(1000.0 / 16)) dh = torch.clamp(dh, max=math.log(1000.0 / 16)) pred_ctr_x = dx * widths[:, None] + ctr_x[:, None] pred_ctr_y = dy * heights[:, None] + ctr_y[:, None] pred_w = torch.exp(dw) * widths[:, None] pred_h = torch.exp(dh) * heights[:, None] pred_boxes = torch.zeros_like(preds) pred_boxes[:, 0::4] = pred_ctr_x - 0.5 * (pred_w - 1) pred_boxes[:, 1::4] = pred_ctr_y - 0.5 * (pred_h - 1) pred_boxes[:, 2::4] = pred_ctr_x + 0.5 * (pred_w - 1) pred_boxes[:, 3::4] = pred_ctr_y + 0.5 * (pred_h - 1) return pred_boxes class ATSSHead(torch.nn.Module): def __init__(self, cfg): super(ATSSHead, self).__init__() self.cfg = cfg num_classes = cfg.MODEL.ATSS.NUM_CLASSES - 1 num_anchors = len(cfg.MODEL.RPN.ASPECT_RATIOS) * cfg.MODEL.RPN.SCALES_PER_OCTAVE in_channels = cfg.MODEL.BACKBONE.OUT_CHANNELS channels = cfg.MODEL.ATSS.CHANNELS use_gn = cfg.MODEL.ATSS.USE_GN use_bn = cfg.MODEL.ATSS.USE_BN use_dcn_in_tower = cfg.MODEL.ATSS.USE_DFCONV use_dyrelu = cfg.MODEL.ATSS.USE_DYRELU use_se = cfg.MODEL.ATSS.USE_SE cls_tower = [] bbox_tower = [] for i in range(cfg.MODEL.ATSS.NUM_CONVS): if use_dcn_in_tower and i == cfg.MODEL.ATSS.NUM_CONVS - 1: conv_func = DFConv2d else: conv_func = nn.Conv2d cls_tower.append( conv_func(in_channels if i == 0 else channels, channels, kernel_size=3, stride=1, padding=1, bias=True) ) if use_gn: cls_tower.append(nn.GroupNorm(32, channels)) if use_bn: cls_tower.append(nn.BatchNorm2d(channels)) if use_se: cls_tower.append(SELayer(channels)) if use_dyrelu: cls_tower.append(DYReLU(channels, channels)) else: cls_tower.append(nn.ReLU()) bbox_tower.append( conv_func(in_channels if i == 0 else channels, channels, kernel_size=3, stride=1, padding=1, bias=True) ) if use_gn: bbox_tower.append(nn.GroupNorm(32, channels)) if use_bn: bbox_tower.append(nn.BatchNorm2d(channels)) if use_se: bbox_tower.append(SELayer(channels)) if use_dyrelu: bbox_tower.append(DYReLU(channels, channels)) else: bbox_tower.append(nn.ReLU()) self.add_module("cls_tower", nn.Sequential(*cls_tower)) self.add_module("bbox_tower", nn.Sequential(*bbox_tower)) self.cls_logits = nn.Conv2d(channels, num_anchors * num_classes, kernel_size=3, stride=1, padding=1) self.bbox_pred = nn.Conv2d(channels, num_anchors * 4, kernel_size=3, stride=1, padding=1) self.centerness = nn.Conv2d(channels, num_anchors * 1, kernel_size=3, stride=1, padding=1) # initialization for modules in [self.cls_tower, self.bbox_tower, self.cls_logits, self.bbox_pred, self.centerness]: for l in modules.modules(): if isinstance(l, nn.Conv2d): torch.nn.init.normal_(l.weight, std=0.01) torch.nn.init.constant_(l.bias, 0) # initialize the bias for focal loss prior_prob = cfg.MODEL.ATSS.PRIOR_PROB bias_value = -math.log((1 - prior_prob) / prior_prob) torch.nn.init.constant_(self.cls_logits.bias, bias_value) self.scales = nn.ModuleList([Scale(init_value=1.0) for _ in range(5)]) def forward(self, x): logits = [] bbox_reg = [] centerness = [] for l, feature in enumerate(x): cls_tower = self.cls_tower(feature) box_tower = self.bbox_tower(feature) logits.append(self.cls_logits(cls_tower)) bbox_pred = self.scales[l](self.bbox_pred(box_tower)) bbox_reg.append(bbox_pred) centerness.append(self.centerness(box_tower)) return logits, bbox_reg, centerness class ATSSModule(torch.nn.Module): def __init__(self, cfg): super(ATSSModule, self).__init__() self.cfg = cfg self.head = ATSSHead(cfg) box_coder = BoxCoder(cfg) self.loss_evaluator = make_atss_loss_evaluator(cfg, box_coder) self.box_selector_train = make_atss_postprocessor(cfg, box_coder, is_train=True) self.box_selector_test = make_atss_postprocessor(cfg, box_coder, is_train=False) self.anchor_generator = make_anchor_generator_complex(cfg) def forward(self, images, features, targets=None): box_cls, box_regression, centerness = self.head(features) anchors = self.anchor_generator(images, features) if self.training: return self._forward_train(box_cls, box_regression, centerness, targets, anchors) else: return self._forward_test(box_cls, box_regression, centerness, anchors) def _forward_train(self, box_cls, box_regression, centerness, targets, anchors): loss_box_cls, loss_box_reg, loss_centerness = self.loss_evaluator( box_cls, box_regression, centerness, targets, anchors ) losses = {"loss_cls": loss_box_cls, "loss_reg": loss_box_reg, "loss_centerness": loss_centerness} if self.cfg.MODEL.RPN_ONLY: return None, losses else: boxes = self.box_selector_train(box_cls, box_regression, centerness, anchors) train_boxes = [] for b, a in zip(boxes, anchors): a = cat_boxlist(a) b.add_field("visibility", torch.ones(b.bbox.shape[0], dtype=torch.bool, device=b.bbox.device)) del b.extra_fields["scores"] del b.extra_fields["labels"] train_boxes.append(cat_boxlist([b, a])) return train_boxes, losses def _forward_test(self, box_cls, box_regression, centerness, anchors): boxes = self.box_selector_test(box_cls, box_regression, centerness, anchors) return boxes, {}