Spaces:
Runtime error
Runtime error
Commit
•
c9d11b2
1
Parent(s):
4b0fc85
inference
Browse files- inference.py +6 -6
- scripts/infer_frames.py +17 -13
inference.py
CHANGED
@@ -102,7 +102,7 @@ def setup_model(weights_fp=WEIGHTS, imgsz=896, batch_size=32):
|
|
102 |
|
103 |
return model, device
|
104 |
|
105 |
-
def do_detection(dataloader, model, device, gp=None, batch_size=BATCH_SIZE):
|
106 |
"""
|
107 |
Args:
|
108 |
frames_dir: a directory containing frames to be evaluated
|
@@ -115,7 +115,7 @@ def do_detection(dataloader, model, device, gp=None, batch_size=BATCH_SIZE):
|
|
115 |
|
116 |
inference = []
|
117 |
# Run detection
|
118 |
-
with tqdm(total=len(dataloader)*batch_size, desc="Running detection", ncols=0) as pbar:
|
119 |
for batch_i, (img, _, shapes) in enumerate(dataloader):
|
120 |
|
121 |
if gp: gp(batch_i / len(dataloader), pbar.__str__())
|
@@ -134,7 +134,7 @@ def do_detection(dataloader, model, device, gp=None, batch_size=BATCH_SIZE):
|
|
134 |
|
135 |
return inference, width, height
|
136 |
|
137 |
-
def do_suppression(dataloader, inference, width, height, gp=None, batch_size=BATCH_SIZE, conf_thres=CONF_THRES, iou_thres=NMS_IOU):
|
138 |
"""
|
139 |
Args:
|
140 |
frames_dir: a directory containing frames to be evaluated
|
@@ -147,7 +147,7 @@ def do_suppression(dataloader, inference, width, height, gp=None, batch_size=BAT
|
|
147 |
# keep predictions to feed them ordered into the Tracker
|
148 |
# TODO: how to deal with large files?
|
149 |
all_preds = {}
|
150 |
-
with tqdm(total=len(dataloader)*batch_size, desc="Running suppression", ncols=0) as pbar:
|
151 |
for batch_i, (img, _, shapes) in enumerate(dataloader):
|
152 |
|
153 |
if gp: gp(batch_i / len(dataloader), pbar.__str__())
|
@@ -180,7 +180,7 @@ def do_suppression(dataloader, inference, width, height, gp=None, batch_size=BAT
|
|
180 |
|
181 |
return all_preds, real_width, real_height
|
182 |
|
183 |
-
def do_tracking(all_preds, image_meter_width, image_meter_height, gp=None, max_age=MAX_AGE, iou_thres=IOU_THRES, min_hits=MIN_HITS, min_length=MIN_LENGTH):
|
184 |
|
185 |
if (gp): gp(0, "Tracking...")
|
186 |
|
@@ -194,7 +194,7 @@ def do_tracking(all_preds, image_meter_width, image_meter_height, gp=None, max_a
|
|
194 |
tracker = Tracker(clip_info, args={ 'max_age': max_age, 'min_hits': 0, 'iou_threshold': iou_thres}, min_hits=min_hits)
|
195 |
|
196 |
# Run tracking
|
197 |
-
with tqdm(total=len(all_preds), desc="Running tracking", ncols=0) as pbar:
|
198 |
for i, key in enumerate(sorted(all_preds.keys())):
|
199 |
if gp: gp(i / len(all_preds), pbar.__str__())
|
200 |
boxes = all_preds[key]
|
|
|
102 |
|
103 |
return model, device
|
104 |
|
105 |
+
def do_detection(dataloader, model, device, gp=None, batch_size=BATCH_SIZE, verbose=True):
|
106 |
"""
|
107 |
Args:
|
108 |
frames_dir: a directory containing frames to be evaluated
|
|
|
115 |
|
116 |
inference = []
|
117 |
# Run detection
|
118 |
+
with tqdm(total=len(dataloader)*batch_size, desc="Running detection", ncols=0, disable=not verbose) as pbar:
|
119 |
for batch_i, (img, _, shapes) in enumerate(dataloader):
|
120 |
|
121 |
if gp: gp(batch_i / len(dataloader), pbar.__str__())
|
|
|
134 |
|
135 |
return inference, width, height
|
136 |
|
137 |
+
def do_suppression(dataloader, inference, width, height, gp=None, batch_size=BATCH_SIZE, conf_thres=CONF_THRES, iou_thres=NMS_IOU, verbose=True):
|
138 |
"""
|
139 |
Args:
|
140 |
frames_dir: a directory containing frames to be evaluated
|
|
|
147 |
# keep predictions to feed them ordered into the Tracker
|
148 |
# TODO: how to deal with large files?
|
149 |
all_preds = {}
|
150 |
+
with tqdm(total=len(dataloader)*batch_size, desc="Running suppression", ncols=0, disable=not verbose) as pbar:
|
151 |
for batch_i, (img, _, shapes) in enumerate(dataloader):
|
152 |
|
153 |
if gp: gp(batch_i / len(dataloader), pbar.__str__())
|
|
|
180 |
|
181 |
return all_preds, real_width, real_height
|
182 |
|
183 |
+
def do_tracking(all_preds, image_meter_width, image_meter_height, gp=None, max_age=MAX_AGE, iou_thres=IOU_THRES, min_hits=MIN_HITS, min_length=MIN_LENGTH, verbose=True):
|
184 |
|
185 |
if (gp): gp(0, "Tracking...")
|
186 |
|
|
|
194 |
tracker = Tracker(clip_info, args={ 'max_age': max_age, 'min_hits': 0, 'iou_threshold': iou_thres}, min_hits=min_hits)
|
195 |
|
196 |
# Run tracking
|
197 |
+
with tqdm(total=len(all_preds), desc="Running tracking", ncols=0, disable=not verbose) as pbar:
|
198 |
for i, key in enumerate(sorted(all_preds.keys())):
|
199 |
if gp: gp(i / len(all_preds), pbar.__str__())
|
200 |
boxes = all_preds[key]
|
scripts/infer_frames.py
CHANGED
@@ -8,9 +8,10 @@ from aris import create_manual_marking, create_metadata_dictionary, prep_for_mm
|
|
8 |
from inference import setup_model, do_suppression, do_detection, do_tracking, json_dump_round_float
|
9 |
from visualizer import generate_video_batches
|
10 |
import json
|
|
|
11 |
|
12 |
|
13 |
-
def main(args, config={}):
|
14 |
"""
|
15 |
Main processing task to be run in gradio
|
16 |
- Writes aris frames to dirname(filepath)/frames/{i}.jpg
|
@@ -54,15 +55,18 @@ def main(args, config={}):
|
|
54 |
|
55 |
seq_list = os.listdir(in_loc_dir)
|
56 |
idx = 1
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
|
|
|
|
|
|
66 |
|
67 |
#progress_log = lambda p, m: 0
|
68 |
|
@@ -83,16 +87,16 @@ def infer_seq(in_dir, out_dir, config, seq_name, model, device, metadata_path):
|
|
83 |
dataloader = create_dataloader_frames_only(in_dir)
|
84 |
|
85 |
try:
|
86 |
-
inference, width, height = do_detection(dataloader, model, device)
|
87 |
except:
|
88 |
print("Error in " + seq_name)
|
89 |
with open(os.path.join(out_dir, "ERROR_" + seq_name + ".txt"), 'w') as f:
|
90 |
f.write("ERROR")
|
91 |
return
|
92 |
|
93 |
-
all_preds, real_width, real_height = do_suppression(dataloader, inference, width, height, conf_thres=config['conf_threshold'], iou_thres=config['nms_iou'])
|
94 |
|
95 |
-
results = do_tracking(all_preds, image_meter_width, image_meter_height, min_length=config['min_length'], max_age=config['max_age'], iou_thres=config['iou_threshold'], min_hits=config['min_hits'])
|
96 |
|
97 |
mot_rows = []
|
98 |
for frame in results['frames']:
|
|
|
8 |
from inference import setup_model, do_suppression, do_detection, do_tracking, json_dump_round_float
|
9 |
from visualizer import generate_video_batches
|
10 |
import json
|
11 |
+
from tqdm import tqdm
|
12 |
|
13 |
|
14 |
+
def main(args, config={}, verbose=True):
|
15 |
"""
|
16 |
Main processing task to be run in gradio
|
17 |
- Writes aris frames to dirname(filepath)/frames/{i}.jpg
|
|
|
55 |
|
56 |
seq_list = os.listdir(in_loc_dir)
|
57 |
idx = 1
|
58 |
+
with tqdm(total=len(seq_list), desc="...", ncols=0) as pbar:
|
59 |
+
for seq in seq_list:
|
60 |
+
pbar.update(1)
|
61 |
+
pbar.set_description("Processing " + seq)
|
62 |
+
print(" ")
|
63 |
+
print("(" + str(idx) + "/" + str(len(seq_list)) + ") " + seq)
|
64 |
+
print(" ")
|
65 |
+
idx += 1
|
66 |
+
in_seq_dir = os.path.join(in_loc_dir, seq)
|
67 |
+
infer_seq(in_seq_dir, out_dir, config, seq, model, device, metadata_path, verbose)
|
68 |
+
|
69 |
+
def infer_seq(in_dir, out_dir, config, seq_name, model, device, metadata_path, verbose):
|
70 |
|
71 |
#progress_log = lambda p, m: 0
|
72 |
|
|
|
87 |
dataloader = create_dataloader_frames_only(in_dir)
|
88 |
|
89 |
try:
|
90 |
+
inference, width, height = do_detection(dataloader, model, device, verbose=verbose)
|
91 |
except:
|
92 |
print("Error in " + seq_name)
|
93 |
with open(os.path.join(out_dir, "ERROR_" + seq_name + ".txt"), 'w') as f:
|
94 |
f.write("ERROR")
|
95 |
return
|
96 |
|
97 |
+
all_preds, real_width, real_height = do_suppression(dataloader, inference, width, height, conf_thres=config['conf_threshold'], iou_thres=config['nms_iou'], verbose=verbose)
|
98 |
|
99 |
+
results = do_tracking(all_preds, image_meter_width, image_meter_height, min_length=config['min_length'], max_age=config['max_age'], iou_thres=config['iou_threshold'], min_hits=config['min_hits'], verbose=verbose)
|
100 |
|
101 |
mot_rows = []
|
102 |
for frame in results['frames']:
|