kisa-misa commited on
Commit
ddff15f
1 Parent(s): de639ab

Create predict.py

Browse files
Files changed (1) hide show
  1. predict.py +271 -0
predict.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hydra
2
+ import torch
3
+ import argparse
4
+ import time
5
+ from pathlib import Path
6
+ import math
7
+ import cv2
8
+ import torch
9
+ import torch.backends.cudnn as cudnn
10
+ from numpy import random
11
+ from ultralytics.yolo.engine.predictor import BasePredictor
12
+ from ultralytics.yolo.utils import DEFAULT_CONFIG, ROOT, ops
13
+ from ultralytics.yolo.utils.checks import check_imgsz
14
+ from ultralytics.yolo.utils.plotting import Annotator, colors, save_one_box
15
+
16
+ import cv2
17
+ from deep_sort_pytorch.utils.parser import get_config
18
+ from deep_sort_pytorch.deep_sort import DeepSort
19
+ from collections import deque
20
+ import numpy as np
21
+ palette = (2 ** 11 - 1, 2 ** 15 - 1, 2 ** 20 - 1)
22
+ cars_deque = {}
23
+
24
+
25
+ deepsort = None
26
+
27
+ object_counter = {}
28
+
29
+ speed_line_queue = {}
30
+ def estimatespeed(Location1, Location2):
31
+ #Euclidean Distance Formula
32
+ d_pixel = math.sqrt(math.pow(Location2[0] - Location1[0], 2) + math.pow(Location2[1] - Location1[1], 2))
33
+ # defining thr pixels per meter
34
+ ppm = 8
35
+ d_meters = d_pixel/ppm
36
+ time_constant = 15*3.6
37
+ #distance = speed/time
38
+ speed = d_meters * time_constant
39
+
40
+ return int(speed)
41
+ def init_tracker():
42
+ global deepsort
43
+ cfg_deep = get_config()
44
+ cfg_deep.merge_from_file("deep_sort_pytorch/configs/deep_sort.yaml")
45
+
46
+ deepsort= DeepSort(cfg_deep.DEEPSORT.REID_CKPT,
47
+ max_dist=cfg_deep.DEEPSORT.MAX_DIST, min_confidence=cfg_deep.DEEPSORT.MIN_CONFIDENCE,
48
+ nms_max_overlap=cfg_deep.DEEPSORT.NMS_MAX_OVERLAP, max_iou_distance=cfg_deep.DEEPSORT.MAX_IOU_DISTANCE,
49
+ max_age=cfg_deep.DEEPSORT.MAX_AGE, n_init=cfg_deep.DEEPSORT.N_INIT, nn_budget=cfg_deep.DEEPSORT.NN_BUDGET,
50
+ use_cuda=True)
51
+ ##########################################################################################
52
+ def xyxy_to_xywh(*xyxy):
53
+ """" Calculates the relative bounding box from absolute pixel values. """
54
+ bbox_left = min([xyxy[0].item(), xyxy[2].item()])
55
+ bbox_top = min([xyxy[1].item(), xyxy[3].item()])
56
+ bbox_w = abs(xyxy[0].item() - xyxy[2].item())
57
+ bbox_h = abs(xyxy[1].item() - xyxy[3].item())
58
+ x_c = (bbox_left + bbox_w / 2)
59
+ y_c = (bbox_top + bbox_h / 2)
60
+ w = bbox_w
61
+ h = bbox_h
62
+ return x_c, y_c, w, h
63
+
64
+
65
+ def compute_color_for_labels(label):
66
+ """
67
+ Simple function that adds fixed color depending on the class
68
+ """
69
+ if label == 0: #person
70
+ color = (85,45,255)
71
+ elif label == 2: # Car
72
+ color = (222,82,175)
73
+ elif label == 3: # Motobike
74
+ color = (0, 204, 255)
75
+ elif label == 5: # Bus
76
+ color = (0, 149, 255)
77
+ else:
78
+ color = [int((p * (label ** 2 - label + 1)) % 255) for p in palette]
79
+ return tuple(color)
80
+
81
+ def draw_border(img, pt1, pt2, color, thickness, r, d):
82
+ x1,y1 = pt1
83
+ x2,y2 = pt2
84
+ # Top left
85
+ cv2.line(img, (x1 + r, y1), (x1 + r + d, y1), color, thickness)
86
+ cv2.line(img, (x1, y1 + r), (x1, y1 + r + d), color, thickness)
87
+ cv2.ellipse(img, (x1 + r, y1 + r), (r, r), 180, 0, 90, color, thickness)
88
+ # Top right
89
+ cv2.line(img, (x2 - r, y1), (x2 - r - d, y1), color, thickness)
90
+ cv2.line(img, (x2, y1 + r), (x2, y1 + r + d), color, thickness)
91
+ cv2.ellipse(img, (x2 - r, y1 + r), (r, r), 270, 0, 90, color, thickness)
92
+ # Bottom left
93
+ cv2.line(img, (x1 + r, y2), (x1 + r + d, y2), color, thickness)
94
+ cv2.line(img, (x1, y2 - r), (x1, y2 - r - d), color, thickness)
95
+ cv2.ellipse(img, (x1 + r, y2 - r), (r, r), 90, 0, 90, color, thickness)
96
+ # Bottom right
97
+ cv2.line(img, (x2 - r, y2), (x2 - r - d, y2), color, thickness)
98
+ cv2.line(img, (x2, y2 - r), (x2, y2 - r - d), color, thickness)
99
+ cv2.ellipse(img, (x2 - r, y2 - r), (r, r), 0, 0, 90, color, thickness)
100
+
101
+ cv2.rectangle(img, (x1 + r, y1), (x2 - r, y2), color, -1, cv2.LINE_AA)
102
+ cv2.rectangle(img, (x1, y1 + r), (x2, y2 - r - d), color, -1, cv2.LINE_AA)
103
+
104
+ cv2.circle(img, (x1 +r, y1+r), 2, color, 12)
105
+ cv2.circle(img, (x2 -r, y1+r), 2, color, 12)
106
+ cv2.circle(img, (x1 +r, y2-r), 2, color, 12)
107
+ cv2.circle(img, (x2 -r, y2-r), 2, color, 12)
108
+
109
+ return img
110
+
111
+ def UI_box(x, img, color=None, label=None, line_thickness=None):
112
+ # Plots one bounding box on image img
113
+ tl = line_thickness or round(0.002 * (img.shape[0] + img.shape[1]) / 2) + 1 # line/font thickness
114
+ color = color or [random.randint(0, 255) for _ in range(3)]
115
+ c1, c2 = (int(x[0]), int(x[1])), (int(x[2]), int(x[3]))
116
+ cv2.rectangle(img, c1, c2, color, thickness=tl, lineType=cv2.LINE_AA)
117
+ if label:
118
+ tf = max(tl - 1, 1) # font thickness
119
+ t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0]
120
+
121
+ img = draw_border(img, (c1[0], c1[1] - t_size[1] -3), (c1[0] + t_size[0], c1[1]+3), color, 1, 8, 2)
122
+
123
+ cv2.putText(img, label, (c1[0], c1[1] - 2), 0, tl / 3, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)
124
+
125
+
126
+ def ccw(A,B,C):
127
+ return (C[1]-A[1]) * (B[0]-A[0]) > (B[1]-A[1]) * (C[0]-A[0])
128
+
129
+
130
+ def draw_boxes(img, bbox, names,object_id, identities=None, offset=(0, 0)):
131
+ #cv2.line(img, line[0], line[1], (46,162,112), 3)
132
+ cv2.putText(img, f'Number of cars: {len(cars_deque)}', (11, 35), 0, 1, [0, 255, 0], thickness=2, lineType=cv2.LINE_AA)
133
+ height, width, _ = img.shape
134
+ # remove tracked point from buffer if object is lost
135
+ for key in list(cars_deque):
136
+ if key not in identities:
137
+ cars_deque.pop(key)
138
+
139
+ for i, box in enumerate(bbox):
140
+ obj_name = names[object_id[i]]
141
+ if obj_name == 'car':
142
+ x1, y1, x2, y2 = [int(i) for i in box]
143
+ x1 += offset[0]
144
+ x2 += offset[0]
145
+ y1 += offset[1]
146
+ y2 += offset[1]
147
+
148
+ # code to find center of bottom edge
149
+ center = (int((x2+x1)/ 2), int((y2+y2)/2))
150
+
151
+ # get ID of object
152
+ id = int(identities[i]) if identities is not None else 0
153
+
154
+ # create new buffer for new object
155
+ if id not in cars_deque:
156
+ cars_deque[id] = deque(maxlen= 64)
157
+ speed_line_queue[id] = []
158
+ color = compute_color_for_labels(object_id[i])
159
+
160
+
161
+ label = '{}{:d}'.format("", id) + ":"+ '%s' % (obj_name)
162
+
163
+
164
+ # add center to buffer
165
+ cars_deque[id].appendleft(center)
166
+ if len(cars_deque[id]) >= 2:
167
+ object_speed = estimatespeed(cars_deque[id][1], cars_deque[id][0])
168
+ speed_line_queue[id].append(object_speed)
169
+ if obj_name not in object_counter:
170
+ object_counter[obj_name] = 1
171
+
172
+
173
+ try:
174
+ label = label + " " + str(sum(speed_line_queue[id])//len(speed_line_queue[id])) + "km/h"
175
+ except:
176
+ pass
177
+ UI_box(box, img, label=label, color=color, line_thickness=2)
178
+
179
+
180
+ return img
181
+
182
+
183
+ class DetectionPredictor(BasePredictor):
184
+
185
+ def get_annotator(self, img):
186
+ return Annotator(img, line_width=self.args.line_thickness, example=str(self.model.names))
187
+
188
+ def preprocess(self, img):
189
+ img = torch.from_numpy(img).to(self.model.device)
190
+ img = img.half() if self.model.fp16 else img.float() # uint8 to fp16/32
191
+ img /= 255 # 0 - 255 to 0.0 - 1.0
192
+ return img
193
+
194
+ def postprocess(self, preds, img, orig_img):
195
+ preds = ops.non_max_suppression(preds,
196
+ self.args.conf,
197
+ self.args.iou,
198
+ agnostic=self.args.agnostic_nms,
199
+ max_det=self.args.max_det)
200
+
201
+ for i, pred in enumerate(preds):
202
+ shape = orig_img[i].shape if self.webcam else orig_img.shape
203
+ pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], shape).round()
204
+
205
+ return preds
206
+
207
+ def write_results(self, idx, preds, batch):
208
+ p, im, im0 = batch
209
+ all_outputs = []
210
+ log_string = ""
211
+ if len(im.shape) == 3:
212
+ im = im[None] # expand for batch dim
213
+ self.seen += 1
214
+ im0 = im0.copy()
215
+ if self.webcam: # batch_size >= 1
216
+ log_string += f'{idx}: '
217
+ frame = self.dataset.count
218
+ else:
219
+ frame = getattr(self.dataset, 'frame', 0)
220
+
221
+ self.data_path = p
222
+ save_path = str(self.save_dir / p.name) # im.jpg
223
+ self.txt_path = str(self.save_dir / 'labels' / p.stem) + ('' if self.dataset.mode == 'image' else f'_{frame}')
224
+ log_string += '%gx%g ' % im.shape[2:] # print string
225
+ self.annotator = self.get_annotator(im0)
226
+
227
+ det = preds[idx]
228
+ all_outputs.append(det)
229
+ if len(det) == 0:
230
+ return log_string
231
+ for c in det[:, 5].unique():
232
+ n = (det[:, 5] == c).sum() # detections per class
233
+ log_string += f"{n} {self.model.names[int(c)]}{'s' * (n > 1)}, "
234
+ # write
235
+ gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh
236
+ xywh_bboxs = []
237
+ confs = []
238
+ oids = []
239
+ outputs = []
240
+ for *xyxy, conf, cls in reversed(det):
241
+ x_c, y_c, bbox_w, bbox_h = xyxy_to_xywh(*xyxy)
242
+ xywh_obj = [x_c, y_c, bbox_w, bbox_h]
243
+ xywh_bboxs.append(xywh_obj)
244
+ confs.append([conf.item()])
245
+ oids.append(int(cls))
246
+ xywhs = torch.Tensor(xywh_bboxs)
247
+ confss = torch.Tensor(confs)
248
+
249
+ outputs = deepsort.update(xywhs, confss, oids, im0)
250
+ if len(outputs) > 0:
251
+ bbox_xyxy = outputs[:, :4]
252
+ identities = outputs[:, -2]
253
+ object_id = outputs[:, -1]
254
+
255
+ draw_boxes(im0, bbox_xyxy, self.model.names, object_id,identities)
256
+
257
+ return log_string
258
+
259
+
260
+ @hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name)
261
+ def predict(cfg):
262
+ init_tracker()
263
+ cfg.model = cfg.model or "yolov8n.pt"
264
+ cfg.imgsz = check_imgsz(cfg.imgsz, min_dim=2) # check image size
265
+ cfg.source = cfg.source if cfg.source is not None else ROOT / "assets"
266
+ predictor = DetectionPredictor(cfg)
267
+ predictor()
268
+
269
+
270
+ if __name__ == "__main__":
271
+ predict()