|
from typing import * |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
|
|
from . import transforms |
|
from . import mesh |
|
from ._helpers import batched |
|
|
|
|
|
__all__ = [ |
|
'sliding_window_1d', |
|
'sliding_window_2d', |
|
'sliding_window_nd', |
|
'image_uv', |
|
'image_pixel_center', |
|
'image_mesh', |
|
'chessboard', |
|
'depth_edge', |
|
'depth_aliasing', |
|
'image_mesh_from_depth', |
|
'point_to_normal', |
|
'depth_to_normal', |
|
'masked_min', |
|
'masked_max', |
|
'bounding_rect' |
|
] |
|
|
|
|
|
def sliding_window_1d(x: torch.Tensor, window_size: int, stride: int = 1, dim: int = -1) -> torch.Tensor: |
|
""" |
|
Sliding window view of the input tensor. The dimension of the sliding window is appended to the end of the input tensor's shape. |
|
NOTE: Since Pytorch has `unfold` function, 1D sliding window view is just a wrapper of it. |
|
""" |
|
return x.unfold(dim, window_size, stride) |
|
|
|
|
|
def sliding_window_nd(x: torch.Tensor, window_size: Tuple[int, ...], stride: Tuple[int, ...], dim: Tuple[int, ...]) -> torch.Tensor: |
|
dim = [dim[i] % x.ndim for i in range(len(dim))] |
|
assert len(window_size) == len(stride) == len(dim) |
|
for i in range(len(window_size)): |
|
x = sliding_window_1d(x, window_size[i], stride[i], dim[i]) |
|
return x |
|
|
|
|
|
def sliding_window_2d(x: torch.Tensor, window_size: Union[int, Tuple[int, int]], stride: Union[int, Tuple[int, int]], dim: Union[int, Tuple[int, int]] = (-2, -1)) -> torch.Tensor: |
|
if isinstance(window_size, int): |
|
window_size = (window_size, window_size) |
|
if isinstance(stride, int): |
|
stride = (stride, stride) |
|
return sliding_window_nd(x, window_size, stride, dim) |
|
|
|
|
|
def image_uv(height: int, width: int, left: int = None, top: int = None, right: int = None, bottom: int = None, device: torch.device = None, dtype: torch.dtype = None) -> torch.Tensor: |
|
""" |
|
Get image space UV grid, ranging in [0, 1]. |
|
|
|
>>> image_uv(10, 10): |
|
[[[0.05, 0.05], [0.15, 0.05], ..., [0.95, 0.05]], |
|
[[0.05, 0.15], [0.15, 0.15], ..., [0.95, 0.15]], |
|
... ... ... |
|
[[0.05, 0.95], [0.15, 0.95], ..., [0.95, 0.95]]] |
|
|
|
Args: |
|
width (int): image width |
|
height (int): image height |
|
|
|
Returns: |
|
np.ndarray: shape (height, width, 2) |
|
""" |
|
if left is None: left = 0 |
|
if top is None: top = 0 |
|
if right is None: right = width |
|
if bottom is None: bottom = height |
|
u = torch.linspace((left + 0.5) / width, (right - 0.5) / width, right - left, device=device, dtype=dtype) |
|
v = torch.linspace((top + 0.5) / height, (bottom - 0.5) / height, bottom - top, device=device, dtype=dtype) |
|
u, v = torch.meshgrid(u, v, indexing='xy') |
|
uv = torch.stack([u, v], dim=-1) |
|
return uv |
|
|
|
|
|
def image_pixel_center( |
|
height: int, |
|
width: int, |
|
left: int = None, |
|
top: int = None, |
|
right: int = None, |
|
bottom: int = None, |
|
dtype: torch.dtype = None, |
|
device: torch.device = None |
|
) -> torch.Tensor: |
|
""" |
|
Get image pixel center coordinates, ranging in [0, width] and [0, height]. |
|
`image[i, j]` has pixel center coordinates `(j + 0.5, i + 0.5)`. |
|
|
|
>>> image_pixel_center(10, 10): |
|
[[[0.5, 0.5], [1.5, 0.5], ..., [9.5, 0.5]], |
|
[[0.5, 1.5], [1.5, 1.5], ..., [9.5, 1.5]], |
|
... ... ... |
|
[[0.5, 9.5], [1.5, 9.5], ..., [9.5, 9.5]]] |
|
|
|
Args: |
|
width (int): image width |
|
height (int): image height |
|
|
|
Returns: |
|
np.ndarray: shape (height, width, 2) |
|
""" |
|
if left is None: left = 0 |
|
if top is None: top = 0 |
|
if right is None: right = width |
|
if bottom is None: bottom = height |
|
u = torch.linspace(left + 0.5, right - 0.5, right - left, dtype=dtype, device=device) |
|
v = torch.linspace(top + 0.5, bottom - 0.5, bottom - top, dtype=dtype, device=device) |
|
u, v = torch.meshgrid(u, v, indexing='xy') |
|
return torch.stack([u, v], dim=2) |
|
|
|
|
|
def image_mesh(height: int, width: int, mask: torch.Tensor = None, device: torch.device = None, dtype: torch.dtype = None) -> Tuple[torch.Tensor, torch.Tensor]: |
|
""" |
|
Get a quad mesh regarding image pixel uv coordinates as vertices and image grid as faces. |
|
|
|
Args: |
|
width (int): image width |
|
height (int): image height |
|
mask (np.ndarray, optional): binary mask of shape (height, width), dtype=bool. Defaults to None. |
|
|
|
Returns: |
|
uv (np.ndarray): uv corresponding to pixels as described in image_uv() |
|
faces (np.ndarray): quad faces connecting neighboring pixels |
|
indices (np.ndarray, optional): indices of vertices in the original mesh |
|
""" |
|
if device is None and mask is not None: |
|
device = mask.device |
|
if mask is not None: |
|
assert mask.shape[0] == height and mask.shape[1] == width |
|
assert mask.dtype == torch.bool |
|
uv = image_uv(height, width, device=device, dtype=dtype).reshape((-1, 2)) |
|
row_faces = torch.stack([ |
|
torch.arange(0, width - 1, dtype=torch.int32, device=device), |
|
torch.arange(width, 2 * width - 1, dtype=torch.int32, device=device), |
|
torch.arange(1 + width, 2 * width, dtype=torch.int32, device=device), |
|
torch.arange(1, width, dtype=torch.int32, device=device) |
|
], dim=1) |
|
faces = (torch.arange(0, (height - 1) * width, width, device=device, dtype=torch.int32)[:, None, None] + row_faces[None, :, :]).reshape((-1, 4)) |
|
if mask is not None: |
|
quad_mask = (mask[:-1, :-1] & mask[1:, :-1] & mask[1:, 1:] & mask[:-1, 1:]).ravel() |
|
faces = faces[quad_mask] |
|
faces, uv, indices = mesh.remove_unreferenced_vertices(faces, uv, return_indices=True) |
|
return uv, faces, indices |
|
return uv, faces |
|
|
|
|
|
def depth_edge(depth: torch.Tensor, atol: float = None, rtol: float = None, kernel_size: int = 3, mask: torch.Tensor = None) -> torch.BoolTensor: |
|
""" |
|
Compute the edge mask of a depth map. The edge is defined as the pixels whose neighbors have a large difference in depth. |
|
|
|
Args: |
|
depth (torch.Tensor): shape (..., height, width), linear depth map |
|
atol (float): absolute tolerance |
|
rtol (float): relative tolerance |
|
|
|
Returns: |
|
edge (torch.Tensor): shape (..., height, width) of dtype torch.bool |
|
""" |
|
shape = depth.shape |
|
depth = depth.reshape(-1, 1, *shape[-2:]) |
|
if mask is not None: |
|
mask = mask.reshape(-1, 1, *shape[-2:]) |
|
|
|
if mask is None: |
|
diff = (F.max_pool2d(depth, kernel_size, stride=1, padding=kernel_size // 2) + F.max_pool2d(-depth, kernel_size, stride=1, padding=kernel_size // 2)) |
|
else: |
|
diff = (F.max_pool2d(torch.where(mask, depth, -torch.inf), kernel_size, stride=1, padding=kernel_size // 2) + F.max_pool2d(torch.where(mask, -depth, -torch.inf), kernel_size, stride=1, padding=kernel_size // 2)) |
|
|
|
edge = torch.zeros_like(depth, dtype=torch.bool) |
|
if atol is not None: |
|
edge |= diff > atol |
|
if rtol is not None: |
|
edge |= (diff / depth).nan_to_num_() > rtol |
|
edge = edge.reshape(*shape) |
|
return edge |
|
|
|
|
|
def depth_aliasing(depth: torch.Tensor, atol: float = None, rtol: float = None, kernel_size: int = 3, mask: torch.Tensor = None) -> torch.BoolTensor: |
|
""" |
|
Compute the map that indicates the aliasing of a depth map. The aliasing is defined as the pixels which neither close to the maximum nor the minimum of its neighbors. |
|
Args: |
|
depth (torch.Tensor): shape (..., height, width), linear depth map |
|
atol (float): absolute tolerance |
|
rtol (float): relative tolerance |
|
|
|
Returns: |
|
edge (torch.Tensor): shape (..., height, width) of dtype torch.bool |
|
""" |
|
shape = depth.shape |
|
depth = depth.reshape(-1, 1, *shape[-2:]) |
|
if mask is not None: |
|
mask = mask.reshape(-1, 1, *shape[-2:]) |
|
|
|
if mask is None: |
|
diff_max = F.max_pool2d(depth, kernel_size, stride=1, padding=kernel_size // 2) - depth |
|
diff_min = F.max_pool2d(-depth, kernel_size, stride=1, padding=kernel_size // 2) + depth |
|
else: |
|
diff_max = F.max_pool2d(torch.where(mask, depth, -torch.inf), kernel_size, stride=1, padding=kernel_size // 2) - depth |
|
diff_min = F.max_pool2d(torch.where(mask, -depth, -torch.inf), kernel_size, stride=1, padding=kernel_size // 2) + depth |
|
diff = torch.minimum(diff_max, diff_min) |
|
|
|
edge = torch.zeros_like(depth, dtype=torch.bool) |
|
if atol is not None: |
|
edge |= diff > atol |
|
if rtol is not None: |
|
edge |= (diff / depth).nan_to_num_() > rtol |
|
edge = edge.reshape(*shape) |
|
return edge |
|
|
|
|
|
def image_mesh_from_depth( |
|
depth: torch.Tensor, |
|
extrinsics: torch.Tensor = None, |
|
intrinsics: torch.Tensor = None |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
height, width = depth.shape |
|
uv, faces = image_mesh(height, width) |
|
faces = faces.reshape(-1, 4) |
|
depth = depth.reshape(-1) |
|
pts = transforms.unproject_cv(image_uv, depth, extrinsics, intrinsics) |
|
faces = mesh.triangulate(faces, vertices=pts) |
|
return pts, faces |
|
|
|
|
|
@batched(3, 2, 2) |
|
def point_to_normal(point: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor: |
|
""" |
|
Calculate normal map from point map. Value range is [-1, 1]. Normal direction in OpenGL identity camera's coordinate system. |
|
|
|
Args: |
|
point (torch.Tensor): shape (..., height, width, 3), point map |
|
Returns: |
|
normal (torch.Tensor): shape (..., height, width, 3), normal map. |
|
""" |
|
has_mask = mask is not None |
|
|
|
if mask is None: |
|
mask = torch.ones_like(point[..., 0], dtype=torch.bool) |
|
mask = F.pad(mask, (1, 1, 1, 1), mode='constant', value=0) |
|
|
|
pts = F.pad(point.permute(0, 3, 1, 2), (1, 1, 1, 1), mode='constant', value=1).permute(0, 2, 3, 1) |
|
up = pts[:, :-2, 1:-1, :] - pts[:, 1:-1, 1:-1, :] |
|
left = pts[:, 1:-1, :-2, :] - pts[:, 1:-1, 1:-1, :] |
|
down = pts[:, 2:, 1:-1, :] - pts[:, 1:-1, 1:-1, :] |
|
right = pts[:, 1:-1, 2:, :] - pts[:, 1:-1, 1:-1, :] |
|
normal = torch.stack([ |
|
torch.cross(up, left, dim=-1), |
|
torch.cross(left, down, dim=-1), |
|
torch.cross(down, right, dim=-1), |
|
torch.cross(right, up, dim=-1), |
|
]) |
|
normal = F.normalize(normal, dim=-1) |
|
valid = torch.stack([ |
|
mask[:, :-2, 1:-1] & mask[:, 1:-1, :-2], |
|
mask[:, 1:-1, :-2] & mask[:, 2:, 1:-1], |
|
mask[:, 2:, 1:-1] & mask[:, 1:-1, 2:], |
|
mask[:, 1:-1, 2:] & mask[:, :-2, 1:-1], |
|
]) & mask[None, :, 1:-1, 1:-1] |
|
normal = (normal * valid[..., None]).sum(dim=0) |
|
normal = F.normalize(normal, dim=-1) |
|
|
|
if has_mask: |
|
return normal, valid.any(dim=0) |
|
else: |
|
return normal |
|
|
|
|
|
@batched(2, 2, 2) |
|
def depth_to_normal(depth: torch.Tensor, intrinsics: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor: |
|
""" |
|
Calculate normal map from depth map. Value range is [-1, 1]. Normal direction in OpenGL identity camera's coordinate system. |
|
|
|
Args: |
|
depth (torch.Tensor): shape (..., height, width), linear depth map |
|
intrinsics (torch.Tensor): shape (..., 3, 3), intrinsics matrix |
|
Returns: |
|
normal (torch.Tensor): shape (..., 3, height, width), normal map. |
|
""" |
|
has_mask = mask is not None |
|
|
|
height, width = depth.shape[-2:] |
|
if mask is None: |
|
mask = torch.ones_like(depth, dtype=torch.bool) |
|
mask = F.pad(mask, (1, 1, 1, 1), mode='constant', value=0) |
|
|
|
uv = image_uv(*depth.shape[-2:]).unsqueeze(0).to(depth) |
|
pts = transforms.unproject_cv(uv.reshape(-1, 2), depth.flatten(-2), intrinsics=intrinsics, extrinsics=None).unflatten(-2, (height, width)) |
|
|
|
return point_to_normal(pts, mask) |
|
|
|
|
|
def masked_min(input: torch.Tensor, mask: torch.BoolTensor, dim: int = None, keepdim: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: |
|
"""Similar to torch.min, but with mask |
|
""" |
|
if dim is None: |
|
return torch.where(mask, input, torch.tensor(torch.inf, dtype=input.dtype, device=input.device)).min() |
|
else: |
|
return torch.where(mask, input, torch.tensor(torch.inf, dtype=input.dtype, device=input.device)).min(dim=dim, keepdim=keepdim) |
|
|
|
|
|
def masked_max(input: torch.Tensor, mask: torch.BoolTensor, dim: int = None, keepdim: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: |
|
"""Similar to torch.max, but with mask |
|
""" |
|
if dim is None: |
|
return torch.where(mask, input, torch.tensor(-torch.inf, dtype=input.dtype, device=input.device)).max() |
|
else: |
|
return torch.where(mask, input, torch.tensor(-torch.inf, dtype=input.dtype, device=input.device)).max(dim=dim, keepdim=keepdim) |
|
|
|
|
|
def bounding_rect(mask: torch.BoolTensor): |
|
"""get bounding rectangle of a mask |
|
|
|
Args: |
|
mask (torch.Tensor): shape (..., height, width), mask |
|
|
|
Returns: |
|
rect (torch.Tensor): shape (..., 4), bounding rectangle (left, top, right, bottom) |
|
""" |
|
height, width = mask.shape[-2:] |
|
mask = mask.flatten(-2).unsqueeze(-1) |
|
uv = image_uv(height, width).to(mask.device).reshape(-1, 2) |
|
left_top = masked_min(uv, mask, dim=-2)[0] |
|
right_bottom = masked_max(uv, mask, dim=-2)[0] |
|
return torch.cat([left_top, right_bottom], dim=-1) |
|
|
|
|
|
def chessboard(width: int, height: int, grid_size: int, color_a: torch.Tensor, color_b: torch.Tensor) -> torch.Tensor: |
|
"""get a chessboard image |
|
|
|
Args: |
|
width (int): image width |
|
height (int): image height |
|
grid_size (int): size of chessboard grid |
|
color_a (torch.Tensor): shape (chanenls,), color of the grid at the top-left corner |
|
color_b (torch.Tensor): shape (chanenls,), color in complementary grids |
|
|
|
Returns: |
|
image (torch.Tensor): shape (height, width, channels), chessboard image |
|
""" |
|
x = torch.div(torch.arange(width), grid_size, rounding_mode='floor') |
|
y = torch.div(torch.arange(height), grid_size, rounding_mode='floor') |
|
mask = ((x[None, :] + y[:, None]) % 2).to(color_a) |
|
image = (1 - mask[..., None]) * color_a + mask[..., None] * color_b |
|
return image |