Spaces:
Runtime error
Runtime error
Commit
•
1a4c886
1
Parent(s):
c756bdd
Update inference.py
Browse files- inference.py +29 -21
inference.py
CHANGED
@@ -59,14 +59,16 @@ def do_full_inference(dataloader, image_meter_width, image_meter_height, gp=None
|
|
59 |
# Detect boxes in frames
|
60 |
inference, image_shapes, width, height = do_detection(dataloader, model, device, gp=gp)
|
61 |
|
|
|
|
|
62 |
if config.associative_tracker == TrackerType.BYTETRACK:
|
63 |
|
64 |
# Find low confidence detections
|
65 |
-
low_outputs = do_suppression(inference,
|
66 |
low_preds, real_width, real_height = format_predictions(image_shapes, low_outputs, width, height, gp=gp)
|
67 |
|
68 |
# Find high confidence detections
|
69 |
-
high_outputs = do_suppression(inference,
|
70 |
high_preds, real_width, real_height = format_predictions(image_shapes, high_outputs, width, height, gp=gp)
|
71 |
|
72 |
# Perform associative tracking (ByteTrack)
|
@@ -167,7 +169,7 @@ def do_detection(dataloader, model, device, gp=None, batch_size=BATCH_SIZE, verb
|
|
167 |
|
168 |
return inference, image_shapes, width, height
|
169 |
|
170 |
-
def do_suppression(inference,
|
171 |
"""
|
172 |
Args:
|
173 |
frames_dir: a directory containing frames to be evaluated
|
@@ -186,7 +188,7 @@ def do_suppression(inference, pix2w, gp=None, batch_size=BATCH_SIZE, conf_thres=
|
|
186 |
if gp: gp(batch_i / len(inference), pbar.__str__())
|
187 |
|
188 |
with torch.no_grad():
|
189 |
-
output = non_max_suppression(inf_out,
|
190 |
|
191 |
|
192 |
outputs.append(output)
|
@@ -392,13 +394,32 @@ def json_dump_round_float(some_object, out_path, num_digits=4):
|
|
392 |
with patch('json.encoder._make_iterencode', wraps=inner):
|
393 |
return json.dump(some_object, open(out_path, 'w'), indent=2)
|
394 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
395 |
def non_max_suppression(
|
396 |
prediction,
|
397 |
-
pix2w,
|
398 |
conf_thres=0.25,
|
399 |
iou_thres=0.45,
|
400 |
-
max_det=300
|
401 |
-
max_length=0.8
|
402 |
):
|
403 |
"""Non-Maximum Suppression (NMS) on inference results to reject overlapping detections
|
404 |
|
@@ -420,19 +441,6 @@ def non_max_suppression(
|
|
420 |
prediction = prediction.cpu()
|
421 |
bs = prediction.shape[0] # batch size
|
422 |
xc = prediction[..., 4] > conf_thres # candidates
|
423 |
-
print("pix2w", pix2w)
|
424 |
-
width = prediction[..., 2]*pix2w
|
425 |
-
wc = width < max_length
|
426 |
-
|
427 |
-
print(0.1, sum(sum(width < 0.1)))
|
428 |
-
print(0.8, sum(sum(width < 0.8)))
|
429 |
-
print(1.5, sum(sum(width < 1.5)))
|
430 |
-
print(10, sum(sum(width < 10 )))
|
431 |
-
print(10, sum(sum(width < 10 )))
|
432 |
-
print(prediction[1,512,:])
|
433 |
-
print("mean", torch.mean(width))
|
434 |
-
|
435 |
-
|
436 |
|
437 |
|
438 |
# Settings
|
@@ -446,7 +454,7 @@ def non_max_suppression(
|
|
446 |
|
447 |
|
448 |
# Keep boxes that pass confidence threshold
|
449 |
-
x = x[xc[xi]
|
450 |
|
451 |
# If none remain process next image
|
452 |
if not x.shape[0]:
|
|
|
59 |
# Detect boxes in frames
|
60 |
inference, image_shapes, width, height = do_detection(dataloader, model, device, gp=gp)
|
61 |
|
62 |
+
inference = filter_detection_size(inference, image_meter_width, width, 1.5)
|
63 |
+
|
64 |
if config.associative_tracker == TrackerType.BYTETRACK:
|
65 |
|
66 |
# Find low confidence detections
|
67 |
+
low_outputs = do_suppression(inference, conf_thres=config.byte_low_conf, iou_thres=config.nms_iou, gp=gp)
|
68 |
low_preds, real_width, real_height = format_predictions(image_shapes, low_outputs, width, height, gp=gp)
|
69 |
|
70 |
# Find high confidence detections
|
71 |
+
high_outputs = do_suppression(inference, conf_thres=config.byte_high_conf, iou_thres=config.nms_iou, gp=gp)
|
72 |
high_preds, real_width, real_height = format_predictions(image_shapes, high_outputs, width, height, gp=gp)
|
73 |
|
74 |
# Perform associative tracking (ByteTrack)
|
|
|
169 |
|
170 |
return inference, image_shapes, width, height
|
171 |
|
172 |
+
def do_suppression(inference, gp=None, batch_size=BATCH_SIZE, conf_thres=CONF_THRES, iou_thres=NMS_IOU, verbose=True):
|
173 |
"""
|
174 |
Args:
|
175 |
frames_dir: a directory containing frames to be evaluated
|
|
|
188 |
if gp: gp(batch_i / len(inference), pbar.__str__())
|
189 |
|
190 |
with torch.no_grad():
|
191 |
+
output = non_max_suppression(inf_out, conf_thres=conf_thres, iou_thres=iou_thres)
|
192 |
|
193 |
|
194 |
outputs.append(output)
|
|
|
394 |
with patch('json.encoder._make_iterencode', wraps=inner):
|
395 |
return json.dump(some_object, open(out_path, 'w'), indent=2)
|
396 |
|
397 |
+
|
398 |
+
def filter_detection_size(inference, image_meter_width, width, max_length):
|
399 |
+
|
400 |
+
outputs = []
|
401 |
+
for prediction in inference:
|
402 |
+
|
403 |
+
pix2width = image_meter_width/width
|
404 |
+
print("pix2w", pix2width)
|
405 |
+
width = prediction[..., 2]*pix2width
|
406 |
+
wc = width < max_length
|
407 |
+
bs = prediction.shape[0] # batches
|
408 |
+
|
409 |
+
output = [torch.zeros((0, 6), device=prediction.device)] * bs
|
410 |
+
for xi, x in enumerate(prediction):
|
411 |
+
x = x[wc[xi]] # confidence
|
412 |
+
output[xi] = x
|
413 |
+
|
414 |
+
outputs.append(output)
|
415 |
+
|
416 |
+
return outputs
|
417 |
+
|
418 |
def non_max_suppression(
|
419 |
prediction,
|
|
|
420 |
conf_thres=0.25,
|
421 |
iou_thres=0.45,
|
422 |
+
max_det=300
|
|
|
423 |
):
|
424 |
"""Non-Maximum Suppression (NMS) on inference results to reject overlapping detections
|
425 |
|
|
|
441 |
prediction = prediction.cpu()
|
442 |
bs = prediction.shape[0] # batch size
|
443 |
xc = prediction[..., 4] > conf_thres # candidates
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
444 |
|
445 |
|
446 |
# Settings
|
|
|
454 |
|
455 |
|
456 |
# Keep boxes that pass confidence threshold
|
457 |
+
x = x[xc[xi]] # confidence
|
458 |
|
459 |
# If none remain process next image
|
460 |
if not x.shape[0]:
|