zdou0830's picture
clean
971be17
raw
history blame
32.1 kB
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import logging
import torch
from maskrcnn_benchmark.modeling.box_coder import BoxCoder
from maskrcnn_benchmark.structures.bounding_box import BoxList, _onnx_clip_boxes_to_image
from maskrcnn_benchmark.structures.boxlist_ops import cat_boxlist
from maskrcnn_benchmark.structures.boxlist_ops import boxlist_nms
from maskrcnn_benchmark.structures.boxlist_ops import boxlist_ml_nms
from maskrcnn_benchmark.structures.boxlist_ops import remove_small_boxes
from ..utils import permute_and_flatten
import pdb
class RPNPostProcessor(torch.nn.Module):
"""
Performs post-processing on the outputs of the RPN boxes, before feeding the
proposals to the heads
"""
def __init__(
self, pre_nms_top_n, post_nms_top_n, nms_thresh, min_size, box_coder=None, fpn_post_nms_top_n=None, onnx=False
):
"""
Arguments:
pre_nms_top_n (int)
post_nms_top_n (int)
nms_thresh (float)
min_size (int)
box_coder (BoxCoder)
fpn_post_nms_top_n (int)
"""
super(RPNPostProcessor, self).__init__()
self.pre_nms_top_n = pre_nms_top_n
self.post_nms_top_n = post_nms_top_n
self.nms_thresh = nms_thresh
self.min_size = min_size
self.onnx = onnx
if box_coder is None:
box_coder = BoxCoder(weights=(1.0, 1.0, 1.0, 1.0))
self.box_coder = box_coder
if fpn_post_nms_top_n is None:
fpn_post_nms_top_n = post_nms_top_n
self.fpn_post_nms_top_n = fpn_post_nms_top_n
def add_gt_proposals(self, proposals, targets):
"""
Arguments:
proposals: list[BoxList]
targets: list[BoxList]
"""
# Get the device we're operating on
device = proposals[0].bbox.device
gt_boxes = [target.copy_with_fields([]) for target in targets]
# later cat of bbox requires all fields to be present for all bbox
# so we need to add a dummy for objectness that's missing
for gt_box in gt_boxes:
gt_box.add_field("objectness", torch.ones(len(gt_box), device=device))
proposals = [cat_boxlist((proposal, gt_box)) for proposal, gt_box in zip(proposals, gt_boxes)]
return proposals
def forward_for_single_feature_map(self, anchors, objectness, box_regression):
"""
Arguments:
anchors: list[BoxList]
objectness: tensor of size N, A, H, W
box_regression: tensor of size N, A * 4, H, W
"""
device = objectness.device
N, A, H, W = objectness.shape
# put in the same format as anchors
objectness = objectness.permute(0, 2, 3, 1).reshape(N, -1)
objectness = objectness.sigmoid()
box_regression = box_regression.view(N, -1, 4, H, W).permute(0, 3, 4, 1, 2)
box_regression = box_regression.reshape(N, -1, 4)
num_anchors = A * H * W
pre_nms_top_n = min(self.pre_nms_top_n, num_anchors)
objectness, topk_idx = objectness.topk(pre_nms_top_n, dim=1, sorted=True)
batch_idx = torch.arange(N, device=device)[:, None]
box_regression = box_regression[batch_idx, topk_idx]
image_shapes = [box.size for box in anchors]
concat_anchors = torch.cat([a.bbox for a in anchors], dim=0)
concat_anchors = concat_anchors.reshape(N, -1, 4)[batch_idx, topk_idx]
proposals = self.box_coder.decode(box_regression.view(-1, 4), concat_anchors.view(-1, 4))
proposals = proposals.view(N, -1, 4)
result = []
for proposal, score, im_shape in zip(proposals, objectness, image_shapes):
if self.onnx:
proposal = _onnx_clip_boxes_to_image(proposal, im_shape)
boxlist = BoxList(proposal, im_shape, mode="xyxy")
else:
boxlist = BoxList(proposal, im_shape, mode="xyxy")
boxlist = boxlist.clip_to_image(remove_empty=False)
boxlist.add_field("objectness", score)
boxlist = remove_small_boxes(boxlist, self.min_size)
boxlist = boxlist_nms(
boxlist,
self.nms_thresh,
max_proposals=self.post_nms_top_n,
score_field="objectness",
)
result.append(boxlist)
return result
def forward(self, anchors, objectness, box_regression, targets=None):
"""
Arguments:
anchors: list[list[BoxList]]
objectness: list[tensor]
box_regression: list[tensor]
Returns:
boxlists (list[BoxList]): the post-processed anchors, after
applying box decoding and NMS
"""
sampled_boxes = []
num_levels = len(objectness)
anchors = list(zip(*anchors))
for a, o, b in zip(anchors, objectness, box_regression):
sampled_boxes.append(self.forward_for_single_feature_map(a, o, b))
boxlists = list(zip(*sampled_boxes))
boxlists = [cat_boxlist(boxlist) for boxlist in boxlists]
if num_levels > 1:
boxlists = self.select_over_all_levels(boxlists)
# append ground-truth bboxes to proposals
if self.training and targets is not None:
boxlists = self.add_gt_proposals(boxlists, targets)
return boxlists
def select_over_all_levels(self, boxlists):
num_images = len(boxlists)
# different behavior during training and during testing:
# during training, post_nms_top_n is over *all* the proposals combined, while
# during testing, it is over the proposals for each image
# TODO resolve this difference and make it consistent. It should be per image,
# and not per batch
if self.training:
objectness = torch.cat([boxlist.get_field("objectness") for boxlist in boxlists], dim=0)
box_sizes = [len(boxlist) for boxlist in boxlists]
post_nms_top_n = min(self.fpn_post_nms_top_n, len(objectness))
_, inds_sorted = torch.topk(objectness, post_nms_top_n, dim=0, sorted=True)
inds_mask = torch.zeros_like(objectness, dtype=torch.bool)
inds_mask[inds_sorted] = 1
inds_mask = inds_mask.split(box_sizes)
for i in range(num_images):
boxlists[i] = boxlists[i][inds_mask[i]]
else:
for i in range(num_images):
objectness = boxlists[i].get_field("objectness")
post_nms_top_n = min(self.fpn_post_nms_top_n, len(objectness))
_, inds_sorted = torch.topk(objectness, post_nms_top_n, dim=0, sorted=True)
boxlists[i] = boxlists[i][inds_sorted]
return boxlists
def make_rpn_postprocessor(config, rpn_box_coder, is_train):
fpn_post_nms_top_n = config.MODEL.RPN.FPN_POST_NMS_TOP_N_TRAIN
if not is_train:
fpn_post_nms_top_n = config.MODEL.RPN.FPN_POST_NMS_TOP_N_TEST
pre_nms_top_n = config.MODEL.RPN.PRE_NMS_TOP_N_TRAIN
post_nms_top_n = config.MODEL.RPN.POST_NMS_TOP_N_TRAIN
if not is_train:
pre_nms_top_n = config.MODEL.RPN.PRE_NMS_TOP_N_TEST
post_nms_top_n = config.MODEL.RPN.POST_NMS_TOP_N_TEST
nms_thresh = config.MODEL.RPN.NMS_THRESH
min_size = config.MODEL.RPN.MIN_SIZE
onnx = config.MODEL.ONNX
box_selector = RPNPostProcessor(
pre_nms_top_n=pre_nms_top_n,
post_nms_top_n=post_nms_top_n,
nms_thresh=nms_thresh,
min_size=min_size,
box_coder=rpn_box_coder,
fpn_post_nms_top_n=fpn_post_nms_top_n,
onnx=onnx,
)
return box_selector
class RetinaPostProcessor(torch.nn.Module):
"""
Performs post-processing on the outputs of the RetinaNet boxes.
This is only used in the testing.
"""
def __init__(
self,
pre_nms_thresh,
pre_nms_top_n,
nms_thresh,
fpn_post_nms_top_n,
min_size,
num_classes,
box_coder=None,
):
"""
Arguments:
pre_nms_thresh (float)
pre_nms_top_n (int)
nms_thresh (float)
fpn_post_nms_top_n (int)
min_size (int)
num_classes (int)
box_coder (BoxCoder)
"""
super(RetinaPostProcessor, self).__init__()
self.pre_nms_thresh = pre_nms_thresh
self.pre_nms_top_n = pre_nms_top_n
self.nms_thresh = nms_thresh
self.fpn_post_nms_top_n = fpn_post_nms_top_n
self.min_size = min_size
self.num_classes = num_classes
if box_coder is None:
box_coder = BoxCoder(weights=(10.0, 10.0, 5.0, 5.0))
self.box_coder = box_coder
def forward_for_single_feature_map(self, anchors, box_cls, box_regression):
"""
Arguments:
anchors: list[BoxList]
box_cls: tensor of size N, A * C, H, W
box_regression: tensor of size N, A * 4, H, W
"""
device = box_cls.device
N, _, H, W = box_cls.shape
A = box_regression.size(1) // 4
C = box_cls.size(1) // A
# put in the same format as anchors
box_cls = permute_and_flatten(box_cls, N, A, C, H, W)
box_cls = box_cls.sigmoid()
box_regression = permute_and_flatten(box_regression, N, A, 4, H, W)
box_regression = box_regression.reshape(N, -1, 4)
num_anchors = A * H * W
candidate_inds = box_cls > self.pre_nms_thresh
pre_nms_top_n = candidate_inds.view(N, -1).sum(1)
pre_nms_top_n = pre_nms_top_n.clamp(max=self.pre_nms_top_n)
results = []
for per_box_cls, per_box_regression, per_pre_nms_top_n, per_candidate_inds, per_anchors in zip(
box_cls, box_regression, pre_nms_top_n, candidate_inds, anchors
):
# Sort and select TopN
# TODO most of this can be made out of the loop for
# all images.
# TODO:Yang: Not easy to do. Because the numbers of detections are
# different in each image. Therefore, this part needs to be done
# per image.
per_box_cls = per_box_cls[per_candidate_inds]
per_box_cls, top_k_indices = per_box_cls.topk(per_pre_nms_top_n, sorted=False)
per_candidate_nonzeros = per_candidate_inds.nonzero()[top_k_indices, :]
per_box_loc = per_candidate_nonzeros[:, 0]
per_class = per_candidate_nonzeros[:, 1]
per_class += 1
detections = self.box_coder.decode(
per_box_regression[per_box_loc, :].view(-1, 4), per_anchors.bbox[per_box_loc, :].view(-1, 4)
)
boxlist = BoxList(detections, per_anchors.size, mode="xyxy")
boxlist.add_field("labels", per_class)
boxlist.add_field("scores", per_box_cls)
boxlist = boxlist.clip_to_image(remove_empty=False)
boxlist = remove_small_boxes(boxlist, self.min_size)
results.append(boxlist)
return results
# TODO very similar to filter_results from PostProcessor
# but filter_results is per image
# TODO Yang: solve this issue in the future. No good solution
# right now.
def select_over_all_levels(self, boxlists):
num_images = len(boxlists)
results = []
for i in range(num_images):
scores = boxlists[i].get_field("scores")
labels = boxlists[i].get_field("labels")
boxes = boxlists[i].bbox
boxlist = boxlists[i]
result = []
# skip the background
for j in range(1, self.num_classes):
inds = (labels == j).nonzero().view(-1)
scores_j = scores[inds]
boxes_j = boxes[inds, :].view(-1, 4)
boxlist_for_class = BoxList(boxes_j, boxlist.size, mode="xyxy")
boxlist_for_class.add_field("scores", scores_j)
boxlist_for_class = boxlist_nms(boxlist_for_class, self.nms_thresh, score_field="scores")
num_labels = len(boxlist_for_class)
boxlist_for_class.add_field(
"labels", torch.full((num_labels,), j, dtype=torch.int64, device=scores.device)
)
result.append(boxlist_for_class)
result = cat_boxlist(result)
number_of_detections = len(result)
# Limit to max_per_image detections **over all classes**
if number_of_detections > self.fpn_post_nms_top_n > 0:
cls_scores = result.get_field("scores")
image_thresh, _ = torch.kthvalue(cls_scores.cpu(), number_of_detections - self.fpn_post_nms_top_n + 1)
keep = cls_scores >= image_thresh.item()
keep = torch.nonzero(keep).squeeze(1)
result = result[keep]
results.append(result)
return results
def forward(self, anchors, objectness, box_regression, targets=None):
"""
Arguments:
anchors: list[list[BoxList]]
objectness: list[tensor]
box_regression: list[tensor]
Returns:
boxlists (list[BoxList]): the post-processed anchors, after
applying box decoding and NMS
"""
sampled_boxes = []
anchors = list(zip(*anchors))
for a, o, b in zip(anchors, objectness, box_regression):
sampled_boxes.append(self.forward_for_single_feature_map(a, o, b))
boxlists = list(zip(*sampled_boxes))
boxlists = [cat_boxlist(boxlist) for boxlist in boxlists]
boxlists = self.select_over_all_levels(boxlists)
return boxlists
def make_retina_postprocessor(config, rpn_box_coder, is_train):
pre_nms_thresh = config.MODEL.RETINANET.INFERENCE_TH
pre_nms_top_n = config.MODEL.RETINANET.PRE_NMS_TOP_N
nms_thresh = config.MODEL.RETINANET.NMS_TH
fpn_post_nms_top_n = config.MODEL.RETINANET.DETECTIONS_PER_IMG
min_size = 0
box_selector = RetinaPostProcessor(
pre_nms_thresh=pre_nms_thresh,
pre_nms_top_n=pre_nms_top_n,
nms_thresh=nms_thresh,
fpn_post_nms_top_n=fpn_post_nms_top_n,
min_size=min_size,
num_classes=config.MODEL.RETINANET.NUM_CLASSES,
box_coder=rpn_box_coder,
)
return box_selector
class FCOSPostProcessor(torch.nn.Module):
"""
Performs post-processing on the outputs of the RetinaNet boxes.
This is only used in the testing.
"""
def __init__(
self,
pre_nms_thresh,
pre_nms_top_n,
nms_thresh,
fpn_post_nms_top_n,
min_size,
num_classes,
bbox_aug_enabled=False,
):
"""
Arguments:
pre_nms_thresh (float)
pre_nms_top_n (int)
nms_thresh (float)
fpn_post_nms_top_n (int)
min_size (int)
num_classes (int)
box_coder (BoxCoder)
"""
super(FCOSPostProcessor, self).__init__()
self.pre_nms_thresh = pre_nms_thresh
self.pre_nms_top_n = pre_nms_top_n
self.nms_thresh = nms_thresh
self.fpn_post_nms_top_n = fpn_post_nms_top_n
self.min_size = min_size
self.num_classes = num_classes
self.bbox_aug_enabled = bbox_aug_enabled
def forward_for_single_feature_map(self, locations, box_cls, box_regression, centerness, image_sizes):
"""
Arguments:
anchors: list[BoxList]
box_cls: tensor of size N, A * C, H, W
box_regression: tensor of size N, A * 4, H, W
"""
N, C, H, W = box_cls.shape
# put in the same format as locations
box_cls = box_cls.view(N, C, H, W).permute(0, 2, 3, 1)
box_cls = box_cls.reshape(N, -1, C).sigmoid()
box_regression = box_regression.view(N, 4, H, W).permute(0, 2, 3, 1)
box_regression = box_regression.reshape(N, -1, 4)
centerness = centerness.view(N, 1, H, W).permute(0, 2, 3, 1)
centerness = centerness.reshape(N, -1).sigmoid()
candidate_inds = box_cls > self.pre_nms_thresh
pre_nms_top_n = candidate_inds.reshape(N, -1).sum(1)
pre_nms_top_n = pre_nms_top_n.clamp(max=self.pre_nms_top_n)
# multiply the classification scores with centerness scores
box_cls = box_cls * centerness[:, :, None]
results = []
for i in range(N):
per_box_cls = box_cls[i]
per_candidate_inds = candidate_inds[i]
per_box_cls = per_box_cls[per_candidate_inds]
per_candidate_nonzeros = per_candidate_inds.nonzero()
per_box_loc = per_candidate_nonzeros[:, 0]
per_class = per_candidate_nonzeros[:, 1] + 1
per_box_regression = box_regression[i]
per_box_regression = per_box_regression[per_box_loc]
per_locations = locations[per_box_loc]
per_pre_nms_top_n = pre_nms_top_n[i]
if per_candidate_inds.sum().item() > per_pre_nms_top_n.item():
per_box_cls, top_k_indices = per_box_cls.topk(per_pre_nms_top_n, sorted=False)
per_class = per_class[top_k_indices]
per_box_regression = per_box_regression[top_k_indices]
per_locations = per_locations[top_k_indices]
detections = torch.stack(
[
per_locations[:, 0] - per_box_regression[:, 0],
per_locations[:, 1] - per_box_regression[:, 1],
per_locations[:, 0] + per_box_regression[:, 2],
per_locations[:, 1] + per_box_regression[:, 3],
],
dim=1,
)
h, w = image_sizes[i]
boxlist = BoxList(detections, (int(w), int(h)), mode="xyxy")
boxlist.add_field("centers", per_locations)
boxlist.add_field("labels", per_class)
boxlist.add_field("scores", torch.sqrt(per_box_cls))
boxlist = boxlist.clip_to_image(remove_empty=False)
boxlist = remove_small_boxes(boxlist, self.min_size)
results.append(boxlist)
return results
def forward(self, locations, box_cls, box_regression, centerness, image_sizes):
"""
Arguments:
anchors: list[list[BoxList]]
box_cls: list[tensor]
box_regression: list[tensor]
image_sizes: list[(h, w)]
Returns:
boxlists (list[BoxList]): the post-processed anchors, after
applying box decoding and NMS
"""
sampled_boxes = []
for _, (l, o, b, c) in enumerate(zip(locations, box_cls, box_regression, centerness)):
sampled_boxes.append(self.forward_for_single_feature_map(l, o, b, c, image_sizes))
boxlists = list(zip(*sampled_boxes))
boxlists = [cat_boxlist(boxlist) for boxlist in boxlists]
if not self.bbox_aug_enabled:
boxlists = self.select_over_all_levels(boxlists)
return boxlists
# TODO very similar to filter_results from PostProcessor
# but filter_results is per image
# TODO Yang: solve this issue in the future. No good solution
# right now.
def select_over_all_levels(self, boxlists):
num_images = len(boxlists)
results = []
for i in range(num_images):
# multiclass nms
result = boxlist_ml_nms(boxlists[i], self.nms_thresh)
number_of_detections = len(result)
# Limit to max_per_image detections **over all classes**
if number_of_detections > self.fpn_post_nms_top_n > 0:
cls_scores = result.get_field("scores")
image_thresh, _ = torch.kthvalue(cls_scores.cpu(), number_of_detections - self.fpn_post_nms_top_n + 1)
keep = cls_scores >= image_thresh.item()
keep = torch.nonzero(keep).squeeze(1)
result = result[keep]
results.append(result)
return results
def make_fcos_postprocessor(config, is_train=False):
pre_nms_thresh = config.MODEL.FCOS.INFERENCE_TH
if is_train:
pre_nms_thresh = config.MODEL.FCOS.INFERENCE_TH_TRAIN
pre_nms_top_n = config.MODEL.FCOS.PRE_NMS_TOP_N
fpn_post_nms_top_n = config.MODEL.FCOS.DETECTIONS_PER_IMG
if is_train:
pre_nms_top_n = config.MODEL.FCOS.PRE_NMS_TOP_N_TRAIN
fpn_post_nms_top_n = config.MODEL.FCOS.POST_NMS_TOP_N_TRAIN
nms_thresh = config.MODEL.FCOS.NMS_TH
box_selector = FCOSPostProcessor(
pre_nms_thresh=pre_nms_thresh,
pre_nms_top_n=pre_nms_top_n,
nms_thresh=nms_thresh,
fpn_post_nms_top_n=fpn_post_nms_top_n,
min_size=0,
num_classes=config.MODEL.FCOS.NUM_CLASSES,
)
return box_selector
class ATSSPostProcessor(torch.nn.Module):
def __init__(
self,
pre_nms_thresh,
pre_nms_top_n,
nms_thresh,
fpn_post_nms_top_n,
min_size,
num_classes,
box_coder,
bbox_aug_enabled=False,
bbox_aug_vote=False,
score_agg="MEAN",
mdetr_style_aggregate_class_num=-1,
):
super(ATSSPostProcessor, self).__init__()
self.pre_nms_thresh = pre_nms_thresh
self.pre_nms_top_n = pre_nms_top_n
self.nms_thresh = nms_thresh
self.fpn_post_nms_top_n = fpn_post_nms_top_n
self.min_size = min_size
self.num_classes = num_classes
self.bbox_aug_enabled = bbox_aug_enabled
self.box_coder = box_coder
self.bbox_aug_vote = bbox_aug_vote
self.score_agg = score_agg
self.mdetr_style_aggregate_class_num = mdetr_style_aggregate_class_num
def forward_for_single_feature_map(
self,
box_regression,
centerness,
anchors,
box_cls=None,
token_logits=None,
dot_product_logits=None,
positive_map=None,
):
N, _, H, W = box_regression.shape
A = box_regression.size(1) // 4
if box_cls is not None:
C = box_cls.size(1) // A
if token_logits is not None:
T = token_logits.size(1) // A
# put in the same format as anchors
if box_cls is not None:
# print('Classification.')
box_cls = permute_and_flatten(box_cls, N, A, C, H, W)
box_cls = box_cls.sigmoid()
# binary focal loss version
if token_logits is not None:
# print('Token.')
token_logits = permute_and_flatten(token_logits, N, A, T, H, W)
token_logits = token_logits.sigmoid()
# turn back to original classes
scores = convert_grounding_to_od_logits(
logits=token_logits, box_cls=box_cls, positive_map=positive_map, score_agg=self.score_agg
)
box_cls = scores
# binary dot product focal version
if dot_product_logits is not None:
# print('Dot Product.')
dot_product_logits = dot_product_logits.sigmoid()
if self.mdetr_style_aggregate_class_num != -1:
scores = convert_grounding_to_od_logits_v2(
logits=dot_product_logits,
num_class=self.mdetr_style_aggregate_class_num,
positive_map=positive_map,
score_agg=self.score_agg,
disable_minus_one=False,
)
else:
scores = convert_grounding_to_od_logits(
logits=dot_product_logits, box_cls=box_cls, positive_map=positive_map, score_agg=self.score_agg
)
box_cls = scores
box_regression = permute_and_flatten(box_regression, N, A, 4, H, W)
box_regression = box_regression.reshape(N, -1, 4)
candidate_inds = box_cls > self.pre_nms_thresh
pre_nms_top_n = candidate_inds.reshape(N, -1).sum(1)
pre_nms_top_n = pre_nms_top_n.clamp(max=self.pre_nms_top_n)
centerness = permute_and_flatten(centerness, N, A, 1, H, W)
centerness = centerness.reshape(N, -1).sigmoid()
# multiply the classification scores with centerness scores
box_cls = box_cls * centerness[:, :, None]
results = []
for per_box_cls, per_box_regression, per_pre_nms_top_n, per_candidate_inds, per_anchors in zip(
box_cls, box_regression, pre_nms_top_n, candidate_inds, anchors
):
per_box_cls = per_box_cls[per_candidate_inds]
per_box_cls, top_k_indices = per_box_cls.topk(per_pre_nms_top_n, sorted=False)
per_candidate_nonzeros = per_candidate_inds.nonzero()[top_k_indices, :]
per_box_loc = per_candidate_nonzeros[:, 0]
per_class = per_candidate_nonzeros[:, 1] + 1
# print(per_class)
detections = self.box_coder.decode(
per_box_regression[per_box_loc, :].view(-1, 4), per_anchors.bbox[per_box_loc, :].view(-1, 4)
)
boxlist = BoxList(detections, per_anchors.size, mode="xyxy")
boxlist.add_field("labels", per_class)
boxlist.add_field("scores", torch.sqrt(per_box_cls))
boxlist = boxlist.clip_to_image(remove_empty=False)
boxlist = remove_small_boxes(boxlist, self.min_size)
results.append(boxlist)
return results
def forward(
self,
box_regression,
centerness,
anchors,
box_cls=None,
token_logits=None,
dot_product_logits=None,
positive_map=None,
):
sampled_boxes = []
anchors = list(zip(*anchors))
for idx, (b, c, a) in enumerate(zip(box_regression, centerness, anchors)):
o = None
t = None
d = None
if box_cls is not None:
o = box_cls[idx]
if token_logits is not None:
t = token_logits[idx]
if dot_product_logits is not None:
d = dot_product_logits[idx]
sampled_boxes.append(self.forward_for_single_feature_map(b, c, a, o, t, d, positive_map))
boxlists = list(zip(*sampled_boxes))
boxlists = [cat_boxlist(boxlist) for boxlist in boxlists]
if not (self.bbox_aug_enabled and not self.bbox_aug_vote):
boxlists = self.select_over_all_levels(boxlists)
return boxlists
# TODO very similar to filter_results from PostProcessor
# but filter_results is per image
# TODO Yang: solve this issue in the future. No good solution
# right now.
def select_over_all_levels(self, boxlists):
num_images = len(boxlists)
results = []
for i in range(num_images):
# multiclass nms
result = boxlist_ml_nms(boxlists[i], self.nms_thresh)
number_of_detections = len(result)
# Limit to max_per_image detections **over all classes**
if number_of_detections > self.fpn_post_nms_top_n > 0:
cls_scores = result.get_field("scores")
image_thresh, _ = torch.kthvalue(
# cls_scores.cpu(),
cls_scores.cpu().float(),
number_of_detections - self.fpn_post_nms_top_n + 1,
)
keep = cls_scores >= image_thresh.item()
keep = torch.nonzero(keep).squeeze(1)
result = result[keep]
results.append(result)
return results
def convert_grounding_to_od_logits(logits, box_cls, positive_map, score_agg=None):
scores = torch.zeros(logits.shape[0], logits.shape[1], box_cls.shape[2]).to(logits.device)
# 256 -> 80, average for each class
if positive_map is not None:
# score aggregation method
if score_agg == "MEAN":
for label_j in positive_map:
scores[:, :, label_j - 1] = logits[:, :, torch.LongTensor(positive_map[label_j])].mean(-1)
elif score_agg == "MAX":
# torch.max() returns (values, indices)
for label_j in positive_map:
scores[:, :, label_j - 1] = logits[:, :, torch.LongTensor(positive_map[label_j])].max(-1)[0]
elif score_agg == "ONEHOT":
# one hot
scores = logits[:, :, : len(positive_map)]
else:
raise NotImplementedError
return scores
def convert_grounding_to_od_logits_v2(logits, num_class, positive_map, score_agg=None, disable_minus_one=True):
scores = torch.zeros(logits.shape[0], logits.shape[1], num_class).to(logits.device)
# 256 -> 80, average for each class
if positive_map is not None:
# score aggregation method
if score_agg == "MEAN":
for label_j in positive_map:
locations_label_j = positive_map[label_j]
if isinstance(locations_label_j, int):
locations_label_j = [locations_label_j]
scores[:, :, label_j if disable_minus_one else label_j - 1] = logits[
:, :, torch.LongTensor(locations_label_j)
].mean(-1)
elif score_agg == "POWER":
for label_j in positive_map:
locations_label_j = positive_map[label_j]
if isinstance(locations_label_j, int):
locations_label_j = [locations_label_j]
probability = torch.prod(logits[:, :, torch.LongTensor(locations_label_j)], dim=-1).squeeze(-1)
probability = torch.pow(probability, 1 / len(locations_label_j))
scores[:, :, label_j if disable_minus_one else label_j - 1] = probability
elif score_agg == "MAX":
# torch.max() returns (values, indices)
for label_j in positive_map:
scores[:, :, label_j if disable_minus_one else label_j - 1] = logits[
:, :, torch.LongTensor(positive_map[label_j])
].max(-1)[0]
elif score_agg == "ONEHOT":
# one hot
scores = logits[:, :, : len(positive_map)]
else:
raise NotImplementedError
return scores
def make_atss_postprocessor(config, box_coder, is_train=False):
pre_nms_thresh = config.MODEL.ATSS.INFERENCE_TH
if is_train:
pre_nms_thresh = config.MODEL.ATSS.INFERENCE_TH_TRAIN
pre_nms_top_n = config.MODEL.ATSS.PRE_NMS_TOP_N
fpn_post_nms_top_n = config.MODEL.ATSS.DETECTIONS_PER_IMG
if is_train:
pre_nms_top_n = config.MODEL.ATSS.PRE_NMS_TOP_N_TRAIN
fpn_post_nms_top_n = config.MODEL.ATSS.POST_NMS_TOP_N_TRAIN
nms_thresh = config.MODEL.ATSS.NMS_TH
score_agg = config.MODEL.DYHEAD.SCORE_AGG
box_selector = ATSSPostProcessor(
pre_nms_thresh=pre_nms_thresh,
pre_nms_top_n=pre_nms_top_n,
nms_thresh=nms_thresh,
fpn_post_nms_top_n=fpn_post_nms_top_n,
min_size=0,
num_classes=config.MODEL.ATSS.NUM_CLASSES,
box_coder=box_coder,
bbox_aug_enabled=config.TEST.USE_MULTISCALE,
score_agg=score_agg,
mdetr_style_aggregate_class_num=config.TEST.MDETR_STYLE_AGGREGATE_CLASS_NUM,
)
return box_selector