Rodin / app.py
skkk's picture
fix a bug
5ff67bd
import os
os.system('pip uninstall -y gradio_fake3d')
os.system('pip install gradio_fake3d-0.0.3-py3-none-any.whl')
import gradio as gr
import re
from gradio_fake3d import Fake3D
from PIL import Image
from Rodin import Generator, crop_image, log
from constant import *
generator = Generator(USER, PASSWORD, TOKEN)
change_button_name = """
function updateButton(input) {
var buttonGenerate = document.getElementById('button_generate');
buttonGenerate.innerText = 'Redo';
return '';
}
"""
change_button_name_to_generating = """
function updateButton(input) {
var buttonGenerate = document.getElementById('button_generate');
buttonGenerate.innerText = 'Generating...';
return '';
}
"""
reset_button_name = """
function updateButton(input) {
var buttonGenerate = document.getElementById('button_generate');
buttonGenerate.innerText = 'Generate';
return '';
}
"""
jump_to_rodin = """
function redirectToGithub(input) {
if (input.includes('OpenClay')) {
window.open("https://github.com/CLAY-3D/OpenCLAY", "_blank");
}
return "Rodin Gen-1(0525)";
}
"""
html_content = """
<div style="text-align: center;">
<h1>Rodin Gen-1</h1>
<div style="display: flex; justify-content: space-around;">
<p><strong>Rodin Gen-1:</strong> <a href="https://hyperhuman.top/rodin" target="_blank">https://hyperhuman.top/rodin</a></p>
<p><strong>Github:</strong> <a href="https://github.com/CLAY-3D/OpenCLAY" target="_blank">https://github.com/CLAY-3D/OpenCLAY</a></p>
</div>
</div>
"""
options = [
"Rodin Gen-1(0525)",
"OpenClay(600M) - Coming soon",
"OpenClay(200M) - Coming soon"
]
example = [
["assets/00.png"],
["assets/08.png"],
["assets/13.png"],
["assets/24.PNG"],
["assets/30.png"],
["assets/42.png"],
["assets/46.png"]
]
def do_nothing(text):
return ""
def handle_selection(selection):
return "Rodin Gen-1(0525)"
def hint_in_prompt(hint, prompt):
return re.search(fr"{hint[:-1]}", prompt) is not None
def prompt_remove_hint(prompt, hint):
return re.sub(fr"\s*{hint[:-1]}[\.,]*", "", prompt)
def handle_hint_change(prompt: str, prompt_hint):
prompt = prompt.strip()
if prompt != "" and not prompt.endswith("."):
prompt = prompt + "."
for _, hint in PROMPT_HINT_LIST:
if hint in prompt_hint:
if not hint_in_prompt(hint, prompt):
prompt = prompt + " " + hint
else:
prompt = prompt_remove_hint(prompt, hint)
prompt = prompt.strip()
return prompt
def handle_prompt_change(prompt):
hint_list = []
for _, hint in PROMPT_HINT_LIST:
if hint_in_prompt(hint, prompt):
hint_list.append(hint)
return hint_list
def clear_task(task_input=None):
"""_summary_
[cache_task_uuid, block_prompt, block_prompt_hint, fake3d]
"""
log("INFO", "Clearing task...")
return "", "", "", [], "assets/white_image.png"
def clear_task_id():
return ""
def return_render(image):
image = Image.fromarray(image)
return image, crop_image(image, DEFAULT)
def crop_image_default(image):
return crop_image(image, DEFAULT)
def crop_image_metal(image):
return crop_image(image, METAL)
def crop_image_contrast(image):
return crop_image(image, CONTRAST)
def crop_image_normal(image):
return crop_image(image, NORMAL)
with gr.Blocks() as demo:
gr.HTML(html_content)
with gr.Row():
with gr.Column():
block_image = gr.Image(height=256, image_mode="RGB", sources="upload", elem_classes="elem_imageupload", type="filepath")
block_model_card = gr.Dropdown(choices=options, label="Model Card", value="Rodin Gen-1(0525)", interactive=True)
with gr.Group():
block_prompt = gr.Textbox(
value="",
placeholder="Auto generated description of Image",
lines=1,
show_label=True,
label="Prompt",
)
block_prompt_hint = gr.CheckboxGroup(value="Labels", choices=PROMPT_HINT_LIST, show_label=False)
with gr.Column():
with gr.Group():
fake3d = Fake3D(interactive=False, label="3D Preview")
with gr.Row():
button_generate = gr.Button(value="Generate", variant="primary", elem_id="button_generate")
with gr.Column(min_width=200, scale=20):
with gr.Row():
block_default = gr.Button("Default", min_width=0)
block_metal = gr.Button("Metal", min_width=0)
with gr.Row():
block_contrast = gr.Button("Contrast", min_width=0)
block_normal = gr.Button("Normal", min_width=0)
button_more = gr.Button(value="Download from Rodin", variant="primary", link=rodin_url)
gr.Markdown("""
**TIPS**:
1. Upload an image to generate 3D geometry.
2. Click Redo to regenerate the model.
3. 4 buttons to switch the view.
4. Swipe to rotate the model.
""")
cache_task_uuid = gr.Text(value="", visible=False)
cache_raw_image = gr.Image(visible=False, type="pil")
cacha_empty = gr.Text(visible=False)
cache_image_base64 = gr.Text(visible=False)
block_example = gr.Examples(
examples=example,
fn=clear_task,
inputs=[block_image],
outputs=[cache_image_base64, cache_task_uuid, block_prompt, block_prompt_hint, fake3d],
run_on_click=True,
cache_examples=True,
label="Examples"
)
block_image.upload(
fn=do_nothing,
js=change_button_name_to_generating,
inputs=[cacha_empty],
outputs=[cacha_empty],
queue=False
).success(
fn=generator.preprocess,
inputs=[block_prompt, block_image, cache_image_base64, cache_task_uuid],
outputs=[block_prompt, cache_image_base64],
show_progress="minimal",
queue=False
).success(
fn=generator.generate_mesh,
inputs=[block_prompt, cache_image_base64, cache_task_uuid],
outputs=[cache_raw_image, cache_task_uuid, fake3d],
queue=True
).success(
fn=do_nothing,
js=change_button_name,
inputs=[cacha_empty],
outputs=[cacha_empty],
queue=False
)
block_image.clear(
fn=do_nothing,
js=reset_button_name,
inputs=[cacha_empty],
outputs=[cacha_empty],
queue=False
).then(
fn=clear_task,
outputs=[cache_image_base64, cache_task_uuid, block_prompt, block_prompt_hint, fake3d],
show_progress="hidden",
queue=False
)
button_generate.click(
fn=do_nothing,
js=change_button_name_to_generating,
inputs=[cacha_empty],
outputs=[cacha_empty],
queue=False
).success(
fn=generator.preprocess,
inputs=[block_prompt, block_image, cache_image_base64, cache_task_uuid],
outputs=[block_prompt, cache_image_base64],
show_progress="minimal",
queue=False
).success(
fn=generator.generate_mesh,
inputs=[block_prompt, cache_image_base64, cache_task_uuid],
outputs=[cache_raw_image, cache_task_uuid, fake3d],
queue=True
).then(
fn=do_nothing,
js=change_button_name,
inputs=[cacha_empty],
outputs=[cacha_empty],
queue=False
)
block_default.click(fn=crop_image_default, inputs=[cache_raw_image], outputs=fake3d, show_progress="minimal")
block_metal.click(fn=crop_image_metal, inputs=[cache_raw_image], outputs=fake3d, show_progress="minimal")
block_contrast.click(fn=crop_image_contrast, inputs=[cache_raw_image], outputs=fake3d, show_progress="minimal")
block_normal.click(fn=crop_image_normal, inputs=[cache_raw_image], outputs=fake3d, show_progress="minimal")
button_more.click()
block_prompt_hint.input(
fn=handle_hint_change, inputs=[block_prompt, block_prompt_hint], outputs=[block_prompt],
show_progress="hidden",
queue=False,
)
block_prompt.change(
fn=handle_prompt_change,
inputs=[block_prompt],
outputs=[block_prompt_hint],
trigger_mode="always_last",
show_progress="hidden",
)
block_model_card.change(fn=handle_selection, inputs=[block_model_card], outputs=[block_model_card], show_progress="hidden", js=jump_to_rodin)
if __name__ == "__main__":
demo.queue()
demo.launch(show_api=False, show_error=True)