chongjie commited on
Commit
823f35c
1 Parent(s): fa4f45e

Add pcd2grid

Browse files
Files changed (1) hide show
  1. app.py +47 -3
app.py CHANGED
@@ -13,6 +13,7 @@ import mcc_model
13
  import util.misc as misc
14
  from engine_mcc import prepare_data
15
  from plyfile import PlyData, PlyElement
 
16
 
17
  def run_inference(model, samples, device, temperature, args):
18
  model.eval()
@@ -122,6 +123,25 @@ def normalize(seen_xyz):
122
  seen_xyz = seen_xyz - seen_xyz[torch.isfinite(seen_xyz.sum(dim=-1))].mean(axis=0)
123
  return seen_xyz
124
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  def infer(
126
  image,
127
  depth_image,
@@ -203,10 +223,33 @@ def infer(
203
  PlyData([element], text=True).write(f)
204
  temp_file_name = f.name
205
 
206
- return temp_file_name
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
 
208
  if __name__ == '__main__':
209
- device = "cuda" if torch.cuda.is_available() else "cpu"
 
210
 
211
  parser = main_mcc.get_args_parser()
212
  parser.set_defaults(eval=True)
@@ -231,7 +274,8 @@ if __name__ == '__main__':
231
  gr.Slider(minimum=0.05, maximum=0.5, step=0.05, value=0.2, label="Grain Size"),
232
  gr.Slider(minimum=0, maximum=1.0, step=0.1, value=0.1, label="Color Temperature")
233
  ],
234
- outputs=[gr.outputs.File(label="Point Cloud")],
 
235
  examples=[["demo/quest2.jpg", "demo/quest2_depth.png", "demo/quest2_seg.png", 0.2, 0.1]],
236
  cache_examples=True)
237
  demo.launch(server_name="0.0.0.0", server_port=7860)
 
13
  import util.misc as misc
14
  from engine_mcc import prepare_data
15
  from plyfile import PlyData, PlyElement
16
+ import trimesh
17
 
18
  def run_inference(model, samples, device, temperature, args):
19
  model.eval()
 
123
  seen_xyz = seen_xyz - seen_xyz[torch.isfinite(seen_xyz.sum(dim=-1))].mean(axis=0)
124
  return seen_xyz
125
 
126
+ def voxel_grid_downsample(points, colors, voxel_size):
127
+ # Compute voxel indices
128
+ voxel_indices = np.floor(points / voxel_size).astype(int)
129
+
130
+ # Remove duplicate voxel indices
131
+ unique_voxel_indices, inverse_indices = np.unique(voxel_indices, axis=0, return_inverse=True)
132
+
133
+ # Compute the centroid of the points and the average color in each voxel
134
+ centroids = np.empty_like(unique_voxel_indices, dtype=float)
135
+ avg_colors = np.empty((len(unique_voxel_indices), colors.shape[1]), dtype=colors.dtype)
136
+ for i in range(len(unique_voxel_indices)):
137
+ centroids[i] = points[inverse_indices == i].mean(axis=0)
138
+ avg_colors[i] = colors[inverse_indices == i].mean(axis=0)
139
+
140
+ # Convert colors from RGB to BGR
141
+ avg_colors = avg_colors[:, ::-1]
142
+
143
+ return centroids, avg_colors
144
+
145
  def infer(
146
  image,
147
  depth_image,
 
223
  PlyData([element], text=True).write(f)
224
  temp_file_name = f.name
225
 
226
+ # Perform voxel grid downsampling
227
+ voxel_size = 0.2 # Change this to the size of your cubes
228
+ downsampled_xyz, downsampled_colors = voxel_grid_downsample(unseen_xyz, pred_colors, voxel_size)
229
+
230
+ meshes = []
231
+ for point, color in zip(downsampled_xyz, downsampled_colors):
232
+ # Create a cube mesh at the given point
233
+ cube = trimesh.creation.box(extents=[voxel_size]*3)
234
+ cube.apply_translation(point)
235
+
236
+ # Assign the average color to the vertices
237
+ cube.visual.vertex_colors = np.hstack([color, 255]) # Set alpha to 255
238
+ meshes.append(cube)
239
+
240
+ # Save point cloud data to a temporary file
241
+ with tempfile.NamedTemporaryFile(suffix=".obj", delete=False) as f:
242
+ temp_obj_file = f.name
243
+ print(temp_obj_file)
244
+ # Combine all the cubes into a single mesh
245
+ combined = trimesh.util.concatenate(meshes)
246
+ # Save the combined mesh to a file
247
+ combined.export(temp_obj_file)
248
+ return temp_file_name, temp_obj_file
249
 
250
  if __name__ == '__main__':
251
+ device = "cpu"
252
+ # device = "cuda" if torch.cuda.is_available() else "cpu"
253
 
254
  parser = main_mcc.get_args_parser()
255
  parser.set_defaults(eval=True)
 
274
  gr.Slider(minimum=0.05, maximum=0.5, step=0.05, value=0.2, label="Grain Size"),
275
  gr.Slider(minimum=0, maximum=1.0, step=0.1, value=0.1, label="Color Temperature")
276
  ],
277
+ outputs=[gr.outputs.File(label="Point Cloud"),
278
+ gr.Model3D( clear_color=[0.0, 0.0, 0.0, 0.0], label="3D Model")],
279
  examples=[["demo/quest2.jpg", "demo/quest2_depth.png", "demo/quest2_seg.png", 0.2, 0.1]],
280
  cache_examples=True)
281
  demo.launch(server_name="0.0.0.0", server_port=7860)