InstantSplat / utils /pose_utils.py
kairunwen's picture
add code
35e2073
import math
import numpy as np
import torch
import torch.nn.functional as F
from typing import Tuple
from utils.stepfun import sample_np, sample
import scipy
def quad2rotation(q):
"""
Convert quaternion to rotation in batch. Since all operation in pytorch, support gradient passing.
Args:
quad (tensor, batch_size*4): quaternion.
Returns:
rot_mat (tensor, batch_size*3*3): rotation.
"""
# bs = quad.shape[0]
# qr, qi, qj, qk = quad[:, 0], quad[:, 1], quad[:, 2], quad[:, 3]
# two_s = 2.0 / (quad * quad).sum(-1)
# rot_mat = torch.zeros(bs, 3, 3).to(quad.get_device())
# rot_mat[:, 0, 0] = 1 - two_s * (qj**2 + qk**2)
# rot_mat[:, 0, 1] = two_s * (qi * qj - qk * qr)
# rot_mat[:, 0, 2] = two_s * (qi * qk + qj * qr)
# rot_mat[:, 1, 0] = two_s * (qi * qj + qk * qr)
# rot_mat[:, 1, 1] = 1 - two_s * (qi**2 + qk**2)
# rot_mat[:, 1, 2] = two_s * (qj * qk - qi * qr)
# rot_mat[:, 2, 0] = two_s * (qi * qk - qj * qr)
# rot_mat[:, 2, 1] = two_s * (qj * qk + qi * qr)
# rot_mat[:, 2, 2] = 1 - two_s * (qi**2 + qj**2)
# return rot_mat
if not isinstance(q, torch.Tensor):
q = torch.tensor(q).cuda()
norm = torch.sqrt(
q[:, 0] * q[:, 0] + q[:, 1] * q[:, 1] + q[:, 2] * q[:, 2] + q[:, 3] * q[:, 3]
)
q = q / norm[:, None]
rot = torch.zeros((q.size(0), 3, 3)).to(q)
r = q[:, 0]
x = q[:, 1]
y = q[:, 2]
z = q[:, 3]
rot[:, 0, 0] = 1 - 2 * (y * y + z * z)
rot[:, 0, 1] = 2 * (x * y - r * z)
rot[:, 0, 2] = 2 * (x * z + r * y)
rot[:, 1, 0] = 2 * (x * y + r * z)
rot[:, 1, 1] = 1 - 2 * (x * x + z * z)
rot[:, 1, 2] = 2 * (y * z - r * x)
rot[:, 2, 0] = 2 * (x * z - r * y)
rot[:, 2, 1] = 2 * (y * z + r * x)
rot[:, 2, 2] = 1 - 2 * (x * x + y * y)
return rot
def get_camera_from_tensor(inputs):
"""
Convert quaternion and translation to transformation matrix.
"""
if not isinstance(inputs, torch.Tensor):
inputs = torch.tensor(inputs).cuda()
N = len(inputs.shape)
if N == 1:
inputs = inputs.unsqueeze(0)
# quad, T = inputs[:, :4], inputs[:, 4:]
# # normalize quad
# quad = F.normalize(quad)
# R = quad2rotation(quad)
# RT = torch.cat([R, T[:, :, None]], 2)
# # Add homogenous row
# homogenous_row = torch.tensor([0, 0, 0, 1]).cuda()
# RT = torch.cat([RT, homogenous_row[None, None, :].repeat(N, 1, 1)], 1)
# if N == 1:
# RT = RT[0]
# return RT
quad, T = inputs[:, :4], inputs[:, 4:]
w2c = torch.eye(4).to(inputs).float()
w2c[:3, :3] = quad2rotation(quad)
w2c[:3, 3] = T
return w2c
def quadmultiply(q1, q2):
"""
Multiply two quaternions together using quaternion arithmetic
"""
# Extract scalar and vector parts of the quaternions
w1, x1, y1, z1 = q1.unbind(dim=-1)
w2, x2, y2, z2 = q2.unbind(dim=-1)
# Calculate the quaternion product
result_quaternion = torch.stack(
[
w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2,
w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2,
w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2,
w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2,
],
dim=-1,
)
return result_quaternion
def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
"""
Returns torch.sqrt(torch.max(0, x))
but with a zero subgradient where x is 0.
Source: https://pytorch3d.readthedocs.io/en/latest/_modules/pytorch3d/transforms/rotation_conversions.html#matrix_to_quaternion
"""
ret = torch.zeros_like(x)
positive_mask = x > 0
ret[positive_mask] = torch.sqrt(x[positive_mask])
return ret
def rotation2quad(matrix: torch.Tensor) -> torch.Tensor:
"""
Convert rotations given as rotation matrices to quaternions.
Args:
matrix: Rotation matrices as tensor of shape (..., 3, 3).
Returns:
quaternions with real part first, as tensor of shape (..., 4).
Source: https://pytorch3d.readthedocs.io/en/latest/_modules/pytorch3d/transforms/rotation_conversions.html#matrix_to_quaternion
"""
if matrix.size(-1) != 3 or matrix.size(-2) != 3:
raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
if not isinstance(matrix, torch.Tensor):
matrix = torch.tensor(matrix).cuda()
batch_dim = matrix.shape[:-2]
m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(
matrix.reshape(batch_dim + (9,)), dim=-1
)
q_abs = _sqrt_positive_part(
torch.stack(
[
1.0 + m00 + m11 + m22,
1.0 + m00 - m11 - m22,
1.0 - m00 + m11 - m22,
1.0 - m00 - m11 + m22,
],
dim=-1,
)
)
# we produce the desired quaternion multiplied by each of r, i, j, k
quat_by_rijk = torch.stack(
[
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
# `int`.
torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
# `int`.
torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
# `int`.
torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
# `int`.
torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
],
dim=-2,
)
# We floor here at 0.1 but the exact level is not important; if q_abs is small,
# the candidate won't be picked.
flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
# if not for numerical problems, quat_candidates[i] should be same (up to a sign),
# forall i; we pick the best-conditioned one (with the largest denominator)
return quat_candidates[
F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :
].reshape(batch_dim + (4,))
def get_tensor_from_camera(RT, Tquad=False):
"""
Convert transformation matrix to quaternion and translation.
"""
# gpu_id = -1
# if type(RT) == torch.Tensor:
# if RT.get_device() != -1:
# gpu_id = RT.get_device()
# RT = RT.detach().cpu()
# RT = RT.numpy()
# from mathutils import Matrix
#
# R, T = RT[:3, :3], RT[:3, 3]
# rot = Matrix(R)
# quad = rot.to_quaternion()
# if Tquad:
# tensor = np.concatenate([T, quad], 0)
# else:
# tensor = np.concatenate([quad, T], 0)
# tensor = torch.from_numpy(tensor).float()
# if gpu_id != -1:
# tensor = tensor.to(gpu_id)
# return tensor
if not isinstance(RT, torch.Tensor):
RT = torch.tensor(RT).cuda()
rot = RT[:3, :3].unsqueeze(0).detach()
quat = rotation2quad(rot).squeeze()
tran = RT[:3, 3].detach()
return torch.cat([quat, tran])
def normalize(x):
return x / np.linalg.norm(x)
def viewmatrix(lookdir, up, position, subtract_position=False):
"""Construct lookat view matrix."""
vec2 = normalize((lookdir - position) if subtract_position else lookdir)
vec0 = normalize(np.cross(up, vec2))
vec1 = normalize(np.cross(vec2, vec0))
m = np.stack([vec0, vec1, vec2, position], axis=1)
return m
def poses_avg(poses):
"""New pose using average position, z-axis, and up vector of input poses."""
position = poses[:, :3, 3].mean(0)
z_axis = poses[:, :3, 2].mean(0)
up = poses[:, :3, 1].mean(0)
cam2world = viewmatrix(z_axis, up, position)
return cam2world
def focus_point_fn(poses):
"""Calculate nearest point to all focal axes in poses."""
directions, origins = poses[:, :3, 2:3], poses[:, :3, 3:4]
m = np.eye(3) - directions * np.transpose(directions, [0, 2, 1])
mt_m = np.transpose(m, [0, 2, 1]) @ m
focus_pt = np.linalg.inv(mt_m.mean(0)) @ (mt_m @ origins).mean(0)[:, 0]
return focus_pt
def pad_poses(p):
"""Pad [..., 3, 4] pose matrices with a homogeneous bottom row [0,0,0,1]."""
bottom = np.broadcast_to([0, 0, 0, 1.], p[..., :1, :4].shape)
return np.concatenate([p[..., :3, :4], bottom], axis=-2)
def unpad_poses(p):
"""Remove the homogeneous bottom row from [..., 4, 4] pose matrices."""
return p[..., :3, :4]
def transform_poses_pca(poses):
"""Transforms poses so principal components lie on XYZ axes.
Args:
poses: a (N, 3, 4) array containing the cameras' camera to world transforms.
Returns:
A tuple (poses, transform), with the transformed poses and the applied
camera_to_world transforms.
"""
t = poses[:, :3, 3]
t_mean = t.mean(axis=0)
t = t - t_mean
eigval, eigvec = np.linalg.eig(t.T @ t)
# Sort eigenvectors in order of largest to smallest eigenvalue.
inds = np.argsort(eigval)[::-1]
eigvec = eigvec[:, inds]
rot = eigvec.T
if np.linalg.det(rot) < 0:
rot = np.diag(np.array([1, 1, -1])) @ rot
transform = np.concatenate([rot, rot @ -t_mean[:, None]], -1)
poses_recentered = unpad_poses(transform @ pad_poses(poses))
transform = np.concatenate([transform, np.eye(4)[3:]], axis=0)
# Flip coordinate system if z component of y-axis is negative
if poses_recentered.mean(axis=0)[2, 1] < 0:
poses_recentered = np.diag(np.array([1, -1, -1])) @ poses_recentered
transform = np.diag(np.array([1, -1, -1, 1])) @ transform
# Just make sure it's it in the [-1, 1]^3 cube
scale_factor = 1. / np.max(np.abs(poses_recentered[:, :3, 3]))
poses_recentered[:, :3, 3] *= scale_factor
transform = np.diag(np.array([scale_factor] * 3 + [1])) @ transform
return poses_recentered, transform
def recenter_poses(poses: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
"""Recenter poses around the origin."""
cam2world = poses_avg(poses)
transform = np.linalg.inv(pad_poses(cam2world))
poses = transform @ pad_poses(poses)
return unpad_poses(poses), transform
def generate_ellipse_path(views, n_frames=600, const_speed=True, z_variation=0., z_phase=0.):
poses = []
for view in views:
tmp_view = np.eye(4)
tmp_view[:3] = np.concatenate([view.R.T, view.T[:, None]], 1)
tmp_view = np.linalg.inv(tmp_view)
tmp_view[:, 1:3] *= -1
poses.append(tmp_view)
poses = np.stack(poses, 0)
poses, transform = transform_poses_pca(poses)
# Calculate the focal point for the path (cameras point toward this).
center = focus_point_fn(poses)
# Path height sits at z=0 (in middle of zero-mean capture pattern).
offset = np.array([center[0] , center[1], 0 ])
# Calculate scaling for ellipse axes based on input camera positions.
sc = np.percentile(np.abs(poses[:, :3, 3] - offset), 90, axis=0)
# Use ellipse that is symmetric about the focal point in xy.
low = -sc + offset
high = sc + offset
# Optional height variation need not be symmetric
z_low = np.percentile((poses[:, :3, 3]), 10, axis=0)
z_high = np.percentile((poses[:, :3, 3]), 90, axis=0)
def get_positions(theta):
# Interpolate between bounds with trig functions to get ellipse in x-y.
# Optionally also interpolate in z to change camera height along path.
return np.stack([
(low[0] + (high - low)[0] * (np.cos(theta) * .5 + .5)),
(low[1] + (high - low)[1] * (np.sin(theta) * .5 + .5)),
z_variation * (z_low[2] + (z_high - z_low)[2] *
(np.cos(theta + 2 * np.pi * z_phase) * .5 + .5)),
], -1)
theta = np.linspace(0, 2. * np.pi, n_frames + 1, endpoint=True)
positions = get_positions(theta)
if const_speed:
# Resample theta angles so that the velocity is closer to constant.
lengths = np.linalg.norm(positions[1:] - positions[:-1], axis=-1)
theta = sample_np(None, theta, np.log(lengths), n_frames + 1)
positions = get_positions(theta)
# Throw away duplicated last position.
positions = positions[:-1]
# Set path's up vector to axis closest to average of input pose up vectors.
avg_up = poses[:, :3, 1].mean(0)
avg_up = avg_up / np.linalg.norm(avg_up)
ind_up = np.argmax(np.abs(avg_up))
up = np.eye(3)[ind_up] * np.sign(avg_up[ind_up])
# up = normalize(poses[:, :3, 1].sum(0))
render_poses = []
for p in positions:
render_pose = np.eye(4)
render_pose[:3] = viewmatrix(p - center, up, p)
render_pose = np.linalg.inv(transform) @ render_pose
render_pose[:3, 1:3] *= -1
render_poses.append(np.linalg.inv(render_pose))
return render_poses
def generate_spiral_path(poses_arr,
n_frames: int = 180,
n_rots: int = 2,
zrate: float = .5) -> np.ndarray:
"""Calculates a forward facing spiral path for rendering."""
poses = poses_arr[:, :-2].reshape([-1, 3, 5])
bounds = poses_arr[:, -2:]
fix_rotation = np.array([
[0, -1, 0, 0],
[1, 0, 0, 0],
[0, 0, 1, 0],
[0, 0, 0, 1],
], dtype=np.float32)
poses = poses[:, :3, :4] @ fix_rotation
scale = 1. / (bounds.min() * .75)
poses[:, :3, 3] *= scale
bounds *= scale
poses, transform = recenter_poses(poses)
close_depth, inf_depth = bounds.min() * .9, bounds.max() * 5.
dt = .75
focal = 1 / (((1 - dt) / close_depth + dt / inf_depth))
# Get radii for spiral path using 90th percentile of camera positions.
positions = poses[:, :3, 3]
radii = np.percentile(np.abs(positions), 90, 0)
radii = np.concatenate([radii, [1.]])
# Generate poses for spiral path.
render_poses = []
cam2world = poses_avg(poses)
up = poses[:, :3, 1].mean(0)
for theta in np.linspace(0., 2. * np.pi * n_rots, n_frames, endpoint=False):
t = radii * [np.cos(theta), -np.sin(theta), -np.sin(theta * zrate), 1.]
position = cam2world @ t
lookat = cam2world @ [0, 0, -focal, 1.]
z_axis = position - lookat
render_pose = np.eye(4)
render_pose[:3] = viewmatrix(z_axis, up, position)
render_pose = np.linalg.inv(transform) @ render_pose
render_pose[:3, 1:3] *= -1
render_pose[:3, 3] /= scale
render_poses.append(np.linalg.inv(render_pose))
render_poses = np.stack(render_poses, axis=0)
return render_poses
def generate_interpolated_path(
views,
n_interp,
spline_degree = 5,
smoothness = 0.03,
rot_weight = 0.1,
lock_up = False,
fixed_up_vector = None,
lookahead_i = None,
frames_per_colmap = None,
const_speed = False,
n_buffer = None,
periodic = False,
n_interp_as_total = False,
):
"""Creates a smooth spline path between input keyframe camera poses.
Spline is calculated with poses in format (position, lookat-point, up-point).
Args:
poses: (n, 3, 4) array of input pose keyframes.
n_interp: returned path will have n_interp * (n - 1) total poses.
spline_degree: polynomial degree of B-spline.
smoothness: parameter for spline smoothing, 0 forces exact interpolation.
rot_weight: relative weighting of rotation/translation in spline solve.
lock_up: if True, forced to use given Up and allow Lookat to vary.
fixed_up_vector: replace the interpolated `up` with a fixed vector.
lookahead_i: force the look direction to look at the pose `i` frames ahead.
frames_per_colmap: conversion factor for the desired average velocity.
const_speed: renormalize spline to have constant delta between each pose.
n_buffer: Number of buffer frames to insert at the start and end of the
path. Helps keep the ends of a spline path straight.
periodic: make the spline path periodic (perfect loop).
n_interp_as_total: use n_interp as total number of poses in path rather than
the number of poses to interpolate between each input.
Returns:
Array of new camera poses with shape (n_interp * (n - 1), 3, 4), or
(n_interp, 3, 4) if n_interp_as_total is set.
"""
poses = []
for view in views:
tmp_view = np.eye(4)
tmp_view[:3] = np.concatenate([view.R.T, view.T[:, None]], 1)
tmp_view = np.linalg.inv(tmp_view)
tmp_view[:, 1:3] *= -1
poses.append(tmp_view)
poses = np.stack(poses, 0)
def poses_to_points(poses, dist):
"""Converts from pose matrices to (position, lookat, up) format."""
pos = poses[:, :3, -1]
lookat = poses[:, :3, -1] - dist * poses[:, :3, 2]
up = poses[:, :3, -1] + dist * poses[:, :3, 1]
return np.stack([pos, lookat, up], 1)
def points_to_poses(points):
"""Converts from (position, lookat, up) format to pose matrices."""
poses = []
for i in range(len(points)):
pos, lookat_point, up_point = points[i]
if lookahead_i is not None:
if i + lookahead_i < len(points):
lookat = pos - points[i + lookahead_i][0]
else:
lookat = pos - lookat_point
up = (up_point - pos) if fixed_up_vector is None else fixed_up_vector
poses.append(viewmatrix(lookat, up, pos))
return np.array(poses)
def insert_buffer_poses(poses, n_buffer):
"""Insert extra poses at the start and end of the path."""
def average_distance(points):
distances = np.linalg.norm(points[1:] - points[0:-1], axis=-1)
return np.mean(distances)
def shift(pose, dz):
result = np.copy(pose)
z = result[:3, 2]
z /= np.linalg.norm(z)
# Move along forward-backward axis. -z is forward.
result[:3, 3] += z * dz
return result
dz = average_distance(poses[:, :3, 3])
prefix = np.stack([shift(poses[0], (i + 1) * dz) for i in range(n_buffer)])
prefix = prefix[::-1] # reverse order
suffix = np.stack(
[shift(poses[-1], -(i + 1) * dz) for i in range(n_buffer)]
)
result = np.concatenate([prefix, poses, suffix])
return result
def remove_buffer_poses(poses, u, n_frames, u_keyframes, n_buffer):
u_keyframes = u_keyframes[n_buffer:-n_buffer]
mask = (u >= u_keyframes[0]) & (u <= u_keyframes[-1])
poses = poses[mask]
u = u[mask]
n_frames = len(poses)
return poses, u, n_frames, u_keyframes
def interp(points, u, k, s):
"""Runs multidimensional B-spline interpolation on the input points."""
sh = points.shape
pts = np.reshape(points, (sh[0], -1))
k = min(k, sh[0] - 1)
tck, u_keyframes = scipy.interpolate.splprep(pts.T, k=k, s=s, per=periodic)
new_points = np.array(scipy.interpolate.splev(u, tck))
new_points = np.reshape(new_points.T, (len(u), sh[1], sh[2]))
return new_points, u_keyframes
if n_buffer is not None:
poses = insert_buffer_poses(poses, n_buffer)
points = poses_to_points(poses, dist=rot_weight)
if n_interp_as_total:
n_frames = n_interp + 1 # Add extra since final pose is discarded.
else:
n_frames = n_interp * (points.shape[0] - 1)
u = np.linspace(0, 1, n_frames, endpoint=True)
new_points, u_keyframes = interp(points, u=u, k=spline_degree, s=smoothness)
poses = points_to_poses(new_points)
if n_buffer is not None:
poses, u, n_frames, u_keyframes = remove_buffer_poses(
poses, u, n_frames, u_keyframes, n_buffer
)
# poses, transform = transform_poses_pca(poses)
if frames_per_colmap is not None:
# Recalculate the number of frames to achieve desired average velocity.
positions = poses[:, :3, -1]
lengths = np.linalg.norm(positions[1:] - positions[:-1], axis=-1)
total_length_colmap = lengths.sum()
print('old n_frames:', n_frames)
print('total_length_colmap:', total_length_colmap)
n_frames = int(total_length_colmap * frames_per_colmap)
print('new n_frames:', n_frames)
u = np.linspace(
np.min(u_keyframes), np.max(u_keyframes), n_frames, endpoint=True
)
new_points, _ = interp(points, u=u, k=spline_degree, s=smoothness)
poses = points_to_poses(new_points)
if const_speed:
# Resample timesteps so that the velocity is nearly constant.
positions = poses[:, :3, -1]
lengths = np.linalg.norm(positions[1:] - positions[:-1], axis=-1)
u = sample(None, u, np.log(lengths), n_frames + 1)
new_points, _ = interp(points, u=u, k=spline_degree, s=smoothness)
poses = points_to_poses(new_points)
# return poses[:-1], u[:-1], u_keyframes
return poses[:-1]