SkalskiP commited on
Commit
bcf1bcb
1 Parent(s): 56994b0

replace MMDetection Visualizer with Supervision Annotators

Browse files
Files changed (2) hide show
  1. requirements.txt +1 -1
  2. tools/demo.py +25 -16
requirements.txt CHANGED
@@ -9,7 +9,7 @@ addict
9
  yapf
10
  numpy
11
  opencv-python
12
- supervision==0.6.0
13
  ftfy
14
  regex
15
  pot
 
9
  yapf
10
  numpy
11
  opencv-python
12
+ supervision==0.18.0
13
  ftfy
14
  regex
15
  pot
tools/demo.py CHANGED
@@ -1,5 +1,6 @@
1
  # Copyright (c) Tencent Inc. All rights reserved.
2
  import os
 
3
  import argparse
4
  import os.path as osp
5
  from functools import partial
@@ -11,6 +12,7 @@ import onnxsim
11
  import torch
12
  import gradio as gr
13
  import numpy as np
 
14
  from PIL import Image
15
  from torchvision.ops import nms
16
  from mmengine.config import Config, ConfigDict, DictAction
@@ -23,6 +25,8 @@ from mmyolo.registry import RUNNERS
23
 
24
  from yolo_world.easydeploy.model import DeployModel, MMYOLOBackend
25
 
 
 
26
 
27
  def parse_args():
28
  parser = argparse.ArgumentParser(
@@ -65,27 +69,32 @@ def run_image(runner,
65
  output = runner.model.test_step(data_batch)[0]
66
  pred_instances = output.pred_instances
67
 
68
- keep_idxs = nms(pred_instances.bboxes,
69
- pred_instances.scores,
70
- iou_threshold=nms_thr)
71
- pred_instances = pred_instances[keep_idxs]
72
- pred_instances = pred_instances[
73
- pred_instances.scores.float() > score_thr]
74
  if len(pred_instances.scores) > max_num_boxes:
75
  indices = pred_instances.scores.float().topk(max_num_boxes)[1]
76
  pred_instances = pred_instances[indices]
77
- output.pred_instances = pred_instances
 
 
 
 
 
 
 
 
 
 
 
78
 
79
  image = np.array(image)
80
- visualizer = DetLocalVisualizer()
81
- visualizer.dataset_meta['classes'] = [t[0] for t in texts]
82
- visualizer.add_datasample('image',
83
- np.array(image),
84
- output,
85
- draw_gt=False,
86
- out_file=image_path,
87
- pred_score_thr=score_thr)
88
- image = Image.open(image_path)
89
  return image
90
 
91
 
 
1
  # Copyright (c) Tencent Inc. All rights reserved.
2
  import os
3
+ import cv2
4
  import argparse
5
  import os.path as osp
6
  from functools import partial
 
12
  import torch
13
  import gradio as gr
14
  import numpy as np
15
+ import supervision as sv
16
  from PIL import Image
17
  from torchvision.ops import nms
18
  from mmengine.config import Config, ConfigDict, DictAction
 
25
 
26
  from yolo_world.easydeploy.model import DeployModel, MMYOLOBackend
27
 
28
+ BOUNDING_BOX_ANNOTATOR = sv.BoundingBoxAnnotator()
29
+ LABEL_ANNOTATOR = sv.LabelAnnotator(text_color=sv.Color.BLACK)
30
 
31
  def parse_args():
32
  parser = argparse.ArgumentParser(
 
69
  output = runner.model.test_step(data_batch)[0]
70
  pred_instances = output.pred_instances
71
 
72
+ keep = nms(pred_instances.bboxes, pred_instances.scores, iou_threshold=nms_thr)
73
+ pred_instances = pred_instances[keep]
74
+ pred_instances = pred_instances[pred_instances.scores.float() > score_thr]
75
+
 
 
76
  if len(pred_instances.scores) > max_num_boxes:
77
  indices = pred_instances.scores.float().topk(max_num_boxes)[1]
78
  pred_instances = pred_instances[indices]
79
+
80
+ pred_instances = pred_instances.cpu().numpy()
81
+ detections = sv.Detections(
82
+ xyxy=pred_instances['bboxes'],
83
+ class_id=pred_instances['labels'],
84
+ confidence=pred_instances['scores']
85
+ )
86
+ labels = [
87
+ f"{texts[class_id][0]} {confidence:0.2f}"
88
+ for class_id, confidence
89
+ in zip(detections.class_id, detections.confidence)
90
+ ]
91
 
92
  image = np.array(image)
93
+ image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
94
+ image = BOUNDING_BOX_ANNOTATOR.annotate(image, detections)
95
+ image = LABEL_ANNOTATOR.annotate(image, detections, labels=labels)
96
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
97
+ image = Image.fromarray(image)
 
 
 
 
98
  return image
99
 
100