Jingkang Yang commited on
Commit
aaf4c51
1 Parent(s): d7a8f62

update: app

Browse files
app.py CHANGED
@@ -23,6 +23,7 @@ import cv2
23
  import tqdm
24
  import numpy as np
25
  import gradio as gr
 
26
 
27
  from detectron2.config import get_cfg
28
 
@@ -150,10 +151,16 @@ def greet_sailvos3d(rgb_input, depth_map_input, rage_matrices_input, class_candi
150
 
151
  Depth_Semantic_SAM_Mask = read_image('outputs/Depth_Semantic_SAM_Mask.png')
152
  RGB_Semantic_SAM_Mask = read_image('outputs/RGB_Semantic_SAM_Mask.png')
 
 
 
 
 
 
153
  Depth_map = read_image('outputs/Depth_rendered.png')
154
- Depth_Semantic_SAM_Mask_gif = 'outputs/depth_3d_sam_mask.mp4'
155
- RGB_Semantic_SAM_Mask_gif = 'outputs/rgb_3d_sam_mask.mp4'
156
- return RGB_Semantic_SAM_Mask, RGB_Semantic_SAM_Mask_gif, Depth_map, Depth_Semantic_SAM_Mask, Depth_Semantic_SAM_Mask_gif
157
 
158
  def greet_scannet(rgb_input, depth_map_input, class_candidates):
159
  rgb_input = rgb_input
@@ -192,10 +199,16 @@ def greet_scannet(rgb_input, depth_map_input, class_candidates):
192
 
193
  Depth_Semantic_SAM_Mask = read_image('outputs/Depth_Semantic_SAM_Mask.png')
194
  RGB_Semantic_SAM_Mask = read_image('outputs/RGB_Semantic_SAM_Mask.png')
 
 
 
 
 
 
195
  Depth_map = read_image('outputs/Depth_rendered.png')
196
- Depth_Semantic_SAM_Mask_gif = 'outputs/depth_3d_sam_mask.mp4'
197
- RGB_Semantic_SAM_Mask_gif = 'outputs/rgb_3d_sam_mask.mp4'
198
- return RGB_Semantic_SAM_Mask, RGB_Semantic_SAM_Mask_gif, Depth_map, Depth_Semantic_SAM_Mask, Depth_Semantic_SAM_Mask_gif
199
 
200
 
201
  with gr.Blocks(analytics_enabled=False) as segrgbd_iface:
@@ -211,20 +224,22 @@ with gr.Blocks(analytics_enabled=False) as segrgbd_iface:
211
  with gr.Column():
212
  with gr.Row():
213
  Input_RGB_Component = gr.Image(label = 'RGB_Input', type = 'filepath').style(width=320, height=200)
214
- Depth_Map_Output_Component = gr.Image(label = "Depth_Map").style(width=320, height=200)
215
  with gr.Row():
216
- Depth_Map_Input_Component = gr.File(label = 'Depth_map')
217
  Component_2D_to_3D_Projection_Parameters = gr.File(label = '2D_to_3D_Projection_Parameters')
218
  with gr.Row():
219
  Class_Candidates_Component = gr.Text(label = 'Class_Candidates')
220
  vc_end_btn = gr.Button("Send")
221
  with gr.Tab(label='Result'):
222
  with gr.Row():
223
- RGB_Semantic_SAM_Mask_Component = gr.Image(label = "RGB_Semantic_SAM_Mask").style(width=320, height=200)
224
  RGB_Semantic_SAM_Mask_3D_Component = gr.Video(label = "Video_3D_RGB_Semantic_SAM_Mask").style(width=320, height=200)
225
  with gr.Row():
226
- Depth_Semantic_SAM_Mask_Component = gr.Image(label = "Depth_Semantic_SAM_Mask").style(width=320, height=200)
227
  Depth_Semantic_SAM_Mask_3D_Component = gr.Video(label = "Video_3D_Depth_Semantic_SAM_Mask").style(width=320, height=200)
 
 
228
  gr.Examples(examples=[
229
  [
230
  'UI/sailvos3d/ex1/inputs/rgb_000160.bmp',
@@ -252,18 +267,20 @@ with gr.Blocks(analytics_enabled=False) as segrgbd_iface:
252
  with gr.Column():
253
  with gr.Row():
254
  Input_RGB_Component = gr.Image(label = 'RGB_Input', type = 'filepath').style(width=320, height=200)
255
- Depth_Map_Output_Component = gr.Image(label = "Depth_Map").style(width=320, height=200)
256
  with gr.Row():
257
- Depth_Map_Input_Component = gr.File(label = "Depth_Map")
258
  Class_Candidates_Component = gr.Text(label = 'Class_Candidates')
259
  vc_end_btn = gr.Button("Send")
260
  with gr.Tab(label='Result'):
261
  with gr.Row():
262
- RGB_Semantic_SAM_Mask_Component = gr.Image(label = "RGB_Semantic_SAM_Mask").style(width=320, height=200)
263
  RGB_Semantic_SAM_Mask_3D_Component = gr.Video(label = "Video_3D_RGB_Semantic_SAM_Mask").style(width=320, height=200)
264
  with gr.Row():
265
- Depth_Semantic_SAM_Mask_Component = gr.Image(label = "Depth_Semantic_SAM_Mask").style(width=320, height=200)
266
  Depth_Semantic_SAM_Mask_3D_Component = gr.Video(label = "Video_3D_Depth_Semantic_SAM_Mask").style(width=320, height=200)
 
 
267
  gr.Examples(examples=[
268
  [
269
  'UI/scannetv2/examples/scene0000_00/color/1660.jpg',
 
23
  import tqdm
24
  import numpy as np
25
  import gradio as gr
26
+ from tools.util import *
27
 
28
  from detectron2.config import get_cfg
29
 
 
151
 
152
  Depth_Semantic_SAM_Mask = read_image('outputs/Depth_Semantic_SAM_Mask.png')
153
  RGB_Semantic_SAM_Mask = read_image('outputs/RGB_Semantic_SAM_Mask.png')
154
+ Depth_Semantic_SAM = read_image('outputs/Depth_Semantic_SAM.png')
155
+ RGB_Semantic_SAM = read_image('outputs/RGB_Semantic_SAM.png')
156
+ two_image_to_gif(Depth_Semantic_SAM_Mask, Depth_Semantic_SAM, 'Depth_Semantic_SAM_2D')
157
+ two_image_to_gif(RGB_Semantic_SAM_Mask, RGB_Semantic_SAM, 'RGB_Semantic_SAM_2D')
158
+ Depth_Semantic_SAM_2D = 'outputs/Depth_Semantic_SAM_2D.mp4'
159
+ RGB_Semantic_SAM_2D = 'outputs/RGB_Semantic_SAM_2D.mp4'
160
  Depth_map = read_image('outputs/Depth_rendered.png')
161
+ Depth_Semantic_SAM_Mask_gif = 'outputs/Depth_3D_All.mp4'
162
+ RGB_Semantic_SAM_Mask_gif = 'outputs/RGB_3D_All.mp4'
163
+ return RGB_Semantic_SAM_2D, RGB_Semantic_SAM_Mask_gif, Depth_map, Depth_Semantic_SAM_2D, Depth_Semantic_SAM_Mask_gif
164
 
165
  def greet_scannet(rgb_input, depth_map_input, class_candidates):
166
  rgb_input = rgb_input
 
199
 
200
  Depth_Semantic_SAM_Mask = read_image('outputs/Depth_Semantic_SAM_Mask.png')
201
  RGB_Semantic_SAM_Mask = read_image('outputs/RGB_Semantic_SAM_Mask.png')
202
+ Depth_Semantic_SAM = read_image('outputs/Depth_Semantic_SAM.png')
203
+ RGB_Semantic_SAM = read_image('outputs/RGB_Semantic_SAM.png')
204
+ two_image_to_gif(Depth_Semantic_SAM_Mask, Depth_Semantic_SAM, 'Depth_Semantic_SAM_2D')
205
+ two_image_to_gif(RGB_Semantic_SAM_Mask, RGB_Semantic_SAM, 'RGB_Semantic_SAM_2D')
206
+ Depth_Semantic_SAM_2D = 'outputs/Depth_Semantic_SAM_2D.mp4'
207
+ RGB_Semantic_SAM_2D = 'outputs/RGB_Semantic_SAM_2D.mp4'
208
  Depth_map = read_image('outputs/Depth_rendered.png')
209
+ Depth_Semantic_SAM_Mask_gif = 'outputs/Depth_3D_All.mp4'
210
+ RGB_Semantic_SAM_Mask_gif = 'outputs/RGB_3D_All.mp4'
211
+ return RGB_Semantic_SAM_2D, RGB_Semantic_SAM_Mask_gif, Depth_map, Depth_Semantic_SAM_2D, Depth_Semantic_SAM_Mask_gif
212
 
213
 
214
  with gr.Blocks(analytics_enabled=False) as segrgbd_iface:
 
224
  with gr.Column():
225
  with gr.Row():
226
  Input_RGB_Component = gr.Image(label = 'RGB_Input', type = 'filepath').style(width=320, height=200)
227
+ Depth_Map_Output_Component = gr.Image(label = "Vis_Depth_Map").style(width=320, height=200)
228
  with gr.Row():
229
+ Depth_Map_Input_Component = gr.File(label = 'input_Depth_map')
230
  Component_2D_to_3D_Projection_Parameters = gr.File(label = '2D_to_3D_Projection_Parameters')
231
  with gr.Row():
232
  Class_Candidates_Component = gr.Text(label = 'Class_Candidates')
233
  vc_end_btn = gr.Button("Send")
234
  with gr.Tab(label='Result'):
235
  with gr.Row():
236
+ RGB_Semantic_SAM_Mask_Component = gr.Video(label = "RGB_Semantic_SAM_Mask").style(width=320, height=200)
237
  RGB_Semantic_SAM_Mask_3D_Component = gr.Video(label = "Video_3D_RGB_Semantic_SAM_Mask").style(width=320, height=200)
238
  with gr.Row():
239
+ Depth_Semantic_SAM_Mask_Component = gr.Video(label = "Depth_Semantic_SAM_Mask").style(width=320, height=200)
240
  Depth_Semantic_SAM_Mask_3D_Component = gr.Video(label = "Video_3D_Depth_Semantic_SAM_Mask").style(width=320, height=200)
241
+ with gr.Row():
242
+ gr.Markdown("<b> It takes around 2 to 5 minutes to get the final results. The framework initialization, SAM segmentation, zero-shot semantic segmentation and 3D results rendering take long time.</b>")
243
  gr.Examples(examples=[
244
  [
245
  'UI/sailvos3d/ex1/inputs/rgb_000160.bmp',
 
267
  with gr.Column():
268
  with gr.Row():
269
  Input_RGB_Component = gr.Image(label = 'RGB_Input', type = 'filepath').style(width=320, height=200)
270
+ Depth_Map_Output_Component = gr.Image(label = "Vis_Depth_Map").style(width=320, height=200)
271
  with gr.Row():
272
+ Depth_Map_Input_Component = gr.File(label = "Input_Depth_Map")
273
  Class_Candidates_Component = gr.Text(label = 'Class_Candidates')
274
  vc_end_btn = gr.Button("Send")
275
  with gr.Tab(label='Result'):
276
  with gr.Row():
277
+ RGB_Semantic_SAM_Mask_Component = gr.Video(label = "RGB_Semantic_SAM_Mask").style(width=320, height=200)
278
  RGB_Semantic_SAM_Mask_3D_Component = gr.Video(label = "Video_3D_RGB_Semantic_SAM_Mask").style(width=320, height=200)
279
  with gr.Row():
280
+ Depth_Semantic_SAM_Mask_Component = gr.Video(label = "Depth_Semantic_SAM_Mask").style(width=320, height=200)
281
  Depth_Semantic_SAM_Mask_3D_Component = gr.Video(label = "Video_3D_Depth_Semantic_SAM_Mask").style(width=320, height=200)
282
+ with gr.Row():
283
+ gr.Markdown("<b> It takes around 2 to 5 minutes to get the final results. The framework initialization, SAM segmentation, zero-shot semantic segmentation and 3D results rendering take long time.</b>")
284
  gr.Examples(examples=[
285
  [
286
  'UI/scannetv2/examples/scene0000_00/color/1660.jpg',
open_vocab_seg/utils/__pycache__/predictor.cpython-39.pyc CHANGED
Binary files a/open_vocab_seg/utils/__pycache__/predictor.cpython-39.pyc and b/open_vocab_seg/utils/__pycache__/predictor.cpython-39.pyc differ
 
open_vocab_seg/utils/predictor.py CHANGED
@@ -7,6 +7,7 @@ import torchvision
7
  import imageio
8
  from tqdm import tqdm
9
  import os
 
10
 
11
  from pytorch3d.structures import Pointclouds
12
  from pytorch3d.renderer import look_at_view_transform
@@ -19,7 +20,6 @@ from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamP
19
  import matplotlib.pyplot as plt
20
  import matplotlib as mpl
21
  from .pcd_rendering import unproject_pts_pt, get_coord_grids_pt, create_pcd_renderer
22
- import cv2
23
 
24
 
25
  class OVSegPredictor(DefaultPredictor):
@@ -179,7 +179,7 @@ class VisualizationDemo(object):
179
  if "sem_seg" in predictions:
180
  r = predictions["sem_seg"]
181
  pred_mask = r.argmax(dim=0).to('cpu')
182
- pred_mask = np.array(pred_mask, dtype=int)
183
 
184
  vis_output = visualizer.draw_sem_seg(
185
  pred_mask
@@ -223,6 +223,7 @@ class VisualizationDemo(object):
223
  )
224
  print('Using SAM to generate segments for the RGB image')
225
  masks_rgb = mask_generator_2.generate(image)
 
226
 
227
  print('Using SAM to generate segments for the Depth map')
228
  d, world_coord = self.project_2d_to_3d(depth_map_path, rage_matrices_path)
@@ -233,15 +234,16 @@ class VisualizationDemo(object):
233
  plt.axis('off')
234
  plt.savefig('outputs/Depth_rendered.png', bbox_inches='tight', pad_inches=0.0)
235
  masks_depth = mask_generator_2.generate(image_depth.astype(np.uint8)[:,:,:-1])
 
236
 
237
  if "sem_seg" in predictions:
238
  r = predictions["sem_seg"]
239
  pred_mask = r.argmax(dim=0).to('cpu')
240
- pred_mask = np.array(pred_mask, dtype=int)
241
 
242
  pred_mask_sam_rgb = pred_mask.copy()
243
  for mask in masks_rgb:
244
- cls_tmp, cls_num = np.unique(pred_mask_sam_rgb[mask['segmentation']], return_counts=True)
245
  pred_mask_sam_rgb[mask['segmentation']] = cls_tmp[np.argmax(cls_num)]
246
  mask['class'] = cls_tmp[np.argmax(cls_num)]
247
 
@@ -254,7 +256,7 @@ class VisualizationDemo(object):
254
 
255
  pred_mask_sam_depth = pred_mask.copy()
256
  for mask in masks_depth:
257
- cls_tmp, cls_num = np.unique(pred_mask_sam_depth[mask['segmentation']], return_counts=True)
258
  pred_mask_sam_depth[mask['segmentation']] = cls_tmp[np.argmax(cls_num)]
259
  mask['class'] = cls_tmp[np.argmax(cls_num)]
260
 
@@ -400,7 +402,8 @@ class VisualizationDemo(object):
400
  num_frames = 45
401
  degrees = np.linspace(120, 220, num_frames)
402
 
403
- total = ['rgb_3d_sam_mask', 'depth_3d_sam_mask']
 
404
 
405
  for j, name in enumerate(total):
406
  img = torch.from_numpy(xyzrgb[name][:, 3:] / 255.).to(device).float()
@@ -415,13 +418,21 @@ class VisualizationDemo(object):
415
  result = result.permute(0, 3, 1, 2)
416
  frame = (255. * result.detach().cpu().squeeze().permute(1, 2, 0).numpy()).astype(np.uint8)
417
  frames.append(frame)
 
 
418
 
419
  # video_out_file = '{}.gif'.format(name)
420
- # imageio.mimwrite(os.path.join('outputs', video_out_file), frames, duration=20)
421
 
422
  video_out_file = '{}.mp4'.format(name)
423
  imageio.mimwrite(os.path.join('outputs', video_out_file), frames, fps=25, quality=8)
424
-
 
 
 
 
 
 
425
  class VisualizationDemoIndoor(VisualizationDemo):
426
  def __init__(self, cfg, instance_mode=ColorMode.IMAGE, parallel=False):
427
  super().__init__(cfg, instance_mode, parallel)
@@ -472,6 +483,7 @@ class VisualizationDemoIndoor(VisualizationDemo):
472
  )
473
  print('Using SAM to generate segments for the RGB image')
474
  masks_rgb = mask_generator_2.generate(image)
 
475
 
476
  print('Using SAM to generate segments for the Depth map')
477
  d = np.full(depth_img.shape, 0, dtype=float)
@@ -483,6 +495,7 @@ class VisualizationDemoIndoor(VisualizationDemo):
483
  plt.axis('off')
484
  plt.savefig('outputs/Depth_rendered.png')
485
  masks_depth = mask_generator_2.generate(colored_depth.astype(np.uint8)[:,:,:-1])
 
486
 
487
  if "sem_seg" in predictions:
488
  r = predictions["sem_seg"]
@@ -541,7 +554,7 @@ class VisualizationDemoIndoor(VisualizationDemo):
541
 
542
  output3D = {}
543
  output3D['rgb_3d_sem'] = np.stack((uv_depth, output2D['sem_seg_on_rgb'].get_image()), axis=2).reshape((depth_img.shape[0], depth_img.shape[1], 6))
544
- output3D['depth_3d_sem'] = np.stack((uv_depth, output2D['sem_seg_on_rgb'].get_image()), axis=2).reshape((depth_img.shape[0], depth_img.shape[1], 6))
545
  output3D['rgb_3d_sam'] = np.stack((uv_depth, output2D['sam_seg_on_rgb'].get_image()), axis=2).reshape((depth_img.shape[0], depth_img.shape[1], 6))
546
  output3D['depth_3d_sam'] = np.stack((uv_depth, output2D['sam_seg_on_depth'].get_image()), axis=2).reshape((depth_img.shape[0], depth_img.shape[1], 6))
547
 
@@ -585,6 +598,7 @@ class VisualizationDemoIndoor(VisualizationDemo):
585
  )
586
  print('Using SAM to generate segments for the RGB image')
587
  masks_rgb = mask_generator_2.generate(image)
 
588
 
589
  print('Using SAM to generate segments for the Depth map')
590
  d = np.full(depth_img.shape, 0, dtype=float)
@@ -596,6 +610,7 @@ class VisualizationDemoIndoor(VisualizationDemo):
596
  plt.axis('off')
597
  plt.savefig('outputs/Depth_rendered.png')
598
  masks_depth = mask_generator_2.generate(colored_depth.astype(np.uint8)[:,:,:-1])
 
599
 
600
  if "sem_seg" in predictions:
601
  r = predictions["sem_seg"]
@@ -656,7 +671,7 @@ class VisualizationDemoIndoor(VisualizationDemo):
656
 
657
  output3D = {}
658
  output3D['rgb_3d_sem'] = np.stack((uv_depth, output2D['sem_seg_on_rgb'].get_image()), axis=2).reshape((depth_img.shape[0], depth_img.shape[1], 6))
659
- output3D['depth_3d_sem'] = np.stack((uv_depth, output2D['sem_seg_on_rgb'].get_image()), axis=2).reshape((depth_img.shape[0], depth_img.shape[1], 6))
660
  output3D['rgb_3d_sam'] = np.stack((uv_depth, output2D['sam_seg_on_rgb'].get_image()), axis=2).reshape((depth_img.shape[0], depth_img.shape[1], 6))
661
  output3D['depth_3d_sam'] = np.stack((uv_depth, output2D['sam_seg_on_depth'].get_image()), axis=2).reshape((depth_img.shape[0], depth_img.shape[1], 6))
662
 
@@ -743,9 +758,10 @@ class VisualizationDemoIndoor(VisualizationDemo):
743
  radius = 1.5 / min(h, w) * 2.0
744
 
745
 
746
- total = ['rgb_3d_sam_mask', 'depth_3d_sam_mask']
747
  num_frames = 45
748
  degrees = np.linspace(120, 220, num_frames)
 
749
  for j, name in enumerate(total):
750
  img = torch.from_numpy(xyzrgb[name][:, :, 3:] / 255.).to(device).float()
751
  pcd = Pointclouds(points=[pts], features=[img.squeeze().reshape(-1, 3)])
@@ -761,11 +777,17 @@ class VisualizationDemoIndoor(VisualizationDemo):
761
  result = result.permute(0, 3, 1, 2)
762
  frame = (255. * result.detach().cpu().squeeze().permute(1, 2, 0).numpy()).astype(np.uint8)
763
  frames.append(frame)
 
 
764
 
765
- # video_out_file = '{}.gif'.format(name)
766
  # imageio.mimwrite(os.path.join('outputs', video_out_file), frames, fps=25)
767
 
768
  video_out_file = '{}.mp4'.format(name)
769
  imageio.mimwrite(os.path.join('outputs', video_out_file), frames, fps=25, quality=8)
770
-
771
-
 
 
 
 
 
7
  import imageio
8
  from tqdm import tqdm
9
  import os
10
+ import cv2
11
 
12
  from pytorch3d.structures import Pointclouds
13
  from pytorch3d.renderer import look_at_view_transform
 
20
  import matplotlib.pyplot as plt
21
  import matplotlib as mpl
22
  from .pcd_rendering import unproject_pts_pt, get_coord_grids_pt, create_pcd_renderer
 
23
 
24
 
25
  class OVSegPredictor(DefaultPredictor):
 
179
  if "sem_seg" in predictions:
180
  r = predictions["sem_seg"]
181
  pred_mask = r.argmax(dim=0).to('cpu')
182
+ pred_mask = np.array(pred_mask, dtype=np.int)
183
 
184
  vis_output = visualizer.draw_sem_seg(
185
  pred_mask
 
223
  )
224
  print('Using SAM to generate segments for the RGB image')
225
  masks_rgb = mask_generator_2.generate(image)
226
+ masks_rgb = sorted(masks_rgb, key=(lambda x: x['area']), reverse=True)
227
 
228
  print('Using SAM to generate segments for the Depth map')
229
  d, world_coord = self.project_2d_to_3d(depth_map_path, rage_matrices_path)
 
234
  plt.axis('off')
235
  plt.savefig('outputs/Depth_rendered.png', bbox_inches='tight', pad_inches=0.0)
236
  masks_depth = mask_generator_2.generate(image_depth.astype(np.uint8)[:,:,:-1])
237
+ masks_depth = sorted(masks_depth, key=(lambda x: x['area']), reverse=True)
238
 
239
  if "sem_seg" in predictions:
240
  r = predictions["sem_seg"]
241
  pred_mask = r.argmax(dim=0).to('cpu')
242
+ pred_mask = np.array(pred_mask, dtype=np.int)
243
 
244
  pred_mask_sam_rgb = pred_mask.copy()
245
  for mask in masks_rgb:
246
+ cls_tmp, cls_num = np.unique(pred_mask[mask['segmentation']], return_counts=True)
247
  pred_mask_sam_rgb[mask['segmentation']] = cls_tmp[np.argmax(cls_num)]
248
  mask['class'] = cls_tmp[np.argmax(cls_num)]
249
 
 
256
 
257
  pred_mask_sam_depth = pred_mask.copy()
258
  for mask in masks_depth:
259
+ cls_tmp, cls_num = np.unique(pred_mask[mask['segmentation']], return_counts=True)
260
  pred_mask_sam_depth[mask['segmentation']] = cls_tmp[np.argmax(cls_num)]
261
  mask['class'] = cls_tmp[np.argmax(cls_num)]
262
 
 
402
  num_frames = 45
403
  degrees = np.linspace(120, 220, num_frames)
404
 
405
+ total = ['rgb_3d_sam', 'depth_3d_sam', 'rgb_3d_sam_mask', 'depth_3d_sam_mask']
406
+ frames_all = {}
407
 
408
  for j, name in enumerate(total):
409
  img = torch.from_numpy(xyzrgb[name][:, 3:] / 255.).to(device).float()
 
418
  result = result.permute(0, 3, 1, 2)
419
  frame = (255. * result.detach().cpu().squeeze().permute(1, 2, 0).numpy()).astype(np.uint8)
420
  frames.append(frame)
421
+
422
+ frames_all[name] = frames
423
 
424
  # video_out_file = '{}.gif'.format(name)
425
+ # imageio.mimwrite(os.path.join('outputs', video_out_file), frames, fps=25)
426
 
427
  video_out_file = '{}.mp4'.format(name)
428
  imageio.mimwrite(os.path.join('outputs', video_out_file), frames, fps=25, quality=8)
429
+
430
+ video_out_file = '{}.mp4'.format('RGB_3D_All')
431
+ imageio.mimwrite(os.path.join('outputs', video_out_file), frames_all['rgb_3d_sam_mask']+frames_all['rgb_3d_sam'], fps=25, quality=8)
432
+
433
+ video_out_file = '{}.mp4'.format('Depth_3D_All')
434
+ imageio.mimwrite(os.path.join('outputs', video_out_file), frames_all['depth_3d_sam_mask']+frames_all['depth_3d_sam'], fps=25, quality=8)
435
+
436
  class VisualizationDemoIndoor(VisualizationDemo):
437
  def __init__(self, cfg, instance_mode=ColorMode.IMAGE, parallel=False):
438
  super().__init__(cfg, instance_mode, parallel)
 
483
  )
484
  print('Using SAM to generate segments for the RGB image')
485
  masks_rgb = mask_generator_2.generate(image)
486
+ masks_rgb = sorted(masks_rgb, key=(lambda x: x['area']), reverse=True)
487
 
488
  print('Using SAM to generate segments for the Depth map')
489
  d = np.full(depth_img.shape, 0, dtype=float)
 
495
  plt.axis('off')
496
  plt.savefig('outputs/Depth_rendered.png')
497
  masks_depth = mask_generator_2.generate(colored_depth.astype(np.uint8)[:,:,:-1])
498
+ masks_depth = sorted(masks_depth, key=(lambda x: x['area']), reverse=True)
499
 
500
  if "sem_seg" in predictions:
501
  r = predictions["sem_seg"]
 
554
 
555
  output3D = {}
556
  output3D['rgb_3d_sem'] = np.stack((uv_depth, output2D['sem_seg_on_rgb'].get_image()), axis=2).reshape((depth_img.shape[0], depth_img.shape[1], 6))
557
+ output3D['depth_3d_sem'] = np.stack((uv_depth, output2D['sem_seg_on_depth'].get_image()), axis=2).reshape((depth_img.shape[0], depth_img.shape[1], 6))
558
  output3D['rgb_3d_sam'] = np.stack((uv_depth, output2D['sam_seg_on_rgb'].get_image()), axis=2).reshape((depth_img.shape[0], depth_img.shape[1], 6))
559
  output3D['depth_3d_sam'] = np.stack((uv_depth, output2D['sam_seg_on_depth'].get_image()), axis=2).reshape((depth_img.shape[0], depth_img.shape[1], 6))
560
 
 
598
  )
599
  print('Using SAM to generate segments for the RGB image')
600
  masks_rgb = mask_generator_2.generate(image)
601
+ masks_rgb = sorted(masks_rgb, key=(lambda x: x['area']), reverse=True)
602
 
603
  print('Using SAM to generate segments for the Depth map')
604
  d = np.full(depth_img.shape, 0, dtype=float)
 
610
  plt.axis('off')
611
  plt.savefig('outputs/Depth_rendered.png')
612
  masks_depth = mask_generator_2.generate(colored_depth.astype(np.uint8)[:,:,:-1])
613
+ masks_depth = sorted(masks_depth, key=(lambda x: x['area']), reverse=True)
614
 
615
  if "sem_seg" in predictions:
616
  r = predictions["sem_seg"]
 
671
 
672
  output3D = {}
673
  output3D['rgb_3d_sem'] = np.stack((uv_depth, output2D['sem_seg_on_rgb'].get_image()), axis=2).reshape((depth_img.shape[0], depth_img.shape[1], 6))
674
+ output3D['depth_3d_sem'] = np.stack((uv_depth, output2D['sem_seg_on_depth'].get_image()), axis=2).reshape((depth_img.shape[0], depth_img.shape[1], 6))
675
  output3D['rgb_3d_sam'] = np.stack((uv_depth, output2D['sam_seg_on_rgb'].get_image()), axis=2).reshape((depth_img.shape[0], depth_img.shape[1], 6))
676
  output3D['depth_3d_sam'] = np.stack((uv_depth, output2D['sam_seg_on_depth'].get_image()), axis=2).reshape((depth_img.shape[0], depth_img.shape[1], 6))
677
 
 
758
  radius = 1.5 / min(h, w) * 2.0
759
 
760
 
761
+ total = ['rgb_3d_sam', 'depth_3d_sam', 'rgb_3d_sam_mask', 'depth_3d_sam_mask']
762
  num_frames = 45
763
  degrees = np.linspace(120, 220, num_frames)
764
+ frames_all = {}
765
  for j, name in enumerate(total):
766
  img = torch.from_numpy(xyzrgb[name][:, :, 3:] / 255.).to(device).float()
767
  pcd = Pointclouds(points=[pts], features=[img.squeeze().reshape(-1, 3)])
 
777
  result = result.permute(0, 3, 1, 2)
778
  frame = (255. * result.detach().cpu().squeeze().permute(1, 2, 0).numpy()).astype(np.uint8)
779
  frames.append(frame)
780
+
781
+ frames_all[name] = frames
782
 
783
+ # video_out_file = '{}.mp4'.format(name)
784
  # imageio.mimwrite(os.path.join('outputs', video_out_file), frames, fps=25)
785
 
786
  video_out_file = '{}.mp4'.format(name)
787
  imageio.mimwrite(os.path.join('outputs', video_out_file), frames, fps=25, quality=8)
788
+
789
+ video_out_file = '{}.mp4'.format('RGB_3D_All')
790
+ imageio.mimwrite(os.path.join('outputs', video_out_file), frames_all['rgb_3d_sam_mask']+frames_all['rgb_3d_sam'], fps=25, quality=8)
791
+
792
+ video_out_file = '{}.mp4'.format('Depth_3D_All')
793
+ imageio.mimwrite(os.path.join('outputs', video_out_file), frames_all['depth_3d_sam_mask']+frames_all['depth_3d_sam'], fps=25, quality=8)
tools/__init__.py ADDED
File without changes
tools/util.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import os
4
+ import copy
5
+ from PIL import Image
6
+ import json
7
+ import imageio
8
+ # import clip
9
+
10
+
11
+ SCANNET_COLOR_MAP_20 = {-1: (0., 0., 0.), 0: (174., 199., 232.), 1: (152., 223., 138.), 2: (31., 119., 180.), 3: (255., 187., 120.), 4: (188., 189., 34.), 5: (140., 86., 75.),
12
+ 6: (255., 152., 150.), 7: (214., 39., 40.), 8: (197., 176., 213.), 9: (148., 103., 189.), 10: (196., 156., 148.), 11: (23., 190., 207.), 12: (247., 182., 210.),
13
+ 13: (219., 219., 141.), 14: (255., 127., 14.), 15: (158., 218., 229.), 16: (44., 160., 44.), 17: (112., 128., 144.), 18: (227., 119., 194.), 19: (82., 84., 163.)}
14
+
15
+ class Voxelize(object):
16
+ def __init__(self,
17
+ voxel_size=0.05,
18
+ hash_type="fnv",
19
+ mode='train',
20
+ keys=("coord", "normal", "color", "label"),
21
+ return_discrete_coord=False,
22
+ return_min_coord=False):
23
+ self.voxel_size = voxel_size
24
+ self.hash = self.fnv_hash_vec if hash_type == "fnv" else self.ravel_hash_vec
25
+ assert mode in ["train", "test"]
26
+ self.mode = mode
27
+ self.keys = keys
28
+ self.return_discrete_coord = return_discrete_coord
29
+ self.return_min_coord = return_min_coord
30
+
31
+ def __call__(self, data_dict):
32
+ assert "coord" in data_dict.keys()
33
+ discrete_coord = np.floor(data_dict["coord"] / np.array(self.voxel_size)).astype(np.int)
34
+ min_coord = discrete_coord.min(0) * np.array(self.voxel_size)
35
+ discrete_coord -= discrete_coord.min(0)
36
+ key = self.hash(discrete_coord)
37
+ idx_sort = np.argsort(key)
38
+ key_sort = key[idx_sort]
39
+ _, inverse, count = np.unique(key_sort, return_inverse=True, return_counts=True)
40
+ if self.mode == 'train': # train mode
41
+ # idx_select = np.cumsum(np.insert(count, 0, 0)[0:-1]) + np.random.randint(0, count.max(), count.size) % count
42
+ idx_select = np.cumsum(np.insert(count, 0, 0)[0:-1])
43
+ idx_unique = idx_sort[idx_select]
44
+ if self.return_discrete_coord:
45
+ data_dict["discrete_coord"] = discrete_coord[idx_unique]
46
+ if self.return_min_coord:
47
+ data_dict["min_coord"] = min_coord.reshape([1, 3])
48
+ for key in self.keys:
49
+ data_dict[key] = data_dict[key][idx_unique]
50
+ return data_dict
51
+
52
+ elif self.mode == 'test': # test mode
53
+ data_part_list = []
54
+ for i in range(count.max()):
55
+ idx_select = np.cumsum(np.insert(count, 0, 0)[0:-1]) + i % count
56
+ idx_part = idx_sort[idx_select]
57
+ data_part = dict(index=idx_part)
58
+ for key in data_dict.keys():
59
+ if key in self.keys:
60
+ data_part[key] = data_dict[key][idx_part]
61
+ else:
62
+ data_part[key] = data_dict[key]
63
+ if self.return_discrete_coord:
64
+ data_part["discrete_coord"] = discrete_coord[idx_part]
65
+ if self.return_min_coord:
66
+ data_part["min_coord"] = min_coord.reshape([1, 3])
67
+ data_part_list.append(data_part)
68
+ return data_part_list
69
+ else:
70
+ raise NotImplementedError
71
+
72
+ @staticmethod
73
+ def ravel_hash_vec(arr):
74
+ """
75
+ Ravel the coordinates after subtracting the min coordinates.
76
+ """
77
+ assert arr.ndim == 2
78
+ arr = arr.copy()
79
+ arr -= arr.min(0)
80
+ arr = arr.astype(np.uint64, copy=False)
81
+ arr_max = arr.max(0).astype(np.uint64) + 1
82
+
83
+ keys = np.zeros(arr.shape[0], dtype=np.uint64)
84
+ # Fortran style indexing
85
+ for j in range(arr.shape[1] - 1):
86
+ keys += arr[:, j]
87
+ keys *= arr_max[j + 1]
88
+ keys += arr[:, -1]
89
+ return keys
90
+
91
+ @staticmethod
92
+ def fnv_hash_vec(arr):
93
+ """
94
+ FNV64-1A
95
+ """
96
+ assert arr.ndim == 2
97
+ # Floor first for negative coordinates
98
+ arr = arr.copy()
99
+ arr = arr.astype(np.uint64, copy=False)
100
+ hashed_arr = np.uint64(14695981039346656037) * np.ones(arr.shape[0], dtype=np.uint64)
101
+ for j in range(arr.shape[1]):
102
+ hashed_arr *= np.uint64(1099511628211)
103
+ hashed_arr = np.bitwise_xor(hashed_arr, arr[:, j])
104
+ return hashed_arr
105
+
106
+
107
+ def overlap_percentage(mask1, mask2):
108
+ intersection = np.logical_and(mask1, mask2)
109
+ area_intersection = np.sum(intersection)
110
+
111
+ area_mask1 = np.sum(mask1)
112
+ area_mask2 = np.sum(mask2)
113
+
114
+ smaller_area = min(area_mask1, area_mask2)
115
+
116
+ return area_intersection / smaller_area
117
+
118
+
119
+ def remove_samll_masks(masks, ratio=0.8):
120
+ filtered_masks = []
121
+ skip_masks = set()
122
+
123
+ for i, mask1_dict in enumerate(masks):
124
+ if i in skip_masks:
125
+ continue
126
+
127
+ should_keep = True
128
+ for j, mask2_dict in enumerate(masks):
129
+ if i == j or j in skip_masks:
130
+ continue
131
+ mask1 = mask1_dict["segmentation"]
132
+ mask2 = mask2_dict["segmentation"]
133
+ overlap = overlap_percentage(mask1, mask2)
134
+ if overlap > ratio:
135
+ if np.sum(mask1) < np.sum(mask2):
136
+ should_keep = False
137
+ break
138
+ else:
139
+ skip_masks.add(j)
140
+
141
+ if should_keep:
142
+ filtered_masks.append(mask1)
143
+
144
+ return filtered_masks
145
+
146
+
147
+ def to_numpy(x):
148
+ if isinstance(x, torch.Tensor):
149
+ x = x.clone().detach().cpu().numpy()
150
+ assert isinstance(x, np.ndarray)
151
+ return x
152
+
153
+
154
+ def save_point_cloud(coord, color=None, file_path="pc.ply", logger=None):
155
+ os.makedirs(os.path.dirname(file_path), exist_ok=True)
156
+ coord = to_numpy(coord)
157
+ if color is not None:
158
+ color = to_numpy(color)
159
+ pcd = o3d.geometry.PointCloud()
160
+ pcd.points = o3d.utility.Vector3dVector(coord)
161
+ pcd.colors = o3d.utility.Vector3dVector(np.ones_like(coord) if color is None else color)
162
+ o3d.io.write_point_cloud(file_path, pcd)
163
+ if logger is not None:
164
+ logger.info(f"Save Point Cloud to: {file_path}")
165
+
166
+
167
+ def remove_small_group(group_ids, th):
168
+ unique_elements, counts = np.unique(group_ids, return_counts=True)
169
+ result = group_ids.copy()
170
+ for i, count in enumerate(counts):
171
+ if count < th:
172
+ result[group_ids == unique_elements[i]] = -1
173
+
174
+ return result
175
+
176
+
177
+ def pairwise_indices(length):
178
+ return [[i, i + 1] if i + 1 < length else [i] for i in range(0, length, 2)]
179
+
180
+
181
+ def num_to_natural(group_ids):
182
+ '''
183
+ Change the group number to natural number arrangement
184
+ '''
185
+ if np.all(group_ids == -1):
186
+ return group_ids
187
+ array = copy.deepcopy(group_ids)
188
+ unique_values = np.unique(array[array != -1])
189
+ mapping = np.full(np.max(unique_values) + 2, -1)
190
+ mapping[unique_values + 1] = np.arange(len(unique_values))
191
+ array = mapping[array + 1]
192
+ return array
193
+
194
+
195
+ def get_matching_indices(source, pcd_tree, search_voxel_size, K=None):
196
+ match_inds = []
197
+ for i, point in enumerate(source.points):
198
+ [_, idx, _] = pcd_tree.search_radius_vector_3d(point, search_voxel_size)
199
+ if K is not None:
200
+ idx = idx[:K]
201
+ for j in idx:
202
+ # match_inds[i, j] = 1
203
+ match_inds.append((i, j))
204
+ return match_inds
205
+
206
+
207
+ def visualize_3d(data_dict, text_feat_path, save_path):
208
+ text_feat = torch.load(text_feat_path)
209
+ group_logits = np.einsum('nc,mc->nm', data_dict["group_feat"], text_feat)
210
+ group_labels = np.argmax(group_logits, axis=-1)
211
+ labels = group_labels[data_dict["group"]]
212
+ labels[data_dict["group"] == -1] = -1
213
+ visualize_pcd(data_dict["coord"], data_dict["color"], labels, save_path)
214
+
215
+
216
+ def visualize_pcd(coord, pcd_color, labels, save_path):
217
+ # alpha = 0.5
218
+ label_color = np.array([SCANNET_COLOR_MAP_20[label] for label in labels])
219
+ # overlay = (pcd_color * (1-alpha) + label_color * alpha).astype(np.uint8) / 255
220
+ label_color = label_color / 255
221
+ save_point_cloud(coord, label_color, save_path)
222
+
223
+
224
+ def visualize_2d(img_color, labels, img_size, save_path):
225
+ import matplotlib.pyplot as plt
226
+ # from skimage.segmentation import mark_boundaries
227
+ # from skimage.color import label2rgb
228
+ label_names = ["wall", "floor", "cabinet", "bed", "chair",
229
+ "sofa", "table", "door", "window", "bookshelf",
230
+ "picture", "counter", "desk", "curtain", "refridgerator",
231
+ "shower curtain", "toilet", "sink", "bathtub", "other"]
232
+ colors = np.array(list(SCANNET_COLOR_MAP_20.values()))[1:]
233
+ segmentation_color = np.zeros((img_size[0], img_size[1], 3))
234
+ for i, color in enumerate(colors):
235
+ segmentation_color[labels == i] = color
236
+ alpha = 1
237
+ overlay = (img_color * (1-alpha) + segmentation_color * alpha).astype(np.uint8)
238
+ fig, ax = plt.subplots()
239
+ ax.imshow(overlay)
240
+ patches = [plt.plot([], [], 's', color=np.array(color)/255, label=label)[0] for label, color in zip(label_names, colors)]
241
+ plt.legend(handles=patches, bbox_to_anchor=(0.5, -0.1), loc='upper center', ncol=4, fontsize='small')
242
+ plt.savefig(save_path, bbox_inches='tight')
243
+ plt.show()
244
+
245
+
246
+ def visualize_partition(coord, group_id, save_path):
247
+ group_id = group_id.reshape(-1)
248
+ num_groups = group_id.max() + 1
249
+ group_colors = np.random.rand(num_groups, 3)
250
+ group_colors = np.vstack((group_colors, np.array([0,0,0])))
251
+ color = group_colors[group_id]
252
+ save_point_cloud(coord, color, save_path)
253
+
254
+
255
+ def delete_invalid_group(group, group_feat):
256
+ indices = np.unique(group[group != -1])
257
+ group = num_to_natural(group)
258
+ group_feat = group_feat[indices]
259
+ return group, group_feat
260
+
261
+ def group_sem_voting(semantic_label, seg_result, instance_num=0):
262
+ if instance_num == 0:
263
+ instance_num = seg_result.max() + 1
264
+ seg_labels = []
265
+ sem_map = -1 * torch.ones_like(semantic_label)
266
+ for n in range(instance_num):
267
+ mask = (seg_result == n)
268
+ if mask.sum() == 0:
269
+ sem_map[mask] = -1
270
+ seg_labels.append(-1)
271
+ continue
272
+ seg_label_n_cover, seg_label_n_nums = torch.unique(semantic_label[mask], return_counts=True)
273
+ seg_label_n = seg_label_n_cover[seg_label_n_nums.max(-1)[1]]
274
+ seg_labels.append(seg_label_n)
275
+ sem_map[mask] = seg_label_n
276
+
277
+ return sem_map
278
+
279
+ def two_image_to_gif(image_1, image_2, name):
280
+ num_begin = 30
281
+ num_frames = 30
282
+ num_end = 30
283
+ frames = []
284
+ for i in range(num_begin):
285
+ frames.append(image_1)
286
+ for i in range(num_frames):
287
+ image_tmp = image_1 + (image_2 - image_1) * (i / (num_frames - 1))
288
+ frames.append(image_tmp.astype(np.uint8))
289
+ for i in range(num_end):
290
+ frames.append(image_2)
291
+
292
+ # video_out_file = '{}.gif'.format(name)
293
+ # imageio.mimwrite(os.path.join('outputs', video_out_file), frames, fps=25)
294
+
295
+ video_out_file = '{}.mp4'.format(name)
296
+ imageio.mimwrite(os.path.join('outputs', video_out_file), frames, fps=25, quality=8)