oskarastrom commited on
Commit
7162476
1 Parent(s): 94395df

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +12 -5
inference.py CHANGED
@@ -205,6 +205,11 @@ def format_predictions(image_shapes, outputs, width, height, gp=None, batch_size
205
  gp: a callback function which takes as input 1 parameter, (int) percent complete
206
  prep_for_marking: re-index fish for manual marking output
207
  """
 
 
 
 
 
208
 
209
  if (gp): gp(0, "Formatting...")
210
  # keep predictions to feed them ordered into the Tracker
@@ -216,7 +221,9 @@ def format_predictions(image_shapes, outputs, width, height, gp=None, batch_size
216
  if gp: gp(batch_i / len(image_shapes), pbar.__str__())
217
 
218
  batch_shapes = image_shapes[batch_i]
 
219
  print(len(batch_shapes))
 
220
  print(len(batch))
221
  print(batch_shapes[0])
222
  print(batch[0])
@@ -402,16 +409,16 @@ def json_dump_round_float(some_object, out_path, num_digits=4):
402
  def filter_detection_size(inference, image_meter_width, width, max_length):
403
 
404
  outputs = []
405
- for prediction in inference:
406
 
407
  pix2width = image_meter_width/width
408
  print("pix2w", pix2width)
409
- width = prediction[..., 2]*pix2width
410
  wc = width < max_length
411
- bs = prediction.shape[0] # batches
412
 
413
- output = [torch.zeros((0, 6), device=prediction.device)] * bs
414
- for xi, x in enumerate(prediction):
415
  x = x[wc[xi]] # confidence
416
  output[xi] = x
417
 
 
205
  gp: a callback function which takes as input 1 parameter, (int) percent complete
206
  prep_for_marking: re-index fish for manual marking output
207
  """
208
+
209
+ print(type(image_shapes))
210
+ print(len(image_shapes))
211
+ print(type(outputs))
212
+ print(len(outputs))
213
 
214
  if (gp): gp(0, "Formatting...")
215
  # keep predictions to feed them ordered into the Tracker
 
221
  if gp: gp(batch_i / len(image_shapes), pbar.__str__())
222
 
223
  batch_shapes = image_shapes[batch_i]
224
+ print(type(batch_shapes))
225
  print(len(batch_shapes))
226
+ print(type(batch))
227
  print(len(batch))
228
  print(batch_shapes[0])
229
  print(batch[0])
 
409
  def filter_detection_size(inference, image_meter_width, width, max_length):
410
 
411
  outputs = []
412
+ for batch in inference:
413
 
414
  pix2width = image_meter_width/width
415
  print("pix2w", pix2width)
416
+ width = batch[..., 2]*pix2width
417
  wc = width < max_length
418
+ bs = batch.shape[0] # batches
419
 
420
+ output = [torch.zeros((0, 6), device=batch.device)] * bs
421
+ for xi, x in enumerate(batch):
422
  x = x[wc[xi]] # confidence
423
  output[xi] = x
424