Spaces:
Runtime error
Runtime error
Commit
•
7162476
1
Parent(s):
94395df
Update inference.py
Browse files- inference.py +12 -5
inference.py
CHANGED
@@ -205,6 +205,11 @@ def format_predictions(image_shapes, outputs, width, height, gp=None, batch_size
|
|
205 |
gp: a callback function which takes as input 1 parameter, (int) percent complete
|
206 |
prep_for_marking: re-index fish for manual marking output
|
207 |
"""
|
|
|
|
|
|
|
|
|
|
|
208 |
|
209 |
if (gp): gp(0, "Formatting...")
|
210 |
# keep predictions to feed them ordered into the Tracker
|
@@ -216,7 +221,9 @@ def format_predictions(image_shapes, outputs, width, height, gp=None, batch_size
|
|
216 |
if gp: gp(batch_i / len(image_shapes), pbar.__str__())
|
217 |
|
218 |
batch_shapes = image_shapes[batch_i]
|
|
|
219 |
print(len(batch_shapes))
|
|
|
220 |
print(len(batch))
|
221 |
print(batch_shapes[0])
|
222 |
print(batch[0])
|
@@ -402,16 +409,16 @@ def json_dump_round_float(some_object, out_path, num_digits=4):
|
|
402 |
def filter_detection_size(inference, image_meter_width, width, max_length):
|
403 |
|
404 |
outputs = []
|
405 |
-
for
|
406 |
|
407 |
pix2width = image_meter_width/width
|
408 |
print("pix2w", pix2width)
|
409 |
-
width =
|
410 |
wc = width < max_length
|
411 |
-
bs =
|
412 |
|
413 |
-
output = [torch.zeros((0, 6), device=
|
414 |
-
for xi, x in enumerate(
|
415 |
x = x[wc[xi]] # confidence
|
416 |
output[xi] = x
|
417 |
|
|
|
205 |
gp: a callback function which takes as input 1 parameter, (int) percent complete
|
206 |
prep_for_marking: re-index fish for manual marking output
|
207 |
"""
|
208 |
+
|
209 |
+
print(type(image_shapes))
|
210 |
+
print(len(image_shapes))
|
211 |
+
print(type(outputs))
|
212 |
+
print(len(outputs))
|
213 |
|
214 |
if (gp): gp(0, "Formatting...")
|
215 |
# keep predictions to feed them ordered into the Tracker
|
|
|
221 |
if gp: gp(batch_i / len(image_shapes), pbar.__str__())
|
222 |
|
223 |
batch_shapes = image_shapes[batch_i]
|
224 |
+
print(type(batch_shapes))
|
225 |
print(len(batch_shapes))
|
226 |
+
print(type(batch))
|
227 |
print(len(batch))
|
228 |
print(batch_shapes[0])
|
229 |
print(batch[0])
|
|
|
409 |
def filter_detection_size(inference, image_meter_width, width, max_length):
|
410 |
|
411 |
outputs = []
|
412 |
+
for batch in inference:
|
413 |
|
414 |
pix2width = image_meter_width/width
|
415 |
print("pix2w", pix2width)
|
416 |
+
width = batch[..., 2]*pix2width
|
417 |
wc = width < max_length
|
418 |
+
bs = batch.shape[0] # batches
|
419 |
|
420 |
+
output = [torch.zeros((0, 6), device=batch.device)] * bs
|
421 |
+
for xi, x in enumerate(batch):
|
422 |
x = x[wc[xi]] # confidence
|
423 |
output[xi] = x
|
424 |
|