chateauxai commited on
Commit
b0325d3
·
verified ·
1 Parent(s): 39589ae

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -15
app.py CHANGED
@@ -17,9 +17,6 @@ MAX_SEED = np.iinfo(np.int32).max
17
  TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
18
  os.makedirs(TMP_DIR, exist_ok=True)
19
 
20
- # Initialize pipeline
21
- pipeline = TrellisImageTo3DPipeline()
22
-
23
  # Session management
24
  def start_session(req: gr.Request):
25
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
@@ -68,10 +65,12 @@ def unpack_state(state: dict) -> tuple:
68
  gs._scaling = torch.tensor(state['gaussian']['_scaling'], device='cuda')
69
  gs._rotation = torch.tensor(state['gaussian']['_rotation'], device='cuda')
70
  gs._opacity = torch.tensor(state['gaussian']['_opacity'], device='cuda')
 
71
  mesh = edict(
72
  vertices=torch.tensor(state['mesh']['vertices'], device='cuda'),
73
  faces=torch.tensor(state['mesh']['faces'], device='cuda'),
74
  )
 
75
  return gs, mesh
76
 
77
  def get_seed(randomize_seed: bool, seed: int) -> int:
@@ -123,6 +122,7 @@ def image_to_3d(
123
  },
124
  mode=multiimage_algo,
125
  )
 
126
  video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
127
  video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
128
  video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
@@ -141,12 +141,14 @@ def extract_glb(
141
  ) -> tuple:
142
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
143
  gs, mesh = unpack_state(state)
 
144
  # Convert the mesh to polygonal surfaces (quads)
145
  mesh.vertices, mesh.faces = postprocessing_utils.remesh_to_quads(
146
  vertices=mesh.vertices.cpu().numpy(),
147
  faces=mesh.faces.cpu().numpy(),
148
  simplify=mesh_simplify
149
  )
 
150
  glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
151
  glb_path = os.path.join(user_dir, 'sample.glb')
152
  glb.export(glb_path)
@@ -156,7 +158,7 @@ def extract_glb(
156
  @spaces.GPU
157
  def extract_gaussian(state: dict, req: gr.Request) -> tuple:
158
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
159
- gs, * = unpack_state(state)
160
  gaussian_path = os.path.join(user_dir, 'sample.ply')
161
  gs.save_ply(gaussian_path)
162
  torch.cuda.empty_cache()
@@ -171,6 +173,7 @@ with gr.Blocks(theme=gr.themes.Default(), delete_cache=(600, 600)) as demo:
171
  image_prompt = gr.Image(label="Image Prompt", format="png", image_mode="RGBA", type="pil", height=300)
172
  with gr.Tab(label="Multiple Images", id=1) as multiimage_input_tab:
173
  multiimage_prompt = gr.Gallery(label="Image Prompt", format="png", type="pil", height=300, columns=3)
 
174
  with gr.Accordion(label="Generation Settings", open=False):
175
  seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
176
  randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
@@ -181,34 +184,38 @@ with gr.Blocks(theme=gr.themes.Default(), delete_cache=(600, 600)) as demo:
181
  slat_guidance_strength = gr.Slider(0.0, 10.0, label="Latent Guidance Strength", value=3.0, step=0.1)
182
  slat_sampling_steps = gr.Slider(1, 50, label="Latent Sampling Steps", value=12, step=1)
183
  multiimage_algo = gr.Radio(["stochastic", "multidiffusion"], label="Multi-image Algorithm", value="stochastic")
 
184
  generate_btn = gr.Button("Generate", variant="primary")
 
185
  with gr.Accordion(label="GLB Extraction Settings", open=False):
186
  mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01)
187
  texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
 
188
  with gr.Row():
189
  extract_glb_btn = gr.Button("Extract GLB", interactive=False)
190
  extract_gs_btn = gr.Button("Extract Gaussian", interactive=False)
 
191
  with gr.Column():
192
  video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
193
  model_output = LitModel3D(label="Extracted GLB/Gaussian", exposure=10.0, height=300)
 
194
  with gr.Row():
195
  download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
196
  download_gs = gr.DownloadButton(label="Download Gaussian", interactive=False)
197
-
198
- # State Management
199
  is_multiimage = gr.State(False)
200
  output_buf = gr.State()
201
-
202
  # Handlers
203
  demo.load(start_session)
204
  demo.unload(end_session)
205
-
206
  single_image_input_tab.select(lambda: False, outputs=[is_multiimage])
207
  multiimage_input_tab.select(lambda: True, outputs=[is_multiimage])
208
-
209
  image_prompt.upload(preprocess_image, inputs=[image_prompt], outputs=[image_prompt])
210
  multiimage_prompt.upload(preprocess_images, inputs=[multiimage_prompt], outputs=[multiimage_prompt])
211
-
212
  generate_btn.click(get_seed, inputs=[randomize_seed, seed], outputs=[seed]).then(
213
  image_to_3d,
214
  inputs=[image_prompt, multiimage_prompt, is_multiimage, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps, multiimage_algo],
@@ -217,18 +224,18 @@ with gr.Blocks(theme=gr.themes.Default(), delete_cache=(600, 600)) as demo:
217
  lambda: tuple([gr.Button(interactive=True), gr.Button(interactive=True)]),
218
  outputs=[extract_glb_btn, extract_gs_btn],
219
  )
220
-
221
  extract_glb_btn.click(
222
  extract_glb,
223
  inputs=[output_buf, mesh_simplify, texture_size],
224
  outputs=[model_output, download_glb],
225
  )
226
-
227
  extract_gs_btn.click(
228
  extract_gaussian,
229
  inputs=[output_buf],
230
- outputs=[model_output, download_gs]
231
  )
232
 
233
- # Launch the app
234
- demo.launch()
 
17
  TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
18
  os.makedirs(TMP_DIR, exist_ok=True)
19
 
 
 
 
20
  # Session management
21
  def start_session(req: gr.Request):
22
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
 
65
  gs._scaling = torch.tensor(state['gaussian']['_scaling'], device='cuda')
66
  gs._rotation = torch.tensor(state['gaussian']['_rotation'], device='cuda')
67
  gs._opacity = torch.tensor(state['gaussian']['_opacity'], device='cuda')
68
+
69
  mesh = edict(
70
  vertices=torch.tensor(state['mesh']['vertices'], device='cuda'),
71
  faces=torch.tensor(state['mesh']['faces'], device='cuda'),
72
  )
73
+
74
  return gs, mesh
75
 
76
  def get_seed(randomize_seed: bool, seed: int) -> int:
 
122
  },
123
  mode=multiimage_algo,
124
  )
125
+
126
  video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
127
  video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
128
  video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
 
141
  ) -> tuple:
142
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
143
  gs, mesh = unpack_state(state)
144
+
145
  # Convert the mesh to polygonal surfaces (quads)
146
  mesh.vertices, mesh.faces = postprocessing_utils.remesh_to_quads(
147
  vertices=mesh.vertices.cpu().numpy(),
148
  faces=mesh.faces.cpu().numpy(),
149
  simplify=mesh_simplify
150
  )
151
+
152
  glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
153
  glb_path = os.path.join(user_dir, 'sample.glb')
154
  glb.export(glb_path)
 
158
  @spaces.GPU
159
  def extract_gaussian(state: dict, req: gr.Request) -> tuple:
160
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
161
+ gs, _ = unpack_state(state)
162
  gaussian_path = os.path.join(user_dir, 'sample.ply')
163
  gs.save_ply(gaussian_path)
164
  torch.cuda.empty_cache()
 
173
  image_prompt = gr.Image(label="Image Prompt", format="png", image_mode="RGBA", type="pil", height=300)
174
  with gr.Tab(label="Multiple Images", id=1) as multiimage_input_tab:
175
  multiimage_prompt = gr.Gallery(label="Image Prompt", format="png", type="pil", height=300, columns=3)
176
+
177
  with gr.Accordion(label="Generation Settings", open=False):
178
  seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
179
  randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
 
184
  slat_guidance_strength = gr.Slider(0.0, 10.0, label="Latent Guidance Strength", value=3.0, step=0.1)
185
  slat_sampling_steps = gr.Slider(1, 50, label="Latent Sampling Steps", value=12, step=1)
186
  multiimage_algo = gr.Radio(["stochastic", "multidiffusion"], label="Multi-image Algorithm", value="stochastic")
187
+
188
  generate_btn = gr.Button("Generate", variant="primary")
189
+
190
  with gr.Accordion(label="GLB Extraction Settings", open=False):
191
  mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01)
192
  texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
193
+
194
  with gr.Row():
195
  extract_glb_btn = gr.Button("Extract GLB", interactive=False)
196
  extract_gs_btn = gr.Button("Extract Gaussian", interactive=False)
197
+
198
  with gr.Column():
199
  video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
200
  model_output = LitModel3D(label="Extracted GLB/Gaussian", exposure=10.0, height=300)
201
+
202
  with gr.Row():
203
  download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
204
  download_gs = gr.DownloadButton(label="Download Gaussian", interactive=False)
205
+
 
206
  is_multiimage = gr.State(False)
207
  output_buf = gr.State()
208
+
209
  # Handlers
210
  demo.load(start_session)
211
  demo.unload(end_session)
212
+
213
  single_image_input_tab.select(lambda: False, outputs=[is_multiimage])
214
  multiimage_input_tab.select(lambda: True, outputs=[is_multiimage])
215
+
216
  image_prompt.upload(preprocess_image, inputs=[image_prompt], outputs=[image_prompt])
217
  multiimage_prompt.upload(preprocess_images, inputs=[multiimage_prompt], outputs=[multiimage_prompt])
218
+
219
  generate_btn.click(get_seed, inputs=[randomize_seed, seed], outputs=[seed]).then(
220
  image_to_3d,
221
  inputs=[image_prompt, multiimage_prompt, is_multiimage, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps, multiimage_algo],
 
224
  lambda: tuple([gr.Button(interactive=True), gr.Button(interactive=True)]),
225
  outputs=[extract_glb_btn, extract_gs_btn],
226
  )
227
+
228
  extract_glb_btn.click(
229
  extract_glb,
230
  inputs=[output_buf, mesh_simplify, texture_size],
231
  outputs=[model_output, download_glb],
232
  )
233
+
234
  extract_gs_btn.click(
235
  extract_gaussian,
236
  inputs=[output_buf],
237
+ outputs=[model_output, download_gs],
238
  )
239
 
240
+ # Launch the Gradio demo for Hugging Face Spaces
241
+ demo.launch()