oskarastrom commited on
Commit
c9d11b2
1 Parent(s): 4b0fc85
Files changed (2) hide show
  1. inference.py +6 -6
  2. scripts/infer_frames.py +17 -13
inference.py CHANGED
@@ -102,7 +102,7 @@ def setup_model(weights_fp=WEIGHTS, imgsz=896, batch_size=32):
102
 
103
  return model, device
104
 
105
- def do_detection(dataloader, model, device, gp=None, batch_size=BATCH_SIZE):
106
  """
107
  Args:
108
  frames_dir: a directory containing frames to be evaluated
@@ -115,7 +115,7 @@ def do_detection(dataloader, model, device, gp=None, batch_size=BATCH_SIZE):
115
 
116
  inference = []
117
  # Run detection
118
- with tqdm(total=len(dataloader)*batch_size, desc="Running detection", ncols=0) as pbar:
119
  for batch_i, (img, _, shapes) in enumerate(dataloader):
120
 
121
  if gp: gp(batch_i / len(dataloader), pbar.__str__())
@@ -134,7 +134,7 @@ def do_detection(dataloader, model, device, gp=None, batch_size=BATCH_SIZE):
134
 
135
  return inference, width, height
136
 
137
- def do_suppression(dataloader, inference, width, height, gp=None, batch_size=BATCH_SIZE, conf_thres=CONF_THRES, iou_thres=NMS_IOU):
138
  """
139
  Args:
140
  frames_dir: a directory containing frames to be evaluated
@@ -147,7 +147,7 @@ def do_suppression(dataloader, inference, width, height, gp=None, batch_size=BAT
147
  # keep predictions to feed them ordered into the Tracker
148
  # TODO: how to deal with large files?
149
  all_preds = {}
150
- with tqdm(total=len(dataloader)*batch_size, desc="Running suppression", ncols=0) as pbar:
151
  for batch_i, (img, _, shapes) in enumerate(dataloader):
152
 
153
  if gp: gp(batch_i / len(dataloader), pbar.__str__())
@@ -180,7 +180,7 @@ def do_suppression(dataloader, inference, width, height, gp=None, batch_size=BAT
180
 
181
  return all_preds, real_width, real_height
182
 
183
- def do_tracking(all_preds, image_meter_width, image_meter_height, gp=None, max_age=MAX_AGE, iou_thres=IOU_THRES, min_hits=MIN_HITS, min_length=MIN_LENGTH):
184
 
185
  if (gp): gp(0, "Tracking...")
186
 
@@ -194,7 +194,7 @@ def do_tracking(all_preds, image_meter_width, image_meter_height, gp=None, max_a
194
  tracker = Tracker(clip_info, args={ 'max_age': max_age, 'min_hits': 0, 'iou_threshold': iou_thres}, min_hits=min_hits)
195
 
196
  # Run tracking
197
- with tqdm(total=len(all_preds), desc="Running tracking", ncols=0) as pbar:
198
  for i, key in enumerate(sorted(all_preds.keys())):
199
  if gp: gp(i / len(all_preds), pbar.__str__())
200
  boxes = all_preds[key]
 
102
 
103
  return model, device
104
 
105
+ def do_detection(dataloader, model, device, gp=None, batch_size=BATCH_SIZE, verbose=True):
106
  """
107
  Args:
108
  frames_dir: a directory containing frames to be evaluated
 
115
 
116
  inference = []
117
  # Run detection
118
+ with tqdm(total=len(dataloader)*batch_size, desc="Running detection", ncols=0, disable=not verbose) as pbar:
119
  for batch_i, (img, _, shapes) in enumerate(dataloader):
120
 
121
  if gp: gp(batch_i / len(dataloader), pbar.__str__())
 
134
 
135
  return inference, width, height
136
 
137
+ def do_suppression(dataloader, inference, width, height, gp=None, batch_size=BATCH_SIZE, conf_thres=CONF_THRES, iou_thres=NMS_IOU, verbose=True):
138
  """
139
  Args:
140
  frames_dir: a directory containing frames to be evaluated
 
147
  # keep predictions to feed them ordered into the Tracker
148
  # TODO: how to deal with large files?
149
  all_preds = {}
150
+ with tqdm(total=len(dataloader)*batch_size, desc="Running suppression", ncols=0, disable=not verbose) as pbar:
151
  for batch_i, (img, _, shapes) in enumerate(dataloader):
152
 
153
  if gp: gp(batch_i / len(dataloader), pbar.__str__())
 
180
 
181
  return all_preds, real_width, real_height
182
 
183
+ def do_tracking(all_preds, image_meter_width, image_meter_height, gp=None, max_age=MAX_AGE, iou_thres=IOU_THRES, min_hits=MIN_HITS, min_length=MIN_LENGTH, verbose=True):
184
 
185
  if (gp): gp(0, "Tracking...")
186
 
 
194
  tracker = Tracker(clip_info, args={ 'max_age': max_age, 'min_hits': 0, 'iou_threshold': iou_thres}, min_hits=min_hits)
195
 
196
  # Run tracking
197
+ with tqdm(total=len(all_preds), desc="Running tracking", ncols=0, disable=not verbose) as pbar:
198
  for i, key in enumerate(sorted(all_preds.keys())):
199
  if gp: gp(i / len(all_preds), pbar.__str__())
200
  boxes = all_preds[key]
scripts/infer_frames.py CHANGED
@@ -8,9 +8,10 @@ from aris import create_manual_marking, create_metadata_dictionary, prep_for_mm
8
  from inference import setup_model, do_suppression, do_detection, do_tracking, json_dump_round_float
9
  from visualizer import generate_video_batches
10
  import json
 
11
 
12
 
13
- def main(args, config={}):
14
  """
15
  Main processing task to be run in gradio
16
  - Writes aris frames to dirname(filepath)/frames/{i}.jpg
@@ -54,15 +55,18 @@ def main(args, config={}):
54
 
55
  seq_list = os.listdir(in_loc_dir)
56
  idx = 1
57
- for seq in seq_list:
58
- print(" ")
59
- print("(" + str(idx) + "/" + str(len(seq_list)) + ") " + seq)
60
- print(" ")
61
- idx += 1
62
- in_seq_dir = os.path.join(in_loc_dir, seq)
63
- infer_seq(in_seq_dir, out_dir, config, seq, model, device, metadata_path)
64
-
65
- def infer_seq(in_dir, out_dir, config, seq_name, model, device, metadata_path):
 
 
 
66
 
67
  #progress_log = lambda p, m: 0
68
 
@@ -83,16 +87,16 @@ def infer_seq(in_dir, out_dir, config, seq_name, model, device, metadata_path):
83
  dataloader = create_dataloader_frames_only(in_dir)
84
 
85
  try:
86
- inference, width, height = do_detection(dataloader, model, device)
87
  except:
88
  print("Error in " + seq_name)
89
  with open(os.path.join(out_dir, "ERROR_" + seq_name + ".txt"), 'w') as f:
90
  f.write("ERROR")
91
  return
92
 
93
- all_preds, real_width, real_height = do_suppression(dataloader, inference, width, height, conf_thres=config['conf_threshold'], iou_thres=config['nms_iou'])
94
 
95
- results = do_tracking(all_preds, image_meter_width, image_meter_height, min_length=config['min_length'], max_age=config['max_age'], iou_thres=config['iou_threshold'], min_hits=config['min_hits'])
96
 
97
  mot_rows = []
98
  for frame in results['frames']:
 
8
  from inference import setup_model, do_suppression, do_detection, do_tracking, json_dump_round_float
9
  from visualizer import generate_video_batches
10
  import json
11
+ from tqdm import tqdm
12
 
13
 
14
+ def main(args, config={}, verbose=True):
15
  """
16
  Main processing task to be run in gradio
17
  - Writes aris frames to dirname(filepath)/frames/{i}.jpg
 
55
 
56
  seq_list = os.listdir(in_loc_dir)
57
  idx = 1
58
+ with tqdm(total=len(seq_list), desc="...", ncols=0) as pbar:
59
+ for seq in seq_list:
60
+ pbar.update(1)
61
+ pbar.set_description("Processing " + seq)
62
+ print(" ")
63
+ print("(" + str(idx) + "/" + str(len(seq_list)) + ") " + seq)
64
+ print(" ")
65
+ idx += 1
66
+ in_seq_dir = os.path.join(in_loc_dir, seq)
67
+ infer_seq(in_seq_dir, out_dir, config, seq, model, device, metadata_path, verbose)
68
+
69
+ def infer_seq(in_dir, out_dir, config, seq_name, model, device, metadata_path, verbose):
70
 
71
  #progress_log = lambda p, m: 0
72
 
 
87
  dataloader = create_dataloader_frames_only(in_dir)
88
 
89
  try:
90
+ inference, width, height = do_detection(dataloader, model, device, verbose=verbose)
91
  except:
92
  print("Error in " + seq_name)
93
  with open(os.path.join(out_dir, "ERROR_" + seq_name + ".txt"), 'w') as f:
94
  f.write("ERROR")
95
  return
96
 
97
+ all_preds, real_width, real_height = do_suppression(dataloader, inference, width, height, conf_thres=config['conf_threshold'], iou_thres=config['nms_iou'], verbose=verbose)
98
 
99
+ results = do_tracking(all_preds, image_meter_width, image_meter_height, min_length=config['min_length'], max_age=config['max_age'], iou_thres=config['iou_threshold'], min_hits=config['min_hits'], verbose=verbose)
100
 
101
  mot_rows = []
102
  for frame in results['frames']: