StableRecon / gs_utils.py
Stable-X's picture
feat: Add gs_utils for gs export
0332bda
raw
history blame
3.47 kB
import numpy as np
import os
from plyfile import PlyElement, PlyData
import open3d as o3d
def get_f_dc(colors):
return RGB2SH(colors)[:, :, np.newaxis]
def get_f_rest(points, max_sh_degree=3):
f_rest_shape = (points.shape[0], (max_sh_degree + 1) ** 2 - 1, 3)
return np.zeros(f_rest_shape)
def get_opacity(points):
return inverse_sigmoid(0.5 * np.ones((points.shape[0], 1)))
def get_scales(points):
scales = np.ones((points.shape[0], 3)) * 0.0015
scales[:, 2] = 1e-6
return np.log(scales)
def get_rotation(normals):
if normals is not None and np.any(normals):
return normal2rotation(normals)
else:
return np.zeros((normals.shape[0], 4))
def RGB2SH(rgb):
return (rgb - 0.5) / 0.28209479177387814
def inverse_sigmoid(x):
return np.log(x / (1 - x))
def normal2rotation(n):
n = n / np.linalg.norm(n, axis=1, keepdims=True)
w0 = np.tile([[1, 0, 0]], (n.shape[0], 1))
R0 = w0 - np.sum(w0 * n, axis=1, keepdims=True) * n
R0 *= np.sign(R0[:, :1])
R0 /= np.linalg.norm(R0, axis=1, keepdims=True)
R1 = np.cross(n, R0)
R1 *= np.sign(R1[:, 1:2]) * np.sign(n[:, 2:])
R = np.stack([R0, R1, n], axis=-1)
q = rotmat2quaternion(R)
return q
def rotmat2quaternion(R, normalize=False):
tr = R[:, 0, 0] + R[:, 1, 1] + R[:, 2, 2] + 1e-6
r = np.sqrt(1 + tr) / 2
q = np.stack([
r,
(R[:, 2, 1] - R[:, 1, 2]) / (4 * r),
(R[:, 0, 2] - R[:, 2, 0]) / (4 * r),
(R[:, 1, 0] - R[:, 0, 1]) / (4 * r)
], axis=-1)
if normalize:
q /= np.linalg.norm(q, axis=-1, keepdims=True)
return q
def point2gs(path, pcd, scale=None, max_sh_degree=1):
# Ensure the directory exists
os.makedirs(os.path.dirname(path), exist_ok=True)
# Get point cloud data
xyz = np.asarray(pcd.points)
normals = np.asarray(pcd.normals) if pcd.has_normals() else np.zeros_like(xyz)
colors = np.asarray(pcd.colors) if pcd.has_colors() else np.ones_like(xyz)
# Generate additional attributes
f_dc = get_f_dc(colors).reshape(xyz.shape[0], -1)
f_rest = get_f_rest(xyz, max_sh_degree).reshape(xyz.shape[0], -1)
opacities = get_opacity(xyz)
if scale is not None:
scale = np.log(scale)
else:
scale = get_scales(xyz)
rotation = get_rotation(normals)
# Construct list of attributes
attribute_names = ['x', 'y', 'z', 'nx', 'ny', 'nz']
attribute_names.extend([f'f_dc_{i}' for i in range(f_dc.shape[-1])])
attribute_names.extend([f'f_rest_{i}' for i in range(f_rest.shape[-1])])
attribute_names.append('opacity')
attribute_names.extend([f'scale_{i}' for i in range(scale.shape[1])])
attribute_names.extend([f'rot_{i}' for i in range(rotation.shape[1])])
# Create dtype for numpy structured array
dtype_full = [(attribute, 'f4') for attribute in attribute_names]
# Combine all attributes
attributes = np.concatenate((
xyz, normals,
f_dc,
f_rest,
opacities, scale, rotation
), axis=1)
# Ensure attributes match the dtype
assert attributes.shape[1] == len(dtype_full), f"Mismatch in attribute count. Expected {len(dtype_full)}, got {attributes.shape[1]}"
# Create structured array
elements = np.empty(xyz.shape[0], dtype=dtype_full)
elements[:] = list(map(tuple, attributes))
# Create PlyElement and save
el = PlyElement.describe(elements, 'vertex')
PlyData([el]).write(path)