oskarastrom commited on
Commit
7e4e0ac
1 Parent(s): 73ba285

ByteTrack for UI

Browse files
app.py CHANGED
@@ -14,6 +14,8 @@ from gradio_scripts.result_ui import Result_Gradio, update_result, table_headers
14
  from dataloader import create_dataloader_aris
15
  from aris import BEAM_WIDTH_DIR
16
 
 
 
17
  #Initialize State & Result
18
  state = {
19
  'files': [],
@@ -27,24 +29,29 @@ result = {}
27
 
28
 
29
  # Called when an Aris file is uploaded for inference
30
- def on_aris_input(file_list, model_id, conf_thresh, iou_thresh, min_hits, max_age, use_associative, boost_power, boost_decay, min_length, min_travel):
31
 
32
  # Reset Result
33
  reset_state(result, state)
34
  state['files'] = file_list
35
  state['total'] = len(file_list)
 
36
  state['hyperparams'] = {
37
  'model': models[model_id] if model_id in models else models['master'],
38
  'conf_thresh': conf_thresh,
39
  'iou_thresh': iou_thresh,
40
  'min_hits': min_hits,
41
  'max_age': max_age,
42
- 'use_associative_tracking': use_associative,
43
- 'boost_power': boost_power,
44
- 'boost_decay': boost_decay,
45
  'min_length': min_length,
46
- 'min_travel': min_travel
 
47
  }
 
 
 
 
 
 
48
 
49
  print(" ")
50
  print("Running with:")
@@ -69,6 +76,7 @@ def on_result_upload(zip_list, aris_list):
69
 
70
 
71
  reset_state(result, state)
 
72
 
73
  component_updates = {
74
  master_tabs: gr.update(selected=1),
 
14
  from dataloader import create_dataloader_aris
15
  from aris import BEAM_WIDTH_DIR
16
 
17
+ WEBAPP_VERSION = "1.0"
18
+
19
  #Initialize State & Result
20
  state = {
21
  'files': [],
 
29
 
30
 
31
  # Called when an Aris file is uploaded for inference
32
+ def on_aris_input(file_list, model_id, conf_thresh, iou_thresh, min_hits, max_age, associative_tracker, boost_power, boost_decay, byte_low_conf, byte_high_conf, min_length, min_travel):
33
 
34
  # Reset Result
35
  reset_state(result, state)
36
  state['files'] = file_list
37
  state['total'] = len(file_list)
38
+ state['version'] = WEBAPP_VERSION
39
  state['hyperparams'] = {
40
  'model': models[model_id] if model_id in models else models['master'],
41
  'conf_thresh': conf_thresh,
42
  'iou_thresh': iou_thresh,
43
  'min_hits': min_hits,
44
  'max_age': max_age,
 
 
 
45
  'min_length': min_length,
46
+ 'min_travel': min_travel,
47
+ 'associative_tracker': associative_tracker,
48
  }
49
+ if (associative_tracker == "Confidence Boost"):
50
+ state['hyperparams']['boost_power'] = boost_power
51
+ state['hyperparams']['boost_decay'] = boost_decay
52
+ elif (associative_tracker == "ByteTrack"):
53
+ state['hyperparams']['byte_low_conf'] = byte_low_conf
54
+ state['hyperparams']['byte_high_conf'] = byte_high_conf
55
 
56
  print(" ")
57
  print("Running with:")
 
76
 
77
 
78
  reset_state(result, state)
79
+ state['version'] = WEBAPP_VERSION
80
 
81
  component_updates = {
82
  master_tabs: gr.update(selected=1),
aris.py CHANGED
@@ -441,6 +441,15 @@ def create_metadata_table(result, table_headers, info_headers):
441
  else:
442
  metadata = { 'FISH': [] }
443
 
 
 
 
 
 
 
 
 
 
444
  # Create fish table
445
  table = []
446
  for fish in metadata["FISH"]:
 
441
  else:
442
  metadata = { 'FISH': [] }
443
 
444
+ # Calculate detection dropout
445
+ for fish in metadata['FISH']:
446
+ count = 0
447
+ for frame in result['frames'][fish['START_FRAME']:fish['END_FRAME']+1]:
448
+ for ann in frame['fish']:
449
+ if ann['fish_id'] == fish['TOTAL']:
450
+ count += 1
451
+ fish['DETECTION_DROPOUT'] = 1 - count / (fish['END_FRAME'] + 1 - fish['START_FRAME'])
452
+
453
  # Create fish table
454
  table = []
455
  for fish in metadata["FISH"]:
gradio_scripts/pdf_handler.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import numpy as np
3
+ from matplotlib.backends.backend_pdf import PdfPages
4
+ from matplotlib import collections as mc
5
+ import matplotlib.pyplot as plt
6
+ import math
7
+ from aris import BEAM_WIDTH_DIR
8
+ import cv2
9
+
10
+ from dataloader import create_dataloader_aris
11
+
12
+
13
+ STANDARD_FIG_SIZE = (16, 9)
14
+ OUT_PDF_FILE_NAME = 'multipage_pdf.pdf'
15
+
16
+
17
+ def make_pdf(i, state, result, table_headers):
18
+
19
+ fish_info = result["fish_info"][i]
20
+ fish_table = result["fish_table"][i]
21
+ json_result = result['json_result'][i]
22
+ metadata = json_result['metadata']
23
+ aris_input = result["aris_input"][i]
24
+
25
+ with PdfPages(OUT_PDF_FILE_NAME) as pdf:
26
+ plt.rcParams['text.usetex'] = False
27
+
28
+ generate_title_page(pdf, metadata, state)
29
+
30
+ generate_global_result(pdf, fish_info)
31
+
32
+ generate_fish_list(pdf, table_headers, fish_table)
33
+
34
+
35
+ dataset = None
36
+ if (aris_input is not None):
37
+ dataloader, dataset = create_dataloader_aris(aris_input, BEAM_WIDTH_DIR, None)
38
+
39
+ for i, fish in enumerate(json_result['fish']):
40
+ calculate_fish_paths(json_result, dataset, i)
41
+
42
+ draw_combined_fish_graphs(pdf, json_result)
43
+
44
+ for i, fish in enumerate(json_result['fish']):
45
+ generate_fish_tracks(pdf, json_result, i)
46
+
47
+ # We can also set the file's metadata via the PdfPages object:
48
+ d = pdf.infodict()
49
+ d['Title'] = 'Multipage PDF Example'
50
+ d['Author'] = 'Oskar Åström'
51
+ d['Subject'] = 'How to create a multipage pdf file and set its metadata'
52
+ d['Keywords'] = ''
53
+ d['CreationDate'] = datetime.datetime.today()
54
+ d['ModDate'] = datetime.datetime.today()
55
+
56
+
57
+ def generate_title_page(pdf, metadata, state):
58
+ # set up figure that will be used to display the opening banner
59
+ fig = plt.figure(figsize=STANDARD_FIG_SIZE)
60
+ plt.axis('off')
61
+
62
+ title_font_size = 40
63
+ minor_font_size = 20
64
+
65
+ # stuff to be printed out on the first page of the report
66
+ plt.text(0.5,-0.5,f'{metadata["FILE_NAME"].split("/")[-1]}',fontsize=title_font_size, horizontalalignment='center')
67
+
68
+ plt.text(0,1,f'Duration: {metadata["TOTAL_TIME"]}',fontsize=minor_font_size)
69
+ plt.text(0,1.5,f'Frames: {metadata["TOTAL_FRAMES"]}',fontsize=minor_font_size)
70
+ plt.text(0,2,f'Frame Rate: {metadata["FRAME_RATE"]}',fontsize=minor_font_size)
71
+
72
+ plt.text(0.5,1,f'Time of filming: {metadata["DATE"]} ({metadata["START"]} - {metadata["END"]})',fontsize=minor_font_size)
73
+ plt.text(0.5,1.5,f'Web app version: {state["version"]}',fontsize=minor_font_size)
74
+
75
+ plt.text(1.1,4.5,f'PDF generated on {datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}',fontsize=minor_font_size, horizontalalignment='right')
76
+
77
+ plt.ylim([-1, 4])
78
+ plt.xlim([0, 1])
79
+ plt.gca().invert_yaxis()
80
+
81
+ pdf.savefig(fig)
82
+ plt.close(fig)
83
+
84
+ def generate_global_result(pdf, fish_info):
85
+ # set up figure that will be used to display the opening banner
86
+ fig = plt.figure(figsize=STANDARD_FIG_SIZE)
87
+ plt.axis('off')
88
+ # stuff to be printed out on the first page of the report
89
+
90
+ minor_font_size = 18
91
+
92
+ headers = ["Result", "Camera Info", "Hyperparameters"]
93
+ info_col_1 = []
94
+ info_col_2 = []
95
+ info_col = info_col_1
96
+ row_state = -1
97
+ for row in fish_info:
98
+ if row_state >= 0:
99
+ info_col.append([row[0].replace("**","").replace("_", " ").lower(), row[1], 'normal'])
100
+ if (row[0] == "****"):
101
+ row_state += 1
102
+ if row_state == 2: info_col = info_col_2
103
+ info_col.append([headers[row_state], "", 'bold'])
104
+ for row_i, row in enumerate(info_col_1):
105
+ h = -1 + 5*row_i/len(info_col_1)
106
+ plt.text(0, h, row[0], fontsize=minor_font_size, weight=row[2])
107
+ plt.text(0.25, h, row[1], fontsize=minor_font_size, weight=row[2])
108
+ for row_i, row in enumerate(info_col_2):
109
+ h = -1 + 5*row_i/len(info_col_2)
110
+ plt.text(0.5, h, row[0], fontsize=minor_font_size, weight=row[2])
111
+ plt.text(0.75, h, row[1], fontsize=minor_font_size, weight=row[2])
112
+ plt.ylim([-1, 4])
113
+ plt.xlim([0, 1])
114
+ plt.gca().invert_yaxis()
115
+
116
+ pdf.savefig(fig)
117
+ plt.close(fig)
118
+
119
+ def generate_fish_list(pdf, table_headers, fish_table):
120
+ # set up figure that will be used to display the opening banner
121
+ fig = plt.figure(figsize=STANDARD_FIG_SIZE)
122
+ plt.axis('off')
123
+ # stuff to be printed out on the first page of the report
124
+
125
+ title_font_size = 40
126
+ header_font_size = 12
127
+ body_font_size = 20
128
+
129
+ # Title
130
+ plt.text(0.5,-1.3,f'{"Identified Fish"}',fontsize=title_font_size, horizontalalignment='center', weight='bold')
131
+
132
+ # Identified fish
133
+ row_h = 0.25
134
+ col_start = 0
135
+ row_l = 1
136
+ dropout_i = None
137
+ for col_i, col in enumerate(table_headers):
138
+ x = col_start + row_l*(col_i+0.5)/len(table_headers)
139
+ if col == "TOTAL": col = "ID"
140
+ if col == "DETECTION_DROPOUT":
141
+ col = "frame drop rate"
142
+ dropout_i = col_i
143
+ col = col.lower().replace("_", " ")
144
+ plt.text(x, -1, col, fontsize=header_font_size, horizontalalignment='center', weight="bold")
145
+ plt.plot([col_start*2, -col_start*2 + row_l], [-1 + 0.05, -1 + 0.05], color='black')
146
+
147
+ for row_i, row in enumerate(fish_table):
148
+ y = -1 + (row_i+1)*row_h
149
+ plt.plot([col_start*2, -col_start*2 + row_l], [y+0.05, y+0.05], color='black')
150
+ for col_i, col in enumerate(row):
151
+ x = col_start + row_l*(col_i+0.5)/len(row)
152
+ if (col_i == dropout_i):
153
+ col = str(int(col*100)) + "%"
154
+ elif type(col) == float:
155
+ col = "{:.4f}".format(col)
156
+ plt.text(x, y, col, fontsize=body_font_size, horizontalalignment='center')
157
+ plt.ylim([-1, 4])
158
+ plt.xlim([0, 1])
159
+ plt.gca().invert_yaxis()
160
+
161
+ pdf.savefig(fig)
162
+ plt.close(fig)
163
+
164
+ def calculate_fish_paths(result, dataset, id):
165
+
166
+ fish = result['metadata']['FISH'][id]
167
+ start_frame = fish['START_FRAME']
168
+ end_frame = fish['END_FRAME']
169
+
170
+ # Extract base frame (first frame for that fish)
171
+ w, h = 1, 2
172
+ img = None
173
+ if (dataset is not None):
174
+
175
+ images = dataset.didson.load_frames(start_frame=start_frame, end_frame=start_frame+1)
176
+ img = images[0]
177
+
178
+ frame_height = 2
179
+ scale_factor = frame_height / h
180
+ h = frame_height
181
+ w = int(scale_factor*w)
182
+
183
+ fish['base_frame'] = img
184
+ fish['scaled_frame_size'] = (w, h)
185
+
186
+
187
+ # Find frames for this fish
188
+ bboxes = []
189
+ for frame in result['frames'][start_frame:end_frame+1]:
190
+ bbox = None
191
+ for ann in frame['fish']:
192
+ if ann['fish_id'] == id+1:
193
+ bbox = ann
194
+ bboxes.append(bbox)
195
+
196
+
197
+ # Calculate tracks through frames
198
+ missed = 0
199
+ X = []
200
+ Y = []
201
+ V = []
202
+ certainty = []
203
+ for bbox in bboxes:
204
+ if bbox is not None:
205
+
206
+ # Find fish centers
207
+ x = (bbox['bbox'][0] + bbox['bbox'][2])/2
208
+ y = (bbox['bbox'][1] + bbox['bbox'][3])/2
209
+
210
+ # Calculate velocity
211
+ v = None
212
+ if len(X) > 0:
213
+ last_x = X[-1]
214
+ last_y = Y[-1]
215
+ dx = result['image_meter_width']*(last_x - x)/(missed+1)
216
+ dy = result['image_meter_height']*(last_y - y)/(missed+1)
217
+ v = math.sqrt(dx*dx + dy*dy)
218
+
219
+ # Interpolate over missing frames
220
+ if missed > 0:
221
+ for i in range(missed):
222
+ p = (i+1)/(missed+1)
223
+ X.append(last_x*(1-p) + p*x)
224
+ Y.append(last_y*(1-p) + p*y)
225
+ V.append(v)
226
+ certainty.append(False)
227
+
228
+ # Append new track frame
229
+ X.append(x)
230
+ Y.append(y)
231
+ if v is not None: V.append(v)
232
+ certainty.append(True)
233
+ missed = 0
234
+ else:
235
+ missed += 1
236
+
237
+ fish['path'] = {
238
+ 'X': X,
239
+ 'Y': Y,
240
+ 'certainty': certainty,
241
+ 'V': V
242
+ }
243
+
244
+
245
+ def draw_combined_fish_graphs(pdf, result):
246
+ vel = []
247
+ log_vel = []
248
+ for fish in result['metadata']['FISH']:
249
+ vel += fish['path']['V']
250
+ log_vel += [math.log(v) for v in fish['path']['V']]
251
+
252
+ fig, axs = plt.subplots(2, 2, sharey=True, figsize=STANDARD_FIG_SIZE)
253
+ axs[0,0].hist(log_vel, bins=20)
254
+ axs[0,0].set_title('Fish Log-Velocities between frames')
255
+ axs[0,0].set_xlabel("Log-Velocity (log(m/frame))")
256
+ axs[0,1].hist(vel, bins=20)
257
+ axs[0,1].set_title('Fish Velocities between frames')
258
+ axs[0,1].set_xlabel("Velocity (m/frame)")
259
+
260
+ pdf.savefig(fig)
261
+ plt.close(fig)
262
+
263
+
264
+ def generate_fish_tracks(pdf, result, id):
265
+
266
+ fish = result['metadata']['FISH'][id]
267
+ start_frame = fish['START_FRAME']
268
+ end_frame = fish['END_FRAME']
269
+
270
+ fig, ax = plt.subplots(figsize=STANDARD_FIG_SIZE)
271
+ plt.axis('off')
272
+
273
+ w, h = fish['scaled_frame_size']
274
+ if (fish['base_frame'] is not None):
275
+ img = fish['base_frame']
276
+ img = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE)
277
+ plt.imshow(img, extent=(0, h, 0, w), cmap=plt.colormaps['Greys'])
278
+
279
+ # Title
280
+ plt.text(h/2,1.1,f'Fish {id+1} (frames {start_frame}-{end_frame})',fontsize=40, color="red", horizontalalignment='center', zorder=5)
281
+
282
+ X = fish['path']['X']
283
+ Y = fish['path']['Y']
284
+ certainty = fish['path']['certainty']
285
+
286
+ plt.text(h*(1-Y[0]), w*(1-X[0]), "Start", fontsize=15, weight="bold")
287
+ plt.text(h*(1-Y[-1]), w*(1-X[-1]), "End", fontsize=15, weight="bold")
288
+
289
+ colors = []
290
+ for i in range(1, len(X)):
291
+
292
+ certain = certainty[i]
293
+ fully_certain = certain
294
+ half_certain = certain
295
+ if i > 0:
296
+ fully_certain &= certainty[i-1]
297
+ half_certain |= certainty[i-1]
298
+
299
+ #color = 'yellow' if certain else 'orangered'
300
+ #plt.plot(h*(1-y), w*(1-x), marker='o', markersize=3, color=color, zorder=3)
301
+ col = 'yellow' if fully_certain else ('darkorange' if half_certain else 'orangered')
302
+ colors.append(col)
303
+ ax.plot([h*(1-Y[i-1]), h*(1-Y[i])], [w*(1-X[i-1]), w*(1-X[i])], color=col, linewidth=1)
304
+
305
+ for i in range(1, len(X)):
306
+ ax.plot(h*(1-Y[i]), w*(1-X[i]), color=colors[i], marker='o', markersize=3)
307
+
308
+
309
+ plt.ylim([0, w])
310
+ plt.xlim([0, h])
311
+ pdf.savefig(fig)
312
+ plt.close(fig)
gradio_scripts/result_ui.py CHANGED
@@ -1,5 +1,6 @@
1
  import gradio as gr
2
  import numpy as np
 
3
 
4
  js_update_tab_labels = """
5
  async () => {
@@ -14,7 +15,7 @@ js_update_tab_labels = """
14
  }
15
  """
16
 
17
- table_headers = ["TOTAL", "START_FRAME", "END_FRAME", "DIR", "R", "THETA", "L", "TRAVEL"]
18
  info_headers = [
19
  "TOTAL_TIME", "DATE", "START", "END", "FRAME_RATE", "",
20
  "TOTAL_FISH", "UPSTREAM_FISH", "DOWNSTREAM_FISH", "NONDIRECTIONAL_FISH", "",
@@ -38,6 +39,8 @@ def update_result(i, state, result, inference_handler):
38
 
39
  annotation_avaliable = not (result["aris_input"][i] == None)
40
 
 
 
41
  # Send update to UI, and to inference_handler to start next file inference
42
  return {
43
  zip_out: gr.update(value=result["path_zip"]),
 
1
  import gradio as gr
2
  import numpy as np
3
+ from gradio_scripts.pdf_handler import make_pdf
4
 
5
  js_update_tab_labels = """
6
  async () => {
 
15
  }
16
  """
17
 
18
+ table_headers = ["TOTAL", "START_FRAME", "END_FRAME", "DETECTION_DROPOUT", "DIR", "R", "THETA", "L", "TRAVEL"]
19
  info_headers = [
20
  "TOTAL_TIME", "DATE", "START", "END", "FRAME_RATE", "",
21
  "TOTAL_FISH", "UPSTREAM_FISH", "DOWNSTREAM_FISH", "NONDIRECTIONAL_FISH", "",
 
39
 
40
  annotation_avaliable = not (result["aris_input"][i] == None)
41
 
42
+ make_pdf(state['index']-1, state, result, table_headers)
43
+
44
  # Send update to UI, and to inference_handler to start next file inference
45
  return {
46
  zip_out: gr.update(value=result["path_zip"]),
gradio_scripts/upload_ui.py CHANGED
@@ -31,12 +31,16 @@ def Upload_Gradio(gradio_components):
31
  settings.append(gr.Slider(0, 100, value=16, label="Min Hits", info="Minimum number of frames a fish has to appear in to count"))
32
  settings.append(gr.Slider(0, 100, value=14, label="Max Age", info="Max age of occlusion before track is split"))
33
 
34
- with gr.Row():
35
- gr.Markdown("Associative Tracking")
36
- settings.append(gr.Checkbox(value=False, label="Enabled"))
37
- with gr.Row():
38
  settings.append(gr.Slider(0, 5, value=1, label="Boost Power", info=""))
39
  settings.append(gr.Slider(0, 1, value=1, label="Boost Decay", info=""))
 
 
 
 
 
40
 
41
  gr.Markdown("Other")
42
  with gr.Row():
 
31
  settings.append(gr.Slider(0, 100, value=16, label="Min Hits", info="Minimum number of frames a fish has to appear in to count"))
32
  settings.append(gr.Slider(0, 100, value=14, label="Max Age", info="Max age of occlusion before track is split"))
33
 
34
+ tracker = gr.Dropdown(["None", "Confidence Boost", "ByteTrack"], label="Associative Tracking", value="None")
35
+ settings.append(tracker)
36
+ with gr.Row(visible=False) as track_row:
 
37
  settings.append(gr.Slider(0, 5, value=1, label="Boost Power", info=""))
38
  settings.append(gr.Slider(0, 1, value=1, label="Boost Decay", info=""))
39
+ tracker.change(lambda x: gr.update(visible=(x=="Confidence Boost")), tracker, track_row)
40
+ with gr.Row(visible=False) as track_row:
41
+ settings.append(gr.Slider(0, 1, value=0.1, label="Low Conf Threshold", info=""))
42
+ settings.append(gr.Slider(0, 1, value=0.3, label="High Conf Threshold", info=""))
43
+ tracker.change(lambda x: gr.update(visible=(x=="ByteTrack")), tracker, track_row)
44
 
45
  gr.Markdown("Other")
46
  with gr.Row():
inference.py CHANGED
@@ -58,12 +58,9 @@ def do_full_inference(dataloader, image_meter_width, image_meter_height, gp=None
58
  if 'iou_thresh' not in hyperparams: hyperparams['iou_thresh'] = NMS_IOU
59
  if 'min_hits' not in hyperparams: hyperparams['min_hits'] = MIN_HITS
60
  if 'max_age' not in hyperparams: hyperparams['max_age'] = MAX_AGE
61
- if 'use_associative_tracking' not in hyperparams: hyperparams['use_associative_tracking'] = False
62
- if 'boost_power' not in hyperparams: hyperparams['boost_power'] = 1
63
- if 'boost_decay' not in hyperparams: hyperparams['maxboost_decay_age'] = 1
64
- if 'AT_decay' not in hyperparams: hyperparams['AT_decay'] = MIN_HITS
65
  if 'min_length' not in hyperparams: hyperparams['min_length'] = MIN_LENGTH
66
  if 'min_travel' not in hyperparams: hyperparams['min_travel'] = MIN_TRAVEL
 
67
 
68
  model, device = setup_model(hyperparams['model'])
69
 
@@ -95,16 +92,41 @@ def do_full_inference(dataloader, image_meter_width, image_meter_height, gp=None
95
 
96
  outputs = do_suppression(inference, conf_thres=hyperparams['conf_thresh'], iou_thres=hyperparams['iou_thresh'], gp=gp)
97
 
98
- if hyperparams['use_associative_tracking']:
 
 
99
 
100
- do_confidence_boost(inference, outputs, conf_power=hyperparams['boost_power'], conf_decay=hyperparams['boost_decay'], gp=gp)
 
101
 
102
- outputs = do_suppression(inference, conf_thres=hyperparams['conf_thresh'], iou_thres=hyperparams['iou_thresh'], gp=gp)
 
103
 
104
- all_preds, real_width, real_height = format_predictions(image_shapes, outputs, width, height, gp=gp)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
 
106
- results = do_tracking(all_preds, image_meter_width, image_meter_height, min_hits=hyperparams['min_hits'], max_age=hyperparams['max_age'], min_length=hyperparams['min_length'], min_travel=hyperparams['min_travel'], gp=gp)
107
 
 
 
 
 
 
 
108
  return results
109
 
110
 
 
58
  if 'iou_thresh' not in hyperparams: hyperparams['iou_thresh'] = NMS_IOU
59
  if 'min_hits' not in hyperparams: hyperparams['min_hits'] = MIN_HITS
60
  if 'max_age' not in hyperparams: hyperparams['max_age'] = MAX_AGE
 
 
 
 
61
  if 'min_length' not in hyperparams: hyperparams['min_length'] = MIN_LENGTH
62
  if 'min_travel' not in hyperparams: hyperparams['min_travel'] = MIN_TRAVEL
63
+ if 'associative_tracker' not in hyperparams: hyperparams['associative_tracker'] = "None"
64
 
65
  model, device = setup_model(hyperparams['model'])
66
 
 
92
 
93
  outputs = do_suppression(inference, conf_thres=hyperparams['conf_thresh'], iou_thres=hyperparams['iou_thresh'], gp=gp)
94
 
95
+ if hyperparams['associative_tracker'] == "ByteTrack":
96
+ if 'byte_low_conf' not in hyperparams: hyperparams['byte_low_conf'] = 0.1
97
+ if 'byte_high_conf' not in hyperparams: hyperparams['byte_high_conf'] = 0.3
98
 
99
+ low_outputs = do_suppression(inference, conf_thres=hyperparams['low_conf_threshold'], iou_thres=hyperparams['iou_thresh'], gp=gp)
100
+ low_preds, real_width, real_height = format_predictions(image_shapes, low_outputs, width, height, gp=gp)
101
 
102
+ high_outputs = do_suppression(inference, conf_thres=hyperparams['high_conf_threshold'], iou_thres=hyperparams['iou_thresh'], gp=gp)
103
+ high_preds, real_width, real_height = format_predictions(image_shapes, high_outputs, width, height, gp=gp)
104
 
105
+ results = do_associative_tracking(
106
+ low_preds, high_preds, image_meter_width, image_meter_height,
107
+ reverse=False, min_length=hyperparams['min_length'], min_travel=hyperparams['min_travel'],
108
+ max_age=hyperparams['max_age'], min_hits=hyperparams['min_hits'],
109
+ gp=gp)
110
+ else:
111
+
112
+ outputs = do_suppression(inference, conf_thres=hyperparams['conf_threshold'], iou_thres=hyperparams['iou_thresh'], gp=gp)
113
+
114
+ if hyperparams['associative_tracker'] == "Confidence Boost":
115
+ if 'boost_power' not in hyperparams: hyperparams['boost_power'] = 1
116
+ if 'boost_decay' not in hyperparams: hyperparams['boost_decay'] = 1
117
+
118
+ do_confidence_boost(inference, outputs, boost_power=hyperparams['boost_power'], boost_decay=hyperparams['boost_decay'], gp=gp)
119
+
120
+ outputs = do_suppression(inference, conf_thres=hyperparams['conf_threshold'], iou_thres=hyperparams['iou_thresh'], gp=gp)
121
 
122
+ all_preds, real_width, real_height = format_predictions(image_shapes, outputs, width, height, gp=gp)
123
 
124
+ results = do_tracking(
125
+ all_preds, image_meter_width, image_meter_height,
126
+ min_length=hyperparams['min_length'], min_travel=hyperparams['min_travel'],
127
+ max_age=hyperparams['max_age'], iou_thres=hyperparams['iou_threshold'], min_hits=hyperparams['min_hits'],
128
+ gp=gp)
129
+
130
  return results
131
 
132
 
multipage_pdf.pdf ADDED
Binary file (940 kB). View file