glenn-jocher commited on
Commit
c09964c
·
unverified ·
1 Parent(s): ab2da5e

Update inference default to multi_label=False (#2252)

Browse files

* Update inference default to multi_label=False

* bug fix

* Update plots.py

* Update plots.py

Files changed (4) hide show
  1. models/common.py +1 -1
  2. test.py +4 -4
  3. utils/general.py +5 -4
  4. utils/plots.py +1 -1
models/common.py CHANGED
@@ -7,7 +7,7 @@ import numpy as np
7
  import requests
8
  import torch
9
  import torch.nn as nn
10
- from PIL import Image, ImageDraw
11
 
12
  from utils.datasets import letterbox
13
  from utils.general import non_max_suppression, make_divisible, scale_coords, xyxy2xywh
 
7
  import requests
8
  import torch
9
  import torch.nn as nn
10
+ from PIL import Image
11
 
12
  from utils.datasets import letterbox
13
  from utils.general import non_max_suppression, make_divisible, scale_coords, xyxy2xywh
test.py CHANGED
@@ -106,7 +106,7 @@ def test(data,
106
  with torch.no_grad():
107
  # Run model
108
  t = time_synchronized()
109
- inf_out, train_out = model(img, augment=augment) # inference and training outputs
110
  t0 += time_synchronized() - t
111
 
112
  # Compute loss
@@ -117,11 +117,11 @@ def test(data,
117
  targets[:, 2:] *= torch.Tensor([width, height, width, height]).to(device) # to pixels
118
  lb = [targets[targets[:, 0] == i, 1:] for i in range(nb)] if save_hybrid else [] # for autolabelling
119
  t = time_synchronized()
120
- output = non_max_suppression(inf_out, conf_thres=conf_thres, iou_thres=iou_thres, labels=lb)
121
  t1 += time_synchronized() - t
122
 
123
  # Statistics per image
124
- for si, pred in enumerate(output):
125
  labels = targets[targets[:, 0] == si, 1:]
126
  nl = len(labels)
127
  tcls = labels[:, 0].tolist() if nl else [] # target class
@@ -209,7 +209,7 @@ def test(data,
209
  f = save_dir / f'test_batch{batch_i}_labels.jpg' # labels
210
  Thread(target=plot_images, args=(img, targets, paths, f, names), daemon=True).start()
211
  f = save_dir / f'test_batch{batch_i}_pred.jpg' # predictions
212
- Thread(target=plot_images, args=(img, output_to_target(output), paths, f, names), daemon=True).start()
213
 
214
  # Compute statistics
215
  stats = [np.concatenate(x, 0) for x in zip(*stats)] # to numpy
 
106
  with torch.no_grad():
107
  # Run model
108
  t = time_synchronized()
109
+ out, train_out = model(img, augment=augment) # inference and training outputs
110
  t0 += time_synchronized() - t
111
 
112
  # Compute loss
 
117
  targets[:, 2:] *= torch.Tensor([width, height, width, height]).to(device) # to pixels
118
  lb = [targets[targets[:, 0] == i, 1:] for i in range(nb)] if save_hybrid else [] # for autolabelling
119
  t = time_synchronized()
120
+ out = non_max_suppression(out, conf_thres=conf_thres, iou_thres=iou_thres, labels=lb, multi_label=True)
121
  t1 += time_synchronized() - t
122
 
123
  # Statistics per image
124
+ for si, pred in enumerate(out):
125
  labels = targets[targets[:, 0] == si, 1:]
126
  nl = len(labels)
127
  tcls = labels[:, 0].tolist() if nl else [] # target class
 
209
  f = save_dir / f'test_batch{batch_i}_labels.jpg' # labels
210
  Thread(target=plot_images, args=(img, targets, paths, f, names), daemon=True).start()
211
  f = save_dir / f'test_batch{batch_i}_pred.jpg' # predictions
212
+ Thread(target=plot_images, args=(img, output_to_target(out), paths, f, names), daemon=True).start()
213
 
214
  # Compute statistics
215
  stats = [np.concatenate(x, 0) for x in zip(*stats)] # to numpy
utils/general.py CHANGED
@@ -390,11 +390,12 @@ def wh_iou(wh1, wh2):
390
  return inter / (wh1.prod(2) + wh2.prod(2) - inter) # iou = inter / (area1 + area2 - inter)
391
 
392
 
393
- def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, labels=()):
394
- """Performs Non-Maximum Suppression (NMS) on inference results
 
395
 
396
  Returns:
397
- detections with shape: nx6 (x1, y1, x2, y2, conf, cls)
398
  """
399
 
400
  nc = prediction.shape[2] - 5 # number of classes
@@ -406,7 +407,7 @@ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=Non
406
  max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
407
  time_limit = 10.0 # seconds to quit after
408
  redundant = True # require redundant detections
409
- multi_label = nc > 1 # multiple labels per box (adds 0.5ms/img)
410
  merge = False # use merge-NMS
411
 
412
  t = time.time()
 
390
  return inter / (wh1.prod(2) + wh2.prod(2) - inter) # iou = inter / (area1 + area2 - inter)
391
 
392
 
393
+ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, multi_label=False,
394
+ labels=()):
395
+ """Runs Non-Maximum Suppression (NMS) on inference results
396
 
397
  Returns:
398
+ list of detections, on (n,6) tensor per image [xyxy, conf, cls]
399
  """
400
 
401
  nc = prediction.shape[2] - 5 # number of classes
 
407
  max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
408
  time_limit = 10.0 # seconds to quit after
409
  redundant = True # require redundant detections
410
+ multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
411
  merge = False # use merge-NMS
412
 
413
  t = time.time()
utils/plots.py CHANGED
@@ -54,7 +54,7 @@ def butter_lowpass_filtfilt(data, cutoff=1500, fs=50000, order=5):
54
  return filtfilt(b, a, data) # forward-backward filter
55
 
56
 
57
- def plot_one_box(x, img, color=None, label=None, line_thickness=None):
58
  # Plots one bounding box on image img
59
  tl = line_thickness or round(0.002 * (img.shape[0] + img.shape[1]) / 2) + 1 # line/font thickness
60
  color = color or [random.randint(0, 255) for _ in range(3)]
 
54
  return filtfilt(b, a, data) # forward-backward filter
55
 
56
 
57
+ def plot_one_box(x, img, color=None, label=None, line_thickness=3):
58
  # Plots one bounding box on image img
59
  tl = line_thickness or round(0.002 * (img.shape[0] + img.shape[1]) / 2) + 1 # line/font thickness
60
  color = color or [random.randint(0, 255) for _ in range(3)]