zdou0830's picture
desco
749745d
raw
history blame
No virus
7.81 kB
import math
import torch
import torch.nn.functional as F
from torch import nn
from maskrcnn_benchmark.modeling import registry
from maskrcnn_benchmark.layers import Scale, DFConv2d
from .loss import make_fcos_loss_evaluator
from .anchor_generator import make_center_anchor_generator
from .inference import make_fcos_postprocessor
@registry.RPN_HEADS.register("FCOSHead")
class FCOSHead(torch.nn.Module):
def __init__(self, cfg):
super(FCOSHead, self).__init__()
# TODO: Implement the sigmoid version first.
num_classes = cfg.MODEL.FCOS.NUM_CLASSES - 1
in_channels = cfg.MODEL.BACKBONE.OUT_CHANNELS
use_gn = cfg.MODEL.FCOS.USE_GN
use_bn = cfg.MODEL.FCOS.USE_BN
use_dcn_in_tower = cfg.MODEL.FCOS.USE_DFCONV
self.fpn_strides = cfg.MODEL.FCOS.FPN_STRIDES
self.norm_reg_targets = cfg.MODEL.FCOS.NORM_REG_TARGETS
self.centerness_on_reg = cfg.MODEL.FCOS.CENTERNESS_ON_REG
cls_tower = []
bbox_tower = []
for i in range(cfg.MODEL.FCOS.NUM_CONVS):
if use_dcn_in_tower and i == cfg.MODEL.FCOS.NUM_CONVS - 1:
conv_func = DFConv2d
else:
conv_func = nn.Conv2d
cls_tower.append(conv_func(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=True))
if use_gn:
cls_tower.append(nn.GroupNorm(32, in_channels))
if use_bn:
cls_tower.append(nn.BatchNorm2d(in_channels))
cls_tower.append(nn.ReLU())
bbox_tower.append(conv_func(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=True))
if use_gn:
bbox_tower.append(nn.GroupNorm(32, in_channels))
if use_bn:
bbox_tower.append(nn.BatchNorm2d(in_channels))
bbox_tower.append(nn.ReLU())
self.add_module("cls_tower", nn.Sequential(*cls_tower))
self.add_module("bbox_tower", nn.Sequential(*bbox_tower))
self.cls_logits = nn.Conv2d(in_channels, num_classes, kernel_size=3, stride=1, padding=1)
self.bbox_pred = nn.Conv2d(in_channels, 4, kernel_size=3, stride=1, padding=1)
self.centerness = nn.Conv2d(in_channels, 1, kernel_size=3, stride=1, padding=1)
# initialization
for modules in [self.cls_tower, self.bbox_tower, self.cls_logits, self.bbox_pred, self.centerness]:
for l in modules.modules():
if isinstance(l, nn.Conv2d):
torch.nn.init.normal_(l.weight, std=0.01)
torch.nn.init.constant_(l.bias, 0)
# initialize the bias for focal loss
prior_prob = cfg.MODEL.FCOS.PRIOR_PROB
bias_value = -math.log((1 - prior_prob) / prior_prob)
torch.nn.init.constant_(self.cls_logits.bias, bias_value)
self.scales = nn.ModuleList([Scale(init_value=1.0) for _ in range(5)])
def forward(self, x):
logits = []
bbox_reg = []
centerness = []
for l, feature in enumerate(x):
cls_tower = self.cls_tower(feature)
box_tower = self.bbox_tower(feature)
logits.append(self.cls_logits(cls_tower))
if self.centerness_on_reg:
centerness.append(self.centerness(box_tower))
else:
centerness.append(self.centerness(cls_tower))
bbox_pred = self.scales[l](self.bbox_pred(box_tower))
if self.norm_reg_targets:
bbox_pred = F.relu(bbox_pred)
if self.training:
bbox_reg.append(bbox_pred)
else:
bbox_reg.append(bbox_pred * self.fpn_strides[l])
else:
bbox_reg.append(torch.exp(bbox_pred))
return logits, bbox_reg, centerness
class FCOSModule(torch.nn.Module):
"""
Module for FCOS computation. Takes feature maps from the backbone and
FCOS outputs and losses. Only Test on FPN now.
"""
def __init__(self, cfg):
super(FCOSModule, self).__init__()
head = FCOSHead(cfg)
box_selector_train = make_fcos_postprocessor(cfg, is_train=True)
box_selector_test = make_fcos_postprocessor(cfg, is_train=False)
loss_evaluator = make_fcos_loss_evaluator(cfg)
self.cfg = cfg
self.head = head
self.box_selector_train = box_selector_train
self.box_selector_test = box_selector_test
self.loss_evaluator = loss_evaluator
self.fpn_strides = cfg.MODEL.FCOS.FPN_STRIDES
if not cfg.MODEL.RPN_ONLY:
self.anchor_generator = make_center_anchor_generator(cfg)
def forward(self, images, features, targets=None):
"""
Arguments:
images (ImageList): images for which we want to compute the predictions
features (list[Tensor]): features computed from the images that are
used for computing the predictions. Each tensor in the list
correspond to different feature levels
targets (list[BoxList): ground-truth boxes present in the image (optional)
Returns:
boxes (list[BoxList]): the predicted boxes from the RPN, one BoxList per
image.
losses (dict[Tensor]): the losses for the model during training. During
testing, it is an empty dict.
"""
box_cls, box_regression, centerness = self.head(features)
locations = self.compute_locations(features)
if self.training and targets is not None:
return self._forward_train(locations, box_cls, box_regression, centerness, targets, images.image_sizes)
else:
return self._forward_test(locations, box_cls, box_regression, centerness, images.image_sizes)
def _forward_train(self, locations, box_cls, box_regression, centerness, targets, image_sizes=None):
loss_box_cls, loss_box_reg, loss_centerness = self.loss_evaluator(
locations, box_cls, box_regression, centerness, targets
)
losses = {"loss_cls": loss_box_cls, "loss_reg": loss_box_reg, "loss_centerness": loss_centerness}
if self.cfg.MODEL.RPN_ONLY:
return None, losses
else:
boxes = self.box_selector_train(locations, box_cls, box_regression, centerness, image_sizes)
proposals = self.anchor_generator(boxes, image_sizes, centerness)
return proposals, losses
def _forward_test(self, locations, box_cls, box_regression, centerness, image_sizes):
boxes = self.box_selector_test(locations, box_cls, box_regression, centerness, image_sizes)
if not self.cfg.MODEL.RPN_ONLY:
boxes = self.anchor_generator(boxes, image_sizes, centerness)
return boxes, {}
def compute_locations(self, features):
locations = []
for level, feature in enumerate(features):
h, w = feature.size()[-2:]
locations_per_level = self.compute_locations_per_level(h, w, self.fpn_strides[level], feature.device)
locations.append(locations_per_level)
return locations
def compute_locations_per_level(self, h, w, stride, device):
shifts_x = torch.arange(0, w * stride, step=stride, dtype=torch.float32, device=device)
shifts_y = torch.arange(0, h * stride, step=stride, dtype=torch.float32, device=device)
shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x)
shift_x = shift_x.reshape(-1)
shift_y = shift_y.reshape(-1)
locations = torch.stack((shift_x, shift_y), dim=1) + stride // 2
return locations