Update inference default to multi_label=False (#2252)
Browse files* Update inference default to multi_label=False
* bug fix
* Update plots.py
* Update plots.py
- models/common.py +1 -1
- test.py +4 -4
- utils/general.py +5 -4
- utils/plots.py +1 -1
models/common.py
CHANGED
@@ -7,7 +7,7 @@ import numpy as np
|
|
7 |
import requests
|
8 |
import torch
|
9 |
import torch.nn as nn
|
10 |
-
from PIL import Image
|
11 |
|
12 |
from utils.datasets import letterbox
|
13 |
from utils.general import non_max_suppression, make_divisible, scale_coords, xyxy2xywh
|
|
|
7 |
import requests
|
8 |
import torch
|
9 |
import torch.nn as nn
|
10 |
+
from PIL import Image
|
11 |
|
12 |
from utils.datasets import letterbox
|
13 |
from utils.general import non_max_suppression, make_divisible, scale_coords, xyxy2xywh
|
test.py
CHANGED
@@ -106,7 +106,7 @@ def test(data,
|
|
106 |
with torch.no_grad():
|
107 |
# Run model
|
108 |
t = time_synchronized()
|
109 |
-
|
110 |
t0 += time_synchronized() - t
|
111 |
|
112 |
# Compute loss
|
@@ -117,11 +117,11 @@ def test(data,
|
|
117 |
targets[:, 2:] *= torch.Tensor([width, height, width, height]).to(device) # to pixels
|
118 |
lb = [targets[targets[:, 0] == i, 1:] for i in range(nb)] if save_hybrid else [] # for autolabelling
|
119 |
t = time_synchronized()
|
120 |
-
|
121 |
t1 += time_synchronized() - t
|
122 |
|
123 |
# Statistics per image
|
124 |
-
for si, pred in enumerate(
|
125 |
labels = targets[targets[:, 0] == si, 1:]
|
126 |
nl = len(labels)
|
127 |
tcls = labels[:, 0].tolist() if nl else [] # target class
|
@@ -209,7 +209,7 @@ def test(data,
|
|
209 |
f = save_dir / f'test_batch{batch_i}_labels.jpg' # labels
|
210 |
Thread(target=plot_images, args=(img, targets, paths, f, names), daemon=True).start()
|
211 |
f = save_dir / f'test_batch{batch_i}_pred.jpg' # predictions
|
212 |
-
Thread(target=plot_images, args=(img, output_to_target(
|
213 |
|
214 |
# Compute statistics
|
215 |
stats = [np.concatenate(x, 0) for x in zip(*stats)] # to numpy
|
|
|
106 |
with torch.no_grad():
|
107 |
# Run model
|
108 |
t = time_synchronized()
|
109 |
+
out, train_out = model(img, augment=augment) # inference and training outputs
|
110 |
t0 += time_synchronized() - t
|
111 |
|
112 |
# Compute loss
|
|
|
117 |
targets[:, 2:] *= torch.Tensor([width, height, width, height]).to(device) # to pixels
|
118 |
lb = [targets[targets[:, 0] == i, 1:] for i in range(nb)] if save_hybrid else [] # for autolabelling
|
119 |
t = time_synchronized()
|
120 |
+
out = non_max_suppression(out, conf_thres=conf_thres, iou_thres=iou_thres, labels=lb, multi_label=True)
|
121 |
t1 += time_synchronized() - t
|
122 |
|
123 |
# Statistics per image
|
124 |
+
for si, pred in enumerate(out):
|
125 |
labels = targets[targets[:, 0] == si, 1:]
|
126 |
nl = len(labels)
|
127 |
tcls = labels[:, 0].tolist() if nl else [] # target class
|
|
|
209 |
f = save_dir / f'test_batch{batch_i}_labels.jpg' # labels
|
210 |
Thread(target=plot_images, args=(img, targets, paths, f, names), daemon=True).start()
|
211 |
f = save_dir / f'test_batch{batch_i}_pred.jpg' # predictions
|
212 |
+
Thread(target=plot_images, args=(img, output_to_target(out), paths, f, names), daemon=True).start()
|
213 |
|
214 |
# Compute statistics
|
215 |
stats = [np.concatenate(x, 0) for x in zip(*stats)] # to numpy
|
utils/general.py
CHANGED
@@ -390,11 +390,12 @@ def wh_iou(wh1, wh2):
|
|
390 |
return inter / (wh1.prod(2) + wh2.prod(2) - inter) # iou = inter / (area1 + area2 - inter)
|
391 |
|
392 |
|
393 |
-
def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False,
|
394 |
-
|
|
|
395 |
|
396 |
Returns:
|
397 |
-
|
398 |
"""
|
399 |
|
400 |
nc = prediction.shape[2] - 5 # number of classes
|
@@ -406,7 +407,7 @@ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=Non
|
|
406 |
max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
|
407 |
time_limit = 10.0 # seconds to quit after
|
408 |
redundant = True # require redundant detections
|
409 |
-
multi_label
|
410 |
merge = False # use merge-NMS
|
411 |
|
412 |
t = time.time()
|
|
|
390 |
return inter / (wh1.prod(2) + wh2.prod(2) - inter) # iou = inter / (area1 + area2 - inter)
|
391 |
|
392 |
|
393 |
+
def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, multi_label=False,
|
394 |
+
labels=()):
|
395 |
+
"""Runs Non-Maximum Suppression (NMS) on inference results
|
396 |
|
397 |
Returns:
|
398 |
+
list of detections, on (n,6) tensor per image [xyxy, conf, cls]
|
399 |
"""
|
400 |
|
401 |
nc = prediction.shape[2] - 5 # number of classes
|
|
|
407 |
max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
|
408 |
time_limit = 10.0 # seconds to quit after
|
409 |
redundant = True # require redundant detections
|
410 |
+
multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
|
411 |
merge = False # use merge-NMS
|
412 |
|
413 |
t = time.time()
|
utils/plots.py
CHANGED
@@ -54,7 +54,7 @@ def butter_lowpass_filtfilt(data, cutoff=1500, fs=50000, order=5):
|
|
54 |
return filtfilt(b, a, data) # forward-backward filter
|
55 |
|
56 |
|
57 |
-
def plot_one_box(x, img, color=None, label=None, line_thickness=
|
58 |
# Plots one bounding box on image img
|
59 |
tl = line_thickness or round(0.002 * (img.shape[0] + img.shape[1]) / 2) + 1 # line/font thickness
|
60 |
color = color or [random.randint(0, 255) for _ in range(3)]
|
|
|
54 |
return filtfilt(b, a, data) # forward-backward filter
|
55 |
|
56 |
|
57 |
+
def plot_one_box(x, img, color=None, label=None, line_thickness=3):
|
58 |
# Plots one bounding box on image img
|
59 |
tl = line_thickness or round(0.002 * (img.shape[0] + img.shape[1]) / 2) + 1 # line/font thickness
|
60 |
color = color or [random.randint(0, 255) for _ in range(3)]
|