MoGe / utils3d /torch /utils.py
Ruicheng's picture
first commit
ec0c8fa
raw
history blame
13.9 kB
from typing import *
import torch
import torch.nn.functional as F
from . import transforms
from . import mesh
from ._helpers import batched
__all__ = [
'sliding_window_1d',
'sliding_window_2d',
'sliding_window_nd',
'image_uv',
'image_pixel_center',
'image_mesh',
'chessboard',
'depth_edge',
'depth_aliasing',
'image_mesh_from_depth',
'point_to_normal',
'depth_to_normal',
'masked_min',
'masked_max',
'bounding_rect'
]
def sliding_window_1d(x: torch.Tensor, window_size: int, stride: int = 1, dim: int = -1) -> torch.Tensor:
"""
Sliding window view of the input tensor. The dimension of the sliding window is appended to the end of the input tensor's shape.
NOTE: Since Pytorch has `unfold` function, 1D sliding window view is just a wrapper of it.
"""
return x.unfold(dim, window_size, stride)
def sliding_window_nd(x: torch.Tensor, window_size: Tuple[int, ...], stride: Tuple[int, ...], dim: Tuple[int, ...]) -> torch.Tensor:
dim = [dim[i] % x.ndim for i in range(len(dim))]
assert len(window_size) == len(stride) == len(dim)
for i in range(len(window_size)):
x = sliding_window_1d(x, window_size[i], stride[i], dim[i])
return x
def sliding_window_2d(x: torch.Tensor, window_size: Union[int, Tuple[int, int]], stride: Union[int, Tuple[int, int]], dim: Union[int, Tuple[int, int]] = (-2, -1)) -> torch.Tensor:
if isinstance(window_size, int):
window_size = (window_size, window_size)
if isinstance(stride, int):
stride = (stride, stride)
return sliding_window_nd(x, window_size, stride, dim)
def image_uv(height: int, width: int, left: int = None, top: int = None, right: int = None, bottom: int = None, device: torch.device = None, dtype: torch.dtype = None) -> torch.Tensor:
"""
Get image space UV grid, ranging in [0, 1].
>>> image_uv(10, 10):
[[[0.05, 0.05], [0.15, 0.05], ..., [0.95, 0.05]],
[[0.05, 0.15], [0.15, 0.15], ..., [0.95, 0.15]],
... ... ...
[[0.05, 0.95], [0.15, 0.95], ..., [0.95, 0.95]]]
Args:
width (int): image width
height (int): image height
Returns:
np.ndarray: shape (height, width, 2)
"""
if left is None: left = 0
if top is None: top = 0
if right is None: right = width
if bottom is None: bottom = height
u = torch.linspace((left + 0.5) / width, (right - 0.5) / width, right - left, device=device, dtype=dtype)
v = torch.linspace((top + 0.5) / height, (bottom - 0.5) / height, bottom - top, device=device, dtype=dtype)
u, v = torch.meshgrid(u, v, indexing='xy')
uv = torch.stack([u, v], dim=-1)
return uv
def image_pixel_center(
height: int,
width: int,
left: int = None,
top: int = None,
right: int = None,
bottom: int = None,
dtype: torch.dtype = None,
device: torch.device = None
) -> torch.Tensor:
"""
Get image pixel center coordinates, ranging in [0, width] and [0, height].
`image[i, j]` has pixel center coordinates `(j + 0.5, i + 0.5)`.
>>> image_pixel_center(10, 10):
[[[0.5, 0.5], [1.5, 0.5], ..., [9.5, 0.5]],
[[0.5, 1.5], [1.5, 1.5], ..., [9.5, 1.5]],
... ... ...
[[0.5, 9.5], [1.5, 9.5], ..., [9.5, 9.5]]]
Args:
width (int): image width
height (int): image height
Returns:
np.ndarray: shape (height, width, 2)
"""
if left is None: left = 0
if top is None: top = 0
if right is None: right = width
if bottom is None: bottom = height
u = torch.linspace(left + 0.5, right - 0.5, right - left, dtype=dtype, device=device)
v = torch.linspace(top + 0.5, bottom - 0.5, bottom - top, dtype=dtype, device=device)
u, v = torch.meshgrid(u, v, indexing='xy')
return torch.stack([u, v], dim=2)
def image_mesh(height: int, width: int, mask: torch.Tensor = None, device: torch.device = None, dtype: torch.dtype = None) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Get a quad mesh regarding image pixel uv coordinates as vertices and image grid as faces.
Args:
width (int): image width
height (int): image height
mask (np.ndarray, optional): binary mask of shape (height, width), dtype=bool. Defaults to None.
Returns:
uv (np.ndarray): uv corresponding to pixels as described in image_uv()
faces (np.ndarray): quad faces connecting neighboring pixels
indices (np.ndarray, optional): indices of vertices in the original mesh
"""
if device is None and mask is not None:
device = mask.device
if mask is not None:
assert mask.shape[0] == height and mask.shape[1] == width
assert mask.dtype == torch.bool
uv = image_uv(height, width, device=device, dtype=dtype).reshape((-1, 2))
row_faces = torch.stack([
torch.arange(0, width - 1, dtype=torch.int32, device=device),
torch.arange(width, 2 * width - 1, dtype=torch.int32, device=device),
torch.arange(1 + width, 2 * width, dtype=torch.int32, device=device),
torch.arange(1, width, dtype=torch.int32, device=device)
], dim=1)
faces = (torch.arange(0, (height - 1) * width, width, device=device, dtype=torch.int32)[:, None, None] + row_faces[None, :, :]).reshape((-1, 4))
if mask is not None:
quad_mask = (mask[:-1, :-1] & mask[1:, :-1] & mask[1:, 1:] & mask[:-1, 1:]).ravel()
faces = faces[quad_mask]
faces, uv, indices = mesh.remove_unreferenced_vertices(faces, uv, return_indices=True)
return uv, faces, indices
return uv, faces
def depth_edge(depth: torch.Tensor, atol: float = None, rtol: float = None, kernel_size: int = 3, mask: torch.Tensor = None) -> torch.BoolTensor:
"""
Compute the edge mask of a depth map. The edge is defined as the pixels whose neighbors have a large difference in depth.
Args:
depth (torch.Tensor): shape (..., height, width), linear depth map
atol (float): absolute tolerance
rtol (float): relative tolerance
Returns:
edge (torch.Tensor): shape (..., height, width) of dtype torch.bool
"""
shape = depth.shape
depth = depth.reshape(-1, 1, *shape[-2:])
if mask is not None:
mask = mask.reshape(-1, 1, *shape[-2:])
if mask is None:
diff = (F.max_pool2d(depth, kernel_size, stride=1, padding=kernel_size // 2) + F.max_pool2d(-depth, kernel_size, stride=1, padding=kernel_size // 2))
else:
diff = (F.max_pool2d(torch.where(mask, depth, -torch.inf), kernel_size, stride=1, padding=kernel_size // 2) + F.max_pool2d(torch.where(mask, -depth, -torch.inf), kernel_size, stride=1, padding=kernel_size // 2))
edge = torch.zeros_like(depth, dtype=torch.bool)
if atol is not None:
edge |= diff > atol
if rtol is not None:
edge |= (diff / depth).nan_to_num_() > rtol
edge = edge.reshape(*shape)
return edge
def depth_aliasing(depth: torch.Tensor, atol: float = None, rtol: float = None, kernel_size: int = 3, mask: torch.Tensor = None) -> torch.BoolTensor:
"""
Compute the map that indicates the aliasing of a depth map. The aliasing is defined as the pixels which neither close to the maximum nor the minimum of its neighbors.
Args:
depth (torch.Tensor): shape (..., height, width), linear depth map
atol (float): absolute tolerance
rtol (float): relative tolerance
Returns:
edge (torch.Tensor): shape (..., height, width) of dtype torch.bool
"""
shape = depth.shape
depth = depth.reshape(-1, 1, *shape[-2:])
if mask is not None:
mask = mask.reshape(-1, 1, *shape[-2:])
if mask is None:
diff_max = F.max_pool2d(depth, kernel_size, stride=1, padding=kernel_size // 2) - depth
diff_min = F.max_pool2d(-depth, kernel_size, stride=1, padding=kernel_size // 2) + depth
else:
diff_max = F.max_pool2d(torch.where(mask, depth, -torch.inf), kernel_size, stride=1, padding=kernel_size // 2) - depth
diff_min = F.max_pool2d(torch.where(mask, -depth, -torch.inf), kernel_size, stride=1, padding=kernel_size // 2) + depth
diff = torch.minimum(diff_max, diff_min)
edge = torch.zeros_like(depth, dtype=torch.bool)
if atol is not None:
edge |= diff > atol
if rtol is not None:
edge |= (diff / depth).nan_to_num_() > rtol
edge = edge.reshape(*shape)
return edge
def image_mesh_from_depth(
depth: torch.Tensor,
extrinsics: torch.Tensor = None,
intrinsics: torch.Tensor = None
) -> Tuple[torch.Tensor, torch.Tensor]:
height, width = depth.shape
uv, faces = image_mesh(height, width)
faces = faces.reshape(-1, 4)
depth = depth.reshape(-1)
pts = transforms.unproject_cv(image_uv, depth, extrinsics, intrinsics)
faces = mesh.triangulate(faces, vertices=pts)
return pts, faces
@batched(3, 2, 2)
def point_to_normal(point: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
"""
Calculate normal map from point map. Value range is [-1, 1]. Normal direction in OpenGL identity camera's coordinate system.
Args:
point (torch.Tensor): shape (..., height, width, 3), point map
Returns:
normal (torch.Tensor): shape (..., height, width, 3), normal map.
"""
has_mask = mask is not None
if mask is None:
mask = torch.ones_like(point[..., 0], dtype=torch.bool)
mask = F.pad(mask, (1, 1, 1, 1), mode='constant', value=0)
pts = F.pad(point.permute(0, 3, 1, 2), (1, 1, 1, 1), mode='constant', value=1).permute(0, 2, 3, 1)
up = pts[:, :-2, 1:-1, :] - pts[:, 1:-1, 1:-1, :]
left = pts[:, 1:-1, :-2, :] - pts[:, 1:-1, 1:-1, :]
down = pts[:, 2:, 1:-1, :] - pts[:, 1:-1, 1:-1, :]
right = pts[:, 1:-1, 2:, :] - pts[:, 1:-1, 1:-1, :]
normal = torch.stack([
torch.cross(up, left, dim=-1),
torch.cross(left, down, dim=-1),
torch.cross(down, right, dim=-1),
torch.cross(right, up, dim=-1),
])
normal = F.normalize(normal, dim=-1)
valid = torch.stack([
mask[:, :-2, 1:-1] & mask[:, 1:-1, :-2],
mask[:, 1:-1, :-2] & mask[:, 2:, 1:-1],
mask[:, 2:, 1:-1] & mask[:, 1:-1, 2:],
mask[:, 1:-1, 2:] & mask[:, :-2, 1:-1],
]) & mask[None, :, 1:-1, 1:-1]
normal = (normal * valid[..., None]).sum(dim=0)
normal = F.normalize(normal, dim=-1)
if has_mask:
return normal, valid.any(dim=0)
else:
return normal
@batched(2, 2, 2)
def depth_to_normal(depth: torch.Tensor, intrinsics: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
"""
Calculate normal map from depth map. Value range is [-1, 1]. Normal direction in OpenGL identity camera's coordinate system.
Args:
depth (torch.Tensor): shape (..., height, width), linear depth map
intrinsics (torch.Tensor): shape (..., 3, 3), intrinsics matrix
Returns:
normal (torch.Tensor): shape (..., 3, height, width), normal map.
"""
has_mask = mask is not None
height, width = depth.shape[-2:]
if mask is None:
mask = torch.ones_like(depth, dtype=torch.bool)
mask = F.pad(mask, (1, 1, 1, 1), mode='constant', value=0)
uv = image_uv(*depth.shape[-2:]).unsqueeze(0).to(depth)
pts = transforms.unproject_cv(uv.reshape(-1, 2), depth.flatten(-2), intrinsics=intrinsics, extrinsics=None).unflatten(-2, (height, width))
return point_to_normal(pts, mask)
def masked_min(input: torch.Tensor, mask: torch.BoolTensor, dim: int = None, keepdim: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""Similar to torch.min, but with mask
"""
if dim is None:
return torch.where(mask, input, torch.tensor(torch.inf, dtype=input.dtype, device=input.device)).min()
else:
return torch.where(mask, input, torch.tensor(torch.inf, dtype=input.dtype, device=input.device)).min(dim=dim, keepdim=keepdim)
def masked_max(input: torch.Tensor, mask: torch.BoolTensor, dim: int = None, keepdim: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""Similar to torch.max, but with mask
"""
if dim is None:
return torch.where(mask, input, torch.tensor(-torch.inf, dtype=input.dtype, device=input.device)).max()
else:
return torch.where(mask, input, torch.tensor(-torch.inf, dtype=input.dtype, device=input.device)).max(dim=dim, keepdim=keepdim)
def bounding_rect(mask: torch.BoolTensor):
"""get bounding rectangle of a mask
Args:
mask (torch.Tensor): shape (..., height, width), mask
Returns:
rect (torch.Tensor): shape (..., 4), bounding rectangle (left, top, right, bottom)
"""
height, width = mask.shape[-2:]
mask = mask.flatten(-2).unsqueeze(-1)
uv = image_uv(height, width).to(mask.device).reshape(-1, 2)
left_top = masked_min(uv, mask, dim=-2)[0]
right_bottom = masked_max(uv, mask, dim=-2)[0]
return torch.cat([left_top, right_bottom], dim=-1)
def chessboard(width: int, height: int, grid_size: int, color_a: torch.Tensor, color_b: torch.Tensor) -> torch.Tensor:
"""get a chessboard image
Args:
width (int): image width
height (int): image height
grid_size (int): size of chessboard grid
color_a (torch.Tensor): shape (chanenls,), color of the grid at the top-left corner
color_b (torch.Tensor): shape (chanenls,), color in complementary grids
Returns:
image (torch.Tensor): shape (height, width, channels), chessboard image
"""
x = torch.div(torch.arange(width), grid_size, rounding_mode='floor')
y = torch.div(torch.arange(height), grid_size, rounding_mode='floor')
mask = ((x[None, :] + y[:, None]) % 2).to(color_a)
image = (1 - mask[..., None]) * color_a + mask[..., None] * color_b
return image