chateauxai commited on
Commit
6e469c0
·
verified ·
1 Parent(s): 7f3d4c0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +268 -228
app.py CHANGED
@@ -1,244 +1,284 @@
1
  import gradio as gr
2
- import spaces
3
- from gradio_litmodel3d import LitModel3D
4
- import os
5
- import shutil
6
  import torch
7
  import numpy as np
 
 
8
  import imageio
9
- from easydict import EasyDict as edict
10
  from PIL import Image
11
- from trellis.pipelines import TrellisImageTo3DPipeline
12
- from trellis.representations import Gaussian, MeshExtractResult
13
- from trellis.utils import render_utils, postprocessing_utils
14
 
15
- # Add this before your preprocessing functions
16
- pipeline = TrellisImageTo3DPipeline()
 
 
 
 
 
 
 
 
 
 
17
 
18
  # Constants
19
  MAX_SEED = np.iinfo(np.int32).max
20
  TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
21
  os.makedirs(TMP_DIR, exist_ok=True)
22
 
23
- # Session management
24
- def start_session(req: gr.Request):
25
- user_dir = os.path.join(TMP_DIR, str(req.session_hash))
26
- os.makedirs(user_dir, exist_ok=True)
27
-
28
- def end_session(req: gr.Request):
29
- user_dir = os.path.join(TMP_DIR, str(req.session_hash))
30
- shutil.rmtree(user_dir)
31
-
32
- # Preprocessing functions
33
- def preprocess_image(image: Image.Image) -> Image.Image:
34
- return pipeline.preprocess_image(image)
35
-
36
- def preprocess_images(images: list) -> list:
37
- images = [image[0] for image in images]
38
- return [pipeline.preprocess_image(image) for image in images]
39
-
40
- # Utility functions
41
- def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
42
- return {
43
- 'gaussian': {
44
- **gs.init_params,
45
- '_xyz': gs._xyz.cpu().numpy(),
46
- '_features_dc': gs._features_dc.cpu().numpy(),
47
- '_scaling': gs._scaling.cpu().numpy(),
48
- '_rotation': gs._rotation.cpu().numpy(),
49
- '_opacity': gs._opacity.cpu().numpy(),
50
- },
51
- 'mesh': {
52
- 'vertices': mesh.vertices.cpu().numpy(),
53
- 'faces': mesh.faces.cpu().numpy(),
54
- },
55
- }
56
-
57
- def unpack_state(state: dict) -> tuple:
58
- gs = Gaussian(
59
- aabb=state['gaussian']['aabb'],
60
- sh_degree=state['gaussian']['sh_degree'],
61
- mininum_kernel_size=state['gaussian']['mininum_kernel_size'],
62
- scaling_bias=state['gaussian']['scaling_bias'],
63
- opacity_bias=state['gaussian']['opacity_bias'],
64
- scaling_activation=state['gaussian']['scaling_activation'],
65
- )
66
- gs._xyz = torch.tensor(state['gaussian']['_xyz'], device='cuda')
67
- gs._features_dc = torch.tensor(state['gaussian']['_features_dc'], device='cuda')
68
- gs._scaling = torch.tensor(state['gaussian']['_scaling'], device='cuda')
69
- gs._rotation = torch.tensor(state['gaussian']['_rotation'], device='cuda')
70
- gs._opacity = torch.tensor(state['gaussian']['_opacity'], device='cuda')
71
-
72
- mesh = edict(
73
- vertices=torch.tensor(state['mesh']['vertices'], device='cuda'),
74
- faces=torch.tensor(state['mesh']['faces'], device='cuda'),
75
- )
76
-
77
- return gs, mesh
78
-
79
- def get_seed(randomize_seed: bool, seed: int) -> int:
80
- return np.random.randint(0, MAX_SEED) if randomize_seed else seed
81
-
82
- # Core functions
83
- @spaces.GPU
84
- def image_to_3d(
85
- image: Image.Image,
86
- multiimages: list,
87
- is_multiimage: bool,
88
- seed: int,
89
- ss_guidance_strength: float,
90
- ss_sampling_steps: int,
91
- slat_guidance_strength: float,
92
- slat_sampling_steps: int,
93
- multiimage_algo: str,
94
- req: gr.Request,
95
- ) -> tuple:
96
- user_dir = os.path.join(TMP_DIR, str(req.session_hash))
97
- if not is_multiimage:
98
- outputs = pipeline.run(
99
- image,
100
- seed=seed,
101
- formats=["gaussian", "mesh"],
102
- preprocess_image=False,
103
- sparse_structure_sampler_params={
104
- "steps": ss_sampling_steps,
105
- "cfg_strength": ss_guidance_strength,
106
- },
107
- slat_sampler_params={
108
- "steps": slat_sampling_steps,
109
- "cfg_strength": slat_guidance_strength,
110
- },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  )
112
- else:
113
- outputs = pipeline.run_multi_image(
114
- [image[0] for image in multiimages],
115
- seed=seed,
116
- formats=["gaussian", "mesh"],
117
- preprocess_image=False,
118
- sparse_structure_sampler_params={
119
- "steps": ss_sampling_steps,
120
- "cfg_strength": ss_guidance_strength,
121
- },
122
- slat_sampler_params={
123
- "steps": slat_sampling_steps,
124
- "cfg_strength": slat_guidance_strength,
125
- },
126
- mode=multiimage_algo,
127
  )
128
 
129
- video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
130
- video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
131
- video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
132
- video_path = os.path.join(user_dir, 'sample.mp4')
133
- imageio.mimsave(video_path, video, fps=15)
134
- state = pack_state(outputs['gaussian'][0], outputs['mesh'][0])
135
- torch.cuda.empty_cache()
136
- return state, video_path
137
-
138
- @spaces.GPU(duration=90)
139
- def extract_glb(
140
- state: dict,
141
- mesh_simplify: float,
142
- texture_size: int,
143
- req: gr.Request,
144
- ) -> tuple:
145
- user_dir = os.path.join(TMP_DIR, str(req.session_hash))
146
- gs, mesh = unpack_state(state)
147
-
148
- # Convert the mesh to polygonal surfaces (quads)
149
- mesh.vertices, mesh.faces = postprocessing_utils.remesh_to_quads(
150
- vertices=mesh.vertices.cpu().numpy(),
151
- faces=mesh.faces.cpu().numpy(),
152
- simplify=mesh_simplify
153
- )
154
-
155
- glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
156
- glb_path = os.path.join(user_dir, 'sample.glb')
157
- glb.export(glb_path)
158
- torch.cuda.empty_cache()
159
- return glb_path, glb_path
160
-
161
- @spaces.GPU
162
- def extract_gaussian(state: dict, req: gr.Request) -> tuple:
163
- user_dir = os.path.join(TMP_DIR, str(req.session_hash))
164
- gs, _ = unpack_state(state)
165
- gaussian_path = os.path.join(user_dir, 'sample.ply')
166
- gs.save_ply(gaussian_path)
167
- torch.cuda.empty_cache()
168
- return gaussian_path, gaussian_path
169
-
170
- # Gradio UI setup
171
- with gr.Blocks(theme=gr.themes.Default(), delete_cache=(600, 600)) as demo:
172
- with gr.Row():
173
- with gr.Column():
174
- with gr.Tabs() as input_tabs:
175
- with gr.Tab(label="Single Image", id=0) as single_image_input_tab:
176
- image_prompt = gr.Image(label="Image Prompt", format="png", image_mode="RGBA", type="pil", height=300)
177
- with gr.Tab(label="Multiple Images", id=1) as multiimage_input_tab:
178
- multiimage_prompt = gr.Gallery(label="Image Prompt", format="png", type="pil", height=300, columns=3)
179
-
180
- with gr.Accordion(label="Generation Settings", open=False):
181
- seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
182
- randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
183
- with gr.Row():
184
- ss_guidance_strength = gr.Slider(0.0, 10.0, label="Sparse Guidance Strength", value=7.5, step=0.1)
185
- ss_sampling_steps = gr.Slider(1, 50, label="Sparse Sampling Steps", value=12, step=1)
186
- with gr.Row():
187
- slat_guidance_strength = gr.Slider(0.0, 10.0, label="Latent Guidance Strength", value=3.0, step=0.1)
188
- slat_sampling_steps = gr.Slider(1, 50, label="Latent Sampling Steps", value=12, step=1)
189
- multiimage_algo = gr.Radio(["stochastic", "multidiffusion"], label="Multi-image Algorithm", value="stochastic")
190
-
191
- generate_btn = gr.Button("Generate", variant="primary")
192
-
193
- with gr.Accordion(label="GLB Extraction Settings", open=False):
194
- mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01)
195
- texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
196
-
197
- with gr.Row():
198
- extract_glb_btn = gr.Button("Extract GLB", interactive=False)
199
- extract_gs_btn = gr.Button("Extract Gaussian", interactive=False)
200
 
201
- with gr.Column():
202
- video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
203
- model_output = LitModel3D(label="Extracted GLB/Gaussian", exposure=10.0, height=300)
204
-
205
- with gr.Row():
206
- download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
207
- download_gs = gr.DownloadButton(label="Download Gaussian", interactive=False)
208
-
209
- is_multiimage = gr.State(False)
210
- output_buf = gr.State()
211
-
212
- # Handlers
213
- demo.load(start_session)
214
- demo.unload(end_session)
215
-
216
- single_image_input_tab.select(lambda: False, outputs=[is_multiimage])
217
- multiimage_input_tab.select(lambda: True, outputs=[is_multiimage])
218
-
219
- image_prompt.upload(preprocess_image, inputs=[image_prompt], outputs=[image_prompt])
220
- multiimage_prompt.upload(preprocess_images, inputs=[multiimage_prompt], outputs=[multiimage_prompt])
221
-
222
- generate_btn.click(get_seed, inputs=[randomize_seed, seed], outputs=[seed]).then(
223
- image_to_3d,
224
- inputs=[image_prompt, multiimage_prompt, is_multiimage, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps, multiimage_algo],
225
- outputs=[output_buf, video_output],
226
- ).then(
227
- lambda: tuple([gr.Button(interactive=True), gr.Button(interactive=True)]),
228
- outputs=[extract_glb_btn, extract_gs_btn],
229
- )
230
-
231
- extract_glb_btn.click(
232
- extract_glb,
233
- inputs=[output_buf, mesh_simplify, texture_size],
234
- outputs=[model_output, download_glb],
235
- )
236
-
237
- extract_gs_btn.click(
238
- extract_gaussian,
239
- inputs=[output_buf],
240
- outputs=[model_output, download_gs],
241
- )
242
-
243
- # Launch the Gradio demo for Hugging Face Spaces
244
- demo.launch()
 
1
  import gradio as gr
 
 
 
 
2
  import torch
3
  import numpy as np
4
+ import os
5
+ import shutil
6
  import imageio
 
7
  from PIL import Image
 
 
 
8
 
9
+ # Ensure imports are available
10
+ try:
11
+ from trellis.pipelines import TrellisImageTo3DPipeline
12
+ from trellis.representations import Gaussian, MeshExtractResult
13
+ from trellis.utils import render_utils, postprocessing_utils
14
+ from easydict import EasyDict as edict
15
+ except ImportError as e:
16
+ print(f"Error importing required libraries: {e}")
17
+ print("Please install the following libraries:")
18
+ print("- trellis-ai")
19
+ print("- easydict")
20
+ TrellisImageTo3DPipeline = None
21
 
22
  # Constants
23
  MAX_SEED = np.iinfo(np.int32).max
24
  TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
25
  os.makedirs(TMP_DIR, exist_ok=True)
26
 
27
+ class ImageTo3DConverter:
28
+ def __init__(self):
29
+ # Initialize the pipeline with error handling
30
+ try:
31
+ self.pipeline = TrellisImageTo3DPipeline()
32
+ except Exception as e:
33
+ print(f"Failed to initialize pipeline: {e}")
34
+ self.pipeline = None
35
+
36
+ def validate_input(self, image, is_multiimage):
37
+ """Validate input images before processing"""
38
+ if not self.pipeline:
39
+ raise ValueError("Pipeline not initialized. Check library installation.")
40
+
41
+ if is_multiimage:
42
+ # Handle multi-image input
43
+ if not image or len(image) == 0:
44
+ raise ValueError("No images provided for multi-image processing")
45
+ # Ensure images are PIL Image objects
46
+ valid_images = [img[0] if isinstance(img, list) else img for img in image]
47
+ return valid_images
48
+ else:
49
+ # Handle single image input
50
+ if image is None:
51
+ raise ValueError("No image provided")
52
+ return image
53
+
54
+ def preprocess_image(self, image):
55
+ """Safely preprocess a single image"""
56
+ try:
57
+ return self.pipeline.preprocess_image(image)
58
+ except Exception as e:
59
+ print(f"Image preprocessing error: {e}")
60
+ return image
61
+
62
+ def process_image(self,
63
+ image,
64
+ multiimages,
65
+ is_multiimage,
66
+ seed,
67
+ ss_guidance_strength,
68
+ ss_sampling_steps,
69
+ slat_guidance_strength,
70
+ slat_sampling_steps,
71
+ multiimage_algo):
72
+ """Main image to 3D conversion method"""
73
+ # Validate and preprocess input
74
+ try:
75
+ processed_input = self.validate_input(image if not is_multiimage else multiimages, is_multiimage)
76
+ except ValueError as e:
77
+ print(f"Input validation error: {e}")
78
+ return None, None
79
+
80
+ # Determine processing method based on input type
81
+ try:
82
+ if not is_multiimage:
83
+ outputs = self.pipeline.run(
84
+ processed_input,
85
+ seed=seed,
86
+ formats=["gaussian", "mesh"],
87
+ preprocess_image=False,
88
+ sparse_structure_sampler_params={
89
+ "steps": ss_sampling_steps,
90
+ "cfg_strength": ss_guidance_strength,
91
+ },
92
+ slat_sampler_params={
93
+ "steps": slat_sampling_steps,
94
+ "cfg_strength": slat_guidance_strength,
95
+ },
96
+ )
97
+ else:
98
+ outputs = self.pipeline.run_multi_image(
99
+ processed_input,
100
+ seed=seed,
101
+ formats=["gaussian", "mesh"],
102
+ preprocess_image=False,
103
+ sparse_structure_sampler_params={
104
+ "steps": ss_sampling_steps,
105
+ "cfg_strength": ss_guidance_strength,
106
+ },
107
+ slat_sampler_params={
108
+ "steps": slat_sampling_steps,
109
+ "cfg_strength": slat_guidance_strength,
110
+ },
111
+ mode=multiimage_algo,
112
+ )
113
+ except Exception as e:
114
+ print(f"3D conversion error: {e}")
115
+ return None, None
116
+
117
+ # Generate video
118
+ try:
119
+ video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
120
+ video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
121
+ video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
122
+
123
+ # Save video
124
+ user_dir = os.path.join(TMP_DIR, 'temp_session')
125
+ os.makedirs(user_dir, exist_ok=True)
126
+ video_path = os.path.join(user_dir, 'sample.mp4')
127
+ imageio.mimsave(video_path, video, fps=15)
128
+
129
+ # Pack and return state
130
+ state = {
131
+ 'gaussian': {
132
+ **outputs['gaussian'][0].init_params,
133
+ '_xyz': outputs['gaussian'][0]._xyz.cpu().numpy(),
134
+ '_features_dc': outputs['gaussian'][0]._features_dc.cpu().numpy(),
135
+ '_scaling': outputs['gaussian'][0]._scaling.cpu().numpy(),
136
+ '_rotation': outputs['gaussian'][0]._rotation.cpu().numpy(),
137
+ '_opacity': outputs['gaussian'][0]._opacity.cpu().numpy(),
138
+ },
139
+ 'mesh': {
140
+ 'vertices': outputs['mesh'][0].vertices.cpu().numpy(),
141
+ 'faces': outputs['mesh'][0].faces.cpu().numpy(),
142
+ },
143
+ }
144
+
145
+ return state, video_path
146
+
147
+ except Exception as e:
148
+ print(f"Video generation error: {e}")
149
+ return None, None
150
+
151
+ def extract_glb(self, state, mesh_simplify=0.95, texture_size=1024):
152
+ """Extract GLB from the processed state"""
153
+ try:
154
+ # Reconstruct Gaussian and Mesh from state
155
+ gs = Gaussian(
156
+ aabb=state['gaussian']['aabb'],
157
+ sh_degree=state['gaussian']['sh_degree'],
158
+ mininum_kernel_size=state['gaussian']['mininum_kernel_size'],
159
+ scaling_bias=state['gaussian'].get('scaling_bias', 0.1),
160
+ opacity_bias=state['gaussian'].get('opacity_bias', 0.1),
161
+ scaling_activation=state['gaussian'].get('scaling_activation', 'softplus'),
162
+ )
163
+ gs._xyz = torch.tensor(state['gaussian']['_xyz'], device='cuda')
164
+ gs._features_dc = torch.tensor(state['gaussian']['_features_dc'], device='cuda')
165
+ gs._scaling = torch.tensor(state['gaussian']['_scaling'], device='cuda')
166
+ gs._rotation = torch.tensor(state['gaussian']['_rotation'], device='cuda')
167
+ gs._opacity = torch.tensor(state['gaussian']['_opacity'], device='cuda')
168
+
169
+ mesh = edict(
170
+ vertices=torch.tensor(state['mesh']['vertices'], device='cuda'),
171
+ faces=torch.tensor(state['mesh']['faces'], device='cuda'),
172
+ )
173
+
174
+ # Convert mesh
175
+ mesh.vertices, mesh.faces = postprocessing_utils.remesh_to_quads(
176
+ vertices=mesh.vertices.cpu().numpy(),
177
+ faces=mesh.faces.cpu().numpy(),
178
+ simplify=mesh_simplify
179
+ )
180
+
181
+ # Generate GLB
182
+ glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
183
+
184
+ # Save GLB
185
+ user_dir = os.path.join(TMP_DIR, 'temp_session')
186
+ os.makedirs(user_dir, exist_ok=True)
187
+ glb_path = os.path.join(user_dir, 'sample.glb')
188
+ glb.export(glb_path)
189
+
190
+ return glb_path
191
+
192
+ except Exception as e:
193
+ print(f"GLB extraction error: {e}")
194
+ return None
195
+
196
+ # Gradio Interface Setup
197
+ def create_gradio_interface():
198
+ converter = ImageTo3DConverter()
199
+
200
+ with gr.Blocks() as demo:
201
+ # Input components
202
+ with gr.Row():
203
+ with gr.Column():
204
+ with gr.Tabs() as input_tabs:
205
+ with gr.Tab("Single Image"):
206
+ single_image = gr.Image(label="Single Image Input")
207
+ with gr.Tab("Multiple Images"):
208
+ multi_images = gr.Gallery(label="Multiple Image Input")
209
+
210
+ # Generation settings
211
+ with gr.Accordion("Generation Settings"):
212
+ seed = gr.Slider(0, MAX_SEED, label="Seed", value=0)
213
+ randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
214
+
215
+ with gr.Row():
216
+ ss_guidance = gr.Slider(0, 10, label="Sparse Guidance Strength", value=7.5)
217
+ ss_steps = gr.Slider(1, 50, label="Sparse Sampling Steps", value=12)
218
+
219
+ with gr.Row():
220
+ slat_guidance = gr.Slider(0, 10, label="Latent Guidance Strength", value=3.0)
221
+ slat_steps = gr.Slider(1, 50, label="Latent Sampling Steps", value=12)
222
+
223
+ multi_algo = gr.Radio(["stochastic", "multidiffusion"], label="Multi-image Algorithm", value="stochastic")
224
+
225
+ # Buttons
226
+ generate_btn = gr.Button("Generate 3D Model")
227
+
228
+ # GLB Extraction
229
+ with gr.Accordion("GLB Extraction"):
230
+ mesh_simplify = gr.Slider(0.9, 0.98, label="Mesh Simplify", value=0.95)
231
+ texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024)
232
+ extract_glb_btn = gr.Button("Extract GLB")
233
+
234
+ # Output components
235
+ with gr.Column():
236
+ video_output = gr.Video(label="Generated 3D Asset Preview")
237
+ glb_output = gr.File(label="Extracted GLB")
238
+
239
+ # Event handlers
240
+ def generate_3d(image, multi_image, seed, ss_guidance, ss_steps,
241
+ slat_guidance, slat_steps, multi_algo):
242
+ # Determine if it's multi-image mode
243
+ is_multi = isinstance(multi_image, list) and len(multi_image) > 0
244
+
245
+ # Randomize seed if selected
246
+ if randomize_seed:
247
+ seed = np.random.randint(0, MAX_SEED)
248
+
249
+ # Process image
250
+ state, video = converter.process_image(
251
+ image, multi_image, is_multi, seed,
252
+ ss_guidance, ss_steps,
253
+ slat_guidance, slat_steps,
254
+ multi_algo
255
+ )
256
+
257
+ return video if video else None
258
+
259
+ def extract_glb(state, simplify, texture_size):
260
+ if state is None:
261
+ return None
262
+ glb_path = converter.extract_glb(state, simplify, texture_size)
263
+ return glb_path
264
+
265
+ # Connect event handlers
266
+ generate_btn.click(
267
+ generate_3d,
268
+ inputs=[single_image, multi_images, seed, ss_guidance, ss_steps,
269
+ slat_guidance, slat_steps, multi_algo],
270
+ outputs=[video_output]
271
  )
272
+
273
+ extract_glb_btn.click(
274
+ extract_glb,
275
+ inputs=[state, mesh_simplify, texture_size],
276
+ outputs=[glb_output]
 
 
 
 
 
 
 
 
 
 
277
  )
278
 
279
+ return demo
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
280
 
281
+ # Launch the interface
282
+ if __name__ == "__main__":
283
+ interface = create_gradio_interface()
284
+ interface.launch()