ID-Pose / app.py
tokenid
update layout
3f6d90e
import spaces
import os
import numpy as np
from PIL import Image
from omegaconf import OmegaConf
from functools import partial
import gradio as gr
from huggingface_hub import hf_hub_download
import torch
from torchvision import transforms
import rembg
import cv2
from pytorch_lightning import seed_everything
from src.visualizer import CameraVisualizer
from src.pose_estimation import load_model_from_config, estimate_poses, estimate_elevs
from src.pose_funcs import find_optimal_poses
from src.utils import spherical_to_cartesian, elu_to_c2w
if torch.cuda.is_available():
_device_ = 'cuda:0'
else:
_device_ = 'cpu'
_config_path_ = 'src/configs/sd-objaverse-finetune-c_concat-256.yaml'
_ckpt_path_ = hf_hub_download(repo_id='tokenid/ID-Pose', filename='ckpts/zero123-xl.ckpt', repo_type='model')
_matcher_ckpt_path_ = hf_hub_download(repo_id='tokenid/ID-Pose', filename='ckpts/indoor_ds_new.ckpt', repo_type='model')
_config_ = OmegaConf.load(_config_path_)
_model_ = load_model_from_config(_config_, _ckpt_path_, device='cpu')
_model_ = _model_.to(_device_)
_model_.eval()
def rgba_to_rgb(img):
assert img.mode == 'RGBA'
img = np.asarray(img, dtype=np.float32)
img[:, :, :3] = img[:, :, :3] * (img[..., 3:]/255.) + (255-img[..., 3:])
img = img.clip(0, 255).astype(np.uint8)
return Image.fromarray(img[:, :, :3])
def remove_background(image, rembg_session = None, force = False, **rembg_kwargs):
do_remove = True
if image.mode == "RGBA" and image.getextrema()[3][0] < 255:
do_remove = False
do_remove = do_remove or force
if do_remove:
image = rembg.remove(image, session=rembg_session, **rembg_kwargs)
return image
def group_recenter(images, ratio=1.5, mask_thres=127, bkg_color=[255, 255, 255, 255]):
ws = []
hs = []
images = [ np.asarray(img) for img in images ]
for img in images:
alpha = img[:, :, 3]
yy, xx = np.where(alpha > mask_thres)
y0, y1 = yy.min(), yy.max()
x0, x1 = xx.min(), xx.max()
ws.append(float(x1 - x0) / img.shape[0])
hs.append(float(y1 - y0) / img.shape[1])
sz_w = np.max(ws)
sz_h = np.max(hs)
sz = max(ratio*sz_w, ratio*sz_h)
out_rgbs = []
for rgba in images:
rgb = rgba[:, :, :3]
alpha = rgba[:, :, 3]
yy, xx = np.where(alpha > mask_thres)
y0, y1 = yy.min(), yy.max()
x0, x1 = xx.min(), xx.max()
height, width, chn = rgb.shape
cy = (y0 + y1) // 2
cx = (x0 + x1) // 2
y0 = cy - int(np.floor(sz * rgba.shape[0] / 2))
y1 = cy + int(np.ceil(sz * rgba.shape[0] / 2))
x0 = cx - int(np.floor(sz * rgba.shape[1] / 2))
x1 = cx + int(np.ceil(sz * rgba.shape[1] / 2))
out = rgba[ max(y0, 0) : min(y1, height) , max(x0, 0) : min(x1, width), : ].copy()
pads = [(max(0-y0, 0), max(y1-height, 0)), (max(0-x0, 0), max(x1-width, 0)), (0, 0)]
out = np.pad(out, pads, mode='constant', constant_values=0)
out[:, :, :3] = out[:, :, :3] * (out[..., 3:]/255.) + np.array(bkg_color)[None, None, :3] * (1-out[..., 3:]/255.)
out[:, :, -1] = bkg_color[-1]
out = cv2.resize(out.astype(np.uint8), (256, 256))
out = out[:, :, :3]
out_rgbs.append(out)
return out_rgbs
def run_preprocess(image1, image2, preprocess_chk, seed_value):
seed_everything(seed_value)
if preprocess_chk:
rembg_session = rembg.new_session()
image1 = remove_background(image1, force=True, rembg_session = rembg_session)
image2 = remove_background(image2, force=True, rembg_session = rembg_session)
rgbs = group_recenter([image1, image2])
image1 = Image.fromarray(rgbs[0])
image2 = Image.fromarray(rgbs[1])
return image1, image2
def image_to_tensor(img, width=256, height=256):
img = transforms.ToTensor()(img).unsqueeze(0)
img = img * 2 - 1
img = transforms.functional.resize(img, [height, width])
return img
@spaces.GPU(duration=110)
def run_pose_exploration(image1, image2, probe_bsz, adj_bsz, adj_iters, seed_value):
seed_everything(seed_value)
image1 = image_to_tensor(image1).to(_device_)
image2 = image_to_tensor(image2).to(_device_)
images = [image1, image2]
elevs, elev_ranges = estimate_elevs(
_model_, images,
est_type='all',
matcher_ckpt_path=_matcher_ckpt_path_
)
anchor_polar = elevs[0]
if torch.mean(torch.abs(image1 - image2)) < 0.005:
theta = azimuth = radius = 0
print('Identical images found!')
else:
noise = np.random.randn(probe_bsz, 4, 32, 32)
result_poses, aux_data = estimate_poses(
_model_, images,
seed_cand_num=8,
explore_type='triangular',
refine_type='triangular',
probe_ts_range=[0.2, 0.21],
ts_range=[0.2, 0.21],
probe_bsz=probe_bsz,
adjust_factor=10.0,
adjust_iters=adj_iters,
adjust_bsz=adj_bsz,
refine_factor=1.0,
refine_iters=0,
refine_bsz=4,
noise=noise,
elevs=elevs,
elev_ranges=elev_ranges
)
theta, azimuth, radius = result_poses[0]
if anchor_polar is None:
anchor_polar = np.pi/2
explored_sph = (float(theta), float(azimuth), float(radius))
return float(anchor_polar), explored_sph
@spaces.GPU(duration=110)
def run_pose_refinement(image1, image2, est_result, refine_iters, seed_value):
seed_everything(seed_value)
anchor_polar = est_result[0]
explored_sph = est_result[1]
images = [image_to_tensor(image1).to(_device_), image_to_tensor(image2).to(_device_)]
images = [ img.permute(0, 2, 3, 1) for img in images ]
out_poses, _, loss = find_optimal_poses(
_model_, images,
1.0,
bsz=1,
n_iter=refine_iters,
init_poses={1: explored_sph},
ts_range=[0.2, 0.21],
combinations=[(0, 1), (1, 0)],
avg_last_n=20,
print_n=100
)
final_sph = out_poses[0]
theta, azimuth, radius = final_sph
xyz0 = spherical_to_cartesian((anchor_polar, 0., 4.))
c2w0 = elu_to_c2w(xyz0, np.zeros(3), np.array([0., 0., 1.]))
xyz1 = spherical_to_cartesian((theta + anchor_polar, 0. + azimuth, 4. + radius))
c2w1 = elu_to_c2w(xyz1, np.zeros(3), np.array([0., 0., 1.]))
cam_vis = CameraVisualizer([c2w0, c2w1], ['Image 1', 'Image 2'], ['red', 'blue'], images=[np.asarray(image1, dtype=np.uint8), np.asarray(image2, dtype=np.uint8)])
fig = cam_vis.update_figure(5, base_radius=-1.2, font_size=16, show_background=True, show_grid=True, show_ticklabels=True)
return (anchor_polar, final_sph), fig
def run_example(image1, image2):
image1, image2 = run_preprocess(image1, image2, True, 0)
anchor_polar, explored_sph = run_pose_exploration(image1, image2, 16, 4, 10, 0)
return (anchor_polar, explored_sph), image1, image2
def run_or_visualize(image1, image2, probe_bsz, adj_bsz, adj_iters, seed_value, est_result):
if est_result is None:
anchor_polar, explored_sph = run_pose_exploration(image1, image2, probe_bsz, adj_bsz, adj_iters, seed_value)
else:
anchor_polar = est_result[0]
explored_sph = est_result[1]
print('Using cache result.')
xyz0 = spherical_to_cartesian((anchor_polar, 0., 4.))
c2w0 = elu_to_c2w(xyz0, np.zeros(3), np.array([0., 0., 1.]))
xyz1 = spherical_to_cartesian((explored_sph[0] + anchor_polar, 0. + explored_sph[1], 4. + explored_sph[2]))
c2w1 = elu_to_c2w(xyz1, np.zeros(3), np.array([0., 0., 1.]))
cam_vis = CameraVisualizer([c2w0, c2w1], ['Image 1', 'Image 2'], ['red', 'blue'], images=[np.asarray(image1, dtype=np.uint8), np.asarray(image2, dtype=np.uint8)])
fig = cam_vis.update_figure(5, base_radius=-1.2, font_size=16, show_background=True, show_grid=True, show_ticklabels=True)
return (anchor_polar, explored_sph), fig, gr.update(interactive=True)
_HEADER_ = '''
# Official 🤗 Gradio Demo for [ID-Pose: Sparse-view Camera Pose Estimation By Inverting Diffusion Models](https://github.com/xt4d/id-pose)
- ID-Pose accepts input images with NO overlapping appearance.
- The estimation takes about 1 minute. ZeroGPU may be halted during processing due to quota restrictions.
'''
_FOOTER_ = '''
[Project Page](https://xt4d.github.io/id-pose-web/) | ⭐ [Github](https://github.com/xt4d/id-pose) ⭐ [![GitHub Stars](https://img.shields.io/github/stars/xt4d/id-pose?style=social)](https://github.com/xt4d/id-pose)
---
'''
_CITE_ = r"""
```bibtex
@article{cheng2023id,
title={ID-Pose: Sparse-view Camera Pose Estimation by Inverting Diffusion Models},
author={Cheng, Weihao and Cao, Yan-Pei and Shan, Ying},
journal={arXiv preprint arXiv:2306.17140},
year={2023}
}
```
"""
def run_demo():
demo = gr.Blocks(title='ID-Pose: Sparse-view Camera Pose Estimation By Inverting Diffusion Models')
with demo:
est_result = gr.JSON(visible=False)
gr.Markdown(_HEADER_)
with gr.Row(variant='panel'):
with gr.Column(scale=1):
with gr.Row():
input_image1 = gr.Image(type='pil', image_mode='RGBA', label='Input Image 1')
input_image2 = gr.Image(type='pil', image_mode='RGBA', label='Input Image 2')
with gr.Row():
processed_image1 = gr.Image(type='numpy', image_mode='RGB', label='Processed Image 1', interactive=False)
processed_image2 = gr.Image(type='numpy', image_mode='RGB', label='Processed Image 2', interactive=False)
with gr.Row():
preprocess_chk = gr.Checkbox(True, label='Remove background and recenter object')
with gr.Accordion('Advanced options', open=False):
probe_bsz = gr.Slider(4, 32, value=16, step=4, label='Probe Batch Size')
adj_bsz = gr.Slider(1, 8, value=4, step=1, label='Adjust Batch Size')
adj_iters = gr.Slider(1, 20, value=10, step=1, label='Adjust Iterations')
seed_value = gr.Number(value=0, label="Seed Value", precision=0)
with gr.Row():
run_btn = gr.Button('Estimate', variant='primary', interactive=True)
with gr.Row():
refine_iters = gr.Slider(0, 1000, value=0, step=50, label='Refinement Iterations')
with gr.Row():
refine_btn = gr.Button('Refine', variant='primary', interactive=False)
with gr.Row():
gr.Markdown(_FOOTER_)
with gr.Row():
gr.Markdown(_CITE_)
with gr.Column(scale=1.4):
with gr.Row():
vis_output = gr.Plot(label='Camera Pose Results: anchor (red) and target (blue)')
with gr.Row():
with gr.Column(min_width=200):
gr.Examples(
examples = [
['data/gradio_demo/duck_0.png', 'data/gradio_demo/duck_1.png'],
['data/gradio_demo/chair_0.png', 'data/gradio_demo/chair_1.png'],
['data/gradio_demo/foosball_0.png', 'data/gradio_demo/foosball_1.png'],
['data/gradio_demo/bunny_0.png', 'data/gradio_demo/bunny_1.png'],
['data/gradio_demo/circo_0.png', 'data/gradio_demo/circo_1.png'],
],
inputs=[input_image1, input_image2],
fn=run_example,
outputs=[est_result, processed_image1, processed_image2],
label='Examples (Captured)',
cache_examples='lazy',
examples_per_page=5
)
with gr.Column(min_width=200):
gr.Examples(
examples = [
['data/gradio_demo/arc_0.png', 'data/gradio_demo/arc_1.png'],
['data/gradio_demo/husky_0.png', 'data/gradio_demo/husky_1.png'],
['data/gradio_demo/cybertruck_0.png', 'data/gradio_demo/cybertruck_1.png'],
['data/gradio_demo/plane_0.png', 'data/gradio_demo/plane_1.png'],
['data/gradio_demo/christ_0.png', 'data/gradio_demo/christ_1.png'],
],
inputs=[input_image1, input_image2],
fn=run_example,
outputs=[est_result, processed_image1, processed_image2],
label='Examples (Internet)',
cache_examples='lazy',
examples_per_page=5
)
with gr.Column(min_width=200):
gr.Examples(
examples = [
['data/gradio_demo/status_0.png', 'data/gradio_demo/status_1.png'],
['data/gradio_demo/cat_0.png', 'data/gradio_demo/cat_1.png'],
['data/gradio_demo/ferrari_0.png', 'data/gradio_demo/ferrari_1.png'],
['data/gradio_demo/elon_0.png', 'data/gradio_demo/elon_1.png'],
['data/gradio_demo/ride_horse_0.png', 'data/gradio_demo/ride_horse_1.png'],
],
inputs=[input_image1, input_image2],
fn=run_example,
outputs=[est_result, processed_image1, processed_image2],
label='Examples (Generated)',
cache_examples='lazy',
examples_per_page=5
)
run_btn.click(
fn=run_preprocess,
inputs=[input_image1, input_image2, preprocess_chk, seed_value],
outputs=[processed_image1, processed_image2],
).success(
fn=run_or_visualize,
inputs=[processed_image1, processed_image2, probe_bsz, adj_bsz, adj_iters, seed_value, est_result],
outputs=[est_result, vis_output, refine_btn]
)
refine_btn.click(
fn=run_pose_refinement,
inputs=[processed_image1, processed_image2, est_result, refine_iters, seed_value],
outputs=[est_result, vis_output]
)
input_image1.clear(
fn=lambda: None,
outputs=[est_result]
)
input_image2.clear(
fn=lambda: None,
outputs=[est_result]
)
demo.launch()
if __name__ == '__main__':
run_demo()