# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. import logging import torch from maskrcnn_benchmark.modeling.box_coder import BoxCoder from maskrcnn_benchmark.structures.bounding_box import BoxList, _onnx_clip_boxes_to_image from maskrcnn_benchmark.structures.boxlist_ops import cat_boxlist from maskrcnn_benchmark.structures.boxlist_ops import boxlist_nms from maskrcnn_benchmark.structures.boxlist_ops import boxlist_ml_nms from maskrcnn_benchmark.structures.boxlist_ops import remove_small_boxes from ..utils import permute_and_flatten import pdb class RPNPostProcessor(torch.nn.Module): """ Performs post-processing on the outputs of the RPN boxes, before feeding the proposals to the heads """ def __init__( self, pre_nms_top_n, post_nms_top_n, nms_thresh, min_size, box_coder=None, fpn_post_nms_top_n=None, onnx=False ): """ Arguments: pre_nms_top_n (int) post_nms_top_n (int) nms_thresh (float) min_size (int) box_coder (BoxCoder) fpn_post_nms_top_n (int) """ super(RPNPostProcessor, self).__init__() self.pre_nms_top_n = pre_nms_top_n self.post_nms_top_n = post_nms_top_n self.nms_thresh = nms_thresh self.min_size = min_size self.onnx = onnx if box_coder is None: box_coder = BoxCoder(weights=(1.0, 1.0, 1.0, 1.0)) self.box_coder = box_coder if fpn_post_nms_top_n is None: fpn_post_nms_top_n = post_nms_top_n self.fpn_post_nms_top_n = fpn_post_nms_top_n def add_gt_proposals(self, proposals, targets): """ Arguments: proposals: list[BoxList] targets: list[BoxList] """ # Get the device we're operating on device = proposals[0].bbox.device gt_boxes = [target.copy_with_fields([]) for target in targets] # later cat of bbox requires all fields to be present for all bbox # so we need to add a dummy for objectness that's missing for gt_box in gt_boxes: gt_box.add_field("objectness", torch.ones(len(gt_box), device=device)) proposals = [cat_boxlist((proposal, gt_box)) for proposal, gt_box in zip(proposals, gt_boxes)] return proposals def forward_for_single_feature_map(self, anchors, objectness, box_regression): """ Arguments: anchors: list[BoxList] objectness: tensor of size N, A, H, W box_regression: tensor of size N, A * 4, H, W """ device = objectness.device N, A, H, W = objectness.shape # put in the same format as anchors objectness = objectness.permute(0, 2, 3, 1).reshape(N, -1) objectness = objectness.sigmoid() box_regression = box_regression.view(N, -1, 4, H, W).permute(0, 3, 4, 1, 2) box_regression = box_regression.reshape(N, -1, 4) num_anchors = A * H * W pre_nms_top_n = min(self.pre_nms_top_n, num_anchors) objectness, topk_idx = objectness.topk(pre_nms_top_n, dim=1, sorted=True) batch_idx = torch.arange(N, device=device)[:, None] box_regression = box_regression[batch_idx, topk_idx] image_shapes = [box.size for box in anchors] concat_anchors = torch.cat([a.bbox for a in anchors], dim=0) concat_anchors = concat_anchors.reshape(N, -1, 4)[batch_idx, topk_idx] proposals = self.box_coder.decode(box_regression.view(-1, 4), concat_anchors.view(-1, 4)) proposals = proposals.view(N, -1, 4) result = [] for proposal, score, im_shape in zip(proposals, objectness, image_shapes): if self.onnx: proposal = _onnx_clip_boxes_to_image(proposal, im_shape) boxlist = BoxList(proposal, im_shape, mode="xyxy") else: boxlist = BoxList(proposal, im_shape, mode="xyxy") boxlist = boxlist.clip_to_image(remove_empty=False) boxlist.add_field("objectness", score) boxlist = remove_small_boxes(boxlist, self.min_size) boxlist = boxlist_nms( boxlist, self.nms_thresh, max_proposals=self.post_nms_top_n, score_field="objectness", ) result.append(boxlist) return result def forward(self, anchors, objectness, box_regression, targets=None): """ Arguments: anchors: list[list[BoxList]] objectness: list[tensor] box_regression: list[tensor] Returns: boxlists (list[BoxList]): the post-processed anchors, after applying box decoding and NMS """ sampled_boxes = [] num_levels = len(objectness) anchors = list(zip(*anchors)) for a, o, b in zip(anchors, objectness, box_regression): sampled_boxes.append(self.forward_for_single_feature_map(a, o, b)) boxlists = list(zip(*sampled_boxes)) boxlists = [cat_boxlist(boxlist) for boxlist in boxlists] if num_levels > 1: boxlists = self.select_over_all_levels(boxlists) # append ground-truth bboxes to proposals if self.training and targets is not None: boxlists = self.add_gt_proposals(boxlists, targets) return boxlists def select_over_all_levels(self, boxlists): num_images = len(boxlists) # different behavior during training and during testing: # during training, post_nms_top_n is over *all* the proposals combined, while # during testing, it is over the proposals for each image # TODO resolve this difference and make it consistent. It should be per image, # and not per batch if self.training: objectness = torch.cat([boxlist.get_field("objectness") for boxlist in boxlists], dim=0) box_sizes = [len(boxlist) for boxlist in boxlists] post_nms_top_n = min(self.fpn_post_nms_top_n, len(objectness)) _, inds_sorted = torch.topk(objectness, post_nms_top_n, dim=0, sorted=True) inds_mask = torch.zeros_like(objectness, dtype=torch.bool) inds_mask[inds_sorted] = 1 inds_mask = inds_mask.split(box_sizes) for i in range(num_images): boxlists[i] = boxlists[i][inds_mask[i]] else: for i in range(num_images): objectness = boxlists[i].get_field("objectness") post_nms_top_n = min(self.fpn_post_nms_top_n, len(objectness)) _, inds_sorted = torch.topk(objectness, post_nms_top_n, dim=0, sorted=True) boxlists[i] = boxlists[i][inds_sorted] return boxlists def make_rpn_postprocessor(config, rpn_box_coder, is_train): fpn_post_nms_top_n = config.MODEL.RPN.FPN_POST_NMS_TOP_N_TRAIN if not is_train: fpn_post_nms_top_n = config.MODEL.RPN.FPN_POST_NMS_TOP_N_TEST pre_nms_top_n = config.MODEL.RPN.PRE_NMS_TOP_N_TRAIN post_nms_top_n = config.MODEL.RPN.POST_NMS_TOP_N_TRAIN if not is_train: pre_nms_top_n = config.MODEL.RPN.PRE_NMS_TOP_N_TEST post_nms_top_n = config.MODEL.RPN.POST_NMS_TOP_N_TEST nms_thresh = config.MODEL.RPN.NMS_THRESH min_size = config.MODEL.RPN.MIN_SIZE onnx = config.MODEL.ONNX box_selector = RPNPostProcessor( pre_nms_top_n=pre_nms_top_n, post_nms_top_n=post_nms_top_n, nms_thresh=nms_thresh, min_size=min_size, box_coder=rpn_box_coder, fpn_post_nms_top_n=fpn_post_nms_top_n, onnx=onnx, ) return box_selector class RetinaPostProcessor(torch.nn.Module): """ Performs post-processing on the outputs of the RetinaNet boxes. This is only used in the testing. """ def __init__( self, pre_nms_thresh, pre_nms_top_n, nms_thresh, fpn_post_nms_top_n, min_size, num_classes, box_coder=None, ): """ Arguments: pre_nms_thresh (float) pre_nms_top_n (int) nms_thresh (float) fpn_post_nms_top_n (int) min_size (int) num_classes (int) box_coder (BoxCoder) """ super(RetinaPostProcessor, self).__init__() self.pre_nms_thresh = pre_nms_thresh self.pre_nms_top_n = pre_nms_top_n self.nms_thresh = nms_thresh self.fpn_post_nms_top_n = fpn_post_nms_top_n self.min_size = min_size self.num_classes = num_classes if box_coder is None: box_coder = BoxCoder(weights=(10.0, 10.0, 5.0, 5.0)) self.box_coder = box_coder def forward_for_single_feature_map(self, anchors, box_cls, box_regression): """ Arguments: anchors: list[BoxList] box_cls: tensor of size N, A * C, H, W box_regression: tensor of size N, A * 4, H, W """ device = box_cls.device N, _, H, W = box_cls.shape A = box_regression.size(1) // 4 C = box_cls.size(1) // A # put in the same format as anchors box_cls = permute_and_flatten(box_cls, N, A, C, H, W) box_cls = box_cls.sigmoid() box_regression = permute_and_flatten(box_regression, N, A, 4, H, W) box_regression = box_regression.reshape(N, -1, 4) num_anchors = A * H * W candidate_inds = box_cls > self.pre_nms_thresh pre_nms_top_n = candidate_inds.view(N, -1).sum(1) pre_nms_top_n = pre_nms_top_n.clamp(max=self.pre_nms_top_n) results = [] for per_box_cls, per_box_regression, per_pre_nms_top_n, per_candidate_inds, per_anchors in zip( box_cls, box_regression, pre_nms_top_n, candidate_inds, anchors ): # Sort and select TopN # TODO most of this can be made out of the loop for # all images. # TODO:Yang: Not easy to do. Because the numbers of detections are # different in each image. Therefore, this part needs to be done # per image. per_box_cls = per_box_cls[per_candidate_inds] per_box_cls, top_k_indices = per_box_cls.topk(per_pre_nms_top_n, sorted=False) per_candidate_nonzeros = per_candidate_inds.nonzero()[top_k_indices, :] per_box_loc = per_candidate_nonzeros[:, 0] per_class = per_candidate_nonzeros[:, 1] per_class += 1 detections = self.box_coder.decode( per_box_regression[per_box_loc, :].view(-1, 4), per_anchors.bbox[per_box_loc, :].view(-1, 4) ) boxlist = BoxList(detections, per_anchors.size, mode="xyxy") boxlist.add_field("labels", per_class) boxlist.add_field("scores", per_box_cls) boxlist = boxlist.clip_to_image(remove_empty=False) boxlist = remove_small_boxes(boxlist, self.min_size) results.append(boxlist) return results # TODO very similar to filter_results from PostProcessor # but filter_results is per image # TODO Yang: solve this issue in the future. No good solution # right now. def select_over_all_levels(self, boxlists): num_images = len(boxlists) results = [] for i in range(num_images): scores = boxlists[i].get_field("scores") labels = boxlists[i].get_field("labels") boxes = boxlists[i].bbox boxlist = boxlists[i] result = [] # skip the background for j in range(1, self.num_classes): inds = (labels == j).nonzero().view(-1) scores_j = scores[inds] boxes_j = boxes[inds, :].view(-1, 4) boxlist_for_class = BoxList(boxes_j, boxlist.size, mode="xyxy") boxlist_for_class.add_field("scores", scores_j) boxlist_for_class = boxlist_nms(boxlist_for_class, self.nms_thresh, score_field="scores") num_labels = len(boxlist_for_class) boxlist_for_class.add_field( "labels", torch.full((num_labels,), j, dtype=torch.int64, device=scores.device) ) result.append(boxlist_for_class) result = cat_boxlist(result) number_of_detections = len(result) # Limit to max_per_image detections **over all classes** if number_of_detections > self.fpn_post_nms_top_n > 0: cls_scores = result.get_field("scores") image_thresh, _ = torch.kthvalue(cls_scores.cpu(), number_of_detections - self.fpn_post_nms_top_n + 1) keep = cls_scores >= image_thresh.item() keep = torch.nonzero(keep).squeeze(1) result = result[keep] results.append(result) return results def forward(self, anchors, objectness, box_regression, targets=None): """ Arguments: anchors: list[list[BoxList]] objectness: list[tensor] box_regression: list[tensor] Returns: boxlists (list[BoxList]): the post-processed anchors, after applying box decoding and NMS """ sampled_boxes = [] anchors = list(zip(*anchors)) for a, o, b in zip(anchors, objectness, box_regression): sampled_boxes.append(self.forward_for_single_feature_map(a, o, b)) boxlists = list(zip(*sampled_boxes)) boxlists = [cat_boxlist(boxlist) for boxlist in boxlists] boxlists = self.select_over_all_levels(boxlists) return boxlists def make_retina_postprocessor(config, rpn_box_coder, is_train): pre_nms_thresh = config.MODEL.RETINANET.INFERENCE_TH pre_nms_top_n = config.MODEL.RETINANET.PRE_NMS_TOP_N nms_thresh = config.MODEL.RETINANET.NMS_TH fpn_post_nms_top_n = config.MODEL.RETINANET.DETECTIONS_PER_IMG min_size = 0 box_selector = RetinaPostProcessor( pre_nms_thresh=pre_nms_thresh, pre_nms_top_n=pre_nms_top_n, nms_thresh=nms_thresh, fpn_post_nms_top_n=fpn_post_nms_top_n, min_size=min_size, num_classes=config.MODEL.RETINANET.NUM_CLASSES, box_coder=rpn_box_coder, ) return box_selector class FCOSPostProcessor(torch.nn.Module): """ Performs post-processing on the outputs of the RetinaNet boxes. This is only used in the testing. """ def __init__( self, pre_nms_thresh, pre_nms_top_n, nms_thresh, fpn_post_nms_top_n, min_size, num_classes, bbox_aug_enabled=False, ): """ Arguments: pre_nms_thresh (float) pre_nms_top_n (int) nms_thresh (float) fpn_post_nms_top_n (int) min_size (int) num_classes (int) box_coder (BoxCoder) """ super(FCOSPostProcessor, self).__init__() self.pre_nms_thresh = pre_nms_thresh self.pre_nms_top_n = pre_nms_top_n self.nms_thresh = nms_thresh self.fpn_post_nms_top_n = fpn_post_nms_top_n self.min_size = min_size self.num_classes = num_classes self.bbox_aug_enabled = bbox_aug_enabled def forward_for_single_feature_map(self, locations, box_cls, box_regression, centerness, image_sizes): """ Arguments: anchors: list[BoxList] box_cls: tensor of size N, A * C, H, W box_regression: tensor of size N, A * 4, H, W """ N, C, H, W = box_cls.shape # put in the same format as locations box_cls = box_cls.view(N, C, H, W).permute(0, 2, 3, 1) box_cls = box_cls.reshape(N, -1, C).sigmoid() box_regression = box_regression.view(N, 4, H, W).permute(0, 2, 3, 1) box_regression = box_regression.reshape(N, -1, 4) centerness = centerness.view(N, 1, H, W).permute(0, 2, 3, 1) centerness = centerness.reshape(N, -1).sigmoid() candidate_inds = box_cls > self.pre_nms_thresh pre_nms_top_n = candidate_inds.reshape(N, -1).sum(1) pre_nms_top_n = pre_nms_top_n.clamp(max=self.pre_nms_top_n) # multiply the classification scores with centerness scores box_cls = box_cls * centerness[:, :, None] results = [] for i in range(N): per_box_cls = box_cls[i] per_candidate_inds = candidate_inds[i] per_box_cls = per_box_cls[per_candidate_inds] per_candidate_nonzeros = per_candidate_inds.nonzero() per_box_loc = per_candidate_nonzeros[:, 0] per_class = per_candidate_nonzeros[:, 1] + 1 per_box_regression = box_regression[i] per_box_regression = per_box_regression[per_box_loc] per_locations = locations[per_box_loc] per_pre_nms_top_n = pre_nms_top_n[i] if per_candidate_inds.sum().item() > per_pre_nms_top_n.item(): per_box_cls, top_k_indices = per_box_cls.topk(per_pre_nms_top_n, sorted=False) per_class = per_class[top_k_indices] per_box_regression = per_box_regression[top_k_indices] per_locations = per_locations[top_k_indices] detections = torch.stack( [ per_locations[:, 0] - per_box_regression[:, 0], per_locations[:, 1] - per_box_regression[:, 1], per_locations[:, 0] + per_box_regression[:, 2], per_locations[:, 1] + per_box_regression[:, 3], ], dim=1, ) h, w = image_sizes[i] boxlist = BoxList(detections, (int(w), int(h)), mode="xyxy") boxlist.add_field("centers", per_locations) boxlist.add_field("labels", per_class) boxlist.add_field("scores", torch.sqrt(per_box_cls)) boxlist = boxlist.clip_to_image(remove_empty=False) boxlist = remove_small_boxes(boxlist, self.min_size) results.append(boxlist) return results def forward(self, locations, box_cls, box_regression, centerness, image_sizes): """ Arguments: anchors: list[list[BoxList]] box_cls: list[tensor] box_regression: list[tensor] image_sizes: list[(h, w)] Returns: boxlists (list[BoxList]): the post-processed anchors, after applying box decoding and NMS """ sampled_boxes = [] for _, (l, o, b, c) in enumerate(zip(locations, box_cls, box_regression, centerness)): sampled_boxes.append(self.forward_for_single_feature_map(l, o, b, c, image_sizes)) boxlists = list(zip(*sampled_boxes)) boxlists = [cat_boxlist(boxlist) for boxlist in boxlists] if not self.bbox_aug_enabled: boxlists = self.select_over_all_levels(boxlists) return boxlists # TODO very similar to filter_results from PostProcessor # but filter_results is per image # TODO Yang: solve this issue in the future. No good solution # right now. def select_over_all_levels(self, boxlists): num_images = len(boxlists) results = [] for i in range(num_images): # multiclass nms result = boxlist_ml_nms(boxlists[i], self.nms_thresh) number_of_detections = len(result) # Limit to max_per_image detections **over all classes** if number_of_detections > self.fpn_post_nms_top_n > 0: cls_scores = result.get_field("scores") image_thresh, _ = torch.kthvalue(cls_scores.cpu(), number_of_detections - self.fpn_post_nms_top_n + 1) keep = cls_scores >= image_thresh.item() keep = torch.nonzero(keep).squeeze(1) result = result[keep] results.append(result) return results def make_fcos_postprocessor(config, is_train=False): pre_nms_thresh = config.MODEL.FCOS.INFERENCE_TH if is_train: pre_nms_thresh = config.MODEL.FCOS.INFERENCE_TH_TRAIN pre_nms_top_n = config.MODEL.FCOS.PRE_NMS_TOP_N fpn_post_nms_top_n = config.MODEL.FCOS.DETECTIONS_PER_IMG if is_train: pre_nms_top_n = config.MODEL.FCOS.PRE_NMS_TOP_N_TRAIN fpn_post_nms_top_n = config.MODEL.FCOS.POST_NMS_TOP_N_TRAIN nms_thresh = config.MODEL.FCOS.NMS_TH box_selector = FCOSPostProcessor( pre_nms_thresh=pre_nms_thresh, pre_nms_top_n=pre_nms_top_n, nms_thresh=nms_thresh, fpn_post_nms_top_n=fpn_post_nms_top_n, min_size=0, num_classes=config.MODEL.FCOS.NUM_CLASSES, ) return box_selector class ATSSPostProcessor(torch.nn.Module): def __init__( self, pre_nms_thresh, pre_nms_top_n, nms_thresh, fpn_post_nms_top_n, min_size, num_classes, box_coder, bbox_aug_enabled=False, bbox_aug_vote=False, score_agg="MEAN", mdetr_style_aggregate_class_num=-1, ): super(ATSSPostProcessor, self).__init__() self.pre_nms_thresh = pre_nms_thresh self.pre_nms_top_n = pre_nms_top_n self.nms_thresh = nms_thresh self.fpn_post_nms_top_n = fpn_post_nms_top_n self.min_size = min_size self.num_classes = num_classes self.bbox_aug_enabled = bbox_aug_enabled self.box_coder = box_coder self.bbox_aug_vote = bbox_aug_vote self.score_agg = score_agg self.mdetr_style_aggregate_class_num = mdetr_style_aggregate_class_num def forward_for_single_feature_map( self, box_regression, centerness, anchors, box_cls=None, token_logits=None, dot_product_logits=None, positive_map=None, ): N, _, H, W = box_regression.shape A = box_regression.size(1) // 4 if box_cls is not None: C = box_cls.size(1) // A if token_logits is not None: T = token_logits.size(1) // A # put in the same format as anchors if box_cls is not None: # print('Classification.') box_cls = permute_and_flatten(box_cls, N, A, C, H, W) box_cls = box_cls.sigmoid() # binary focal loss version if token_logits is not None: # print('Token.') token_logits = permute_and_flatten(token_logits, N, A, T, H, W) token_logits = token_logits.sigmoid() # turn back to original classes scores = convert_grounding_to_od_logits( logits=token_logits, box_cls=box_cls, positive_map=positive_map, score_agg=self.score_agg ) box_cls = scores # binary dot product focal version if dot_product_logits is not None: # print('Dot Product.') dot_product_logits = dot_product_logits.sigmoid() if self.mdetr_style_aggregate_class_num != -1: scores = convert_grounding_to_od_logits_v2( logits=dot_product_logits, num_class=self.mdetr_style_aggregate_class_num, positive_map=positive_map, score_agg=self.score_agg, disable_minus_one=False, ) else: scores = convert_grounding_to_od_logits( logits=dot_product_logits, box_cls=box_cls, positive_map=positive_map, score_agg=self.score_agg ) box_cls = scores box_regression = permute_and_flatten(box_regression, N, A, 4, H, W) box_regression = box_regression.reshape(N, -1, 4) candidate_inds = box_cls > self.pre_nms_thresh pre_nms_top_n = candidate_inds.reshape(N, -1).sum(1) pre_nms_top_n = pre_nms_top_n.clamp(max=self.pre_nms_top_n) centerness = permute_and_flatten(centerness, N, A, 1, H, W) centerness = centerness.reshape(N, -1).sigmoid() # multiply the classification scores with centerness scores box_cls = box_cls * centerness[:, :, None] results = [] for per_box_cls, per_box_regression, per_pre_nms_top_n, per_candidate_inds, per_anchors in zip( box_cls, box_regression, pre_nms_top_n, candidate_inds, anchors ): per_box_cls = per_box_cls[per_candidate_inds] per_box_cls, top_k_indices = per_box_cls.topk(per_pre_nms_top_n, sorted=False) per_candidate_nonzeros = per_candidate_inds.nonzero()[top_k_indices, :] per_box_loc = per_candidate_nonzeros[:, 0] per_class = per_candidate_nonzeros[:, 1] + 1 # print(per_class) detections = self.box_coder.decode( per_box_regression[per_box_loc, :].view(-1, 4), per_anchors.bbox[per_box_loc, :].view(-1, 4) ) boxlist = BoxList(detections, per_anchors.size, mode="xyxy") boxlist.add_field("labels", per_class) boxlist.add_field("scores", torch.sqrt(per_box_cls)) boxlist = boxlist.clip_to_image(remove_empty=False) boxlist = remove_small_boxes(boxlist, self.min_size) results.append(boxlist) return results def forward( self, box_regression, centerness, anchors, box_cls=None, token_logits=None, dot_product_logits=None, positive_map=None, ): sampled_boxes = [] anchors = list(zip(*anchors)) for idx, (b, c, a) in enumerate(zip(box_regression, centerness, anchors)): o = None t = None d = None if box_cls is not None: o = box_cls[idx] if token_logits is not None: t = token_logits[idx] if dot_product_logits is not None: d = dot_product_logits[idx] sampled_boxes.append(self.forward_for_single_feature_map(b, c, a, o, t, d, positive_map)) boxlists = list(zip(*sampled_boxes)) boxlists = [cat_boxlist(boxlist) for boxlist in boxlists] if not (self.bbox_aug_enabled and not self.bbox_aug_vote): boxlists = self.select_over_all_levels(boxlists) return boxlists # TODO very similar to filter_results from PostProcessor # but filter_results is per image # TODO Yang: solve this issue in the future. No good solution # right now. def select_over_all_levels(self, boxlists): num_images = len(boxlists) results = [] for i in range(num_images): # multiclass nms result = boxlist_ml_nms(boxlists[i], self.nms_thresh) number_of_detections = len(result) # Limit to max_per_image detections **over all classes** if number_of_detections > self.fpn_post_nms_top_n > 0: cls_scores = result.get_field("scores") image_thresh, _ = torch.kthvalue( # cls_scores.cpu(), cls_scores.cpu().float(), number_of_detections - self.fpn_post_nms_top_n + 1, ) keep = cls_scores >= image_thresh.item() keep = torch.nonzero(keep).squeeze(1) result = result[keep] results.append(result) return results def convert_grounding_to_od_logits(logits, box_cls, positive_map, score_agg=None): scores = torch.zeros(logits.shape[0], logits.shape[1], box_cls.shape[2]).to(logits.device) # 256 -> 80, average for each class if positive_map is not None: # score aggregation method if score_agg == "MEAN": for label_j in positive_map: scores[:, :, label_j - 1] = logits[:, :, torch.LongTensor(positive_map[label_j])].mean(-1) elif score_agg == "MAX": # torch.max() returns (values, indices) for label_j in positive_map: scores[:, :, label_j - 1] = logits[:, :, torch.LongTensor(positive_map[label_j])].max(-1)[0] elif score_agg == "ONEHOT": # one hot scores = logits[:, :, : len(positive_map)] else: raise NotImplementedError return scores def convert_grounding_to_od_logits_v2(logits, num_class, positive_map, score_agg=None, disable_minus_one=True): scores = torch.zeros(logits.shape[0], logits.shape[1], num_class).to(logits.device) # 256 -> 80, average for each class if positive_map is not None: # score aggregation method if score_agg == "MEAN": for label_j in positive_map: locations_label_j = positive_map[label_j] if isinstance(locations_label_j, int): locations_label_j = [locations_label_j] scores[:, :, label_j if disable_minus_one else label_j - 1] = logits[ :, :, torch.LongTensor(locations_label_j) ].mean(-1) elif score_agg == "POWER": for label_j in positive_map: locations_label_j = positive_map[label_j] if isinstance(locations_label_j, int): locations_label_j = [locations_label_j] probability = torch.prod(logits[:, :, torch.LongTensor(locations_label_j)], dim=-1).squeeze(-1) probability = torch.pow(probability, 1 / len(locations_label_j)) scores[:, :, label_j if disable_minus_one else label_j - 1] = probability elif score_agg == "MAX": # torch.max() returns (values, indices) for label_j in positive_map: scores[:, :, label_j if disable_minus_one else label_j - 1] = logits[ :, :, torch.LongTensor(positive_map[label_j]) ].max(-1)[0] elif score_agg == "ONEHOT": # one hot scores = logits[:, :, : len(positive_map)] else: raise NotImplementedError return scores def make_atss_postprocessor(config, box_coder, is_train=False): pre_nms_thresh = config.MODEL.ATSS.INFERENCE_TH if is_train: pre_nms_thresh = config.MODEL.ATSS.INFERENCE_TH_TRAIN pre_nms_top_n = config.MODEL.ATSS.PRE_NMS_TOP_N fpn_post_nms_top_n = config.MODEL.ATSS.DETECTIONS_PER_IMG if is_train: pre_nms_top_n = config.MODEL.ATSS.PRE_NMS_TOP_N_TRAIN fpn_post_nms_top_n = config.MODEL.ATSS.POST_NMS_TOP_N_TRAIN nms_thresh = config.MODEL.ATSS.NMS_TH score_agg = config.MODEL.DYHEAD.SCORE_AGG box_selector = ATSSPostProcessor( pre_nms_thresh=pre_nms_thresh, pre_nms_top_n=pre_nms_top_n, nms_thresh=nms_thresh, fpn_post_nms_top_n=fpn_post_nms_top_n, min_size=0, num_classes=config.MODEL.ATSS.NUM_CLASSES, box_coder=box_coder, bbox_aug_enabled=config.TEST.USE_MULTISCALE, score_agg=score_agg, mdetr_style_aggregate_class_num=config.TEST.MDETR_STYLE_AGGREGATE_CLASS_NUM, ) return box_selector