Spaces:
Runtime error
Runtime error
Commit
•
2482ba4
1
Parent(s):
d57b89b
Updated inference
Browse files- inference.py +6 -5
- scripts/infer_frames.py +19 -9
inference.py
CHANGED
@@ -24,8 +24,8 @@ WEIGHTS = 'models/v5m_896_300best.pt'
|
|
24 |
# will need to configure these based on GPU hardware
|
25 |
BATCH_SIZE = 32
|
26 |
|
27 |
-
conf_thres = 0.
|
28 |
-
iou_thres = 0.
|
29 |
min_length = 0.3 # minimum fish length, in meters
|
30 |
###
|
31 |
|
@@ -72,7 +72,7 @@ def do_full_inference(dataloader, image_meter_width, image_meter_height, gp=None
|
|
72 |
return
|
73 |
|
74 |
|
75 |
-
all_preds = do_suppression(dataloader, inference, width, height, gp=gp)
|
76 |
|
77 |
results = do_tracking(all_preds, image_meter_width, image_meter_height, gp=gp)
|
78 |
|
@@ -122,7 +122,6 @@ def do_detection(dataloader, model, device, gp=None, batch_size=BATCH_SIZE):
|
|
122 |
size = tuple(img.shape)
|
123 |
nb, _, height, width = size # batch size, channels, height, width
|
124 |
|
125 |
-
print(nb, _, height, width)
|
126 |
# Run model & NMS
|
127 |
with torch.no_grad():
|
128 |
inf_out, _ = model(img, augment=False)
|
@@ -166,6 +165,8 @@ def do_suppression(dataloader, inference, width, height, gp=None, batch_size=BAT
|
|
166 |
# confidence score currently not used by tracker; set to 1.0
|
167 |
boxes = None
|
168 |
if box.shape[0]:
|
|
|
|
|
169 |
do_norm = partial(norm, w=shapes[si][0][1], h=shapes[si][0][0])
|
170 |
normed = list((map(do_norm, box[:, :4].tolist())))
|
171 |
boxes = np.stack([ [*bb, conf] for bb, conf in zip(normed, confs) ])
|
@@ -174,7 +175,7 @@ def do_suppression(dataloader, inference, width, height, gp=None, batch_size=BAT
|
|
174 |
|
175 |
pbar.update(1*batch_size)
|
176 |
|
177 |
-
return all_preds
|
178 |
|
179 |
def do_tracking(all_preds, image_meter_width, image_meter_height, gp=None):
|
180 |
|
|
|
24 |
# will need to configure these based on GPU hardware
|
25 |
BATCH_SIZE = 32
|
26 |
|
27 |
+
conf_thres = 0.3 # detection
|
28 |
+
iou_thres = 0.3 # NMS IOU
|
29 |
min_length = 0.3 # minimum fish length, in meters
|
30 |
###
|
31 |
|
|
|
72 |
return
|
73 |
|
74 |
|
75 |
+
all_preds, real_width, real_height = do_suppression(dataloader, inference, width, height, gp=gp)
|
76 |
|
77 |
results = do_tracking(all_preds, image_meter_width, image_meter_height, gp=gp)
|
78 |
|
|
|
122 |
size = tuple(img.shape)
|
123 |
nb, _, height, width = size # batch size, channels, height, width
|
124 |
|
|
|
125 |
# Run model & NMS
|
126 |
with torch.no_grad():
|
127 |
inf_out, _ = model(img, augment=False)
|
|
|
165 |
# confidence score currently not used by tracker; set to 1.0
|
166 |
boxes = None
|
167 |
if box.shape[0]:
|
168 |
+
real_width = shapes[si][0][1]
|
169 |
+
real_height = shapes[si][0][0]
|
170 |
do_norm = partial(norm, w=shapes[si][0][1], h=shapes[si][0][0])
|
171 |
normed = list((map(do_norm, box[:, :4].tolist())))
|
172 |
boxes = np.stack([ [*bb, conf] for bb, conf in zip(normed, confs) ])
|
|
|
175 |
|
176 |
pbar.update(1*batch_size)
|
177 |
|
178 |
+
return all_preds, real_width, real_height
|
179 |
|
180 |
def do_tracking(all_preds, image_meter_width, image_meter_height, gp=None):
|
181 |
|
scripts/infer_frames.py
CHANGED
@@ -8,7 +8,6 @@ from aris import create_manual_marking, create_metadata_dictionary, prep_for_mm
|
|
8 |
from inference import setup_model, do_suppression, do_detection, do_tracking, json_dump_round_float
|
9 |
from visualizer import generate_video_batches
|
10 |
import json
|
11 |
-
from tracker import Tracker
|
12 |
|
13 |
|
14 |
def main(args):
|
@@ -29,7 +28,7 @@ def main(args):
|
|
29 |
|
30 |
dirname = args.frames
|
31 |
|
32 |
-
locations = ["
|
33 |
for loc in locations:
|
34 |
|
35 |
in_loc_dir = os.path.join(dirname, loc)
|
@@ -73,9 +72,15 @@ def infer_seq(in_dir, out_dir, seq_name, weights, metadata_path):
|
|
73 |
# run detection + tracking
|
74 |
model, device = setup_model(weights)
|
75 |
|
76 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
|
78 |
-
all_preds = do_suppression(dataloader, inference, width, height, gp=gradio_progress)
|
79 |
|
80 |
results = do_tracking(all_preds, image_meter_width, image_meter_height, gp=gradio_progress)
|
81 |
|
@@ -84,12 +89,17 @@ def infer_seq(in_dir, out_dir, seq_name, weights, metadata_path):
|
|
84 |
for fish in frame['fish']:
|
85 |
bbox = fish['bbox']
|
86 |
row = []
|
87 |
-
|
|
|
|
|
|
|
|
|
|
|
88 |
row.append(str(fish['fish_id'] + 1))
|
89 |
-
row.append(str(int(
|
90 |
-
row.append(str(int(
|
91 |
-
row.append(str(int(
|
92 |
-
row.append(str(int(
|
93 |
row.append("-1")
|
94 |
row.append("-1")
|
95 |
row.append("-1")
|
|
|
8 |
from inference import setup_model, do_suppression, do_detection, do_tracking, json_dump_round_float
|
9 |
from visualizer import generate_video_batches
|
10 |
import json
|
|
|
11 |
|
12 |
|
13 |
def main(args):
|
|
|
28 |
|
29 |
dirname = args.frames
|
30 |
|
31 |
+
locations = ["test"]
|
32 |
for loc in locations:
|
33 |
|
34 |
in_loc_dir = os.path.join(dirname, loc)
|
|
|
72 |
# run detection + tracking
|
73 |
model, device = setup_model(weights)
|
74 |
|
75 |
+
try:
|
76 |
+
inference, width, height = do_detection(dataloader, model, device, gp=gradio_progress)
|
77 |
+
except:
|
78 |
+
print("Error in " + seq_name)
|
79 |
+
with open(os.path.join(out_dir, "ERROR_" + seq_name + ".txt"), 'w') as f:
|
80 |
+
f.write("ERROR")
|
81 |
+
return
|
82 |
|
83 |
+
all_preds, real_width, real_height = do_suppression(dataloader, inference, width, height, gp=gradio_progress)
|
84 |
|
85 |
results = do_tracking(all_preds, image_meter_width, image_meter_height, gp=gradio_progress)
|
86 |
|
|
|
89 |
for fish in frame['fish']:
|
90 |
bbox = fish['bbox']
|
91 |
row = []
|
92 |
+
right = bbox[0]*real_width
|
93 |
+
top = bbox[1]*real_height
|
94 |
+
w = bbox[2]*real_width - bbox[0]*real_width
|
95 |
+
h = bbox[3]*real_height - bbox[1]*real_height
|
96 |
+
|
97 |
+
row.append(str(frame['frame_num'] + 1))
|
98 |
row.append(str(fish['fish_id'] + 1))
|
99 |
+
row.append(str(int(right)))
|
100 |
+
row.append(str(int(top)))
|
101 |
+
row.append(str(int(w)))
|
102 |
+
row.append(str(int(h)))
|
103 |
row.append("-1")
|
104 |
row.append("-1")
|
105 |
row.append("-1")
|