JasonYinnnn's picture
init
afea36f
# SPDX-FileCopyrightText: 2026 Ze-Xin Yin, Robot labs of Horizon Robotics, and D-Robotics
# SPDX-License-Identifier: Apache-2.0
# See the LICENSE file in the project root for full license information.
import os
import json
import cv2
import torch
from PIL import Image
import imageio
import numpy as np
import open3d as o3d
from einops import rearrange
def voxelize_mesh(points, faces, clip_range_first=False, return_mask=True, resolution=64):
if clip_range_first:
points = np.clip(points, -0.5 + 1e-6, 0.5 - 1e-6)
mesh = o3d.geometry.TriangleMesh()
mesh.vertices = o3d.utility.Vector3dVector(points)
if isinstance(faces, o3d.cuda.pybind.utility.Vector3iVector):
mesh.triangles = faces
else:
mesh.triangles = o3d.cuda.pybind.utility.Vector3iVector(faces)
voxel_grid = o3d.geometry.VoxelGrid.create_from_triangle_mesh_within_bounds(mesh, voxel_size=1/64, min_bound=(-0.5, -0.5, -0.5), max_bound=(0.5, 0.5, 0.5))
vertices = np.array([voxel.grid_index for voxel in voxel_grid.get_voxels()])
assert np.all(vertices >= 0) and np.all(vertices < 64), "Some vertices are out of bounds"
vertices = (vertices + 0.5) / 64 - 0.5
coords = ((torch.tensor(vertices) + 0.5) * resolution).int().contiguous()
ss = torch.zeros(1, resolution, resolution, resolution, dtype=torch.long)
ss[:, coords[:, 0], coords[:, 1], coords[:, 2]] = 1
if return_mask:
ss_mask = rearrange(ss, 'c (x n1) (y n2) (z n3) -> (n1 n2 n3 c) x y z', n1=4, n2=4, n3=4).float()
return ss , ss_mask
else:
return ss
def transform_vertices(vertices, ops, params):
for op, param in zip(ops, params):
if op == 'scale':
vertices = vertices * param
elif op == 'translation':
vertices = vertices + param
else:
raise NotImplementedError
return vertices
def normalize_vertices(vertices, scale_factor=1.0):
min_pos, max_pos = np.min(vertices, axis=0), np.max(vertices, axis=0)
trans_pos = (min_pos + max_pos)[None] / 2.0
scale_pos = np.max(max_pos - min_pos) * scale_factor # 1: [-0.5, 0.5], 2.0: [-0.25, 0.25]
vertices = transform_vertices(vertices, ops=['translation', 'scale'],
params=[-trans_pos, 1.0 / (scale_pos + 1e-6)])
return vertices, trans_pos, scale_pos
def renormalize_vertices(vertices, val_range=0.5, scale_factor=1.25):
min_pos, max_pos = np.min(vertices, axis=0), np.max(vertices, axis=0)
if (min_pos < -val_range).any() or (max_pos > val_range).any():
trans_pos = (min_pos + max_pos)[None] / 2.0
scale_pos = np.max(max_pos - min_pos) * scale_factor # 1: [-0.5, 0.5], 2.0: [-0.25, 0.25]
vertices = transform_vertices(vertices, ops=['translation', 'scale'],
params=[-trans_pos, 1.0 / (scale_pos + 1e-6)])
return vertices
def rot_vertices(vertices, rot_angles, axis_list=['z']):
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(vertices)
for ang, axis in zip(rot_angles, axis_list):
if axis == 'x':
R = pcd.get_rotation_matrix_from_xyz((ang, 0, 0))
pcd.rotate(R, center=(0., 0., 0.))
del R
elif axis == 'y':
R = pcd.get_rotation_matrix_from_xyz((0, ang, 0))
pcd.rotate(R, center=(0., 0., 0.))
del R
elif axis == 'z':
R = pcd.get_rotation_matrix_from_xyz((0, 0, ang))
pcd.rotate(R, center=(0., 0., 0.))
del R
else:
raise NotImplementedError
rot_vertices = np.array(pcd.points)
del pcd
return rot_vertices
def _rotmat_x(a: torch.Tensor) -> torch.Tensor:
# a: scalar tensor
ca, sa = torch.cos(a), torch.sin(a)
R = torch.stack([
torch.stack([torch.ones_like(a), torch.zeros_like(a), torch.zeros_like(a)]),
torch.stack([torch.zeros_like(a), ca, -sa]),
torch.stack([torch.zeros_like(a), sa, ca]),
])
return R # [3,3]
def _rotmat_y(a: torch.Tensor) -> torch.Tensor:
ca, sa = torch.cos(a), torch.sin(a)
R = torch.stack([
torch.stack([ca, torch.zeros_like(a), sa]),
torch.stack([torch.zeros_like(a), torch.ones_like(a), torch.zeros_like(a)]),
torch.stack([-sa, torch.zeros_like(a), ca]),
])
return R
def _rotmat_z(a: torch.Tensor) -> torch.Tensor:
ca, sa = torch.cos(a), torch.sin(a)
R = torch.stack([
torch.stack([ca, -sa, torch.zeros_like(a)]),
torch.stack([sa, ca, torch.zeros_like(a)]),
torch.stack([torch.zeros_like(a), torch.zeros_like(a), torch.ones_like(a)]),
])
return R
def rot_vertices_torch(vertices, rot_angles, axis_list=('z',), center=(0.0, 0.0, 0.0)):
"""
vertices: (N,3) numpy or torch
rot_angles: iterable of angles (radians), length matches axis_list
axis_list: iterable like ['x','y','z'] (applied in order)
center: rotation center, default origin (0,0,0), same as your Open3D code
return: torch.Tensor (N,3)
"""
v = torch.as_tensor(vertices)
device, dtype = v.device, v.dtype
c = torch.tensor(center, device=device, dtype=dtype).view(1, 3)
v = v - c # translate to center
# Compose rotations in the same order as your for-loop:
# Open3D effectively does v <- v @ R^T (for row-vector points).
for ang, axis in zip(rot_angles, axis_list):
a = torch.as_tensor(ang, device=device, dtype=dtype)
if axis == 'x':
R = _rotmat_x(a)
elif axis == 'y':
R = _rotmat_y(a)
elif axis == 'z':
R = _rotmat_z(a)
else:
raise NotImplementedError(f"Unknown axis {axis}")
v = v @ R.T # match Open3D row-vector convention
v = v + c
return v
def get_instance_mask(instance_mask_path):
index_mask = imageio.v3.imread(instance_mask_path)
index_mask = np.rint(index_mask.astype(np.float32) / 65535 * 100.0) # hand coded, max obj nums = 100
instance_list = np.unique(index_mask).astype(np.uint8)
return index_mask, instance_list
def get_gt_depth(gt_depth_path, metadata):
gt_depth = imageio.v3.imread(gt_depth_path).astype(np.float32) / 65535.
depth_min, depth_max = metadata['depth']['min'], metadata['depth']['max']
gt_depth = gt_depth * (depth_max - depth_min) + depth_min
return torch.from_numpy(gt_depth).to(dtype=torch.float32)
def get_est_depth(est_depth_path):
npz = np.load(est_depth_path)
est_depth = npz['depth']
est_depth_mask = npz['mask']
est_depth = torch.from_numpy(est_depth).to(dtype=torch.float32)
ivalid_mask = torch.logical_or(torch.isnan(est_depth), torch.isinf(est_depth))
est_depth_mask = np.logical_and(est_depth_mask, ~ivalid_mask.detach().cpu().numpy())
est_depth = torch.where(ivalid_mask, 0.0, est_depth)
return est_depth, est_depth_mask
def get_mix_est_depth(est_depth_path, image_size):
if 'MoGe' in est_depth_path:
npz = np.load(est_depth_path)
est_depth = npz['depth']
est_depth_mask = npz['mask']
est_depth = torch.from_numpy(est_depth).to(dtype=torch.float32)
ivalid_mask = torch.logical_or(torch.isnan(est_depth), torch.isinf(est_depth))
est_depth_mask = np.logical_and(est_depth_mask, ~ivalid_mask.detach().cpu().numpy())
est_depth = torch.where(ivalid_mask, 0.0, est_depth)
return est_depth, est_depth_mask
elif 'DAv2_' in est_depth_path or 'ml-depth-pro' in est_depth_path:
npz = np.load(est_depth_path)
est_depth = npz['depth']
est_depth_mask = np.logical_not(np.logical_or(
np.isnan(est_depth),
np.isinf(est_depth),
))
est_depth = torch.from_numpy(est_depth).to(dtype=torch.float32)
ivalid_mask = torch.logical_or(torch.isnan(est_depth), torch.isinf(est_depth))
est_depth_mask = np.logical_and(est_depth_mask, ~ivalid_mask.detach().cpu().numpy())
est_depth = torch.where(ivalid_mask, 0.0, est_depth)
return est_depth, est_depth_mask
elif 'VGGT_1B' in est_depth_path:
npz = np.load(est_depth_path)
est_depth = npz['depth']
est_depth_mask = npz['depth_conf'] > 2.0
valid_depth_mask = np.logical_not(np.logical_or(
np.isnan(est_depth),
np.isinf(est_depth),
))
est_depth_mask = np.logical_and(
est_depth_mask,
valid_depth_mask
)
est_depth = np.where(valid_depth_mask, est_depth, 0.0)
depth_min, depth_max = np.min(est_depth), np.max(est_depth)
est_depth = (est_depth - depth_min) / (depth_max - depth_min + 1e-6)
est_depth = Image.fromarray(est_depth)
est_depth = est_depth.resize((image_size, image_size), Image.Resampling.NEAREST)
est_depth = torch.tensor(np.array(est_depth)).to(dtype=torch.float32)
est_depth = est_depth * (depth_max - depth_min) + depth_min
est_depth_mask = Image.fromarray(est_depth_mask.astype(np.float32))
est_depth_mask = est_depth_mask.resize((image_size, image_size), Image.Resampling.NEAREST)
est_depth_mask = np.array(est_depth_mask) > 0.5
ivalid_mask = torch.logical_or(torch.isnan(est_depth), torch.isinf(est_depth))
est_depth_mask = np.logical_and(est_depth_mask, ~ivalid_mask.detach().cpu().numpy())
est_depth = torch.where(ivalid_mask, 0.0, est_depth)
return est_depth, est_depth_mask
def lstsq_align_depth(est_depth, gt_depth, mask):
valid_coords = torch.nonzero(mask)
if valid_coords.shape[0] > 0:
valid_gt_depth = gt_depth[valid_coords[:, 0], valid_coords[:, 1]]
valid_est_depth = est_depth[valid_coords[:, 0], valid_coords[:, 1]]
X = torch.linalg.lstsq(valid_est_depth[None, :, None], valid_gt_depth[None, :, None]).solution
lstsq_scale = X.item()
else:
lstsq_scale = 1.0
return est_depth * lstsq_scale
def get_cam_poses(frame_info, H, W):
camera_angle_x = float(frame_info['camera_angle_x'])
focal = .5 * W / np.tan(.5 * camera_angle_x)
K = np.array([
[focal, 0, 0.5*W],
[0, focal, 0.5*H],
[0, 0, 1]
])
K = torch.from_numpy(K).float()
c2w = torch.from_numpy(np.array(frame_info['transform_matrix'])).float()
return K, c2w
def edge_mask_morph_gradient(mask, kernel, iterations=1):
"""
mask: HxW, bool/uint8
ksize: 3/5/7... 越大边缘越厚
return: edge_mask uint8 {0,1}
"""
m = (mask.astype(np.uint8) > 0).astype(np.uint8)
dil = cv2.dilate(m, kernel, iterations=iterations, borderType=cv2.BORDER_CONSTANT, borderValue=0.0)
ero = cv2.erode(m, kernel, iterations=iterations, borderType=cv2.BORDER_CONSTANT, borderValue=0.0)
edge = (dil - ero) # 0/1/2
edge = (edge > 0).astype(np.uint8)
return edge
def process_scene_image(image: Image.Image, instance_mask: np.ndarray, image_size: int,
resize_perturb: bool = False, resize_perturb_ratio: float = 0.0):
image_rgba = image
try:
alpha = np.array(image_rgba.getchannel("A")) > 0
except ValueError:
alpha = np.ones_like(np.array(image_rgba.getchannel(0))) > 0
alpha = np.logical_and(alpha, instance_mask).astype(np.uint8) * 255
image_resized = image_rgba.resize((image_size, image_size), Image.Resampling.LANCZOS).convert("RGB")
alpha_resized = Image.fromarray(alpha, mode="L").resize((image_size, image_size), Image.Resampling.NEAREST)
if resize_perturb and np.random.rand() < resize_perturb_ratio:
rand_reso = np.random.randint(32, image_size)
image_resized = image_resized.resize((rand_reso, rand_reso), Image.Resampling.LANCZOS)
image_resized = image_resized.resize((image_size, image_size), Image.Resampling.LANCZOS)
alpha_resized = alpha_resized.resize((rand_reso, rand_reso), Image.Resampling.NEAREST)
alpha_resized = alpha_resized.resize((image_size, image_size), Image.Resampling.NEAREST)
img_np = np.array(image_resized, dtype=np.uint8)
img_t = torch.from_numpy(img_np).permute(2, 0, 1).float() / 255.0
a_np = np.array(alpha_resized, dtype=np.uint8)
a_t = torch.from_numpy(a_np).unsqueeze(0).float() / 255.0
img4 = torch.cat([img_t, a_t], dim=0) # (4,S,S)
return img_t, img4
def get_rays(i, j, K, c2w):
i = i.float() + 0.5
j = j.float() + 0.5
dirs = torch.stack([(i-K[0][2])/K[0][0], -(j-K[1][2])/K[1][1], -torch.ones_like(i)], -1)
# Rotate ray directions from camera frame to the world frame
rays_d = torch.sum(dirs[..., np.newaxis, :] * c2w[:3,:3], -1) # dot product, equals to: [c2w.dot(dir) for dir in dirs]
# Translate camera frame's origin to the world frame. It is the origin of all rays.
rays_o = c2w[:3,-1].expand(rays_d.shape)
return rays_o, rays_d
def get_rays_fast(u: torch.Tensor, v: torch.Tensor, K: torch.Tensor, c2w: torch.Tensor):
"""
u, v: 1D tensor (pixel coords), dtype long/int64 or int32
K: (3,3) or (4,4) but used as 3x3; on same device as output
c2w: (4,4) or (3,4), uses [:3,:3] and [:3,3]
return:
rays_o: (N,3)
rays_d: (N,3)
"""
# 确保 float 并加 0.5 取像素中心
u = u.to(dtype=torch.float32) + 0.5
v = v.to(dtype=torch.float32) + 0.5
fx, fy = K[0, 0], K[1, 1]
cx, cy = K[0, 2], K[1, 2]
# dirs in camera frame (N,3)
dirs = torch.stack([(u - cx) / fx,
-(v - cy) / fy,
-torch.ones_like(u)], dim=-1)
# 旋转到世界坐标:dirs @ R^T (更常见/更快)
R = c2w[:3, :3] # (3,3)
rays_d = dirs @ R.T # (N,3)
# 原点:相机中心 (3,) 扩展到 (N,3)
t = c2w[:3, 3]
rays_o = t.expand_as(rays_d)
return rays_o, rays_d
def process_instance_image(image: Image.Image, instance_mask: np.ndarray, color_mask: np.ndarray, depth_map: torch.Tensor,
K: torch.Tensor, c2w: torch.Tensor, image_size: int):
image_rgba = image
try:
alpha = np.asarray(image_rgba.getchannel("A")) > 0
except ValueError:
alpha = np.ones_like(np.array(image_rgba.getchannel(0))) > 0
alpha = np.logical_and(alpha, instance_mask).astype(np.uint8) * 255
valid_mask = np.array(alpha).nonzero()
bbox = [valid_mask[1].min(), valid_mask[0].min(), valid_mask[1].max(), valid_mask[0].max()]
center = [(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2]
hsize = max(bbox[2] - bbox[0], bbox[3] - bbox[1]) / 2
aug_size_ratio = 1.2
aug_hsize = hsize * aug_size_ratio
aug_center_offset = [0, 0]
aug_center = [center[0] + aug_center_offset[0], center[1] + aug_center_offset[1]]
aug_bbox = [int(aug_center[0] - aug_hsize), int(aug_center[1] - aug_hsize), int(aug_center[0] + aug_hsize), int(aug_center[1] + aug_hsize)]
i, j = torch.from_numpy(valid_mask[1]), torch.from_numpy(valid_mask[0])
rays_o, rays_d = get_rays(i, j, K, c2w)
rays_color = color_mask[valid_mask[0], valid_mask[1]].astype(np.float32)
rays_t = depth_map[valid_mask[0], valid_mask[1]]
image_resized = image_rgba.crop(aug_bbox).convert("RGB").resize((image_size, image_size), Image.Resampling.LANCZOS)
alpha_resized = Image.fromarray(alpha, mode="L").crop(aug_bbox).resize((image_size, image_size), Image.Resampling.NEAREST)
img_np = np.asarray(image_resized, dtype=np.uint8)
img_t = torch.from_numpy(img_np).permute(2, 0, 1).float() / 255.0
a_np = np.asarray(alpha_resized, dtype=np.uint8)
a_t = torch.from_numpy(a_np).unsqueeze(0).float() / 255.0
return img_t, a_t, rays_o, rays_d, rays_color, rays_t
def get_crop_area_rays(image: Image.Image, instance_mask: np.ndarray, K: torch.Tensor, c2w: torch.Tensor, image_size):
alpha = np.asarray(image.getchannel("A")) > 0
if instance_mask is not None:
alpha = np.logical_and(alpha, instance_mask).astype(np.float32) # * 255
else:
alpha = alpha.astype(np.float32)
valid_mask = np.array(alpha).nonzero()
bbox = [valid_mask[1].min(), valid_mask[0].min(), valid_mask[1].max(), valid_mask[0].max()]
center = [(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2]
hsize = max(bbox[2] - bbox[0], bbox[3] - bbox[1]) / 2
aug_size_ratio = 1.2
aug_hsize = hsize * aug_size_ratio
aug_center_offset = [0, 0]
aug_center = [center[0] + aug_center_offset[0], center[1] + aug_center_offset[1]]
aug_bbox = [int(aug_center[0] - aug_hsize), int(aug_center[1] - aug_hsize), int(aug_center[0] + aug_hsize), int(aug_center[1] + aug_hsize)]
i, j = torch.meshgrid(
torch.linspace(aug_bbox[0], aug_bbox[2]-1, steps=image_size),
torch.linspace(aug_bbox[1], aug_bbox[3]-1, steps=image_size)
)
rays_o, rays_d = get_rays(i, j, K, c2w)
return rays_o, rays_d
def process_instance_image_crop(image: Image.Image, instance_mask: np.ndarray, color_mask: np.ndarray,
depth_map: torch.Tensor,
gt_depth_map: torch.Tensor,
K: torch.Tensor, c2w: torch.Tensor, image_size: int,
edge_mask_morph_gradient_fn):
image_rgba = image
alpha = np.asarray(image_rgba.getchannel("A")) > 0
alpha = np.logical_and(alpha, instance_mask).astype(np.float32) # * 255
valid_mask = np.array(alpha).nonzero()
bbox = [valid_mask[1].min(), valid_mask[0].min(), valid_mask[1].max(), valid_mask[0].max()]
center = [(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2]
hsize = max(bbox[2] - bbox[0], bbox[3] - bbox[1]) / 2
aug_size_ratio = 1.2
aug_hsize = hsize * aug_size_ratio
aug_center_offset = [0, 0]
aug_center = [center[0] + aug_center_offset[0], center[1] + aug_center_offset[1]]
aug_bbox = [int(aug_center[0] - aug_hsize), int(aug_center[1] - aug_hsize), int(aug_center[0] + aug_hsize), int(aug_center[1] + aug_hsize)]
i, j = torch.meshgrid(
torch.linspace(aug_bbox[0], aug_bbox[2]-1, steps=image_size),
torch.linspace(aug_bbox[1], aug_bbox[3]-1, steps=image_size)
)
rays_o, rays_d = get_rays(i, j, K, c2w)
image_resized = image_rgba.crop(aug_bbox).convert("RGB").resize((image_size, image_size), Image.Resampling.LANCZOS)
alpha_resized = Image.fromarray(alpha, mode="F").crop(aug_bbox).resize((image_size, image_size), Image.Resampling.NEAREST)
depth_map_resized = Image.fromarray(depth_map.detach().cpu().numpy(), mode="F").crop(aug_bbox).resize((image_size, image_size), Image.Resampling.NEAREST)
gt_depth_map_resized = Image.fromarray(gt_depth_map.detach().cpu().numpy(), mode="F").crop(aug_bbox).resize((image_size, image_size), Image.Resampling.NEAREST)
color_mask_resized = Image.fromarray(color_mask.astype(np.float32), mode="F").crop(aug_bbox).resize((image_size, image_size), Image.Resampling.NEAREST)
img_np = np.asarray(image_resized, dtype=np.uint8)
img_t = torch.from_numpy(img_np).permute(2, 0, 1).float() / 255.0
a_np = np.asarray(alpha_resized, dtype=np.float32).astype(dtype=np.uint8)
edge_mask = edge_mask_morph_gradient_fn((a_np > 0).astype(np.uint8))
fg_mask = (a_np > edge_mask).astype(np.uint8)
rays_color = fg_mask.astype(np.float32) + edge_mask.astype(np.float32) * 0.5
valid_mask = fg_mask.nonzero()
rays_t = torch.from_numpy(np.asarray(depth_map_resized).astype(np.float32))
a_t = torch.from_numpy(a_np).unsqueeze(0).float() # / 255.0
return img_t, a_t, fg_mask, rays_o, rays_d, rays_color, rays_t, valid_mask, depth_map_resized, gt_depth_map_resized, color_mask_resized
def process_instance_image_only(image: Image.Image, instance_mask: np.ndarray, image_size: int):
image_rgba = image
alpha = np.asarray(image_rgba.getchannel("A")) > 0
alpha = np.logical_and(alpha, instance_mask).astype(np.uint8) * 255
valid_mask = np.array(alpha).nonzero()
bbox = [valid_mask[1].min(), valid_mask[0].min(), valid_mask[1].max(), valid_mask[0].max()]
center = [(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2]
hsize = max(bbox[2] - bbox[0], bbox[3] - bbox[1]) / 2
aug_size_ratio = 1.2
aug_hsize = hsize * aug_size_ratio
aug_center_offset = [0, 0]
aug_center = [center[0] + aug_center_offset[0], center[1] + aug_center_offset[1]]
aug_bbox = [int(aug_center[0] - aug_hsize), int(aug_center[1] - aug_hsize), int(aug_center[0] + aug_hsize), int(aug_center[1] + aug_hsize)]
image_resized = image_rgba.crop(aug_bbox).convert("RGB").resize((image_size, image_size), Image.Resampling.LANCZOS)
alpha_resized = Image.fromarray(alpha, mode="L").crop(aug_bbox).resize((image_size, image_size), Image.Resampling.NEAREST)
img_np = np.asarray(image_resized, dtype=np.uint8)
img_t = torch.from_numpy(img_np).permute(2, 0, 1).float() / 255.0
a_np = np.asarray(alpha_resized, dtype=np.uint8)
a_t = torch.from_numpy(a_np).unsqueeze(0).float() / 255.0
return img_t, a_t
def crop_depth_image(depth_image, aug_bbox, image_size):
d = depth_image.cpu()
d_np = d.numpy().astype(np.float32)
img = Image.fromarray(d_np, mode="F")
img = img.crop(aug_bbox)
img = img.resize((image_size, image_size), Image.Resampling.NEAREST)
out = torch.from_numpy(np.asarray(img, dtype=np.float32))
return out
def proj_depth2pcd(mask, depth, image, rays_o, rays_d):
mask = torch.nonzero(mask)
###
mask = [mask[:, 0].detach().cpu().numpy(), mask[:, 1].detach().cpu().numpy()]
pixel_depth = depth[mask[0], mask[1]]
pixel_color = image.detach().permute(1, 2, 0)[mask[0], mask[1]]
pixel_points = rays_o[mask[0], mask[1]] + rays_d[mask[0], mask[1]] * pixel_depth[:, None] # pt
return pixel_points.detach().cpu().numpy(), pixel_color.detach().cpu().numpy()
def vox2pts(ss, resolution = 64):
coords = torch.nonzero(ss[0] > 0, as_tuple=False)
position = (coords.float() + 0.5) / resolution - 0.5
position = position.detach().cpu().numpy()
return position
def voxelize_pcd(points, points_color=None, clip_range_first=False, return_mask=True, resolution=64):
if clip_range_first:
points = np.clip(points, -0.5 + 1e-6, 0.5 - 1e-6)
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(points)
voxel_grid = o3d.geometry.VoxelGrid.create_from_point_cloud_within_bounds(pcd, voxel_size=1/resolution, min_bound=(-0.5, -0.5, -0.5), max_bound=(0.5, 0.5, 0.5))
vertices = np.array([voxel.grid_index for voxel in voxel_grid.get_voxels()])
assert np.all(vertices >= 0) and np.all(vertices < resolution), "Some vertices are out of bounds"
vertices = (vertices + 0.5) / resolution - 0.5
coords = ((torch.tensor(vertices) + 0.5) * resolution).int().contiguous()
ss = torch.zeros(1, resolution, resolution, resolution, dtype=torch.long)
ss[:, coords[:, 0], coords[:, 1], coords[:, 2]] = 1
if points_color is not None:
points_t = torch.from_numpy(points).to(torch.float32)
colors_t = torch.from_numpy(points_color).to(torch.float32)
coords = torch.floor((points_t + 0.5) * resolution).to(torch.long)
coords = torch.clamp(coords, 0, resolution - 1)
ix, iy, iz = coords[:, 0], coords[:, 1], coords[:, 2]
lin = ix * (resolution * resolution) + iy * resolution + iz # linear index in [0, R^3)
sum_color = torch.zeros((resolution * resolution * resolution), dtype=torch.float32)
sum_color.index_add_(0, lin, colors_t)
count = torch.zeros((resolution * resolution * resolution,), dtype=torch.long)
ones = torch.ones_like(lin, dtype=torch.long)
count.index_add_(0, lin, ones)
count_f = count.to(torch.float32)
mean_color = sum_color / torch.clamp(count_f, min=1.0) # empty -> divide by 1 (still 0)
color_mean = mean_color.view(resolution, resolution, resolution, 1).permute(3, 0, 1, 2).contiguous()
if return_mask:
ss_mask = rearrange(ss if points_color is None else color_mean, 'c (x n1) (y n2) (z n3) -> (n1 n2 n3 c) x y z', n1=4, n2=4, n3=4).float()
return ss , ss_mask
else:
return ss
def voxelize_pcd_pt(points, points_color=None, clip_range_first=False, return_mask=True, resolution=64):
points = torch.nan_to_num(points)
points_color = torch.nan_to_num(points_color) if isinstance(points_color, torch.Tensor) else points_color
device = points.device
if clip_range_first:
points = torch.clip(points, -0.5 + 1e-6, 0.5 - 1e-6)
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(points.detach().cpu().numpy())
voxel_grid = o3d.geometry.VoxelGrid.create_from_point_cloud_within_bounds(pcd, voxel_size=1/resolution, min_bound=(-0.5, -0.5, -0.5), max_bound=(0.5, 0.5, 0.5))
vertices = np.array([voxel.grid_index for voxel in voxel_grid.get_voxels()])
assert np.all(vertices >= 0) and np.all(vertices < resolution), "Some vertices are out of bounds"
vertices = (vertices + 0.5) / resolution - 0.5
coords = ((torch.tensor(vertices, device=device) + 0.5) * resolution).int().contiguous()
ss = torch.zeros(1, resolution, resolution, resolution, dtype=torch.long, device=device)
ss[:, coords[:, 0], coords[:, 1], coords[:, 2]] = 1
if points_color is not None:
points_t = points.to(torch.float32)
colors_t = points_color.to(torch.float32)
coords = torch.floor((points_t + 0.5) * resolution).to(torch.long)
coords = torch.clamp(coords, 0, resolution - 1)
ix, iy, iz = coords[:, 0], coords[:, 1], coords[:, 2]
lin = ix * (resolution * resolution) + iy * resolution + iz # linear index in [0, R^3)
sum_color = torch.zeros((resolution * resolution * resolution), dtype=torch.float32, device=device)
sum_color.index_add_(0, lin, colors_t)
count = torch.zeros((resolution * resolution * resolution,), dtype=torch.long, device=device)
ones = torch.ones_like(lin, dtype=torch.long)
count.index_add_(0, lin, ones)
count_f = count.to(torch.float32)
mean_color = sum_color / torch.clamp(count_f, min=1.0) # empty -> divide by 1 (still 0)
color_mean = mean_color.view(resolution, resolution, resolution, 1).permute(3, 0, 1, 2).contiguous()
if return_mask:
ss_mask = rearrange(ss if points_color is None else color_mean, 'c (x n1) (y n2) (z n3) -> (n1 n2 n3 c) x y z', n1=4, n2=4, n3=4).float()
return ss , ss_mask
else:
return ss
def get_std_cond(root, instance, crop_size, return_mask=False):
image_root = os.path.join(root, 'renders_cond', instance)
if os.path.exists(os.path.join(image_root, 'transforms.json')):
with open(os.path.join(image_root, 'transforms.json')) as f:
metadata = json.load(f)
else:
image_root = os.path.join(root, 'renders', instance)
with open(os.path.join(image_root, 'transforms.json')) as f:
metadata = json.load(f)
n_views = len(metadata['frames'])
view = np.random.randint(n_views)
metadata = metadata['frames'][view]
image_path = os.path.join(image_root, metadata['file_path'])
image = Image.open(image_path)
alpha = np.array(image.getchannel(3))
bbox = np.array(alpha).nonzero()
bbox = [bbox[1].min(), bbox[0].min(), bbox[1].max(), bbox[0].max()]
center = [(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2]
hsize = max(bbox[2] - bbox[0], bbox[3] - bbox[1]) / 2
aug_size_ratio = 1.2
aug_hsize = hsize * aug_size_ratio
aug_center_offset = [0, 0]
aug_center = [center[0] + aug_center_offset[0], center[1] + aug_center_offset[1]]
aug_bbox = [int(aug_center[0] - aug_hsize), int(aug_center[1] - aug_hsize), int(aug_center[0] + aug_hsize), int(aug_center[1] + aug_hsize)]
image = image.crop(aug_bbox)
image = image.resize((crop_size, crop_size), Image.Resampling.LANCZOS)
alpha = image.getchannel(3)
image = image.convert('RGB')
image = torch.tensor(np.array(image)).permute(2, 0, 1).float() / 255.0
alpha = torch.tensor(np.array(alpha)).float() / 255.0
image = image * alpha.unsqueeze(0)
if return_mask:
return image, alpha.unsqueeze(0)
else:
return image
def map_rotated_slat2canonical_pose(vertices, rot_slat_info):
vertices_scale = rot_slat_info['scale']
vertices_trans = np.array(rot_slat_info['translation'])
rand_rot = rot_slat_info['rotate']
pcd = o3d.geometry.PointCloud()
vertices = vertices * vertices_scale
vertices = vertices + vertices_trans
pcd.points = o3d.utility.Vector3dVector(vertices)
R1 = pcd.get_rotation_matrix_from_xyz((-rand_rot[0], 0, 0))
R2 = pcd.get_rotation_matrix_from_xyz((0, -rand_rot[1], 0))
R3 = pcd.get_rotation_matrix_from_xyz((0, 0, -rand_rot[2]))
pcd.rotate(R3, center=(0., 0., 0.))
pcd.rotate(R2, center=(0., 0., 0.))
pcd.rotate(R1, center=(0., 0., 0.))
vertices = np.asarray(pcd.points)
return vertices
def project2ply(mask, depth, image, K, c2w):
mask = torch.nonzero(mask)
rays_o, rays_d = get_rays(mask[:, 1], mask[:, 0], K, c2w)
###
mask = [mask[:, 0].detach().cpu().numpy(), mask[:, 1].detach().cpu().numpy()]
pixel_depth = depth[mask[0], mask[1]]
pixel_color = image.detach().permute(1, 2, 0).cpu().numpy()[mask[0], mask[1]]
pixel_points = rays_o + rays_d * pixel_depth[:, None]
pixel_points = pixel_points.detach().cpu().numpy()
return pixel_points, pixel_color