SyncDreamer / app.py
liuyuan-pal's picture
update
df62e57
raw
history blame
7.42 kB
from functools import partial
from PIL import Image
import numpy as np
import gradio as gr
import torch
import os
import fire
from ldm.util import add_margin
_TITLE = '''SyncDreamer: Generating Multiview-consistent Images from a Single-view Image'''
_DESCRIPTION = '''
<div>
<a style="display:inline-block" href="https://liuyuan-pal.github.io/SyncDreamer/"><img src="https://img.shields.io/badge/SyncDremer-Homepage-blue"></a>
<a style="display:inline-block; margin-left: .5em" href="https://arxiv.org/abs/2309.03453"><img src="https://img.shields.io/badge/2309.03453-f9f7f7?logo="></a>
<a style="display:inline-block; margin-left: .5em" href='https://github.com/liuyuan-pal/SyncDreamer'><img src='https://img.shields.io/github/stars/liuyuan-pal/SyncDreamer?style=social' /></a>
</div>
Given a single-view image, SyncDreamer is able to generate multiview-consistent images, which enables direct 3D reconstruction with NeuS or NeRF without SDS loss'''
_USER_GUIDE0 = "Step0: Please upload an image in the block above (or choose an example above). We use alpha values as object masks if given."
_USER_GUIDE1 = "Step1: Please select a crop size using the glider."
_USER_GUIDE2 = "Step2: Please choose a suitable elevation angle and then click the Generate button."
def mask_prediction(mask_predictor, image_in: Image.Image):
if image_in.mode=='RGBA':
return image_in
else:
raise NotImplementedError
def resize_inputs(image_input, crop_size):
alpha_np = np.asarray(image_input)[:, :, 3]
coords = np.stack(np.nonzero(alpha_np), 1)[:, (1, 0)]
min_x, min_y = np.min(coords, 0)
max_x, max_y = np.max(coords, 0)
ref_img_ = image_input.crop((min_x, min_y, max_x, max_y))
h, w = ref_img_.height, ref_img_.width
scale = crop_size / max(h, w)
h_, w_ = int(scale * h), int(scale * w)
ref_img_ = ref_img_.resize((w_, h_), resample=Image.BICUBIC)
results = add_margin(ref_img_, size=256)
return results
def run_demo():
device = f"cuda:0" if torch.cuda.is_available() else "cpu"
models = None # init_model(device, os.path.join(code_dir, ckpt))
# init sam model
mask_predictor = None # sam_init(device_idx)
# with open('instructions_12345.md', 'r') as f:
# article = f.read()
# NOTE: Examples must match inputs
example_folder = os.path.join(os.path.dirname(__file__), 'hf_demo', 'examples')
example_fns = os.listdir(example_folder)
example_fns.sort()
examples_full = [os.path.join(example_folder, x) for x in example_fns if x.endswith('.png')]
# Compose demo layout & data flow.
with gr.Blocks(title=_TITLE, css="hf_demo/style.css") as demo:
with gr.Row():
with gr.Column(scale=1):
gr.Markdown('# ' + _TITLE)
# with gr.Column(scale=0):
# gr.DuplicateButton(value='Duplicate Space for private use', elem_id='duplicate-button')
gr.Markdown(_DESCRIPTION)
with gr.Row(variant='panel'):
with gr.Column(scale=1):
image_block = gr.Image(type='pil', image_mode='RGBA', height=256, label='Input image', tool=None, interactive=True)
guide_text = gr.Markdown(_USER_GUIDE0, visible=True)
gr.Examples(
examples=examples_full, # NOTE: elements must match inputs list!
inputs=[image_block],
outputs=[image_block],
cache_examples=False,
label='Examples (click one of the images below to start)',
examples_per_page=40
)
with gr.Column(scale=1):
sam_block = gr.Image(type='pil', image_mode='RGBA', label="SAM output", height=256, interactive=False)
crop_size_slider = gr.Slider(120, 240, 200, step=10, label='Crop size', interactive=True)
with gr.Column(scale=1):
input_block = gr.Image(type='pil', image_mode='RGB', label="Input to SyncDreamer", height=256, interactive=False)
elevation_slider = gr.Slider(-10, 40, 30, step=5, label='Elevation angle', interactive=True)
run_btn = gr.Button('Run Generation', variant='primary', interactive=False)
update_guide = lambda GUIDE_TEXT: gr.update(value=GUIDE_TEXT)
image_block.change(fn=partial(mask_prediction, mask_predictor), inputs=[image_block], outputs=[sam_block], queue=False)\
.success(fn=partial(update_guide, _USER_GUIDE1), outputs=[guide_text], queue=False)
crop_size_slider.change(fn=resize_inputs, inputs=[sam_block, crop_size_slider], outputs=[input_block], queue=False)\
.success(fn=partial(update_guide, _USER_GUIDE2), outputs=[guide_text], queue=False)
run_btn.click
demo.queue().launch(share=False, max_threads=80) # auth=("admin", os.environ['PASSWD'])
if __name__=="__main__":
fire.Fire(run_demo)