Zengyf-CVer commited on
Commit
bfac58a
1 Parent(s): e021555

v04 add colot

Browse files
Files changed (1) hide show
  1. app.py +25 -5
app.py CHANGED
@@ -160,23 +160,25 @@ def export_json(results, img_size):
160
 
161
 
162
  # frame conversion
163
- def pil_draw(img, countdown_msg, textFont, xyxy, font_size, opt):
164
 
165
  img_pil = ImageDraw.Draw(img)
166
 
167
- img_pil.rectangle(xyxy, fill=None, outline="green") # bounding box
168
 
169
  if "label" in opt:
170
  text_w, text_h = textFont.getsize(countdown_msg) # Label size
 
171
  img_pil.rectangle(
172
  (xyxy[0], xyxy[1], xyxy[0] + text_w, xyxy[1] + text_h),
173
- fill="green",
174
- outline="green",
175
  ) # label background
 
176
  img_pil.multiline_text(
177
  (xyxy[0], xyxy[1]),
178
  countdown_msg,
179
- fill=(205, 250, 255),
180
  font=textFont,
181
  align="center",
182
  )
@@ -184,6 +186,16 @@ def pil_draw(img, countdown_msg, textFont, xyxy, font_size, opt):
184
  return img
185
 
186
 
 
 
 
 
 
 
 
 
 
 
187
  # YOLOv5 image detection function
188
  def yolo_det_img(img, device, model_name, infer_size, conf, iou, max_num, model_cls, opt):
189
 
@@ -210,6 +222,8 @@ def yolo_det_img(img, device, model_name, infer_size, conf, iou, max_num, model_
210
  model.max_det = int(max_num) # Maximum number of detection frames
211
  model.classes = model_cls # model classes
212
 
 
 
213
  img_size = img.size # frame size
214
 
215
  results = model(img, size=infer_size) # detection
@@ -260,6 +274,8 @@ def yolo_det_img(img, device, model_name, infer_size, conf, iou, max_num, model_
260
  [x0, y0, x1, y1],
261
  FONTSIZE,
262
  opt,
 
 
263
  )
264
 
265
  # ----------add object size----------
@@ -332,6 +348,8 @@ def yolo_det_video(video, device, model_name, infer_size, conf, iou, max_num, mo
332
  model.max_det = int(max_num) # Maximum number of detection frames
333
  model.classes = model_cls # model classes
334
 
 
 
335
  # ----------------Load fonts----------------
336
  yaml_index = cls_name.index(".yaml")
337
  cls_name_lang = cls_name[yaml_index - 2:yaml_index]
@@ -393,6 +411,8 @@ def yolo_det_video(video, device, model_name, infer_size, conf, iou, max_num, mo
393
  [x0, y0, x1, y1],
394
  FONTSIZE,
395
  opt,
 
 
396
  )
397
 
398
  frame = cv2.cvtColor(np.asarray(frame), cv2.COLOR_RGB2BGR)
 
160
 
161
 
162
  # frame conversion
163
+ def pil_draw(img, countdown_msg, textFont, xyxy, font_size, opt, obj_cls_index, color_list):
164
 
165
  img_pil = ImageDraw.Draw(img)
166
 
167
+ img_pil.rectangle(xyxy, fill=None, outline=color_list[obj_cls_index]) # bounding box
168
 
169
  if "label" in opt:
170
  text_w, text_h = textFont.getsize(countdown_msg) # Label size
171
+
172
  img_pil.rectangle(
173
  (xyxy[0], xyxy[1], xyxy[0] + text_w, xyxy[1] + text_h),
174
+ fill=color_list[obj_cls_index],
175
+ outline=color_list[obj_cls_index],
176
  ) # label background
177
+
178
  img_pil.multiline_text(
179
  (xyxy[0], xyxy[1]),
180
  countdown_msg,
181
+ fill=(255, 255, 255),
182
  font=textFont,
183
  align="center",
184
  )
 
186
  return img
187
 
188
 
189
+ def color_set(cls_num):
190
+ color_list = []
191
+ for i in range(cls_num):
192
+ color = tuple(np.random.choice(range(256), size=3))
193
+ # color = ["#"+''.join([random.choice('0123456789ABCDEF') for j in range(6)])]
194
+ color_list.append(color)
195
+
196
+ return color_list
197
+
198
+
199
  # YOLOv5 image detection function
200
  def yolo_det_img(img, device, model_name, infer_size, conf, iou, max_num, model_cls, opt):
201
 
 
222
  model.max_det = int(max_num) # Maximum number of detection frames
223
  model.classes = model_cls # model classes
224
 
225
+ color_list = color_set(len(model_cls_name_cp)) # 设置颜色
226
+
227
  img_size = img.size # frame size
228
 
229
  results = model(img, size=infer_size) # detection
 
274
  [x0, y0, x1, y1],
275
  FONTSIZE,
276
  opt,
277
+ obj_cls_index,
278
+ color_list,
279
  )
280
 
281
  # ----------add object size----------
 
348
  model.max_det = int(max_num) # Maximum number of detection frames
349
  model.classes = model_cls # model classes
350
 
351
+ color_list = color_set(len(model_cls_name_cp)) # 设置颜色
352
+
353
  # ----------------Load fonts----------------
354
  yaml_index = cls_name.index(".yaml")
355
  cls_name_lang = cls_name[yaml_index - 2:yaml_index]
 
411
  [x0, y0, x1, y1],
412
  FONTSIZE,
413
  opt,
414
+ obj_cls_index,
415
+ color_list,
416
  )
417
 
418
  frame = cv2.cvtColor(np.asarray(frame), cv2.COLOR_RGB2BGR)