import os import copy import random from PIL import Image import torch import matplotlib.pyplot as plt import numpy as np from plyfile import PlyData from segment_anything import SamPredictor, sam_model_registry def get_image_ids(path): files = os.listdir(path) files = [f.split('.')[0] for f in files if os.path.isfile(path+'/'+f)] #Filtering only the files. return sorted(files) def load_align_matrix_from_txt(path): lines = open(path).readlines() # test set data doesn't have align_matrix axis_align_matrix = np.eye(4) for line in lines: if 'axisAlignment' in line: axis_align_matrix = [ float(x) for x in line.rstrip().strip('axisAlignment = ').split(' ') ] break axis_align_matrix = np.array(axis_align_matrix).reshape((4, 4)) return axis_align_matrix def load_matrix_from_txt(path, shape=(4, 4)): with open(path) as f: txt = f.readlines() txt = ''.join(txt).replace('\n', ' ') matrix = [float(v) for v in txt.split()] return np.array(matrix).reshape(shape) def load_image(path): image = Image.open(path) return np.array(image) def convert_from_uvd(u, v, d, intr, pose, align): extr = np.linalg.inv(pose) if d == 0: return None, None, None fx = intr[0, 0] fy = intr[1, 1] cx = intr[0, 2] cy = intr[1, 2] depth_scale = 1000 z = d / depth_scale x = (u - cx) * z / fx y = (v - cy) * z / fy world = (align @ pose @ np.array([x, y, z, 1])) return world[:3] / world[3] # Find the cloest point in the cloud with select def find_closest_point(point, point_cloud, num=1): # calculate the Euclidean distances between the input vector and each row of the matrix distances = np.linalg.norm(point_cloud - point, axis=1) # find the index of the row with the minimum distance closest_index = np.argsort(distances)[:num] # get the closest vector from the matrix closest_vector = point_cloud[closest_index] return closest_index, closest_vector def plot_3d(xdata, ydata, zdata, color=None, b_min=2, b_max=8, view=(45, 45)): fig, ax = plt.subplots(subplot_kw={"projection": "3d"}, dpi=200) ax.view_init(view[0], view[1]) ax.set_xlim(b_min, b_max) ax.set_ylim(b_min, b_max) ax.set_zlim(b_min, b_max) ax.scatter3D(xdata, ydata, zdata, c=color, cmap='rgb', s=0.1) class SAM3DDemo(object): def __init__(self, sam_model, sam_ckpt, scene_name): sam = sam_model_registry[sam_model](checkpoint=sam_ckpt) self.predictor = SamPredictor(sam) self.scene_name = scene_name scene_path = os.path.join('./scannet_data', scene_name) self.color_path = os.path.join(scene_path, 'color') self.depth_path = os.path.join(scene_path, 'depth') self.pose_path = os.path.join(scene_path, 'pose') self.intrinsic_path = os.path.join(scene_path, 'intrinsic') self.align_matirx_path = f'{scene_path}/{scene_name}.txt' self.img_ids = get_image_ids(self.color_path) self.align_matrix = load_align_matrix_from_txt(self.align_matirx_path) self.intrinsic_depth = load_matrix_from_txt(os.path.join(self.intrinsic_path, 'intrinsic_depth.txt')) self.poses = [load_matrix_from_txt(os.path.join(self.pose_path, f'{i}.txt')) for i in self.img_ids] self.rgb_images = [load_image(os.path.join(self.color_path, f'{i}.jpg')) for i in self.img_ids] self.depth_images = [load_image(os.path.join(self.depth_path, f'{i}.png'))for i in self.img_ids] def project_3D_to_images(self, select_points, valid_margin=20): valid_img_ids = [] valid_points = {} for img_i in range(len(self.img_ids)): rgb_img = self.rgb_images[img_i] depth_img = self.depth_images[img_i] extrinsics = self.poses[img_i] projection_matrix = self.intrinsic_depth @ np.linalg.inv(self.align_matrix @ extrinsics) raw_points = np.vstack((select_points.T, np.ones((1, select_points.T.shape[1])))) raw_points = np.dot(projection_matrix, raw_points) # bounding simplest points = raw_points[:2, :] / raw_points[2, :] points = np.round(points).astype(np.int32) valid = (points[0] >= valid_margin).all() & (points[1] >= valid_margin).all() \ & (points[0] < (rgb_img.shape[1] - valid_margin)).all() & (points[1] < (rgb_img.shape[0] - valid_margin)).all() \ & (raw_points[2, :] > 0).all() if valid: depth_margin = 0.4 gt_depths = depth_img[points[1], points[0]] / 1000 proj_depths = raw_points[2, :] if (proj_depths[0] > (1 - depth_margin / 2.0) * gt_depths[0]) & (proj_depths[0] < (1 + depth_margin / 2.0) * gt_depths[0]): valid_img_ids.append(img_i) valid_points[img_i] = points show_id = valid_img_ids[-1] show_points = valid_points[show_id] rgb_img = self.rgb_images[show_id] fig, ax = plt.subplots() ax.imshow(rgb_img) for x, y in zip(show_points[0], show_points[1]): ax.plot(x, y, 'ro') canvas = fig.canvas canvas.draw() w, h = canvas.get_width_height() rgb_img_w_points = np.frombuffer(canvas.tostring_rgb(), dtype='uint8').reshape(h, w, 3) print("projecting 3D point to images successfully...") return valid_img_ids, valid_points, rgb_img_w_points def process_img_w_sam(self, valid_img_ids, valid_points, granularity): mask_colors = [] for img_i in range(len(self.img_ids)): rgb_img = self.rgb_images[img_i] msk_color = np.full(rgb_img.shape, 0.5) if img_i in valid_img_ids: self.predictor.set_image(rgb_img) point_coor = valid_points[img_i].T[0][None] masks, _, _ = self.predictor.predict(point_coords=point_coor, point_labels=np.array([1])) # fig, axs = plt.subplots(nrows=1, ncols=3, figsize=(10, 5)) # for i in range(3): # mask_img = masks[i][:,:,None] * rgb_img # axs[i].set_title(f'granularity {i}') # axs[i].imshow(mask_img) m = masks[granularity] msk_color[m] = [0, 0, 1.0] mask_colors.append(msk_color) show_id = valid_img_ids[-1] rgb_img = self.rgb_images[show_id] fig, axs = plt.subplots(nrows=1, ncols=3, figsize=(24, 8)) for i in range(3): mask_img = masks[i][:,:,None] * rgb_img axs[i].set_title(f'granularity {i}') axs[i].imshow(mask_img) canvas = fig.canvas canvas.draw() w, h = canvas.get_width_height() rgb_img_w_masks = np.frombuffer(canvas.tostring_rgb(), dtype='uint8').reshape(h, w, 3) print("processing images with SAM successfully...") return mask_colors, rgb_img_w_masks def project_mask_to_3d(self, mask_colors, sample_ratio=0.002): x_data, y_data, z_data, c_data = [], [], [], [] for img_i in range(len(self.img_ids)): id = self.img_ids[img_i] # RGBD d = self.depth_images[img_i] c = self.rgb_images[img_i] p = self.poses[img_i] msk_color = mask_colors[img_i] # Projecting RGB features into the point space for i in range(d.shape[0]): for j in range(d.shape[1]): if random.random() < sample_ratio: x, y, z = convert_from_uvd(j, i, d[i, j], self.intrinsic_depth, p, self.align_matrix) if x is None: continue x_data.append(x) y_data.append(y) z_data.append(z) ci = int(i * c.shape[0] / d.shape[0]) cj = int(j * c.shape[1] / d.shape[1]) c_data.append([msk_color[ci, cj]]) print("reprojecting images to 3D points successfully...") return x_data, y_data, z_data, c_data def match_projected_point_to_gt_point(self, x_data, y_data, z_data, c_data, gt_coords): c_data = torch.tensor(np.concatenate(c_data, axis=0)) img_coords = np.array([x_data, y_data, z_data], dtype=np.float32).T gt_quant_coords = np.floor_divide(gt_coords, 0.2) img_quant_coords = np.floor_divide(img_coords, 0.2) # Remove the reduandant coords unique_gt_coords, gt_inverse_indices = np.unique(gt_quant_coords, axis=0, return_inverse=True) unique_img_coords, img_inverse_indices = np.unique(img_quant_coords, axis=0, return_inverse=True) # Match the coords in gt_coords to img_corrds def find_loc(vec): obj = np.empty((), dtype=object) out = np.where((unique_img_coords == vec).all(1))[0] obj[()] = out return obj gt_2_img_map = np.apply_along_axis(find_loc, 1, unique_gt_coords) # Since some places are empty, using the simple round interplation gt_2_img_map_filled = [] start_id = np.array([0]) for loc in gt_2_img_map: if not np.any(loc): loc = start_id else: start_id = loc gt_2_img_map_filled.append(int(loc)) mean_colors = [] for i in range(unique_img_coords.shape[0]): valid_locs = np.where(img_inverse_indices == i) mean_f = torch.mean(c_data[valid_locs], axis=0) # mean_f, _ = torch.mode(c_data[valid_locs], dim=0) mean_colors.append(mean_f.unsqueeze(0)) mean_colors = torch.cat(mean_colors) # Project the averaged features back to groundtruth point clouds img_2_gt_colors = mean_colors[gt_2_img_map_filled] projected_gt_colors = img_2_gt_colors[gt_inverse_indices] print("convert projected points to GT points successfully...") return projected_gt_colors def render_point_cloud(self, data, color): data_copy = copy.copy(data) uint_color = torch.round(torch.tensor(color) * 255).to(torch.uint8) data_copy['red'] = uint_color[:, 0] data_copy['green'] = uint_color[:, 1] data_copy['blue'] = uint_color[:, 2] return data_copy def run_with_coord(self, point, granularity): x_data, y_data, z_data, c_data = [], [], [], [] plydata = PlyData.read(f"./scannet_data/{self.scene_name}/{self.scene_name}.ply") data = plydata.elements[0].data # gt_coords stand for the groudtruth point clouds coordinates gt_coords = np.array([data['x'], data['y'], data['z']], dtype=np.float32).T gt_color = np.array([data['red'], data['green'], data['blue']], dtype=np.float32).T blank_color = np.full(gt_color.shape, 0.5) select_index, select_points = find_closest_point(point, gt_coords, num=10) point_select_color = blank_color.copy() point_select_color[select_index] = [1.0, 0, 0] data_point_select = self.render_point_cloud(data, point_select_color) valid_img_ids, valid_points, rgb_img_w_points = self.project_3D_to_images(select_points) mask_colors, rgb_img_w_masks = self.process_img_w_sam(valid_img_ids, valid_points, granularity) x_data, y_data, z_data, c_data = self.project_mask_to_3d(mask_colors) projected_gt_colors = self.match_projected_point_to_gt_point(x_data, y_data, z_data, c_data, gt_coords) data_final = self.render_point_cloud(data, projected_gt_colors) return data_point_select, rgb_img_w_points, rgb_img_w_masks, data_final