Spaces:
Runtime error
Runtime error
from .tracker.byte_tracker import BYTETracker | |
import cv2 | |
import numpy as np | |
class ByteTrack(object): | |
def __init__(self, detector, min_box_area=10): | |
self.min_box_area = min_box_area | |
self.rgb_means = (0.485, 0.456, 0.406) | |
self.std = (0.229, 0.224, 0.225) | |
self.detector = detector | |
self.input_shape = tuple(detector.model.get_inputs()[0].shape[2:]) | |
self.tracker = BYTETracker(frame_rate=30) | |
def inference(self, image, conf_thresh=0.25, classes=None): | |
dets, image_info = self.detector.detect(image, conf_thres=conf_thresh, input_shape=self.input_shape, classes=classes) | |
class_ids=[] | |
ids=[] | |
bboxes=[] | |
scores=[] | |
if isinstance(dets, np.ndarray) and len(dets) > 0: | |
class_ids = dets[:, -1].tolist() | |
bboxes, ids, scores = self._tracker_update( | |
dets, | |
image_info, | |
) | |
# image = self.draw_tracking_info( | |
# image, | |
# bboxes, | |
# ids, | |
# scores, | |
# ) | |
# return image, len(bboxes), class_ids | |
return bboxes, ids, scores, class_ids | |
def get_id_color(self, index): | |
temp_index = abs(int(index)) * 3 | |
color = ((37 * temp_index) % 255, (17 * temp_index) % 255, | |
(29 * temp_index) % 255) | |
return color | |
def draw_tracking_info( | |
self, | |
image, | |
tlwhs, | |
ids, | |
scores, | |
frame_id=0, | |
elapsed_time=0., | |
): | |
text_scale = 1.5 | |
text_thickness = 2 | |
line_thickness = 2 | |
# text = 'frame: %d ' % (frame_id) | |
# text += 'elapsed time: %.0fms ' % (elapsed_time * 1000) | |
# text += 'num: %d' % (len(tlwhs)) | |
# cv2.putText( | |
# image, | |
# text, | |
# (0, int(15 * text_scale)), | |
# cv2.FONT_HERSHEY_PLAIN, | |
# 2, | |
# (0, 255, 0), | |
# thickness=text_thickness, | |
# ) | |
for index, tlwh in enumerate(tlwhs): | |
x1, y1 = int(tlwh[0]), int(tlwh[1]) | |
x2, y2 = x1 + int(tlwh[2]), y1 + int(tlwh[3]) | |
color = self.get_id_color(ids[index]) | |
cv2.rectangle(image, (x1, y1), (x2, y2), color, line_thickness) | |
text = str(ids[index]) | |
cv2.putText(image, text, (x1, y1 - 5), cv2.FONT_HERSHEY_PLAIN, | |
text_scale, (0, 0, 0), text_thickness + 3) | |
cv2.putText(image, text, (x1, y1 - 5), cv2.FONT_HERSHEY_PLAIN, | |
text_scale, (255, 255, 255), text_thickness) | |
return image | |
def _tracker_update(self, dets, image_info): | |
online_targets = [] | |
if dets is not None: | |
online_targets = self.tracker.update( | |
dets[:, :-1], | |
[image_info['height'], image_info['width']], | |
[image_info['height'], image_info['width']], | |
) | |
online_tlwhs = [] | |
online_ids = [] | |
online_scores = [] | |
for online_target in online_targets: | |
tlwh = online_target.tlwh | |
track_id = online_target.track_id | |
vertical = tlwh[2] / tlwh[3] > 1.6 | |
if tlwh[2] * tlwh[3] > self.min_box_area and not vertical: | |
online_tlwhs.append(tlwh) | |
online_ids.append(track_id) | |
online_scores.append(online_target.score) | |
return online_tlwhs, online_ids, online_scores |