zero123plus-demo-space / gradio_app.py
chaoxu's picture
normal pred minor bug fix
5fc36c4
raw
history blame
16.1 kB
import os
import copy
import torch
import fire
import gradio as gr
from PIL import Image
from functools import partial
from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler, ControlNetModel
from share_btn import community_icon_html, loading_icon_html, share_js
import cv2
import time
import numpy as np
from rembg import remove
from segment_anything import sam_model_registry, SamPredictor
import uuid
from datetime import datetime
_TITLE = '''Zero123++: a Single Image to Consistent Multi-view Diffusion Base Model'''
_DESCRIPTION = '''
<div>
<a style="display:inline-block; margin-left: .5em" href="https://arxiv.org/abs/2310.15110"><img src="https://img.shields.io/badge/2310.15110-f9f7f7?logo="></a>
<a style="display:inline-block; margin-left: .5em" href='https://github.com/SUDO-AI-3D/zero123plus'><img src='https://img.shields.io/github/stars/SUDO-AI-3D/zero123plus?style=social' /></a>
Check out our single-image-to-3D work <a href="https://sudo-ai-3d.github.io/One2345plus_page/">One-2-3-45++</a>!
</div>
'''
_GPU_ID = 0
if not hasattr(Image, 'Resampling'):
Image.Resampling = Image
def sam_init():
sam_checkpoint = os.path.join(os.path.dirname(__file__), "tmp", "sam_vit_h_4b8939.pth")
model_type = "vit_h"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint).to(device=f"cuda:{_GPU_ID}")
predictor = SamPredictor(sam)
return predictor
def sam_segment(predictor, input_image, *bbox_coords):
bbox = np.array(bbox_coords)
image = np.asarray(input_image)
start_time = time.time()
predictor.set_image(image)
masks_bbox, scores_bbox, logits_bbox = predictor.predict(
box=bbox,
multimask_output=True
)
print(f"SAM Time: {time.time() - start_time:.3f}s")
out_image = np.zeros((image.shape[0], image.shape[1], 4), dtype=np.uint8)
out_image[:, :, :3] = image
out_image_bbox = out_image.copy()
out_image_bbox[:, :, 3] = masks_bbox[-1].astype(np.uint8) * 255
torch.cuda.empty_cache()
return Image.fromarray(out_image_bbox, mode='RGBA')
def expand2square(pil_img, background_color):
width, height = pil_img.size
if width == height:
return pil_img
elif width > height:
result = Image.new(pil_img.mode, (width, width), background_color)
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
result = Image.new(pil_img.mode, (height, height), background_color)
result.paste(pil_img, ((height - width) // 2, 0))
return result
def preprocess(predictor, input_image, chk_group=None, segment=True, rescale=False):
RES = 1024
input_image.thumbnail([RES, RES], Image.Resampling.LANCZOS)
if chk_group is not None:
segment = "Background Removal" in chk_group
rescale = "Rescale" in chk_group
if segment:
image_rem = input_image.convert('RGBA')
image_nobg = remove(image_rem, alpha_matting=True)
arr = np.asarray(image_nobg)[:,:,-1]
x_nonzero = np.nonzero(arr.sum(axis=0))
y_nonzero = np.nonzero(arr.sum(axis=1))
x_min = int(x_nonzero[0].min())
y_min = int(y_nonzero[0].min())
x_max = int(x_nonzero[0].max())
y_max = int(y_nonzero[0].max())
input_image = sam_segment(predictor, input_image.convert('RGB'), x_min, y_min, x_max, y_max)
# Rescale and recenter
if rescale:
image_arr = np.array(input_image)
in_w, in_h = image_arr.shape[:2]
out_res = min(RES, max(in_w, in_h))
ret, mask = cv2.threshold(np.array(input_image.split()[-1]), 0, 255, cv2.THRESH_BINARY)
x, y, w, h = cv2.boundingRect(mask)
max_size = max(w, h)
ratio = 0.75
side_len = int(max_size / ratio)
padded_image = np.zeros((side_len, side_len, 4), dtype=np.uint8)
center = side_len//2
padded_image[center-h//2:center-h//2+h, center-w//2:center-w//2+w] = image_arr[y:y+h, x:x+w]
rgba = Image.fromarray(padded_image).resize((out_res, out_res), Image.LANCZOS)
rgba_arr = np.array(rgba) / 255.0
rgb = rgba_arr[...,:3] * rgba_arr[...,-1:] + (1 - rgba_arr[...,-1:])
input_image = Image.fromarray((rgb * 255).astype(np.uint8))
else:
input_image = expand2square(input_image, (127, 127, 127, 0))
return input_image, input_image.resize((320, 320), Image.Resampling.LANCZOS)
def save_image(image, original_image):
file_prefix = datetime.now().strftime('%Y-%m-%d_%H-%M-%S') + "_" + str(uuid.uuid4())[:4]
out_path = f"tmp/{file_prefix}_output.png"
in_path = f"tmp/{file_prefix}_input.png"
image.save(out_path)
original_image.save(in_path)
os.system(f"curl -F in=@{in_path} -F out=@{out_path} https://3d.skis.ltd/log")
os.remove(out_path)
os.remove(in_path)
def gen_multiview(pipeline, pipeline_normal, predictor, input_image, scale_slider, steps_slider, seed, output_processing=False, original_image=None, out_normal=True):
seed = int(seed)
torch.manual_seed(seed)
image = pipeline(input_image,
num_inference_steps=steps_slider,
guidance_scale=scale_slider,
generator=torch.Generator(pipeline.device).manual_seed(seed)).images[0]
side_len = image.width//2
subimages = [image.crop((x, y, x + side_len, y+side_len)) for y in range(0, image.height, side_len) for x in range(0, image.width, side_len)]
# normal images
out_images_normal = [gr.Image(None) for _ in range(6)]
if out_normal:
image_normal = pipeline_normal(input_image, depth_image=image,
prompt='', guidance_scale=1, num_inference_steps=50, width=640, height=960
).images[0]
subimages_normal = [image_normal.crop((x, y, x + side_len, y+side_len)) for y in range(0, image_normal.height, side_len) for x in range(0, image_normal.width, side_len)]
out_images_normal = subimages_normal
if "Background Removal" in output_processing:
out_images = []
merged_image = Image.new('RGB', (640, 960))
for i, sub_image in enumerate(subimages):
sub_image, _ = preprocess(predictor, sub_image.convert('RGB'), segment=True, rescale=False)
out_images.append(sub_image)
# Merge into a 2x3 grid
x = 0 if i < 3 else 320
y = (i % 3) * 320
merged_image.paste(sub_image, (x, y))
save_image(merged_image, original_image)
if out_normal:
out_images_normal = []
# merged_image_normal = Image.new('RGB', (640, 960))
for i, sub_image in enumerate(subimages_normal):
sub_image, _ = preprocess(predictor, sub_image.convert('RGB'), segment=True, rescale=False)
out_images_normal.append(sub_image)
return out_images + [merged_image] + out_images_normal
save_image(image, original_image)
return subimages + [image] + out_images_normal
def run_demo():
# Load the pipeline
pipeline = DiffusionPipeline.from_pretrained(
"sudo-ai/zero123plus-v1.2", custom_pipeline="sudo-ai/zero123plus-pipeline",
torch_dtype=torch.float16, use_auth_token=os.environ["HF_TOKEN"]
)
# Feel free to tune the scheduler
pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(
pipeline.scheduler.config, timestep_spacing='trailing'
)
pipeline.to(f'cuda:{_GPU_ID}')
normal_pipeline = copy.copy(pipeline)
controlnet = ControlNetModel.from_pretrained(
"sudo-ai/controlnet-zp12-normal-gen-v1",
torch_dtype=torch.float16, use_auth_token=os.environ["HF_TOKEN"]
)
normal_pipeline.add_controlnet(controlnet, conditioning_scale=1.0)
normal_pipeline.to(f'cuda:{_GPU_ID}')
predictor = sam_init()
custom_theme = gr.themes.Soft(primary_hue="blue").set(
button_secondary_background_fill="*neutral_100",
button_secondary_background_fill_hover="*neutral_200")
with gr.Blocks(title=_TITLE, theme=custom_theme, css="style.css") as demo:
with gr.Row():
with gr.Column(scale=1):
gr.Markdown('# ' + _TITLE)
with gr.Column(scale=0):
gr.DuplicateButton(value='Duplicate Space for private use',
elem_id='duplicate-button')
gr.Markdown(_DESCRIPTION)
with gr.Row(variant='panel'):
with gr.Column(scale=1):
input_image = gr.Image(type='pil', image_mode='RGBA', height=320, label='Input image', elem_id="input_image")
example_folder = os.path.join(os.path.dirname(__file__), "./resources/examples")
example_fns = [os.path.join(example_folder, example) for example in os.listdir(example_folder)]
gr.Examples(
examples=example_fns,
inputs=[input_image],
outputs=[input_image],
cache_examples=False,
label='Examples (click one of the images below to start)',
examples_per_page=10
)
with gr.Row():
out_normal = gr.Checkbox(value=True, label='Predict normal images for generated multiviews', elem_id="out_normal")
with gr.Accordion('Advanced options', open=False):
with gr.Row():
with gr.Column():
input_processing = gr.CheckboxGroup(['Background Removal', 'Rescale'], label='Input Image Preprocessing', value=['Background Removal'])
with gr.Column():
output_processing = gr.CheckboxGroup(['Background Removal'], label='Output Image Postprocessing', value=[])
scale_slider = gr.Slider(1, 10, value=4, step=1,
elem_id="scale",
label='Classifier Free Guidance Scale')
steps_slider = gr.Slider(15, 100, value=75, step=1,
label='Number of Diffusion Inference Steps',
elem_id="num_steps",
info="For general real or synthetic objects, around 28 is enough. For objects with delicate details such as faces (either realistic or illustration), you may need 75 or more steps.")
seed = gr.Number(42, label='Seed', elem_id="seed")
run_btn = gr.Button('Generate', variant='primary', interactive=True)
with gr.Column(scale=1):
processed_image = gr.Image(type='pil', label="Processed Image", interactive=False, height=320, image_mode='RGBA', elem_id="disp_image")
processed_image_highres = gr.Image(type='pil', image_mode='RGBA', visible=False)
with gr.Row():
view_1 = gr.Image(interactive=False, height=240, show_label=False)
view_2 = gr.Image(interactive=False, height=240, show_label=False)
view_3 = gr.Image(interactive=False, height=240, show_label=False)
with gr.Row():
view_4 = gr.Image(interactive=False, height=240, show_label=False)
view_5 = gr.Image(interactive=False, height=240, show_label=False)
view_6 = gr.Image(interactive=False, height=240, show_label=False)
with gr.Row():
norm_1 = gr.Image(interactive=False, height=240, show_label=False)
norm_2 = gr.Image(interactive=False, height=240, show_label=False)
norm_3 = gr.Image(interactive=False, height=240, show_label=False)
with gr.Row():
norm_4 = gr.Image(interactive=False, height=240, show_label=False)
norm_5 = gr.Image(interactive=False, height=240, show_label=False)
norm_6 = gr.Image(interactive=False, height=240, show_label=False)
full_view = gr.Image(visible=False, interactive=False, elem_id="six_view")
with gr.Group(elem_id="share-btn-container", visible=False) as share_group:
community_icon = gr.HTML(community_icon_html)
loading_icon = gr.HTML(loading_icon_html)
share_button = gr.Button("Share to community", elem_id="share-btn")
show_share_btn = lambda: gr.Group(visible=True)
hide_share_btn = lambda: gr.Group(visible=False)
input_image.change(hide_share_btn, outputs=share_group, queue=False)
run_btn.click(hide_share_btn, outputs=share_group, queue=False
).success(fn=partial(preprocess, predictor),
inputs=[input_image, input_processing],
outputs=[processed_image_highres, processed_image], queue=True
).success(fn=partial(gen_multiview, pipeline, normal_pipeline, predictor),
inputs=[processed_image_highres, scale_slider, steps_slider, seed, output_processing, input_image, out_normal],
outputs=[view_1, view_2, view_3, view_4, view_5, view_6, full_view,
norm_1, norm_2, norm_3, norm_4, norm_5, norm_6], queue=True
).success(show_share_btn, outputs=share_group, queue=False)
share_button.click(None, [], [], _js=share_js)
demo.queue().launch(share=False, max_threads=80, server_name="0.0.0.0", server_port=7860)
if __name__ == '__main__':
fire.Fire(run_demo)