Files changed (3) hide show
  1. Dockerfile +58 -0
  2. README.md +2 -3
  3. app.py +51 -75
Dockerfile ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Read the doc: https://huggingface.co/docs/hub/spaces-sdks-docker
2
+ # you will also find guides on how best to write your Dockerfile
3
+
4
+ FROM nvidia/cuda:12.1.0-devel-ubuntu22.04
5
+
6
+ # Set compute capability for nerfacc and tiny-cuda-nn
7
+ # See https://developer.nvidia.com/cuda-gpus and limit number to speed-up build
8
+ ENV TORCH_CUDA_ARCH_LIST="6.0 6.1 7.0 7.5 8.0 8.6 8.9 9.0+PTX"
9
+ ENV TCNN_CUDA_ARCHITECTURES=90;89;86;80;75;70;61;60
10
+ # Speed-up build for RTX 30xx
11
+ # ENV TORCH_CUDA_ARCH_LIST="8.6"
12
+ # ENV TCNN_CUDA_ARCHITECTURES=86
13
+ # Speed-up build for RTX 40xx
14
+ # ENV TORCH_CUDA_ARCH_LIST="8.9"
15
+ # ENV TCNN_CUDA_ARCHITECTURES=89
16
+
17
+ # apt install by root user
18
+ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
19
+ build-essential \
20
+ curl \
21
+ cmake \
22
+ git \
23
+ git-lfs \
24
+ ffmpeg \
25
+ libegl1-mesa-dev \
26
+ libgl1-mesa-dev \
27
+ libgles2-mesa-dev \
28
+ libglib2.0-0 \
29
+ libgl1-mesa-glx \
30
+ libsm6 \
31
+ libxext6 \
32
+ libxrender1 \
33
+ python-is-python3 \
34
+ python3.10-dev \
35
+ python3-pip \
36
+ rsync \
37
+ wget \
38
+ && rm -rf /var/lib/apt/lists/*
39
+
40
+ RUN useradd -m -u 1000 user
41
+ USER user
42
+
43
+ ENV CUDA_HOME=/usr/local/cuda
44
+ ENV PATH=${CUDA_HOME}/bin:/home/user/.local/bin:${PATH}
45
+ ENV LD_LIBRARY_PATH=${CUDA_HOME}/lib64:${LD_LIBRARY_PATH}
46
+ ENV LIBRARY_PATH=${CUDA_HOME}/lib64/stubs:${LIBRARY_PATH}
47
+
48
+ WORKDIR /app
49
+
50
+ RUN pip install torch==2.2.1 torchvision==0.17.1 torchaudio==2.2.1 --index-url https://download.pytorch.org/whl/cu121
51
+ RUN pip install --no-cache-dir datasets "huggingface-hub>=0.19" "hf-transfer>=0.1.4" "protobuf<4" "click<8.1" "pydantic~=1.0"
52
+ RUN pip install --no-cache-dir gradio[oauth]==4.44.1 "uvicorn>=0.14.0" spaces
53
+
54
+ COPY --chown=user ./requirements.txt requirements.txt
55
+ RUN pip install --no-cache-dir --upgrade -r requirements.txt
56
+
57
+ COPY --chown=user . /app
58
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
README.md CHANGED
@@ -1,10 +1,9 @@
1
  ---
2
  title: MIDI 3D
3
  emoji: 📚
4
- colorFrom: purple
5
  colorTo: red
6
- sdk: gradio
7
- sdk_version: 4.44.1
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
 
1
  ---
2
  title: MIDI 3D
3
  emoji: 📚
4
+ colorFrom: gray
5
  colorTo: red
6
+ sdk: docker
 
7
  app_file: app.py
8
  pinned: false
9
  license: apache-2.0
app.py CHANGED
@@ -6,20 +6,18 @@ from typing import Any, List, Union
6
 
7
  import gradio as gr
8
  import numpy as np
9
- import spaces
10
  import torch
11
- import trimesh
12
  from gradio_image_prompter import ImagePrompter
13
  from gradio_litmodel3d import LitModel3D
14
  from huggingface_hub import snapshot_download
15
  from PIL import Image
16
- from skimage import measure
17
  from transformers import AutoModelForMaskGeneration, AutoProcessor
18
 
19
  from midi.pipelines.pipeline_midi import MIDIPipeline
20
- from midi.utils.smoothing import smooth_gpu
21
  from scripts.grounding_sam import plot_segmentation, segment
22
- from scripts.inference_midi import preprocess_image, split_rgb_mask
 
 
23
 
24
  # Constants
25
  MAX_SEED = np.iinfo(np.int32).max
@@ -30,7 +28,7 @@ REPO_ID = "VAST-AI/MIDI-3D"
30
 
31
  MARKDOWN = """
32
  ## Image to 3D Scene with [MIDI-3D](https://huanngzh.github.io/MIDI-Page/)
33
- <b>Important!</b> Please check out our [instruction video](https://github.com/user-attachments/assets/814c046e-f5c3-47cf-bb56-60154be8374c)!
34
  1. Upload an image, and draw bounding boxes for each instance by holding and dragging the mouse. Then clik "Run Segmentation" to generate the segmentation result. <b>Ensure instances should not be too small and bounding boxes fit snugly around each instance.</b>
35
  2. <b>Check "Do image padding" in "Generation Settings" if instances in your image are too close to the image border.</b> Then click "Run Generation" to generate a 3D scene from the image and segmentation result.
36
  3. If you find the generated 3D scene satisfactory, download it by clicking the "Download GLB" button.
@@ -39,9 +37,9 @@ MARKDOWN = """
39
  EXAMPLES = [
40
  [
41
  {
42
- "image": "assets/example_data/Cartoon-Style/03_rgb.png",
43
  },
44
- "assets/example_data/Cartoon-Style/03_seg.png",
45
  42,
46
  False,
47
  False,
@@ -57,39 +55,39 @@ EXAMPLES = [
57
  ],
58
  [
59
  {
60
- "image": "assets/example_data/Realistic-Style/02_rgb.png",
61
  },
62
- "assets/example_data/Realistic-Style/02_seg.png",
63
  42,
64
  False,
65
  False,
66
  ],
67
  [
68
  {
69
- "image": "assets/example_data/Cartoon-Style/00_rgb.png",
70
  },
71
- "assets/example_data/Cartoon-Style/00_seg.png",
72
  42,
73
  False,
74
- False,
75
  ],
76
  [
77
  {
78
- "image": "assets/example_data/Realistic-Style/00_rgb.png",
79
  },
80
- "assets/example_data/Realistic-Style/00_seg.png",
81
  42,
82
  False,
83
  True,
84
  ],
85
  [
86
  {
87
- "image": "assets/example_data/Realistic-Style/01_rgb.png",
88
  },
89
- "assets/example_data/Realistic-Style/01_seg.png",
90
  42,
91
  False,
92
- True,
93
  ],
94
  [
95
  {
@@ -127,10 +125,38 @@ pipe.init_custom_adapter(
127
 
128
 
129
  # Utils
130
- def get_random_hex():
131
- random_bytes = os.urandom(8)
132
- random_hex = random_bytes.hex()
133
- return random_hex
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
 
135
 
136
  @spaces.GPU()
@@ -164,37 +190,7 @@ def run_segmentation(image_prompts: Any, polygon_refinement: bool) -> Image.Imag
164
  return seg_map_pil
165
 
166
 
167
- @torch.no_grad()
168
- def run_midi(
169
- pipe: Any,
170
- rgb_image: Union[str, Image.Image],
171
- seg_image: Union[str, Image.Image],
172
- seed: int,
173
- num_inference_steps: int = 50,
174
- guidance_scale: float = 7.0,
175
- do_image_padding: bool = False,
176
- ) -> trimesh.Scene:
177
- if do_image_padding:
178
- rgb_image, seg_image = preprocess_image(rgb_image, seg_image)
179
- instance_rgbs, instance_masks, scene_rgbs = split_rgb_mask(rgb_image, seg_image)
180
-
181
- num_instances = len(instance_rgbs)
182
- outputs = pipe(
183
- image=instance_rgbs,
184
- mask=instance_masks,
185
- image_scene=scene_rgbs,
186
- attention_kwargs={"num_instances": num_instances},
187
- generator=torch.Generator(device=pipe.device).manual_seed(seed),
188
- num_inference_steps=num_inference_steps,
189
- guidance_scale=guidance_scale,
190
- decode_progressive=True,
191
- return_dict=False,
192
- )
193
-
194
- return outputs
195
-
196
-
197
- @spaces.GPU(duration=180)
198
  @torch.no_grad()
199
  @torch.autocast(device_type=DEVICE, dtype=torch.bfloat16)
200
  def run_generation(
@@ -212,7 +208,7 @@ def run_generation(
212
  if not isinstance(rgb_image, Image.Image) and "image" in rgb_image:
213
  rgb_image = rgb_image["image"]
214
 
215
- outputs = run_midi(
216
  pipe,
217
  rgb_image,
218
  seg_image,
@@ -222,27 +218,7 @@ def run_generation(
222
  do_image_padding,
223
  )
224
 
225
- # marching cubes
226
- trimeshes = []
227
- for _, (logits_, grid_size, bbox_size, bbox_min, bbox_max) in enumerate(
228
- zip(*outputs)
229
- ):
230
- grid_logits = logits_.view(grid_size)
231
- grid_logits = smooth_gpu(grid_logits, method="gaussian", sigma=1)
232
- torch.cuda.empty_cache()
233
- vertices, faces, normals, _ = measure.marching_cubes(
234
- grid_logits.float().cpu().numpy(), 0, method="lewiner"
235
- )
236
- vertices = vertices / grid_size * bbox_size + bbox_min
237
-
238
- # Trimesh
239
- mesh = trimesh.Trimesh(vertices.astype(np.float32), np.ascontiguousarray(faces))
240
- trimeshes.append(mesh)
241
-
242
- # compose the output meshes
243
- scene = trimesh.Scene(trimeshes)
244
-
245
- tmp_path = os.path.join(TMP_DIR, f"midi3d_{get_random_hex()}.glb")
246
  scene.export(tmp_path)
247
 
248
  torch.cuda.empty_cache()
 
6
 
7
  import gradio as gr
8
  import numpy as np
 
9
  import torch
 
10
  from gradio_image_prompter import ImagePrompter
11
  from gradio_litmodel3d import LitModel3D
12
  from huggingface_hub import snapshot_download
13
  from PIL import Image
 
14
  from transformers import AutoModelForMaskGeneration, AutoProcessor
15
 
16
  from midi.pipelines.pipeline_midi import MIDIPipeline
 
17
  from scripts.grounding_sam import plot_segmentation, segment
18
+ from scripts.inference_midi import run_midi
19
+
20
+ import spaces
21
 
22
  # Constants
23
  MAX_SEED = np.iinfo(np.int32).max
 
28
 
29
  MARKDOWN = """
30
  ## Image to 3D Scene with [MIDI-3D](https://huanngzh.github.io/MIDI-Page/)
31
+ <b>Important!</b> Please check out our [instruction video](https://github.com/user-attachments/assets/4fc8aea4-010f-40c7-989d-6b1d9d3e3e09)!
32
  1. Upload an image, and draw bounding boxes for each instance by holding and dragging the mouse. Then clik "Run Segmentation" to generate the segmentation result. <b>Ensure instances should not be too small and bounding boxes fit snugly around each instance.</b>
33
  2. <b>Check "Do image padding" in "Generation Settings" if instances in your image are too close to the image border.</b> Then click "Run Generation" to generate a 3D scene from the image and segmentation result.
34
  3. If you find the generated 3D scene satisfactory, download it by clicking the "Download GLB" button.
 
37
  EXAMPLES = [
38
  [
39
  {
40
+ "image": "assets/example_data/Cartoon-Style/00_rgb.png",
41
  },
42
+ "assets/example_data/Cartoon-Style/00_seg.png",
43
  42,
44
  False,
45
  False,
 
55
  ],
56
  [
57
  {
58
+ "image": "assets/example_data/Cartoon-Style/03_rgb.png",
59
  },
60
+ "assets/example_data/Cartoon-Style/03_seg.png",
61
  42,
62
  False,
63
  False,
64
  ],
65
  [
66
  {
67
+ "image": "assets/example_data/Realistic-Style/00_rgb.png",
68
  },
69
+ "assets/example_data/Realistic-Style/00_seg.png",
70
  42,
71
  False,
72
+ True,
73
  ],
74
  [
75
  {
76
+ "image": "assets/example_data/Realistic-Style/01_rgb.png",
77
  },
78
+ "assets/example_data/Realistic-Style/01_seg.png",
79
  42,
80
  False,
81
  True,
82
  ],
83
  [
84
  {
85
+ "image": "assets/example_data/Realistic-Style/02_rgb.png",
86
  },
87
+ "assets/example_data/Realistic-Style/02_seg.png",
88
  42,
89
  False,
90
+ False,
91
  ],
92
  [
93
  {
 
125
 
126
 
127
  # Utils
128
+ def split_rgb_mask(rgb_image, seg_image):
129
+ if isinstance(rgb_image, str):
130
+ rgb_image = Image.open(rgb_image)
131
+ if isinstance(seg_image, str):
132
+ seg_image = Image.open(seg_image)
133
+ rgb_image = rgb_image.convert("RGB")
134
+ seg_image = seg_image.convert("L")
135
+
136
+ rgb_array = np.array(rgb_image)
137
+ seg_array = np.array(seg_image)
138
+
139
+ label_ids = np.unique(seg_array)
140
+ label_ids = label_ids[label_ids > 0]
141
+
142
+ instance_rgbs, instance_masks, scene_rgbs = [], [], []
143
+
144
+ for segment_id in sorted(label_ids):
145
+ # Here we set the background to white
146
+ white_background = np.ones_like(rgb_array) * 255
147
+
148
+ mask = np.zeros_like(seg_array, dtype=np.uint8)
149
+ mask[seg_array == segment_id] = 255
150
+ segment_rgb = white_background.copy()
151
+ segment_rgb[mask == 255] = rgb_array[mask == 255]
152
+
153
+ segment_rgb_image = Image.fromarray(segment_rgb)
154
+ segment_mask_image = Image.fromarray(mask)
155
+ instance_rgbs.append(segment_rgb_image)
156
+ instance_masks.append(segment_mask_image)
157
+ scene_rgbs.append(rgb_image)
158
+
159
+ return instance_rgbs, instance_masks, scene_rgbs
160
 
161
 
162
  @spaces.GPU()
 
190
  return seg_map_pil
191
 
192
 
193
+ # @spaces.GPU()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
  @torch.no_grad()
195
  @torch.autocast(device_type=DEVICE, dtype=torch.bfloat16)
196
  def run_generation(
 
208
  if not isinstance(rgb_image, Image.Image) and "image" in rgb_image:
209
  rgb_image = rgb_image["image"]
210
 
211
+ scene = run_midi(
212
  pipe,
213
  rgb_image,
214
  seg_image,
 
218
  do_image_padding,
219
  )
220
 
221
+ _, tmp_path = tempfile.mkstemp(suffix=".glb", prefix="midi3d_", dir=TMP_DIR)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
222
  scene.export(tmp_path)
223
 
224
  torch.cuda.empty_cache()