Stable-X commited on
Commit
0332bda
1 Parent(s): 82b898c

feat: Add gs_utils for gs export

Browse files
Files changed (3) hide show
  1. app.py +20 -41
  2. gs_utils.py +106 -0
  3. requirements.txt +2 -1
app.py CHANGED
@@ -17,7 +17,7 @@ from torchvision import transforms
17
  from PIL import Image
18
  import open3d as o3d
19
  from backend_utils import improved_multiway_registration, pts2normal, point2mesh, combine_and_clean_point_clouds
20
-
21
 
22
  # Default values
23
  DEFAULT_CKPT_PATH = './checkpoints/spann3r.pth'
@@ -29,15 +29,8 @@ OPENGL = np.array([[1, 0, 0, 0],
29
  [0, 0, -1, 0],
30
  [0, 0, 0, 1]])
31
 
32
- def export_geometry(geometry, as_pointcloud=False):
33
- if as_pointcloud:
34
- if not isinstance(geometry, o3d.geometry.PointCloud):
35
- raise ValueError("Expected an Open3D PointCloud object when as_pointcloud is True")
36
- output_path = tempfile.mktemp(suffix='.ply')
37
- else:
38
- if not isinstance(geometry, o3d.geometry.TriangleMesh):
39
- raise ValueError("Expected an Open3D TriangleMesh object when as_pointcloud is False")
40
- output_path = tempfile.mktemp(suffix='.obj')
41
 
42
  # Apply rotation
43
  rot = np.eye(4)
@@ -45,11 +38,7 @@ def export_geometry(geometry, as_pointcloud=False):
45
  transform = np.linalg.inv(OPENGL @ rot)
46
  geometry.transform(transform)
47
 
48
- # Export the geometry
49
- if as_pointcloud:
50
- o3d.io.write_point_cloud(output_path, geometry, write_ascii=False, compressed=True)
51
- else:
52
- o3d.io.write_triangle_mesh(output_path, geometry, write_ascii=False, compressed=True)
53
 
54
  return output_path
55
 
@@ -176,7 +165,7 @@ def generate_mask(image: np.ndarray):
176
  return mask_np
177
  @torch.no_grad()
178
  def reconstruct(video_path, conf_thresh, kf_every,
179
- as_pointcloud=False, remove_background=False, refine=False):
180
  # Extract frames from video
181
  demo_path = extract_frames(video_path)
182
 
@@ -220,31 +209,21 @@ def reconstruct(video_path, conf_thresh, kf_every,
220
  pcds.append(pcd)
221
 
222
  pcd_combined = combine_and_clean_point_clouds(pcds, voxel_size=0.001)
223
-
224
- if as_pointcloud:
225
- o3d_geometry = pcd_combined
226
- else:
227
- o3d_geometry = point2mesh(pcd_combined)
228
 
229
  # Create coarse result
230
- coarse_output_path = export_geometry(o3d_geometry, as_pointcloud)
231
 
232
  yield coarse_output_path, None
233
 
234
- if refine:
235
- # Perform global optimization
236
- print("Performing global registration...")
237
- transformed_pcds, _, _ = improved_multiway_registration(pcds, voxel_size=0.001)
238
-
239
- if as_pointcloud:
240
- o3d_geometry = transformed_pcds
241
- else:
242
- o3d_geometry = point2mesh(transformed_pcds)
243
-
244
- # Create coarse result
245
- refined_output_path = export_geometry(o3d_geometry, as_pointcloud)
246
-
247
- yield coarse_output_path, refined_output_path
248
 
249
  # Clean up temporary directory
250
  os.system(f"rm -rf {demo_path}")
@@ -320,19 +299,19 @@ with gr.Blocks(
320
  kf_every = gr.Slider(1, 30, step=1, value=1, label="Keyframe Interval")
321
  with gr.Row():
322
  remove_background = gr.Checkbox(label="Remove Background", value=False)
323
- refine = gr.Checkbox(label="Enable Backend", value=False)
324
- as_pointcloud = gr.Checkbox(label="As Pointcloud", value=False)
325
  reconstruct_btn = gr.Button("Reconstruct")
326
 
327
  with gr.Column(scale=2):
328
  with gr.Tab("Coarse Model"):
329
- coarse_model = gr.Model3D(label="Coarse 3D Model", display_mode="solid", clear_color=[0.0, 0.0, 0.0, 0.0])
 
330
  with gr.Tab("Refined Model"):
331
- refined_model = gr.Model3D(label="Refined 3D Model", display_mode="solid", clear_color=[0.0, 0.0, 0.0, 0.0])
 
332
 
333
  reconstruct_btn.click(
334
  fn=reconstruct,
335
- inputs=[video_input, conf_thresh, kf_every, as_pointcloud, remove_background, refine],
336
  outputs=[coarse_model, refined_model]
337
  )
338
 
 
17
  from PIL import Image
18
  import open3d as o3d
19
  from backend_utils import improved_multiway_registration, pts2normal, point2mesh, combine_and_clean_point_clouds
20
+ from gs_utils import point2gs
21
 
22
  # Default values
23
  DEFAULT_CKPT_PATH = './checkpoints/spann3r.pth'
 
29
  [0, 0, -1, 0],
30
  [0, 0, 0, 1]])
31
 
32
+ def export_geometry(geometry):
33
+ output_path = tempfile.mktemp(suffix='.obj')
 
 
 
 
 
 
 
34
 
35
  # Apply rotation
36
  rot = np.eye(4)
 
38
  transform = np.linalg.inv(OPENGL @ rot)
39
  geometry.transform(transform)
40
 
41
+ o3d.io.write_triangle_mesh(output_path, geometry, write_ascii=False, compressed=True)
 
 
 
 
42
 
43
  return output_path
44
 
 
165
  return mask_np
166
  @torch.no_grad()
167
  def reconstruct(video_path, conf_thresh, kf_every,
168
+ remove_background=False):
169
  # Extract frames from video
170
  demo_path = extract_frames(video_path)
171
 
 
209
  pcds.append(pcd)
210
 
211
  pcd_combined = combine_and_clean_point_clouds(pcds, voxel_size=0.001)
212
+ o3d_geometry = point2mesh(pcd_combined)
 
 
 
 
213
 
214
  # Create coarse result
215
+ coarse_output_path = export_geometry(o3d_geometry)
216
 
217
  yield coarse_output_path, None
218
 
219
+ # Perform global optimization
220
+ print("Performing global registration...")
221
+ transformed_pcds, _, _ = improved_multiway_registration(pcds, voxel_size=0.01)
222
+
223
+ # Create coarse result
224
+ refined_output_path = tempfile.mktemp(suffix='.ply')
225
+ point2gs(refined_output_path, transformed_pcds)
226
+ yield coarse_output_path, refined_output_path
 
 
 
 
 
 
227
 
228
  # Clean up temporary directory
229
  os.system(f"rm -rf {demo_path}")
 
299
  kf_every = gr.Slider(1, 30, step=1, value=1, label="Keyframe Interval")
300
  with gr.Row():
301
  remove_background = gr.Checkbox(label="Remove Background", value=False)
 
 
302
  reconstruct_btn = gr.Button("Reconstruct")
303
 
304
  with gr.Column(scale=2):
305
  with gr.Tab("Coarse Model"):
306
+ coarse_model = gr.Model3D(label="Coarse 3D Model", display_mode="solid",
307
+ clear_color=[0.0, 0.0, 0.0, 0.0])
308
  with gr.Tab("Refined Model"):
309
+ refined_model = gr.Model3D(label="Refined Gaussian Splatting", display_mode="solid",
310
+ clear_color=[0.0, 0.0, 0.0, 0.0])
311
 
312
  reconstruct_btn.click(
313
  fn=reconstruct,
314
+ inputs=[video_input, conf_thresh, kf_every, remove_background],
315
  outputs=[coarse_model, refined_model]
316
  )
317
 
gs_utils.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import os
3
+ from plyfile import PlyElement, PlyData
4
+ import open3d as o3d
5
+
6
+ def get_f_dc(colors):
7
+ return RGB2SH(colors)[:, :, np.newaxis]
8
+
9
+ def get_f_rest(points, max_sh_degree=3):
10
+ f_rest_shape = (points.shape[0], (max_sh_degree + 1) ** 2 - 1, 3)
11
+ return np.zeros(f_rest_shape)
12
+
13
+ def get_opacity(points):
14
+ return inverse_sigmoid(0.5 * np.ones((points.shape[0], 1)))
15
+
16
+ def get_scales(points):
17
+ scales = np.ones((points.shape[0], 3)) * 0.0015
18
+ scales[:, 2] = 1e-6
19
+
20
+ return np.log(scales)
21
+
22
+ def get_rotation(normals):
23
+ if normals is not None and np.any(normals):
24
+ return normal2rotation(normals)
25
+ else:
26
+ return np.zeros((normals.shape[0], 4))
27
+
28
+ def RGB2SH(rgb):
29
+ return (rgb - 0.5) / 0.28209479177387814
30
+
31
+ def inverse_sigmoid(x):
32
+ return np.log(x / (1 - x))
33
+
34
+ def normal2rotation(n):
35
+ n = n / np.linalg.norm(n, axis=1, keepdims=True)
36
+ w0 = np.tile([[1, 0, 0]], (n.shape[0], 1))
37
+ R0 = w0 - np.sum(w0 * n, axis=1, keepdims=True) * n
38
+ R0 *= np.sign(R0[:, :1])
39
+ R0 /= np.linalg.norm(R0, axis=1, keepdims=True)
40
+ R1 = np.cross(n, R0)
41
+ R1 *= np.sign(R1[:, 1:2]) * np.sign(n[:, 2:])
42
+ R = np.stack([R0, R1, n], axis=-1)
43
+ q = rotmat2quaternion(R)
44
+ return q
45
+
46
+ def rotmat2quaternion(R, normalize=False):
47
+ tr = R[:, 0, 0] + R[:, 1, 1] + R[:, 2, 2] + 1e-6
48
+ r = np.sqrt(1 + tr) / 2
49
+ q = np.stack([
50
+ r,
51
+ (R[:, 2, 1] - R[:, 1, 2]) / (4 * r),
52
+ (R[:, 0, 2] - R[:, 2, 0]) / (4 * r),
53
+ (R[:, 1, 0] - R[:, 0, 1]) / (4 * r)
54
+ ], axis=-1)
55
+ if normalize:
56
+ q /= np.linalg.norm(q, axis=-1, keepdims=True)
57
+ return q
58
+
59
+ def point2gs(path, pcd, scale=None, max_sh_degree=1):
60
+ # Ensure the directory exists
61
+ os.makedirs(os.path.dirname(path), exist_ok=True)
62
+
63
+ # Get point cloud data
64
+ xyz = np.asarray(pcd.points)
65
+ normals = np.asarray(pcd.normals) if pcd.has_normals() else np.zeros_like(xyz)
66
+ colors = np.asarray(pcd.colors) if pcd.has_colors() else np.ones_like(xyz)
67
+
68
+ # Generate additional attributes
69
+ f_dc = get_f_dc(colors).reshape(xyz.shape[0], -1)
70
+ f_rest = get_f_rest(xyz, max_sh_degree).reshape(xyz.shape[0], -1)
71
+ opacities = get_opacity(xyz)
72
+ if scale is not None:
73
+ scale = np.log(scale)
74
+ else:
75
+ scale = get_scales(xyz)
76
+ rotation = get_rotation(normals)
77
+
78
+ # Construct list of attributes
79
+ attribute_names = ['x', 'y', 'z', 'nx', 'ny', 'nz']
80
+ attribute_names.extend([f'f_dc_{i}' for i in range(f_dc.shape[-1])])
81
+ attribute_names.extend([f'f_rest_{i}' for i in range(f_rest.shape[-1])])
82
+ attribute_names.append('opacity')
83
+ attribute_names.extend([f'scale_{i}' for i in range(scale.shape[1])])
84
+ attribute_names.extend([f'rot_{i}' for i in range(rotation.shape[1])])
85
+
86
+ # Create dtype for numpy structured array
87
+ dtype_full = [(attribute, 'f4') for attribute in attribute_names]
88
+
89
+ # Combine all attributes
90
+ attributes = np.concatenate((
91
+ xyz, normals,
92
+ f_dc,
93
+ f_rest,
94
+ opacities, scale, rotation
95
+ ), axis=1)
96
+
97
+ # Ensure attributes match the dtype
98
+ assert attributes.shape[1] == len(dtype_full), f"Mismatch in attribute count. Expected {len(dtype_full)}, got {attributes.shape[1]}"
99
+
100
+ # Create structured array
101
+ elements = np.empty(xyz.shape[0], dtype=dtype_full)
102
+ elements[:] = list(map(tuple, attributes))
103
+
104
+ # Create PlyElement and save
105
+ el = PlyElement.describe(elements, 'vertex')
106
+ PlyData([el]).write(path)
requirements.txt CHANGED
@@ -18,4 +18,5 @@ transformers
18
  kornia
19
  timm
20
  numpy==1.26.4
21
- open3d
 
 
18
  kornia
19
  timm
20
  numpy==1.26.4
21
+ open3d
22
+ plyfile