"""Contain the functions to sample point features from the triplane representation.""" import torch __all__ = ['TriplaneSampler'] class TriplaneSampler(torch.nn.Module): """Defines the class to help sample point features from the triplane representation. Basically, this class implements the following functions for sampling point features (rgb && sigma) from the triplane representation: 1. `generate_planes()`. 2. `project_onto_planes()`. 3. `sample_from_planes()`. 4. `sample_from_3dgrid()`. """ def __init__(self): super().__init__() @staticmethod def generate_planes(): """ Defines planes by the three vectors that form the "axes" of the plane. Should work with arbitrary number of planes and planes of arbitrary orientation. """ return torch.tensor([[[1, 0, 0], [0, 1, 0], [0, 0, 1]], [[1, 0, 0], [0, 0, 1], [0, 1, 0]], [[0, 0, 1], [1, 0, 0], [0, 1, 0]]], dtype=torch.float32) @staticmethod def project_onto_planes(planes, coordinates): """ Does a projection of a 3D point onto a batch of 2D planes, returning 2D plane coordinates. Args: planes: Plane axes of shape (n_planes, 3, 3) coordinates: Coordinates of shape (N, M, 3) Returns: projections: Projections of shape (N*n_planes, M, 2) """ N, M, C = coordinates.shape n_planes, _, _ = planes.shape coordinates = coordinates.unsqueeze(1).expand(-1, n_planes, -1, -1).reshape( N * n_planes, M, 3) inv_planes = torch.linalg.inv(planes).unsqueeze(0).expand( N, -1, -1, -1).reshape(N * n_planes, 3, 3) projections = torch.bmm(coordinates, inv_planes) return projections[..., :2] @staticmethod def sample_from_planes(plane_axes, plane_features, coordinates, mode='bilinear', padding_mode='zeros', box_warp=None): assert padding_mode == 'zeros' N, n_planes, C, H, W = plane_features.shape _, M, _ = coordinates.shape plane_features = plane_features.view(N * n_planes, C, H, W) coordinates = (2 / box_warp) * coordinates projected_coordinates = TriplaneSampler.project_onto_planes( plane_axes, coordinates).unsqueeze(1) output_features = torch.nn.functional.grid_sample( plane_features, projected_coordinates.float(), mode=mode, padding_mode=padding_mode, align_corners=False).permute(0, 3, 2, 1).reshape(N, n_planes, M, C) return output_features @staticmethod def sample_from_3dgrid(grid, coordinates): """ Expects coordinates in shape (batch_size, num_points_per_batch, 3) Expects grid in shape (1, channels, H, W, D) (Also works if grid has batch size) Returns: Sampled features with shape: (batch_size, num_points_per_batch, feature_channels). """ batch_size, n_coords, n_dims = coordinates.shape sampled_features = torch.nn.functional.grid_sample( grid.expand(batch_size, -1, -1, -1, -1), coordinates.reshape(batch_size, 1, 1, -1, n_dims), mode='bilinear', padding_mode='zeros', align_corners=False) N, C, H, W, D = sampled_features.shape sampled_features = sampled_features.permute(0, 4, 3, 2, 1).reshape( N, H * W * D, C) return sampled_features