kaifz's picture
remove installation
3774569
import torch
import os
# import open3d as o3d
import numpy as np
from pathlib import Path
# import importlib
# diff_gaussian_rasterization = importlib.import_module("diff_gaussian_rasterization")
from diff_gaussian_rasterization import GaussianRasterizationSettings as Camera
def setup_camera(w, h, k, w2c, near=0.01, far=100.0, bg=[0, 0, 0], z_threshold=0.2, device='cuda'):
fx, fy, cx, cy = k[0][0], k[1][1], k[0][2], k[1][2]
w2c = torch.tensor(w2c).cuda().float()
cam_center = torch.inverse(w2c)[:3, 3]
w2c = w2c.unsqueeze(0).transpose(1, 2)
opengl_proj = torch.tensor([[2 * fx / w, 0.0, -(w - 2 * cx) / w, 0.0],
[0.0, 2 * fy / h, -(h - 2 * cy) / h, 0.0],
[0.0, 0.0, far / (far - near), -(far * near) / (far - near)],
[0.0, 0.0, 1.0, 0.0]]).cuda().float().unsqueeze(0).transpose(1, 2)
full_proj = w2c.bmm(opengl_proj)
cam = Camera(
image_height=h,
image_width=w,
tanfovx=w / (2 * fx),
tanfovy=h / (2 * fy),
bg=torch.tensor(bg, dtype=torch.float32, device=device),
scale_modifier=1.0,
viewmatrix=w2c.to(device),
projmatrix=full_proj.to(device),
sh_degree=0,
campos=cam_center.to(device),
prefiltered=False,
z_threshold=z_threshold,
)
return cam
def params2rendervar(params):
rendervar = {
'means3D': params['means3D'],
'colors_precomp': params['rgb_colors'],
'rotations': torch.nn.functional.normalize(params['unnorm_rotations']),
'opacities': torch.sigmoid(params['logit_opacities']),
'scales': torch.exp(params['log_scales']),
'means2D': torch.zeros_like(params['means3D'], requires_grad=True, device="cuda") + 0
}
return rendervar
def params2rendervar_wt(params, t):
print(params['unnorm_rotations'][t])
rendervar = {
'means3D': params['means3D'][t],
'colors_precomp': params['rgb_colors'][t],
'rotations': torch.nn.functional.normalize(params['unnorm_rotations'][t]),
'opacities': torch.sigmoid(params['logit_opacities'][t]),
'scales': torch.exp(params['log_scales'][t]),
'means2D': torch.zeros_like(params['means3D'][t], requires_grad=True, device="cuda") + 0
}
return rendervar
def params2rendervar_consistent_rgb(params, variables):
rendervar = {
'means3D': params['means3D'],
'colors_precomp': variables['rgb_colors'],
'rotations': torch.nn.functional.normalize(params['unnorm_rotations']),
'opacities': torch.sigmoid(params['logit_opacities']),
'scales': torch.exp(params['log_scales']),
'means2D': torch.zeros_like(params['means3D'], requires_grad=True, device="cuda") + 0
}
return rendervar
def l1_loss_v1(x, y):
return torch.abs((x - y)).mean()
def l1_loss_v2(x, y):
return (torch.abs(x - y).sum(-1)).mean()
def weighted_l2_loss_v1(x, y, w):
return torch.sqrt(((x - y) ** 2) * w + 1e-20).mean()
def weighted_l2_loss_v2(x, y, w):
return torch.sqrt(((x - y) ** 2).sum(-1) * w + 1e-20).mean()
def quat_mult(q1, q2):
w1, x1, y1, z1 = q1.T
w2, x2, y2, z2 = q2.T
w = w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2
x = w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2
y = w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2
z = w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2
return torch.stack([w, x, y, z]).T
def o3d_knn(pts, num_knn):
import open3d as o3d
indices = []
sq_dists = []
pcd = o3d.geometry.PointCloud()
# breakpoint()
# print(pts.shape)
pts_cont = np.ascontiguousarray(pts, np.float64)
# print(pts_cont.shape)
pcd.points = o3d.utility.Vector3dVector(np.ascontiguousarray(pts, np.float64))
if len(pcd.points) == 0:
print("Point cloud is empty!")
else:
pcd_tree = o3d.geometry.KDTreeFlann(pcd)
for p in pcd.points:
[_, i, d] = pcd_tree.search_knn_vector_3d(p, num_knn + 1)
indices.append(i[1:])
sq_dists.append(d[1:])
return np.array(sq_dists), np.array(indices)
def o3d_knn_tensor(pts_tensor, num_knn):
import open3d as o3d
if pts_tensor.numel() == 0:
print("Point cloud is empty!")
return None, None
pts_np = pts_tensor.detach().cpu().numpy() if pts_tensor.is_cuda else pts_tensor.numpy()
pts_np_cont = np.ascontiguousarray(pts_np, dtype=np.float64)
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(pts_np_cont)
pcd_tree = o3d.geometry.KDTreeFlann(pcd)
indices = []
sq_dists = []
for p in pts_np_cont:
[_, idx, dist] = pcd_tree.search_knn_vector_3d(p, num_knn + 1)
indices.append(idx[1:]) # Skip the first index since it's the point itself
sq_dists.append(dist[1:])
return torch.tensor(sq_dists, dtype=pts_tensor.dtype, device=pts_tensor.device), torch.tensor(indices, dtype=torch.long, device=pts_tensor.device)
def params2cpu(params, is_initial_timestep):
if is_initial_timestep:
res = {k: v.detach().cpu().contiguous().numpy() for k, v in params.items()}
else:
res = {k: v.detach().cpu().contiguous().numpy() for k, v in params.items() if
k in ['means3D', 'rgb_colors', 'unnorm_rotations']}
return res
def save_params(output_params, seq, exp):
to_save = {}
for k in output_params[0].keys():
if k in output_params[1].keys():
to_save[k] = np.stack([params[k] for params in output_params])
else:
to_save[k] = output_params[0][k]
os.makedirs(f"./output/{exp}/{seq}", exist_ok=True)
np.savez(f"./output/{exp}/{seq}/params", **to_save)
def farthest_point_sample(xyz, npoint):
"""
Input:
xyz: pointcloud data, [B, N, C]
npoint: number of samples
Return:
centroids: sampled pointcloud index, [B, npoint]
"""
device = xyz.device
B, N, C = xyz.shape
centroids = torch.zeros(B, npoint, dtype=torch.long).to(device)
distance = torch.ones(B, N).to(device) * 1e10
farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device)
batch_indices = torch.arange(B, dtype=torch.long).to(device)
for i in range(npoint):
centroids[:, i] = farthest
centroid = xyz[batch_indices, farthest, :].view(B, 1, C)
dist = torch.sum((xyz - centroid) ** 2, -1)
mask = dist < distance
distance[mask] = dist[mask]
farthest = torch.max(distance, -1)[1]
return centroids
def quat2mat(q):
norm = torch.sqrt(q[:, 0] * q[:, 0] + q[:, 1] * q[:, 1] + q[:, 2] * q[:, 2] + q[:, 3] * q[:, 3])
q = q / norm[:, None]
rot = torch.zeros((q.shape[0], 3, 3)).to(q.device)
r = q[:, 0]
x = q[:, 1]
y = q[:, 2]
z = q[:, 3]
rot[:, 0, 0] = 1 - 2 * (y * y + z * z)
rot[:, 0, 1] = 2 * (x * y - r * z)
rot[:, 0, 2] = 2 * (x * z + r * y)
rot[:, 1, 0] = 2 * (x * y + r * z)
rot[:, 1, 1] = 1 - 2 * (x * x + z * z)
rot[:, 1, 2] = 2 * (y * z - r * x)
rot[:, 2, 0] = 2 * (x * z - r * y)
rot[:, 2, 1] = 2 * (y * z + r * x)
rot[:, 2, 2] = 1 - 2 * (x * x + y * y)
return rot
def rot2quat(rot):
# Preallocate quaternion tensor
q = torch.zeros((rot.shape[0], 4)).to(rot.device)
# Compute quaternion components
q[:, 0] = 0.5 * torch.sqrt(1 + rot[:, 0, 0] + rot[:, 1, 1] + rot[:, 2, 2])
q[:, 1] = (rot[:, 2, 1] - rot[:, 1, 2]) / (4 * q[:, 0])
q[:, 2] = (rot[:, 0, 2] - rot[:, 2, 0]) / (4 * q[:, 0])
q[:, 3] = (rot[:, 1, 0] - rot[:, 0, 1]) / (4 * q[:, 0])
return q