Ashraf commited on
Commit
b402669
·
verified ·
1 Parent(s): c2fff22

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -84
app.py CHANGED
@@ -1,7 +1,6 @@
1
  import gradio as gr
2
  import spaces
3
  from gradio_litmodel3d import LitModel3D
4
-
5
  import os
6
  import shutil
7
  os.environ['SPCONV_ALGO'] = 'native'
@@ -15,22 +14,18 @@ from trellis.pipelines import TrellisImageTo3DPipeline
15
  from trellis.representations import Gaussian, MeshExtractResult
16
  from trellis.utils import render_utils, postprocessing_utils
17
 
18
-
19
  MAX_SEED = np.iinfo(np.int32).max
20
  TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
21
  os.makedirs(TMP_DIR, exist_ok=True)
22
 
23
-
24
  def start_session(req: gr.Request):
25
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
26
  os.makedirs(user_dir, exist_ok=True)
27
 
28
-
29
  def end_session(req: gr.Request):
30
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
31
  shutil.rmtree(user_dir)
32
 
33
-
34
  def preprocess_image(image: Image.Image) -> Image.Image:
35
  """
36
  Preprocess the input image.
@@ -42,7 +37,6 @@ def preprocess_image(image: Image.Image) -> Image.Image:
42
  processed_image = pipeline.preprocess_image(image)
43
  return processed_image
44
 
45
-
46
  def preprocess_images(images: List[Tuple[Image.Image, str]]) -> List[Image.Image]:
47
  """
48
  Preprocess a list of input images.
@@ -57,7 +51,6 @@ def preprocess_images(images: List[Tuple[Image.Image, str]]) -> List[Image.Image
57
  processed_images = [pipeline.preprocess_image(image) for image in images]
58
  return processed_images
59
 
60
-
61
  def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
62
  return {
63
  'gaussian': {
@@ -74,7 +67,6 @@ def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
74
  },
75
  }
76
 
77
-
78
  def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
79
  gs = Gaussian(
80
  aabb=state['gaussian']['aabb'],
@@ -97,14 +89,12 @@ def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
97
 
98
  return gs, mesh
99
 
100
-
101
  def get_seed(randomize_seed: bool, seed: int) -> int:
102
  """
103
  Get the random seed.
104
  """
105
  return np.random.randint(0, MAX_SEED) if randomize_seed else seed
106
 
107
-
108
  @spaces.GPU
109
  def image_to_3d(
110
  image: Image.Image,
@@ -175,7 +165,6 @@ def image_to_3d(
175
  torch.cuda.empty_cache()
176
  return state, video_path
177
 
178
-
179
  @spaces.GPU(duration=90)
180
  def extract_glb(
181
  state: dict,
@@ -200,7 +189,6 @@ def extract_glb(
200
  torch.cuda.empty_cache()
201
  return glb_path, glb_path
202
 
203
-
204
  @spaces.GPU
205
  def extract_gaussian(state: dict, req: gr.Request) -> Tuple[str, str]:
206
  """
@@ -217,39 +205,7 @@ def extract_gaussian(state: dict, req: gr.Request) -> Tuple[str, str]:
217
  torch.cuda.empty_cache()
218
  return gaussian_path, gaussian_path
219
 
220
-
221
- def prepare_multi_example() -> List[Image.Image]:
222
- multi_case = list(set([i.split('_')[0] for i in os.listdir("assets/example_multi_image")]))
223
- images = []
224
- for case in multi_case:
225
- _images = []
226
- for i in range(1, 4):
227
- img = Image.open(f'assets/example_multi_image/{case}_{i}.png')
228
- W, H = img.size
229
- img = img.resize((int(W / H * 512), 512))
230
- _images.append(np.array(img))
231
- images.append(Image.fromarray(np.concatenate(_images, axis=1)))
232
- return images
233
-
234
-
235
- def split_image(image: Image.Image) -> List[Image.Image]:
236
- """
237
- Split an image into multiple views.
238
- """
239
- image = np.array(image)
240
- alpha = image[..., 3]
241
- alpha = np.any(alpha>0, axis=0)
242
- start_pos = np.where(~alpha[:-1] & alpha[1:])[0].tolist()
243
- end_pos = np.where(alpha[:-1] & ~alpha[1:])[0].tolist()
244
- images = []
245
- for s, e in zip(start_pos, end_pos):
246
- images.append(Image.fromarray(image[:, s:e+1]))
247
- return [preprocess_image(image) for image in images]
248
-
249
-
250
- with gr.Blocks(delete_cache=(600, 600)) as demo:
251
- gr.Markdown("")
252
-
253
  with gr.Row():
254
  with gr.Column():
255
  with gr.Tabs() as input_tabs:
@@ -257,26 +213,19 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
257
  image_prompt = gr.Image(label="Image Prompt", format="png", image_mode="RGBA", type="pil", height=300)
258
  with gr.Tab(label="Multiple Images", id=1) as multiimage_input_tab:
259
  multiimage_prompt = gr.Gallery(label="Image Prompt", format="png", type="pil", height=300, columns=3)
260
- gr.Markdown("""
261
- Input different views of the object in separate images.
262
-
263
- *NOTE: this is an experimental algorithm without training a specialized model. It may not produce the best results for all images, especially those having different poses or inconsistent details.*
264
- """)
265
 
266
  with gr.Accordion(label="Generation Settings", open=False):
267
  seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
268
  randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
269
- gr.Markdown("Stage 1: Sparse Structure Generation")
270
  with gr.Row():
271
  ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
272
  ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
273
- gr.Markdown("Stage 2: Structured Latent Generation")
274
  with gr.Row():
275
  slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
276
  slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
277
  multiimage_algo = gr.Radio(["stochastic", "multidiffusion"], label="Multi-image Algorithm", value="stochastic")
278
 
279
- generate_btn = gr.Button("Generate")
280
 
281
  with gr.Accordion(label="GLB Extraction Settings", open=False):
282
  mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01)
@@ -285,9 +234,6 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
285
  with gr.Row():
286
  extract_glb_btn = gr.Button("Extract GLB", interactive=False)
287
  extract_gs_btn = gr.Button("Extract Gaussian", interactive=False)
288
- gr.Markdown("""
289
- *NOTE: Gaussian file can be very large (~50MB), it will take a while to display and download.*
290
- """)
291
 
292
  with gr.Column():
293
  video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
@@ -300,40 +246,17 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
300
  is_multiimage = gr.State(False)
301
  output_buf = gr.State()
302
 
303
- # Example images at the bottom of the page
304
- with gr.Row() as single_image_example:
305
- examples = gr.Examples(
306
- examples=[
307
- f'assets/example_image/{image}'
308
- for image in os.listdir("")
309
- ],
310
- inputs=[image_prompt],
311
- fn=preprocess_image,
312
- outputs=[image_prompt],
313
- run_on_click=True,
314
- examples_per_page=64,
315
- )
316
- with gr.Row(visible=False) as multiimage_example:
317
- examples_multi = gr.Examples(
318
- examples=prepare_multi_example(),
319
- inputs=[image_prompt],
320
- fn=split_image,
321
- outputs=[multiimage_prompt],
322
- run_on_click=True,
323
- examples_per_page=8,
324
- )
325
-
326
  # Handlers
327
  demo.load(start_session)
328
  demo.unload(end_session)
329
 
330
  single_image_input_tab.select(
331
- lambda: tuple([False, gr.Row.update(visible=True), gr.Row.update(visible=False)]),
332
- outputs=[is_multiimage, single_image_example, multiimage_example]
333
  )
334
  multiimage_input_tab.select(
335
- lambda: tuple([True, gr.Row.update(visible=False), gr.Row.update(visible=True)]),
336
- outputs=[is_multiimage, single_image_example, multiimage_example]
337
  )
338
 
339
  image_prompt.upload(
@@ -387,7 +310,6 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
387
  lambda: gr.Button(interactive=False),
388
  outputs=[download_glb],
389
  )
390
-
391
 
392
  # Launch the Gradio app
393
  if __name__ == "__main__":
 
1
  import gradio as gr
2
  import spaces
3
  from gradio_litmodel3d import LitModel3D
 
4
  import os
5
  import shutil
6
  os.environ['SPCONV_ALGO'] = 'native'
 
14
  from trellis.representations import Gaussian, MeshExtractResult
15
  from trellis.utils import render_utils, postprocessing_utils
16
 
 
17
  MAX_SEED = np.iinfo(np.int32).max
18
  TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
19
  os.makedirs(TMP_DIR, exist_ok=True)
20
 
 
21
  def start_session(req: gr.Request):
22
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
23
  os.makedirs(user_dir, exist_ok=True)
24
 
 
25
  def end_session(req: gr.Request):
26
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
27
  shutil.rmtree(user_dir)
28
 
 
29
  def preprocess_image(image: Image.Image) -> Image.Image:
30
  """
31
  Preprocess the input image.
 
37
  processed_image = pipeline.preprocess_image(image)
38
  return processed_image
39
 
 
40
  def preprocess_images(images: List[Tuple[Image.Image, str]]) -> List[Image.Image]:
41
  """
42
  Preprocess a list of input images.
 
51
  processed_images = [pipeline.preprocess_image(image) for image in images]
52
  return processed_images
53
 
 
54
  def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
55
  return {
56
  'gaussian': {
 
67
  },
68
  }
69
 
 
70
  def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
71
  gs = Gaussian(
72
  aabb=state['gaussian']['aabb'],
 
89
 
90
  return gs, mesh
91
 
 
92
  def get_seed(randomize_seed: bool, seed: int) -> int:
93
  """
94
  Get the random seed.
95
  """
96
  return np.random.randint(0, MAX_SEED) if randomize_seed else seed
97
 
 
98
  @spaces.GPU
99
  def image_to_3d(
100
  image: Image.Image,
 
165
  torch.cuda.empty_cache()
166
  return state, video_path
167
 
 
168
  @spaces.GPU(duration=90)
169
  def extract_glb(
170
  state: dict,
 
189
  torch.cuda.empty_cache()
190
  return glb_path, glb_path
191
 
 
192
  @spaces.GPU
193
  def extract_gaussian(state: dict, req: gr.Request) -> Tuple[str, str]:
194
  """
 
205
  torch.cuda.empty_cache()
206
  return gaussian_path, gaussian_path
207
 
208
+ with gr.Blocks(theme=gr.themes.Default(), delete_cache=(600, 600)) as demo:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
  with gr.Row():
210
  with gr.Column():
211
  with gr.Tabs() as input_tabs:
 
213
  image_prompt = gr.Image(label="Image Prompt", format="png", image_mode="RGBA", type="pil", height=300)
214
  with gr.Tab(label="Multiple Images", id=1) as multiimage_input_tab:
215
  multiimage_prompt = gr.Gallery(label="Image Prompt", format="png", type="pil", height=300, columns=3)
 
 
 
 
 
216
 
217
  with gr.Accordion(label="Generation Settings", open=False):
218
  seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
219
  randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
 
220
  with gr.Row():
221
  ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
222
  ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
 
223
  with gr.Row():
224
  slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
225
  slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
226
  multiimage_algo = gr.Radio(["stochastic", "multidiffusion"], label="Multi-image Algorithm", value="stochastic")
227
 
228
+ generate_btn = gr.Button("Generate", variant="primary")
229
 
230
  with gr.Accordion(label="GLB Extraction Settings", open=False):
231
  mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01)
 
234
  with gr.Row():
235
  extract_glb_btn = gr.Button("Extract GLB", interactive=False)
236
  extract_gs_btn = gr.Button("Extract Gaussian", interactive=False)
 
 
 
237
 
238
  with gr.Column():
239
  video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
 
246
  is_multiimage = gr.State(False)
247
  output_buf = gr.State()
248
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
249
  # Handlers
250
  demo.load(start_session)
251
  demo.unload(end_session)
252
 
253
  single_image_input_tab.select(
254
+ lambda: False,
255
+ outputs=[is_multiimage]
256
  )
257
  multiimage_input_tab.select(
258
+ lambda: True,
259
+ outputs=[is_multiimage]
260
  )
261
 
262
  image_prompt.upload(
 
310
  lambda: gr.Button(interactive=False),
311
  outputs=[download_glb],
312
  )
 
313
 
314
  # Launch the Gradio app
315
  if __name__ == "__main__":