Spaces:
Running
on
Zero
Running
on
Zero
import json | |
import os | |
import random | |
import tempfile | |
from typing import Any, List, Union | |
import gradio as gr | |
import numpy as np | |
import torch | |
from gradio_image_prompter import ImagePrompter | |
from gradio_litmodel3d import LitModel3D | |
from huggingface_hub import snapshot_download | |
from PIL import Image | |
from transformers import AutoModelForMaskGeneration, AutoProcessor | |
from midi.pipelines.pipeline_midi import MIDIPipeline | |
from scripts.grounding_sam import plot_segmentation, segment | |
from scripts.inference_midi import run_midi | |
import spaces | |
# Constants | |
MAX_SEED = np.iinfo(np.int32).max | |
TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "tmp") | |
DTYPE = torch.bfloat16 | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
REPO_ID = "VAST-AI/MIDI-3D" | |
MARKDOWN = """ | |
## Image to 3D Scene with [MIDI-3D](https://huanngzh.github.io/MIDI-Page/) | |
<b>Important!</b> Please check out our [instruction video](https://github.com/user-attachments/assets/4fc8aea4-010f-40c7-989d-6b1d9d3e3e09)! | |
1. Upload an image, and draw bounding boxes for each instance by holding and dragging the mouse. Then clik "Run Segmentation" to generate the segmentation result. <b>Ensure instances should not be too small and bounding boxes fit snugly around each instance.</b> | |
2. <b>Check "Do image padding" in "Generation Settings" if instances in your image are too close to the image border.</b> Then click "Run Generation" to generate a 3D scene from the image and segmentation result. | |
3. If you find the generated 3D scene satisfactory, download it by clicking the "Download GLB" button. | |
""" | |
EXAMPLES = [ | |
[ | |
{ | |
"image": "assets/example_data/Cartoon-Style/00_rgb.png", | |
}, | |
"assets/example_data/Cartoon-Style/00_seg.png", | |
42, | |
False, | |
False, | |
], | |
[ | |
{ | |
"image": "assets/example_data/Cartoon-Style/01_rgb.png", | |
}, | |
"assets/example_data/Cartoon-Style/01_seg.png", | |
42, | |
False, | |
False, | |
], | |
[ | |
{ | |
"image": "assets/example_data/Cartoon-Style/03_rgb.png", | |
}, | |
"assets/example_data/Cartoon-Style/03_seg.png", | |
42, | |
False, | |
False, | |
], | |
[ | |
{ | |
"image": "assets/example_data/Realistic-Style/00_rgb.png", | |
}, | |
"assets/example_data/Realistic-Style/00_seg.png", | |
42, | |
False, | |
True, | |
], | |
[ | |
{ | |
"image": "assets/example_data/Realistic-Style/01_rgb.png", | |
}, | |
"assets/example_data/Realistic-Style/01_seg.png", | |
42, | |
False, | |
True, | |
], | |
[ | |
{ | |
"image": "assets/example_data/Realistic-Style/02_rgb.png", | |
}, | |
"assets/example_data/Realistic-Style/02_seg.png", | |
42, | |
False, | |
False, | |
], | |
[ | |
{ | |
"image": "assets/example_data/Realistic-Style/05_rgb.png", | |
}, | |
"assets/example_data/Realistic-Style/05_seg.png", | |
42, | |
False, | |
False, | |
], | |
] | |
os.makedirs(TMP_DIR, exist_ok=True) | |
# Prepare models | |
## Grounding SAM | |
segmenter_id = "facebook/sam-vit-base" | |
sam_processor = AutoProcessor.from_pretrained(segmenter_id) | |
sam_segmentator = AutoModelForMaskGeneration.from_pretrained(segmenter_id).to( | |
DEVICE, DTYPE | |
) | |
## MIDI-3D | |
local_dir = "pretrained_weights/MIDI-3D" | |
snapshot_download(repo_id=REPO_ID, local_dir=local_dir) | |
pipe: MIDIPipeline = MIDIPipeline.from_pretrained(local_dir).to(DEVICE, DTYPE) | |
pipe.init_custom_adapter( | |
set_self_attn_module_names=[ | |
"blocks.8", | |
"blocks.9", | |
"blocks.10", | |
"blocks.11", | |
"blocks.12", | |
] | |
) | |
# Utils | |
def split_rgb_mask(rgb_image, seg_image): | |
if isinstance(rgb_image, str): | |
rgb_image = Image.open(rgb_image) | |
if isinstance(seg_image, str): | |
seg_image = Image.open(seg_image) | |
rgb_image = rgb_image.convert("RGB") | |
seg_image = seg_image.convert("L") | |
rgb_array = np.array(rgb_image) | |
seg_array = np.array(seg_image) | |
label_ids = np.unique(seg_array) | |
label_ids = label_ids[label_ids > 0] | |
instance_rgbs, instance_masks, scene_rgbs = [], [], [] | |
for segment_id in sorted(label_ids): | |
# Here we set the background to white | |
white_background = np.ones_like(rgb_array) * 255 | |
mask = np.zeros_like(seg_array, dtype=np.uint8) | |
mask[seg_array == segment_id] = 255 | |
segment_rgb = white_background.copy() | |
segment_rgb[mask == 255] = rgb_array[mask == 255] | |
segment_rgb_image = Image.fromarray(segment_rgb) | |
segment_mask_image = Image.fromarray(mask) | |
instance_rgbs.append(segment_rgb_image) | |
instance_masks.append(segment_mask_image) | |
scene_rgbs.append(rgb_image) | |
return instance_rgbs, instance_masks, scene_rgbs | |
def run_segmentation(image_prompts: Any, polygon_refinement: bool) -> Image.Image: | |
rgb_image = image_prompts["image"].convert("RGB") | |
# pre-process the layers and get the xyxy boxes of each layer | |
if len(image_prompts["points"]) == 0: | |
gr.Error("Please draw bounding boxes for each instance on the image.") | |
boxes = [ | |
[ | |
[int(box[0]), int(box[1]), int(box[3]), int(box[4])] | |
for box in image_prompts["points"] | |
] | |
] | |
# run the segmentation | |
detections = segment( | |
sam_processor, | |
sam_segmentator, | |
rgb_image, | |
boxes=[boxes], | |
polygon_refinement=polygon_refinement, | |
) | |
seg_map_pil = plot_segmentation(rgb_image, detections) | |
torch.cuda.empty_cache() | |
return seg_map_pil | |
# @spaces.GPU() | |
def run_generation( | |
rgb_image: Any, | |
seg_image: Union[str, Image.Image], | |
seed: int, | |
randomize_seed: bool = False, | |
num_inference_steps: int = 50, | |
guidance_scale: float = 7.0, | |
do_image_padding: bool = False, | |
): | |
if randomize_seed: | |
seed = random.randint(0, MAX_SEED) | |
if not isinstance(rgb_image, Image.Image) and "image" in rgb_image: | |
rgb_image = rgb_image["image"] | |
scene = run_midi( | |
pipe, | |
rgb_image, | |
seg_image, | |
seed, | |
num_inference_steps, | |
guidance_scale, | |
do_image_padding, | |
) | |
_, tmp_path = tempfile.mkstemp(suffix=".glb", prefix="midi3d_", dir=TMP_DIR) | |
scene.export(tmp_path) | |
torch.cuda.empty_cache() | |
return tmp_path, tmp_path, seed | |
# Demo | |
with gr.Blocks() as demo: | |
gr.Markdown(MARKDOWN) | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Row(): | |
image_prompts = ImagePrompter(label="Input Image", type="pil") | |
seg_image = gr.Image( | |
label="Segmentation Result", type="pil", format="png" | |
) | |
with gr.Accordion("Segmentation Settings", open=False): | |
polygon_refinement = gr.Checkbox( | |
label="Polygon Refinement", value=False | |
) | |
seg_button = gr.Button("Run Segmentation") | |
with gr.Accordion("Generation Settings", open=False): | |
do_image_padding = gr.Checkbox(label="Do image padding", value=False) | |
seed = gr.Slider( | |
label="Seed", | |
minimum=0, | |
maximum=MAX_SEED, | |
step=1, | |
value=0, | |
) | |
randomize_seed = gr.Checkbox(label="Randomize seed", value=True) | |
num_inference_steps = gr.Slider( | |
label="Number of inference steps", | |
minimum=1, | |
maximum=50, | |
step=1, | |
value=50, | |
) | |
guidance_scale = gr.Slider( | |
label="CFG scale", | |
minimum=0.0, | |
maximum=10.0, | |
step=0.1, | |
value=7.0, | |
) | |
gen_button = gr.Button("Run Generation", variant="primary") | |
with gr.Column(): | |
model_output = LitModel3D(label="Generated GLB", exposure=1.0, height=500) | |
download_glb = gr.DownloadButton(label="Download GLB", interactive=False) | |
with gr.Row(): | |
gr.Examples( | |
examples=EXAMPLES, | |
fn=run_generation, | |
inputs=[image_prompts, seg_image, seed, randomize_seed, do_image_padding], | |
outputs=[model_output, download_glb, seed], | |
cache_examples=False, | |
) | |
seg_button.click( | |
run_segmentation, | |
inputs=[ | |
image_prompts, | |
polygon_refinement, | |
], | |
outputs=[seg_image], | |
).then(lambda: gr.Button(interactive=True), outputs=[gen_button]) | |
gen_button.click( | |
run_generation, | |
inputs=[ | |
image_prompts, | |
seg_image, | |
seed, | |
randomize_seed, | |
num_inference_steps, | |
guidance_scale, | |
do_image_padding, | |
], | |
outputs=[model_output, download_glb, seed], | |
).then(lambda: gr.Button(interactive=True), outputs=[download_glb]) | |
demo.launch() | |