qitaoz's picture
Upload 57 files
4562a06 verified
from tkinter import FALSE
import cv2
import ipdb # noqa: F401
import numpy as np
import torch
from pytorch3d.renderer import PerspectiveCameras, RayBundle
from pytorch3d.transforms import Rotate, Translate
from diffusionsfm.utils.normalize import (
compute_optical_axis_intersection,
intersect_skew_line_groups,
first_camera_transform,
intersect_skew_lines_high_dim,
)
from diffusionsfm.utils.distortion import apply_distortion_tensor
class Rays(object):
def __init__(
self,
rays=None,
origins=None,
directions=None,
moments=None,
segments=None,
depths=None,
moments_rescale=1.0,
ndc_coordinates=None,
crop_parameters=None,
num_patches_x=16,
num_patches_y=16,
distortion_coeffs=None,
camera_coordinate_rays=None,
mode=None,
unprojected=None,
depth_resolution=1,
row_form=False,
):
"""
Ray class to keep track of current ray representation.
Args:
rays: (..., 6).
origins: (..., 3).
directions: (..., 3).
moments: (..., 3).
mode: One of "ray", "plucker" or "segment".
moments_rescale: Rescale the moment component of the rays by a scalar.
ndc_coordinates: (..., 2): NDC coordinates of each ray.
"""
self.depth_resolution = depth_resolution
self.num_patches_x = num_patches_x
self.num_patches_y = num_patches_y
if rays is not None:
self.rays = rays
assert mode is not None
self._mode = mode
elif segments is not None:
if not row_form:
segments = Rays.patches_to_rows(
segments,
num_patches_x=num_patches_x,
num_patches_y=num_patches_y,
depth_resolution=depth_resolution,
)
self.rays = torch.cat((origins, segments), dim=-1)
self._mode = "segment"
elif origins is not None and directions is not None:
self.rays = torch.cat((origins, directions), dim=-1)
self._mode = "ray"
elif directions is not None and moments is not None:
self.rays = torch.cat((directions, moments), dim=-1)
self._mode = "plucker"
else:
raise Exception("Invalid combination of arguments")
if depths is not None:
self._mode = mode
self.depths = depths
else:
self.depths = None
assert mode is not None
if unprojected is not None:
self.unprojected = unprojected
else:
self.unprojected = None
if moments_rescale != 1.0:
self.rescale_moments(moments_rescale)
if ndc_coordinates is not None:
self.ndc_coordinates = ndc_coordinates
elif crop_parameters is not None:
# (..., H, W, 2)
xy_grid = compute_ndc_coordinates(
crop_parameters,
num_patches_x=num_patches_x,
num_patches_y=num_patches_y,
distortion_coeffs=distortion_coeffs,
)[..., :2]
xy_grid = xy_grid.reshape(*xy_grid.shape[:-3], -1, 2)
self.ndc_coordinates = xy_grid
else:
self.ndc_coordinates = None
if camera_coordinate_rays is not None:
self.camera_ray_coordinates = True
self.camera_coordinate_ray_directions = camera_coordinate_rays
else:
self.camera_ray_coordinates = False
def __getitem__(self, index):
cam_coord_rays = None
if self.camera_ray_coordinates:
cam_coord_rays = self.camera_coordinate_ray_directions[index]
return Rays(
rays=self.rays[index],
mode=self._mode,
camera_coordinate_rays=cam_coord_rays,
ndc_coordinates=(
self.ndc_coordinates[index]
if self.ndc_coordinates is not None
else None
),
num_patches_x=self.num_patches_x,
num_patches_y=self.num_patches_y,
depths=(
self.depths[index]
if self.ndc_coordinates is not None and self.depths is not None
else None
),
unprojected=(
self.unprojected[index] if self.ndc_coordinates is not None else None
),
depth_resolution=self.depth_resolution,
)
def __len__(self):
return self.rays.shape[0]
def to_spatial(
self, include_ndc_coordinates=False, include_depths=False, use_homogeneous=False
):
"""
Converts rays to spatial representation: (..., H * W, 6) --> (..., 6, H, W)
If use_homogeneous is True, then each 3D component will be 4D and normalized.
Returns:
torch.Tensor: (..., 6, H, W)
"""
if self._mode == "ray":
rays = self.to_plucker().rays
else:
rays = self.rays
*batch_dims, P, D = rays.shape
H = W = int(np.sqrt(P))
assert H * W == P
if use_homogeneous:
rays_reshaped = rays.reshape(*batch_dims, P, D // 3, 3)
ones = torch.ones_like(rays_reshaped[..., :1])
rays_reshaped = torch.cat((rays_reshaped, ones), dim=-1)
rays = torch.nn.functional.normalize(rays_reshaped, dim=-1)
D = (4 * D) // 3
rays = rays.reshape(*batch_dims, P, D)
rays = torch.transpose(rays, -1, -2) # (..., 6, H * W)
rays = rays.reshape(*batch_dims, D, H, W)
if include_depths:
depths = self.depths.unsqueeze(1)
rays = torch.cat((rays, depths), dim=-3)
if include_ndc_coordinates:
ndc_coords = self.ndc_coordinates.transpose(-1, -2) # (..., 2, H * W)
ndc_coords = ndc_coords.reshape(*batch_dims, 2, H, W)
rays = torch.cat((rays, ndc_coords), dim=-3)
return rays
def to_spatial_with_camera_coordinate_rays(
self,
I_camera,
crop_params,
moments_rescale=1.0,
include_ndc_coordinates=False,
use_homogeneous=False,
):
"""
Converts rays to spatial representation: (..., H * W, 6) --> (..., 6, H, W)
Returns:
torch.Tensor: (..., 6, H, W)
"""
rays = self.to_spatial(
include_ndc_coordinates=include_ndc_coordinates,
use_homogeneous=use_homogeneous,
)
N, _, H, W = rays.shape
camera_coord_rays = (
cameras_to_rays(
cameras=I_camera,
num_patches_x=H,
num_patches_y=W,
crop_parameters=crop_params,
)
.rescale_moments(1 / moments_rescale)
.get_directions()
)
self.camera_coordinate_ray_directions = camera_coord_rays
# camera_coord_rays = torch.stack(camera_coord_rays)
camera_coord_rays = torch.transpose(camera_coord_rays, -1, -2)
camera_coord_rays = camera_coord_rays.reshape(N, 3, H, W)
rays = torch.cat((camera_coord_rays, rays), dim=-3)
return rays
def rescale_moments(self, scale):
"""
Rescale the moment component of the rays by a scalar. Might be desirable since
moments may come from a very narrow distribution.
Note that this modifies in place!
"""
assert False, "Deprecated"
if self._mode == "plucker":
self.rays[..., 3:] *= scale
return self
else:
return self.to_plucker().rescale_moments(scale)
def to_spatial_with_camera_coordinate_rays_object(
self,
I_camera,
crop_params,
moments_rescale=1.0,
include_ndc_coordinates=False,
use_homogeneous=False,
):
"""
Converts rays to spatial representation: (..., H * W, 6) --> (..., 6, H, W)
Returns:
torch.Tensor: (..., 6, H, W)
"""
rays = self.to_spatial(include_ndc_coordinates, use_homogeneous=use_homogeneous)
N, _, H, W = rays.shape
camera_coord_rays = (
cameras_to_rays(
cameras=I_camera,
num_patches_x=H,
num_patches_y=W,
crop_parameters=crop_params,
)
.rescale_moments(1 / moments_rescale)
.get_directions()
)
self.camera_coordinate_ray_directions = camera_coord_rays
camera_coord_rays = torch.transpose(camera_coord_rays, -1, -2)
camera_coord_rays = camera_coord_rays.reshape(N, 3, H, W)
@classmethod
def patches_to_rows(cls, x, num_patches_x=16, num_patches_y=16, depth_resolution=1):
B, P, C = x.shape
assert P == (depth_resolution**2 * num_patches_x * num_patches_y)
x = x.reshape(
B,
depth_resolution * num_patches_x,
depth_resolution * num_patches_y,
C,
)
new = x.unfold(1, depth_resolution, depth_resolution).unfold(
2, depth_resolution, depth_resolution
)
new = new.permute((0, 1, 2, 4, 5, 3))
new = new.reshape(
(B, num_patches_x * num_patches_y, depth_resolution * depth_resolution * C)
)
return new
@classmethod
def rows_to_patches(cls, x, num_patches_x=16, num_patches_y=16, depth_resolution=1):
B, P, CP = x.shape
assert P == (num_patches_x * num_patches_y)
C = CP // (depth_resolution**2)
HP, WP = num_patches_x * depth_resolution, num_patches_y * depth_resolution
x = x.reshape(
B, num_patches_x, num_patches_y, depth_resolution, depth_resolution, C
)
x = x.permute(0, 1, 3, 2, 4, 5)
x = x.reshape(B, HP * WP, C)
return x
@classmethod
def upsample_origins(
cls, x, num_patches_x=16, num_patches_y=16, depth_resolution=1
):
B, P, C = x.shape
origins = x.permute((0, 2, 1))
origins = origins.reshape((B, C, num_patches_x, num_patches_y))
origins = torch.nn.functional.interpolate(
origins, scale_factor=(depth_resolution, depth_resolution)
)
origins = origins.permute((0, 2, 3, 1)).reshape(
(B, P * depth_resolution * depth_resolution, C)
)
return origins
@classmethod
def from_spatial_with_camera_coordinate_rays(
cls, rays, mode, moments_rescale=1.0, use_homogeneous=False
):
"""
Converts rays from spatial representation: (..., 6, H, W) --> (..., H * W, 6)
Args:
rays: (..., 6, H, W)
Returns:
Rays: (..., H * W, 6)
"""
*batch_dims, D, H, W = rays.shape
rays = rays.reshape(*batch_dims, D, H * W)
rays = torch.transpose(rays, -1, -2)
camera_coordinate_ray_directions = rays[..., :3]
rays = rays[..., 3:]
return cls(
rays=rays,
mode=mode,
moments_rescale=moments_rescale,
camera_coordinate_rays=camera_coordinate_ray_directions,
)
@classmethod
def from_spatial(
cls,
rays,
mode,
moments_rescale=1.0,
ndc_coordinates=None,
num_patches_x=16,
num_patches_y=16,
use_homogeneous=False,
):
"""
Converts rays from spatial representation: (..., 6, H, W) --> (..., H * W, 6)
Args:
rays: (..., 6, H, W)
Returns:
Rays: (..., H * W, 6)
"""
*batch_dims, D, H, W = rays.shape
rays = rays.reshape(*batch_dims, D, H * W)
rays = torch.transpose(rays, -1, -2)
if use_homogeneous:
D -= 2
if D == 7:
if use_homogeneous:
r1 = rays[..., :3] / (rays[..., 3:4] + 1e-6)
r2 = rays[..., 4:7] / (rays[..., 7:8] + 1e-6)
rays = torch.cat((r1, r2), dim=-1)
depths = rays[8]
else:
old_rays = rays
rays = rays[..., :6]
depths = old_rays[..., 6]
return cls(
rays=rays,
mode=mode,
moments_rescale=moments_rescale,
ndc_coordinates=ndc_coordinates,
depths=depths.reshape(*batch_dims, H, W),
num_patches_x=num_patches_x,
num_patches_y=num_patches_y,
)
elif D > 7:
D += 2
if use_homogeneous:
rays_reshaped = rays.reshape((*batch_dims, H * W, D // 4, 4))
rays_not_homo = rays_reshaped / rays_reshaped[..., :, 3].unsqueeze(-1)
rays = rays_not_homo[..., :, :3].reshape(
(*batch_dims, H * W, (D // 4) * 3)
)
D = (D // 4) * 3
ray = cls(
origins=rays[:, :, :3],
segments=rays[:, :, 3:],
mode="segment",
moments_rescale=moments_rescale,
ndc_coordinates=ndc_coordinates,
# depths=rays[:, :, -1].reshape(*batch_dims, H, W),
row_form=True,
num_patches_x=num_patches_x,
num_patches_y=num_patches_y,
depth_resolution=int(((D - 3) // 3) ** 0.5),
)
if mode == "ray":
return ray.to_point_direction()
elif mode == "plucker":
return ray.to_plucker()
elif mode == "segment":
return ray
else:
assert False
else:
if use_homogeneous:
r1 = rays[..., :3] / (rays[..., 3:4] + 1e-6)
r2 = rays[..., 4:7] / (rays[..., 7:8] + 1e-6)
rays = torch.cat((r1, r2), dim=-1)
return cls(
rays=rays,
mode=mode,
moments_rescale=moments_rescale,
ndc_coordinates=ndc_coordinates,
num_patches_x=num_patches_x,
num_patches_y=num_patches_y,
)
def to_point_direction(self, normalize_moment=True):
"""
Convert to point direction representation <O, D>.
Returns:
rays: (..., 6).
"""
if self._mode == "plucker":
direction = torch.nn.functional.normalize(self.rays[..., :3], dim=-1)
moment = self.rays[..., 3:]
if normalize_moment:
c = torch.linalg.norm(direction, dim=-1, keepdim=True)
moment = moment / c
points = torch.cross(direction, moment, dim=-1)
return Rays(
rays=torch.cat((points, direction), dim=-1),
mode="ray",
ndc_coordinates=self.ndc_coordinates,
num_patches_x=self.num_patches_x,
num_patches_y=self.num_patches_y,
depths=self.depths,
unprojected=self.unprojected,
depth_resolution=self.depth_resolution,
)
elif self._mode == "segment":
origins = self.get_origins(high_res=True)
direction = self.get_segments() - origins
direction = torch.nn.functional.normalize(direction, dim=-1)
return Rays(
rays=torch.cat((origins, direction), dim=-1),
mode="ray",
ndc_coordinates=self.ndc_coordinates,
num_patches_x=self.num_patches_x,
num_patches_y=self.num_patches_y,
depths=self.depths,
unprojected=self.unprojected,
depth_resolution=self.depth_resolution,
)
else:
return self
def to_plucker(self):
"""
Convert to plucker representation <D, OxD>.
"""
if self._mode == "plucker":
return self
elif self._mode == "ray":
ray = self.rays.clone()
ray_origins = ray[..., :3]
ray_directions = ray[..., 3:]
# Normalize ray directions to unit vectors
ray_directions = ray_directions / torch.linalg.vector_norm(
ray_directions, dim=-1, keepdim=True
)
plucker_normal = torch.cross(ray_origins, ray_directions, dim=-1)
new_ray = torch.cat([ray_directions, plucker_normal], dim=-1)
return Rays(
rays=new_ray,
mode="plucker",
ndc_coordinates=self.ndc_coordinates,
num_patches_x=self.num_patches_x,
num_patches_y=self.num_patches_y,
depths=self.depths,
unprojected=self.unprojected,
depth_resolution=self.depth_resolution,
)
elif self._mode == "segment":
return self.to_point_direction().to_plucker()
def get_directions(self, normalize=True):
if self._mode == "plucker":
directions = self.rays[..., :3]
elif self._mode == "segment":
directions = self.to_point_direction().get_directions()
else:
directions = self.rays[..., 3:]
if normalize:
directions = torch.nn.functional.normalize(directions, dim=-1)
return directions
def get_camera_coordinate_rays(self, normalize=True):
directions = self.camera_coordinate_ray_directions
if normalize:
directions = torch.nn.functional.normalize(directions, dim=-1)
return directions
def get_origins(self, high_res=False):
if self._mode == "plucker":
origins = self.to_point_direction().get_origins(high_res=high_res)
elif self._mode == "ray":
origins = self.rays[..., :3]
elif self._mode == "segment":
origins = Rays.upsample_origins(
self.rays[..., :3],
num_patches_x=self.num_patches_x,
num_patches_y=self.num_patches_y,
depth_resolution=self.depth_resolution,
)
else:
assert False
return origins
def get_moments(self):
if self._mode == "plucker":
moments = self.rays[..., 3:]
elif self._mode in ["ray", "segment"]:
moments = self.to_plucker().get_moments()
return moments
def get_segments(self):
assert self._mode == "segment"
if self.unprojected is not None:
return self.unprojected
else:
return Rays.rows_to_patches(
self.rays[..., 3:],
num_patches_x=self.num_patches_x,
num_patches_y=self.num_patches_y,
depth_resolution=self.depth_resolution,
)
def get_ndc_coordinates(self):
return self.ndc_coordinates
@property
def mode(self):
return self._mode
@mode.setter
def mode(self, mode):
self._mode = mode
@property
def device(self):
return self.rays.device
def __repr__(self, *args, **kwargs):
ray_str = self.rays.__repr__(*args, **kwargs)[6:] # remove "tensor"
if self._mode == "plucker":
return "PluRay" + ray_str
elif self._mode == "ray":
return "DirRay" + ray_str
else:
return "SegRay" + ray_str
def to(self, device):
self.rays = self.rays.to(device)
def clone(self):
return Rays(rays=self.rays.clone(), mode=self._mode)
@property
def shape(self):
return self.rays.shape
def visualize(self):
directions = torch.nn.functional.normalize(self.get_directions(), dim=-1).cpu()
moments = torch.nn.functional.normalize(self.get_moments(), dim=-1).cpu()
return (directions + 1) / 2, (moments + 1) / 2
def to_ray_bundle(self, length=0.3, recompute_origin=False, true_length=False):
"""
Args:
length (float): Length of the rays for visualization.
recompute_origin (bool): If True, origin is set to the intersection point of
all rays. If False, origins are the point along the ray closest
"""
origins = self.get_origins(high_res=self.depth_resolution > 1)
lengths = torch.ones_like(origins[..., :2]) * length
lengths[..., 0] = 0
p_intersect, p_closest, _, _ = intersect_skew_line_groups(
origins.float(), self.get_directions().float()
)
if recompute_origin:
centers = p_intersect
centers = centers.unsqueeze(1).repeat(1, lengths.shape[1], 1)
else:
centers = p_closest
if true_length:
length = torch.norm(self.get_segments() - centers, dim=-1).unsqueeze(-1)
lengths = torch.ones_like(origins[..., :2]) * length
lengths[..., 0] = 0
return RayBundle(
origins=centers,
directions=self.get_directions(),
lengths=lengths,
xys=self.get_directions(),
)
def cameras_to_rays(
cameras,
crop_parameters,
use_half_pix=True,
use_plucker=True,
num_patches_x=16,
num_patches_y=16,
no_crop_param_device="cpu",
distortion_coeffs=None,
depths=None,
visualize=False,
mode=None,
depth_resolution=1,
nearest_neighbor=True,
distortion_coefficients=None,
):
"""
Unprojects rays from camera center to grid on image plane.
To match Moneish's code, set use_half_pix=False, use_plucker=True. Also, the
arguments to meshgrid should be swapped (x first, then y). I'm following Pytorch3d
convention to have y first.
distortion_coeffs refers to Amy's distortion experiments
distortion_coefficients refers to the fisheye parameters from colmap
Args:
cameras: Pytorch3D cameras to unproject. Can be batched.
crop_parameters: Crop parameters in NDC (cc_x, cc_y, crop_width, scale).
Shape is (B, 4).
use_half_pix: If True, use half pixel offset (Default: True).
use_plucker: If True, return rays in plucker coordinates (Default: False).
num_patches_x: Number of patches in x direction (Default: 16).
num_patches_y: Number of patches in y direction (Default: 16).
"""
unprojected = []
unprojected_ones = []
crop_parameters_list = (
crop_parameters if crop_parameters is not None else [None for _ in cameras]
)
depths_list = depths if depths is not None else [None for _ in cameras]
if distortion_coeffs is None:
zs = []
for i, (camera, crop_param, depth) in enumerate(
zip(cameras, crop_parameters_list, depths_list)
):
xyd_grid = compute_ndc_coordinates(
crop_parameters=crop_param,
use_half_pix=use_half_pix,
num_patches_x=num_patches_x,
num_patches_y=num_patches_y,
no_crop_param_device=no_crop_param_device,
depths=depth,
return_zs=True,
depth_resolution=depth_resolution,
nearest_neighbor=nearest_neighbor,
)
xyd_grid, z, ones_grid = xyd_grid
zs.append(z)
if (
distortion_coefficients is not None
and (distortion_coefficients[i] != 0).any()
):
xyd_grid = undistort_ndc_coordinates(
ndc_coordinates=xyd_grid,
principal_point=camera.principal_point[0],
focal_length=camera.focal_length[0],
distortion_coefficients=distortion_coefficients[i],
)
unprojected.append(
camera.unproject_points(
xyd_grid.reshape(-1, 3), world_coordinates=True, from_ndc=True
)
)
if depths is not None and mode == "plucker":
unprojected_ones.append(
camera.unproject_points(
ones_grid.reshape(-1, 3), world_coordinates=True, from_ndc=True
)
)
else:
for camera, crop_param, distort_coeff in zip(
cameras, crop_parameters_list, distortion_coeffs
):
xyd_grid = compute_ndc_coordinates(
crop_parameters=crop_param,
use_half_pix=use_half_pix,
num_patches_x=num_patches_x,
num_patches_y=num_patches_y,
no_crop_param_device=no_crop_param_device,
distortion_coeffs=distort_coeff,
depths=depths,
nearest_neighbor=nearest_neighbor,
)
unprojected.append(
camera.unproject_points(
xyd_grid.reshape(-1, 3), world_coordinates=True, from_ndc=True
)
)
unprojected = torch.stack(unprojected, dim=0) # (N, P, 3)
origins = cameras.get_camera_center().unsqueeze(1) # (N, 1, 3)
origins = origins.repeat(1, num_patches_x * num_patches_y, 1) # (N, P, 3)
if depths is None:
directions = unprojected - origins
rays = Rays(
origins=origins,
directions=directions,
crop_parameters=crop_parameters,
num_patches_x=num_patches_x,
num_patches_y=num_patches_y,
distortion_coeffs=distortion_coeffs,
mode="ray",
unprojected=unprojected,
)
if use_plucker:
return rays.to_plucker()
elif mode == "segment":
rays = Rays(
origins=origins,
segments=unprojected,
crop_parameters=crop_parameters,
num_patches_x=num_patches_x,
num_patches_y=num_patches_y,
distortion_coeffs=distortion_coeffs,
depths=torch.stack(zs, dim=0),
mode=mode,
unprojected=unprojected,
depth_resolution=depth_resolution,
)
elif mode == "plucker" or mode == "ray":
unprojected_ones = torch.stack(unprojected_ones)
directions = unprojected_ones - origins
rays = Rays(
origins=origins,
directions=directions,
crop_parameters=crop_parameters,
num_patches_x=num_patches_x,
num_patches_y=num_patches_y,
distortion_coeffs=distortion_coeffs,
depths=torch.stack(zs, dim=0),
mode="ray",
unprojected=unprojected,
)
if mode == "plucker":
rays = rays.to_plucker()
else:
assert False
if visualize:
return rays, unprojected, torch.stack(zs, dim=0)
return rays
def rays_to_cameras(
rays,
crop_parameters,
num_patches_x=16,
num_patches_y=16,
use_half_pix=True,
no_crop_param_device="cpu",
sampled_ray_idx=None,
cameras=None,
focal_length=(3.453,),
distortion_coeffs=None,
calculate_distortion=False,
depth_resolution=1,
average_centers=False,
):
"""
If cameras are provided, will use those intrinsics. Otherwise will use the provided
focal_length(s). Dataset default is 3.32.
Args:
rays (Rays): (N, P, 6)
crop_parameters (torch.Tensor): (N, 4)
"""
device = rays.device
origins = rays.get_origins(high_res=True)
directions = rays.get_directions()
if average_centers:
camera_centers = torch.mean(origins, dim=1)
else:
camera_centers, _ = intersect_skew_lines_high_dim(origins, directions)
# Retrieve target rays
if cameras is None:
if len(focal_length) == 1:
focal_length = focal_length * rays.shape[0]
I_camera = PerspectiveCameras(focal_length=focal_length, device=device)
else:
# Use same intrinsics but reset to identity extrinsics.
I_camera = cameras.clone()
I_camera.R[:] = torch.eye(3, device=device)
I_camera.T[:] = torch.zeros(3, device=device)
if distortion_coeffs is not None and not calculate_distortion:
coeff = distortion_coeffs
else:
coeff = None
I_patch_rays = cameras_to_rays(
cameras=I_camera,
num_patches_x=num_patches_x * depth_resolution,
num_patches_y=num_patches_y * depth_resolution,
use_half_pix=use_half_pix,
crop_parameters=crop_parameters,
no_crop_param_device=no_crop_param_device,
distortion_coeffs=coeff,
mode="plucker",
depth_resolution=depth_resolution,
).get_directions()
if sampled_ray_idx is not None:
I_patch_rays = I_patch_rays[:, sampled_ray_idx]
# Compute optimal rotation to align rays
R = torch.zeros_like(I_camera.R)
for i in range(len(I_camera)):
R[i] = compute_optimal_rotation_alignment(
I_patch_rays[i],
directions[i],
)
# Construct and return rotated camera
cam = I_camera.clone()
cam.R = R
cam.T = -torch.matmul(R.transpose(1, 2), camera_centers.unsqueeze(2)).squeeze(2)
return cam
# https://www.reddit.com/r/learnmath/comments/v1crd7/linear_algebra_qr_to_ql_decomposition/
def ql_decomposition(A):
P = torch.tensor([[0, 0, 1], [0, 1, 0], [1, 0, 0]], device=A.device).float()
A_tilde = torch.matmul(A, P)
Q_tilde, R_tilde = torch.linalg.qr(A_tilde)
Q = torch.matmul(Q_tilde, P)
L = torch.matmul(torch.matmul(P, R_tilde), P)
d = torch.diag(L)
Q[:, 0] *= torch.sign(d[0])
Q[:, 1] *= torch.sign(d[1])
Q[:, 2] *= torch.sign(d[2])
L[0] *= torch.sign(d[0])
L[1] *= torch.sign(d[1])
L[2] *= torch.sign(d[2])
return Q, L
def rays_to_cameras_homography(
rays,
crop_parameters,
num_patches_x=16,
num_patches_y=16,
use_half_pix=True,
sampled_ray_idx=None,
reproj_threshold=0.2,
camera_coordinate_rays=False,
average_centers=False,
depth_resolution=1,
directions_from_averaged_center=False,
):
"""
Args:
rays (Rays): (N, P, 6)
crop_parameters (torch.Tensor): (N, 4)
"""
device = rays.device
origins = rays.get_origins(high_res=True)
directions = rays.get_directions()
if average_centers:
camera_centers = torch.mean(origins, dim=1)
else:
camera_centers, _ = intersect_skew_lines_high_dim(origins, directions)
if directions_from_averaged_center:
assert rays.mode == "segment"
directions = rays.get_segments() - camera_centers.unsqueeze(1).repeat(
(1, num_patches_x * num_patches_y, 1)
)
# Retrieve target rays
I_camera = PerspectiveCameras(focal_length=[1] * rays.shape[0], device=device)
I_patch_rays = cameras_to_rays(
cameras=I_camera,
num_patches_x=num_patches_x * depth_resolution,
num_patches_y=num_patches_y * depth_resolution,
use_half_pix=use_half_pix,
crop_parameters=crop_parameters,
no_crop_param_device=device,
mode="plucker",
).get_directions()
if sampled_ray_idx is not None:
I_patch_rays = I_patch_rays[:, sampled_ray_idx]
# Compute optimal rotation to align rays
if camera_coordinate_rays:
directions_used = rays.get_camera_coordinate_rays()
else:
directions_used = directions
Rs = []
focal_lengths = []
principal_points = []
for i in range(rays.shape[-3]):
R, f, pp = compute_optimal_rotation_intrinsics(
I_patch_rays[i],
directions_used[i],
reproj_threshold=reproj_threshold,
)
Rs.append(R)
focal_lengths.append(f)
principal_points.append(pp)
R = torch.stack(Rs)
focal_lengths = torch.stack(focal_lengths)
principal_points = torch.stack(principal_points)
T = -torch.matmul(R.transpose(1, 2), camera_centers.unsqueeze(2)).squeeze(2)
return PerspectiveCameras(
R=R,
T=T,
focal_length=focal_lengths,
principal_point=principal_points,
device=device,
)
def compute_optimal_rotation_alignment(A, B):
"""
Compute optimal R that minimizes: || A - B @ R ||_F
Args:
A (torch.Tensor): (N, 3)
B (torch.Tensor): (N, 3)
Returns:
R (torch.tensor): (3, 3)
"""
# normally with R @ B, this would be A @ B.T
H = B.T @ A
U, _, Vh = torch.linalg.svd(H, full_matrices=True)
s = torch.linalg.det(U @ Vh)
S_prime = torch.diag(torch.tensor([1, 1, torch.sign(s)], device=A.device))
return U @ S_prime @ Vh
def compute_optimal_rotation_intrinsics(
rays_origin, rays_target, z_threshold=1e-4, reproj_threshold=0.2
):
"""
Note: for some reason, f seems to be 1/f.
Args:
rays_origin (torch.Tensor): (N, 3)
rays_target (torch.Tensor): (N, 3)
z_threshold (float): Threshold for z value to be considered valid.
Returns:
R (torch.tensor): (3, 3)
focal_length (torch.tensor): (2,)
principal_point (torch.tensor): (2,)
"""
device = rays_origin.device
z_mask = torch.logical_and(
torch.abs(rays_target) > z_threshold, torch.abs(rays_origin) > z_threshold
)[:, 2]
rays_target = rays_target[z_mask]
rays_origin = rays_origin[z_mask]
rays_origin = rays_origin[:, :2] / rays_origin[:, -1:]
rays_target = rays_target[:, :2] / rays_target[:, -1:]
try:
A, _ = cv2.findHomography(
rays_origin.cpu().numpy(),
rays_target.cpu().numpy(),
cv2.RANSAC,
reproj_threshold,
)
except:
A, _ = cv2.findHomography(
rays_origin.cpu().numpy(),
rays_target.cpu().numpy(),
cv2.RANSAC,
reproj_threshold,
)
A = torch.from_numpy(A).float().to(device)
if torch.linalg.det(A) < 0:
# TODO: Find a better fix for this. This gives the correct R but incorrect
# intrinsics.
A = -A
R, L = ql_decomposition(A)
L = L / L[2][2]
f = torch.stack((L[0][0], L[1][1]))
# f = torch.stack(((L[0][0] + L[1][1]) / 2, (L[0][0] + L[1][1]) / 2))
pp = torch.stack((L[2][0], L[2][1]))
return R, f, pp
def compute_ndc_coordinates(
crop_parameters=None,
use_half_pix=True,
num_patches_x=16,
num_patches_y=16,
no_crop_param_device="cpu",
distortion_coeffs=None,
depths=None,
return_zs=False,
depth_resolution=1,
nearest_neighbor=True,
):
"""
Computes NDC Grid using crop_parameters. If crop_parameters is not provided,
then it assumes that the crop is the entire image (corresponding to an NDC grid
where top left corner is (1, 1) and bottom right corner is (-1, -1)).
"""
if crop_parameters is None:
cc_x, cc_y, width = 0, 0, 2
device = no_crop_param_device
else:
if len(crop_parameters.shape) > 1:
if distortion_coeffs is None:
return torch.stack(
[
compute_ndc_coordinates(
crop_parameters=crop_param,
use_half_pix=use_half_pix,
num_patches_x=num_patches_x,
num_patches_y=num_patches_y,
nearest_neighbor=nearest_neighbor,
depths=depths[i] if depths is not None else None,
)
for i, crop_param in enumerate(crop_parameters)
],
dim=0,
)
else:
patch_params = zip(crop_parameters, distortion_coeffs)
return torch.stack(
[
compute_ndc_coordinates(
crop_parameters=crop_param,
use_half_pix=use_half_pix,
num_patches_x=num_patches_x,
num_patches_y=num_patches_y,
distortion_coeffs=distortion_coeff,
nearest_neighbor=nearest_neighbor,
)
for crop_param, distortion_coeff in patch_params
],
dim=0,
)
device = crop_parameters.device
cc_x, cc_y, width, _ = crop_parameters
dx = 1 / num_patches_x
dy = 1 / num_patches_y
if use_half_pix:
min_y = 1 - dy
max_y = -min_y
min_x = 1 - dx
max_x = -min_x
else:
min_y = min_x = 1
max_y = -1 + 2 * dy
max_x = -1 + 2 * dx
y, x = torch.meshgrid(
torch.linspace(min_y, max_y, num_patches_y, dtype=torch.float32, device=device),
torch.linspace(min_x, max_x, num_patches_x, dtype=torch.float32, device=device),
indexing="ij",
)
x_prime = x * width / 2 - cc_x
y_prime = y * width / 2 - cc_y
if distortion_coeffs is not None:
points = torch.cat(
(x_prime.flatten().unsqueeze(-1), y_prime.flatten().unsqueeze(-1)),
dim=-1,
)
new_points = apply_distortion_tensor(
points, distortion_coeffs[0], distortion_coeffs[1]
)
x_prime = new_points[:, 0].reshape((num_patches_x, num_patches_y))
y_prime = new_points[:, 1].reshape((num_patches_x, num_patches_y))
if depths is not None:
if depth_resolution > 1:
high_res_grid = compute_ndc_coordinates(
crop_parameters=crop_parameters,
use_half_pix=use_half_pix,
num_patches_x=num_patches_x * depth_resolution,
num_patches_y=num_patches_y * depth_resolution,
no_crop_param_device=no_crop_param_device,
)
x_prime = high_res_grid[..., 0]
y_prime = high_res_grid[..., 1]
z = depths
xyd_grid = torch.stack([x_prime, y_prime, z], dim=-1)
else:
z = torch.ones_like(x)
xyd_grid = torch.stack([x_prime, y_prime, z], dim=-1)
xyd_grid_ones = torch.stack([x_prime, y_prime, torch.ones_like(x_prime)], dim=-1)
if return_zs:
return xyd_grid, z, xyd_grid_ones
return xyd_grid
def undistort_ndc_coordinates(
ndc_coordinates, principal_point, focal_length, distortion_coefficients
):
"""
Given NDC coordinates from a fisheye camera, computes where the coordinates would
have been for a pinhole camera.
Args:
ndc_coordinates (torch.Tensor): (H, W, 3)
principal_point (torch.Tensor): (2,)
focal_length (torch.Tensor): (2,)
distortion_coefficients (torch.Tensor): (4,)
Returns:
torch.Tensor: (H, W, 3)
"""
device = ndc_coordinates.device
x = ndc_coordinates[..., 0]
y = ndc_coordinates[..., 1]
d = ndc_coordinates[..., 2]
# Compute normalized coordinates (using opencv convention where negative is top-left
x = -(x - principal_point[0]) / focal_length[0]
y = -(y - principal_point[1]) / focal_length[1]
distorted = torch.stack((x.flatten(), y.flatten()), 1).unsqueeze(1).cpu().numpy()
undistorted = cv2.fisheye.undistortPoints(
distorted, np.eye(3), distortion_coefficients.cpu().numpy(), np.eye(3)
)
u = torch.tensor(undistorted[:, 0, 0], device=device)
v = torch.tensor(undistorted[:, 0, 1], device=device)
new_x = -u * focal_length[0] + principal_point[0]
new_y = -v * focal_length[1] + principal_point[1]
return torch.stack((new_x.reshape(x.shape), new_y.reshape(y.shape), d), -1)
def get_identity_cameras_with_intrinsics(cameras):
D = len(cameras)
device = cameras.R.device
new_cameras = cameras.clone()
new_cameras.R = torch.eye(3, device=device).unsqueeze(0).repeat((D, 1, 1))
new_cameras.T = torch.zeros((D, 3), device=device)
return new_cameras
def normalize_cameras_batch(
cameras,
scale=1.0,
normalize_first_camera=False,
depths=None,
crop_parameters=None,
num_patches_x=16,
num_patches_y=16,
distortion_coeffs=[None],
first_cam_mediod=False,
return_scales=False,
):
new_cameras = []
undo_transforms = []
scales = []
for i, cam in enumerate(cameras):
if normalize_first_camera:
# Normalize cameras such that first camera is identity and origin is at
# first camera center.
s = 1
if first_cam_mediod:
s = scale_first_cam_mediod(
cam[0],
depths=depths[i][0].unsqueeze(0) if depths is not None else None,
crop_parameters=crop_parameters[i][0].unsqueeze(0),
num_patches_x=num_patches_x,
num_patches_y=num_patches_y,
distortion_coeffs=(
distortion_coeffs[i][0].unsqueeze(0)
if distortion_coeffs[i] is not None
else None
),
)
scales.append(s)
normalized_cameras = first_camera_transform(cam, s, rotation_only=False)
undo_transform = None
else:
out = normalize_cameras(cam, scale=scale, return_scale=depths is not None)
normalized_cameras, undo_transform, s = out
if depths is not None:
depths[i] *= s
if depths.isnan().any():
assert False
new_cameras.append(normalized_cameras)
undo_transforms.append(undo_transform)
if return_scales:
return new_cameras, undo_transforms, scales
return new_cameras, undo_transforms
def scale_first_cam_mediod(
cameras,
scale=1.0,
return_scale=False,
depths=None,
crop_parameters=None,
num_patches_x=16,
num_patches_y=16,
distortion_coeffs=None,
):
xy_grid = (
compute_ndc_coordinates(
depths=depths,
crop_parameters=crop_parameters,
num_patches_x=num_patches_x,
num_patches_y=num_patches_y,
distortion_coeffs=distortion_coeffs,
)
.reshape((-1, 3))
.to(depths.device)
)
verts = cameras.unproject_points(xy_grid, from_ndc=True, world_coordinates=True)
p_intersect = torch.median(
verts.reshape((-1, 3))[: num_patches_x * num_patches_y].float(), dim=0
).values.unsqueeze(0)
d = torch.norm(p_intersect - cameras.get_camera_center())
if d < 0.001:
return 1
return 1 / d
def normalize_cameras(cameras, scale=1.0, return_scale=False):
"""
Normalizes cameras such that the optical axes point to the origin, the rotation is
identity, and the norm of the translation of the first camera is 1.
Args:
cameras (pytorch3d.renderer.cameras.CamerasBase).
scale (float): Norm of the translation of the first camera.
Returns:
new_cameras (pytorch3d.renderer.cameras.CamerasBase): Normalized cameras.
undo_transform (function): Function that undoes the normalization.
"""
# Let distance from first camera to origin be unit
new_cameras = cameras.clone()
new_transform = (
new_cameras.get_world_to_view_transform()
) # potential R is not valid matrix
p_intersect, dist, _, _, _ = compute_optical_axis_intersection(cameras)
if p_intersect is None:
print("Warning: optical axes code has a nan. Returning identity cameras.")
new_cameras.R[:] = torch.eye(3, device=cameras.R.device, dtype=cameras.R.dtype)
new_cameras.T[:] = torch.tensor(
[0, 0, 1], device=cameras.T.device, dtype=cameras.T.dtype
)
return new_cameras, lambda x: x, 1 / scale
d = dist.squeeze(dim=1).squeeze(dim=0)[0]
# Degenerate case
if d == 0:
print(cameras.T)
print(new_transform.get_matrix()[:, 3, :3])
assert False
assert d != 0
# Can't figure out how to make scale part of the transform too without messing up R.
# Ideally, we would just wrap it all in a single Pytorch3D transform so that it
# would work with any structure (eg PointClouds, Meshes).
tR = Rotate(new_cameras.R[0].unsqueeze(0)).inverse()
tT = Translate(p_intersect)
t = tR.compose(tT)
new_transform = t.compose(new_transform)
new_cameras.R = new_transform.get_matrix()[:, :3, :3]
new_cameras.T = new_transform.get_matrix()[:, 3, :3] / d * scale
def undo_transform(cameras):
cameras_copy = cameras.clone()
cameras_copy.T *= d / scale
new_t = (
t.inverse().compose(cameras_copy.get_world_to_view_transform()).get_matrix()
)
cameras_copy.R = new_t[:, :3, :3]
cameras_copy.T = new_t[:, 3, :3]
return cameras_copy
if return_scale:
return new_cameras, undo_transform, scale / d
return new_cameras, undo_transform
def first_camera_transform(cameras, s, rotation_only=True):
new_cameras = cameras.clone()
new_transform = new_cameras.get_world_to_view_transform()
tR = Rotate(new_cameras.R[0].unsqueeze(0))
if rotation_only:
t = tR.inverse()
else:
tT = Translate(new_cameras.T[0].unsqueeze(0))
t = tR.compose(tT).inverse()
new_transform = t.compose(new_transform)
new_cameras.R = new_transform.get_matrix()[:, :3, :3]
new_cameras.T = new_transform.get_matrix()[:, 3, :3] * s
return new_cameras