ZeroShape / utils /camera.py
zxhuang1698's picture
initial commit
414b431
raw
history blame
7.68 kB
# partially from https://github.com/chenhsuanlin/signed-distance-SRN
import numpy as np
import torch
class Pose():
# a pose class with util methods
def __call__(self, R=None, t=None):
assert(R is not None or t is not None)
if R is None:
if not isinstance(t, torch.Tensor): t = torch.tensor(t)
R = torch.eye(3, device=t.device).repeat(*t.shape[:-1], 1, 1)
elif t is None:
if not isinstance(R, torch.Tensor): R = torch.tensor(R)
t = torch.zeros(R.shape[:-1], device=R.device)
else:
if not isinstance(R, torch.Tensor): R = torch.tensor(R)
if not isinstance(t, torch.Tensor): t = torch.tensor(t)
assert(R.shape[:-1]==t.shape and R.shape[-2:]==(3, 3))
R = R.float()
t = t.float()
pose = torch.cat([R, t[..., None]], dim=-1) # [..., 3, 4]
assert(pose.shape[-2:]==(3, 4))
return pose
def invert(self, pose, use_inverse=False):
R, t = pose[..., :3], pose[..., 3:]
R_inv = R.inverse() if use_inverse else R.transpose(-1, -2)
t_inv = (-R_inv@t)[..., 0]
pose_inv = self(R=R_inv, t=t_inv)
return pose_inv
def compose(self, pose_list):
# pose_new(x) = poseN(...(pose2(pose1(x)))...)
pose_new = pose_list[0]
for pose in pose_list[1:]:
pose_new = self.compose_pair(pose_new, pose)
return pose_new
def compose_pair(self, pose_a, pose_b):
# pose_new(x) = pose_b(pose_a(x))
R_a, t_a = pose_a[..., :3], pose_a[..., 3:]
R_b, t_b = pose_b[..., :3], pose_b[..., 3:]
R_new = R_b@R_a
t_new = (R_b@t_a+t_b)[..., 0]
pose_new = self(R=R_new, t=t_new)
return pose_new
pose = Pose()
# unit sphere normalization
def valid_norm_fac(seen_points, mask):
'''
seen_points: [B, H*W, 3]
mask: [B, 1, H, W], boolean
'''
# get valid points
batch_size = seen_points.shape[0]
# [B, H*W]
mask = mask.view(batch_size, seen_points.shape[1])
# get mean and variance by sample
means, max_dists = [], []
for b in range(batch_size):
# [N_valid, 3]
seen_points_valid = seen_points[b][mask[b]]
# [3]
xyz_mean = torch.mean(seen_points_valid, dim=0)
seen_points_valid_zmean = seen_points_valid - xyz_mean
# scalar
max_dist = torch.max(seen_points_valid_zmean.norm(dim=1))
means.append(xyz_mean)
max_dists.append(max_dist)
# [B, 3]
means = torch.stack(means, dim=0)
# [B]
max_dists = torch.stack(max_dists, dim=0)
return means, max_dists
def get_pixel_grid(opt, H, W):
y_range = torch.arange(H, dtype=torch.float32).to(opt.device)
x_range = torch.arange(W, dtype=torch.float32).to(opt.device)
Y, X = torch.meshgrid(y_range, x_range, indexing='ij')
Z = torch.ones_like(Y)
xyz_grid = torch.stack([X, Y, Z],dim=-1).view(-1,3)
return xyz_grid
def unproj_depth(opt, depth, intr):
'''
depth: [B, 1, H, W]
intr: [B, 3, 3]
'''
batch_size, _, H, W = depth.shape
assert opt.H == H == W
depth = depth.squeeze(1)
# [B, 3, 3]
K_inv = torch.linalg.inv(intr).float()
# [1, H*W,3]
pixel_grid = get_pixel_grid(opt, H, W).unsqueeze(0)
# [B, H*W,3]
pixel_grid = pixel_grid.repeat(batch_size, 1, 1)
# [B, 3, H*W]
ray_dirs = K_inv @ pixel_grid.permute(0, 2, 1).contiguous()
# [B, H*W, 3], in camera coordinates
seen_points = ray_dirs.permute(0, 2, 1).contiguous() * depth.view(batch_size, H*W, 1)
return seen_points
def to_hom(X):
'''
X: [B, N, 3]
Returns:
X_hom: [B, N, 4]
'''
X_hom = torch.cat([X, torch.ones_like(X[..., :1])], dim=-1)
return X_hom
def world2cam(X_world, pose):
'''
X_world: [B, N, 3]
pose: [B, 3, 4]
Returns:
X_cam: [B, N, 3]
'''
X_hom = to_hom(X_world)
X_cam = X_hom @ pose.transpose(-1, -2)
return X_cam
def cam2img(X_cam, cam_intr):
'''
X_cam: [B, N, 3]
cam_intr: [B, 3, 3]
Returns:
X_img: [B, N, 3]
'''
X_img = X_cam @ cam_intr.transpose(-1, -2)
return X_img
def proj_points(opt, points, intr, pose):
'''
points: [B, N, 3]
intr: [B, 3, 3]
pose: [B, 3, 4]
'''
# [B, N, 3]
points_cam = world2cam(points, pose)
# [B, N]
depth = points_cam[..., 2]
# [B, N, 3]
points_img = cam2img(points_cam, intr)
# [B, N, 2]
points_2D = points_img[..., :2] / points_img[..., 2:]
return points_2D, depth
def azim_to_rotation_matrix(azim, representation='angle'):
"""Azim is angle with vector +X, rotated in XZ plane"""
if representation == 'rad':
# [B, ]
cos, sin = torch.cos(azim), torch.sin(azim)
elif representation == 'angle':
# [B, ]
azim = azim * np.pi / 180
cos, sin = torch.cos(azim), torch.sin(azim)
elif representation == 'trig':
# [B, 2]
cos, sin = azim[:, 0], azim[:, 1]
R = torch.eye(3, device=azim.device)[None].repeat(len(azim), 1, 1)
zeros = torch.zeros(len(azim), device=azim.device)
R[:, 0, :] = torch.stack([cos, zeros, sin], dim=-1)
R[:, 2, :] = torch.stack([-sin, zeros, cos], dim=-1)
return R
def elev_to_rotation_matrix(elev, representation='angle'):
"""Angle with vector +Z in YZ plane"""
if representation == 'rad':
# [B, ]
cos, sin = torch.cos(elev), torch.sin(elev)
elif representation == 'angle':
# [B, ]
elev = elev * np.pi / 180
cos, sin = torch.cos(elev), torch.sin(elev)
elif representation == 'trig':
# [B, 2]
cos, sin = elev[:, 0], elev[:, 1]
R = torch.eye(3, device=elev.device)[None].repeat(len(elev), 1, 1)
R[:, 1, 1:] = torch.stack([cos, -sin], dim=-1)
R[:, 2, 1:] = torch.stack([sin, cos], dim=-1)
return R
def roll_to_rotation_matrix(roll, representation='angle'):
"""Angle with vector +X in XY plane"""
if representation == 'rad':
# [B, ]
cos, sin = torch.cos(roll), torch.sin(roll)
elif representation == 'angle':
# [B, ]
roll = roll * np.pi / 180
cos, sin = torch.cos(roll), torch.sin(roll)
elif representation == 'trig':
# [B, 2]
cos, sin = roll[:, 0], roll[:, 1]
R = torch.eye(3, device=roll.device)[None].repeat(len(roll), 1, 1)
R[:, 0, :2] = torch.stack([cos, sin], dim=-1)
R[:, 1, :2] = torch.stack([-sin, cos], dim=-1)
return R
def get_rotation_sphere(azim_sample=4, elev_sample=4, roll_sample=4, scales=[1.0], device='cuda'):
rotations = []
azim_range = [0, 360]
elev_range = [0, 360]
roll_range = [0, 360]
azims = np.linspace(azim_range[0], azim_range[1], num=azim_sample, endpoint=False)
elevs = np.linspace(elev_range[0], elev_range[1], num=elev_sample, endpoint=False)
rolls = np.linspace(roll_range[0], roll_range[1], num=roll_sample, endpoint=False)
for scale in scales:
for azim in azims:
for elev in elevs:
for roll in rolls:
Ry = azim_to_rotation_matrix(torch.tensor([azim]))
Rx = elev_to_rotation_matrix(torch.tensor([elev]))
Rz = roll_to_rotation_matrix(torch.tensor([roll]))
R_permute = torch.tensor([
[-1, 0, 0],
[0, 0, -1],
[0, -1, 0]
]).float().to(Ry.device).unsqueeze(0).expand_as(Ry)
R = scale * Rz@Rx@Ry@R_permute
rotations.append(R.to(device).float())
return torch.cat(rotations, dim=0)