sfmig commited on
Commit
57ac38a
β€’
1 Parent(s): 3e1ea97

added dict for dlc model (fix not falling into any if case)

Browse files
Files changed (1) hide show
  1. app.py +12 -20
app.py CHANGED
@@ -30,7 +30,13 @@ FONTS = {'amiko': "font/Amiko-Regular.ttf",
30
 
31
  Megadet_Models = {'md_v5a': "megadet_model/md_v5a.0.0.pt",
32
  'md_v5b': "megadet_model/md_v5b.0.0.pt"}
33
-
 
 
 
 
 
 
34
  #########################################
35
  # Draw keypoints on image
36
  def draw_keypoints_on_image(image,
@@ -164,7 +170,7 @@ def predict_dlc(list_np_crops,
164
  #####################################################
165
  def predict_pipeline(img_input,
166
  mega_model_input,
167
- model_input_str,
168
  flag_dlc_only,
169
  flag_show_str_labels,
170
  bbox_likelihood_th,
@@ -178,23 +184,9 @@ def predict_pipeline(img_input,
178
  ############################################################
179
  ## Get DLC model and labels as strings
180
  # TODO: make a dict as for megadetector
181
- if model_input_str == 'full_cat':
182
- path_to_DLCmodel = "model/DLC_Cat_resnet_50_iteration-0_shuffle-0"
183
- pose_cfg_path = os.path.join(path_to_DLCmodel,'pose_cfg.yaml')
184
- elif model_input_str == 'full_dog':
185
- path_to_DLCmodel = "model/DLC_Dog_resnet_50_iteration-0_shuffle-0"
186
- pose_cfg_path = os.path.join(path_to_DLCmodel,'pose_cfg.yaml')
187
- elif model_input_str == 'primate_face':
188
- path_to_DLCmodel = "model/DLC_FacialLandmarks_resnet_50_iteration-1_shuffle-1"
189
- pose_cfg_path = os.path.join(path_to_DLCmodel,'pose_cfg.yaml')
190
- elif model_input_str == 'full_human':
191
- path_to_DLCmodel = "model/DLC_human_dancing_resnet_101_iteration-0_shuffle-1"
192
- pose_cfg_path = os.path.join(path_to_DLCmodel,'pose_cfg.yaml')
193
- elif model_input_str == 'full_macaque':
194
- path_to_DLCmodel = "model/DLC_monkey_resnet_50_iteration-0_shuffle-1"
195
- pose_cfg_path = os.path.join(path_to_DLCmodel,'pose_cfg.yaml')
196
 
197
-
198
  # extract map label ids to strings
199
  # pose_cfg_dict['all_joints'] is a list of one-element lists,
200
  with open(pose_cfg_path, "r") as stream:
@@ -283,11 +275,11 @@ gr_image_input = gr.inputs.Image(type="pil", label="Input Image")
283
 
284
 
285
  # Models
286
- gr_dlc_model_input = gr.inputs.Dropdown(choices=['full_cat','full_dog', 'primate_face', 'full_human', 'full_macaque'], # choices
287
  default='full_cat', # default option
288
  type='value', # Type of value to be returned by component. "value" returns the string of the choice selected, "index" returns the index of the choice selected.
289
  label='Select DeepLabCut model')
290
- gr_mega_model_input = gr.inputs.Dropdown(choices=['md_v5a','md_v5b'],
291
  default='md_v5a', # default option
292
  type='value', # Type of value to be returned by component. "value" returns the string of the choice selected, "index" returns the index of the choice selected.
293
  label='Select MegaDetector model')
 
30
 
31
  Megadet_Models = {'md_v5a': "megadet_model/md_v5a.0.0.pt",
32
  'md_v5b': "megadet_model/md_v5b.0.0.pt"}
33
+
34
+ DLC_models = {'full_cat': "model/DLC_Cat_resnet_50_iteration-0_shuffle-0",
35
+ 'full_dog': "model/DLC_Dog_resnet_50_iteration-0_shuffle-0",
36
+ 'primate_face': "model/DLC_FacialLandmarks_resnet_50_iteration-1_shuffle-1",
37
+ 'full_human': "model/DLC_human_dancing_resnet_101_iteration-0_shuffle-1",
38
+ 'full_macaque': 'model/DLC_monkey_resnet_50_iteration-0_shuffle-1'}
39
+
40
  #########################################
41
  # Draw keypoints on image
42
  def draw_keypoints_on_image(image,
 
170
  #####################################################
171
  def predict_pipeline(img_input,
172
  mega_model_input,
173
+ dlc_model_input_str,
174
  flag_dlc_only,
175
  flag_show_str_labels,
176
  bbox_likelihood_th,
 
184
  ############################################################
185
  ## Get DLC model and labels as strings
186
  # TODO: make a dict as for megadetector
187
+ path_to_DLCmodel = DLC_models[dlc_model_input_str]
188
+ pose_cfg_path = os.path.join(path_to_DLCmodel,'pose_cfg.yaml')
 
 
 
 
 
 
 
 
 
 
 
 
 
189
 
 
190
  # extract map label ids to strings
191
  # pose_cfg_dict['all_joints'] is a list of one-element lists,
192
  with open(pose_cfg_path, "r") as stream:
 
275
 
276
 
277
  # Models
278
+ gr_dlc_model_input = gr.inputs.Dropdown(choices=DLC_models.keys(), # choices
279
  default='full_cat', # default option
280
  type='value', # Type of value to be returned by component. "value" returns the string of the choice selected, "index" returns the index of the choice selected.
281
  label='Select DeepLabCut model')
282
+ gr_mega_model_input = gr.inputs.Dropdown(choices=Megadet_Models.keys(),
283
  default='md_v5a', # default option
284
  type='value', # Type of value to be returned by component. "value" returns the string of the choice selected, "index" returns the index of the choice selected.
285
  label='Select MegaDetector model')