chateauxai commited on
Commit
b138d19
·
verified ·
1 Parent(s): 250f2e4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +140 -145
app.py CHANGED
@@ -1,10 +1,8 @@
1
  import gradio as gr
2
  import spaces
3
  from gradio_litmodel3d import LitModel3D
4
-
5
  import os
6
  import shutil
7
- import trimesh # New import
8
  os.environ['SPCONV_ALGO'] = 'native'
9
  from typing import *
10
  import torch
@@ -15,164 +13,161 @@ from PIL import Image
15
  from trellis.pipelines import TrellisImageTo3DPipeline
16
  from trellis.representations import Gaussian, MeshExtractResult
17
  from trellis.utils import render_utils, postprocessing_utils
18
- from scipy.spatial import ConvexHull # New import
19
-
20
- # [Previous imports and constants remain the same...]
21
-
22
- def optimize_building_mesh(mesh, angle_threshold=15, planar_threshold=0.02):
23
- """
24
- Optimize a building mesh by preserving architectural features while reducing complexity.
25
- """
26
- # Convert vertices to numpy array for processing
27
- vertices = np.array(mesh.vertices)
28
- faces = np.array(mesh.faces)
29
-
30
- # 1. Detect planar surfaces
31
- normals = mesh.face_normals
32
- planar_groups = []
33
- processed = set()
34
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  for i in range(len(faces)):
36
- if i in processed:
37
  continue
38
-
39
- # Find connected faces with similar normals
40
- similar_faces = {i}
41
- stack = [i]
42
- while stack:
43
- current = stack.pop()
44
- neighbors = mesh.face_adjacency[mesh.face_adjacency[:,0] == current][:,1]
45
- for n in neighbors:
46
- if n not in processed:
47
- angle = np.arccos(np.dot(normals[current], normals[n])) * 180 / np.pi
48
- if angle < angle_threshold:
49
- similar_faces.add(n)
50
- stack.append(n)
51
- processed.add(n)
52
-
53
- if len(similar_faces) > 0:
54
- planar_groups.append(list(similar_faces))
55
-
56
- # 2. Simplify each planar group while preserving edges
57
- new_vertices = []
58
  new_faces = []
59
- vertex_map = {}
60
-
61
- for group in planar_groups:
62
- # Get vertices for this group
63
- group_faces = faces[group]
64
- group_verts = vertices[np.unique(group_faces)]
65
-
66
- # Find best fitting plane
67
- centroid = np.mean(group_verts, axis=0)
68
- _, _, vh = np.linalg.svd(group_verts - centroid)
69
- normal = vh[2]
70
-
71
- # Project vertices to plane and simplify
72
- projected = group_verts - np.dot(group_verts - centroid, normal)[:, np.newaxis] * normal
73
-
74
- # Create simplified convex hull for this section
75
- hull = ConvexHull(projected[:,:2])
76
- hull_vertices = projected[hull.vertices]
77
-
78
- # Add to new mesh
79
- start_idx = len(new_vertices)
80
- new_vertices.extend(hull_vertices)
81
-
82
- # Triangulate the hull
83
- for i in range(1, len(hull_vertices) - 1):
84
- new_faces.append([start_idx, start_idx + i, start_idx + i + 1])
85
-
86
- # 3. Create new optimized mesh
87
- optimized_mesh = trimesh.Trimesh(
88
- vertices=np.array(new_vertices),
89
- faces=np.array(new_faces)
90
  )
91
-
92
- return optimized_mesh
93
 
94
- # Modify the existing extract_glb function
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  @spaces.GPU(duration=90)
96
- def extract_glb(
97
- state: dict,
98
- mesh_simplify: float,
99
- texture_size: int,
100
- is_building: bool, # New parameter
101
- angle_threshold: float, # New parameter
102
- planar_threshold: float, # New parameter
103
- req: gr.Request,
104
- ) -> Tuple[str, str]:
105
- """
106
- Extract a GLB file from the 3D model with optional building optimization.
107
- """
108
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
109
  gs, mesh = unpack_state(state)
110
-
111
- if is_building:
112
- # Convert to trimesh for optimization
113
- trimesh_mesh = trimesh.Trimesh(
114
- vertices=mesh.vertices.cpu().numpy(),
115
- faces=mesh.faces.cpu().numpy()
116
- )
117
-
118
- # Apply building-specific optimization
119
- optimized_mesh = optimize_building_mesh(
120
- trimesh_mesh,
121
- angle_threshold=angle_threshold,
122
- planar_threshold=planar_threshold
123
- )
124
-
125
- # Convert back to original format
126
- mesh.vertices = torch.tensor(optimized_mesh.vertices, device='cuda')
127
- mesh.faces = torch.tensor(optimized_mesh.faces, device='cuda')
128
-
129
  glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
130
  glb_path = os.path.join(user_dir, 'sample.glb')
131
  glb.export(glb_path)
132
  torch.cuda.empty_cache()
133
  return glb_path, glb_path
134
 
135
- # Modify the main UI code section
136
- with gr.Blocks(delete_cache=(600, 600)) as demo:
137
- # [Previous UI code remains the same until GLB Extraction Settings...]
138
-
139
- with gr.Accordion(label="GLB Extraction Settings", open=False):
140
- mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01)
141
- texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
142
- # Add new building optimization controls
143
- with gr.Row():
144
- is_building = gr.Checkbox(label="Enable Building Optimization", value=False)
145
- with gr.Column(visible=False) as building_settings:
146
- angle_threshold = gr.Slider(5, 45, label="Edge Angle Threshold", value=15, step=1)
147
- planar_threshold = gr.Slider(0.01, 0.1, label="Planar Surface Threshold", value=0.02, step=0.01)
148
-
149
- # [Rest of the UI code remains the same until the event handlers...]
150
-
151
- # Add visibility toggle for building settings
152
- is_building.change(
153
- lambda x: gr.Column.update(visible=x),
154
- inputs=[is_building],
155
- outputs=[building_settings]
156
- )
157
-
158
- # Modify the extract_glb button click handler
159
- extract_glb_btn.click(
160
- extract_glb,
161
- inputs=[output_buf, mesh_simplify, texture_size, is_building, angle_threshold, planar_threshold],
162
- outputs=[model_output, download_glb],
163
- ).then(
164
- lambda: gr.Button(interactive=True),
165
- outputs=[download_glb],
166
- )
167
-
168
- # [Rest of the code remains the same...]
169
-
170
  # Launch the Gradio app
171
  if __name__ == "__main__":
172
  pipeline = TrellisImageTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-image-large")
173
  pipeline.cuda()
174
- try:
175
- pipeline.preprocess_image(Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8))) # Preload rembg
176
- except:
177
- pass
178
- demo.launch()
 
1
  import gradio as gr
2
  import spaces
3
  from gradio_litmodel3d import LitModel3D
 
4
  import os
5
  import shutil
 
6
  os.environ['SPCONV_ALGO'] = 'native'
7
  from typing import *
8
  import torch
 
13
  from trellis.pipelines import TrellisImageTo3DPipeline
14
  from trellis.representations import Gaussian, MeshExtractResult
15
  from trellis.utils import render_utils, postprocessing_utils
16
+ from scipy.spatial.transform import Rotation
17
+
18
+ # Constants
19
+ MAX_SEED = np.iinfo(np.int32).max
20
+ TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
21
+ os.makedirs(TMP_DIR, exist_ok=True)
22
+
23
+ # Session Management
24
+ def start_session(req: gr.Request):
25
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
26
+ os.makedirs(user_dir, exist_ok=True)
27
+
28
+ def end_session(req: gr.Request):
29
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
30
+ shutil.rmtree(user_dir)
31
+
32
+ # Utility Functions
33
+ def preprocess_image(image: Image.Image) -> Image.Image:
34
+ return pipeline.preprocess_image(image)
35
+
36
+ def preprocess_images(images: List[Tuple[Image.Image, str]]) -> List[Image.Image]:
37
+ images = [image[0] for image in images]
38
+ return [pipeline.preprocess_image(image) for image in images]
39
+
40
+ def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
41
+ return {
42
+ 'gaussian': {
43
+ **gs.init_params,
44
+ '_xyz': gs._xyz.cpu().numpy(),
45
+ '_features_dc': gs._features_dc.cpu().numpy(),
46
+ '_scaling': gs._scaling.cpu().numpy(),
47
+ '_rotation': gs._rotation.cpu().numpy(),
48
+ '_opacity': gs._opacity.cpu().numpy(),
49
+ },
50
+ 'mesh': {
51
+ 'vertices': mesh.vertices.cpu().numpy(),
52
+ 'faces': mesh.faces.cpu().numpy(),
53
+ },
54
+ }
55
+
56
+ def unpack_state(state: dict) -> Tuple[Gaussian, MeshExtractResult]:
57
+ gs = Gaussian(
58
+ aabb=state['gaussian']['aabb'],
59
+ sh_degree=state['gaussian']['sh_degree'],
60
+ mininum_kernel_size=state['gaussian']['mininum_kernel_size'],
61
+ scaling_bias=state['gaussian']['scaling_bias'],
62
+ opacity_bias=state['gaussian']['opacity_bias'],
63
+ scaling_activation=state['gaussian']['scaling_activation'],
64
+ )
65
+ gs._xyz = torch.tensor(state['gaussian']['_xyz'], device='cuda')
66
+ gs._features_dc = torch.tensor(state['gaussian']['_features_dc'], device='cuda')
67
+ gs._scaling = torch.tensor(state['gaussian']['_scaling'], device='cuda')
68
+ gs._rotation = torch.tensor(state['gaussian']['_rotation'], device='cuda')
69
+ gs._opacity = torch.tensor(state['gaussian']['_opacity'], device='cuda')
70
+
71
+ mesh = MeshExtractResult(
72
+ vertices=torch.tensor(state['mesh']['vertices'], device='cuda'),
73
+ faces=torch.tensor(state['mesh']['faces'], device='cuda'),
74
+ )
75
+
76
+ return gs, mesh
77
+
78
+ def convert_to_poly_surfaces(mesh: MeshExtractResult, normal_threshold: float = 0.95) -> MeshExtractResult:
79
+ vertices = mesh.vertices.cpu().numpy()
80
+ faces = mesh.faces.cpu().numpy()
81
+
82
+ v0 = vertices[faces[:, 0]]
83
+ v1 = vertices[faces[:, 1]]
84
+ v2 = vertices[faces[:, 2]]
85
+ normals = np.cross(v1 - v0, v2 - v0)
86
+ norms = np.linalg.norm(normals, axis=1)[:, None]
87
+ norms[norms == 0] = 1e-10
88
+ normals = normals / norms
89
+
90
+ groups = []
91
+ used = set()
92
+
93
  for i in range(len(faces)):
94
+ if i in used:
95
  continue
96
+
97
+ group = {i}
98
+ used.add(i)
99
+
100
+ for j in range(i + 1, len(faces)):
101
+ if j in used:
102
+ continue
103
+
104
+ if np.abs(np.dot(normals[i], normals[j])) > normal_threshold:
105
+ if len(set(faces[i]).intersection(faces[j])) >= 2:
106
+ group.add(j)
107
+ used.add(j)
108
+
109
+ groups.append(list(group))
110
+
 
 
 
 
 
111
  new_faces = []
112
+ for group in groups:
113
+ if len(group) <= 2:
114
+ for idx in group:
115
+ new_faces.append(faces[idx])
116
+ else:
117
+ group_faces = faces[group]
118
+ vert_mask = np.zeros(len(vertices), dtype=bool)
119
+ vert_mask[group_faces.flatten()] = True
120
+
121
+ group_verts = vertices[vert_mask]
122
+ normal = normals[group[0]]
123
+ rot = Rotation.align_vectors([[0, 0, 1]], [normal])[0]
124
+ projected = rot.apply(group_verts)
125
+ for idx in group:
126
+ new_faces.append(faces[idx])
127
+
128
+ new_faces = np.array(new_faces)
129
+ return MeshExtractResult(
130
+ vertices=mesh.vertices,
131
+ faces=torch.tensor(new_faces, device=mesh.faces.device),
 
 
 
 
 
 
 
 
 
 
 
132
  )
 
 
133
 
134
+ # Main Functions
135
+ @spaces.GPU
136
+ def image_to_3d(image: Image.Image, multiimages: List[Tuple[Image.Image, str]], is_multiimage: bool, seed: int,
137
+ ss_guidance_strength: float, ss_sampling_steps: int, slat_guidance_strength: float,
138
+ slat_sampling_steps: int, multiimage_algo: Literal["multidiffusion", "stochastic"], req: gr.Request) -> Tuple[dict, str]:
139
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
140
+ outputs = pipeline.run(image, seed=seed, formats=["gaussian", "mesh"], preprocess_image=False,
141
+ sparse_structure_sampler_params={"steps": ss_sampling_steps, "cfg_strength": ss_guidance_strength},
142
+ slat_sampler_params={"steps": slat_sampling_steps, "cfg_strength": slat_guidance_strength})
143
+
144
+ video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
145
+ video_path = os.path.join(user_dir, 'sample.mp4')
146
+ imageio.mimsave(video_path, video, fps=15)
147
+ state = pack_state(outputs['gaussian'][0], outputs['mesh'][0])
148
+ torch.cuda.empty_cache()
149
+ return state, video_path
150
+
151
  @spaces.GPU(duration=90)
152
+ def extract_glb(state: dict, mesh_simplify: float, texture_size: int, use_poly_surfaces: bool,
153
+ normal_threshold: float, req: gr.Request) -> Tuple[str, str]:
 
 
 
 
 
 
 
 
 
 
154
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
155
  gs, mesh = unpack_state(state)
156
+
157
+ if use_poly_surfaces:
158
+ mesh = convert_to_poly_surfaces(mesh, normal_threshold)
159
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
  glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
161
  glb_path = os.path.join(user_dir, 'sample.glb')
162
  glb.export(glb_path)
163
  torch.cuda.empty_cache()
164
  return glb_path, glb_path
165
 
166
+ # Gradio Interface
167
+ demo = gr.Blocks()
168
+ # Add UI elements similar to the original code
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
  # Launch the Gradio app
170
  if __name__ == "__main__":
171
  pipeline = TrellisImageTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-image-large")
172
  pipeline.cuda()
173
+ demo.launch()