sabrinabenas commited on
Commit
7578ea3
β€’
1 Parent(s): 1b9c782

add confidence as alpha

Browse files
Files changed (2) hide show
  1. app.py +27 -18
  2. save_results.py +3 -2
app.py CHANGED
@@ -5,7 +5,7 @@
5
 
6
  from tkinter import W
7
  import gradio as gr
8
-
9
  import torch
10
  import torchvision
11
  from dlclive import DLCLive, Processor
@@ -33,11 +33,12 @@ FONTS = {'amiko': "font/Amiko-Regular.ttf",
33
  Megadet_Models = {'md_v5a': "megadet_model/md_v5a.0.0.pt",
34
  'md_v5b': "megadet_model/md_v5b.0.0.pt"}
35
 
36
- DLC_models = {'full_cat': "model/DLC_Cat_resnet_50_iteration-0_shuffle-0",
37
- 'full_dog': "model/DLC_Dog_resnet_50_iteration-0_shuffle-0",
38
- 'primate_face': "model/DLC_FacialLandmarks_resnet_50_iteration-1_shuffle-1",
39
- 'full_human': "model/DLC_human_dancing_resnet_101_iteration-0_shuffle-1",
40
- 'full_macaque': 'model/DLC_monkey_resnet_50_iteration-0_shuffle-1'}
 
41
  DLC_models_list = ['full_cat', 'full_dog','primate_face', 'full_human', 'full_macaque']
42
  #########################################
43
  # Draw keypoints on image
@@ -67,27 +68,34 @@ def draw_keypoints_on_image(image,
67
 
68
  """
69
  # get a drawing context
70
- draw = ImageDraw.Draw(image)
71
 
72
  im_width, im_height = image.size
73
  keypoints_x = [k[0] for k in keypoints]
74
  keypoints_y = [k[1] for k in keypoints]
75
  alpha = [k[2] for k in keypoints]
 
 
 
 
76
 
77
  # adjust keypoints coords if required
78
  if use_normalized_coordinates:
79
  keypoints_x = tuple([im_width * x for x in keypoints_x])
80
  keypoints_y = tuple([im_height * y for y in keypoints_y])
81
 
82
- cmap = matplotlib.cm.get_cmap('hsv')
83
  cmap2 = matplotlib.cm.get_cmap('Greys')
84
  # draw ellipses around keypoints
85
  for i, (keypoint_x, keypoint_y) in enumerate(zip(keypoints_x, keypoints_y)):
86
- round_fill = [round(num*255) for num in list(cmap(i*10))[:3]] #check!
87
- round_outline = [round(num*255) for num in list(cmap2(alpha[i]))[:3]]
 
 
 
88
  draw.ellipse([(keypoint_x - marker_size, keypoint_y - marker_size),
89
  (keypoint_x + marker_size, keypoint_y + marker_size)],
90
- fill=tuple(round_fill), outline= tuple(round_outline), width=2) #fill and outline: [0,255]
91
 
92
  # add string labels around keypoints
93
  if flag_show_str_labels:
@@ -95,7 +103,7 @@ def draw_keypoints_on_image(image,
95
  font_size)
96
  draw.text((keypoint_x + marker_size, keypoint_y + marker_size),#(0.5*im_width, 0.5*im_height), #-------
97
  map_label_id_to_str[i],
98
- ImageColor.getcolor(keypt_color, "RGB"), # rgb
99
  font=font)
100
 
101
  ############################################
@@ -199,7 +207,7 @@ def predict_dlc(list_np_crops,
199
  # add kpts of this crop to list
200
  list_kpts_per_crop.append(keypts_xyp)
201
  all_kypts.append(keypts_xyp)
202
- #return confidence here
203
  return list_kpts_per_crop
204
 
205
 
@@ -221,8 +229,8 @@ def predict_pipeline(img_input,
221
  ## Get DLC model and labels as strings
222
  # TODO: make a dict as for megadetector
223
  # pdb.set_trace()
224
- path_to_DLCmodel = DownloadModel(dlc_model_input_str, 'model/')
225
- pose_cfg_path = 'model/pose_cfg.yaml'
226
  #pdb.set_trece()
227
  # extract map label ids to strings
228
  # pose_cfg_dict['all_joints'] is a list of one-element lists,
@@ -266,7 +274,8 @@ def predict_pipeline(img_input,
266
  font_size=font_size,
267
  keypt_color=keypt_color,
268
  marker_size=marker_size)
269
- return img_input
 
270
 
271
  else:
272
  # Compute kpts for each crop
@@ -297,13 +306,13 @@ def predict_pipeline(img_input,
297
  img_background.paste(img_crop, box = tuple([int(t) for t in md_results.xyxy[0][ic,:2]]))
298
 
299
 
300
- #set trh!! FIXME
301
  bb_per_animal = md_results.xyxy[0].tolist()[ic]
302
  pred = md_results.xyxy[0].tolist()[ic][4]
303
  if bbox_likelihood_th < pred:
304
  draw_rectangle_text(img_background, bb_per_animal ,font_style=font_style,font_size=font_size, keypt_color=keypt_color)
305
 
306
- print(pred)
307
 
308
  download_file = save_results(md_results,list_kpts_per_crop,map_label_id_to_str,bbox_likelihood_th)
309
 
 
5
 
6
  from tkinter import W
7
  import gradio as gr
8
+ from matplotlib import cm
9
  import torch
10
  import torchvision
11
  from dlclive import DLCLive, Processor
 
33
  Megadet_Models = {'md_v5a': "megadet_model/md_v5a.0.0.pt",
34
  'md_v5b': "megadet_model/md_v5b.0.0.pt"}
35
 
36
+ DLC_folders = {'full_cat': "model/DLC_Cat/",
37
+ 'full_dog': "model/DLC_Dog/",
38
+ 'primate_face': "model/DLC_FacialLandmarks/",
39
+ 'full_human': "model/DLC_human_dancing/",
40
+ 'full_macaque': 'model/DLC_monkey/'}
41
+
42
  DLC_models_list = ['full_cat', 'full_dog','primate_face', 'full_human', 'full_macaque']
43
  #########################################
44
  # Draw keypoints on image
 
68
 
69
  """
70
  # get a drawing context
71
+ draw = ImageDraw.Draw(image,"RGBA")
72
 
73
  im_width, im_height = image.size
74
  keypoints_x = [k[0] for k in keypoints]
75
  keypoints_y = [k[1] for k in keypoints]
76
  alpha = [k[2] for k in keypoints]
77
+ norm = matplotlib.colors.Normalize(vmin=0, vmax=255)
78
+
79
+ names_for_color = [i for i in map_label_id_to_str.keys()]
80
+ colores = np.linspace(0, 255, num=len(names_for_color),dtype= int)
81
 
82
  # adjust keypoints coords if required
83
  if use_normalized_coordinates:
84
  keypoints_x = tuple([im_width * x for x in keypoints_x])
85
  keypoints_y = tuple([im_height * y for y in keypoints_y])
86
 
87
+ #cmap = matplotlib.cm.get_cmap('hsv')
88
  cmap2 = matplotlib.cm.get_cmap('Greys')
89
  # draw ellipses around keypoints
90
  for i, (keypoint_x, keypoint_y) in enumerate(zip(keypoints_x, keypoints_y)):
91
+ round_fill = list(cm.viridis(norm(colores[i]),bytes=True))#[round(num*255) for num in list(cmap(i))[:3]] #check!
92
+ if np.isnan(alpha[i]) == False :
93
+ round_fill[3] = round(alpha[i] *255)
94
+ #print(round_fill)
95
+ #round_outline = [round(num*255) for num in list(cmap2(alpha[i]))[:3]]
96
  draw.ellipse([(keypoint_x - marker_size, keypoint_y - marker_size),
97
  (keypoint_x + marker_size, keypoint_y + marker_size)],
98
+ fill=tuple(round_fill), outline= 'black', width=1) #fill and outline: [0,255]
99
 
100
  # add string labels around keypoints
101
  if flag_show_str_labels:
 
103
  font_size)
104
  draw.text((keypoint_x + marker_size, keypoint_y + marker_size),#(0.5*im_width, 0.5*im_height), #-------
105
  map_label_id_to_str[i],
106
+ ImageColor.getcolor(keypt_color, "RGB"), # rgb #
107
  font=font)
108
 
109
  ############################################
 
207
  # add kpts of this crop to list
208
  list_kpts_per_crop.append(keypts_xyp)
209
  all_kypts.append(keypts_xyp)
210
+
211
  return list_kpts_per_crop
212
 
213
 
 
229
  ## Get DLC model and labels as strings
230
  # TODO: make a dict as for megadetector
231
  # pdb.set_trace()
232
+ path_to_DLCmodel = DownloadModel(dlc_model_input_str, str(DLC_folders[dlc_model_input_str]) )
233
+ pose_cfg_path = str(DLC_folders[dlc_model_input_str]) +'pose_cfg.yaml'
234
  #pdb.set_trece()
235
  # extract map label ids to strings
236
  # pose_cfg_dict['all_joints'] is a list of one-element lists,
 
274
  font_size=font_size,
275
  keypt_color=keypt_color,
276
  marker_size=marker_size)
277
+
278
+ return img_input, []
279
 
280
  else:
281
  # Compute kpts for each crop
 
306
  img_background.paste(img_crop, box = tuple([int(t) for t in md_results.xyxy[0][ic,:2]]))
307
 
308
 
309
+
310
  bb_per_animal = md_results.xyxy[0].tolist()[ic]
311
  pred = md_results.xyxy[0].tolist()[ic][4]
312
  if bbox_likelihood_th < pred:
313
  draw_rectangle_text(img_background, bb_per_animal ,font_style=font_style,font_size=font_size, keypt_color=keypt_color)
314
 
315
+
316
 
317
  download_file = save_results(md_results,list_kpts_per_crop,map_label_id_to_str,bbox_likelihood_th)
318
 
save_results.py CHANGED
@@ -38,7 +38,8 @@ def save_results(md_results, dlc_outputs,map_label_id_to_str,thr,output_file = '
38
 
39
  ## info dlc
40
  kypts = []
41
- pdb.set_trace()
 
42
  for s in dlc_outputs[i]:
43
  #print(s)
44
  aux1 = []
@@ -46,7 +47,7 @@ def save_results(md_results, dlc_outputs,map_label_id_to_str,thr,output_file = '
46
  aux1.append(float(j))
47
 
48
  kypts.append(aux1)
49
- pdb.set_trace()
50
  aux['dlc_pred'] = dict(zip(labels,kypts))
51
  info['bb_' + str(new_index[i]) ]=aux
52
 
 
38
 
39
  ## info dlc
40
  kypts = []
41
+ #pdb.set_trace()
42
+ print(dlc_outputs[i])
43
  for s in dlc_outputs[i]:
44
  #print(s)
45
  aux1 = []
 
47
  aux1.append(float(j))
48
 
49
  kypts.append(aux1)
50
+ #pdb.set_trace()
51
  aux['dlc_pred'] = dict(zip(labels,kypts))
52
  info['bb_' + str(new_index[i]) ]=aux
53