Spaces:
Runtime error
Runtime error
Commit
•
93b9874
1
Parent(s):
6bfcb22
Update inference.py
Browse files- inference.py +15 -9
inference.py
CHANGED
@@ -64,11 +64,11 @@ def do_full_inference(dataloader, image_meter_width, image_meter_height, gp=None
|
|
64 |
if config.associative_tracker == TrackerType.BYTETRACK:
|
65 |
|
66 |
# Find low confidence detections
|
67 |
-
low_outputs = do_suppression(inference, conf_thres=config.byte_low_conf, iou_thres=config.nms_iou, gp=gp)
|
68 |
low_preds, real_width, real_height = format_predictions(image_shapes, low_outputs, width, height, gp=gp)
|
69 |
|
70 |
# Find high confidence detections
|
71 |
-
high_outputs = do_suppression(inference, conf_thres=config.byte_high_conf, iou_thres=config.nms_iou, gp=gp)
|
72 |
high_preds, real_width, real_height = format_predictions(image_shapes, high_outputs, width, height, gp=gp)
|
73 |
|
74 |
# Perform associative tracking (ByteTrack)
|
@@ -80,7 +80,7 @@ def do_full_inference(dataloader, image_meter_width, image_meter_height, gp=None
|
|
80 |
else:
|
81 |
|
82 |
# Find confident detections
|
83 |
-
outputs = do_suppression(inference, conf_thres=config.conf_thresh, iou_thres=config.nms_iou, gp=gp)
|
84 |
|
85 |
if config.associative_tracker == TrackerType.CONF_BOOST:
|
86 |
|
@@ -88,7 +88,7 @@ def do_full_inference(dataloader, image_meter_width, image_meter_height, gp=None
|
|
88 |
do_confidence_boost(inference, outputs, boost_power=config.boost_power, boost_decay=config.boost_decay, gp=gp)
|
89 |
|
90 |
# Find confident detections from boosted list
|
91 |
-
outputs = do_suppression(inference, conf_thres=config.conf_thresh, iou_thres=config.nms_iou, gp=gp)
|
92 |
|
93 |
# Format confident detections
|
94 |
all_preds, real_width, real_height = format_predictions(image_shapes, outputs, width, height, gp=gp)
|
@@ -169,7 +169,7 @@ def do_detection(dataloader, model, device, gp=None, batch_size=BATCH_SIZE, verb
|
|
169 |
|
170 |
return inference, image_shapes, width, height
|
171 |
|
172 |
-
def do_suppression(inference, gp=None, batch_size=BATCH_SIZE, conf_thres=CONF_THRES, iou_thres=NMS_IOU, verbose=True):
|
173 |
"""
|
174 |
Args:
|
175 |
frames_dir: a directory containing frames to be evaluated
|
@@ -188,7 +188,7 @@ def do_suppression(inference, gp=None, batch_size=BATCH_SIZE, conf_thres=CONF_TH
|
|
188 |
if gp: gp(batch_i / len(inference), pbar.__str__())
|
189 |
|
190 |
with torch.no_grad():
|
191 |
-
output = non_max_suppression(inf_out, conf_thres=conf_thres, iou_thres=iou_thres)
|
192 |
|
193 |
|
194 |
outputs.append(output)
|
@@ -422,13 +422,13 @@ def filter_detection_size(inference, image_meter_width, width, max_length):
|
|
422 |
print(wc.shape)
|
423 |
bs = batch.shape[0] # batches
|
424 |
|
425 |
-
output =
|
426 |
print("wc")
|
427 |
print(batch.shape)
|
428 |
for xi, x in enumerate(batch):
|
429 |
x = x[wc[xi]] # confidence
|
430 |
print(x.shape)
|
431 |
-
output[xi] = x
|
432 |
|
433 |
output = torch.tensor(output)
|
434 |
print("output len", output.shape)
|
@@ -439,6 +439,9 @@ def filter_detection_size(inference, image_meter_width, width, max_length):
|
|
439 |
|
440 |
def non_max_suppression(
|
441 |
prediction,
|
|
|
|
|
|
|
442 |
conf_thres=0.25,
|
443 |
iou_thres=0.45,
|
444 |
max_det=300
|
@@ -463,6 +466,9 @@ def non_max_suppression(
|
|
463 |
prediction = prediction.cpu()
|
464 |
bs = prediction.shape[0] # batch size
|
465 |
xc = prediction[..., 4] > conf_thres # candidates
|
|
|
|
|
|
|
466 |
|
467 |
|
468 |
# Settings
|
@@ -476,7 +482,7 @@ def non_max_suppression(
|
|
476 |
|
477 |
|
478 |
# Keep boxes that pass confidence threshold
|
479 |
-
x = x[xc[xi]] # confidence
|
480 |
|
481 |
# If none remain process next image
|
482 |
if not x.shape[0]:
|
|
|
64 |
if config.associative_tracker == TrackerType.BYTETRACK:
|
65 |
|
66 |
# Find low confidence detections
|
67 |
+
low_outputs = do_suppression(inference, image_meter_width, width, conf_thres=config.byte_low_conf, iou_thres=config.nms_iou, gp=gp)
|
68 |
low_preds, real_width, real_height = format_predictions(image_shapes, low_outputs, width, height, gp=gp)
|
69 |
|
70 |
# Find high confidence detections
|
71 |
+
high_outputs = do_suppression(inference, image_meter_width, width, conf_thres=config.byte_high_conf, iou_thres=config.nms_iou, gp=gp)
|
72 |
high_preds, real_width, real_height = format_predictions(image_shapes, high_outputs, width, height, gp=gp)
|
73 |
|
74 |
# Perform associative tracking (ByteTrack)
|
|
|
80 |
else:
|
81 |
|
82 |
# Find confident detections
|
83 |
+
outputs = do_suppression(inference, image_meter_width, width, conf_thres=config.conf_thresh, iou_thres=config.nms_iou, gp=gp)
|
84 |
|
85 |
if config.associative_tracker == TrackerType.CONF_BOOST:
|
86 |
|
|
|
88 |
do_confidence_boost(inference, outputs, boost_power=config.boost_power, boost_decay=config.boost_decay, gp=gp)
|
89 |
|
90 |
# Find confident detections from boosted list
|
91 |
+
outputs = do_suppression(inference, image_meter_width, width, conf_thres=config.conf_thresh, iou_thres=config.nms_iou, gp=gp)
|
92 |
|
93 |
# Format confident detections
|
94 |
all_preds, real_width, real_height = format_predictions(image_shapes, outputs, width, height, gp=gp)
|
|
|
169 |
|
170 |
return inference, image_shapes, width, height
|
171 |
|
172 |
+
def do_suppression(inference, image_meter_width, image_pixel_width, gp=None, batch_size=BATCH_SIZE, conf_thres=CONF_THRES, iou_thres=NMS_IOU, max_length=1.5, verbose=True):
|
173 |
"""
|
174 |
Args:
|
175 |
frames_dir: a directory containing frames to be evaluated
|
|
|
188 |
if gp: gp(batch_i / len(inference), pbar.__str__())
|
189 |
|
190 |
with torch.no_grad():
|
191 |
+
output = non_max_suppression(inf_out, image_meter_width, image_pixel_width, conf_thres=conf_thres, iou_thres=iou_thres, max_length=max_length)
|
192 |
|
193 |
|
194 |
outputs.append(output)
|
|
|
422 |
print(wc.shape)
|
423 |
bs = batch.shape[0] # batches
|
424 |
|
425 |
+
output = torch.zeros((bs, 0, 6), device=batch.device)
|
426 |
print("wc")
|
427 |
print(batch.shape)
|
428 |
for xi, x in enumerate(batch):
|
429 |
x = x[wc[xi]] # confidence
|
430 |
print(x.shape)
|
431 |
+
output[xi, :, :] = x
|
432 |
|
433 |
output = torch.tensor(output)
|
434 |
print("output len", output.shape)
|
|
|
439 |
|
440 |
def non_max_suppression(
|
441 |
prediction,
|
442 |
+
image_meter_width,
|
443 |
+
image_pixel_width,
|
444 |
+
max_length=1.5,
|
445 |
conf_thres=0.25,
|
446 |
iou_thres=0.45,
|
447 |
max_det=300
|
|
|
466 |
prediction = prediction.cpu()
|
467 |
bs = prediction.shape[0] # batch size
|
468 |
xc = prediction[..., 4] > conf_thres # candidates
|
469 |
+
pix2width = image_meter_width/width
|
470 |
+
width = prediction[..., 2]*pix2width
|
471 |
+
wc = width < max_length
|
472 |
|
473 |
|
474 |
# Settings
|
|
|
482 |
|
483 |
|
484 |
# Keep boxes that pass confidence threshold
|
485 |
+
x = x[xc[xi] * wc[xi]] # confidence
|
486 |
|
487 |
# If none remain process next image
|
488 |
if not x.shape[0]:
|