Spaces:
Runtime error
Runtime error
oskarastrom
commited on
Commit
•
7e4e0ac
1
Parent(s):
73ba285
ByteTrack for UI
Browse files- app.py +13 -5
- aris.py +9 -0
- gradio_scripts/pdf_handler.py +312 -0
- gradio_scripts/result_ui.py +4 -1
- gradio_scripts/upload_ui.py +8 -4
- inference.py +31 -9
- multipage_pdf.pdf +0 -0
app.py
CHANGED
@@ -14,6 +14,8 @@ from gradio_scripts.result_ui import Result_Gradio, update_result, table_headers
|
|
14 |
from dataloader import create_dataloader_aris
|
15 |
from aris import BEAM_WIDTH_DIR
|
16 |
|
|
|
|
|
17 |
#Initialize State & Result
|
18 |
state = {
|
19 |
'files': [],
|
@@ -27,24 +29,29 @@ result = {}
|
|
27 |
|
28 |
|
29 |
# Called when an Aris file is uploaded for inference
|
30 |
-
def on_aris_input(file_list, model_id, conf_thresh, iou_thresh, min_hits, max_age,
|
31 |
|
32 |
# Reset Result
|
33 |
reset_state(result, state)
|
34 |
state['files'] = file_list
|
35 |
state['total'] = len(file_list)
|
|
|
36 |
state['hyperparams'] = {
|
37 |
'model': models[model_id] if model_id in models else models['master'],
|
38 |
'conf_thresh': conf_thresh,
|
39 |
'iou_thresh': iou_thresh,
|
40 |
'min_hits': min_hits,
|
41 |
'max_age': max_age,
|
42 |
-
'use_associative_tracking': use_associative,
|
43 |
-
'boost_power': boost_power,
|
44 |
-
'boost_decay': boost_decay,
|
45 |
'min_length': min_length,
|
46 |
-
'min_travel': min_travel
|
|
|
47 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
|
49 |
print(" ")
|
50 |
print("Running with:")
|
@@ -69,6 +76,7 @@ def on_result_upload(zip_list, aris_list):
|
|
69 |
|
70 |
|
71 |
reset_state(result, state)
|
|
|
72 |
|
73 |
component_updates = {
|
74 |
master_tabs: gr.update(selected=1),
|
|
|
14 |
from dataloader import create_dataloader_aris
|
15 |
from aris import BEAM_WIDTH_DIR
|
16 |
|
17 |
+
WEBAPP_VERSION = "1.0"
|
18 |
+
|
19 |
#Initialize State & Result
|
20 |
state = {
|
21 |
'files': [],
|
|
|
29 |
|
30 |
|
31 |
# Called when an Aris file is uploaded for inference
|
32 |
+
def on_aris_input(file_list, model_id, conf_thresh, iou_thresh, min_hits, max_age, associative_tracker, boost_power, boost_decay, byte_low_conf, byte_high_conf, min_length, min_travel):
|
33 |
|
34 |
# Reset Result
|
35 |
reset_state(result, state)
|
36 |
state['files'] = file_list
|
37 |
state['total'] = len(file_list)
|
38 |
+
state['version'] = WEBAPP_VERSION
|
39 |
state['hyperparams'] = {
|
40 |
'model': models[model_id] if model_id in models else models['master'],
|
41 |
'conf_thresh': conf_thresh,
|
42 |
'iou_thresh': iou_thresh,
|
43 |
'min_hits': min_hits,
|
44 |
'max_age': max_age,
|
|
|
|
|
|
|
45 |
'min_length': min_length,
|
46 |
+
'min_travel': min_travel,
|
47 |
+
'associative_tracker': associative_tracker,
|
48 |
}
|
49 |
+
if (associative_tracker == "Confidence Boost"):
|
50 |
+
state['hyperparams']['boost_power'] = boost_power
|
51 |
+
state['hyperparams']['boost_decay'] = boost_decay
|
52 |
+
elif (associative_tracker == "ByteTrack"):
|
53 |
+
state['hyperparams']['byte_low_conf'] = byte_low_conf
|
54 |
+
state['hyperparams']['byte_high_conf'] = byte_high_conf
|
55 |
|
56 |
print(" ")
|
57 |
print("Running with:")
|
|
|
76 |
|
77 |
|
78 |
reset_state(result, state)
|
79 |
+
state['version'] = WEBAPP_VERSION
|
80 |
|
81 |
component_updates = {
|
82 |
master_tabs: gr.update(selected=1),
|
aris.py
CHANGED
@@ -441,6 +441,15 @@ def create_metadata_table(result, table_headers, info_headers):
|
|
441 |
else:
|
442 |
metadata = { 'FISH': [] }
|
443 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
444 |
# Create fish table
|
445 |
table = []
|
446 |
for fish in metadata["FISH"]:
|
|
|
441 |
else:
|
442 |
metadata = { 'FISH': [] }
|
443 |
|
444 |
+
# Calculate detection dropout
|
445 |
+
for fish in metadata['FISH']:
|
446 |
+
count = 0
|
447 |
+
for frame in result['frames'][fish['START_FRAME']:fish['END_FRAME']+1]:
|
448 |
+
for ann in frame['fish']:
|
449 |
+
if ann['fish_id'] == fish['TOTAL']:
|
450 |
+
count += 1
|
451 |
+
fish['DETECTION_DROPOUT'] = 1 - count / (fish['END_FRAME'] + 1 - fish['START_FRAME'])
|
452 |
+
|
453 |
# Create fish table
|
454 |
table = []
|
455 |
for fish in metadata["FISH"]:
|
gradio_scripts/pdf_handler.py
ADDED
@@ -0,0 +1,312 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import datetime
|
2 |
+
import numpy as np
|
3 |
+
from matplotlib.backends.backend_pdf import PdfPages
|
4 |
+
from matplotlib import collections as mc
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
+
import math
|
7 |
+
from aris import BEAM_WIDTH_DIR
|
8 |
+
import cv2
|
9 |
+
|
10 |
+
from dataloader import create_dataloader_aris
|
11 |
+
|
12 |
+
|
13 |
+
STANDARD_FIG_SIZE = (16, 9)
|
14 |
+
OUT_PDF_FILE_NAME = 'multipage_pdf.pdf'
|
15 |
+
|
16 |
+
|
17 |
+
def make_pdf(i, state, result, table_headers):
|
18 |
+
|
19 |
+
fish_info = result["fish_info"][i]
|
20 |
+
fish_table = result["fish_table"][i]
|
21 |
+
json_result = result['json_result'][i]
|
22 |
+
metadata = json_result['metadata']
|
23 |
+
aris_input = result["aris_input"][i]
|
24 |
+
|
25 |
+
with PdfPages(OUT_PDF_FILE_NAME) as pdf:
|
26 |
+
plt.rcParams['text.usetex'] = False
|
27 |
+
|
28 |
+
generate_title_page(pdf, metadata, state)
|
29 |
+
|
30 |
+
generate_global_result(pdf, fish_info)
|
31 |
+
|
32 |
+
generate_fish_list(pdf, table_headers, fish_table)
|
33 |
+
|
34 |
+
|
35 |
+
dataset = None
|
36 |
+
if (aris_input is not None):
|
37 |
+
dataloader, dataset = create_dataloader_aris(aris_input, BEAM_WIDTH_DIR, None)
|
38 |
+
|
39 |
+
for i, fish in enumerate(json_result['fish']):
|
40 |
+
calculate_fish_paths(json_result, dataset, i)
|
41 |
+
|
42 |
+
draw_combined_fish_graphs(pdf, json_result)
|
43 |
+
|
44 |
+
for i, fish in enumerate(json_result['fish']):
|
45 |
+
generate_fish_tracks(pdf, json_result, i)
|
46 |
+
|
47 |
+
# We can also set the file's metadata via the PdfPages object:
|
48 |
+
d = pdf.infodict()
|
49 |
+
d['Title'] = 'Multipage PDF Example'
|
50 |
+
d['Author'] = 'Oskar Åström'
|
51 |
+
d['Subject'] = 'How to create a multipage pdf file and set its metadata'
|
52 |
+
d['Keywords'] = ''
|
53 |
+
d['CreationDate'] = datetime.datetime.today()
|
54 |
+
d['ModDate'] = datetime.datetime.today()
|
55 |
+
|
56 |
+
|
57 |
+
def generate_title_page(pdf, metadata, state):
|
58 |
+
# set up figure that will be used to display the opening banner
|
59 |
+
fig = plt.figure(figsize=STANDARD_FIG_SIZE)
|
60 |
+
plt.axis('off')
|
61 |
+
|
62 |
+
title_font_size = 40
|
63 |
+
minor_font_size = 20
|
64 |
+
|
65 |
+
# stuff to be printed out on the first page of the report
|
66 |
+
plt.text(0.5,-0.5,f'{metadata["FILE_NAME"].split("/")[-1]}',fontsize=title_font_size, horizontalalignment='center')
|
67 |
+
|
68 |
+
plt.text(0,1,f'Duration: {metadata["TOTAL_TIME"]}',fontsize=minor_font_size)
|
69 |
+
plt.text(0,1.5,f'Frames: {metadata["TOTAL_FRAMES"]}',fontsize=minor_font_size)
|
70 |
+
plt.text(0,2,f'Frame Rate: {metadata["FRAME_RATE"]}',fontsize=minor_font_size)
|
71 |
+
|
72 |
+
plt.text(0.5,1,f'Time of filming: {metadata["DATE"]} ({metadata["START"]} - {metadata["END"]})',fontsize=minor_font_size)
|
73 |
+
plt.text(0.5,1.5,f'Web app version: {state["version"]}',fontsize=minor_font_size)
|
74 |
+
|
75 |
+
plt.text(1.1,4.5,f'PDF generated on {datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}',fontsize=minor_font_size, horizontalalignment='right')
|
76 |
+
|
77 |
+
plt.ylim([-1, 4])
|
78 |
+
plt.xlim([0, 1])
|
79 |
+
plt.gca().invert_yaxis()
|
80 |
+
|
81 |
+
pdf.savefig(fig)
|
82 |
+
plt.close(fig)
|
83 |
+
|
84 |
+
def generate_global_result(pdf, fish_info):
|
85 |
+
# set up figure that will be used to display the opening banner
|
86 |
+
fig = plt.figure(figsize=STANDARD_FIG_SIZE)
|
87 |
+
plt.axis('off')
|
88 |
+
# stuff to be printed out on the first page of the report
|
89 |
+
|
90 |
+
minor_font_size = 18
|
91 |
+
|
92 |
+
headers = ["Result", "Camera Info", "Hyperparameters"]
|
93 |
+
info_col_1 = []
|
94 |
+
info_col_2 = []
|
95 |
+
info_col = info_col_1
|
96 |
+
row_state = -1
|
97 |
+
for row in fish_info:
|
98 |
+
if row_state >= 0:
|
99 |
+
info_col.append([row[0].replace("**","").replace("_", " ").lower(), row[1], 'normal'])
|
100 |
+
if (row[0] == "****"):
|
101 |
+
row_state += 1
|
102 |
+
if row_state == 2: info_col = info_col_2
|
103 |
+
info_col.append([headers[row_state], "", 'bold'])
|
104 |
+
for row_i, row in enumerate(info_col_1):
|
105 |
+
h = -1 + 5*row_i/len(info_col_1)
|
106 |
+
plt.text(0, h, row[0], fontsize=minor_font_size, weight=row[2])
|
107 |
+
plt.text(0.25, h, row[1], fontsize=minor_font_size, weight=row[2])
|
108 |
+
for row_i, row in enumerate(info_col_2):
|
109 |
+
h = -1 + 5*row_i/len(info_col_2)
|
110 |
+
plt.text(0.5, h, row[0], fontsize=minor_font_size, weight=row[2])
|
111 |
+
plt.text(0.75, h, row[1], fontsize=minor_font_size, weight=row[2])
|
112 |
+
plt.ylim([-1, 4])
|
113 |
+
plt.xlim([0, 1])
|
114 |
+
plt.gca().invert_yaxis()
|
115 |
+
|
116 |
+
pdf.savefig(fig)
|
117 |
+
plt.close(fig)
|
118 |
+
|
119 |
+
def generate_fish_list(pdf, table_headers, fish_table):
|
120 |
+
# set up figure that will be used to display the opening banner
|
121 |
+
fig = plt.figure(figsize=STANDARD_FIG_SIZE)
|
122 |
+
plt.axis('off')
|
123 |
+
# stuff to be printed out on the first page of the report
|
124 |
+
|
125 |
+
title_font_size = 40
|
126 |
+
header_font_size = 12
|
127 |
+
body_font_size = 20
|
128 |
+
|
129 |
+
# Title
|
130 |
+
plt.text(0.5,-1.3,f'{"Identified Fish"}',fontsize=title_font_size, horizontalalignment='center', weight='bold')
|
131 |
+
|
132 |
+
# Identified fish
|
133 |
+
row_h = 0.25
|
134 |
+
col_start = 0
|
135 |
+
row_l = 1
|
136 |
+
dropout_i = None
|
137 |
+
for col_i, col in enumerate(table_headers):
|
138 |
+
x = col_start + row_l*(col_i+0.5)/len(table_headers)
|
139 |
+
if col == "TOTAL": col = "ID"
|
140 |
+
if col == "DETECTION_DROPOUT":
|
141 |
+
col = "frame drop rate"
|
142 |
+
dropout_i = col_i
|
143 |
+
col = col.lower().replace("_", " ")
|
144 |
+
plt.text(x, -1, col, fontsize=header_font_size, horizontalalignment='center', weight="bold")
|
145 |
+
plt.plot([col_start*2, -col_start*2 + row_l], [-1 + 0.05, -1 + 0.05], color='black')
|
146 |
+
|
147 |
+
for row_i, row in enumerate(fish_table):
|
148 |
+
y = -1 + (row_i+1)*row_h
|
149 |
+
plt.plot([col_start*2, -col_start*2 + row_l], [y+0.05, y+0.05], color='black')
|
150 |
+
for col_i, col in enumerate(row):
|
151 |
+
x = col_start + row_l*(col_i+0.5)/len(row)
|
152 |
+
if (col_i == dropout_i):
|
153 |
+
col = str(int(col*100)) + "%"
|
154 |
+
elif type(col) == float:
|
155 |
+
col = "{:.4f}".format(col)
|
156 |
+
plt.text(x, y, col, fontsize=body_font_size, horizontalalignment='center')
|
157 |
+
plt.ylim([-1, 4])
|
158 |
+
plt.xlim([0, 1])
|
159 |
+
plt.gca().invert_yaxis()
|
160 |
+
|
161 |
+
pdf.savefig(fig)
|
162 |
+
plt.close(fig)
|
163 |
+
|
164 |
+
def calculate_fish_paths(result, dataset, id):
|
165 |
+
|
166 |
+
fish = result['metadata']['FISH'][id]
|
167 |
+
start_frame = fish['START_FRAME']
|
168 |
+
end_frame = fish['END_FRAME']
|
169 |
+
|
170 |
+
# Extract base frame (first frame for that fish)
|
171 |
+
w, h = 1, 2
|
172 |
+
img = None
|
173 |
+
if (dataset is not None):
|
174 |
+
|
175 |
+
images = dataset.didson.load_frames(start_frame=start_frame, end_frame=start_frame+1)
|
176 |
+
img = images[0]
|
177 |
+
|
178 |
+
frame_height = 2
|
179 |
+
scale_factor = frame_height / h
|
180 |
+
h = frame_height
|
181 |
+
w = int(scale_factor*w)
|
182 |
+
|
183 |
+
fish['base_frame'] = img
|
184 |
+
fish['scaled_frame_size'] = (w, h)
|
185 |
+
|
186 |
+
|
187 |
+
# Find frames for this fish
|
188 |
+
bboxes = []
|
189 |
+
for frame in result['frames'][start_frame:end_frame+1]:
|
190 |
+
bbox = None
|
191 |
+
for ann in frame['fish']:
|
192 |
+
if ann['fish_id'] == id+1:
|
193 |
+
bbox = ann
|
194 |
+
bboxes.append(bbox)
|
195 |
+
|
196 |
+
|
197 |
+
# Calculate tracks through frames
|
198 |
+
missed = 0
|
199 |
+
X = []
|
200 |
+
Y = []
|
201 |
+
V = []
|
202 |
+
certainty = []
|
203 |
+
for bbox in bboxes:
|
204 |
+
if bbox is not None:
|
205 |
+
|
206 |
+
# Find fish centers
|
207 |
+
x = (bbox['bbox'][0] + bbox['bbox'][2])/2
|
208 |
+
y = (bbox['bbox'][1] + bbox['bbox'][3])/2
|
209 |
+
|
210 |
+
# Calculate velocity
|
211 |
+
v = None
|
212 |
+
if len(X) > 0:
|
213 |
+
last_x = X[-1]
|
214 |
+
last_y = Y[-1]
|
215 |
+
dx = result['image_meter_width']*(last_x - x)/(missed+1)
|
216 |
+
dy = result['image_meter_height']*(last_y - y)/(missed+1)
|
217 |
+
v = math.sqrt(dx*dx + dy*dy)
|
218 |
+
|
219 |
+
# Interpolate over missing frames
|
220 |
+
if missed > 0:
|
221 |
+
for i in range(missed):
|
222 |
+
p = (i+1)/(missed+1)
|
223 |
+
X.append(last_x*(1-p) + p*x)
|
224 |
+
Y.append(last_y*(1-p) + p*y)
|
225 |
+
V.append(v)
|
226 |
+
certainty.append(False)
|
227 |
+
|
228 |
+
# Append new track frame
|
229 |
+
X.append(x)
|
230 |
+
Y.append(y)
|
231 |
+
if v is not None: V.append(v)
|
232 |
+
certainty.append(True)
|
233 |
+
missed = 0
|
234 |
+
else:
|
235 |
+
missed += 1
|
236 |
+
|
237 |
+
fish['path'] = {
|
238 |
+
'X': X,
|
239 |
+
'Y': Y,
|
240 |
+
'certainty': certainty,
|
241 |
+
'V': V
|
242 |
+
}
|
243 |
+
|
244 |
+
|
245 |
+
def draw_combined_fish_graphs(pdf, result):
|
246 |
+
vel = []
|
247 |
+
log_vel = []
|
248 |
+
for fish in result['metadata']['FISH']:
|
249 |
+
vel += fish['path']['V']
|
250 |
+
log_vel += [math.log(v) for v in fish['path']['V']]
|
251 |
+
|
252 |
+
fig, axs = plt.subplots(2, 2, sharey=True, figsize=STANDARD_FIG_SIZE)
|
253 |
+
axs[0,0].hist(log_vel, bins=20)
|
254 |
+
axs[0,0].set_title('Fish Log-Velocities between frames')
|
255 |
+
axs[0,0].set_xlabel("Log-Velocity (log(m/frame))")
|
256 |
+
axs[0,1].hist(vel, bins=20)
|
257 |
+
axs[0,1].set_title('Fish Velocities between frames')
|
258 |
+
axs[0,1].set_xlabel("Velocity (m/frame)")
|
259 |
+
|
260 |
+
pdf.savefig(fig)
|
261 |
+
plt.close(fig)
|
262 |
+
|
263 |
+
|
264 |
+
def generate_fish_tracks(pdf, result, id):
|
265 |
+
|
266 |
+
fish = result['metadata']['FISH'][id]
|
267 |
+
start_frame = fish['START_FRAME']
|
268 |
+
end_frame = fish['END_FRAME']
|
269 |
+
|
270 |
+
fig, ax = plt.subplots(figsize=STANDARD_FIG_SIZE)
|
271 |
+
plt.axis('off')
|
272 |
+
|
273 |
+
w, h = fish['scaled_frame_size']
|
274 |
+
if (fish['base_frame'] is not None):
|
275 |
+
img = fish['base_frame']
|
276 |
+
img = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE)
|
277 |
+
plt.imshow(img, extent=(0, h, 0, w), cmap=plt.colormaps['Greys'])
|
278 |
+
|
279 |
+
# Title
|
280 |
+
plt.text(h/2,1.1,f'Fish {id+1} (frames {start_frame}-{end_frame})',fontsize=40, color="red", horizontalalignment='center', zorder=5)
|
281 |
+
|
282 |
+
X = fish['path']['X']
|
283 |
+
Y = fish['path']['Y']
|
284 |
+
certainty = fish['path']['certainty']
|
285 |
+
|
286 |
+
plt.text(h*(1-Y[0]), w*(1-X[0]), "Start", fontsize=15, weight="bold")
|
287 |
+
plt.text(h*(1-Y[-1]), w*(1-X[-1]), "End", fontsize=15, weight="bold")
|
288 |
+
|
289 |
+
colors = []
|
290 |
+
for i in range(1, len(X)):
|
291 |
+
|
292 |
+
certain = certainty[i]
|
293 |
+
fully_certain = certain
|
294 |
+
half_certain = certain
|
295 |
+
if i > 0:
|
296 |
+
fully_certain &= certainty[i-1]
|
297 |
+
half_certain |= certainty[i-1]
|
298 |
+
|
299 |
+
#color = 'yellow' if certain else 'orangered'
|
300 |
+
#plt.plot(h*(1-y), w*(1-x), marker='o', markersize=3, color=color, zorder=3)
|
301 |
+
col = 'yellow' if fully_certain else ('darkorange' if half_certain else 'orangered')
|
302 |
+
colors.append(col)
|
303 |
+
ax.plot([h*(1-Y[i-1]), h*(1-Y[i])], [w*(1-X[i-1]), w*(1-X[i])], color=col, linewidth=1)
|
304 |
+
|
305 |
+
for i in range(1, len(X)):
|
306 |
+
ax.plot(h*(1-Y[i]), w*(1-X[i]), color=colors[i], marker='o', markersize=3)
|
307 |
+
|
308 |
+
|
309 |
+
plt.ylim([0, w])
|
310 |
+
plt.xlim([0, h])
|
311 |
+
pdf.savefig(fig)
|
312 |
+
plt.close(fig)
|
gradio_scripts/result_ui.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
import gradio as gr
|
2 |
import numpy as np
|
|
|
3 |
|
4 |
js_update_tab_labels = """
|
5 |
async () => {
|
@@ -14,7 +15,7 @@ js_update_tab_labels = """
|
|
14 |
}
|
15 |
"""
|
16 |
|
17 |
-
table_headers = ["TOTAL", "START_FRAME", "END_FRAME", "DIR", "R", "THETA", "L", "TRAVEL"]
|
18 |
info_headers = [
|
19 |
"TOTAL_TIME", "DATE", "START", "END", "FRAME_RATE", "",
|
20 |
"TOTAL_FISH", "UPSTREAM_FISH", "DOWNSTREAM_FISH", "NONDIRECTIONAL_FISH", "",
|
@@ -38,6 +39,8 @@ def update_result(i, state, result, inference_handler):
|
|
38 |
|
39 |
annotation_avaliable = not (result["aris_input"][i] == None)
|
40 |
|
|
|
|
|
41 |
# Send update to UI, and to inference_handler to start next file inference
|
42 |
return {
|
43 |
zip_out: gr.update(value=result["path_zip"]),
|
|
|
1 |
import gradio as gr
|
2 |
import numpy as np
|
3 |
+
from gradio_scripts.pdf_handler import make_pdf
|
4 |
|
5 |
js_update_tab_labels = """
|
6 |
async () => {
|
|
|
15 |
}
|
16 |
"""
|
17 |
|
18 |
+
table_headers = ["TOTAL", "START_FRAME", "END_FRAME", "DETECTION_DROPOUT", "DIR", "R", "THETA", "L", "TRAVEL"]
|
19 |
info_headers = [
|
20 |
"TOTAL_TIME", "DATE", "START", "END", "FRAME_RATE", "",
|
21 |
"TOTAL_FISH", "UPSTREAM_FISH", "DOWNSTREAM_FISH", "NONDIRECTIONAL_FISH", "",
|
|
|
39 |
|
40 |
annotation_avaliable = not (result["aris_input"][i] == None)
|
41 |
|
42 |
+
make_pdf(state['index']-1, state, result, table_headers)
|
43 |
+
|
44 |
# Send update to UI, and to inference_handler to start next file inference
|
45 |
return {
|
46 |
zip_out: gr.update(value=result["path_zip"]),
|
gradio_scripts/upload_ui.py
CHANGED
@@ -31,12 +31,16 @@ def Upload_Gradio(gradio_components):
|
|
31 |
settings.append(gr.Slider(0, 100, value=16, label="Min Hits", info="Minimum number of frames a fish has to appear in to count"))
|
32 |
settings.append(gr.Slider(0, 100, value=14, label="Max Age", info="Max age of occlusion before track is split"))
|
33 |
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
with gr.Row():
|
38 |
settings.append(gr.Slider(0, 5, value=1, label="Boost Power", info=""))
|
39 |
settings.append(gr.Slider(0, 1, value=1, label="Boost Decay", info=""))
|
|
|
|
|
|
|
|
|
|
|
40 |
|
41 |
gr.Markdown("Other")
|
42 |
with gr.Row():
|
|
|
31 |
settings.append(gr.Slider(0, 100, value=16, label="Min Hits", info="Minimum number of frames a fish has to appear in to count"))
|
32 |
settings.append(gr.Slider(0, 100, value=14, label="Max Age", info="Max age of occlusion before track is split"))
|
33 |
|
34 |
+
tracker = gr.Dropdown(["None", "Confidence Boost", "ByteTrack"], label="Associative Tracking", value="None")
|
35 |
+
settings.append(tracker)
|
36 |
+
with gr.Row(visible=False) as track_row:
|
|
|
37 |
settings.append(gr.Slider(0, 5, value=1, label="Boost Power", info=""))
|
38 |
settings.append(gr.Slider(0, 1, value=1, label="Boost Decay", info=""))
|
39 |
+
tracker.change(lambda x: gr.update(visible=(x=="Confidence Boost")), tracker, track_row)
|
40 |
+
with gr.Row(visible=False) as track_row:
|
41 |
+
settings.append(gr.Slider(0, 1, value=0.1, label="Low Conf Threshold", info=""))
|
42 |
+
settings.append(gr.Slider(0, 1, value=0.3, label="High Conf Threshold", info=""))
|
43 |
+
tracker.change(lambda x: gr.update(visible=(x=="ByteTrack")), tracker, track_row)
|
44 |
|
45 |
gr.Markdown("Other")
|
46 |
with gr.Row():
|
inference.py
CHANGED
@@ -58,12 +58,9 @@ def do_full_inference(dataloader, image_meter_width, image_meter_height, gp=None
|
|
58 |
if 'iou_thresh' not in hyperparams: hyperparams['iou_thresh'] = NMS_IOU
|
59 |
if 'min_hits' not in hyperparams: hyperparams['min_hits'] = MIN_HITS
|
60 |
if 'max_age' not in hyperparams: hyperparams['max_age'] = MAX_AGE
|
61 |
-
if 'use_associative_tracking' not in hyperparams: hyperparams['use_associative_tracking'] = False
|
62 |
-
if 'boost_power' not in hyperparams: hyperparams['boost_power'] = 1
|
63 |
-
if 'boost_decay' not in hyperparams: hyperparams['maxboost_decay_age'] = 1
|
64 |
-
if 'AT_decay' not in hyperparams: hyperparams['AT_decay'] = MIN_HITS
|
65 |
if 'min_length' not in hyperparams: hyperparams['min_length'] = MIN_LENGTH
|
66 |
if 'min_travel' not in hyperparams: hyperparams['min_travel'] = MIN_TRAVEL
|
|
|
67 |
|
68 |
model, device = setup_model(hyperparams['model'])
|
69 |
|
@@ -95,16 +92,41 @@ def do_full_inference(dataloader, image_meter_width, image_meter_height, gp=None
|
|
95 |
|
96 |
outputs = do_suppression(inference, conf_thres=hyperparams['conf_thresh'], iou_thres=hyperparams['iou_thresh'], gp=gp)
|
97 |
|
98 |
-
if hyperparams['
|
|
|
|
|
99 |
|
100 |
-
|
|
|
101 |
|
102 |
-
|
|
|
103 |
|
104 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
|
106 |
-
|
107 |
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
return results
|
109 |
|
110 |
|
|
|
58 |
if 'iou_thresh' not in hyperparams: hyperparams['iou_thresh'] = NMS_IOU
|
59 |
if 'min_hits' not in hyperparams: hyperparams['min_hits'] = MIN_HITS
|
60 |
if 'max_age' not in hyperparams: hyperparams['max_age'] = MAX_AGE
|
|
|
|
|
|
|
|
|
61 |
if 'min_length' not in hyperparams: hyperparams['min_length'] = MIN_LENGTH
|
62 |
if 'min_travel' not in hyperparams: hyperparams['min_travel'] = MIN_TRAVEL
|
63 |
+
if 'associative_tracker' not in hyperparams: hyperparams['associative_tracker'] = "None"
|
64 |
|
65 |
model, device = setup_model(hyperparams['model'])
|
66 |
|
|
|
92 |
|
93 |
outputs = do_suppression(inference, conf_thres=hyperparams['conf_thresh'], iou_thres=hyperparams['iou_thresh'], gp=gp)
|
94 |
|
95 |
+
if hyperparams['associative_tracker'] == "ByteTrack":
|
96 |
+
if 'byte_low_conf' not in hyperparams: hyperparams['byte_low_conf'] = 0.1
|
97 |
+
if 'byte_high_conf' not in hyperparams: hyperparams['byte_high_conf'] = 0.3
|
98 |
|
99 |
+
low_outputs = do_suppression(inference, conf_thres=hyperparams['low_conf_threshold'], iou_thres=hyperparams['iou_thresh'], gp=gp)
|
100 |
+
low_preds, real_width, real_height = format_predictions(image_shapes, low_outputs, width, height, gp=gp)
|
101 |
|
102 |
+
high_outputs = do_suppression(inference, conf_thres=hyperparams['high_conf_threshold'], iou_thres=hyperparams['iou_thresh'], gp=gp)
|
103 |
+
high_preds, real_width, real_height = format_predictions(image_shapes, high_outputs, width, height, gp=gp)
|
104 |
|
105 |
+
results = do_associative_tracking(
|
106 |
+
low_preds, high_preds, image_meter_width, image_meter_height,
|
107 |
+
reverse=False, min_length=hyperparams['min_length'], min_travel=hyperparams['min_travel'],
|
108 |
+
max_age=hyperparams['max_age'], min_hits=hyperparams['min_hits'],
|
109 |
+
gp=gp)
|
110 |
+
else:
|
111 |
+
|
112 |
+
outputs = do_suppression(inference, conf_thres=hyperparams['conf_threshold'], iou_thres=hyperparams['iou_thresh'], gp=gp)
|
113 |
+
|
114 |
+
if hyperparams['associative_tracker'] == "Confidence Boost":
|
115 |
+
if 'boost_power' not in hyperparams: hyperparams['boost_power'] = 1
|
116 |
+
if 'boost_decay' not in hyperparams: hyperparams['boost_decay'] = 1
|
117 |
+
|
118 |
+
do_confidence_boost(inference, outputs, boost_power=hyperparams['boost_power'], boost_decay=hyperparams['boost_decay'], gp=gp)
|
119 |
+
|
120 |
+
outputs = do_suppression(inference, conf_thres=hyperparams['conf_threshold'], iou_thres=hyperparams['iou_thresh'], gp=gp)
|
121 |
|
122 |
+
all_preds, real_width, real_height = format_predictions(image_shapes, outputs, width, height, gp=gp)
|
123 |
|
124 |
+
results = do_tracking(
|
125 |
+
all_preds, image_meter_width, image_meter_height,
|
126 |
+
min_length=hyperparams['min_length'], min_travel=hyperparams['min_travel'],
|
127 |
+
max_age=hyperparams['max_age'], iou_thres=hyperparams['iou_threshold'], min_hits=hyperparams['min_hits'],
|
128 |
+
gp=gp)
|
129 |
+
|
130 |
return results
|
131 |
|
132 |
|
multipage_pdf.pdf
ADDED
Binary file (940 kB). View file
|
|