OmniPart / app.py
omnipart's picture
init
491eded
import gradio as gr
import spaces
import os
import shutil
os.environ['SPCONV_ALGO'] = 'native'
from huggingface_hub import hf_hub_download
from app_utils import (
generate_parts,
prepare_models,
process_image,
apply_merge,
DEFAULT_SIZE_TH,
TMP_ROOT,
)
EXAMPLES = [
["assets/example_data/knight.png", 1800, "6,0,26,20,7;13,1,22,11,12,2,21,27,3,24,23;5,18;4,17;19,16,14,25,28", 42],
["assets/example_data/car.png", 2000, "12,10,2,11;1,7", 42],
["assets/example_data/warhammer.png", 1800, "7,1,0,8", 0],
["assets/example_data/snake.png", 3000, "2,3;0,1;4,5,6,7", 42],
["assets/example_data/Batman.png", 1800, "4,5", 42],
["assets/example_data/robot1.jpeg", 1600, "0,5;10,14,3;1,12,2;13,11,4;7,15", 42],
["assets/example_data/astronaut.png", 2000, "0,4,6;1,8,9,7;2,5", 42],
["assets/example_data/crossbow.jpg", 2000, "2,9;10,12,0,7,11,8,13;4,3", 42],
["assets/example_data/robot.jpg", 1600, "7,19;15,0;6,18", 42],
["assets/example_data/robot_dog.jpg", 1000, "21,9;2,12,10,15,17;11,7;1,0;13,19;4,16", 0],
["assets/example_data/crossbow.jpg", 1600, "9,2;10,15,13;7,14,8,11;0,12,16;5,3,1", 42],
["assets/example_data/robot.jpg", 1800, "1,2,3,5,4,16,17;11,7,19;10,14;18,6,0,15;13,9;12,8", 0],
["assets/example_data/robot_dog.jpg", 1000, "2,12,10,15,17,8,3,5,13,19,6,14;11,7;1,0,21,9,11;4,16", 0],
]
HEADER = """
# OmniPart: Part-Aware 3D Generation with Semantic Decoupling and Structural Cohesion
🔮 Generate **part-aware 3D content** from a single 2D image with **2D mask control**.
## How to Use
**🚀 Quick Start**: Select an example below and click **"▶️ Run Example"**
**📋 Custom Image Processing**:
1. **Upload Image** - Select your image file
2. **Click "Segment Image"** - Get initial 2D segmentation
3. **Merge Segments** - Enter merge groups like `0,1;3,4` and click **"Apply Merge"** (Recommend keeping **2-15 parts**)
4. **Click "Generate 3D Model"** - Create the final 3D results
"""
def start_session(req: gr.Request):
user_dir = os.path.join(TMP_ROOT, str(req.session_hash))
os.makedirs(user_dir, exist_ok=True)
def end_session(req: gr.Request):
user_dir = os.path.join(TMP_ROOT, str(req.session_hash))
shutil.rmtree(user_dir)
with gr.Blocks(title="OmniPart") as demo:
gr.Markdown(HEADER)
state = gr.State({})
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("<div style='text-align: center'>\n\n## Input\n\n</div>")
input_image = gr.Image(label="Upload Image", type="filepath", height=250, width=250)
with gr.Row():
segment_btn = gr.Button("Segment Image", variant="primary", size="lg")
run_example_btn = gr.Button("▶️ Run Example", variant="secondary", size="lg")
size_threshold = gr.Slider(
minimum=600,
maximum=4000,
value=DEFAULT_SIZE_TH,
step=200,
label="Minimum Segment Size (pixels)",
info="Segments smaller than this will be ignored"
)
gr.Markdown("### Merge Controls")
merge_input = gr.Textbox(
label="Merge Groups",
placeholder="0,1;3,4",
lines=2,
info="Specify which segments to merge (e.g., '0,1;3,4' merges segments 0&1 together and 3&4 together)"
)
merge_btn = gr.Button("Apply Merge", variant="primary", size="lg")
gr.Markdown("### 3D Generation Controls")
seed_slider = gr.Slider(
minimum=0,
maximum=10000,
value=42,
step=1,
label="Generation Seed",
info="Random seed for 3D model generation"
)
cfg_slider = gr.Slider(
minimum=0.0,
maximum=15.0,
value=7.5,
step=0.5,
label="CFG Strength",
info="Classifier-Free Guidance strength"
)
generate_mesh_btn = gr.Button("Generate 3D Model", variant="secondary", size="lg")
with gr.Column(scale=2):
gr.Markdown("<div style='text-align: center'>\n\n## Results Display\n\n</div>")
with gr.Row():
initial_seg = gr.Image(label="Init Seg", height=220, width=220)
pre_merge_vis = gr.Image(label="Pre-merge", height=220, width=220)
merged_seg = gr.Image(label="Merged Seg", height=220, width=220)
with gr.Row():
bbox_mesh = gr.Model3D(label="Bounding Boxes", height=350)
whole_mesh = gr.Model3D(label="Combined Parts", height=350)
exploded_mesh = gr.Model3D(label="Exploded Parts", height=350)
with gr.Row():
combined_gs = gr.Model3D(label="Combined 3D Gaussians", clear_color=(0.0, 0.0, 0.0, 0.0), height=350)
exploded_gs = gr.Model3D(label="Exploded 3D Gaussians", clear_color=(0.0, 0.0, 0.0, 0.0), height=350)
with gr.Row():
examples = gr.Examples(
examples=EXAMPLES,
inputs=[input_image, size_threshold, merge_input, seed_slider],
cache_examples=False,
)
demo.load(start_session)
demo.unload(end_session)
segment_btn.click(
process_image,
inputs=[input_image, size_threshold],
outputs=[initial_seg, pre_merge_vis, state]
)
merge_btn.click(
apply_merge,
inputs=[merge_input, state],
outputs=[merged_seg, state]
)
generate_mesh_btn.click(
generate_parts,
inputs=[state, seed_slider, cfg_slider],
outputs=[bbox_mesh, whole_mesh, exploded_mesh, combined_gs, exploded_gs]
)
run_example_btn.click(
fn=process_image,
inputs=[input_image, size_threshold],
outputs=[initial_seg, pre_merge_vis, state]
).then(
fn=apply_merge,
inputs=[merge_input, state],
outputs=[merged_seg, state]
).then(
fn=generate_parts,
inputs=[state, seed_slider, cfg_slider],
outputs=[bbox_mesh, whole_mesh, exploded_mesh, combined_gs, exploded_gs]
)
if __name__ == "__main__":
os.makedirs("ckpt", exist_ok=True)
sam_ckpt_path = hf_hub_download(repo_id="omnipart/OmniPart_modules", filename="sam_vit_h_4b8939.pth", local_dir="ckpt")
partfield_ckpt_path = hf_hub_download(repo_id="omnipart/OmniPart_modules", filename="partfield_encoder.ckpt", local_dir="ckpt")
bbox_gen_ckpt_path = hf_hub_download(repo_id="omnipart/OmniPart_modules", filename="bbox_gen.ckpt", local_dir="ckpt")
prepare_models(sam_ckpt_path, partfield_ckpt_path, bbox_gen_ckpt_path)
demo.launch()