Spaces:
Running
Running
""" | |
Convenience classes for an SE3 pose and a pinhole Camera with lens distortion. | |
Based on PyTorch tensors: differentiable, batched, with GPU support. | |
""" | |
import functools | |
import inspect | |
import math | |
from typing import Dict, List, NamedTuple, Optional, Tuple, Union | |
import numpy as np | |
import torch | |
from .utils import ( | |
J_distort_points, | |
distort_points, | |
skew_symmetric, | |
so3exp_map, | |
to_homogeneous, | |
) | |
def autocast(func): | |
"""Cast the inputs of a TensorWrapper method to PyTorch tensors | |
if they are numpy arrays. Use the device and dtype of the wrapper. | |
""" | |
def wrap(self, *args): | |
device = torch.device("cpu") | |
dtype = None | |
if isinstance(self, TensorWrapper): | |
if self._data is not None: | |
device = self.device | |
dtype = self.dtype | |
elif not inspect.isclass(self) or not issubclass(self, TensorWrapper): | |
raise ValueError(self) | |
cast_args = [] | |
for arg in args: | |
if isinstance(arg, np.ndarray): | |
arg = torch.from_numpy(arg) | |
arg = arg.to(device=device, dtype=dtype) | |
cast_args.append(arg) | |
return func(self, *cast_args) | |
return wrap | |
class TensorWrapper: | |
_data = None | |
def __init__(self, data: torch.Tensor): | |
self._data = data | |
def shape(self): | |
return self._data.shape[:-1] | |
def device(self): | |
return self._data.device | |
def dtype(self): | |
return self._data.dtype | |
def __getitem__(self, index): | |
return self.__class__(self._data[index]) | |
def __setitem__(self, index, item): | |
self._data[index] = item.data | |
def to(self, *args, **kwargs): | |
return self.__class__(self._data.to(*args, **kwargs)) | |
def cpu(self): | |
return self.__class__(self._data.cpu()) | |
def cuda(self): | |
return self.__class__(self._data.cuda()) | |
def pin_memory(self): | |
return self.__class__(self._data.pin_memory()) | |
def float(self): | |
return self.__class__(self._data.float()) | |
def double(self): | |
return self.__class__(self._data.double()) | |
def detach(self): | |
return self.__class__(self._data.detach()) | |
def stack(cls, objects: List, dim=0, *, out=None): | |
data = torch.stack([obj._data for obj in objects], dim=dim, out=out) | |
return cls(data) | |
def __torch_function__(self, func, types, args=(), kwargs=None): | |
if kwargs is None: | |
kwargs = {} | |
if func is torch.stack: | |
return self.stack(*args, **kwargs) | |
else: | |
return NotImplemented | |
class Pose(TensorWrapper): | |
def __init__(self, data: torch.Tensor): | |
assert data.shape[-1] == 12 | |
super().__init__(data) | |
def from_Rt(cls, R: torch.Tensor, t: torch.Tensor): | |
"""Pose from a rotation matrix and translation vector. | |
Accepts numpy arrays or PyTorch tensors. | |
Args: | |
R: rotation matrix with shape (..., 3, 3). | |
t: translation vector with shape (..., 3). | |
""" | |
assert R.shape[-2:] == (3, 3) | |
assert t.shape[-1] == 3 | |
assert R.shape[:-2] == t.shape[:-1] | |
data = torch.cat([R.flatten(start_dim=-2), t], -1) | |
return cls(data) | |
def from_aa(cls, aa: torch.Tensor, t: torch.Tensor): | |
"""Pose from an axis-angle rotation vector and translation vector. | |
Accepts numpy arrays or PyTorch tensors. | |
Args: | |
aa: axis-angle rotation vector with shape (..., 3). | |
t: translation vector with shape (..., 3). | |
""" | |
assert aa.shape[-1] == 3 | |
assert t.shape[-1] == 3 | |
assert aa.shape[:-1] == t.shape[:-1] | |
return cls.from_Rt(so3exp_map(aa), t) | |
def from_4x4mat(cls, T: torch.Tensor): | |
"""Pose from an SE(3) transformation matrix. | |
Args: | |
T: transformation matrix with shape (..., 4, 4). | |
""" | |
assert T.shape[-2:] == (4, 4) | |
R, t = T[..., :3, :3], T[..., :3, 3] | |
return cls.from_Rt(R, t) | |
def from_colmap(cls, image: NamedTuple): | |
"""Pose from a COLMAP Image.""" | |
return cls.from_Rt(image.qvec2rotmat(), image.tvec) | |
def R(self) -> torch.Tensor: | |
"""Underlying rotation matrix with shape (..., 3, 3).""" | |
rvec = self._data[..., :9] | |
return rvec.reshape(rvec.shape[:-1] + (3, 3)) | |
def t(self) -> torch.Tensor: | |
"""Underlying translation vector with shape (..., 3).""" | |
return self._data[..., -3:] | |
def inv(self) -> "Pose": | |
"""Invert an SE(3) pose.""" | |
R = self.R.transpose(-1, -2) | |
t = -(R @ self.t.unsqueeze(-1)).squeeze(-1) | |
return self.__class__.from_Rt(R, t) | |
def compose(self, other: "Pose") -> "Pose": | |
"""Chain two SE(3) poses: T_B2C.compose(T_A2B) -> T_A2C.""" | |
R = self.R @ other.R | |
t = self.t + (self.R @ other.t.unsqueeze(-1)).squeeze(-1) | |
return self.__class__.from_Rt(R, t) | |
def transform(self, p3d: torch.Tensor) -> torch.Tensor: | |
"""Transform a set of 3D points. | |
Args: | |
p3d: 3D points, numpy array or PyTorch tensor with shape (..., 3). | |
""" | |
assert p3d.shape[-1] == 3 | |
# assert p3d.shape[:-2] == self.shape # allow broadcasting | |
return p3d @ self.R.transpose(-1, -2) + self.t.unsqueeze(-2) | |
def __mul__(self, p3D: torch.Tensor) -> torch.Tensor: | |
"""Transform a set of 3D points: T_A2B * p3D_A -> p3D_B.""" | |
return self.transform(p3D) | |
def __matmul__( | |
self, other: Union["Pose", torch.Tensor] | |
) -> Union["Pose", torch.Tensor]: | |
"""Transform a set of 3D points: T_A2B * p3D_A -> p3D_B. | |
or chain two SE(3) poses: T_B2C @ T_A2B -> T_A2C.""" | |
if isinstance(other, self.__class__): | |
return self.compose(other) | |
else: | |
return self.transform(other) | |
def J_transform(self, p3d_out: torch.Tensor): | |
# [[1,0,0,0,-pz,py], | |
# [0,1,0,pz,0,-px], | |
# [0,0,1,-py,px,0]] | |
J_t = torch.diag_embed(torch.ones_like(p3d_out)) | |
J_rot = -skew_symmetric(p3d_out) | |
J = torch.cat([J_t, J_rot], dim=-1) | |
return J # N x 3 x 6 | |
def numpy(self) -> Tuple[np.ndarray]: | |
return self.R.numpy(), self.t.numpy() | |
def magnitude(self) -> Tuple[torch.Tensor]: | |
"""Magnitude of the SE(3) transformation. | |
Returns: | |
dr: rotation anngle in degrees. | |
dt: translation distance in meters. | |
""" | |
trace = torch.diagonal(self.R, dim1=-1, dim2=-2).sum(-1) | |
cos = torch.clamp((trace - 1) / 2, -1, 1) | |
dr = torch.acos(cos).abs() / math.pi * 180 | |
dt = torch.norm(self.t, dim=-1) | |
return dr, dt | |
def __repr__(self): | |
return f"Pose: {self.shape} {self.dtype} {self.device}" | |
class Camera(TensorWrapper): | |
eps = 1e-4 | |
def __init__(self, data: torch.Tensor): | |
assert data.shape[-1] in {6, 8, 10} | |
super().__init__(data) | |
def from_colmap(cls, camera: Union[Dict, NamedTuple]): | |
"""Camera from a COLMAP Camera tuple or dictionary. | |
We use the corner-convetion from COLMAP (center of top left pixel is (0.5, 0.5)) | |
""" | |
if isinstance(camera, tuple): | |
camera = camera._asdict() | |
model = camera["model"] | |
params = camera["params"] | |
if model in ["OPENCV", "PINHOLE", "RADIAL"]: | |
(fx, fy, cx, cy), params = np.split(params, [4]) | |
elif model in ["SIMPLE_PINHOLE", "SIMPLE_RADIAL"]: | |
(f, cx, cy), params = np.split(params, [3]) | |
fx = fy = f | |
if model == "SIMPLE_RADIAL": | |
params = np.r_[params, 0.0] | |
else: | |
raise NotImplementedError(model) | |
data = np.r_[camera["width"], camera["height"], fx, fy, cx, cy, params] | |
return cls(data) | |
def from_calibration_matrix(cls, K: torch.Tensor): | |
cx, cy = K[..., 0, 2], K[..., 1, 2] | |
fx, fy = K[..., 0, 0], K[..., 1, 1] | |
data = torch.stack([2 * cx, 2 * cy, fx, fy, cx, cy], -1) | |
return cls(data) | |
def calibration_matrix(self): | |
K = torch.zeros( | |
*self._data.shape[:-1], | |
3, | |
3, | |
device=self._data.device, | |
dtype=self._data.dtype, | |
) | |
K[..., 0, 2] = self._data[..., 4] | |
K[..., 1, 2] = self._data[..., 5] | |
K[..., 0, 0] = self._data[..., 2] | |
K[..., 1, 1] = self._data[..., 3] | |
K[..., 2, 2] = 1.0 | |
return K | |
def size(self) -> torch.Tensor: | |
"""Size (width height) of the images, with shape (..., 2).""" | |
return self._data[..., :2] | |
def f(self) -> torch.Tensor: | |
"""Focal lengths (fx, fy) with shape (..., 2).""" | |
return self._data[..., 2:4] | |
def c(self) -> torch.Tensor: | |
"""Principal points (cx, cy) with shape (..., 2).""" | |
return self._data[..., 4:6] | |
def dist(self) -> torch.Tensor: | |
"""Distortion parameters, with shape (..., {0, 2, 4}).""" | |
return self._data[..., 6:] | |
def scale(self, scales: torch.Tensor): | |
"""Update the camera parameters after resizing an image.""" | |
s = scales | |
data = torch.cat([self.size * s, self.f * s, self.c * s, self.dist], -1) | |
return self.__class__(data) | |
def crop(self, left_top: Tuple[float], size: Tuple[int]): | |
"""Update the camera parameters after cropping an image.""" | |
left_top = self._data.new_tensor(left_top) | |
size = self._data.new_tensor(size) | |
data = torch.cat([size, self.f, self.c - left_top, self.dist], -1) | |
return self.__class__(data) | |
def in_image(self, p2d: torch.Tensor): | |
"""Check if 2D points are within the image boundaries.""" | |
assert p2d.shape[-1] == 2 | |
# assert p2d.shape[:-2] == self.shape # allow broadcasting | |
size = self.size.unsqueeze(-2) | |
valid = torch.all((p2d >= 0) & (p2d <= (size - 1)), -1) | |
return valid | |
def project(self, p3d: torch.Tensor) -> Tuple[torch.Tensor]: | |
"""Project 3D points into the camera plane and check for visibility.""" | |
z = p3d[..., -1] | |
valid = z > self.eps | |
z = z.clamp(min=self.eps) | |
p2d = p3d[..., :-1] / z.unsqueeze(-1) | |
return p2d, valid | |
def J_project(self, p3d: torch.Tensor): | |
x, y, z = p3d[..., 0], p3d[..., 1], p3d[..., 2] | |
zero = torch.zeros_like(z) | |
z = z.clamp(min=self.eps) | |
J = torch.stack([1 / z, zero, -x / z**2, zero, 1 / z, -y / z**2], dim=-1) | |
J = J.reshape(p3d.shape[:-1] + (2, 3)) | |
return J # N x 2 x 3 | |
def distort(self, pts: torch.Tensor) -> Tuple[torch.Tensor]: | |
"""Distort normalized 2D coordinates | |
and check for validity of the distortion model. | |
""" | |
assert pts.shape[-1] == 2 | |
# assert pts.shape[:-2] == self.shape # allow broadcasting | |
return distort_points(pts, self.dist) | |
def J_distort(self, pts: torch.Tensor): | |
return J_distort_points(pts, self.dist) # N x 2 x 2 | |
def denormalize(self, p2d: torch.Tensor) -> torch.Tensor: | |
"""Convert normalized 2D coordinates into pixel coordinates.""" | |
return p2d * self.f.unsqueeze(-2) + self.c.unsqueeze(-2) | |
def normalize(self, p2d: torch.Tensor) -> torch.Tensor: | |
"""Convert normalized 2D coordinates into pixel coordinates.""" | |
return (p2d - self.c.unsqueeze(-2)) / self.f.unsqueeze(-2) | |
def J_denormalize(self): | |
return torch.diag_embed(self.f).unsqueeze(-3) # 1 x 2 x 2 | |
def cam2image(self, p3d: torch.Tensor) -> Tuple[torch.Tensor]: | |
"""Transform 3D points into 2D pixel coordinates.""" | |
p2d, visible = self.project(p3d) | |
p2d, mask = self.distort(p2d) | |
p2d = self.denormalize(p2d) | |
valid = visible & mask & self.in_image(p2d) | |
return p2d, valid | |
def J_world2image(self, p3d: torch.Tensor): | |
p2d_dist, valid = self.project(p3d) | |
J = self.J_denormalize() @ self.J_distort(p2d_dist) @ self.J_project(p3d) | |
return J, valid | |
def image2cam(self, p2d: torch.Tensor) -> torch.Tensor: | |
"""Convert 2D pixel corrdinates to 3D points with z=1""" | |
assert self._data.shape | |
p2d = self.normalize(p2d) | |
# iterative undistortion | |
return to_homogeneous(p2d) | |
def to_cameradict(self, camera_model: Optional[str] = None) -> List[Dict]: | |
data = self._data.clone() | |
if data.dim() == 1: | |
data = data.unsqueeze(0) | |
assert data.dim() == 2 | |
b, d = data.shape | |
if camera_model is None: | |
camera_model = {6: "PINHOLE", 8: "RADIAL", 10: "OPENCV"}[d] | |
cameras = [] | |
for i in range(b): | |
if camera_model.startswith("SIMPLE_"): | |
params = [x.item() for x in data[i, 3 : min(d, 7)]] | |
else: | |
params = [x.item() for x in data[i, 2:]] | |
cameras.append( | |
{ | |
"model": camera_model, | |
"width": int(data[i, 0].item()), | |
"height": int(data[i, 1].item()), | |
"params": params, | |
} | |
) | |
return cameras if self._data.dim() == 2 else cameras[0] | |
def __repr__(self): | |
return f"Camera {self.shape} {self.dtype} {self.device}" | |