Spaces:
Build error
Build error
sabrinabenas
commited on
Commit
β’
7578ea3
1
Parent(s):
1b9c782
add confidence as alpha
Browse files- app.py +27 -18
- save_results.py +3 -2
app.py
CHANGED
@@ -5,7 +5,7 @@
|
|
5 |
|
6 |
from tkinter import W
|
7 |
import gradio as gr
|
8 |
-
|
9 |
import torch
|
10 |
import torchvision
|
11 |
from dlclive import DLCLive, Processor
|
@@ -33,11 +33,12 @@ FONTS = {'amiko': "font/Amiko-Regular.ttf",
|
|
33 |
Megadet_Models = {'md_v5a': "megadet_model/md_v5a.0.0.pt",
|
34 |
'md_v5b': "megadet_model/md_v5b.0.0.pt"}
|
35 |
|
36 |
-
|
37 |
-
'full_dog': "model/
|
38 |
-
'primate_face': "model/
|
39 |
-
'full_human': "model/
|
40 |
-
'full_macaque': 'model/
|
|
|
41 |
DLC_models_list = ['full_cat', 'full_dog','primate_face', 'full_human', 'full_macaque']
|
42 |
#########################################
|
43 |
# Draw keypoints on image
|
@@ -67,27 +68,34 @@ def draw_keypoints_on_image(image,
|
|
67 |
|
68 |
"""
|
69 |
# get a drawing context
|
70 |
-
draw = ImageDraw.Draw(image)
|
71 |
|
72 |
im_width, im_height = image.size
|
73 |
keypoints_x = [k[0] for k in keypoints]
|
74 |
keypoints_y = [k[1] for k in keypoints]
|
75 |
alpha = [k[2] for k in keypoints]
|
|
|
|
|
|
|
|
|
76 |
|
77 |
# adjust keypoints coords if required
|
78 |
if use_normalized_coordinates:
|
79 |
keypoints_x = tuple([im_width * x for x in keypoints_x])
|
80 |
keypoints_y = tuple([im_height * y for y in keypoints_y])
|
81 |
|
82 |
-
cmap = matplotlib.cm.get_cmap('hsv')
|
83 |
cmap2 = matplotlib.cm.get_cmap('Greys')
|
84 |
# draw ellipses around keypoints
|
85 |
for i, (keypoint_x, keypoint_y) in enumerate(zip(keypoints_x, keypoints_y)):
|
86 |
-
round_fill = [round(num*255) for num in list(cmap(i
|
87 |
-
|
|
|
|
|
|
|
88 |
draw.ellipse([(keypoint_x - marker_size, keypoint_y - marker_size),
|
89 |
(keypoint_x + marker_size, keypoint_y + marker_size)],
|
90 |
-
fill=tuple(round_fill), outline=
|
91 |
|
92 |
# add string labels around keypoints
|
93 |
if flag_show_str_labels:
|
@@ -95,7 +103,7 @@ def draw_keypoints_on_image(image,
|
|
95 |
font_size)
|
96 |
draw.text((keypoint_x + marker_size, keypoint_y + marker_size),#(0.5*im_width, 0.5*im_height), #-------
|
97 |
map_label_id_to_str[i],
|
98 |
-
ImageColor.getcolor(keypt_color, "RGB"), # rgb
|
99 |
font=font)
|
100 |
|
101 |
############################################
|
@@ -199,7 +207,7 @@ def predict_dlc(list_np_crops,
|
|
199 |
# add kpts of this crop to list
|
200 |
list_kpts_per_crop.append(keypts_xyp)
|
201 |
all_kypts.append(keypts_xyp)
|
202 |
-
|
203 |
return list_kpts_per_crop
|
204 |
|
205 |
|
@@ -221,8 +229,8 @@ def predict_pipeline(img_input,
|
|
221 |
## Get DLC model and labels as strings
|
222 |
# TODO: make a dict as for megadetector
|
223 |
# pdb.set_trace()
|
224 |
-
path_to_DLCmodel = DownloadModel(dlc_model_input_str,
|
225 |
-
pose_cfg_path = '
|
226 |
#pdb.set_trece()
|
227 |
# extract map label ids to strings
|
228 |
# pose_cfg_dict['all_joints'] is a list of one-element lists,
|
@@ -266,7 +274,8 @@ def predict_pipeline(img_input,
|
|
266 |
font_size=font_size,
|
267 |
keypt_color=keypt_color,
|
268 |
marker_size=marker_size)
|
269 |
-
|
|
|
270 |
|
271 |
else:
|
272 |
# Compute kpts for each crop
|
@@ -297,13 +306,13 @@ def predict_pipeline(img_input,
|
|
297 |
img_background.paste(img_crop, box = tuple([int(t) for t in md_results.xyxy[0][ic,:2]]))
|
298 |
|
299 |
|
300 |
-
|
301 |
bb_per_animal = md_results.xyxy[0].tolist()[ic]
|
302 |
pred = md_results.xyxy[0].tolist()[ic][4]
|
303 |
if bbox_likelihood_th < pred:
|
304 |
draw_rectangle_text(img_background, bb_per_animal ,font_style=font_style,font_size=font_size, keypt_color=keypt_color)
|
305 |
|
306 |
-
|
307 |
|
308 |
download_file = save_results(md_results,list_kpts_per_crop,map_label_id_to_str,bbox_likelihood_th)
|
309 |
|
|
|
5 |
|
6 |
from tkinter import W
|
7 |
import gradio as gr
|
8 |
+
from matplotlib import cm
|
9 |
import torch
|
10 |
import torchvision
|
11 |
from dlclive import DLCLive, Processor
|
|
|
33 |
Megadet_Models = {'md_v5a': "megadet_model/md_v5a.0.0.pt",
|
34 |
'md_v5b': "megadet_model/md_v5b.0.0.pt"}
|
35 |
|
36 |
+
DLC_folders = {'full_cat': "model/DLC_Cat/",
|
37 |
+
'full_dog': "model/DLC_Dog/",
|
38 |
+
'primate_face': "model/DLC_FacialLandmarks/",
|
39 |
+
'full_human': "model/DLC_human_dancing/",
|
40 |
+
'full_macaque': 'model/DLC_monkey/'}
|
41 |
+
|
42 |
DLC_models_list = ['full_cat', 'full_dog','primate_face', 'full_human', 'full_macaque']
|
43 |
#########################################
|
44 |
# Draw keypoints on image
|
|
|
68 |
|
69 |
"""
|
70 |
# get a drawing context
|
71 |
+
draw = ImageDraw.Draw(image,"RGBA")
|
72 |
|
73 |
im_width, im_height = image.size
|
74 |
keypoints_x = [k[0] for k in keypoints]
|
75 |
keypoints_y = [k[1] for k in keypoints]
|
76 |
alpha = [k[2] for k in keypoints]
|
77 |
+
norm = matplotlib.colors.Normalize(vmin=0, vmax=255)
|
78 |
+
|
79 |
+
names_for_color = [i for i in map_label_id_to_str.keys()]
|
80 |
+
colores = np.linspace(0, 255, num=len(names_for_color),dtype= int)
|
81 |
|
82 |
# adjust keypoints coords if required
|
83 |
if use_normalized_coordinates:
|
84 |
keypoints_x = tuple([im_width * x for x in keypoints_x])
|
85 |
keypoints_y = tuple([im_height * y for y in keypoints_y])
|
86 |
|
87 |
+
#cmap = matplotlib.cm.get_cmap('hsv')
|
88 |
cmap2 = matplotlib.cm.get_cmap('Greys')
|
89 |
# draw ellipses around keypoints
|
90 |
for i, (keypoint_x, keypoint_y) in enumerate(zip(keypoints_x, keypoints_y)):
|
91 |
+
round_fill = list(cm.viridis(norm(colores[i]),bytes=True))#[round(num*255) for num in list(cmap(i))[:3]] #check!
|
92 |
+
if np.isnan(alpha[i]) == False :
|
93 |
+
round_fill[3] = round(alpha[i] *255)
|
94 |
+
#print(round_fill)
|
95 |
+
#round_outline = [round(num*255) for num in list(cmap2(alpha[i]))[:3]]
|
96 |
draw.ellipse([(keypoint_x - marker_size, keypoint_y - marker_size),
|
97 |
(keypoint_x + marker_size, keypoint_y + marker_size)],
|
98 |
+
fill=tuple(round_fill), outline= 'black', width=1) #fill and outline: [0,255]
|
99 |
|
100 |
# add string labels around keypoints
|
101 |
if flag_show_str_labels:
|
|
|
103 |
font_size)
|
104 |
draw.text((keypoint_x + marker_size, keypoint_y + marker_size),#(0.5*im_width, 0.5*im_height), #-------
|
105 |
map_label_id_to_str[i],
|
106 |
+
ImageColor.getcolor(keypt_color, "RGB"), # rgb #
|
107 |
font=font)
|
108 |
|
109 |
############################################
|
|
|
207 |
# add kpts of this crop to list
|
208 |
list_kpts_per_crop.append(keypts_xyp)
|
209 |
all_kypts.append(keypts_xyp)
|
210 |
+
|
211 |
return list_kpts_per_crop
|
212 |
|
213 |
|
|
|
229 |
## Get DLC model and labels as strings
|
230 |
# TODO: make a dict as for megadetector
|
231 |
# pdb.set_trace()
|
232 |
+
path_to_DLCmodel = DownloadModel(dlc_model_input_str, str(DLC_folders[dlc_model_input_str]) )
|
233 |
+
pose_cfg_path = str(DLC_folders[dlc_model_input_str]) +'pose_cfg.yaml'
|
234 |
#pdb.set_trece()
|
235 |
# extract map label ids to strings
|
236 |
# pose_cfg_dict['all_joints'] is a list of one-element lists,
|
|
|
274 |
font_size=font_size,
|
275 |
keypt_color=keypt_color,
|
276 |
marker_size=marker_size)
|
277 |
+
|
278 |
+
return img_input, []
|
279 |
|
280 |
else:
|
281 |
# Compute kpts for each crop
|
|
|
306 |
img_background.paste(img_crop, box = tuple([int(t) for t in md_results.xyxy[0][ic,:2]]))
|
307 |
|
308 |
|
309 |
+
|
310 |
bb_per_animal = md_results.xyxy[0].tolist()[ic]
|
311 |
pred = md_results.xyxy[0].tolist()[ic][4]
|
312 |
if bbox_likelihood_th < pred:
|
313 |
draw_rectangle_text(img_background, bb_per_animal ,font_style=font_style,font_size=font_size, keypt_color=keypt_color)
|
314 |
|
315 |
+
|
316 |
|
317 |
download_file = save_results(md_results,list_kpts_per_crop,map_label_id_to_str,bbox_likelihood_th)
|
318 |
|
save_results.py
CHANGED
@@ -38,7 +38,8 @@ def save_results(md_results, dlc_outputs,map_label_id_to_str,thr,output_file = '
|
|
38 |
|
39 |
## info dlc
|
40 |
kypts = []
|
41 |
-
pdb.set_trace()
|
|
|
42 |
for s in dlc_outputs[i]:
|
43 |
#print(s)
|
44 |
aux1 = []
|
@@ -46,7 +47,7 @@ def save_results(md_results, dlc_outputs,map_label_id_to_str,thr,output_file = '
|
|
46 |
aux1.append(float(j))
|
47 |
|
48 |
kypts.append(aux1)
|
49 |
-
pdb.set_trace()
|
50 |
aux['dlc_pred'] = dict(zip(labels,kypts))
|
51 |
info['bb_' + str(new_index[i]) ]=aux
|
52 |
|
|
|
38 |
|
39 |
## info dlc
|
40 |
kypts = []
|
41 |
+
#pdb.set_trace()
|
42 |
+
print(dlc_outputs[i])
|
43 |
for s in dlc_outputs[i]:
|
44 |
#print(s)
|
45 |
aux1 = []
|
|
|
47 |
aux1.append(float(j))
|
48 |
|
49 |
kypts.append(aux1)
|
50 |
+
#pdb.set_trace()
|
51 |
aux['dlc_pred'] = dict(zip(labels,kypts))
|
52 |
info['bb_' + str(new_index[i]) ]=aux
|
53 |
|