glenn-jocher commited on
Commit
57812df
1 Parent(s): 4200674

New Colors() class (#2963)

Browse files
Files changed (3) hide show
  1. detect.py +3 -6
  2. models/common.py +2 -3
  3. utils/plots.py +16 -8
detect.py CHANGED
@@ -11,7 +11,7 @@ from models.experimental import attempt_load
11
  from utils.datasets import LoadStreams, LoadImages
12
  from utils.general import check_img_size, check_requirements, check_imshow, non_max_suppression, apply_classifier, \
13
  scale_coords, xyxy2xywh, strip_optimizer, set_logging, increment_path, save_one_box
14
- from utils.plots import plot_one_box
15
  from utils.torch_utils import select_device, load_classifier, time_synchronized
16
 
17
 
@@ -34,6 +34,7 @@ def detect(opt):
34
  model = attempt_load(weights, map_location=device) # load FP32 model
35
  stride = int(model.stride.max()) # model stride
36
  imgsz = check_img_size(imgsz, s=stride) # check img_size
 
37
  if half:
38
  model.half() # to FP16
39
 
@@ -52,10 +53,6 @@ def detect(opt):
52
  else:
53
  dataset = LoadImages(source, img_size=imgsz, stride=stride)
54
 
55
- # Get names and colors
56
- names = model.module.names if hasattr(model, 'module') else model.names
57
- colors = [[random.randint(0, 255) for _ in range(3)] for _ in names]
58
-
59
  # Run inference
60
  if device.type != 'cpu':
61
  model(torch.zeros(1, 3, imgsz, imgsz).to(device).type_as(next(model.parameters()))) # run once
@@ -112,7 +109,7 @@ def detect(opt):
112
  c = int(cls) # integer class
113
  label = None if opt.hide_labels else (names[c] if opt.hide_conf else f'{names[c]} {conf:.2f}')
114
 
115
- plot_one_box(xyxy, im0, label=label, color=colors[c], line_thickness=opt.line_thickness)
116
  if opt.save_crop:
117
  save_one_box(xyxy, im0s, file=save_dir / 'crops' / names[c] / f'{p.stem}.jpg', BGR=True)
118
 
 
11
  from utils.datasets import LoadStreams, LoadImages
12
  from utils.general import check_img_size, check_requirements, check_imshow, non_max_suppression, apply_classifier, \
13
  scale_coords, xyxy2xywh, strip_optimizer, set_logging, increment_path, save_one_box
14
+ from utils.plots import colors, plot_one_box
15
  from utils.torch_utils import select_device, load_classifier, time_synchronized
16
 
17
 
 
34
  model = attempt_load(weights, map_location=device) # load FP32 model
35
  stride = int(model.stride.max()) # model stride
36
  imgsz = check_img_size(imgsz, s=stride) # check img_size
37
+ names = model.module.names if hasattr(model, 'module') else model.names # get class names
38
  if half:
39
  model.half() # to FP16
40
 
 
53
  else:
54
  dataset = LoadImages(source, img_size=imgsz, stride=stride)
55
 
 
 
 
 
56
  # Run inference
57
  if device.type != 'cpu':
58
  model(torch.zeros(1, 3, imgsz, imgsz).to(device).type_as(next(model.parameters()))) # run once
 
109
  c = int(cls) # integer class
110
  label = None if opt.hide_labels else (names[c] if opt.hide_conf else f'{names[c]} {conf:.2f}')
111
 
112
+ plot_one_box(xyxy, im0, label=label, color=colors(c, True), line_thickness=opt.line_thickness)
113
  if opt.save_crop:
114
  save_one_box(xyxy, im0s, file=save_dir / 'crops' / names[c] / f'{p.stem}.jpg', BGR=True)
115
 
models/common.py CHANGED
@@ -14,7 +14,7 @@ from torch.cuda import amp
14
 
15
  from utils.datasets import letterbox
16
  from utils.general import non_max_suppression, make_divisible, scale_coords, increment_path, xyxy2xywh, save_one_box
17
- from utils.plots import color_list, plot_one_box
18
  from utils.torch_utils import time_synchronized
19
 
20
 
@@ -312,7 +312,6 @@ class Detections:
312
  self.s = shape # inference BCHW shape
313
 
314
  def display(self, pprint=False, show=False, save=False, crop=False, render=False, save_dir=Path('')):
315
- colors = color_list()
316
  for i, (im, pred) in enumerate(zip(self.imgs, self.pred)):
317
  str = f'image {i + 1}/{len(self.pred)}: {im.shape[0]}x{im.shape[1]} '
318
  if pred is not None:
@@ -325,7 +324,7 @@ class Detections:
325
  if crop:
326
  save_one_box(box, im, file=save_dir / 'crops' / self.names[int(cls)] / self.files[i])
327
  else: # all others
328
- plot_one_box(box, im, label=label, color=colors[int(cls) % 10])
329
 
330
  im = Image.fromarray(im.astype(np.uint8)) if isinstance(im, np.ndarray) else im # from np
331
  if pprint:
 
14
 
15
  from utils.datasets import letterbox
16
  from utils.general import non_max_suppression, make_divisible, scale_coords, increment_path, xyxy2xywh, save_one_box
17
+ from utils.plots import colors, plot_one_box
18
  from utils.torch_utils import time_synchronized
19
 
20
 
 
312
  self.s = shape # inference BCHW shape
313
 
314
  def display(self, pprint=False, show=False, save=False, crop=False, render=False, save_dir=Path('')):
 
315
  for i, (im, pred) in enumerate(zip(self.imgs, self.pred)):
316
  str = f'image {i + 1}/{len(self.pred)}: {im.shape[0]}x{im.shape[1]} '
317
  if pred is not None:
 
324
  if crop:
325
  save_one_box(box, im, file=save_dir / 'crops' / self.names[int(cls)] / self.files[i])
326
  else: # all others
327
+ plot_one_box(box, im, label=label, color=colors(cls))
328
 
329
  im = Image.fromarray(im.astype(np.uint8)) if isinstance(im, np.ndarray) else im # from np
330
  if pprint:
utils/plots.py CHANGED
@@ -26,12 +26,22 @@ matplotlib.rc('font', **{'size': 11})
26
  matplotlib.use('Agg') # for writing to files only
27
 
28
 
29
- def color_list():
30
- # Return first 10 plt colors as (r,g,b) https://stackoverflow.com/questions/51350872/python-from-color-name-to-rgb
31
- def hex2rgb(h):
 
 
 
 
 
 
 
 
 
32
  return tuple(int(h[1 + i:1 + i + 2], 16) for i in (0, 2, 4))
33
 
34
- return [hex2rgb(h) for h in matplotlib.colors.TABLEAU_COLORS.values()] # or BASE_ (8), CSS4_ (148), XKCD_ (949)
 
35
 
36
 
37
  def hist2d(x, y, n=100):
@@ -137,7 +147,6 @@ def plot_images(images, targets, paths=None, fname='images.jpg', names=None, max
137
  h = math.ceil(scale_factor * h)
138
  w = math.ceil(scale_factor * w)
139
 
140
- colors = color_list() # list of colors
141
  mosaic = np.full((int(ns * h), int(ns * w), 3), 255, dtype=np.uint8) # init
142
  for i, img in enumerate(images):
143
  if i == max_subplots: # if last batch has fewer images than we expect
@@ -168,7 +177,7 @@ def plot_images(images, targets, paths=None, fname='images.jpg', names=None, max
168
  boxes[[1, 3]] += block_y
169
  for j, box in enumerate(boxes.T):
170
  cls = int(classes[j])
171
- color = colors[cls % len(colors)]
172
  cls = names[cls] if names else cls
173
  if labels or conf[j] > 0.25: # 0.25 conf thresh
174
  label = '%s' % cls if labels else '%s %.1f' % (cls, conf[j])
@@ -276,7 +285,6 @@ def plot_labels(labels, names=(), save_dir=Path(''), loggers=None):
276
  print('Plotting labels... ')
277
  c, b = labels[:, 0], labels[:, 1:].transpose() # classes, boxes
278
  nc = int(c.max() + 1) # number of classes
279
- colors = color_list()
280
  x = pd.DataFrame(b.transpose(), columns=['x', 'y', 'width', 'height'])
281
 
282
  # seaborn correlogram
@@ -302,7 +310,7 @@ def plot_labels(labels, names=(), save_dir=Path(''), loggers=None):
302
  labels[:, 1:] = xywh2xyxy(labels[:, 1:]) * 2000
303
  img = Image.fromarray(np.ones((2000, 2000, 3), dtype=np.uint8) * 255)
304
  for cls, *box in labels[:1000]:
305
- ImageDraw.Draw(img).rectangle(box, width=1, outline=colors[int(cls) % 10]) # plot
306
  ax[1].imshow(img)
307
  ax[1].axis('off')
308
 
 
26
  matplotlib.use('Agg') # for writing to files only
27
 
28
 
29
+ class Colors:
30
+ # Ultralytics color palette https://ultralytics.com/
31
+ def __init__(self):
32
+ self.palette = [self.hex2rgb(c) for c in matplotlib.colors.TABLEAU_COLORS.values()]
33
+ self.n = len(self.palette)
34
+
35
+ def __call__(self, i, bgr=False):
36
+ c = self.palette[int(i) % self.n]
37
+ return (c[2], c[1], c[0]) if bgr else c
38
+
39
+ @staticmethod
40
+ def hex2rgb(h): # rgb order (PIL)
41
  return tuple(int(h[1 + i:1 + i + 2], 16) for i in (0, 2, 4))
42
 
43
+
44
+ colors = Colors() # create instance for 'from utils.plots import colors'
45
 
46
 
47
  def hist2d(x, y, n=100):
 
147
  h = math.ceil(scale_factor * h)
148
  w = math.ceil(scale_factor * w)
149
 
 
150
  mosaic = np.full((int(ns * h), int(ns * w), 3), 255, dtype=np.uint8) # init
151
  for i, img in enumerate(images):
152
  if i == max_subplots: # if last batch has fewer images than we expect
 
177
  boxes[[1, 3]] += block_y
178
  for j, box in enumerate(boxes.T):
179
  cls = int(classes[j])
180
+ color = colors(cls)
181
  cls = names[cls] if names else cls
182
  if labels or conf[j] > 0.25: # 0.25 conf thresh
183
  label = '%s' % cls if labels else '%s %.1f' % (cls, conf[j])
 
285
  print('Plotting labels... ')
286
  c, b = labels[:, 0], labels[:, 1:].transpose() # classes, boxes
287
  nc = int(c.max() + 1) # number of classes
 
288
  x = pd.DataFrame(b.transpose(), columns=['x', 'y', 'width', 'height'])
289
 
290
  # seaborn correlogram
 
310
  labels[:, 1:] = xywh2xyxy(labels[:, 1:]) * 2000
311
  img = Image.fromarray(np.ones((2000, 2000, 3), dtype=np.uint8) * 255)
312
  for cls, *box in labels[:1000]:
313
+ ImageDraw.Draw(img).rectangle(box, width=1, outline=colors(cls)) # plot
314
  ax[1].imshow(img)
315
  ax[1].axis('off')
316