Spaces:
Running
Running
| import argparse | |
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import os | |
| import random | |
| os.system('python setup.py develop') | |
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| from PIL import ImageDraw, Image | |
| from matplotlib import pyplot as plt | |
| from mmcv import Config | |
| from mmcv.runner import load_checkpoint | |
| from mmpose.core import wrap_fp16_model | |
| from mmpose.models import build_posenet | |
| from torchvision import transforms | |
| from demo import Resize_Pad | |
| from models import * | |
| import matplotlib | |
| matplotlib.use('agg') | |
| def plot_results(support_img, query_img, support_kp, support_w, query_kp, | |
| query_w, skeleton, | |
| initial_proposals, prediction, radius=6): | |
| h, w, c = support_img.shape | |
| prediction = prediction[-1].cpu().numpy() * h | |
| query_img = (query_img - np.min(query_img)) / ( | |
| np.max(query_img) - np.min(query_img)) | |
| for id, (img, w, keypoint) in enumerate(zip([query_img], | |
| [query_w], | |
| [prediction])): | |
| f, axes = plt.subplots() | |
| plt.imshow(img) | |
| for k in range(keypoint.shape[0]): | |
| if w[k] > 0: | |
| kp = keypoint[k, :2] | |
| c = (1, 0, 0, 0.75) if w[k] == 1 else (0, 0, 1, 0.6) | |
| patch = plt.Circle(kp, radius, color=c) | |
| axes.add_patch(patch) | |
| axes.text(kp[0], kp[1], k) | |
| plt.draw() | |
| for l, limb in enumerate(skeleton): | |
| kp = keypoint[:, :2] | |
| if l > len(COLORS) - 1: | |
| c = [x / 255 for x in random.sample(range(0, 255), 3)] | |
| else: | |
| c = [x / 255 for x in COLORS[l]] | |
| if w[limb[0]] > 0 and w[limb[1]] > 0: | |
| patch = plt.Line2D([kp[limb[0], 0], kp[limb[1], 0]], | |
| [kp[limb[0], 1], kp[limb[1], 1]], | |
| linewidth=6, color=c, alpha=0.6) | |
| axes.add_artist(patch) | |
| plt.axis('off') # command for hiding the axis. | |
| plt.subplots_adjust(0, 0, 1, 1, 0, 0) | |
| return plt | |
| COLORS = [ | |
| [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], | |
| [85, 255, 0], [0, 255, 0], [0, 255, 85], [0, 255, 170], [0, 255, 255], | |
| [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], [170, 0, 255], | |
| [255, 0, 255], [255, 0, 170], [255, 0, 85], [255, 0, 0] | |
| ] | |
| def process(query_img, state, | |
| cfg_path='configs/demo_b.py'): | |
| cfg = Config.fromfile(cfg_path) | |
| width, height, _ = state['original_support_image'].shape | |
| kp_src_np = np.array(state['kp_src']).copy().astype(np.float32) | |
| kp_src_np[:, 0] = kp_src_np[:,0] / (width // 4) * cfg.model.encoder_config.img_size | |
| kp_src_np[:, 1] = kp_src_np[:,1] / (height // 4) * cfg.model.encoder_config.img_size | |
| kp_src_np = np.flip(kp_src_np, 1).copy() | |
| kp_src_tensor = torch.tensor(kp_src_np).float() | |
| preprocess = transforms.Compose([ | |
| transforms.ToTensor(), | |
| transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), | |
| Resize_Pad(cfg.model.encoder_config.img_size, | |
| cfg.model.encoder_config.img_size)]) | |
| if len(state['skeleton']) == 0: | |
| state['skeleton'] = [(0, 0)] | |
| support_img = preprocess(state['original_support_image']).flip(0)[None] | |
| np_query = np.array(query_img)[:, :, ::-1].copy() | |
| q_img = preprocess(np_query).flip(0)[None] | |
| # Create heatmap from keypoints | |
| genHeatMap = TopDownGenerateTargetFewShot() | |
| data_cfg = cfg.data_cfg | |
| data_cfg['image_size'] = np.array([cfg.model.encoder_config.img_size, | |
| cfg.model.encoder_config.img_size]) | |
| data_cfg['joint_weights'] = None | |
| data_cfg['use_different_joint_weights'] = False | |
| kp_src_3d = torch.cat( | |
| (kp_src_tensor, torch.zeros(kp_src_tensor.shape[0], 1)), dim=-1) | |
| kp_src_3d_weight = torch.cat( | |
| (torch.ones_like(kp_src_tensor), | |
| torch.zeros(kp_src_tensor.shape[0], 1)), dim=-1) | |
| target_s, target_weight_s = genHeatMap._msra_generate_target(data_cfg, | |
| kp_src_3d, | |
| kp_src_3d_weight, | |
| sigma=1) | |
| target_s = torch.tensor(target_s).float()[None] | |
| target_weight_s = torch.ones_like( | |
| torch.tensor(target_weight_s).float()[None]) | |
| data = { | |
| 'img_s': [support_img], | |
| 'img_q': q_img, | |
| 'target_s': [target_s], | |
| 'target_weight_s': [target_weight_s], | |
| 'target_q': None, | |
| 'target_weight_q': None, | |
| 'return_loss': False, | |
| 'img_metas': [{'sample_skeleton': [state['skeleton']], | |
| 'query_skeleton': state['skeleton'], | |
| 'sample_joints_3d': [kp_src_3d], | |
| 'query_joints_3d': kp_src_3d, | |
| 'sample_center': [kp_src_tensor.mean(dim=0)], | |
| 'query_center': kp_src_tensor.mean(dim=0), | |
| 'sample_scale': [ | |
| kp_src_tensor.max(dim=0)[0] - | |
| kp_src_tensor.min(dim=0)[0]], | |
| 'query_scale': kp_src_tensor.max(dim=0)[0] - | |
| kp_src_tensor.min(dim=0)[0], | |
| 'sample_rotation': [0], | |
| 'query_rotation': 0, | |
| 'sample_bbox_score': [1], | |
| 'query_bbox_score': 1, | |
| 'query_image_file': '', | |
| 'sample_image_file': [''], | |
| }] | |
| } | |
| # Load model | |
| model = build_posenet(cfg.model) | |
| fp16_cfg = cfg.get('fp16', None) | |
| if fp16_cfg is not None: | |
| wrap_fp16_model(model) | |
| load_checkpoint(model, checkpoint_path, map_location='cpu') | |
| model.eval() | |
| with torch.no_grad(): | |
| outputs = model(**data) | |
| # visualize results | |
| vis_s_weight = target_weight_s[0] | |
| vis_q_weight = target_weight_s[0] | |
| vis_s_image = support_img[0].detach().cpu().numpy().transpose(1, 2, 0) | |
| vis_q_image = q_img[0].detach().cpu().numpy().transpose(1, 2, 0) | |
| support_kp = kp_src_3d | |
| out = plot_results(vis_s_image, | |
| vis_q_image, | |
| support_kp, | |
| vis_s_weight, | |
| None, | |
| vis_q_weight, | |
| state['skeleton'], | |
| None, | |
| torch.tensor(outputs['points']).squeeze(0), | |
| ) | |
| return out, state | |
| def update_examples(support_img, posed_support, query_img, state, r=0.015, width=0.02): | |
| state['color_idx'] = 0 | |
| state['original_support_image'] = np.array(support_img)[:, :, ::-1].copy() | |
| support_img, posed_support, _ = set_query(support_img, state, example=True) | |
| w, h = support_img.size | |
| draw_pose = ImageDraw.Draw(support_img) | |
| draw_limb = ImageDraw.Draw(posed_support) | |
| r = int(r * w) | |
| width = int(width * w) | |
| for pixel in state['kp_src']: | |
| leftUpPoint = (pixel[1] - r, pixel[0] - r) | |
| rightDownPoint = (pixel[1] + r, pixel[0] + r) | |
| twoPointList = [leftUpPoint, rightDownPoint] | |
| draw_pose.ellipse(twoPointList, fill=(255, 0, 0, 255)) | |
| draw_limb.ellipse(twoPointList, fill=(255, 0, 0, 255)) | |
| for limb in state['skeleton']: | |
| point_a = state['kp_src'][limb[0]][::-1] | |
| point_b = state['kp_src'][limb[1]][::-1] | |
| if state['color_idx'] < len(COLORS): | |
| c = COLORS[state['color_idx']] | |
| state['color_idx'] += 1 | |
| else: | |
| c = random.choices(range(256), k=3) | |
| draw_limb.line([point_a, point_b], fill=tuple(c), width=width) | |
| return support_img, posed_support, query_img, state | |
| def get_select_coords(kp_support, | |
| limb_support, | |
| state, | |
| evt: gr.SelectData, | |
| r=0.015): | |
| pixels_in_queue = set() | |
| pixels_in_queue.add((evt.index[1], evt.index[0])) | |
| while len(pixels_in_queue) > 0: | |
| pixel = pixels_in_queue.pop() | |
| if pixel[0] is not None and pixel[1] is not None and pixel not in \ | |
| state['kp_src']: | |
| state['kp_src'].append(pixel) | |
| else: | |
| continue | |
| if limb_support is None: | |
| canvas_limb = kp_support | |
| else: | |
| canvas_limb = limb_support | |
| canvas_kp = kp_support | |
| w, h = canvas_kp.size | |
| draw_pose = ImageDraw.Draw(canvas_kp) | |
| draw_limb = ImageDraw.Draw(canvas_limb) | |
| r = int(r * w) | |
| leftUpPoint = (pixel[1] - r, pixel[0] - r) | |
| rightDownPoint = (pixel[1] + r, pixel[0] + r) | |
| twoPointList = [leftUpPoint, rightDownPoint] | |
| draw_pose.ellipse(twoPointList, fill=(255, 0, 0, 255)) | |
| draw_limb.ellipse(twoPointList, fill=(255, 0, 0, 255)) | |
| return canvas_kp, canvas_limb, state | |
| def get_limbs(kp_support, | |
| state, | |
| evt: gr.SelectData, | |
| r=0.02, width=0.02): | |
| curr_pixel = (evt.index[1], evt.index[0]) | |
| pixels_in_queue = set() | |
| pixels_in_queue.add((evt.index[1], evt.index[0])) | |
| canvas_kp = kp_support | |
| w, h = canvas_kp.size | |
| r = int(r * w) | |
| width = int(width * w) | |
| while len(pixels_in_queue) > 0 and curr_pixel != state['prev_clicked']: | |
| pixel = pixels_in_queue.pop() | |
| state['prev_clicked'] = pixel | |
| closest_point = min(state['kp_src'], | |
| key=lambda p: (p[0] - pixel[0]) ** 2 + | |
| (p[1] - pixel[1]) ** 2) | |
| closest_point_index = state['kp_src'].index(closest_point) | |
| draw_limb = ImageDraw.Draw(canvas_kp) | |
| if state['color_idx'] < len(COLORS): | |
| c = COLORS[state['color_idx']] | |
| else: | |
| c = random.choices(range(256), k=3) | |
| leftUpPoint = (closest_point[1] - r, closest_point[0] - r) | |
| rightDownPoint = (closest_point[1] + r, closest_point[0] + r) | |
| twoPointList = [leftUpPoint, rightDownPoint] | |
| draw_limb.ellipse(twoPointList, fill=tuple(c)) | |
| if state['count'] == 0: | |
| state['prev_pt'] = closest_point[1], closest_point[0] | |
| state['prev_pt_idx'] = closest_point_index | |
| state['count'] = state['count'] + 1 | |
| else: | |
| if state['prev_pt_idx'] != closest_point_index: | |
| # Create Line and add Limb | |
| draw_limb.line( | |
| [state['prev_pt'], (closest_point[1], closest_point[0])], | |
| fill=tuple(c), | |
| width=width) | |
| state['skeleton'].append( | |
| (state['prev_pt_idx'], closest_point_index)) | |
| state['color_idx'] = state['color_idx'] + 1 | |
| else: | |
| draw_limb.ellipse(twoPointList, fill=(255, 0, 0, 255)) | |
| state['count'] = 0 | |
| return canvas_kp, state | |
| def set_query(support_img, state, example=False): | |
| if not example: | |
| state['skeleton'].clear() | |
| state['kp_src'].clear() | |
| state['original_support_image'] = np.array(support_img)[:, :, ::-1].copy() | |
| width, height = support_img.size | |
| support_img = support_img.resize((width // 4, width // 4), | |
| Image.Resampling.LANCZOS) | |
| return support_img, support_img, state | |
| with gr.Blocks() as demo: | |
| state = gr.State({ | |
| 'kp_src': [], | |
| 'skeleton': [], | |
| 'count': 0, | |
| 'color_idx': 0, | |
| 'prev_pt': None, | |
| 'prev_pt_idx': None, | |
| 'prev_clicked': None, | |
| 'original_support_image': None, | |
| }) | |
| gr.Markdown(''' | |
| # Pose Anything Demo | |
| We present a novel approach to category agnostic pose estimation that | |
| leverages the inherent geometrical relations between keypoints through a | |
| newly designed Graph Transformer Decoder. By capturing and incorporating | |
| this crucial structural information, our method enhances the accuracy of | |
| keypoint localization, marking a significant departure from conventional | |
| CAPE techniques that treat keypoints as isolated entities. | |
| ### [Paper](https://arxiv.org/abs/2311.17891) | [Official Repo](https://github.com/orhir/PoseAnything) | |
| ## Instructions | |
| 1. Upload an image of the object you want to pose on the **left** image. | |
| 2. Click on the **left** image to mark keypoints. | |
| 3. Click on the keypoints on the **right** image to mark limbs. | |
| 4. Upload an image of the object you want to pose to the query image ( | |
| **bottom**). | |
| 5. Click **Evaluate** to pose the query image. | |
| ''') | |
| with gr.Row(): | |
| support_img = gr.Image(label="Support Image", | |
| type="pil", | |
| info='Click to mark keypoints').style( | |
| height=400, width=400) | |
| posed_support = gr.Image(label="Posed Support Image", | |
| type="pil", | |
| interactive=False).style(height=400, | |
| width=400) | |
| with gr.Row(): | |
| query_img = gr.Image(label="Query Image", | |
| type="pil").style(height=400, width=400) | |
| with gr.Row(): | |
| eval_btn = gr.Button(value="Evaluate") | |
| with gr.Row(): | |
| output_img = gr.Plot(label="Output Image", height=400, width=400) | |
| with gr.Row(): | |
| gr.Markdown("## Examples") | |
| with gr.Row(): | |
| gr.Examples( | |
| examples=[ | |
| ['examples/dog2.png', | |
| 'examples/dog2.png', | |
| 'examples/dog1.png', | |
| {'kp_src': [(50, 58), (51, 78), (66, 57), (118, 79), | |
| (154, 79), (217, 74), (218, 103), (156, 104), | |
| (152, 151), (215, 162), (213, 191), | |
| (152, 174), (108, 171)], | |
| 'skeleton': [(0, 1), (1, 2), (0, 2), (3, 4), (4, 5), | |
| (3, 7), (7, 6), (3, 12), (12, 8), (8, 9), | |
| (12, 11), (11, 10)], 'count': 0, | |
| 'color_idx': 0, 'prev_pt': (174, 152), | |
| 'prev_pt_idx': 11, 'prev_clicked': (207, 186), | |
| 'original_support_image': None, | |
| } | |
| ], | |
| ['examples/sofa1.jpg', | |
| 'examples/sofa1.jpg', | |
| 'examples/sofa2.jpg', | |
| { | |
| 'kp_src': [(82, 28), (65, 30), (52, 26), (65, 50), | |
| (84, 52), (53, 54), (43, 52), (45, 71), | |
| (81, 69), (77, 39), (57, 43), (58, 64), | |
| (46, 42), (49, 65)], | |
| 'skeleton': [(0, 1), (3, 1), (3, 4), (10, 9), (11, 8), | |
| (1, 10), (10, 11), (11, 3), (1, 2), (7, 6), | |
| (5, 13), (5, 3), (13, 11), (12, 10), (12, 2), | |
| (6, 10), (7, 11)], 'count': 0, | |
| 'color_idx': 23, 'prev_pt': (71, 45), 'prev_pt_idx': 7, | |
| 'prev_clicked': (56, 63), | |
| 'original_support_image': None, | |
| }], | |
| ['examples/person1.jpeg', | |
| 'examples/person1.jpeg', | |
| 'examples/person2.jpeg', | |
| { | |
| 'kp_src': [(121, 95), (122, 160), (154, 130), (184, 106), | |
| (181, 153)], | |
| 'skeleton': [(0, 1), (1, 2), (0, 2), (2, 3), (2, 4), | |
| (4, 3)], 'count': 0, 'color_idx': 6, | |
| 'prev_pt': (153, 181), 'prev_pt_idx': 4, | |
| 'prev_clicked': (181, 108), | |
| 'original_support_image': None, | |
| }] | |
| ], | |
| inputs=[support_img, posed_support, query_img, state], | |
| outputs=[support_img, posed_support, query_img, state], | |
| fn=update_examples, | |
| run_on_click=True, | |
| ) | |
| support_img.select(get_select_coords, | |
| [support_img, posed_support, state], | |
| [support_img, posed_support, state]) | |
| support_img.upload(set_query, | |
| inputs=[support_img, state], | |
| outputs=[support_img, posed_support, state]) | |
| posed_support.select(get_limbs, | |
| [posed_support, state], | |
| [posed_support, state]) | |
| eval_btn.click(fn=process, | |
| inputs=[query_img, state], | |
| outputs=[output_img, state]) | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser(description='Pose Anything Demo') | |
| parser.add_argument('--checkpoint', | |
| help='checkpoint path', | |
| default='1shot-swin_graph_split1.pth') | |
| args = parser.parse_args() | |
| checkpoint_path = args.checkpoint | |
| print("Loading checkpoint from {}".format(checkpoint_path)) | |
| print(os.path.exists(checkpoint_path)) | |
| demo.launch() | |