Spaces:
Runtime error
Runtime error
Add backbone choice
Browse files
app.py
CHANGED
@@ -328,13 +328,19 @@ def get_pose_estimation_prediction(pose_model, image, center, scale):
|
|
328 |
return preds
|
329 |
|
330 |
|
331 |
-
def main(image_bgr, box_model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)):
|
332 |
CTX = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
333 |
|
334 |
|
335 |
box_model.to(CTX)
|
336 |
box_model.eval()
|
337 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
338 |
|
339 |
img_dimensions = (256, 192)
|
340 |
|
@@ -355,8 +361,7 @@ def main(image_bgr, box_model = torchvision.models.detection.fasterrcnn_resnet50
|
|
355 |
for kpt in pose_preds:
|
356 |
draw_pose(kpt, image_bgr) # draw the poses
|
357 |
|
358 |
-
|
359 |
-
return im
|
360 |
|
361 |
title = "TransPose"
|
362 |
description = "Gradio demo for TransPose: Keypoint localization via Transformer. Dataset: COCO train2017 & COCO val2017."
|
@@ -364,5 +369,5 @@ article = "<div style='text-align: center;'><a href='https://github.com/yangseni
|
|
364 |
|
365 |
examples = [["./examples/one.jpg"], ["./examples/two.jpg"]]
|
366 |
|
367 |
-
iface = gr.Interface(main, inputs=gr.inputs.Image(), outputs="image", description=description, article=article, title=title, examples=examples)
|
368 |
iface.launch(enable_queue=True, debug='True')
|
|
|
328 |
return preds
|
329 |
|
330 |
|
331 |
+
def main(image_bgr, backbone_choice, box_model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)):
|
332 |
CTX = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
333 |
|
334 |
|
335 |
box_model.to(CTX)
|
336 |
box_model.eval()
|
337 |
+
|
338 |
+
if backbone_choice == "HRNet":
|
339 |
+
backbone_choice = "tph_a4_256x192"
|
340 |
+
else:
|
341 |
+
backbone_choice == "ResNet"
|
342 |
+
backbone_choice = "tpr_a4_256x192"
|
343 |
+
model = torch.hub.load('yangsenius/TransPose:main', backbone_choice , pretrained=True)
|
344 |
|
345 |
img_dimensions = (256, 192)
|
346 |
|
|
|
361 |
for kpt in pose_preds:
|
362 |
draw_pose(kpt, image_bgr) # draw the poses
|
363 |
|
364 |
+
return image_bgr
|
|
|
365 |
|
366 |
title = "TransPose"
|
367 |
description = "Gradio demo for TransPose: Keypoint localization via Transformer. Dataset: COCO train2017 & COCO val2017."
|
|
|
369 |
|
370 |
examples = [["./examples/one.jpg"], ["./examples/two.jpg"]]
|
371 |
|
372 |
+
iface = gr.Interface(main, inputs=[gr.inputs.Image(), gr.inputs.Radio(["HRNet", "ResNet"])], outputs="image", description=description, article=article, title=title, examples=examples)
|
373 |
iface.launch(enable_queue=True, debug='True')
|