# Copyright (c) Facebook, Inc. and its affiliates. # Copyright (c) Meta Platforms, Inc. All Rights Reserved import os import ast import time import random from PIL import Image import torch import numpy as np from tqdm import tqdm import matplotlib.pyplot as plt from plyfile import PlyData import gradio as gr import plotly.graph_objs as go from sam_3d import SAM3DDemo def pc_to_plot(pc): return go.Figure( data=[ go.Scatter3d( x=pc['x'], y=pc['y'], z=pc['z'], mode='markers', marker=dict( size=2, color=['rgb({},{},{})'.format(r,g,b) for r,g,b in zip(pc['red'], pc['green'], pc['blue'])], ) ) ], layout=dict( scene=dict(xaxis=dict(visible=False), yaxis=dict(visible=False), zaxis=dict(visible=False)) ), ) def inference(scene_name, granularity, coords, plot): print(scene_name, coords) sam_3d = SAM3DDemo('vit_b', 'sam_vit_b_01ec64.pth', scene_name) coords = ast.literal_eval(coords) data_point_select, rgb_img_w_points, rgb_img_w_masks, data_final = sam_3d.run_with_coord(coords, int(granularity)) return pc_to_plot(data_point_select), Image.fromarray(rgb_img_w_points), Image.fromarray(rgb_img_w_masks), pc_to_plot(data_final) plydatas = [] for scene_name in ['scene0000_00', 'scene0001_00', 'scene0002_00']: plydata = PlyData.read(f"./scannet_data/{scene_name}/{scene_name}.ply") data = plydata.elements[0].data plydatas.append(data) examples = [['scene0000_00', 0, [0, -2.5, 0.7], pc_to_plot(plydatas[0])], ['scene0001_00', 0, [0, -2.5, 1], pc_to_plot(plydatas[1])], ['scene0002_00', 0, [0, -2.5, 1], pc_to_plot(plydatas[2])],] title = 'Segment_Anything on 3D in-door point clouds' description = """ Gradio Demo for Segment Anything on 3D indoor scenes (ScanNet supported). \n The logic is straighforward: 1) Find a point in 3D; 2) project the 3D point to valid images; 3) perform 2D SAM on valid images; 4) reproject 2D results back to 3D; 5) Visualization. Unfortunatly, it does not support click the point cloud to generate coordinates automatically. You may want to write down the coordinates and put it manually. \n """ article = """
Open-Vocabulary Semantic Segmentation with Mask-adapted CLIP | Github Repo
""" gr.Interface( inference, inputs=[ gr.Dropdown(choices=['scene0000_00', 'scene0001_00', 'scene0002_00'], label="Scannet scene name (limited scenes supported)"), gr.Dropdown(choices=[0, 1, 2], label="Mask granularity from 0 (most coarse) to 2 (most precise)"), gr.Textbox(lines=1, label='Coordinates'), gr.Plot(label="Input Point cloud (For visualization and point finding only, click responce not supported yet.)"), ], outputs=[gr.Plot(label='Selected point(s): red points show the top 10 cloest points for your input anchor point'), gr.Image(label='Selected image with projected points'), gr.Image(label='Selected image processed after SAM'), gr.Plot(label='Output Point cloud: blue points represent the mask')], title=title, description=description, article=article, examples=examples).launch()