BerfScene / models /rendering /triplane_sampler.py
3v324v23's picture
init
2f85de4
"""Contain the functions to sample point features from the triplane
representation."""
import torch
__all__ = ['TriplaneSampler']
class TriplaneSampler(torch.nn.Module):
"""Defines the class to help sample point features from the triplane
representation.
Basically, this class implements the following functions for sampling point
features (rgb && sigma) from the triplane representation:
1. `generate_planes()`.
2. `project_onto_planes()`.
3. `sample_from_planes()`.
4. `sample_from_3dgrid()`.
"""
def __init__(self):
super().__init__()
@staticmethod
def generate_planes():
"""
Defines planes by the three vectors that form the "axes" of the
plane. Should work with arbitrary number of planes and planes of
arbitrary orientation.
"""
return torch.tensor([[[1, 0, 0],
[0, 1, 0],
[0, 0, 1]],
[[1, 0, 0],
[0, 0, 1],
[0, 1, 0]],
[[0, 0, 1],
[1, 0, 0],
[0, 1, 0]]], dtype=torch.float32)
@staticmethod
def project_onto_planes(planes, coordinates):
"""
Does a projection of a 3D point onto a batch of 2D planes,
returning 2D plane coordinates.
Args:
planes: Plane axes of shape (n_planes, 3, 3)
coordinates: Coordinates of shape (N, M, 3)
Returns:
projections: Projections of shape (N*n_planes, M, 2)
"""
N, M, C = coordinates.shape
n_planes, _, _ = planes.shape
coordinates = coordinates.unsqueeze(1).expand(-1, n_planes, -1,
-1).reshape(
N * n_planes, M, 3)
inv_planes = torch.linalg.inv(planes).unsqueeze(0).expand(
N, -1, -1, -1).reshape(N * n_planes, 3, 3)
projections = torch.bmm(coordinates, inv_planes)
return projections[..., :2]
@staticmethod
def sample_from_planes(plane_axes,
plane_features,
coordinates,
mode='bilinear',
padding_mode='zeros',
box_warp=None):
assert padding_mode == 'zeros'
N, n_planes, C, H, W = plane_features.shape
_, M, _ = coordinates.shape
plane_features = plane_features.view(N * n_planes, C, H, W)
coordinates = (2 / box_warp) * coordinates
projected_coordinates = TriplaneSampler.project_onto_planes(
plane_axes, coordinates).unsqueeze(1)
output_features = torch.nn.functional.grid_sample(
plane_features,
projected_coordinates.float(),
mode=mode,
padding_mode=padding_mode,
align_corners=False).permute(0, 3, 2,
1).reshape(N, n_planes, M, C)
return output_features
@staticmethod
def sample_from_3dgrid(grid, coordinates):
"""
Expects coordinates in shape (batch_size, num_points_per_batch, 3)
Expects grid in shape (1, channels, H, W, D)
(Also works if grid has batch size)
Returns:
Sampled features
with shape: (batch_size, num_points_per_batch, feature_channels).
"""
batch_size, n_coords, n_dims = coordinates.shape
sampled_features = torch.nn.functional.grid_sample(
grid.expand(batch_size, -1, -1, -1, -1),
coordinates.reshape(batch_size, 1, 1, -1, n_dims),
mode='bilinear',
padding_mode='zeros',
align_corners=False)
N, C, H, W, D = sampled_features.shape
sampled_features = sampled_features.permute(0, 4, 3, 2, 1).reshape(
N, H * W * D, C)
return sampled_features