Ahsen Khaliq commited on
Commit
80df0b9
1 Parent(s): 5b9f07d

Upload onnx_inference.py

Browse files
Files changed (1) hide show
  1. onnx_inference.py +160 -0
onnx_inference.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+
4
+ import cv2
5
+ import numpy as np
6
+ from loguru import logger
7
+
8
+ import onnxruntime
9
+
10
+ from yolox.data.data_augment import preproc as preprocess
11
+ from yolox.utils import mkdir, multiclass_nms, demo_postprocess, vis
12
+ from yolox.utils.visualize import plot_tracking
13
+ from yolox.tracker.byte_tracker import BYTETracker
14
+ from yolox.tracking_utils.timer import Timer
15
+
16
+
17
+ def make_parser():
18
+ parser = argparse.ArgumentParser("onnxruntime inference sample")
19
+ parser.add_argument(
20
+ "-m",
21
+ "--model",
22
+ type=str,
23
+ default="bytetrack_s.onnx",
24
+ help="Input your onnx model.",
25
+ )
26
+ parser.add_argument(
27
+ "-i",
28
+ "--video_path",
29
+ type=str,
30
+ default='../../videos/palace.mp4',
31
+ help="Path to your input image.",
32
+ )
33
+ parser.add_argument(
34
+ "-o",
35
+ "--output_dir",
36
+ type=str,
37
+ default='demo_output',
38
+ help="Path to your output directory.",
39
+ )
40
+ parser.add_argument(
41
+ "-s",
42
+ "--score_thr",
43
+ type=float,
44
+ default=0.1,
45
+ help="Score threshould to filter the result.",
46
+ )
47
+ parser.add_argument(
48
+ "-n",
49
+ "--nms_thr",
50
+ type=float,
51
+ default=0.7,
52
+ help="NMS threshould.",
53
+ )
54
+ parser.add_argument(
55
+ "--input_shape",
56
+ type=str,
57
+ default="608,1088",
58
+ help="Specify an input shape for inference.",
59
+ )
60
+ parser.add_argument(
61
+ "--with_p6",
62
+ action="store_true",
63
+ help="Whether your model uses p6 in FPN/PAN.",
64
+ )
65
+ # tracking args
66
+ parser.add_argument("--track_thresh", type=float, default=0.5, help="tracking confidence threshold")
67
+ parser.add_argument("--track_buffer", type=int, default=30, help="the frames for keep lost tracks")
68
+ parser.add_argument("--match_thresh", type=int, default=0.8, help="matching threshold for tracking")
69
+ parser.add_argument('--min-box-area', type=float, default=10, help='filter out tiny boxes')
70
+ parser.add_argument("--mot20", dest="mot20", default=False, action="store_true", help="test mot20.")
71
+ return parser
72
+
73
+
74
+ class Predictor(object):
75
+ def __init__(self, args):
76
+ self.rgb_means = (0.485, 0.456, 0.406)
77
+ self.std = (0.229, 0.224, 0.225)
78
+ self.args = args
79
+ self.session = onnxruntime.InferenceSession(args.model)
80
+ self.input_shape = tuple(map(int, args.input_shape.split(',')))
81
+
82
+ def inference(self, ori_img, timer):
83
+ img_info = {"id": 0}
84
+ height, width = ori_img.shape[:2]
85
+ img_info["height"] = height
86
+ img_info["width"] = width
87
+ img_info["raw_img"] = ori_img
88
+
89
+ img, ratio = preprocess(ori_img, self.input_shape, self.rgb_means, self.std)
90
+ img_info["ratio"] = ratio
91
+ ort_inputs = {self.session.get_inputs()[0].name: img[None, :, :, :]}
92
+ timer.tic()
93
+ output = self.session.run(None, ort_inputs)
94
+ predictions = demo_postprocess(output[0], self.input_shape, p6=self.args.with_p6)[0]
95
+
96
+ boxes = predictions[:, :4]
97
+ scores = predictions[:, 4:5] * predictions[:, 5:]
98
+
99
+ boxes_xyxy = np.ones_like(boxes)
100
+ boxes_xyxy[:, 0] = boxes[:, 0] - boxes[:, 2]/2.
101
+ boxes_xyxy[:, 1] = boxes[:, 1] - boxes[:, 3]/2.
102
+ boxes_xyxy[:, 2] = boxes[:, 0] + boxes[:, 2]/2.
103
+ boxes_xyxy[:, 3] = boxes[:, 1] + boxes[:, 3]/2.
104
+ boxes_xyxy /= ratio
105
+ dets = multiclass_nms(boxes_xyxy, scores, nms_thr=self.args.nms_thr, score_thr=self.args.score_thr)
106
+ return dets[:, :-1], img_info
107
+
108
+
109
+ def imageflow_demo(predictor, args):
110
+ cap = cv2.VideoCapture(args.video_path)
111
+ width = cap.get(cv2.CAP_PROP_FRAME_WIDTH) # float
112
+ height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT) # float
113
+ fps = cap.get(cv2.CAP_PROP_FPS)
114
+ save_folder = args.output_dir
115
+ os.makedirs(save_folder, exist_ok=True)
116
+ save_path = os.path.join(save_folder, args.video_path.split("/")[-1])
117
+ logger.info(f"video save_path is {save_path}")
118
+ vid_writer = cv2.VideoWriter(
119
+ save_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, (int(width), int(height))
120
+ )
121
+ tracker = BYTETracker(args, frame_rate=30)
122
+ timer = Timer()
123
+ frame_id = 0
124
+ results = []
125
+ while True:
126
+ if frame_id % 20 == 0:
127
+ logger.info('Processing frame {} ({:.2f} fps)'.format(frame_id, 1. / max(1e-5, timer.average_time)))
128
+ ret_val, frame = cap.read()
129
+ if ret_val:
130
+ outputs, img_info = predictor.inference(frame, timer)
131
+ online_targets = tracker.update(outputs, [img_info['height'], img_info['width']], [img_info['height'], img_info['width']])
132
+ online_tlwhs = []
133
+ online_ids = []
134
+ online_scores = []
135
+ for t in online_targets:
136
+ tlwh = t.tlwh
137
+ tid = t.track_id
138
+ vertical = tlwh[2] / tlwh[3] > 1.6
139
+ if tlwh[2] * tlwh[3] > args.min_box_area and not vertical:
140
+ online_tlwhs.append(tlwh)
141
+ online_ids.append(tid)
142
+ online_scores.append(t.score)
143
+ timer.toc()
144
+ results.append((frame_id + 1, online_tlwhs, online_ids, online_scores))
145
+ online_im = plot_tracking(img_info['raw_img'], online_tlwhs, online_ids, frame_id=frame_id + 1,
146
+ fps=1. / timer.average_time)
147
+ vid_writer.write(online_im)
148
+ ch = cv2.waitKey(1)
149
+ if ch == 27 or ch == ord("q") or ch == ord("Q"):
150
+ break
151
+ else:
152
+ break
153
+ frame_id += 1
154
+
155
+
156
+ if __name__ == '__main__':
157
+ args = make_parser().parse_args()
158
+
159
+ predictor = Predictor(args)
160
+ imageflow_demo(predictor, args)