|
import torch |
|
import argparse |
|
import gradio as gr |
|
from functools import partial |
|
from my.config import BaseConf, dispatch_gradio |
|
from run_3DFuse import SJC_3DFuse |
|
import numpy as np |
|
from PIL import Image |
|
from pc_project import point_e |
|
from diffusers import UnCLIPPipeline, DiffusionPipeline |
|
from pc_project import point_e_gradio |
|
import numpy as np |
|
import plotly.graph_objs as go |
|
from my.utils.seed import seed_everything |
|
import os |
|
import wget |
|
|
|
SHARED_UI_WARNING = f'''### [NOTE] Training may be very slow in this shared UI. |
|
You can duplicate and use it with a paid private GPU. |
|
<a class="duplicate-button" style="display:inline-block" target="_blank" href="https://huggingface.co/spaces/jyseo/3DFuse?duplicate=true"><img style="margin-top:0;margin-bottom:0" src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAABAAAAAQCAYAAAAf8/9hAAAAAXNSR0IArs4c6QAAAP5JREFUOE+lk7FqAkEURY+ltunEgFXS2sZGIbXfEPdLlnxJyDdYB62sbbUKpLbVNhyYFzbrrA74YJlh9r079973psed0cvUD4A+4HoCjsA85X0Dfn/RBLBgBDxnQPfAEJgBY+A9gALA4tcbamSzS4xq4FOQAJgCDwV2CPKV8tZAJcAjMMkUe1vX+U+SMhfAJEHasQIWmXNN3abzDwHUrgcRGmYcgKe0bxrblHEB4E/pndMazNpSZGcsZdBlYJcEL9Afo75molJyM2FxmPgmgPqlWNLGfwZGG6UiyEvLzHYDmoPkDDiNm9JR9uboiONcBXrpY1qmgs21x1QwyZcpvxt9NS09PlsPAAAAAElFTkSuQmCC&logoWidth=14" alt="Duplicate Space"></a> |
|
Alternatively, you can also use the Colab demo on our project page. |
|
<a style="display:inline-block" href="https://ku-cvlab.github.io/3DFuse/"><img style="margin-top:0;margin-bottom:0" src="https://img.shields.io/badge/Project%20Page-online-brightgreen"></a> |
|
''' |
|
|
|
class Intermediate: |
|
def __init__(self): |
|
self.images = None |
|
self.points = None |
|
self.is_generating = False |
|
|
|
|
|
def gen_3d(model, intermediate, prompt, keyword, seed, ti_step, pt_step) : |
|
intermediate.is_generating = True |
|
images, points = intermediate.images, intermediate.points |
|
if images is None or points is None : |
|
raise gr.Error("Please generate point cloud first") |
|
del model |
|
|
|
seed_everything(seed) |
|
model = dispatch_gradio(SJC_3DFuse, prompt, keyword, ti_step, pt_step, seed) |
|
setting = model.dict() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
yield from model.run_gradio(points, images) |
|
|
|
intermediate.is_generating = False |
|
|
|
|
|
|
|
def gen_pc_from_prompt(intermediate, num_initial_image, prompt, keyword, type, bg_preprocess, seed) : |
|
|
|
seed_everything(seed=seed) |
|
if keyword not in prompt: |
|
raise gr.Error("Prompt should contain keyword!") |
|
elif " " in keyword: |
|
raise gr.Error("Keyword should be one word!") |
|
|
|
images = gen_init(num_initial_image=num_initial_image, prompt=prompt,seed=seed, type=type, bg_preprocess=bg_preprocess) |
|
points = point_e_gradio(images[0],'cuda') |
|
|
|
intermediate.images = images |
|
intermediate.points = points |
|
|
|
coords = np.array(points.coords) |
|
trace = go.Scatter3d(x=coords[:,0], y=coords[:,1], z=coords[:,2], mode='markers', marker=dict(size=2)) |
|
|
|
layout = go.Layout( |
|
scene=dict( |
|
xaxis=dict( |
|
title="", |
|
showgrid=False, |
|
zeroline=False, |
|
showline=False, |
|
ticks='', |
|
showticklabels=False |
|
), |
|
yaxis=dict( |
|
title="", |
|
showgrid=False, |
|
zeroline=False, |
|
showline=False, |
|
ticks='', |
|
showticklabels=False |
|
), |
|
zaxis=dict( |
|
title="", |
|
showgrid=False, |
|
zeroline=False, |
|
showline=False, |
|
ticks='', |
|
showticklabels=False |
|
), |
|
), |
|
margin=dict(l=0, r=0, b=0, t=0), |
|
showlegend=False |
|
) |
|
|
|
fig = go.Figure(data=[trace], layout=layout) |
|
|
|
return images[0], fig, gr.update(interactive=True) |
|
|
|
|
|
def gen_pc_from_image(intermediate, image, prompt, keyword, bg_preprocess, seed) : |
|
|
|
seed_everything(seed=seed) |
|
if keyword not in prompt: |
|
raise gr.Error("Prompt should contain keyword!") |
|
elif " " in keyword: |
|
raise gr.Error("Keyword should be one word!") |
|
|
|
if bg_preprocess: |
|
import cv2 |
|
from carvekit.api.high import HiInterface |
|
interface = HiInterface(object_type="object", |
|
batch_size_seg=5, |
|
batch_size_matting=1, |
|
device='cuda' if torch.cuda.is_available() else 'cpu', |
|
seg_mask_size=640, |
|
matting_mask_size=2048, |
|
trimap_prob_threshold=231, |
|
trimap_dilation=30, |
|
trimap_erosion_iters=5, |
|
fp16=False) |
|
|
|
img_without_background = interface([image]) |
|
mask = np.array(img_without_background[0]) > 127 |
|
image = np.array(image) |
|
image[~mask] = [255., 255., 255.] |
|
image = Image.fromarray(np.array(image)) |
|
|
|
|
|
points = point_e_gradio(image,'cuda') |
|
|
|
intermediate.images = [image] |
|
intermediate.points = points |
|
|
|
coords = np.array(points.coords) |
|
trace = go.Scatter3d(x=coords[:,0], y=coords[:,1], z=coords[:,2], mode='markers', marker=dict(size=2)) |
|
|
|
layout = go.Layout( |
|
scene=dict( |
|
xaxis=dict( |
|
title="", |
|
showgrid=False, |
|
zeroline=False, |
|
showline=False, |
|
ticks='', |
|
showticklabels=False |
|
), |
|
yaxis=dict( |
|
title="", |
|
showgrid=False, |
|
zeroline=False, |
|
showline=False, |
|
ticks='', |
|
showticklabels=False |
|
), |
|
zaxis=dict( |
|
title="", |
|
showgrid=False, |
|
zeroline=False, |
|
showline=False, |
|
ticks='', |
|
showticklabels=False |
|
), |
|
), |
|
margin=dict(l=0, r=0, b=0, t=0), |
|
showlegend=False |
|
) |
|
|
|
fig = go.Figure(data=[trace], layout=layout) |
|
|
|
return image, fig, gr.update(interactive=True) |
|
|
|
def gen_init(num_initial_image, prompt,seed,type="Karlo", bg_preprocess=False): |
|
pipe = UnCLIPPipeline.from_pretrained("kakaobrain/karlo-v1-alpha", torch_dtype=torch.float16) if type=="Karlo (Recommended)" \ |
|
else DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5") |
|
pipe = pipe.to('cuda') |
|
|
|
view_prompt=["front view of ","overhead view of ","side view of ", "back view of "] |
|
|
|
if bg_preprocess: |
|
import cv2 |
|
from carvekit.api.high import HiInterface |
|
interface = HiInterface(object_type="object", |
|
batch_size_seg=5, |
|
batch_size_matting=1, |
|
device='cuda' if torch.cuda.is_available() else 'cpu', |
|
seg_mask_size=640, |
|
matting_mask_size=2048, |
|
trimap_prob_threshold=231, |
|
trimap_dilation=30, |
|
trimap_erosion_iters=5, |
|
fp16=False) |
|
|
|
images = [] |
|
generator = torch.Generator(device='cuda').manual_seed(seed) |
|
for i in range(num_initial_image): |
|
t=", white background" if bg_preprocess else ", white background" |
|
if i==0: |
|
prompt_ = f"{view_prompt[i%4]}{prompt}{t}" |
|
else: |
|
prompt_ = f"{view_prompt[i%4]}{prompt}" |
|
|
|
image = pipe(prompt_, generator=generator).images[0] |
|
|
|
if bg_preprocess: |
|
|
|
|
|
|
|
img_without_background = interface([image]) |
|
mask = np.array(img_without_background[0]) > 127 |
|
image = np.array(image) |
|
image[~mask] = [255., 255., 255.] |
|
image = Image.fromarray(np.array(image)) |
|
images.append(image) |
|
|
|
return images |
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--share', action='store_true', help="public url") |
|
args = parser.parse_args() |
|
|
|
weights_dir = './weights' |
|
if not os.path.exists(weights_dir): |
|
os.makedirs(weights_dir) |
|
weights_path = os.path.join(weights_dir, '3DFuse_sparse_depth_injector.ckpt') |
|
|
|
|
|
if not os.path.isfile(weights_path): |
|
url = 'https://huggingface.co/jyseo/3DFuse_weights/resolve/main/models/3DFuse_sparse_depth_injector.ckpt' |
|
wget.download(url, weights_path) |
|
print(f'{weights_path} downloaded.') |
|
else: |
|
print(f'{weights_path} already exists.') |
|
|
|
|
|
model = None |
|
intermediate = Intermediate() |
|
demo = gr.Blocks(title="3DFuse Interactive Demo") |
|
|
|
with demo: |
|
with gr.Box(): |
|
gr.Markdown(SHARED_UI_WARNING) |
|
|
|
gr.Markdown("# 3DFuse Interactive Demo") |
|
gr.Markdown("### Official Implementation of the paper \"Let 2D Diffusion Model Know 3D-Consistency for Robust Text-to-3D Generation\"") |
|
gr.Markdown("Enter your own prompt and enjoy! With this demo, you can preview the point cloud before 3D generation and determine the desired shape.") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1., variant='panel'): |
|
|
|
with gr.Tab("Text to 3D"): |
|
prompt_input = gr.Textbox(label="Prompt", value="a comfortable bed", interactive=True) |
|
word_input = gr.Textbox(label="Keyword for initialization (should be a noun included in the prompt)", value="bed", interactive=True) |
|
semantic_model_choice = gr.Radio(["Karlo (Recommended)","Stable Diffusion"], label="Backbone for initial image generation", value="Karlo (Recommended)", interactive=True) |
|
seed = gr.Slider(label="Seed", minimum=0, maximum=2100000000, step=1, randomize=True) |
|
preprocess_choice = gr.Checkbox(True, label="Preprocess intially-generated image by removing background", interactive=True) |
|
with gr.Accordion("Advanced Options", open=False): |
|
opt_step = gr.Slider(0, 1000, value=500, step=1, label='Number of text embedding optimization step') |
|
pivot_step = gr.Slider(0, 1000, value=500, step=1, label='Number of pivotal tuning step for LoRA') |
|
with gr.Row(): |
|
button_gen_pc = gr.Button("1. Generate Point Cloud", interactive=True, variant='secondary') |
|
button_gen_3d = gr.Button("2. Generate 3D", interactive=False, variant='primary') |
|
|
|
with gr.Tab("Image to 3D"): |
|
image_input = gr.Image(source='upload', type="pil", interactive=True) |
|
prompt_input_2 = gr.Textbox(label="Prompt", value="a dog", interactive=True) |
|
word_input_2 = gr.Textbox(label="Keyword for initialization (should be a noun included in the prompt)", value="dog", interactive=True) |
|
seed_2 = gr.Slider(label="Seed", minimum=0, maximum=2100000000, step=1, randomize=True) |
|
preprocess_choice_2 = gr.Checkbox(True, label="Preprocess intially-generated image by removing background", interactive=False) |
|
with gr.Accordion("Advanced Options", open=False): |
|
opt_step_2 = gr.Slider(0, 1000, value=500, step=1, label='Number of text embedding optimization step') |
|
pivot_step_2 = gr.Slider(0, 1000, value=500, step=1, label='Number of pivotal tuning step for LoRA') |
|
with gr.Row(): |
|
button_gen_pc_2 = gr.Button("1. Generate Point Cloud", interactive=True, variant='secondary') |
|
button_gen_3d_2 = gr.Button("2. Generate 3D", interactive=False, variant='primary') |
|
gr.Markdown("Note: A photo showing the entire object in a front view is recommended. Also, our framework is not designed with the goal of single shot reconstruction, so it may be difficult to reflect the details of the given image.") |
|
|
|
|
|
with gr.Row(scale=1.): |
|
with gr.Column(scale=1.): |
|
pc_plot = gr.Plot(label="Inferred point cloud") |
|
with gr.Column(scale=1.): |
|
init_output = gr.Image(label='Generated initial image', interactive=False) |
|
|
|
|
|
|
|
with gr.Column(scale=1., variant='panel'): |
|
with gr.Row(): |
|
with gr.Column(scale=1.): |
|
intermediate_output = gr.Image(label="Intermediate Output", interactive=False) |
|
with gr.Column(scale=1.): |
|
logs = gr.Textbox(label="logs", lines=8, max_lines=8, interactive=False) |
|
with gr.Row(): |
|
video_result = gr.Video(label="Video result for generated 3D", interactive=False) |
|
|
|
gr.Markdown("Note: Keyword is used for Textual Inversion. Please choose one important noun included in the prompt. This demo may be slower than directly running run_3DFuse.py.") |
|
|
|
|
|
|
|
button_gen_pc.click(fn=partial(gen_pc_from_prompt,intermediate,4), inputs=[prompt_input, word_input, semantic_model_choice, \ |
|
preprocess_choice, seed], outputs=[init_output, pc_plot, button_gen_3d]) |
|
button_gen_3d.click(fn=partial(gen_3d,model,intermediate), inputs=[prompt_input, word_input, seed, opt_step, pivot_step], \ |
|
outputs=[intermediate_output,logs,video_result]) |
|
|
|
button_gen_pc_2.click(fn=partial(gen_pc_from_image,intermediate), inputs=[image_input, prompt_input_2, word_input_2, \ |
|
preprocess_choice_2, seed_2], outputs=[init_output, pc_plot, button_gen_3d_2]) |
|
button_gen_3d_2.click(fn=partial(gen_3d,model,intermediate), inputs=[prompt_input_2, word_input_2, seed_2, opt_step_2, pivot_step_2], \ |
|
outputs=[intermediate_output,logs,video_result]) |
|
|
|
|
|
demo.queue(concurrency_count=1) |
|
demo.launch(share=args.share) |
|
|
|
|
|
|
|
|
|
|
|
|