Ge Zheng Feng Wang commited on
Commit
c62d838
·
1 Parent(s): 0fd8660

feat(demo): add class agnostic nms for demo and numpy based postprocessing (#588)

Browse files

feat(demo): add class agnostic nms for demo and numpy based postprocessing (#588)

Co-authored-by: Feng Wang <wangfeng19950315@163.com>

Files changed (3) hide show
  1. tools/demo.py +2 -1
  2. yolox/utils/boxes.py +15 -7
  3. yolox/utils/demo_utils.py +29 -1
tools/demo.py CHANGED
@@ -154,7 +154,8 @@ class Predictor(object):
154
  if self.decoder is not None:
155
  outputs = self.decoder(outputs, dtype=outputs.type())
156
  outputs = postprocess(
157
- outputs, self.num_classes, self.confthre, self.nmsthre
 
158
  )
159
  logger.info("Infer time: {:.4f}s".format(time.time() - t0))
160
  return outputs, img_info
 
154
  if self.decoder is not None:
155
  outputs = self.decoder(outputs, dtype=outputs.type())
156
  outputs = postprocess(
157
+ outputs, self.num_classes, self.confthre,
158
+ self.nmsthre, class_agnostic=True
159
  )
160
  logger.info("Infer time: {:.4f}s".format(time.time() - t0))
161
  return outputs, img_info
yolox/utils/boxes.py CHANGED
@@ -29,7 +29,7 @@ def filter_box(output, scale_range):
29
  return output[keep]
30
 
31
 
32
- def postprocess(prediction, num_classes, conf_thre=0.7, nms_thre=0.45):
33
  box_corner = prediction.new(prediction.shape)
34
  box_corner[:, :, 0] = prediction[:, :, 0] - prediction[:, :, 2] / 2
35
  box_corner[:, :, 1] = prediction[:, :, 1] - prediction[:, :, 3] / 2
@@ -53,12 +53,20 @@ def postprocess(prediction, num_classes, conf_thre=0.7, nms_thre=0.45):
53
  if not detections.size(0):
54
  continue
55
 
56
- nms_out_index = torchvision.ops.batched_nms(
57
- detections[:, :4],
58
- detections[:, 4] * detections[:, 5],
59
- detections[:, 6],
60
- nms_thre,
61
- )
 
 
 
 
 
 
 
 
62
  detections = detections[nms_out_index]
63
  if output[i] is None:
64
  output[i] = detections
 
29
  return output[keep]
30
 
31
 
32
+ def postprocess(prediction, num_classes, conf_thre=0.7, nms_thre=0.45, class_agnostic=False):
33
  box_corner = prediction.new(prediction.shape)
34
  box_corner[:, :, 0] = prediction[:, :, 0] - prediction[:, :, 2] / 2
35
  box_corner[:, :, 1] = prediction[:, :, 1] - prediction[:, :, 3] / 2
 
53
  if not detections.size(0):
54
  continue
55
 
56
+ if class_agnostic:
57
+ nms_out_index = torchvision.ops.nms(
58
+ detections[:, :4],
59
+ detections[:, 4] * detections[:, 5],
60
+ nms_thre,
61
+ )
62
+ else:
63
+ nms_out_index = torchvision.ops.batched_nms(
64
+ detections[:, :4],
65
+ detections[:, 4] * detections[:, 5],
66
+ detections[:, 6],
67
+ nms_thre,
68
+ )
69
+
70
  detections = detections[nms_out_index]
71
  if output[i] is None:
72
  output[i] = detections
yolox/utils/demo_utils.py CHANGED
@@ -44,8 +44,17 @@ def nms(boxes, scores, nms_thr):
44
  return keep
45
 
46
 
47
- def multiclass_nms(boxes, scores, nms_thr, score_thr):
48
  """Multiclass NMS implemented in Numpy"""
 
 
 
 
 
 
 
 
 
49
  final_dets = []
50
  num_classes = scores.shape[1]
51
  for cls_ind in range(num_classes):
@@ -68,6 +77,25 @@ def multiclass_nms(boxes, scores, nms_thr, score_thr):
68
  return np.concatenate(final_dets, 0)
69
 
70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  def demo_postprocess(outputs, img_size, p6=False):
72
 
73
  grids = []
 
44
  return keep
45
 
46
 
47
+ def multiclass_nms(boxes, scores, nms_thr, score_thr, class_agnostic=True):
48
  """Multiclass NMS implemented in Numpy"""
49
+ if class_agnostic:
50
+ nms_method = multiclass_nms_class_agnostic
51
+ else:
52
+ nms_method = multiclass_nms_class_aware
53
+ return nms_method(boxes, scores, nms_thr, score_thr)
54
+
55
+
56
+ def multiclass_nms_class_aware(boxes, scores, nms_thr, score_thr):
57
+ """Multiclass NMS implemented in Numpy. Class-aware version."""
58
  final_dets = []
59
  num_classes = scores.shape[1]
60
  for cls_ind in range(num_classes):
 
77
  return np.concatenate(final_dets, 0)
78
 
79
 
80
+ def multiclass_nms_class_agnostic(boxes, scores, nms_thr, score_thr):
81
+ """Multiclass NMS implemented in Numpy. Class-agnostic version."""
82
+ cls_inds = scores.argmax(1)
83
+ cls_scores = scores[np.arange(len(cls_inds)), cls_inds]
84
+
85
+ valid_score_mask = cls_scores > score_thr
86
+ if valid_score_mask.sum() == 0:
87
+ return None
88
+ valid_scores = cls_scores[valid_score_mask]
89
+ valid_boxes = boxes[valid_score_mask]
90
+ valid_cls_inds = cls_inds[valid_score_mask]
91
+ keep = nms(valid_boxes, valid_scores, nms_thr)
92
+ if keep:
93
+ dets = np.concatenate(
94
+ [valid_boxes[keep], valid_scores[keep, None], valid_cls_inds[keep, None]], 1
95
+ )
96
+ return dets
97
+
98
+
99
  def demo_postprocess(outputs, img_size, p6=False):
100
 
101
  grids = []