satpalsr commited on
Commit
a4d5b87
1 Parent(s): a97dc32

Add backbone choice

Browse files
Files changed (1) hide show
  1. app.py +10 -5
app.py CHANGED
@@ -328,13 +328,19 @@ def get_pose_estimation_prediction(pose_model, image, center, scale):
328
  return preds
329
 
330
 
331
- def main(image_bgr, box_model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)):
332
  CTX = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
333
 
334
 
335
  box_model.to(CTX)
336
  box_model.eval()
337
- model = torch.hub.load('yangsenius/TransPose:main', 'tph_a4_256x192', pretrained=True)
 
 
 
 
 
 
338
 
339
  img_dimensions = (256, 192)
340
 
@@ -355,8 +361,7 @@ def main(image_bgr, box_model = torchvision.models.detection.fasterrcnn_resnet50
355
  for kpt in pose_preds:
356
  draw_pose(kpt, image_bgr) # draw the poses
357
 
358
- im = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
359
- return im
360
 
361
  title = "TransPose"
362
  description = "Gradio demo for TransPose: Keypoint localization via Transformer. Dataset: COCO train2017 & COCO val2017."
@@ -364,5 +369,5 @@ article = "<div style='text-align: center;'><a href='https://github.com/yangseni
364
 
365
  examples = [["./examples/one.jpg"], ["./examples/two.jpg"]]
366
 
367
- iface = gr.Interface(main, inputs=gr.inputs.Image(), outputs="image", description=description, article=article, title=title, examples=examples)
368
  iface.launch(enable_queue=True, debug='True')
 
328
  return preds
329
 
330
 
331
+ def main(image_bgr, backbone_choice, box_model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)):
332
  CTX = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
333
 
334
 
335
  box_model.to(CTX)
336
  box_model.eval()
337
+
338
+ if backbone_choice == "HRNet":
339
+ backbone_choice = "tph_a4_256x192"
340
+ else:
341
+ backbone_choice == "ResNet"
342
+ backbone_choice = "tpr_a4_256x192"
343
+ model = torch.hub.load('yangsenius/TransPose:main', backbone_choice , pretrained=True)
344
 
345
  img_dimensions = (256, 192)
346
 
 
361
  for kpt in pose_preds:
362
  draw_pose(kpt, image_bgr) # draw the poses
363
 
364
+ return image_bgr
 
365
 
366
  title = "TransPose"
367
  description = "Gradio demo for TransPose: Keypoint localization via Transformer. Dataset: COCO train2017 & COCO val2017."
 
369
 
370
  examples = [["./examples/one.jpg"], ["./examples/two.jpg"]]
371
 
372
+ iface = gr.Interface(main, inputs=[gr.inputs.Image(), gr.inputs.Radio(["HRNet", "ResNet"])], outputs="image", description=description, article=article, title=title, examples=examples)
373
  iface.launch(enable_queue=True, debug='True')