import argparse # Copyright (c) OpenMMLab. All rights reserved. import os import random os.system('python setup.py develop') import gradio as gr import numpy as np import torch from PIL import ImageDraw, Image from matplotlib import pyplot as plt from mmcv import Config from mmcv.runner import load_checkpoint from mmpose.core import wrap_fp16_model from mmpose.models import build_posenet from torchvision import transforms from demo import Resize_Pad from models import * import matplotlib matplotlib.use('agg') def plot_results(support_img, query_img, support_kp, support_w, query_kp, query_w, skeleton, initial_proposals, prediction, radius=6): h, w, c = support_img.shape prediction = prediction[-1].cpu().numpy() * h query_img = (query_img - np.min(query_img)) / ( np.max(query_img) - np.min(query_img)) for id, (img, w, keypoint) in enumerate(zip([query_img], [query_w], [prediction])): f, axes = plt.subplots() plt.imshow(img) for k in range(keypoint.shape[0]): if w[k] > 0: kp = keypoint[k, :2] c = (1, 0, 0, 0.75) if w[k] == 1 else (0, 0, 1, 0.6) patch = plt.Circle(kp, radius, color=c) axes.add_patch(patch) axes.text(kp[0], kp[1], k) plt.draw() for l, limb in enumerate(skeleton): kp = keypoint[:, :2] if l > len(COLORS) - 1: c = [x / 255 for x in random.sample(range(0, 255), 3)] else: c = [x / 255 for x in COLORS[l]] if w[limb[0]] > 0 and w[limb[1]] > 0: patch = plt.Line2D([kp[limb[0], 0], kp[limb[1], 0]], [kp[limb[0], 1], kp[limb[1], 1]], linewidth=6, color=c, alpha=0.6) axes.add_artist(patch) plt.axis('off') # command for hiding the axis. plt.subplots_adjust(0, 0, 1, 1, 0, 0) return plt COLORS = [ [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85], [255, 0, 0] ] kp_src = [] skeleton = [] count = 0 color_idx = 0 prev_pt = None prev_pt_idx = None prev_clicked = None original_support_image = None checkpoint_path = '' def process(query_img, cfg_path='configs/demo_b.py'): global skeleton cfg = Config.fromfile(cfg_path) kp_src_np = np.array(kp_src).copy().astype(np.float32) kp_src_np[:, 0] = kp_src_np[:, 0] / 128. * cfg.model.encoder_config.img_size kp_src_np[:, 1] = kp_src_np[:, 1] / 128. * cfg.model.encoder_config.img_size kp_src_np = np.flip(kp_src_np, 1).copy() kp_src_tensor = torch.tensor(kp_src_np).float() preprocess = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), Resize_Pad(cfg.model.encoder_config.img_size, cfg.model.encoder_config.img_size)]) if len(skeleton) == 0: skeleton = [(0, 0)] support_img = preprocess(original_support_image).flip(0)[None] np_query = np.array(query_img)[:, :, ::-1].copy() q_img = preprocess(np_query).flip(0)[None] # Create heatmap from keypoints genHeatMap = TopDownGenerateTargetFewShot() data_cfg = cfg.data_cfg data_cfg['image_size'] = np.array([cfg.model.encoder_config.img_size, cfg.model.encoder_config.img_size]) data_cfg['joint_weights'] = None data_cfg['use_different_joint_weights'] = False kp_src_3d = torch.cat( (kp_src_tensor, torch.zeros(kp_src_tensor.shape[0], 1)), dim=-1) kp_src_3d_weight = torch.cat( (torch.ones_like(kp_src_tensor), torch.zeros(kp_src_tensor.shape[0], 1)), dim=-1) target_s, target_weight_s = genHeatMap._msra_generate_target(data_cfg, kp_src_3d, kp_src_3d_weight, sigma=1) target_s = torch.tensor(target_s).float()[None] target_weight_s = torch.ones_like( torch.tensor(target_weight_s).float()[None]) data = { 'img_s': [support_img], 'img_q': q_img, 'target_s': [target_s], 'target_weight_s': [target_weight_s], 'target_q': None, 'target_weight_q': None, 'return_loss': False, 'img_metas': [{'sample_skeleton': [skeleton], 'query_skeleton': skeleton, 'sample_joints_3d': [kp_src_3d], 'query_joints_3d': kp_src_3d, 'sample_center': [kp_src_tensor.mean(dim=0)], 'query_center': kp_src_tensor.mean(dim=0), 'sample_scale': [ kp_src_tensor.max(dim=0)[0] - kp_src_tensor.min(dim=0)[0]], 'query_scale': kp_src_tensor.max(dim=0)[0] - kp_src_tensor.min(dim=0)[0], 'sample_rotation': [0], 'query_rotation': 0, 'sample_bbox_score': [1], 'query_bbox_score': 1, 'query_image_file': '', 'sample_image_file': [''], }] } # Load model model = build_posenet(cfg.model) fp16_cfg = cfg.get('fp16', None) if fp16_cfg is not None: wrap_fp16_model(model) load_checkpoint(model, checkpoint_path, map_location='cpu') model.eval() with torch.no_grad(): outputs = model(**data) # visualize results vis_s_weight = target_weight_s[0] vis_q_weight = target_weight_s[0] vis_s_image = support_img[0].detach().cpu().numpy().transpose(1, 2, 0) vis_q_image = q_img[0].detach().cpu().numpy().transpose(1, 2, 0) support_kp = kp_src_3d out = plot_results(vis_s_image, vis_q_image, support_kp, vis_s_weight, None, vis_q_weight, skeleton, None, torch.tensor(outputs['points']).squeeze(0), ) return out with gr.Blocks() as demo: gr.Markdown(''' # Pose Anything Demo We present a novel approach to category agnostic pose estimation that leverages the inherent geometrical relations between keypoints through a newly designed Graph Transformer Decoder. By capturing and incorporating this crucial structural information, our method enhances the accuracy of keypoint localization, marking a significant departure from conventional CAPE techniques that treat keypoints as isolated entities. ### [Paper](https://arxiv.org/abs/2311.17891) | [Official Repo](https://github.com/orhir/PoseAnything) ![](/file=gradio_teaser.png) ## Instructions 1. Upload an image of the object you want to pose on the **left** image. 2. Click on the **left** image to mark keypoints. 3. Click on the keypoints on the **right** image to mark limbs. 4. Upload an image of the object you want to pose to the query image (**bottom**). 5. Click **Evaluate** to pose the query image. ''') with gr.Row(): support_img = gr.Image(label="Support Image", type="pil", info='Click to mark keypoints').style( height=256, width=256) posed_support = gr.Image(label="Posed Support Image", type="pil", interactive=False).style(height=256, width=256) with gr.Row(): query_img = gr.Image(label="Query Image", type="pil").style(height=256, width=256) with gr.Row(): eval_btn = gr.Button(value="Evaluate") with gr.Row(): output_img = gr.Plot(label="Output Image", height=256, width=256) def get_select_coords(kp_support, limb_support, evt: gr.SelectData, r=0.015): pixels_in_queue = set() pixels_in_queue.add((evt.index[1], evt.index[0])) while len(pixels_in_queue) > 0: pixel = pixels_in_queue.pop() if pixel[0] is not None and pixel[ 1] is not None and pixel not in kp_src: kp_src.append(pixel) else: print("Invalid pixel") if limb_support is None: canvas_limb = kp_support else: canvas_limb = limb_support canvas_kp = kp_support w, h = canvas_kp.size draw_pose = ImageDraw.Draw(canvas_kp) draw_limb = ImageDraw.Draw(canvas_limb) r = int(r * w) leftUpPoint = (pixel[1] - r, pixel[0] - r) rightDownPoint = (pixel[1] + r, pixel[0] + r) twoPointList = [leftUpPoint, rightDownPoint] draw_pose.ellipse(twoPointList, fill=(255, 0, 0, 255)) draw_limb.ellipse(twoPointList, fill=(255, 0, 0, 255)) return canvas_kp, canvas_limb def get_limbs(kp_support, evt: gr.SelectData, r=0.02, width=0.02): global count, color_idx, prev_pt, skeleton, prev_pt_idx, prev_clicked curr_pixel = (evt.index[1], evt.index[0]) pixels_in_queue = set() pixels_in_queue.add((evt.index[1], evt.index[0])) canvas_kp = kp_support w, h = canvas_kp.size r = int(r * w) width = int(width * w) while (len(pixels_in_queue) > 0 and curr_pixel != prev_clicked and len(kp_src) > 0): pixel = pixels_in_queue.pop() prev_clicked = pixel closest_point = min(kp_src, key=lambda p: (p[0] - pixel[0]) ** 2 + (p[1] - pixel[1]) ** 2) closest_point_index = kp_src.index(closest_point) draw_limb = ImageDraw.Draw(canvas_kp) if color_idx < len(COLORS): c = COLORS[color_idx] else: c = random.choices(range(256), k=3) leftUpPoint = (closest_point[1] - r, closest_point[0] - r) rightDownPoint = (closest_point[1] + r, closest_point[0] + r) twoPointList = [leftUpPoint, rightDownPoint] draw_limb.ellipse(twoPointList, fill=tuple(c)) if count == 0: prev_pt = closest_point[1], closest_point[0] prev_pt_idx = closest_point_index count = count + 1 else: if prev_pt_idx != closest_point_index: # Create Line and add Limb draw_limb.line([prev_pt, (closest_point[1], closest_point[0])], fill=tuple(c), width=width) skeleton.append((prev_pt_idx, closest_point_index)) color_idx = color_idx + 1 else: draw_limb.ellipse(twoPointList, fill=(255, 0, 0, 255)) count = 0 return canvas_kp def set_query(support_img): global original_support_image skeleton.clear() kp_src.clear() original_support_image = np.array(support_img)[:, :, ::-1].copy() support_img = support_img.resize((128, 128), Image.Resampling.LANCZOS) return support_img, support_img support_img.select(get_select_coords, [support_img, posed_support], [support_img, posed_support], ) support_img.upload(set_query, inputs=support_img, outputs=[support_img,posed_support]) posed_support.select(get_limbs, posed_support, posed_support) eval_btn.click(fn=process, inputs=[query_img], outputs=output_img) if __name__ == "__main__": parser = argparse.ArgumentParser(description='Pose Anything Demo') parser.add_argument('--checkpoint', help='checkpoint path', default='1shot-swin_graph_split1.pth') args = parser.parse_args() checkpoint_path = args.checkpoint print("Loading checkpoint from {}".format(checkpoint_path)) print(os.path.exists(checkpoint_path)) demo.launch()