import spaces import os import numpy as np from PIL import Image from omegaconf import OmegaConf from functools import partial import gradio as gr from huggingface_hub import hf_hub_download import torch from torchvision import transforms import rembg import cv2 from src.visualizer import CameraVisualizer from src.pose_estimation import load_model_from_config, estimate_poses, estimate_elevs from src.pose_funcs import find_optimal_poses from src.utils import spherical_to_cartesian, elu_to_c2w if torch.cuda.is_available(): _device_ = 'cuda:0' else: _device_ = 'cpu' _config_path_ = 'src/configs/sd-objaverse-finetune-c_concat-256.yaml' _ckpt_path_ = hf_hub_download(repo_id='tokenid/ID-Pose', filename='ckpts/zero123-xl.ckpt', repo_type='model') _matcher_ckpt_path_ = hf_hub_download(repo_id='tokenid/ID-Pose', filename='ckpts/indoor_ds_new.ckpt', repo_type='model') _config_ = OmegaConf.load(_config_path_) _model_ = load_model_from_config(_config_, _ckpt_path_, device='cpu') _model_ = _model_.to(_device_) _model_.eval() def rgba_to_rgb(img): assert img.mode == 'RGBA' img = np.asarray(img, dtype=np.float32) img[:, :, :3] = img[:, :, :3] * (img[..., 3:]/255.) + (255-img[..., 3:]) img = img.clip(0, 255).astype(np.uint8) return Image.fromarray(img[:, :, :3]) def remove_background(image, rembg_session = None, force = False, **rembg_kwargs): do_remove = True if image.mode == "RGBA" and image.getextrema()[3][0] < 255: do_remove = False do_remove = do_remove or force if do_remove: image = rembg.remove(image, session=rembg_session, **rembg_kwargs) return image def group_recenter(images, ratio=1.5, mask_thres=127, bkg_color=[255, 255, 255, 255]): ws = [] hs = [] images = [ np.asarray(img) for img in images ] for img in images: alpha = img[:, :, 3] yy, xx = np.where(alpha > mask_thres) y0, y1 = yy.min(), yy.max() x0, x1 = xx.min(), xx.max() ws.append(x1 - x0) hs.append(y1 - y0) sz_w = np.max(ws) sz_h = np.max(hs) sz = int( max(ratio*sz_w, ratio*sz_h) ) out_rgbs = [] for rgba in images: rgb = rgba[:, :, :3] alpha = rgba[:, :, 3] yy, xx = np.where(alpha > mask_thres) y0, y1 = yy.min(), yy.max() x0, x1 = xx.min(), xx.max() height, width, chn = rgb.shape cy = (y0 + y1) // 2 cx = (x0 + x1) // 2 y0 = cy - int(np.floor(sz / 2)) y1 = cy + int(np.ceil(sz / 2)) x0 = cx - int(np.floor(sz / 2)) x1 = cx + int(np.ceil(sz / 2)) out = rgba[ max(y0, 0) : min(y1, height) , max(x0, 0) : min(x1, width), : ].copy() pads = [(max(0-y0, 0), max(y1-height, 0)), (max(0-x0, 0), max(x1-width, 0)), (0, 0)] out = np.pad(out, pads, mode='constant', constant_values=0) assert(out.shape[:2] == (sz, sz)) out[:, :, :3] = out[:, :, :3] * (out[..., 3:]/255.) + np.array(bkg_color)[None, None, :3] * (1-out[..., 3:]/255.) out[:, :, -1] = bkg_color[-1] out = cv2.resize(out.astype(np.uint8), (256, 256)) out = out[:, :, :3] out_rgbs.append(out) return out_rgbs def run_preprocess(image1, image2, preprocess_chk): if preprocess_chk: rembg_session = rembg.new_session() image1 = remove_background(image1, force=True, rembg_session = rembg_session) image2 = remove_background(image2, force=True, rembg_session = rembg_session) rgbs = group_recenter([image1, image2]) image1 = Image.fromarray(rgbs[0]) image2 = Image.fromarray(rgbs[1]) return image1, image2 def image_to_tensor(img, width=256, height=256): img = transforms.ToTensor()(img).unsqueeze(0) img = img * 2 - 1 img = transforms.functional.resize(img, [height, width]) return img @spaces.GPU def run_pose_exploration_a(cam_vis, image1, image2): image1 = image_to_tensor(image1).to(_device_) image2 = image_to_tensor(image2).to(_device_) images = [image1, image2] elevs, elev_ranges = estimate_elevs( _model_, images, est_type='all', matcher_ckpt_path=_matcher_ckpt_path_ ) fig = None return elevs, elev_ranges, fig @spaces.GPU def run_pose_exploration_b(cam_vis, image1, image2, elevs, elev_ranges, probe_bsz, adj_bsz, adj_iters): noise = np.random.randn(probe_bsz, 4, 32, 32) cam_vis.set_images([np.asarray(image1, dtype=np.uint8), np.asarray(image2, dtype=np.uint8)]) image1 = image_to_tensor(image1).to(_device_) image2 = image_to_tensor(image2).to(_device_) images = [image1, image2] result_poses, aux_data = estimate_poses( _model_, images, seed_cand_num=8, explore_type='triangular', refine_type='triangular', probe_ts_range=[0.2, 0.21], ts_range=[0.2, 0.21], probe_bsz=probe_bsz, adjust_factor=10.0, adjust_iters=adj_iters, adjust_bsz=adj_bsz, refine_factor=1.0, refine_iters=0, refine_bsz=4, noise=noise, elevs=elevs, elev_ranges=elev_ranges ) theta, azimuth, radius = result_poses[0] anchor_polar = aux_data['elev'][0] if anchor_polar is None: anchor_polar = np.pi/2 xyz0 = spherical_to_cartesian((anchor_polar, 0., 4.)) c2w0 = elu_to_c2w(xyz0, np.zeros(3), np.array([0., 0., 1.])) xyz1 = spherical_to_cartesian((theta + anchor_polar, 0. + azimuth, 4. + radius)) c2w1 = elu_to_c2w(xyz1, np.zeros(3), np.array([0., 0., 1.])) cam_vis._poses = [c2w0, c2w1] fig = cam_vis.update_figure(5, base_radius=-1.2, font_size=16, show_background=True, show_grid=True, show_ticklabels=True) explored_sph = (theta, azimuth, radius) return anchor_polar, explored_sph, fig, gr.update(interactive=True) @spaces.GPU def run_pose_refinement(cam_vis, image1, image2, anchor_polar, explored_sph, refine_iters): cam_vis.set_images([np.asarray(image1, dtype=np.uint8), np.asarray(image2, dtype=np.uint8)]) image1 = image_to_tensor(image1).to(_device_) image2 = image_to_tensor(image2).to(_device_) images = [image1, image2] images = [ img.permute(0, 2, 3, 1) for img in images ] out_poses, _, loss = find_optimal_poses( _model_, images, 1.0, bsz=1, n_iter=refine_iters, init_poses={1: explored_sph}, ts_range=[0.2, 0.21], combinations=[(0, 1), (1, 0)], avg_last_n=20, print_n=100 ) final_sph = out_poses[0] theta, azimuth, radius = final_sph xyz0 = spherical_to_cartesian((anchor_polar, 0., 4.)) c2w0 = elu_to_c2w(xyz0, np.zeros(3), np.array([0., 0., 1.])) xyz1 = spherical_to_cartesian((theta + anchor_polar, 0. + azimuth, 4. + radius)) c2w1 = elu_to_c2w(xyz1, np.zeros(3), np.array([0., 0., 1.])) cam_vis._poses = [c2w0, c2w1] fig = cam_vis.update_figure(5, base_radius=-1.2, font_size=16, show_background=True, show_grid=True, show_ticklabels=True) return final_sph, fig _HEADER_ = ''' # Official 🤗 Gradio Demo for [ID-Pose: Sparse-view Camera Pose Estimation By Inverting Diffusion Models](https://github.com/xt4d/id-pose) - ID-Pose accepts input images with NO overlapping appearance. - The estimation takes about 1 minute. ZeroGPU may be halted during processing due to quota restrictions. ''' _FOOTER_ = ''' - Project Page: [https://xt4d.github.io/id-pose-web/](https://xt4d.github.io/id-pose-web/) - Github: [https://github.com/xt4d/id-pose](https://github.com/xt4d/id-pose) ''' _CITE_ = r""" ```bibtex @article{cheng2023id, title={ID-Pose: Sparse-view Camera Pose Estimation by Inverting Diffusion Models}, author={Cheng, Weihao and Cao, Yan-Pei and Shan, Ying}, journal={arXiv preprint arXiv:2306.17140}, year={2023} } ``` """ def run_demo(): demo = gr.Blocks(title='ID-Pose: Sparse-view Camera Pose Estimation By Inverting Diffusion Models') with demo: gr.Markdown(_HEADER_) with gr.Row(variant='panel'): with gr.Column(scale=1): with gr.Row(): with gr.Column(min_width=280): input_image1 = gr.Image(type='pil', image_mode='RGBA', label='Input Image 1', width=280) with gr.Column(min_width=280): input_image2 = gr.Image(type='pil', image_mode='RGBA', label='Input Image 2', width=280) with gr.Row(): with gr.Column(min_width=280): processed_image1 = gr.Image(type='numpy', image_mode='RGB', label='Processed Image 1', width=280, interactive=False) with gr.Column(min_width=280): processed_image2 = gr.Image(type='numpy', image_mode='RGB', label='Processed Image 2', width=280, interactive=False) with gr.Row(): preprocess_chk = gr.Checkbox(True, label='Remove background and recenter object') with gr.Accordion('Advanced options', open=False): probe_bsz = gr.Slider(4, 32, value=16, step=4, label='Probe Batch Size') adj_bsz = gr.Slider(1, 8, value=4, step=1, label='Adjust Batch Size') adj_iters = gr.Slider(1, 20, value=5, step=1, label='Adjust Iterations') with gr.Row(): run_btn = gr.Button('Estimate', variant='primary', interactive=True) with gr.Row(): refine_iters = gr.Slider(0, 1000, value=0, step=50, label='Refinement Iterations') with gr.Row(): refine_btn = gr.Button('Refine', variant='primary', interactive=False) with gr.Row(): gr.Markdown(_FOOTER_) with gr.Row(): gr.Markdown(_CITE_) with gr.Column(scale=1.4): with gr.Row(): vis_output = gr.Plot(label='Camera Pose Results: anchor (red) and target (blue)') with gr.Row(): with gr.Column(min_width=200): gr.Examples( examples = [ ['data/gradio_demo/duck_0.png', 'data/gradio_demo/duck_1.png'], ['data/gradio_demo/chair_0.png', 'data/gradio_demo/chair_1.png'], ['data/gradio_demo/foosball_0.png', 'data/gradio_demo/foosball_1.png'], ], inputs=[input_image1, input_image2], label='Examples (Self-captured)', cache_examples=False, examples_per_page=3 ) with gr.Column(min_width=200): gr.Examples( examples = [ ['data/gradio_demo/bunny_0.png', 'data/gradio_demo/bunny_1.png'], ['data/gradio_demo/bus_0.png', 'data/gradio_demo/bus_1.png'], ['data/gradio_demo/circo_0.png', 'data/gradio_demo/circo_1.png'], ], inputs=[input_image1, input_image2], label='Examples (Images from NAVI)', cache_examples=False, examples_per_page=3 ) with gr.Column(min_width=200): gr.Examples( examples = [ ['data/gradio_demo/status_0.png', 'data/gradio_demo/status_1.png'], ['data/gradio_demo/bag_0.png', 'data/gradio_demo/bag_1.png'], ['data/gradio_demo/cat_0.png', 'data/gradio_demo/cat_1.png'], ], inputs=[input_image1, input_image2], label='Examples (Generated)', cache_examples=False, examples_per_page=3 ) cam_vis = CameraVisualizer([np.eye(4), np.eye(4)], ['Image 1', 'Image 2'], ['red', 'blue']) explored_sph = gr.State() anchor_polar = gr.State() refined_sph = gr.State() elevs = gr.State() elev_ranges = gr.State() run_btn.click( fn=run_preprocess, inputs=[input_image1, input_image2, preprocess_chk], outputs=[processed_image1, processed_image2], ).success( fn=partial(run_pose_exploration_a, cam_vis), inputs=[processed_image1, processed_image2], outputs=[elevs, elev_ranges, vis_output] ).success( fn=partial(run_pose_exploration_b, cam_vis), inputs=[processed_image1, processed_image2, elevs, elev_ranges, probe_bsz, adj_bsz, adj_iters], outputs=[anchor_polar, explored_sph, vis_output, refine_btn] ) refine_btn.click( fn=partial(run_pose_refinement, cam_vis), inputs=[processed_image1, processed_image2, anchor_polar, explored_sph, refine_iters], outputs=[refined_sph, vis_output] ) demo.launch() if __name__ == '__main__': run_demo()