| import random |
|
|
| import numpy as np |
| import torch |
| from sklearn.neighbors import NearestNeighbors |
| from torch import Tensor |
| import torch.nn.functional as F |
| import matplotlib.pyplot as plt |
| from matplotlib import colormaps |
|
|
|
|
| class CameraOptModule(torch.nn.Module): |
| """Camera pose optimization module.""" |
|
|
| def __init__(self, n: int): |
| super().__init__() |
| |
| self.embeds = torch.nn.Embedding(n, 9) |
| |
| self.register_buffer("identity", torch.tensor([1.0, 0.0, 0.0, 0.0, 1.0, 0.0])) |
|
|
| def zero_init(self): |
| torch.nn.init.zeros_(self.embeds.weight) |
|
|
| def random_init(self, std: float): |
| torch.nn.init.normal_(self.embeds.weight, std=std) |
|
|
| def forward(self, camtoworlds: Tensor, embed_ids: Tensor) -> Tensor: |
| """Adjust camera pose based on deltas. |
| |
| Args: |
| camtoworlds: (..., 4, 4) |
| embed_ids: (...,) |
| |
| Returns: |
| updated camtoworlds: (..., 4, 4) |
| """ |
| assert camtoworlds.shape[:-2] == embed_ids.shape |
| batch_shape = camtoworlds.shape[:-2] |
| pose_deltas = self.embeds(embed_ids) |
| dx, drot = pose_deltas[..., :3], pose_deltas[..., 3:] |
| rot = rotation_6d_to_matrix( |
| drot + self.identity.expand(*batch_shape, -1) |
| ) |
| transform = torch.eye(4, device=pose_deltas.device).repeat((*batch_shape, 1, 1)) |
| transform[..., :3, :3] = rot |
| transform[..., :3, 3] = dx |
| return torch.matmul(camtoworlds, transform) |
|
|
|
|
| class AppearanceOptModule(torch.nn.Module): |
| """Appearance optimization module.""" |
|
|
| def __init__( |
| self, |
| n: int, |
| feature_dim: int, |
| embed_dim: int = 16, |
| sh_degree: int = 3, |
| mlp_width: int = 64, |
| mlp_depth: int = 2, |
| ): |
| super().__init__() |
| self.embed_dim = embed_dim |
| self.sh_degree = sh_degree |
| self.embeds = torch.nn.Embedding(n, embed_dim) |
| layers = [] |
| layers.append( |
| torch.nn.Linear(embed_dim + feature_dim + (sh_degree + 1) ** 2, mlp_width) |
| ) |
| layers.append(torch.nn.ReLU(inplace=True)) |
| for _ in range(mlp_depth - 1): |
| layers.append(torch.nn.Linear(mlp_width, mlp_width)) |
| layers.append(torch.nn.ReLU(inplace=True)) |
| layers.append(torch.nn.Linear(mlp_width, 3)) |
| self.color_head = torch.nn.Sequential(*layers) |
|
|
| def forward( |
| self, features: Tensor, embed_ids: Tensor, dirs: Tensor, sh_degree: int |
| ) -> Tensor: |
| """Adjust appearance based on embeddings. |
| |
| Args: |
| features: (N, feature_dim) |
| embed_ids: (C,) |
| dirs: (C, N, 3) |
| |
| Returns: |
| colors: (C, N, 3) |
| """ |
| from gsplat.cuda._torch_impl import _eval_sh_bases_fast |
|
|
| C, N = dirs.shape[:2] |
| |
| if embed_ids is None: |
| embeds = torch.zeros(C, self.embed_dim, device=features.device) |
| else: |
| embeds = self.embeds(embed_ids) |
| embeds = embeds[:, None, :].expand(-1, N, -1) |
| |
| features = features[None, :, :].expand(C, -1, -1) |
| |
| dirs = F.normalize(dirs, dim=-1) |
| num_bases_to_use = (sh_degree + 1) ** 2 |
| num_bases = (self.sh_degree + 1) ** 2 |
| sh_bases = torch.zeros(C, N, num_bases, device=features.device) |
| sh_bases[:, :, :num_bases_to_use] = _eval_sh_bases_fast(num_bases_to_use, dirs) |
| |
| if self.embed_dim > 0: |
| h = torch.cat([embeds, features, sh_bases], dim=-1) |
| else: |
| h = torch.cat([features, sh_bases], dim=-1) |
| colors = self.color_head(h) |
| return colors |
|
|
|
|
| def rotation_6d_to_matrix(d6: Tensor) -> Tensor: |
| """ |
| Converts 6D rotation representation by Zhou et al. [1] to rotation matrix |
| using Gram--Schmidt orthogonalization per Section B of [1]. Adapted from pytorch3d. |
| Args: |
| d6: 6D rotation representation, of size (*, 6) |
| |
| Returns: |
| batch of rotation matrices of size (*, 3, 3) |
| |
| [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. |
| On the Continuity of Rotation Representations in Neural Networks. |
| IEEE Conference on Computer Vision and Pattern Recognition, 2019. |
| Retrieved from http://arxiv.org/abs/1812.07035 |
| """ |
|
|
| a1, a2 = d6[..., :3], d6[..., 3:] |
| b1 = F.normalize(a1, dim=-1) |
| b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1 |
| b2 = F.normalize(b2, dim=-1) |
| b3 = torch.cross(b1, b2, dim=-1) |
| return torch.stack((b1, b2, b3), dim=-2) |
|
|
|
|
| def knn(x: Tensor, K: int = 4) -> Tensor: |
| x_np = x.cpu().numpy() |
| model = NearestNeighbors(n_neighbors=K, metric="euclidean").fit(x_np) |
| distances, _ = model.kneighbors(x_np) |
| return torch.from_numpy(distances).to(x) |
|
|
|
|
| def rgb_to_sh(rgb: Tensor) -> Tensor: |
| C0 = 0.28209479177387814 |
| return (rgb - 0.5) / C0 |
|
|
|
|
| def set_random_seed(seed: int): |
| random.seed(seed) |
| np.random.seed(seed) |
| torch.manual_seed(seed) |
|
|
|
|
| |
| def colormap(img, cmap="jet"): |
| W, H = img.shape[:2] |
| dpi = 300 |
| fig, ax = plt.subplots(1, figsize=(H / dpi, W / dpi), dpi=dpi) |
| im = ax.imshow(img, cmap=cmap) |
| ax.set_axis_off() |
| fig.colorbar(im, ax=ax) |
| fig.tight_layout() |
| fig.canvas.draw() |
| data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) |
| data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) |
| img = torch.from_numpy(data).float().permute(2, 0, 1) |
| plt.close() |
| return img |
|
|
|
|
| def apply_float_colormap(img: torch.Tensor, colormap: str = "turbo") -> torch.Tensor: |
| """Convert single channel to a color img. |
| |
| Args: |
| img (torch.Tensor): (..., 1) float32 single channel image. |
| colormap (str): Colormap for img. |
| |
| Returns: |
| (..., 3) colored img with colors in [0, 1]. |
| """ |
| img = torch.nan_to_num(img, 0) |
| if colormap == "gray": |
| return img.repeat(1, 1, 3) |
| img_long = (img * 255).long() |
| img_long_min = torch.min(img_long) |
| img_long_max = torch.max(img_long) |
| assert img_long_min >= 0, f"the min value is {img_long_min}" |
| assert img_long_max <= 255, f"the max value is {img_long_max}" |
| return torch.tensor( |
| colormaps[colormap].colors, |
| device=img.device, |
| )[img_long[..., 0]] |
|
|
|
|
| def apply_depth_colormap( |
| depth: torch.Tensor, |
| acc: torch.Tensor = None, |
| near_plane: float = None, |
| far_plane: float = None, |
| ) -> torch.Tensor: |
| """Converts a depth image to color for easier analysis. |
| |
| Args: |
| depth (torch.Tensor): (..., 1) float32 depth. |
| acc (torch.Tensor | None): (..., 1) optional accumulation mask. |
| near_plane: Closest depth to consider. If None, use min image value. |
| far_plane: Furthest depth to consider. If None, use max image value. |
| |
| Returns: |
| (..., 3) colored depth image with colors in [0, 1]. |
| """ |
| near_plane = near_plane or float(torch.min(depth)) |
| far_plane = far_plane or float(torch.max(depth)) |
| depth = (depth - near_plane) / (far_plane - near_plane + 1e-10) |
| depth = torch.clip(depth, 0.0, 1.0) |
| img = apply_float_colormap(depth, colormap="turbo") |
| if acc is not None: |
| img = img * acc + (1.0 - acc) |
| return img |
|
|