aiqtech commited on
Commit
23a34ba
·
verified ·
1 Parent(s): b3a304c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +129 -360
app.py CHANGED
@@ -1,7 +1,10 @@
1
  import gradio as gr
2
  import spaces
3
  from gradio_litmodel3d import LitModel3D
 
4
  import os
 
 
5
  import torch
6
  import numpy as np
7
  import imageio
@@ -11,162 +14,28 @@ from PIL import Image
11
  from trellis.pipelines import TrellisImageTo3DPipeline
12
  from trellis.representations import Gaussian, MeshExtractResult
13
  from trellis.utils import render_utils, postprocessing_utils
14
- from transformers import pipeline as translation_pipeline
15
- from diffusers import FluxPipeline
16
- from typing import *
17
 
18
 
19
  MAX_SEED = np.iinfo(np.int32).max
20
  TMP_DIR = "/tmp/Trellis-demo"
21
- os.makedirs(TMP_DIR, exist_ok=True)
22
-
23
- # GPU 메모리 관련 환경 변수
24
- os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128' # 더 작은 값으로 설정
25
- os.environ['CUDA_VISIBLE_DEVICES'] = '0'
26
- os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
27
- os.environ['PYTORCH_NO_CUDA_MEMORY_CACHING'] = '1'
28
- os.environ['CUDA_CACHE_DISABLE'] = '1'
29
-
30
- def initialize_models():
31
- global pipeline, translator, flux_pipe
32
-
33
- try:
34
- # CUDA 설정
35
- if torch.cuda.is_available():
36
- torch.backends.cudnn.benchmark = True
37
- torch.backends.cuda.matmul.allow_tf32 = True
38
- torch.backends.cudnn.allow_tf32 = True
39
-
40
- print("Initializing Trellis pipeline...")
41
- try:
42
- pipeline = TrellisImageTo3DPipeline.from_pretrained(
43
- "JeffreyXiang/TRELLIS-image-large"
44
- )
45
-
46
- if pipeline is None:
47
- raise ValueError("Pipeline initialization returned None")
48
-
49
- if torch.cuda.is_available():
50
- pipeline = pipeline.to("cuda")
51
- # Half precision으로 변환
52
- pipeline = pipeline.half()
53
-
54
- except Exception as e:
55
- print(f"Error initializing Trellis pipeline: {str(e)}")
56
- raise
57
-
58
- print("Initializing translator...")
59
- try:
60
- translator = translation_pipeline(
61
- "translation",
62
- model="Helsinki-NLP/opus-mt-ko-en",
63
- device=0 if torch.cuda.is_available() else -1
64
- )
65
- except Exception as e:
66
- print(f"Error initializing translator: {str(e)}")
67
- raise
68
-
69
- flux_pipe = None
70
-
71
- print("Models initialized successfully")
72
- return True
73
-
74
- except Exception as e:
75
- print(f"Model initialization error: {str(e)}")
76
- free_memory()
77
- return False
78
-
79
- def get_flux_pipe():
80
- """Flux 파이프라인을 필요할 때만 로드하는 함수"""
81
- global flux_pipe
82
- if flux_pipe is None:
83
- try:
84
- free_memory()
85
- flux_pipe = FluxPipeline.from_pretrained(
86
- "black-forest-labs/FLUX.1-dev",
87
- use_safetensors=True
88
- )
89
- if torch.cuda.is_available():
90
- flux_pipe = flux_pipe.to("cuda")
91
- flux_pipe.enable_model_cpu_offload() # CPU 오프로딩 활성화
92
- except Exception as e:
93
- print(f"Error loading Flux pipeline: {e}")
94
- return None
95
- return flux_pipe
96
-
97
- def free_memory():
98
- """강화된 메모리 정리 함수"""
99
- import gc
100
- import os
101
-
102
- # Python 가비지 컬렉션
103
- gc.collect()
104
-
105
- # CUDA 메모리 정리
106
- if torch.cuda.is_available():
107
- torch.cuda.empty_cache()
108
- torch.cuda.synchronize()
109
-
110
- # 임시 파일 정리
111
- tmp_dirs = ['/tmp/transformers_cache', '/tmp/torch_home',
112
- '/tmp/huggingface', '/tmp/cache', TMP_DIR]
113
-
114
- for dir_path in tmp_dirs:
115
- if os.path.exists(dir_path):
116
- try:
117
- for file in os.listdir(dir_path):
118
- file_path = os.path.join(dir_path, file)
119
- if os.path.isfile(file_path):
120
- try:
121
- os.unlink(file_path)
122
- except:
123
- pass
124
- except:
125
- pass
126
 
127
-
128
- def setup_gpu_model(model):
129
- """GPU 설정이 필요한 모델을 처리하는 함수"""
130
- if torch.cuda.is_available():
131
- model = model.to("cuda")
132
- return model
133
-
134
-
135
- def translate_if_korean(text):
136
- if any(ord('가') <= ord(char) <= ord('힣') for char in text):
137
- translated = translator(text)[0]['translation_text']
138
- return translated
139
- return text
140
 
141
 
142
  def preprocess_image(image: Image.Image) -> Tuple[str, Image.Image]:
143
- if image is None:
144
- return None, None
145
-
146
- try:
147
- trial_id = str(uuid.uuid4())
148
-
149
- # 이미지 크기 제한
150
- max_size = 768
151
- if max(image.size) > max_size:
152
- ratio = max_size / max(image.size)
153
- new_size = tuple(int(dim * ratio) for dim in image.size)
154
- image = image.resize(new_size, Image.LANCZOS)
155
-
156
- # 이미지 전처리
157
- processed_image = pipeline.preprocess_image(image)
158
- if processed_image is None:
159
- raise Exception("Failed to process image")
160
-
161
- # 임시 파일 저장
162
- save_path = os.path.join(TMP_DIR, f"{trial_id}.png")
163
- processed_image.save(save_path)
164
-
165
- return trial_id, processed_image
166
-
167
- except Exception as e:
168
- print(f"Error in preprocess_image: {str(e)}")
169
- return None, None
170
 
171
  def pack_state(gs: Gaussian, mesh: MeshExtractResult, trial_id: str) -> dict:
172
  return {
@@ -184,7 +53,8 @@ def pack_state(gs: Gaussian, mesh: MeshExtractResult, trial_id: str) -> dict:
184
  },
185
  'trial_id': trial_id,
186
  }
187
-
 
188
  def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
189
  gs = Gaussian(
190
  aabb=state['gaussian']['aabb'],
@@ -207,190 +77,116 @@ def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
207
 
208
  return gs, mesh, state['trial_id']
209
 
210
- def image_to_3d(trial_id: str, seed: int, randomize_seed: bool, ss_guidance_strength: float,
211
- ss_sampling_steps: int, slat_guidance_strength: float, slat_sampling_steps: int):
212
- try:
213
- if randomize_seed:
214
- seed = np.random.randint(0, MAX_SEED)
215
-
216
- input_image = Image.open(f"{TMP_DIR}/{trial_id}.png")
217
-
218
- # L40S에 맞게 이미지 크기 제한 조정
219
- max_size = 768
220
- if max(input_image.size) > max_size:
221
- ratio = max_size / max(input_image.size)
222
- input_image = input_image.resize(
223
- (int(input_image.size[0] * ratio),
224
- int(input_image.size[1] * ratio)),
225
- Image.LANCZOS
226
- )
227
-
228
- if torch.cuda.is_available():
229
- pipeline.to("cuda")
230
-
231
- try:
232
- outputs = pipeline.run(
233
- input_image,
234
- seed=seed,
235
- formats=["gaussian", "mesh"],
236
- preprocess_image=False,
237
- sparse_structure_sampler_params={
238
- "steps": min(ss_sampling_steps, 20),
239
- "cfg_strength": ss_guidance_strength,
240
- },
241
- slat_sampler_params={
242
- "steps": min(slat_sampling_steps, 20),
243
- "cfg_strength": slat_guidance_strength,
244
- }
245
- )
246
- except RuntimeError as e:
247
- print(f"Runtime error in pipeline.run: {str(e)}")
248
- free_memory()
249
- raise e
250
-
251
- # 비디오 생성
252
- video = render_utils.render_video(outputs['gaussian'][0], num_frames=40)['color']
253
- video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=40)['normal']
254
- video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
255
-
256
- trial_id = str(uuid.uuid4())
257
- video_path = f"{TMP_DIR}/{trial_id}.mp4"
258
- os.makedirs(os.path.dirname(video_path), exist_ok=True)
259
- imageio.mimsave(video_path, video, fps=20)
260
-
261
- state = pack_state(outputs['gaussian'][0], outputs['mesh'][0], trial_id)
262
-
263
- if torch.cuda.is_available():
264
- pipeline.to("cpu")
265
-
266
- return state, video_path
267
-
268
- except Exception as e:
269
- print(f"Error in image_to_3d: {str(e)}")
270
- if torch.cuda.is_available():
271
- pipeline.to("cpu")
272
- raise e
273
-
274
-
275
- def generate_image_from_text(prompt, height, width, guidance_scale, num_steps):
276
- try:
277
- free_memory()
278
-
279
- flux_pipe = get_flux_pipe()
280
- if flux_pipe is None:
281
- raise Exception("Failed to load Flux pipeline")
282
-
283
- # L40S에 맞게 크기 제한 조정
284
- height = min(height, 1024)
285
- width = min(width, 1024)
286
-
287
- translated_prompt = translate_if_korean(prompt)
288
- final_prompt = f"{translated_prompt}, wbgmsst, 3D, white background"
289
-
290
- with torch.cuda.amp.autocast():
291
- output = flux_pipe(
292
- prompt=[final_prompt],
293
- height=height,
294
- width=width,
295
- guidance_scale=guidance_scale,
296
- num_inference_steps=num_steps,
297
- generator=torch.Generator(device='cuda')
298
- )
299
-
300
- image = output.images[0]
301
-
302
- free_memory()
303
- return image
304
-
305
- except Exception as e:
306
- print(f"Error in generate_image_from_text: {str(e)}")
307
- free_memory()
308
- raise e
309
-
310
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
311
  def extract_glb(state: dict, mesh_simplify: float, texture_size: int) -> Tuple[str, str]:
 
 
 
 
 
 
 
 
 
312
  gs, mesh, trial_id = unpack_state(state)
313
  glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
314
  glb_path = f"{TMP_DIR}/{trial_id}.glb"
315
  glb.export(glb_path)
316
  return glb_path, glb_path
317
 
 
318
  def activate_button() -> gr.Button:
319
  return gr.Button(interactive=True)
320
 
 
321
  def deactivate_button() -> gr.Button:
322
  return gr.Button(interactive=False)
323
 
324
- css = """
325
- footer {
326
- visibility: hidden;
327
- }
328
- """
329
 
330
- # Gradio 인터페이스 정의
331
- with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange", css=css) as demo:
332
  gr.Markdown("""
333
- # Craft3D : 3D Asset Creation & Text-to-Image Generation
 
 
334
  """)
335
 
336
- with gr.Tabs():
337
- with gr.TabItem("Image to 3D"):
338
- with gr.Row():
339
- with gr.Column():
340
- image_prompt = gr.Image(label="Image Prompt", image_mode="RGBA", type="pil", height=300)
341
-
342
- with gr.Accordion(label="Generation Settings", open=False):
343
- seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
344
- randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
345
- gr.Markdown("Stage 1: Sparse Structure Generation")
346
- with gr.Row():
347
- ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
348
- ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
349
- gr.Markdown("Stage 2: Structured Latent Generation")
350
- with gr.Row():
351
- slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
352
- slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
353
-
354
- generate_btn = gr.Button("Generate")
355
-
356
- with gr.Accordion(label="GLB Extraction Settings", open=False):
357
- mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01)
358
- texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
359
-
360
- extract_glb_btn = gr.Button("Extract GLB", interactive=False)
361
-
362
- with gr.Column():
363
- video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
364
- model_output = LitModel3D(label="Extracted GLB", exposure=20.0, height=300)
365
- download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
366
 
367
- with gr.TabItem("Text to Image"):
368
- with gr.Row():
369
- with gr.Column():
370
- text_prompt = gr.Textbox(
371
- label="Text Prompt",
372
- placeholder="Enter your image description...",
373
- lines=3
374
- )
375
-
376
- with gr.Row():
377
- txt2img_height = gr.Slider(256, 1024, value=512, step=64, label="Height")
378
- txt2img_width = gr.Slider(256, 1024, value=512, step=64, label="Width")
379
-
380
- with gr.Row():
381
- guidance_scale = gr.Slider(1.0, 20.0, value=7.5, label="Guidance Scale")
382
- num_steps = gr.Slider(1, 50, value=20, label="Number of Steps")
383
-
384
- generate_txt2img_btn = gr.Button("Generate Image")
385
-
386
- with gr.Column():
387
- txt2img_output = gr.Image(label="Generated Image")
388
-
389
  trial_id = gr.Textbox(visible=False)
390
  output_buf = gr.State()
391
 
392
-
393
- # Example images
394
  with gr.Row():
395
  examples = gr.Examples(
396
  examples=[
@@ -401,11 +197,8 @@ with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange", css=css) as demo:
401
  fn=preprocess_image,
402
  outputs=[trial_id, image_prompt],
403
  run_on_click=True,
404
- examples_per_page=32, # 예제 수 감소
405
- cache_examples=False # 예제 캐싱 비활성화는 Examples 컴포넌트에서 설정
406
  )
407
-
408
-
409
 
410
  # Handlers
411
  image_prompt.upload(
@@ -413,7 +206,6 @@ with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange", css=css) as demo:
413
  inputs=[image_prompt],
414
  outputs=[trial_id, image_prompt],
415
  )
416
-
417
  image_prompt.clear(
418
  lambda: '',
419
  outputs=[trial_id],
@@ -421,62 +213,39 @@ with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange", css=css) as demo:
421
 
422
  generate_btn.click(
423
  image_to_3d,
424
- inputs=[trial_id, seed, randomize_seed, ss_guidance_strength, ss_sampling_steps,
425
- slat_guidance_strength, slat_sampling_steps],
426
  outputs=[output_buf, video_output],
427
- concurrency_limit=1
428
  ).then(
429
  activate_button,
430
- outputs=[extract_glb_btn]
 
 
 
 
 
431
  )
432
 
433
  extract_glb_btn.click(
434
  extract_glb,
435
  inputs=[output_buf, mesh_simplify, texture_size],
436
  outputs=[model_output, download_glb],
437
- concurrency_limit=1
438
  ).then(
439
  activate_button,
440
- outputs=[download_glb]
441
- )
442
- generate_txt2img_btn.click(
443
- generate_image_from_text,
444
- inputs=[text_prompt, txt2img_height, txt2img_width, guidance_scale, num_steps],
445
- outputs=[txt2img_output],
446
- concurrency_limit=1,
447
- show_progress=True # 진행 상황 표시
448
  )
449
 
 
 
 
 
 
450
 
 
451
  if __name__ == "__main__":
452
- import warnings
453
- warnings.filterwarnings('ignore')
454
-
455
- # CUDA 설정 확인
456
- if torch.cuda.is_available():
457
- print(f"Using GPU: {torch.cuda.get_device_name()}")
458
- print(f"Available GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
459
-
460
- # CUDA 메모리 설정
461
- torch.cuda.set_per_process_memory_fraction(0.8) # GPU 메모리 사용량 제한
462
-
463
- # 디렉토리 생성
464
- os.makedirs(TMP_DIR, exist_ok=True)
465
-
466
- # 메모리 정리
467
- free_memory()
468
-
469
- # 모델 초기화
470
- if not initialize_models():
471
- print("Failed to initialize models")
472
- exit(1)
473
-
474
- # Gradio 앱 실행
475
- demo.queue(max_size=1).launch(
476
- share=True,
477
- max_threads=2,
478
- show_error=True,
479
- server_port=7860,
480
- server_name="0.0.0.0",
481
- enable_queue=True
482
- )
 
1
  import gradio as gr
2
  import spaces
3
  from gradio_litmodel3d import LitModel3D
4
+
5
  import os
6
+ os.environ['SPCONV_ALGO'] = 'native'
7
+ from typing import *
8
  import torch
9
  import numpy as np
10
  import imageio
 
14
  from trellis.pipelines import TrellisImageTo3DPipeline
15
  from trellis.representations import Gaussian, MeshExtractResult
16
  from trellis.utils import render_utils, postprocessing_utils
 
 
 
17
 
18
 
19
  MAX_SEED = np.iinfo(np.int32).max
20
  TMP_DIR = "/tmp/Trellis-demo"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
+ os.makedirs(TMP_DIR, exist_ok=True)
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
 
25
  def preprocess_image(image: Image.Image) -> Tuple[str, Image.Image]:
26
+ """
27
+ Preprocess the input image.
28
+ Args:
29
+ image (Image.Image): The input image.
30
+ Returns:
31
+ str: uuid of the trial.
32
+ Image.Image: The preprocessed image.
33
+ """
34
+ trial_id = str(uuid.uuid4())
35
+ processed_image = pipeline.preprocess_image(image)
36
+ processed_image.save(f"{TMP_DIR}/{trial_id}.png")
37
+ return trial_id, processed_image
38
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
  def pack_state(gs: Gaussian, mesh: MeshExtractResult, trial_id: str) -> dict:
41
  return {
 
53
  },
54
  'trial_id': trial_id,
55
  }
56
+
57
+
58
  def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
59
  gs = Gaussian(
60
  aabb=state['gaussian']['aabb'],
 
77
 
78
  return gs, mesh, state['trial_id']
79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
+ @spaces.GPU
82
+ def image_to_3d(trial_id: str, seed: int, randomize_seed: bool, ss_guidance_strength: float, ss_sampling_steps: int, slat_guidance_strength: float, slat_sampling_steps: int) -> Tuple[dict, str]:
83
+ """
84
+ Convert an image to a 3D model.
85
+ Args:
86
+ trial_id (str): The uuid of the trial.
87
+ seed (int): The random seed.
88
+ randomize_seed (bool): Whether to randomize the seed.
89
+ ss_guidance_strength (float): The guidance strength for sparse structure generation.
90
+ ss_sampling_steps (int): The number of sampling steps for sparse structure generation.
91
+ slat_guidance_strength (float): The guidance strength for structured latent generation.
92
+ slat_sampling_steps (int): The number of sampling steps for structured latent generation.
93
+ Returns:
94
+ dict: The information of the generated 3D model.
95
+ str: The path to the video of the 3D model.
96
+ """
97
+ if randomize_seed:
98
+ seed = np.random.randint(0, MAX_SEED)
99
+ outputs = pipeline.run(
100
+ Image.open(f"{TMP_DIR}/{trial_id}.png"),
101
+ seed=seed,
102
+ formats=["gaussian", "mesh"],
103
+ preprocess_image=False,
104
+ sparse_structure_sampler_params={
105
+ "steps": ss_sampling_steps,
106
+ "cfg_strength": ss_guidance_strength,
107
+ },
108
+ slat_sampler_params={
109
+ "steps": slat_sampling_steps,
110
+ "cfg_strength": slat_guidance_strength,
111
+ },
112
+ )
113
+ video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
114
+ video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
115
+ video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
116
+ trial_id = uuid.uuid4()
117
+ video_path = f"{TMP_DIR}/{trial_id}.mp4"
118
+ os.makedirs(os.path.dirname(video_path), exist_ok=True)
119
+ imageio.mimsave(video_path, video, fps=15)
120
+ state = pack_state(outputs['gaussian'][0], outputs['mesh'][0], trial_id)
121
+ return state, video_path
122
+
123
+
124
+ @spaces.GPU
125
  def extract_glb(state: dict, mesh_simplify: float, texture_size: int) -> Tuple[str, str]:
126
+ """
127
+ Extract a GLB file from the 3D model.
128
+ Args:
129
+ state (dict): The state of the generated 3D model.
130
+ mesh_simplify (float): The mesh simplification factor.
131
+ texture_size (int): The texture resolution.
132
+ Returns:
133
+ str: The path to the extracted GLB file.
134
+ """
135
  gs, mesh, trial_id = unpack_state(state)
136
  glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
137
  glb_path = f"{TMP_DIR}/{trial_id}.glb"
138
  glb.export(glb_path)
139
  return glb_path, glb_path
140
 
141
+
142
  def activate_button() -> gr.Button:
143
  return gr.Button(interactive=True)
144
 
145
+
146
  def deactivate_button() -> gr.Button:
147
  return gr.Button(interactive=False)
148
 
 
 
 
 
 
149
 
150
+ with gr.Blocks() as demo:
 
151
  gr.Markdown("""
152
+ ## Image to 3D Asset with [TRELLIS](https://trellis3d.github.io/)
153
+ * Upload an image and click "Generate" to create a 3D asset. If the image has alpha channel, it be used as the mask. Otherwise, we use `rembg` to remove the background.
154
+ * If you find the generated 3D asset satisfactory, click "Extract GLB" to extract the GLB file and download it.
155
  """)
156
 
157
+ with gr.Row():
158
+ with gr.Column():
159
+ image_prompt = gr.Image(label="Image Prompt", image_mode="RGBA", type="pil", height=300)
160
+
161
+ with gr.Accordion(label="Generation Settings", open=False):
162
+ seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
163
+ randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
164
+ gr.Markdown("Stage 1: Sparse Structure Generation")
165
+ with gr.Row():
166
+ ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
167
+ ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
168
+ gr.Markdown("Stage 2: Structured Latent Generation")
169
+ with gr.Row():
170
+ slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
171
+ slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
172
+
173
+ generate_btn = gr.Button("Generate")
174
+
175
+ with gr.Accordion(label="GLB Extraction Settings", open=False):
176
+ mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01)
177
+ texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
178
+
179
+ extract_glb_btn = gr.Button("Extract GLB", interactive=False)
 
 
 
 
 
 
 
180
 
181
+ with gr.Column():
182
+ video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
183
+ model_output = LitModel3D(label="Extracted GLB", exposure=20.0, height=300)
184
+ download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
185
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
  trial_id = gr.Textbox(visible=False)
187
  output_buf = gr.State()
188
 
189
+ # Example images at the bottom of the page
 
190
  with gr.Row():
191
  examples = gr.Examples(
192
  examples=[
 
197
  fn=preprocess_image,
198
  outputs=[trial_id, image_prompt],
199
  run_on_click=True,
200
+ examples_per_page=64,
 
201
  )
 
 
202
 
203
  # Handlers
204
  image_prompt.upload(
 
206
  inputs=[image_prompt],
207
  outputs=[trial_id, image_prompt],
208
  )
 
209
  image_prompt.clear(
210
  lambda: '',
211
  outputs=[trial_id],
 
213
 
214
  generate_btn.click(
215
  image_to_3d,
216
+ inputs=[trial_id, seed, randomize_seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps],
 
217
  outputs=[output_buf, video_output],
 
218
  ).then(
219
  activate_button,
220
+ outputs=[extract_glb_btn],
221
+ )
222
+
223
+ video_output.clear(
224
+ deactivate_button,
225
+ outputs=[extract_glb_btn],
226
  )
227
 
228
  extract_glb_btn.click(
229
  extract_glb,
230
  inputs=[output_buf, mesh_simplify, texture_size],
231
  outputs=[model_output, download_glb],
 
232
  ).then(
233
  activate_button,
234
+ outputs=[download_glb],
 
 
 
 
 
 
 
235
  )
236
 
237
+ model_output.clear(
238
+ deactivate_button,
239
+ outputs=[download_glb],
240
+ )
241
+
242
 
243
+ # Launch the Gradio app
244
  if __name__ == "__main__":
245
+ pipeline = TrellisImageTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-image-large")
246
+ pipeline.cuda()
247
+ try:
248
+ pipeline.preprocess_image(Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8))) # Preload rembg
249
+ except:
250
+ pass
251
+ demo.launch()