oskarastrom commited on
Commit
058f18b
1 Parent(s): 2a572c2

Output formats

Browse files
Files changed (4) hide show
  1. app.py +7 -2
  2. gradio_scripts/upload_ui.py +4 -1
  3. inference.py +58 -147
  4. main.py +13 -12
app.py CHANGED
@@ -24,19 +24,23 @@ state = {
24
  'total': 1,
25
  'annotation_index': -1,
26
  'frame_index': 0,
27
- 'config': None
 
28
  }
29
  result = {}
30
 
31
 
32
  # Called when an Aris file is uploaded for inference
33
- def on_aris_input(file_list, model_id, conf_thresh, iou_thresh, min_hits, max_age, associative_tracker, boost_power, boost_decay, byte_low_conf, byte_high_conf, min_length, min_travel):
 
 
34
 
35
  # Reset Result
36
  reset_state(result, state)
37
  state['files'] = file_list
38
  state['total'] = len(file_list)
39
  state['version'] = WEBAPP_VERSION
 
40
  state['config'] = InferenceConfig(
41
  weights = models[model_id] if model_id in models else models['master'],
42
  conf_thresh = conf_thresh,
@@ -170,6 +174,7 @@ def infer_next(_, progress=gr.Progress()):
170
  json_result, json_filepath, zip_filepath, video_filepath, marking_filepath = predict_task(
171
  file_path,
172
  config = state['config'],
 
173
  gradio_progress = set_progress
174
  )
175
 
 
24
  'total': 1,
25
  'annotation_index': -1,
26
  'frame_index': 0,
27
+ 'outputs': [],
28
+ 'config': None,
29
  }
30
  result = {}
31
 
32
 
33
  # Called when an Aris file is uploaded for inference
34
+ def on_aris_input(file_list, model_id, conf_thresh, iou_thresh, min_hits, max_age, associative_tracker, boost_power, boost_decay, byte_low_conf, byte_high_conf, min_length, min_travel, output_formats):
35
+
36
+ print(output_formats)
37
 
38
  # Reset Result
39
  reset_state(result, state)
40
  state['files'] = file_list
41
  state['total'] = len(file_list)
42
  state['version'] = WEBAPP_VERSION
43
+ state['outputs'] = output_formats
44
  state['config'] = InferenceConfig(
45
  weights = models[model_id] if model_id in models else models['master'],
46
  conf_thresh = conf_thresh,
 
174
  json_result, json_filepath, zip_filepath, video_filepath, marking_filepath = predict_task(
175
  file_path,
176
  config = state['config'],
177
+ output_formats = state['outputs'],
178
  gradio_progress = set_progress
179
  )
180
 
gradio_scripts/upload_ui.py CHANGED
@@ -17,8 +17,8 @@ def Upload_Gradio(gradio_components):
17
 
18
  gr.HTML("<p align='center' style='font-size: large;font-style: italic;'>Submit an .aris file to analyze result.</p>")
19
 
 
20
  with gr.Accordion("Advanced Settings", open=False):
21
- settings = []
22
  settings.append(gr.Dropdown(label="Model", value="master", choices=list(models.keys())))
23
 
24
  gr.Markdown("Detection Parameters")
@@ -49,6 +49,9 @@ def Upload_Gradio(gradio_components):
49
 
50
  gradio_components['hyperparams'] = settings
51
 
 
 
 
52
  #Input field for aris submission
53
  gradio_components['input'] = File(file_types=[".aris", ".ddf"], type="binary", label="ARIS Input", file_count="multiple")
54
 
 
17
 
18
  gr.HTML("<p align='center' style='font-size: large;font-style: italic;'>Submit an .aris file to analyze result.</p>")
19
 
20
+ settings = []
21
  with gr.Accordion("Advanced Settings", open=False):
 
22
  settings.append(gr.Dropdown(label="Model", value="master", choices=list(models.keys())))
23
 
24
  gr.Markdown("Detection Parameters")
 
49
 
50
  gradio_components['hyperparams'] = settings
51
 
52
+ with gr.Row():
53
+ settings.append(gr.CheckboxGroup(["Annotated Video", "Manual Marking", "PDF"], label="Output formats", interactive=True, value=["Annotated Video", "Manual Marking"]))
54
+
55
  #Input field for aris submission
56
  gradio_components['input'] = File(file_types=[".aris", ".ddf"], type="binary", label="ARIS Input", file_count="multiple")
57
 
inference.py CHANGED
@@ -53,44 +53,23 @@ def norm(bbox, w, h):
53
 
54
  def do_full_inference(dataloader, image_meter_width, image_meter_height, gp=None, config=InferenceConfig()):
55
 
 
56
  model, device = setup_model(config.weights)
57
 
58
- load = False
59
- save = False
60
-
61
- if load:
62
- with open('static/example/inference_output.json', 'r') as f:
63
- json_object = json.load(f)
64
- inference = json_object['inference']
65
- width = json_object['width']
66
- height = json_object['height']
67
- image_shapes = json_object['image_shapes']
68
- else:
69
- inference, image_shapes, width, height = do_detection(dataloader, model, device, gp=gp)
70
-
71
- if save:
72
- json_object = {
73
- 'inference': inference,
74
- 'width': width,
75
- 'height': height,
76
- 'image_shapes': image_shapes
77
- }
78
- json_text = json.dumps(json_object, indent=4)
79
- with open('static/example/inference_output.json', 'w') as f:
80
- f.write(json_text)
81
- return
82
-
83
-
84
- outputs = do_suppression(inference, conf_thres=config.conf_thresh, iou_thres=config.nms_iou, gp=gp)
85
 
86
  if config.associative_tracker == TrackerType.BYTETRACK:
87
 
 
88
  low_outputs = do_suppression(inference, conf_thres=config.byte_low_conf, iou_thres=config.nms_iou, gp=gp)
89
  low_preds, real_width, real_height = format_predictions(image_shapes, low_outputs, width, height, gp=gp)
90
 
 
91
  high_outputs = do_suppression(inference, conf_thres=config.byte_high_conf, iou_thres=config.nms_iou, gp=gp)
92
  high_preds, real_width, real_height = format_predictions(image_shapes, high_outputs, width, height, gp=gp)
93
 
 
94
  results = do_associative_tracking(
95
  low_preds, high_preds, image_meter_width, image_meter_height,
96
  reverse=False, min_length=config.min_length, min_travel=config.min_travel,
@@ -98,17 +77,21 @@ def do_full_inference(dataloader, image_meter_width, image_meter_height, gp=None
98
  gp=gp)
99
  else:
100
 
101
-
102
  outputs = do_suppression(inference, conf_thres=config.conf_thresh, iou_thres=config.nms_iou, gp=gp)
103
 
104
  if config.associative_tracker == TrackerType.CONF_BOOST:
105
 
 
106
  do_confidence_boost(inference, outputs, boost_power=config.boost_power, boost_decay=config.boost_decay, gp=gp)
107
 
 
108
  outputs = do_suppression(inference, conf_thres=config.conf_thresh, iou_thres=config.nms_iou, gp=gp)
109
 
 
110
  all_preds, real_width, real_height = format_predictions(image_shapes, outputs, width, height, gp=gp)
111
 
 
112
  results = do_tracking(
113
  all_preds, image_meter_width, image_meter_height,
114
  min_length=config.min_length, min_travel=config.min_travel,
@@ -118,6 +101,9 @@ def do_full_inference(dataloader, image_meter_width, image_meter_height, gp=None
118
  return results
119
 
120
 
 
 
 
121
  def setup_model(weights_fp=WEIGHTS, imgsz=896, batch_size=32):
122
  if torch.cuda.is_available():
123
  device = select_device('0', batch_size=batch_size)
@@ -252,13 +238,44 @@ def format_predictions(image_shapes, outputs, width, height, gp=None, batch_size
252
 
253
  return all_preds, real_width, real_height
254
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
255
  def do_confidence_boost(inference, safe_preds, gp=None, batch_size=BATCH_SIZE, boost_power=1, boost_decay=1, verbose=True):
256
  """
257
- Args:
258
- frames_dir: a directory containing frames to be evaluated
259
- image_meter_width: the width of each image, in meters (used for fish length calculation)
260
- gp: a callback function which takes as input 1 parameter, (int) percent complete
261
- prep_for_marking: re-index fish for manual marking output
262
  """
263
 
264
  if (gp): gp(0, "Confidence Boost...")
@@ -303,9 +320,11 @@ def do_confidence_boost(inference, safe_preds, gp=None, batch_size=BATCH_SIZE, b
303
  boost_frame(safe_frame, temp_frame, dt, power=boost_scale, decay=boost_decay)
304
 
305
  pbar.update(1*batch_size)
306
-
307
 
308
  def boost_frame(safe_frame, base_frame, dt, power=1, decay=1):
 
 
 
309
  safe_boxes = safe_frame[:, :4]
310
  boxes = xywh2xyxy(base_frame[:, :4]) # center_x, center_y, width, height) to (x1, y1, x2, y2)≈
311
 
@@ -320,34 +339,7 @@ def boost_frame(safe_frame, base_frame, dt, power=1, decay=1):
320
  base_frame[:, 4] *= 1 + power*(score)*math.exp(-decay*(dt*dt-1))
321
  return base_frame
322
 
323
- def do_tracking(all_preds, image_meter_width, image_meter_height, gp=None, max_age=MAX_AGE, iou_thres=IOU_THRES, min_hits=MIN_HITS, min_length=MIN_LENGTH, min_travel=MIN_TRAVEL, verbose=True):
324
-
325
- if (gp): gp(0, "Tracking...")
326
-
327
- # Initialize tracker
328
- clip_info = {
329
- 'start_frame': 0,
330
- 'end_frame': len(all_preds),
331
- 'image_meter_width': image_meter_width,
332
- 'image_meter_height': image_meter_height
333
- }
334
- tracker = Tracker(clip_info, args={ 'max_age': max_age, 'min_hits': 0, 'iou_threshold': iou_thres}, min_hits=min_hits)
335
-
336
- # Run tracking
337
- with tqdm(total=len(all_preds), desc="Running tracking", ncols=0, disable=not verbose) as pbar:
338
- for i, key in enumerate(sorted(all_preds.keys())):
339
- if gp: gp(i / len(all_preds), pbar.__str__())
340
- boxes = all_preds[key]
341
- if boxes is not None:
342
- tracker.update(boxes)
343
- else:
344
- tracker.update()
345
- pbar.update(1)
346
-
347
- json_data = tracker.finalize(min_length=min_length, min_travel=min_travel)
348
-
349
- return json_data
350
-
351
  def do_associative_tracking(low_preds, high_preds, image_meter_width, image_meter_height, reverse=False, gp=None, max_age=MAX_AGE, iou_thres=IOU_THRES, min_hits=MIN_HITS, min_length=MIN_LENGTH, min_travel=MIN_TRAVEL, verbose=True):
352
 
353
  if (gp): gp(0, "Tracking...")
@@ -379,6 +371,8 @@ def do_associative_tracking(low_preds, high_preds, image_meter_width, image_mete
379
 
380
  return json_data
381
 
 
 
382
  @patch('json.encoder.c_make_encoder', None)
383
  def json_dump_round_float(some_object, out_path, num_digits=4):
384
  """Write a json file to disk with a specified level of precision.
@@ -396,8 +390,6 @@ def json_dump_round_float(some_object, out_path, num_digits=4):
396
  with patch('json.encoder._make_iterencode', wraps=inner):
397
  return json.dump(some_object, open(out_path, 'w'), indent=2)
398
 
399
-
400
-
401
  def non_max_suppression(
402
  prediction,
403
  conf_thres=0.25,
@@ -406,6 +398,8 @@ def non_max_suppression(
406
  ):
407
  """Non-Maximum Suppression (NMS) on inference results to reject overlapping detections
408
 
 
 
409
  Returns:
410
  list of detections, on (n,6) tensor per image [xyxy, conf, cls]
411
  """
@@ -481,86 +475,3 @@ def non_max_suppression(
481
 
482
  return output
483
 
484
-
485
- def no_suppression(
486
- prediction,
487
- conf_thres=0.25,
488
- iou_thres=0.45,
489
- max_det=300,
490
- ):
491
- """Non-Maximum Suppression (NMS) on inference results to reject overlapping detections
492
-
493
- Returns:
494
- list of detections, on (n,6) tensor per image [xyxy, conf, cls]
495
- """
496
-
497
- # Checks
498
- assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'
499
- assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'
500
- if isinstance(prediction, (list, tuple)): # YOLOv5 model in validation model, output = (inference_out, loss_out)
501
- prediction = prediction[0] # select only inference output
502
-
503
- device = prediction.device
504
- mps = 'mps' in device.type # Apple MPS
505
- if mps: # MPS not fully supported yet, convert tensors to CPU before NMS
506
- prediction = prediction.cpu()
507
- bs = prediction.shape[0] # batch size
508
- xc = prediction[..., 4] > conf_thres # candidates
509
-
510
- # Settings
511
- # min_wh = 2 # (pixels) minimum box width and height
512
- max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
513
- redundant = True # require redundant detections
514
- merge = False # use merge-NMS
515
-
516
- output = [torch.zeros((0, 6), device=prediction.device)] * bs
517
- for xi, x in enumerate(prediction): # image index, image inference
518
-
519
-
520
- # Keep boxes that pass confidence threshold
521
- x = x[xc[xi]] # confidence
522
-
523
- # If none remain process next image
524
- if not x.shape[0]:
525
- continue
526
-
527
- # Compute conf
528
- x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf
529
-
530
-
531
- # Box/Mask
532
- box = xywh2xyxy(x[:, :4]) # center_x, center_y, width, height) to (x1, y1, x2, y2)
533
- mask = x[:, 6:] # zero columns if no masks
534
-
535
- # Detections matrix nx6 (xyxy, conf, cls)
536
- conf, j = x[:, 5:6].max(1, keepdim=True)
537
- x = torch.cat((box, conf, j.float(), mask), 1)[conf.view(-1) > conf_thres]
538
-
539
-
540
- # Check shape
541
- n = x.shape[0] # number of boxes
542
- if not n: # no boxes
543
- continue
544
- x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence and remove excess boxes
545
-
546
- # Batched NMS
547
- boxes = x[:, :4] # boxes (offset by class), scores
548
- scores = x[:, 4]
549
- i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
550
-
551
- i = i[:max_det] # limit detections
552
- if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
553
- # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
554
- iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
555
- weights = iou * scores[None] # box weights
556
- x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
557
- if redundant:
558
- i = i[iou.sum(1) > 1] # require redundancy
559
-
560
- output[xi] = x[i]
561
- if mps:
562
- output[xi] = output[xi].to(device)
563
-
564
- logging = False
565
-
566
- return output
 
53
 
54
  def do_full_inference(dataloader, image_meter_width, image_meter_height, gp=None, config=InferenceConfig()):
55
 
56
+ # Set up model
57
  model, device = setup_model(config.weights)
58
 
59
+ # Detect boxes in frames
60
+ inference, image_shapes, width, height = do_detection(dataloader, model, device, gp=gp)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
  if config.associative_tracker == TrackerType.BYTETRACK:
63
 
64
+ # Find low confidence detections
65
  low_outputs = do_suppression(inference, conf_thres=config.byte_low_conf, iou_thres=config.nms_iou, gp=gp)
66
  low_preds, real_width, real_height = format_predictions(image_shapes, low_outputs, width, height, gp=gp)
67
 
68
+ # Find high confidence detections
69
  high_outputs = do_suppression(inference, conf_thres=config.byte_high_conf, iou_thres=config.nms_iou, gp=gp)
70
  high_preds, real_width, real_height = format_predictions(image_shapes, high_outputs, width, height, gp=gp)
71
 
72
+ # Perform associative tracking (ByteTrack)
73
  results = do_associative_tracking(
74
  low_preds, high_preds, image_meter_width, image_meter_height,
75
  reverse=False, min_length=config.min_length, min_travel=config.min_travel,
 
77
  gp=gp)
78
  else:
79
 
80
+ # Find confident detections
81
  outputs = do_suppression(inference, conf_thres=config.conf_thresh, iou_thres=config.nms_iou, gp=gp)
82
 
83
  if config.associative_tracker == TrackerType.CONF_BOOST:
84
 
85
+ # Boost confidence based on found confident detections
86
  do_confidence_boost(inference, outputs, boost_power=config.boost_power, boost_decay=config.boost_decay, gp=gp)
87
 
88
+ # Find confident detections from boosted list
89
  outputs = do_suppression(inference, conf_thres=config.conf_thresh, iou_thres=config.nms_iou, gp=gp)
90
 
91
+ # Format confident detections
92
  all_preds, real_width, real_height = format_predictions(image_shapes, outputs, width, height, gp=gp)
93
 
94
+ # Perform SORT tracking
95
  results = do_tracking(
96
  all_preds, image_meter_width, image_meter_height,
97
  min_length=config.min_length, min_travel=config.min_travel,
 
101
  return results
102
 
103
 
104
+
105
+
106
+
107
  def setup_model(weights_fp=WEIGHTS, imgsz=896, batch_size=32):
108
  if torch.cuda.is_available():
109
  device = select_device('0', batch_size=batch_size)
 
238
 
239
  return all_preds, real_width, real_height
240
 
241
+
242
+ # ---------------------------------------- TRACKING ------------------------------------------
243
+
244
+ def do_tracking(all_preds, image_meter_width, image_meter_height, gp=None, max_age=MAX_AGE, iou_thres=IOU_THRES, min_hits=MIN_HITS, min_length=MIN_LENGTH, min_travel=MIN_TRAVEL, verbose=True):
245
+ """
246
+ Perform SORT tracking based on formatted detections
247
+ """
248
+
249
+ if (gp): gp(0, "Tracking...")
250
+
251
+ # Initialize tracker
252
+ clip_info = {
253
+ 'start_frame': 0,
254
+ 'end_frame': len(all_preds),
255
+ 'image_meter_width': image_meter_width,
256
+ 'image_meter_height': image_meter_height
257
+ }
258
+ tracker = Tracker(clip_info, args={ 'max_age': max_age, 'min_hits': 0, 'iou_threshold': iou_thres}, min_hits=min_hits)
259
+
260
+ # Run tracking
261
+ with tqdm(total=len(all_preds), desc="Running tracking", ncols=0, disable=not verbose) as pbar:
262
+ for i, key in enumerate(sorted(all_preds.keys())):
263
+ if gp: gp(i / len(all_preds), pbar.__str__())
264
+ boxes = all_preds[key]
265
+ if boxes is not None:
266
+ tracker.update(boxes)
267
+ else:
268
+ tracker.update()
269
+ pbar.update(1)
270
+
271
+ json_data = tracker.finalize(min_length=min_length, min_travel=min_travel)
272
+
273
+ return json_data
274
+
275
  def do_confidence_boost(inference, safe_preds, gp=None, batch_size=BATCH_SIZE, boost_power=1, boost_decay=1, verbose=True):
276
  """
277
+ Takes in the full YOLO detections 'inference' and formatted non-max suppressed detections 'safe_preds'
278
+ and boosts the confidence of detections around identified fish that are close in space in neighbouring frames.
 
 
 
279
  """
280
 
281
  if (gp): gp(0, "Confidence Boost...")
 
320
  boost_frame(safe_frame, temp_frame, dt, power=boost_scale, decay=boost_decay)
321
 
322
  pbar.update(1*batch_size)
 
323
 
324
  def boost_frame(safe_frame, base_frame, dt, power=1, decay=1):
325
+ """
326
+ Boosts confidence of base_frame based on confidence in safe_frame, iou, and the time difference between frames.
327
+ """
328
  safe_boxes = safe_frame[:, :4]
329
  boxes = xywh2xyxy(base_frame[:, :4]) # center_x, center_y, width, height) to (x1, y1, x2, y2)≈
330
 
 
339
  base_frame[:, 4] *= 1 + power*(score)*math.exp(-decay*(dt*dt-1))
340
  return base_frame
341
 
342
+ # ByteTrack
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
343
  def do_associative_tracking(low_preds, high_preds, image_meter_width, image_meter_height, reverse=False, gp=None, max_age=MAX_AGE, iou_thres=IOU_THRES, min_hits=MIN_HITS, min_length=MIN_LENGTH, min_travel=MIN_TRAVEL, verbose=True):
344
 
345
  if (gp): gp(0, "Tracking...")
 
371
 
372
  return json_data
373
 
374
+
375
+
376
  @patch('json.encoder.c_make_encoder', None)
377
  def json_dump_round_float(some_object, out_path, num_digits=4):
378
  """Write a json file to disk with a specified level of precision.
 
390
  with patch('json.encoder._make_iterencode', wraps=inner):
391
  return json.dump(some_object, open(out_path, 'w'), indent=2)
392
 
 
 
393
  def non_max_suppression(
394
  prediction,
395
  conf_thres=0.25,
 
398
  ):
399
  """Non-Maximum Suppression (NMS) on inference results to reject overlapping detections
400
 
401
+ NOTE: SIMPLIFIED FOR SINGLE CLASS DETECTION
402
+
403
  Returns:
404
  list of detections, on (n,6) tensor per image [xyxy, conf, cls]
405
  """
 
475
 
476
  return output
477
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main.py CHANGED
@@ -7,7 +7,7 @@ from dataloader import create_dataloader_aris
7
  from inference import do_full_inference, json_dump_round_float
8
  from visualizer import generate_video_batches
9
 
10
- def predict_task(filepath, config, gradio_progress=None):
11
  """
12
  Main processing task to be run in gradio
13
  - Writes aris frames to dirname(filepath)/frames/{i}.jpg
@@ -17,12 +17,11 @@ def predict_task(filepath, config, gradio_progress=None):
17
  - Zips all results to dirname(filepath)/{filename}_results.zip
18
  Args:
19
  filepath (str): path to aris file
20
-
21
- TODO: Separate into subtasks in different queues; have a GPU-only queue.
22
  """
23
  if (gradio_progress): gradio_progress(0, "In task...")
24
  print("Cuda available in task?", torch.cuda.is_available())
25
 
 
26
  dirname = os.path.dirname(filepath)
27
  filename = os.path.basename(filepath).replace(".aris","").replace(".ddf","")
28
  results_filepath = os.path.join(dirname, f"{filename}_results.json")
@@ -31,11 +30,11 @@ def predict_task(filepath, config, gradio_progress=None):
31
  zip_filepath = os.path.join(dirname, f"{filename}_results.zip")
32
  os.makedirs(dirname, exist_ok=True)
33
 
34
- # create dataloader
35
  if (gradio_progress): gradio_progress(0, "Initializing Dataloader...")
36
  dataloader, dataset = create_dataloader_aris(filepath, BEAM_WIDTH_DIR, None)
37
 
38
- # extract aris/didson info. didson does not yet have pixel-meter info
39
  if ".ddf" in filepath:
40
  image_meter_width = -1
41
  image_meter_height = -1
@@ -47,28 +46,30 @@ def predict_task(filepath, config, gradio_progress=None):
47
  # run detection + tracking
48
  results = do_full_inference(dataloader, image_meter_width, image_meter_height, gp=gradio_progress, config=config)
49
 
50
- # re-index results if desired - this should be done before writing the file
51
  results = prep_for_mm(results)
52
  results = add_metadata_to_result(filepath, results)
53
  results['metadata']['hyperparameters'] = config.to_dict()
54
 
55
- # write output to disk
56
  json_dump_round_float(results, results_filepath)
57
 
58
- if dataset.didson.info['version'][3] == 5: # ARIS only
 
59
  create_manual_marking(results, out_path=marking_filepath)
60
 
61
- # generate a video with tracking results
62
- generate_video_batches(dataset.didson, results, frame_rate, video_filepath,
 
63
  image_meter_width=image_meter_width, image_meter_height=image_meter_height, gp=gradio_progress)
64
 
65
- # zip up the results
66
  with ZipFile(zip_filepath, 'w') as z:
67
  for file in [results_filepath, marking_filepath, video_filepath, os.path.join(dirname, 'bg_start.jpg')]:
68
  if os.path.exists(file):
69
  z.write(file, arcname=os.path.basename(file))
70
 
71
- # release GPU memory
72
  torch.cuda.empty_cache()
73
 
74
  return results, results_filepath, zip_filepath, video_filepath, marking_filepath
 
7
  from inference import do_full_inference, json_dump_round_float
8
  from visualizer import generate_video_batches
9
 
10
+ def predict_task(filepath, config, output_formats=[], gradio_progress=None):
11
  """
12
  Main processing task to be run in gradio
13
  - Writes aris frames to dirname(filepath)/frames/{i}.jpg
 
17
  - Zips all results to dirname(filepath)/{filename}_results.zip
18
  Args:
19
  filepath (str): path to aris file
 
 
20
  """
21
  if (gradio_progress): gradio_progress(0, "In task...")
22
  print("Cuda available in task?", torch.cuda.is_available())
23
 
24
+ # Set up save directory and define file names
25
  dirname = os.path.dirname(filepath)
26
  filename = os.path.basename(filepath).replace(".aris","").replace(".ddf","")
27
  results_filepath = os.path.join(dirname, f"{filename}_results.json")
 
30
  zip_filepath = os.path.join(dirname, f"{filename}_results.zip")
31
  os.makedirs(dirname, exist_ok=True)
32
 
33
+ # Create dataloader
34
  if (gradio_progress): gradio_progress(0, "Initializing Dataloader...")
35
  dataloader, dataset = create_dataloader_aris(filepath, BEAM_WIDTH_DIR, None)
36
 
37
+ # Extract aris/didson info. Didson does not yet have pixel-meter info
38
  if ".ddf" in filepath:
39
  image_meter_width = -1
40
  image_meter_height = -1
 
46
  # run detection + tracking
47
  results = do_full_inference(dataloader, image_meter_width, image_meter_height, gp=gradio_progress, config=config)
48
 
49
+ # Generate Metadata and extra inference information
50
  results = prep_for_mm(results)
51
  results = add_metadata_to_result(filepath, results)
52
  results['metadata']['hyperparameters'] = config.to_dict()
53
 
54
+ # Create JSON result file
55
  json_dump_round_float(results, results_filepath)
56
 
57
+ # Create Manual Marking file
58
+ if "Manual Marking" in output_formats and dataset.didson.info['version'][3] == 5:
59
  create_manual_marking(results, out_path=marking_filepath)
60
 
61
+ # Create Annotated Video
62
+ if "Annotated Video" in output_formats:
63
+ generate_video_batches(dataset.didson, results, frame_rate, video_filepath,
64
  image_meter_width=image_meter_width, image_meter_height=image_meter_height, gp=gradio_progress)
65
 
66
+ # Zip up the results
67
  with ZipFile(zip_filepath, 'w') as z:
68
  for file in [results_filepath, marking_filepath, video_filepath, os.path.join(dirname, 'bg_start.jpg')]:
69
  if os.path.exists(file):
70
  z.write(file, arcname=os.path.basename(file))
71
 
72
+ # Release GPU memory
73
  torch.cuda.empty_cache()
74
 
75
  return results, results_filepath, zip_filepath, video_filepath, marking_filepath