sfmig commited on
Commit
596887e
β€’
1 Parent(s): ffd4fc3

refactoring. separated fns in library scripts. skipped megadetector inference if DLC only selected. not downloading model again if existing.

Browse files
Files changed (6) hide show
  1. DLC_models/download_utils.py +61 -0
  2. app.py +83 -310
  3. detection_utils.py +116 -0
  4. save_results.py +0 -56
  5. ui_utils.py +81 -0
  6. viz_utils.py +165 -0
DLC_models/download_utils.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import urllib.request
2
+ import tarfile
3
+ from tqdm import tqdm
4
+ import os
5
+ import yaml
6
+ from ruamel.yaml import YAML
7
+
8
+ def read_plainconfig(configname):
9
+ if not os.path.exists(configname):
10
+ raise FileNotFoundError(
11
+ f"Config {configname} is not found. Please make sure that the file exists."
12
+ )
13
+ with open(configname) as file:
14
+ return YAML().load(file)
15
+
16
+ def DownloadModel(modelname,
17
+ target_dir):
18
+ """
19
+ Downloads a DeepLabCut Model Zoo Project
20
+ """
21
+
22
+ def show_progress(count, block_size, total_size):
23
+ pbar.update(block_size)
24
+
25
+ def tarfilenamecutting(tarf):
26
+ """' auxfun to extract folder path
27
+ ie. /xyz-trainsetxyshufflez/
28
+ """
29
+ for memberid, member in enumerate(tarf.getmembers()):
30
+ if memberid == 0:
31
+ parent = str(member.path)
32
+ l = len(parent) + 1
33
+ if member.path.startswith(parent):
34
+ member.path = member.path[l:]
35
+ yield member
36
+
37
+ neturls = read_plainconfig("DLC_models/pretrained_model_urls.yaml") #FIXME
38
+
39
+ if modelname in neturls.keys():
40
+ url = neturls[modelname]
41
+ print(url)
42
+ response = urllib.request.urlopen(url)
43
+ print(
44
+ "Downloading the model from the DeepLabCut server @Harvard -> Go Crimson!!! {}....".format(
45
+ url
46
+ )
47
+ )
48
+ total_size = int(response.getheader("Content-Length"))
49
+ pbar = tqdm(unit="B", total=total_size, position=0)
50
+ filename, _ = urllib.request.urlretrieve(url, reporthook=show_progress)
51
+ with tarfile.open(filename, mode="r:gz") as tar:
52
+ tar.extractall(target_dir, members=tarfilenamecutting(tar))
53
+ else:
54
+ models = [
55
+ fn
56
+ for fn in neturls.keys()
57
+ if "resnet_" not in fn and "mobilenet_" not in fn
58
+ ]
59
+ print("Model does not exist: ", modelname)
60
+ print("Pick one of the following: ", models)
61
+ return target_dir
app.py CHANGED
@@ -2,215 +2,41 @@
2
  # Built from https://huggingface.co/spaces/sofmi/MegaDetector_DLClive/blob/main/app.py
3
  # Built from https://huggingface.co/spaces/Neslihan/megadetector_dlcmodels/blob/main/app.py
4
 
5
-
6
- from tkinter import W
7
- import gradio as gr
8
- from matplotlib import cm
9
- import torch
10
- import torchvision
11
- from dlclive import DLCLive, Processor
12
- import matplotlib
13
- from PIL import Image, ImageColor, ImageFont, ImageDraw
14
- # check git lfs pull!!
15
- import numpy as np
16
- import math
17
-
18
- # import json
19
  import os
20
  import yaml
21
- from model.models import DownloadModel
22
- from save_results import save_results
23
- import pdb
24
-
25
- #########################################
26
- # Input params
27
- FONTS = {'amiko': "font/Amiko-Regular.ttf",
28
- 'nature': "font/LoveNature.otf",
29
- 'painter':"font/PainterDecorator.otf",
30
- 'animals': "font/UncialAnimals.ttf",
31
- 'zen': "font/ZEN.TTF"}
32
-
33
- Megadet_Models = {'md_v5a': "megadet_model/md_v5a.0.0.pt",
34
- 'md_v5b': "megadet_model/md_v5b.0.0.pt"}
35
-
36
- DLC_folders = {'full_cat': "model/DLC_Cat/",
37
- 'full_dog': "model/DLC_Dog/",
38
- 'primate_face': "model/DLC_FacialLandmarks/",
39
- 'full_human': "model/DLC_human_dancing/",
40
- 'full_macaque': 'model/DLC_monkey/'}
41
-
42
- DLC_models_list = ['full_cat', 'full_dog','primate_face', 'full_human', 'full_macaque']
43
- #########################################
44
- # Draw keypoints on image
45
- def draw_keypoints_on_image(image,
46
- keypoints,
47
- map_label_id_to_str,
48
- flag_show_str_labels,
49
- use_normalized_coordinates=True,
50
- font_style='amiko',
51
- font_size=8,
52
- keypt_color="#ff0000",
53
- marker_size=2,
54
- ):
55
- """Draws keypoints on an image.
56
- Modified from:
57
- https://www.programcreek.com/python/?code=fjchange%2Fobject_centric_VAD%2Fobject_centric_VAD-master%2Fobject_detection%2Futils%2Fvisualization_utils.py
58
- Args:
59
- image: a PIL.Image object.
60
- keypoints: a numpy array with shape [num_keypoints, 2].
61
- map_label_id_to_str: dict with keys=label number and values= label string
62
- flag_show_str_labels: boolean to select whether or not to show string labels
63
- color: color to draw the keypoints with. Default is red.
64
- radius: keypoint radius. Default value is 2.
65
- use_normalized_coordinates: if True (default), treat keypoint values as
66
- relative to the image. Otherwise treat them as absolute.
67
-
68
-
69
- """
70
- # get a drawing context
71
- draw = ImageDraw.Draw(image,"RGBA")
72
-
73
- im_width, im_height = image.size
74
- keypoints_x = [k[0] for k in keypoints]
75
- keypoints_y = [k[1] for k in keypoints]
76
- alpha = [k[2] for k in keypoints]
77
- norm = matplotlib.colors.Normalize(vmin=0, vmax=255)
78
-
79
- names_for_color = [i for i in map_label_id_to_str.keys()]
80
- colores = np.linspace(0, 255, num=len(names_for_color),dtype= int)
81
-
82
- # adjust keypoints coords if required
83
- if use_normalized_coordinates:
84
- keypoints_x = tuple([im_width * x for x in keypoints_x])
85
- keypoints_y = tuple([im_height * y for y in keypoints_y])
86
-
87
- #cmap = matplotlib.cm.get_cmap('hsv')
88
- cmap2 = matplotlib.cm.get_cmap('Greys')
89
- # draw ellipses around keypoints
90
- for i, (keypoint_x, keypoint_y) in enumerate(zip(keypoints_x, keypoints_y)):
91
- round_fill = list(cm.viridis(norm(colores[i]),bytes=True))#[round(num*255) for num in list(cmap(i))[:3]] #check!
92
- if np.isnan(alpha[i]) == False :
93
- round_fill[3] = round(alpha[i] *255)
94
- #print(round_fill)
95
- #round_outline = [round(num*255) for num in list(cmap2(alpha[i]))[:3]]
96
- draw.ellipse([(keypoint_x - marker_size, keypoint_y - marker_size),
97
- (keypoint_x + marker_size, keypoint_y + marker_size)],
98
- fill=tuple(round_fill), outline= 'black', width=1) #fill and outline: [0,255]
99
-
100
- # add string labels around keypoints
101
- if flag_show_str_labels:
102
- font = ImageFont.truetype(FONTS[font_style],
103
- font_size)
104
- draw.text((keypoint_x + marker_size, keypoint_y + marker_size),#(0.5*im_width, 0.5*im_height), #-------
105
- map_label_id_to_str[i],
106
- ImageColor.getcolor(keypt_color, "RGB"), # rgb #
107
- font=font)
108
-
109
- ############################################
110
- # Predict detections with MegaDetector v5a model
111
- def predict_md(im,
112
- mega_model_input,
113
- size=640):
114
-
115
- # resize image
116
- g = (size / max(im.size)) # multipl factor to make max size of the image equal to input size
117
- im = im.resize((int(x * g) for x in im.size),
118
- Image.ANTIALIAS) # resize
119
- MD_model = torch.hub.load('ultralytics/yolov5', 'custom', Megadet_Models[mega_model_input])
120
-
121
- ## detect objects
122
- results = MD_model(im) # inference # vars(results).keys()= dict_keys(['imgs', 'pred', 'names', 'files', 'times', 'xyxy', 'xywh', 'xyxyn', 'xywhn', 'n', 't', 's'])
123
-
124
- return results
125
-
126
- ##########################################
127
- def crop_animal_detections(img_in,
128
- yolo_results,
129
- likelihood_th):
130
-
131
- ## Extract animal crops
132
- list_labels_as_str = [i for i in yolo_results.names.values()] # ['animal', 'person', 'vehicle']
133
- list_np_animal_crops = []
134
-
135
- # image to crop (scale as input for megadetector)
136
- img_in = img_in.resize((yolo_results.ims[0].shape[1],
137
- yolo_results.ims[0].shape[0]))
138
- # for every detection in the img
139
- for det_array in yolo_results.xyxy:
140
-
141
- # for every detection
142
- for j in range(det_array.shape[0]):
143
-
144
- # compute coords around bbox rounded to the nearest integer (for pasting later)
145
- xmin_rd = int(math.floor(det_array[j,0])) # int() should suffice?
146
- ymin_rd = int(math.floor(det_array[j,1]))
147
-
148
- xmax_rd = int(math.ceil(det_array[j,2]))
149
- ymax_rd = int(math.ceil(det_array[j,3]))
150
-
151
- pred_llk = det_array[j,4]
152
- pred_label = det_array[j,5]
153
- # keep animal crops above threshold
154
- if (pred_label == list_labels_as_str.index('animal')) and \
155
- (pred_llk >= likelihood_th):
156
- area = (xmin_rd, ymin_rd, xmax_rd, ymax_rd)
157
-
158
- #pdb.set_trace()
159
- crop = img_in.crop(area) #Image.fromarray(img_in).crop(area)
160
- crop_np = np.asarray(crop)
161
-
162
- # add to list
163
- list_np_animal_crops.append(crop_np)
164
-
165
- return list_np_animal_crops
166
-
167
- def draw_rectangle_text(img,results,font_style='amiko',font_size=8, keypt_color="white",):
168
- #pdb.set_trace()
169
- bbxyxy = results
170
- w, h = bbxyxy[2], bbxyxy[3]
171
- shape = [(bbxyxy[0], bbxyxy[1]), (w , h)]
172
- imgR = ImageDraw.Draw(img)
173
- imgR.rectangle(shape, outline ="red",width=5) ##bb for animal
174
 
175
- confidence = bbxyxy[4]
176
- string_bb = 'animal ' + str(round(confidence, 2))
177
- font = ImageFont.truetype(FONTS[font_style], font_size)
178
 
179
- text_size = font.getsize(string_bb) # (h,w)
180
- position = (bbxyxy[0],bbxyxy[1] - text_size[1] -2 )
181
- left, top, right, bottom = imgR.textbbox(position, string_bb, font=font)
182
- imgR.rectangle((left, top-5, right+5, bottom+5), fill="red")
183
- imgR.text((bbxyxy[0] + 3 ,bbxyxy[1] - text_size[1] -2 ), string_bb, font=font, fill="black")
184
 
185
- return imgR
 
 
186
 
187
- ##########################################
188
- def predict_dlc(list_np_crops,
189
- kpts_likelihood_th,
190
- DLCmodel,
191
- dlc_proc):
192
-
193
- # run dlc thru list of crops
194
- dlc_live = DLCLive(DLCmodel, processor=dlc_proc)
195
- dlc_live.init_inference(list_np_crops[0])
196
 
197
- list_kpts_per_crop = []
198
- all_kypts = []
199
- np_aux = np.empty((1,3)) # can I avoid hardcoding here?
200
- for crop in list_np_crops:
201
- # scale crop here?
202
- keypts_xyp = dlc_live.get_pose(crop) # third column is llk!
203
- # set kpts below threhsold to nan
204
-
205
- #pdb.set_trace()
206
- keypts_xyp[keypts_xyp[:,-1] < kpts_likelihood_th,:] = np_aux.fill(np.nan)
207
- # add kpts of this crop to list
208
- list_kpts_per_crop.append(keypts_xyp)
209
- all_kypts.append(keypts_xyp)
210
-
211
- return list_kpts_per_crop
212
 
 
 
 
 
 
 
 
213
 
 
 
 
 
 
214
  #####################################################
215
  def predict_pipeline(img_input,
216
  mega_model_input,
@@ -225,36 +51,41 @@ def predict_pipeline(img_input,
225
  marker_size,
226
  ):
227
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228
  ############################################################
229
- ## Get DLC model and labels as strings
230
- # TODO: make a dict as for megadetector
231
- # pdb.set_trace()
232
- path_to_DLCmodel = DownloadModel(dlc_model_input_str, str(DLC_folders[dlc_model_input_str]) )
233
- pose_cfg_path = str(DLC_folders[dlc_model_input_str]) +'pose_cfg.yaml'
234
- #pdb.set_trece()
 
 
 
 
 
235
  # extract map label ids to strings
236
- # pose_cfg_dict['all_joints'] is a list of one-element lists,
 
237
  with open(pose_cfg_path, "r") as stream:
238
  pose_cfg_dict = yaml.safe_load(stream)
239
-
240
- map_label_id_to_str = dict([(k,v) for k,v in zip([el[0] for el in pose_cfg_dict['all_joints']],
241
  pose_cfg_dict['all_joints_names'])])
242
 
243
- ############################################################
244
- # ### Run Megadetector
245
- md_results = predict_md(img_input,
246
- mega_model_input,
247
- size=640) #Image.fromarray(results.imgs[0])
248
- #pdb.set_trace()
249
- ################################################################
250
- # Obtain animal crops for bboxes with confidence above th
251
-
252
- list_crops = crop_animal_detections(img_input,
253
- md_results,
254
- bbox_likelihood_th)
255
-
256
  ##############################################################
257
- # Run DLC
258
  dlc_proc = Processor()
259
 
260
  # if required: ignore MD crops and run DLC on full image [mostly for testing]
@@ -284,14 +115,17 @@ def predict_pipeline(img_input,
284
  path_to_DLCmodel,
285
  dlc_proc)
286
 
287
- img_background = img_input.resize((md_results.ims[0].shape[1],md_results.ims[0].shape[0]))
288
- print('I have ' + str(len(list_crops)) + ' bounding box')
 
289
 
 
290
  for ic, (np_crop, kpts_crop) in enumerate(zip(list_crops,
291
  list_kpts_per_crop)):
292
 
293
- ## Draw keypts on crop
294
  img_crop = Image.fromarray(np_crop)
 
 
295
  draw_keypoints_on_image(img_crop,
296
  kpts_crop, # a numpy array with shape [num_keypoints, 2].
297
  map_label_id_to_str,
@@ -302,106 +136,45 @@ def predict_pipeline(img_input,
302
  keypt_color=keypt_color,
303
  marker_size=marker_size)
304
 
305
- ## Paste crop in original image https://pillow.readthedocs.io/en/stable/reference/Image.html#PIL.Image.Image.paste
306
- img_background.paste(img_crop, box = tuple([int(t) for t in md_results.xyxy[0][ic,:2]]))
307
-
308
 
309
-
310
  bb_per_animal = md_results.xyxy[0].tolist()[ic]
311
  pred = md_results.xyxy[0].tolist()[ic][4]
312
  if bbox_likelihood_th < pred:
313
- draw_rectangle_text(img_background, bb_per_animal ,font_style=font_style,font_size=font_size, keypt_color=keypt_color)
 
 
 
314
 
315
 
316
-
317
- download_file = save_results(md_results,list_kpts_per_crop,map_label_id_to_str,bbox_likelihood_th)
 
 
 
318
 
319
  return img_background, download_file
320
 
321
- #############################################
322
- # User interface: inputs
323
- # Input image
324
- gr_image_input = gr.inputs.Image(type="pil", label="Input Image")
325
-
326
-
327
- # Models
328
- gr_dlc_model_input = gr.inputs.Dropdown(choices=list(DLC_models_list), # choices
329
- default='full_cat', # default option
330
- type='value', # Type of value to be returned by component. "value" returns the string of the choice selected, "index" returns the index of the choice selected.
331
- label='Select DeepLabCut model')
332
- gr_mega_model_input = gr.inputs.Dropdown(choices=list(Megadet_Models.keys()),
333
- default='md_v5a', # default option
334
- type='value', # Type of value to be returned by component. "value" returns the string of the choice selected, "index" returns the index of the choice selected.
335
- label='Select MegaDetector model')
336
- # Other inputs
337
- gr_dlc_only_checkbox = gr.inputs.Checkbox(False,
338
- label='Run DLClive only, directly on input image?')
339
- gr_str_labels_checkbox = gr.inputs.Checkbox(True,
340
- label='Show bodypart labels?')
341
-
342
- gr_slider_conf_bboxes = gr.inputs.Slider(0,1,.02,0.8,
343
- label='Set confidence threshold for animal detections')
344
- gr_slider_conf_keypoints = gr.inputs.Slider(0,1,.05,0,
345
- label='Set confidence threshold for keypoints')
346
-
347
- # Data viz
348
- gr_keypt_color = gr.ColorPicker(label="choose color for keypoint label")
349
-
350
- gr_labels_font_style = gr.inputs.Dropdown(choices=['amiko', 'nature', 'painter', 'animals', 'zen'],
351
- default='amiko',
352
- type='value',
353
- label='Select keypoint label font')
354
- gr_slider_font_size = gr.inputs.Slider(5,30,1,8,
355
- label='Set font size')
356
- gr_slider_marker_size = gr.inputs.Slider(1,20,1,5,
357
- label='Set marker size')
358
-
359
- # list of inputs
360
- inputs = [gr_image_input,
361
- gr_mega_model_input,
362
- gr_dlc_model_input,
363
- gr_dlc_only_checkbox,
364
- gr_str_labels_checkbox,
365
- gr_slider_conf_bboxes,
366
- gr_slider_conf_keypoints,
367
- gr_labels_font_style,
368
- gr_slider_font_size,
369
- gr_keypt_color,
370
- gr_slider_marker_size,
371
- ]
372
- ####################################################
373
- # %%
374
- # User interface: outputs
375
- gr_image_output = gr.outputs.Image(type="pil", label="Output Image")
376
- out_smpl_npy_download = gr.File(label="Download JSON file")
377
- outputs = [gr_image_output,out_smpl_npy_download]
378
-
379
- ##############################################
380
- # User interace: description
381
- gr_title = "MegaDetector v5 + DeepLabCut-Live!"
382
- gr_description = "Contributed by Sofia Minano, Neslihan Wittek, Nirel Kadzo, VicShaoChih Chiang, Sabrina Benas -- DLC AI Residents 2022..\
383
- This App detects and estimate the pose of animals in camera trap images using <a href='https://github.com/microsoft/CameraTraps'>MegaDetector v5a</a> + <a href='https://github.com/DeepLabCut/DeepLabCut-live'>DeepLabCut-live</a>. \
384
- We host models from the <a href='http://www.mackenziemathislab.org/dlc-modelzoo'>DeepLabCut ModelZoo Project</a>\, and two <a href='https://github.com/microsoft/CameraTraps/blob/main/megadetector.md'>MegaDetector Models</a>. Please carefully check their licensing information if you use this project. The App additionally builds upon on work from <a href='https://huggingface.co/spaces/hlydecker/MegaDetector_v5'>hlydecker/MegaDetector_v5</a> \
385
- <a href='https://huggingface.co/spaces/sofmi/MegaDetector_DLClive'>sofmi/MegaDetector_DLClive</a> \
386
- <a href='https://huggingface.co/spaces/Neslihan/megadetector_dlcmodels'>Neslihan/megadetector_dlcmodels</a>\."
387
-
388
- # article = "<p style='text-align: center'>This app makes predictions using a YOLOv5x6 model that was trained to detect animals, humans, and vehicles in camera trap images; find out more about the project on <a href='https://github.com/microsoft/CameraTraps'>GitHub</a>. This app was built by Henry Lydecker but really depends on code and models developed by <a href='http://ecologize.org/'>Ecologize</a> and <a href='http://aka.ms/aiforearth'>Microsoft AI for Earth</a>. Find out more about the YOLO model from the original creator, <a href='https://pjreddie.com/darknet/yolo/'>Joseph Redmon</a>. YOLOv5 is a family of compound-scaled object detection models trained on the COCO dataset and developed by Ultralytics, and includes simple functionality for Test Time Augmentation (TTA), model ensembling, hyperparameter evolution, and export to ONNX, CoreML and TFLite. <a href='https://github.com/ultralytics/yolov5'>Source code</a> | <a href='https://pytorch.org/hub/ultralytics_yolov5'>PyTorch Hub</a></p>"
389
-
390
- examples = [['example/monkey_full.jpg', 'md_v5a','full_macaque', False, True, 0.5, 0.3, 'amiko', 9, 'blue', 3],
391
- ['example/dog.jpeg', 'md_v5a', 'full_dog', False, True, 0.5, 0.00, 'amiko',9, 'yellow', 3],
392
- ['example/cat.jpg', 'md_v5a', 'full_cat', False, True, 0.5, 0.05, 'amiko', 9, 'purple', 3]]
393
 
394
- ################################################
395
- # %% Define and launch gradio interface
396
  demo = gr.Interface(predict_pipeline,
397
  inputs=inputs,
398
  outputs=outputs,
399
  title=gr_title,
400
  description=gr_description,
401
  examples=examples,
402
- theme="huggingface",
403
- #live=True
404
- )
405
 
406
  demo.launch(enable_queue=True, share=True)
407
 
 
2
  # Built from https://huggingface.co/spaces/sofmi/MegaDetector_DLClive/blob/main/app.py
3
  # Built from https://huggingface.co/spaces/Neslihan/megadetector_dlcmodels/blob/main/app.py
4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  import os
6
  import yaml
7
+ import numpy as np
8
+ from matplotlib import cm
9
+ import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
+ from PIL import Image, ImageColor, ImageFont, ImageDraw
 
 
12
 
13
+ from DLC_models.download_utils import DownloadModel
14
+ from dlclive import DLCLive, Processor
 
 
 
15
 
16
+ from viz_utils import save_results_as_json, draw_keypoints_on_image, draw_bbox_w_text
17
+ from detection_utils import predict_md, crop_animal_detections, predict_dlc
18
+ from ui_utils import gradio_inputs_for_MD_DLC, gradio_outputs_for_MD_DLC, gradio_description_and_examples
19
 
20
+ # import pdb
21
+ #########################################
22
+ # Input params - Global vars
 
 
 
 
 
 
23
 
24
+ MD_models_dict = {'md_v5a': "MD_models/md_v5a.0.0.pt", #
25
+ 'md_v5b': "MD_models/md_v5b.0.0.pt"}
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
+ # DLC models target dirs
28
+ DLC_models_dict = {'full_cat': "DLC_models/DLC_Cat/",
29
+ 'full_dog': "DLC_models/DLC_Dog/",
30
+ 'primate_face': "DLC_models/DLC_FacialLandmarks/",
31
+ 'full_human': "DLC_models/DLC_human_dancing/",
32
+ 'full_macaque': 'DLC_models/DLC_monkey/'}
33
+
34
 
35
+ # FONTS = {'amiko': "fonts/Amiko-Regular.ttf",
36
+ # 'nature': "fonts/LoveNature.otf",
37
+ # 'painter':"fonts/PainterDecorator.otf",
38
+ # 'animals': "fonts/UncialAnimals.ttf",
39
+ # 'zen': "fonts/ZEN.TTF"}
40
  #####################################################
41
  def predict_pipeline(img_input,
42
  mega_model_input,
 
51
  marker_size,
52
  ):
53
 
54
+ if not flag_dlc_only:
55
+ ############################################################
56
+ # ### Run Megadetector
57
+ md_results = predict_md(img_input,
58
+ MD_models_dict[mega_model_input], #mega_model_input,
59
+ size=640) #Image.fromarray(results.imgs[0])
60
+
61
+ ################################################################
62
+ # Obtain animal crops for bboxes with confidence above th
63
+ list_crops = crop_animal_detections(img_input,
64
+ md_results,
65
+ bbox_likelihood_th)
66
+
67
  ############################################################
68
+ ## Get DLC model and label map
69
+
70
+ # If model is found: do not download (previous execution is likely within same day)
71
+ # TODO: can we ask the user whether to reload dlc model if a directory is found?
72
+ if os.path.isdir(DLC_models_dict[dlc_model_input_str]) and \
73
+ len(os.listdir(DLC_models_dict[dlc_model_input_str])) > 0:
74
+ path_to_DLCmodel = DLC_models_dict[dlc_model_input_str]
75
+ else:
76
+ path_to_DLCmodel = DownloadModel(dlc_model_input_str,
77
+ DLC_models_dict[dlc_model_input_str])
78
+
79
  # extract map label ids to strings
80
+ pose_cfg_path = os.path.join(DLC_models_dict[dlc_model_input_str],
81
+ 'pose_cfg.yaml')
82
  with open(pose_cfg_path, "r") as stream:
83
  pose_cfg_dict = yaml.safe_load(stream)
84
+ map_label_id_to_str = dict([(k,v) for k,v in zip([el[0] for el in pose_cfg_dict['all_joints']], # pose_cfg_dict['all_joints'] is a list of one-element lists,
 
85
  pose_cfg_dict['all_joints_names'])])
86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  ##############################################################
88
+ # Run DLC and visualise results
89
  dlc_proc = Processor()
90
 
91
  # if required: ignore MD crops and run DLC on full image [mostly for testing]
 
115
  path_to_DLCmodel,
116
  dlc_proc)
117
 
118
+ # resize input image to match megadetector output
119
+ img_background = img_input.resize((md_results.ims[0].shape[1],
120
+ md_results.ims[0].shape[0]))
121
 
122
+ # draw keypoints on each crop and paste to background img
123
  for ic, (np_crop, kpts_crop) in enumerate(zip(list_crops,
124
  list_kpts_per_crop)):
125
 
 
126
  img_crop = Image.fromarray(np_crop)
127
+
128
+ # Draw keypts on crop
129
  draw_keypoints_on_image(img_crop,
130
  kpts_crop, # a numpy array with shape [num_keypoints, 2].
131
  map_label_id_to_str,
 
136
  keypt_color=keypt_color,
137
  marker_size=marker_size)
138
 
139
+ # Paste crop in original image
140
+ img_background.paste(img_crop,
141
+ box = tuple([int(t) for t in md_results.xyxy[0][ic,:2]]))
142
 
143
+ # Plot bbox
144
  bb_per_animal = md_results.xyxy[0].tolist()[ic]
145
  pred = md_results.xyxy[0].tolist()[ic][4]
146
  if bbox_likelihood_th < pred:
147
+ draw_bbox_w_text(img_background,
148
+ bb_per_animal,
149
+ font_style=font_style,
150
+ font_size=font_size) # TODO: add selectable color for bbox?
151
 
152
 
153
+ # Save detection results as json
154
+ download_file = save_results_as_json(md_results,
155
+ list_kpts_per_crop,
156
+ map_label_id_to_str,
157
+ bbox_likelihood_th)
158
 
159
  return img_background, download_file
160
 
161
+ #########################################################
162
+ # Define user interface and launch
163
+ inputs = gradio_inputs_for_MD_DLC(list(MD_models_dict.keys()),
164
+ list(DLC_models_dict.keys()))
165
+ outputs = gradio_outputs_for_MD_DLC()
166
+ [gr_title,
167
+ gr_description,
168
+ examples] = gradio_description_and_examples()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
 
170
+ # launch
 
171
  demo = gr.Interface(predict_pipeline,
172
  inputs=inputs,
173
  outputs=outputs,
174
  title=gr_title,
175
  description=gr_description,
176
  examples=examples,
177
+ theme="huggingface")
 
 
178
 
179
  demo.launch(enable_queue=True, share=True)
180
 
detection_utils.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from tkinter import W
3
+ import gradio as gr
4
+ from matplotlib import cm
5
+ import torch
6
+ import torchvision
7
+ from dlclive import DLCLive, Processor
8
+ import matplotlib
9
+ from PIL import Image, ImageColor, ImageFont, ImageDraw
10
+ import numpy as np
11
+ import math
12
+
13
+
14
+ import yaml
15
+ import pdb
16
+
17
+ ############################################
18
+ # Predict detections with MegaDetector v5a model
19
+ def predict_md(im,
20
+ megadetector_model, #Megadet_Models[mega_model_input]
21
+ size=640):
22
+
23
+ # resize image
24
+ g = (size / max(im.size)) # multipl factor to make max size of the image equal to input size
25
+ im = im.resize((int(x * g) for x in im.size),
26
+ Image.ANTIALIAS) # resize
27
+ # device
28
+ if torch.cuda.is_available():
29
+ md_device = torch.device('cuda')
30
+ else:
31
+ md_device = torch.device('cpu')
32
+
33
+ # megadetector
34
+ MD_model = torch.hub.load('ultralytics/yolov5', # repo_or_dir
35
+ 'custom', #model
36
+ megadetector_model, # args for callable model
37
+ force_reload=True,
38
+ device=md_device)
39
+
40
+ # send model to gpu if possible
41
+ if (md_device == torch.device('cuda')):
42
+ print('Sending model to GPU')
43
+ MD_model.to(md_device)
44
+
45
+ ## detect objects
46
+ results = MD_model(im) # inference # vars(results).keys()= dict_keys(['imgs', 'pred', 'names', 'files', 'times', 'xyxy', 'xywh', 'xyxyn', 'xywhn', 'n', 't', 's'])
47
+
48
+ return results
49
+
50
+
51
+ ##########################################
52
+ def crop_animal_detections(img_in,
53
+ yolo_results,
54
+ likelihood_th):
55
+
56
+ ## Extract animal crops
57
+ list_labels_as_str = [i for i in yolo_results.names.values()] # ['animal', 'person', 'vehicle']
58
+ list_np_animal_crops = []
59
+
60
+ # image to crop (scale as input for megadetector)
61
+ img_in = img_in.resize((yolo_results.ims[0].shape[1],
62
+ yolo_results.ims[0].shape[0]))
63
+ # for every detection in the img
64
+ for det_array in yolo_results.xyxy:
65
+
66
+ # for every detection
67
+ for j in range(det_array.shape[0]):
68
+
69
+ # compute coords around bbox rounded to the nearest integer (for pasting later)
70
+ xmin_rd = int(math.floor(det_array[j,0])) # int() should suffice?
71
+ ymin_rd = int(math.floor(det_array[j,1]))
72
+
73
+ xmax_rd = int(math.ceil(det_array[j,2]))
74
+ ymax_rd = int(math.ceil(det_array[j,3]))
75
+
76
+ pred_llk = det_array[j,4]
77
+ pred_label = det_array[j,5]
78
+ # keep animal crops above threshold
79
+ if (pred_label == list_labels_as_str.index('animal')) and \
80
+ (pred_llk >= likelihood_th):
81
+ area = (xmin_rd, ymin_rd, xmax_rd, ymax_rd)
82
+
83
+ #pdb.set_trace()
84
+ crop = img_in.crop(area) #Image.fromarray(img_in).crop(area)
85
+ crop_np = np.asarray(crop)
86
+
87
+ # add to list
88
+ list_np_animal_crops.append(crop_np)
89
+
90
+ return list_np_animal_crops
91
+
92
+ ##########################################
93
+ def predict_dlc(list_np_crops,
94
+ kpts_likelihood_th,
95
+ DLCmodel,
96
+ dlc_proc):
97
+
98
+ # run dlc thru list of crops
99
+ dlc_live = DLCLive(DLCmodel, processor=dlc_proc)
100
+ dlc_live.init_inference(list_np_crops[0])
101
+
102
+ list_kpts_per_crop = []
103
+ all_kypts = []
104
+ np_aux = np.empty((1,3)) # can I avoid hardcoding here?
105
+ for crop in list_np_crops:
106
+ # scale crop here?
107
+ keypts_xyp = dlc_live.get_pose(crop) # third column is llk!
108
+ # set kpts below threhsold to nan
109
+
110
+ #pdb.set_trace()
111
+ keypts_xyp[keypts_xyp[:,-1] < kpts_likelihood_th,:] = np_aux.fill(np.nan)
112
+ # add kpts of this crop to list
113
+ list_kpts_per_crop.append(keypts_xyp)
114
+ all_kypts.append(keypts_xyp)
115
+
116
+ return list_kpts_per_crop
save_results.py DELETED
@@ -1,56 +0,0 @@
1
- import json
2
- import numpy as np
3
- import pdb
4
-
5
- dict_pred = {0: 'animal', 1: 'person', 2: 'vehicle'}
6
-
7
-
8
- def save_results(md_results, dlc_outputs,map_label_id_to_str,thr,output_file = 'dowload_predictions.json'):
9
-
10
- """
11
-
12
- write json
13
-
14
- """
15
- info = {}
16
- ## info megaDetector
17
- info['file']= md_results.files[0]
18
- number_bb = len(md_results.xyxy[0].tolist())
19
- info['number_of_bb'] = number_bb
20
- number_bb_thr = len(dlc_outputs)
21
- labels = [n for n in map_label_id_to_str.values()]
22
- #pdb.set_trace()
23
- new_index = []
24
- for i in range(number_bb):
25
- corner_x1,corner_y1,corner_x2,corner_y2,confidence, _ = md_results.xyxy[0].tolist()[i]
26
-
27
- if confidence > thr:
28
- new_index.append(i)
29
-
30
-
31
- for i in range(number_bb_thr):
32
- aux={}
33
- corner_x1,corner_y1,corner_x2,corner_y2,confidence, _ = md_results.xyxy[0].tolist()[new_index[i]]
34
- aux['corner_1'] = (corner_x1,corner_y1)
35
- aux['corner_2'] = (corner_x2,corner_y2)
36
- aux['predict MD'] = md_results.names[0]
37
- aux['confidence MD'] = confidence
38
-
39
- ## info dlc
40
- kypts = []
41
- for s in dlc_outputs[i]:
42
- aux1 = []
43
- for j in s:
44
- aux1.append(float(j))
45
-
46
- kypts.append(aux1)
47
- aux['dlc_pred'] = dict(zip(labels,kypts))
48
- info['bb_' + str(new_index[i]) ]=aux
49
-
50
-
51
- with open(output_file, 'w') as f:
52
- json.dump(info, f, indent=1)
53
- print('Output file saved at {}'.format(output_file))
54
-
55
- return output_file
56
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ui_utils.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ ##############################
4
+ def gradio_inputs_for_MD_DLC(md_models_list, # list(MD_models_dict.keys())
5
+ dlc_models_list, # list(DLC_models_dict.keys())
6
+ ):
7
+ # Input image
8
+ gr_image_input = gr.inputs.Image(type="pil", label="Input Image")
9
+
10
+
11
+ # Models
12
+ gr_mega_model_input = gr.inputs.Dropdown(choices=md_models_list,
13
+ default='md_v5a', # default option
14
+ type='value', # Type of value to be returned by component. "value" returns the string of the choice selected, "index" returns the index of the choice selected.
15
+ label='Select MegaDetector model')
16
+ gr_dlc_model_input = gr.inputs.Dropdown(choices=dlc_models_list, # choices
17
+ default='full_cat', # default option
18
+ type='value', # Type of value to be returned by component. "value" returns the string of the choice selected, "index" returns the index of the choice selected.
19
+ label='Select DeepLabCut model')
20
+
21
+ # Other inputs
22
+ gr_dlc_only_checkbox = gr.inputs.Checkbox(False,
23
+ label='Run DLClive only, directly on input image?')
24
+ gr_str_labels_checkbox = gr.inputs.Checkbox(True,
25
+ label='Show bodypart labels?')
26
+
27
+ gr_slider_conf_bboxes = gr.inputs.Slider(0,1,.02,0.8,
28
+ label='Set confidence threshold for animal detections')
29
+ gr_slider_conf_keypoints = gr.inputs.Slider(0,1,.05,0,
30
+ label='Set confidence threshold for keypoints')
31
+
32
+ # Data viz
33
+ gr_keypt_color = gr.ColorPicker(label="choose color for keypoint label")
34
+
35
+ gr_labels_font_style = gr.inputs.Dropdown(choices=['amiko', 'nature', 'painter', 'animals', 'zen'],
36
+ default='amiko',
37
+ type='value',
38
+ label='Select keypoint label font')
39
+ gr_slider_font_size = gr.inputs.Slider(5,30,1,8,
40
+ label='Set font size')
41
+ gr_slider_marker_size = gr.inputs.Slider(1,20,1,5,
42
+ label='Set marker size')
43
+
44
+ # list of inputs
45
+ return [gr_image_input,
46
+ gr_mega_model_input,
47
+ gr_dlc_model_input,
48
+ gr_dlc_only_checkbox,
49
+ gr_str_labels_checkbox,
50
+ gr_slider_conf_bboxes,
51
+ gr_slider_conf_keypoints,
52
+ gr_labels_font_style,
53
+ gr_slider_font_size,
54
+ gr_keypt_color,
55
+ gr_slider_marker_size]
56
+
57
+ ####################################################
58
+ def gradio_outputs_for_MD_DLC():
59
+ # User interface: outputs
60
+ gr_image_output = gr.outputs.Image(type="pil", label="Output Image")
61
+ gr_file_download = gr.File(label="Download JSON file")
62
+ return [gr_image_output,
63
+ gr_file_download]
64
+
65
+ ##############################################
66
+ # User interace: description
67
+ def gradio_description_and_examples():
68
+ title = "MegaDetector v5 + DeepLabCut-Live!"
69
+ description = "Contributed by Sofia Minano, Neslihan Wittek, Nirel Kadzo, VicShaoChih Chiang, Sabrina Benas -- DLC AI Residents 2022..\
70
+ This App detects and estimate the pose of animals in camera trap images using <a href='https://github.com/microsoft/CameraTraps'>MegaDetector v5a</a> + <a href='https://github.com/DeepLabCut/DeepLabCut-live'>DeepLabCut-live</a>. \
71
+ We host models from the <a href='http://www.mackenziemathislab.org/dlc-modelzoo'>DeepLabCut ModelZoo Project</a>\, and two <a href='https://github.com/microsoft/CameraTraps/blob/main/megadetector.md'>MegaDetector Models</a>. Please carefully check their licensing information if you use this project. The App additionally builds upon on work from <a href='https://huggingface.co/spaces/hlydecker/MegaDetector_v5'>hlydecker/MegaDetector_v5</a> \
72
+ <a href='https://huggingface.co/spaces/sofmi/MegaDetector_DLClive'>sofmi/MegaDetector_DLClive</a> \
73
+ <a href='https://huggingface.co/spaces/Neslihan/megadetector_dlcmodels'>Neslihan/megadetector_dlcmodels</a>\."
74
+
75
+ # article = "<p style='text-align: center'>This app makes predictions using a YOLOv5x6 model that was trained to detect animals, humans, and vehicles in camera trap images; find out more about the project on <a href='https://github.com/microsoft/CameraTraps'>GitHub</a>. This app was built by Henry Lydecker but really depends on code and models developed by <a href='http://ecologize.org/'>Ecologize</a> and <a href='http://aka.ms/aiforearth'>Microsoft AI for Earth</a>. Find out more about the YOLO model from the original creator, <a href='https://pjreddie.com/darknet/yolo/'>Joseph Redmon</a>. YOLOv5 is a family of compound-scaled object detection models trained on the COCO dataset and developed by Ultralytics, and includes simple functionality for Test Time Augmentation (TTA), model ensembling, hyperparameter evolution, and export to ONNX, CoreML and TFLite. <a href='https://github.com/ultralytics/yolov5'>Source code</a> | <a href='https://pytorch.org/hub/ultralytics_yolov5'>PyTorch Hub</a></p>"
76
+
77
+ examples = [['examples/monkey_full.jpg', 'md_v5a','full_macaque', False, True, 0.5, 0.3, 'amiko', 9, 'blue', 3],
78
+ ['examples/dog.jpeg', 'md_v5a', 'full_dog', False, True, 0.5, 0.00, 'amiko',9, 'yellow', 3],
79
+ ['examples/cat.jpg', 'md_v5a', 'full_cat', False, True, 0.5, 0.05, 'amiko', 9, 'purple', 3]]
80
+
81
+ return [title,description,examples]
viz_utils.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import numpy as np
3
+
4
+ from matplotlib import cm
5
+ import matplotlib
6
+ from PIL import Image, ImageColor, ImageFont, ImageDraw
7
+ import numpy as np
8
+
9
+ FONTS = {'amiko': "fonts/Amiko-Regular.ttf",
10
+ 'nature': "fonts/LoveNature.otf",
11
+ 'painter':"fonts/PainterDecorator.otf",
12
+ 'animals': "fonts/UncialAnimals.ttf",
13
+ 'zen': "fonts/ZEN.TTF"}
14
+
15
+ #########################################
16
+ # Draw keypoints on image
17
+ def draw_keypoints_on_image(image,
18
+ keypoints,
19
+ map_label_id_to_str,
20
+ flag_show_str_labels,
21
+ use_normalized_coordinates=True,
22
+ font_style='amiko',
23
+ font_size=8,
24
+ keypt_color="#ff0000",
25
+ marker_size=2,
26
+ ):
27
+ """Draws keypoints on an image.
28
+ Modified from:
29
+ https://www.programcreek.com/python/?code=fjchange%2Fobject_centric_VAD%2Fobject_centric_VAD-master%2Fobject_detection%2Futils%2Fvisualization_utils.py
30
+ Args:
31
+ image: a PIL.Image object.
32
+ keypoints: a numpy array with shape [num_keypoints, 2].
33
+ map_label_id_to_str: dict with keys=label number and values= label string
34
+ flag_show_str_labels: boolean to select whether or not to show string labels
35
+ color: color to draw the keypoints with. Default is red.
36
+ radius: keypoint radius. Default value is 2.
37
+ use_normalized_coordinates: if True (default), treat keypoint values as
38
+ relative to the image. Otherwise treat them as absolute.
39
+
40
+
41
+ """
42
+ # get a drawing context
43
+ draw = ImageDraw.Draw(image,"RGBA")
44
+
45
+ im_width, im_height = image.size
46
+ keypoints_x = [k[0] for k in keypoints]
47
+ keypoints_y = [k[1] for k in keypoints]
48
+ alpha = [k[2] for k in keypoints]
49
+ norm = matplotlib.colors.Normalize(vmin=0, vmax=255)
50
+
51
+ names_for_color = [i for i in map_label_id_to_str.keys()]
52
+ colores = np.linspace(0, 255, num=len(names_for_color),dtype= int)
53
+
54
+ # adjust keypoints coords if required
55
+ if use_normalized_coordinates:
56
+ keypoints_x = tuple([im_width * x for x in keypoints_x])
57
+ keypoints_y = tuple([im_height * y for y in keypoints_y])
58
+
59
+ #cmap = matplotlib.cm.get_cmap('hsv')
60
+ cmap2 = matplotlib.cm.get_cmap('Greys')
61
+ # draw ellipses around keypoints
62
+ for i, (keypoint_x, keypoint_y) in enumerate(zip(keypoints_x, keypoints_y)):
63
+ round_fill = list(cm.viridis(norm(colores[i]),bytes=True))#[round(num*255) for num in list(cmap(i))[:3]] #check!
64
+ if np.isnan(alpha[i]) == False :
65
+ round_fill[3] = round(alpha[i] *255)
66
+ #print(round_fill)
67
+ #round_outline = [round(num*255) for num in list(cmap2(alpha[i]))[:3]]
68
+ draw.ellipse([(keypoint_x - marker_size, keypoint_y - marker_size),
69
+ (keypoint_x + marker_size, keypoint_y + marker_size)],
70
+ fill=tuple(round_fill), outline= 'black', width=1) #fill and outline: [0,255]
71
+
72
+ # add string labels around keypoints
73
+ if flag_show_str_labels:
74
+ font = ImageFont.truetype(FONTS[font_style],
75
+ font_size)
76
+ draw.text((keypoint_x + marker_size, keypoint_y + marker_size),#(0.5*im_width, 0.5*im_height), #-------
77
+ map_label_id_to_str[i],
78
+ ImageColor.getcolor(keypt_color, "RGB"), # rgb #
79
+ font=font)
80
+
81
+ #########################################
82
+ # Draw bboxes on image
83
+ def draw_bbox_w_text(img,
84
+ results,
85
+ font_style='amiko',
86
+ font_size=8): #TODO: select color too?
87
+ #pdb.set_trace()
88
+ bbxyxy = results
89
+ w, h = bbxyxy[2], bbxyxy[3]
90
+ shape = [(bbxyxy[0], bbxyxy[1]), (w , h)]
91
+ imgR = ImageDraw.Draw(img)
92
+ imgR.rectangle(shape, outline ="red",width=5) ##bb for animal
93
+
94
+ confidence = bbxyxy[4]
95
+ string_bb = 'animal ' + str(round(confidence, 2))
96
+ font = ImageFont.truetype(FONTS[font_style], font_size)
97
+
98
+ text_size = font.getsize(string_bb) # (h,w)
99
+ position = (bbxyxy[0],bbxyxy[1] - text_size[1] -2 )
100
+ left, top, right, bottom = imgR.textbbox(position, string_bb, font=font)
101
+ imgR.rectangle((left, top-5, right+5, bottom+5), fill="red")
102
+ imgR.text((bbxyxy[0] + 3 ,bbxyxy[1] - text_size[1] -2 ), string_bb, font=font, fill="black")
103
+
104
+ return imgR
105
+
106
+ ###########################################
107
+ def save_results_as_json(md_results,
108
+ dlc_outputs,
109
+ map_dlc_label_id_to_str,
110
+ thr,
111
+ path_to_output_file = 'dowload_predictions.json'):
112
+
113
+ """
114
+ Output detections as json file
115
+
116
+ """
117
+ # initialise dict to save to json
118
+ info = {}
119
+ # info from megaDetector
120
+ info['file']= md_results.files[0]
121
+ number_bb = len(md_results.xyxy[0].tolist())
122
+ info['number_of_bb'] = number_bb
123
+ # info from DLC
124
+ number_bb_thr = len(dlc_outputs)
125
+ labels = [n for n in map_dlc_label_id_to_str.values()]
126
+
127
+ # create list of bboxes above th
128
+ new_index = []
129
+ for i in range(number_bb):
130
+ corner_x1,corner_y1,corner_x2,corner_y2,confidence, _ = md_results.xyxy[0].tolist()[i]
131
+
132
+ if confidence > thr:
133
+ new_index.append(i)
134
+
135
+ # define aux dict for every bounding box above threshold
136
+ for i in range(number_bb_thr):
137
+ aux={}
138
+ # MD output
139
+ corner_x1,corner_y1,corner_x2,corner_y2,confidence, _ = md_results.xyxy[0].tolist()[new_index[i]]
140
+ aux['corner_1'] = (corner_x1,corner_y1)
141
+ aux['corner_2'] = (corner_x2,corner_y2)
142
+ aux['predict MD'] = md_results.names[0]
143
+ aux['confidence MD'] = confidence
144
+
145
+ # DLC output
146
+ kypts = []
147
+ for s in dlc_outputs[i]:
148
+ aux1 = []
149
+ for j in s:
150
+ aux1.append(float(j))
151
+
152
+ kypts.append(aux1)
153
+ aux['dlc_pred'] = dict(zip(labels,kypts))
154
+ info['bb_' + str(new_index[i]) ]=aux
155
+
156
+ # save dict as json
157
+ with open(path_to_output_file, 'w') as f:
158
+ json.dump(info, f, indent=1)
159
+ print('Output file saved at {}'.format(path_to_output_file))
160
+
161
+ return path_to_output_file
162
+
163
+
164
+
165
+ ###########################################