File size: 4,060 Bytes
2f85de4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
"""Contain the functions to sample point features from the triplane
   representation."""

import torch

__all__ = ['TriplaneSampler']


class TriplaneSampler(torch.nn.Module):
    """Defines the class to help sample point features from the triplane
       representation.

    Basically, this class implements the following functions for sampling point
    features (rgb && sigma) from the triplane representation:

    1. `generate_planes()`.
    2. `project_onto_planes()`.
    3. `sample_from_planes()`.
    4. `sample_from_3dgrid()`.
    """

    def __init__(self):
        super().__init__()

    @staticmethod
    def generate_planes():
        """
        Defines planes by the three vectors that form the "axes" of the
        plane. Should work with arbitrary number of planes and planes of
        arbitrary orientation.
        """
        return torch.tensor([[[1, 0, 0],
                                [0, 1, 0],
                                [0, 0, 1]],
                                [[1, 0, 0],
                                [0, 0, 1],
                                [0, 1, 0]],
                                [[0, 0, 1],
                                [1, 0, 0],
                                [0, 1, 0]]], dtype=torch.float32)

    @staticmethod
    def project_onto_planes(planes, coordinates):
        """
        Does a projection of a 3D point onto a batch of 2D planes,
        returning 2D plane coordinates.

        Args:
            planes: Plane axes of shape (n_planes, 3, 3)
            coordinates: Coordinates of shape (N, M, 3)

        Returns:
            projections: Projections of shape (N*n_planes, M, 2)
        """
        N, M, C = coordinates.shape
        n_planes, _, _ = planes.shape
        coordinates = coordinates.unsqueeze(1).expand(-1, n_planes, -1,
                                                      -1).reshape(
                                                          N * n_planes, M, 3)
        inv_planes = torch.linalg.inv(planes).unsqueeze(0).expand(
            N, -1, -1, -1).reshape(N * n_planes, 3, 3)
        projections = torch.bmm(coordinates, inv_planes)
        return projections[..., :2]

    @staticmethod
    def sample_from_planes(plane_axes,
                           plane_features,
                           coordinates,
                           mode='bilinear',
                           padding_mode='zeros',
                           box_warp=None):
        assert padding_mode == 'zeros'
        N, n_planes, C, H, W = plane_features.shape
        _, M, _ = coordinates.shape
        plane_features = plane_features.view(N * n_planes, C, H, W)

        coordinates = (2 / box_warp) * coordinates

        projected_coordinates = TriplaneSampler.project_onto_planes(
            plane_axes, coordinates).unsqueeze(1)
        output_features = torch.nn.functional.grid_sample(
            plane_features,
            projected_coordinates.float(),
            mode=mode,
            padding_mode=padding_mode,
            align_corners=False).permute(0, 3, 2,
                                         1).reshape(N, n_planes, M, C)
        return output_features

    @staticmethod
    def sample_from_3dgrid(grid, coordinates):
        """
        Expects coordinates in shape (batch_size, num_points_per_batch, 3)
        Expects grid in shape (1, channels, H, W, D)
        (Also works if grid has batch size)
        Returns:
            Sampled features
            with shape: (batch_size, num_points_per_batch, feature_channels).
        """
        batch_size, n_coords, n_dims = coordinates.shape
        sampled_features = torch.nn.functional.grid_sample(
            grid.expand(batch_size, -1, -1, -1, -1),
            coordinates.reshape(batch_size, 1, 1, -1, n_dims),
            mode='bilinear',
            padding_mode='zeros',
            align_corners=False)
        N, C, H, W, D = sampled_features.shape
        sampled_features = sampled_features.permute(0, 4, 3, 2, 1).reshape(
            N, H * W * D, C)
        return sampled_features