sam_3d / sam_3d.py
JeffLiang
update
092977e
import os
import copy
import random
from PIL import Image
import torch
import matplotlib.pyplot as plt
import numpy as np
from plyfile import PlyData
from segment_anything import SamPredictor, sam_model_registry
def get_image_ids(path):
files = os.listdir(path)
files = [f.split('.')[0] for f in files if os.path.isfile(path+'/'+f)] #Filtering only the files.
return sorted(files)
def load_align_matrix_from_txt(path):
lines = open(path).readlines()
# test set data doesn't have align_matrix
axis_align_matrix = np.eye(4)
for line in lines:
if 'axisAlignment' in line:
axis_align_matrix = [
float(x)
for x in line.rstrip().strip('axisAlignment = ').split(' ')
]
break
axis_align_matrix = np.array(axis_align_matrix).reshape((4, 4))
return axis_align_matrix
def load_matrix_from_txt(path, shape=(4, 4)):
with open(path) as f:
txt = f.readlines()
txt = ''.join(txt).replace('\n', ' ')
matrix = [float(v) for v in txt.split()]
return np.array(matrix).reshape(shape)
def load_image(path):
image = Image.open(path)
return np.array(image)
def convert_from_uvd(u, v, d, intr, pose, align):
extr = np.linalg.inv(pose)
if d == 0:
return None, None, None
fx = intr[0, 0]
fy = intr[1, 1]
cx = intr[0, 2]
cy = intr[1, 2]
depth_scale = 1000
z = d / depth_scale
x = (u - cx) * z / fx
y = (v - cy) * z / fy
world = (align @ pose @ np.array([x, y, z, 1]))
return world[:3] / world[3]
# Find the cloest point in the cloud with select
def find_closest_point(point, point_cloud, num=1):
# calculate the Euclidean distances between the input vector and each row of the matrix
distances = np.linalg.norm(point_cloud - point, axis=1)
# find the index of the row with the minimum distance
closest_index = np.argsort(distances)[:num]
# get the closest vector from the matrix
closest_vector = point_cloud[closest_index]
return closest_index, closest_vector
def plot_3d(xdata, ydata, zdata, color=None, b_min=2, b_max=8, view=(45, 45)):
fig, ax = plt.subplots(subplot_kw={"projection": "3d"}, dpi=200)
ax.view_init(view[0], view[1])
ax.set_xlim(b_min, b_max)
ax.set_ylim(b_min, b_max)
ax.set_zlim(b_min, b_max)
ax.scatter3D(xdata, ydata, zdata, c=color, cmap='rgb', s=0.1)
class SAM3DDemo(object):
def __init__(self, sam_model, sam_ckpt, scene_name):
sam = sam_model_registry[sam_model](checkpoint=sam_ckpt)
self.predictor = SamPredictor(sam)
self.scene_name = scene_name
scene_path = os.path.join('./scannet_data', scene_name)
self.color_path = os.path.join(scene_path, 'color')
self.depth_path = os.path.join(scene_path, 'depth')
self.pose_path = os.path.join(scene_path, 'pose')
self.intrinsic_path = os.path.join(scene_path, 'intrinsic')
self.align_matirx_path = f'{scene_path}/{scene_name}.txt'
self.img_ids = get_image_ids(self.color_path)
self.align_matrix = load_align_matrix_from_txt(self.align_matirx_path)
self.intrinsic_depth = load_matrix_from_txt(os.path.join(self.intrinsic_path, 'intrinsic_depth.txt'))
self.poses = [load_matrix_from_txt(os.path.join(self.pose_path, f'{i}.txt')) for i in self.img_ids]
self.rgb_images = [load_image(os.path.join(self.color_path, f'{i}.jpg')) for i in self.img_ids]
self.depth_images = [load_image(os.path.join(self.depth_path, f'{i}.png'))for i in self.img_ids]
def project_3D_to_images(self, select_points, valid_margin=20):
valid_img_ids = []
valid_points = {}
for img_i in range(len(self.img_ids)):
rgb_img = self.rgb_images[img_i]
depth_img = self.depth_images[img_i]
extrinsics = self.poses[img_i]
projection_matrix = self.intrinsic_depth @ np.linalg.inv(self.align_matrix @ extrinsics)
raw_points = np.vstack((select_points.T, np.ones((1, select_points.T.shape[1]))))
raw_points = np.dot(projection_matrix, raw_points)
# bounding simplest
points = raw_points[:2, :] / raw_points[2, :]
points = np.round(points).astype(np.int32)
valid = (points[0] >= valid_margin).all() & (points[1] >= valid_margin).all() \
& (points[0] < (rgb_img.shape[1] - valid_margin)).all() & (points[1] < (rgb_img.shape[0] - valid_margin)).all() \
& (raw_points[2, :] > 0).all()
if valid:
depth_margin = 0.4
gt_depths = depth_img[points[1], points[0]] / 1000
proj_depths = raw_points[2, :]
if (proj_depths[0] > (1 - depth_margin / 2.0) * gt_depths[0]) & (proj_depths[0] < (1 + depth_margin / 2.0) * gt_depths[0]):
valid_img_ids.append(img_i)
valid_points[img_i] = points
show_id = valid_img_ids[-1]
show_points = valid_points[show_id]
rgb_img = self.rgb_images[show_id]
fig, ax = plt.subplots()
ax.imshow(rgb_img)
for x, y in zip(show_points[0], show_points[1]):
ax.plot(x, y, 'ro')
canvas = fig.canvas
canvas.draw()
w, h = canvas.get_width_height()
rgb_img_w_points = np.frombuffer(canvas.tostring_rgb(), dtype='uint8').reshape(h, w, 3)
print("projecting 3D point to images successfully...")
return valid_img_ids, valid_points, rgb_img_w_points
def process_img_w_sam(self, valid_img_ids, valid_points, granularity):
mask_colors = []
for img_i in range(len(self.img_ids)):
rgb_img = self.rgb_images[img_i]
msk_color = np.full(rgb_img.shape, 0.5)
if img_i in valid_img_ids:
self.predictor.set_image(rgb_img)
point_coor = valid_points[img_i].T[0][None]
masks, _, _ = self.predictor.predict(point_coords=point_coor, point_labels=np.array([1]))
# fig, axs = plt.subplots(nrows=1, ncols=3, figsize=(10, 5))
# for i in range(3):
# mask_img = masks[i][:,:,None] * rgb_img
# axs[i].set_title(f'granularity {i}')
# axs[i].imshow(mask_img)
m = masks[granularity]
msk_color[m] = [0, 0, 1.0]
mask_colors.append(msk_color)
show_id = valid_img_ids[-1]
rgb_img = self.rgb_images[show_id]
fig, axs = plt.subplots(nrows=1, ncols=3, figsize=(24, 8))
for i in range(3):
mask_img = masks[i][:,:,None] * rgb_img
axs[i].set_title(f'granularity {i}')
axs[i].imshow(mask_img)
canvas = fig.canvas
canvas.draw()
w, h = canvas.get_width_height()
rgb_img_w_masks = np.frombuffer(canvas.tostring_rgb(), dtype='uint8').reshape(h, w, 3)
print("processing images with SAM successfully...")
return mask_colors, rgb_img_w_masks
def project_mask_to_3d(self, mask_colors, sample_ratio=0.002):
x_data, y_data, z_data, c_data = [], [], [], []
for img_i in range(len(self.img_ids)):
id = self.img_ids[img_i]
# RGBD
d = self.depth_images[img_i]
c = self.rgb_images[img_i]
p = self.poses[img_i]
msk_color = mask_colors[img_i]
# Projecting RGB features into the point space
for i in range(d.shape[0]):
for j in range(d.shape[1]):
if random.random() < sample_ratio:
x, y, z = convert_from_uvd(j, i, d[i, j], self.intrinsic_depth, p, self.align_matrix)
if x is None:
continue
x_data.append(x)
y_data.append(y)
z_data.append(z)
ci = int(i * c.shape[0] / d.shape[0])
cj = int(j * c.shape[1] / d.shape[1])
c_data.append([msk_color[ci, cj]])
print("reprojecting images to 3D points successfully...")
return x_data, y_data, z_data, c_data
def match_projected_point_to_gt_point(self, x_data, y_data, z_data, c_data, gt_coords):
c_data = torch.tensor(np.concatenate(c_data, axis=0))
img_coords = np.array([x_data, y_data, z_data], dtype=np.float32).T
gt_quant_coords = np.floor_divide(gt_coords, 0.2)
img_quant_coords = np.floor_divide(img_coords, 0.2)
# Remove the reduandant coords
unique_gt_coords, gt_inverse_indices = np.unique(gt_quant_coords, axis=0, return_inverse=True)
unique_img_coords, img_inverse_indices = np.unique(img_quant_coords, axis=0, return_inverse=True)
# Match the coords in gt_coords to img_corrds
def find_loc(vec):
obj = np.empty((), dtype=object)
out = np.where((unique_img_coords == vec).all(1))[0]
obj[()] = out
return obj
gt_2_img_map = np.apply_along_axis(find_loc, 1, unique_gt_coords)
# Since some places are empty, using the simple round interplation
gt_2_img_map_filled = []
start_id = np.array([0])
for loc in gt_2_img_map:
if not np.any(loc):
loc = start_id
else:
start_id = loc
gt_2_img_map_filled.append(int(loc))
mean_colors = []
for i in range(unique_img_coords.shape[0]):
valid_locs = np.where(img_inverse_indices == i)
mean_f = torch.mean(c_data[valid_locs], axis=0)
# mean_f, _ = torch.mode(c_data[valid_locs], dim=0)
mean_colors.append(mean_f.unsqueeze(0))
mean_colors = torch.cat(mean_colors)
# Project the averaged features back to groundtruth point clouds
img_2_gt_colors = mean_colors[gt_2_img_map_filled]
projected_gt_colors = img_2_gt_colors[gt_inverse_indices]
print("convert projected points to GT points successfully...")
return projected_gt_colors
def render_point_cloud(self, data, color):
data_copy = copy.copy(data)
uint_color = torch.round(torch.tensor(color) * 255).to(torch.uint8)
data_copy['red'] = uint_color[:, 0]
data_copy['green'] = uint_color[:, 1]
data_copy['blue'] = uint_color[:, 2]
return data_copy
def run_with_coord(self, point, granularity):
x_data, y_data, z_data, c_data = [], [], [], []
plydata = PlyData.read(f"./scannet_data/{self.scene_name}/{self.scene_name}.ply")
data = plydata.elements[0].data
# gt_coords stand for the groudtruth point clouds coordinates
gt_coords = np.array([data['x'], data['y'], data['z']], dtype=np.float32).T
gt_color = np.array([data['red'], data['green'], data['blue']], dtype=np.float32).T
blank_color = np.full(gt_color.shape, 0.5)
select_index, select_points = find_closest_point(point, gt_coords, num=10)
point_select_color = blank_color.copy()
point_select_color[select_index] = [1.0, 0, 0]
data_point_select = self.render_point_cloud(data, point_select_color)
valid_img_ids, valid_points, rgb_img_w_points = self.project_3D_to_images(select_points)
mask_colors, rgb_img_w_masks = self.process_img_w_sam(valid_img_ids, valid_points, granularity)
x_data, y_data, z_data, c_data = self.project_mask_to_3d(mask_colors)
projected_gt_colors = self.match_projected_point_to_gt_point(x_data, y_data, z_data, c_data, gt_coords)
data_final = self.render_point_cloud(data, projected_gt_colors)
return data_point_select, rgb_img_w_points, rgb_img_w_masks, data_final