Spaces:
Runtime error
Runtime error
Create predict.py
Browse files- predict.py +271 -0
predict.py
ADDED
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import hydra
|
2 |
+
import torch
|
3 |
+
import argparse
|
4 |
+
import time
|
5 |
+
from pathlib import Path
|
6 |
+
import math
|
7 |
+
import cv2
|
8 |
+
import torch
|
9 |
+
import torch.backends.cudnn as cudnn
|
10 |
+
from numpy import random
|
11 |
+
from ultralytics.yolo.engine.predictor import BasePredictor
|
12 |
+
from ultralytics.yolo.utils import DEFAULT_CONFIG, ROOT, ops
|
13 |
+
from ultralytics.yolo.utils.checks import check_imgsz
|
14 |
+
from ultralytics.yolo.utils.plotting import Annotator, colors, save_one_box
|
15 |
+
|
16 |
+
import cv2
|
17 |
+
from deep_sort_pytorch.utils.parser import get_config
|
18 |
+
from deep_sort_pytorch.deep_sort import DeepSort
|
19 |
+
from collections import deque
|
20 |
+
import numpy as np
|
21 |
+
palette = (2 ** 11 - 1, 2 ** 15 - 1, 2 ** 20 - 1)
|
22 |
+
cars_deque = {}
|
23 |
+
|
24 |
+
|
25 |
+
deepsort = None
|
26 |
+
|
27 |
+
object_counter = {}
|
28 |
+
|
29 |
+
speed_line_queue = {}
|
30 |
+
def estimatespeed(Location1, Location2):
|
31 |
+
#Euclidean Distance Formula
|
32 |
+
d_pixel = math.sqrt(math.pow(Location2[0] - Location1[0], 2) + math.pow(Location2[1] - Location1[1], 2))
|
33 |
+
# defining thr pixels per meter
|
34 |
+
ppm = 8
|
35 |
+
d_meters = d_pixel/ppm
|
36 |
+
time_constant = 15*3.6
|
37 |
+
#distance = speed/time
|
38 |
+
speed = d_meters * time_constant
|
39 |
+
|
40 |
+
return int(speed)
|
41 |
+
def init_tracker():
|
42 |
+
global deepsort
|
43 |
+
cfg_deep = get_config()
|
44 |
+
cfg_deep.merge_from_file("deep_sort_pytorch/configs/deep_sort.yaml")
|
45 |
+
|
46 |
+
deepsort= DeepSort(cfg_deep.DEEPSORT.REID_CKPT,
|
47 |
+
max_dist=cfg_deep.DEEPSORT.MAX_DIST, min_confidence=cfg_deep.DEEPSORT.MIN_CONFIDENCE,
|
48 |
+
nms_max_overlap=cfg_deep.DEEPSORT.NMS_MAX_OVERLAP, max_iou_distance=cfg_deep.DEEPSORT.MAX_IOU_DISTANCE,
|
49 |
+
max_age=cfg_deep.DEEPSORT.MAX_AGE, n_init=cfg_deep.DEEPSORT.N_INIT, nn_budget=cfg_deep.DEEPSORT.NN_BUDGET,
|
50 |
+
use_cuda=True)
|
51 |
+
##########################################################################################
|
52 |
+
def xyxy_to_xywh(*xyxy):
|
53 |
+
"""" Calculates the relative bounding box from absolute pixel values. """
|
54 |
+
bbox_left = min([xyxy[0].item(), xyxy[2].item()])
|
55 |
+
bbox_top = min([xyxy[1].item(), xyxy[3].item()])
|
56 |
+
bbox_w = abs(xyxy[0].item() - xyxy[2].item())
|
57 |
+
bbox_h = abs(xyxy[1].item() - xyxy[3].item())
|
58 |
+
x_c = (bbox_left + bbox_w / 2)
|
59 |
+
y_c = (bbox_top + bbox_h / 2)
|
60 |
+
w = bbox_w
|
61 |
+
h = bbox_h
|
62 |
+
return x_c, y_c, w, h
|
63 |
+
|
64 |
+
|
65 |
+
def compute_color_for_labels(label):
|
66 |
+
"""
|
67 |
+
Simple function that adds fixed color depending on the class
|
68 |
+
"""
|
69 |
+
if label == 0: #person
|
70 |
+
color = (85,45,255)
|
71 |
+
elif label == 2: # Car
|
72 |
+
color = (222,82,175)
|
73 |
+
elif label == 3: # Motobike
|
74 |
+
color = (0, 204, 255)
|
75 |
+
elif label == 5: # Bus
|
76 |
+
color = (0, 149, 255)
|
77 |
+
else:
|
78 |
+
color = [int((p * (label ** 2 - label + 1)) % 255) for p in palette]
|
79 |
+
return tuple(color)
|
80 |
+
|
81 |
+
def draw_border(img, pt1, pt2, color, thickness, r, d):
|
82 |
+
x1,y1 = pt1
|
83 |
+
x2,y2 = pt2
|
84 |
+
# Top left
|
85 |
+
cv2.line(img, (x1 + r, y1), (x1 + r + d, y1), color, thickness)
|
86 |
+
cv2.line(img, (x1, y1 + r), (x1, y1 + r + d), color, thickness)
|
87 |
+
cv2.ellipse(img, (x1 + r, y1 + r), (r, r), 180, 0, 90, color, thickness)
|
88 |
+
# Top right
|
89 |
+
cv2.line(img, (x2 - r, y1), (x2 - r - d, y1), color, thickness)
|
90 |
+
cv2.line(img, (x2, y1 + r), (x2, y1 + r + d), color, thickness)
|
91 |
+
cv2.ellipse(img, (x2 - r, y1 + r), (r, r), 270, 0, 90, color, thickness)
|
92 |
+
# Bottom left
|
93 |
+
cv2.line(img, (x1 + r, y2), (x1 + r + d, y2), color, thickness)
|
94 |
+
cv2.line(img, (x1, y2 - r), (x1, y2 - r - d), color, thickness)
|
95 |
+
cv2.ellipse(img, (x1 + r, y2 - r), (r, r), 90, 0, 90, color, thickness)
|
96 |
+
# Bottom right
|
97 |
+
cv2.line(img, (x2 - r, y2), (x2 - r - d, y2), color, thickness)
|
98 |
+
cv2.line(img, (x2, y2 - r), (x2, y2 - r - d), color, thickness)
|
99 |
+
cv2.ellipse(img, (x2 - r, y2 - r), (r, r), 0, 0, 90, color, thickness)
|
100 |
+
|
101 |
+
cv2.rectangle(img, (x1 + r, y1), (x2 - r, y2), color, -1, cv2.LINE_AA)
|
102 |
+
cv2.rectangle(img, (x1, y1 + r), (x2, y2 - r - d), color, -1, cv2.LINE_AA)
|
103 |
+
|
104 |
+
cv2.circle(img, (x1 +r, y1+r), 2, color, 12)
|
105 |
+
cv2.circle(img, (x2 -r, y1+r), 2, color, 12)
|
106 |
+
cv2.circle(img, (x1 +r, y2-r), 2, color, 12)
|
107 |
+
cv2.circle(img, (x2 -r, y2-r), 2, color, 12)
|
108 |
+
|
109 |
+
return img
|
110 |
+
|
111 |
+
def UI_box(x, img, color=None, label=None, line_thickness=None):
|
112 |
+
# Plots one bounding box on image img
|
113 |
+
tl = line_thickness or round(0.002 * (img.shape[0] + img.shape[1]) / 2) + 1 # line/font thickness
|
114 |
+
color = color or [random.randint(0, 255) for _ in range(3)]
|
115 |
+
c1, c2 = (int(x[0]), int(x[1])), (int(x[2]), int(x[3]))
|
116 |
+
cv2.rectangle(img, c1, c2, color, thickness=tl, lineType=cv2.LINE_AA)
|
117 |
+
if label:
|
118 |
+
tf = max(tl - 1, 1) # font thickness
|
119 |
+
t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0]
|
120 |
+
|
121 |
+
img = draw_border(img, (c1[0], c1[1] - t_size[1] -3), (c1[0] + t_size[0], c1[1]+3), color, 1, 8, 2)
|
122 |
+
|
123 |
+
cv2.putText(img, label, (c1[0], c1[1] - 2), 0, tl / 3, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)
|
124 |
+
|
125 |
+
|
126 |
+
def ccw(A,B,C):
|
127 |
+
return (C[1]-A[1]) * (B[0]-A[0]) > (B[1]-A[1]) * (C[0]-A[0])
|
128 |
+
|
129 |
+
|
130 |
+
def draw_boxes(img, bbox, names,object_id, identities=None, offset=(0, 0)):
|
131 |
+
#cv2.line(img, line[0], line[1], (46,162,112), 3)
|
132 |
+
cv2.putText(img, f'Number of cars: {len(cars_deque)}', (11, 35), 0, 1, [0, 255, 0], thickness=2, lineType=cv2.LINE_AA)
|
133 |
+
height, width, _ = img.shape
|
134 |
+
# remove tracked point from buffer if object is lost
|
135 |
+
for key in list(cars_deque):
|
136 |
+
if key not in identities:
|
137 |
+
cars_deque.pop(key)
|
138 |
+
|
139 |
+
for i, box in enumerate(bbox):
|
140 |
+
obj_name = names[object_id[i]]
|
141 |
+
if obj_name == 'car':
|
142 |
+
x1, y1, x2, y2 = [int(i) for i in box]
|
143 |
+
x1 += offset[0]
|
144 |
+
x2 += offset[0]
|
145 |
+
y1 += offset[1]
|
146 |
+
y2 += offset[1]
|
147 |
+
|
148 |
+
# code to find center of bottom edge
|
149 |
+
center = (int((x2+x1)/ 2), int((y2+y2)/2))
|
150 |
+
|
151 |
+
# get ID of object
|
152 |
+
id = int(identities[i]) if identities is not None else 0
|
153 |
+
|
154 |
+
# create new buffer for new object
|
155 |
+
if id not in cars_deque:
|
156 |
+
cars_deque[id] = deque(maxlen= 64)
|
157 |
+
speed_line_queue[id] = []
|
158 |
+
color = compute_color_for_labels(object_id[i])
|
159 |
+
|
160 |
+
|
161 |
+
label = '{}{:d}'.format("", id) + ":"+ '%s' % (obj_name)
|
162 |
+
|
163 |
+
|
164 |
+
# add center to buffer
|
165 |
+
cars_deque[id].appendleft(center)
|
166 |
+
if len(cars_deque[id]) >= 2:
|
167 |
+
object_speed = estimatespeed(cars_deque[id][1], cars_deque[id][0])
|
168 |
+
speed_line_queue[id].append(object_speed)
|
169 |
+
if obj_name not in object_counter:
|
170 |
+
object_counter[obj_name] = 1
|
171 |
+
|
172 |
+
|
173 |
+
try:
|
174 |
+
label = label + " " + str(sum(speed_line_queue[id])//len(speed_line_queue[id])) + "km/h"
|
175 |
+
except:
|
176 |
+
pass
|
177 |
+
UI_box(box, img, label=label, color=color, line_thickness=2)
|
178 |
+
|
179 |
+
|
180 |
+
return img
|
181 |
+
|
182 |
+
|
183 |
+
class DetectionPredictor(BasePredictor):
|
184 |
+
|
185 |
+
def get_annotator(self, img):
|
186 |
+
return Annotator(img, line_width=self.args.line_thickness, example=str(self.model.names))
|
187 |
+
|
188 |
+
def preprocess(self, img):
|
189 |
+
img = torch.from_numpy(img).to(self.model.device)
|
190 |
+
img = img.half() if self.model.fp16 else img.float() # uint8 to fp16/32
|
191 |
+
img /= 255 # 0 - 255 to 0.0 - 1.0
|
192 |
+
return img
|
193 |
+
|
194 |
+
def postprocess(self, preds, img, orig_img):
|
195 |
+
preds = ops.non_max_suppression(preds,
|
196 |
+
self.args.conf,
|
197 |
+
self.args.iou,
|
198 |
+
agnostic=self.args.agnostic_nms,
|
199 |
+
max_det=self.args.max_det)
|
200 |
+
|
201 |
+
for i, pred in enumerate(preds):
|
202 |
+
shape = orig_img[i].shape if self.webcam else orig_img.shape
|
203 |
+
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], shape).round()
|
204 |
+
|
205 |
+
return preds
|
206 |
+
|
207 |
+
def write_results(self, idx, preds, batch):
|
208 |
+
p, im, im0 = batch
|
209 |
+
all_outputs = []
|
210 |
+
log_string = ""
|
211 |
+
if len(im.shape) == 3:
|
212 |
+
im = im[None] # expand for batch dim
|
213 |
+
self.seen += 1
|
214 |
+
im0 = im0.copy()
|
215 |
+
if self.webcam: # batch_size >= 1
|
216 |
+
log_string += f'{idx}: '
|
217 |
+
frame = self.dataset.count
|
218 |
+
else:
|
219 |
+
frame = getattr(self.dataset, 'frame', 0)
|
220 |
+
|
221 |
+
self.data_path = p
|
222 |
+
save_path = str(self.save_dir / p.name) # im.jpg
|
223 |
+
self.txt_path = str(self.save_dir / 'labels' / p.stem) + ('' if self.dataset.mode == 'image' else f'_{frame}')
|
224 |
+
log_string += '%gx%g ' % im.shape[2:] # print string
|
225 |
+
self.annotator = self.get_annotator(im0)
|
226 |
+
|
227 |
+
det = preds[idx]
|
228 |
+
all_outputs.append(det)
|
229 |
+
if len(det) == 0:
|
230 |
+
return log_string
|
231 |
+
for c in det[:, 5].unique():
|
232 |
+
n = (det[:, 5] == c).sum() # detections per class
|
233 |
+
log_string += f"{n} {self.model.names[int(c)]}{'s' * (n > 1)}, "
|
234 |
+
# write
|
235 |
+
gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh
|
236 |
+
xywh_bboxs = []
|
237 |
+
confs = []
|
238 |
+
oids = []
|
239 |
+
outputs = []
|
240 |
+
for *xyxy, conf, cls in reversed(det):
|
241 |
+
x_c, y_c, bbox_w, bbox_h = xyxy_to_xywh(*xyxy)
|
242 |
+
xywh_obj = [x_c, y_c, bbox_w, bbox_h]
|
243 |
+
xywh_bboxs.append(xywh_obj)
|
244 |
+
confs.append([conf.item()])
|
245 |
+
oids.append(int(cls))
|
246 |
+
xywhs = torch.Tensor(xywh_bboxs)
|
247 |
+
confss = torch.Tensor(confs)
|
248 |
+
|
249 |
+
outputs = deepsort.update(xywhs, confss, oids, im0)
|
250 |
+
if len(outputs) > 0:
|
251 |
+
bbox_xyxy = outputs[:, :4]
|
252 |
+
identities = outputs[:, -2]
|
253 |
+
object_id = outputs[:, -1]
|
254 |
+
|
255 |
+
draw_boxes(im0, bbox_xyxy, self.model.names, object_id,identities)
|
256 |
+
|
257 |
+
return log_string
|
258 |
+
|
259 |
+
|
260 |
+
@hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name)
|
261 |
+
def predict(cfg):
|
262 |
+
init_tracker()
|
263 |
+
cfg.model = cfg.model or "yolov8n.pt"
|
264 |
+
cfg.imgsz = check_imgsz(cfg.imgsz, min_dim=2) # check image size
|
265 |
+
cfg.source = cfg.source if cfg.source is not None else ROOT / "assets"
|
266 |
+
predictor = DetectionPredictor(cfg)
|
267 |
+
predictor()
|
268 |
+
|
269 |
+
|
270 |
+
if __name__ == "__main__":
|
271 |
+
predict()
|