File size: 4,344 Bytes
b944fa1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
from torch.autograd import Function

from ..utils import ext_loader

ext_module = ext_loader.load_ext(
    '_ext', ['assign_score_withk_forward', 'assign_score_withk_backward'])


class AssignScoreWithK(Function):
    r"""Perform weighted sum to generate output features according to scores.
    Modified from `PAConv <https://github.com/CVMI-Lab/PAConv/tree/main/
    scene_seg/lib/paconv_lib/src/gpu>`_.

    This is a memory-efficient CUDA implementation of assign_scores operation,
    which first transform all point features with weight bank, then assemble
    neighbor features with ``knn_idx`` and perform weighted sum of ``scores``.

    See the `paper <https://arxiv.org/pdf/2103.14635.pdf>`_ appendix Sec. D for
        more detailed descriptions.

    Note:
        This implementation assumes using ``neighbor`` kernel input, which is
            (point_features - center_features, point_features).
        See https://github.com/CVMI-Lab/PAConv/blob/main/scene_seg/model/
        pointnet2/paconv.py#L128 for more details.
    """

    @staticmethod
    def forward(ctx,
                scores,
                point_features,
                center_features,
                knn_idx,
                aggregate='sum'):
        """
        Args:
            scores (torch.Tensor): (B, npoint, K, M), predicted scores to
                aggregate weight matrices in the weight bank.
                ``npoint`` is the number of sampled centers.
                ``K`` is the number of queried neighbors.
                ``M`` is the number of weight matrices in the weight bank.
            point_features (torch.Tensor): (B, N, M, out_dim)
                Pre-computed point features to be aggregated.
            center_features (torch.Tensor): (B, N, M, out_dim)
                Pre-computed center features to be aggregated.
            knn_idx (torch.Tensor): (B, npoint, K), index of sampled kNN.
                We assume the first idx in each row is the idx of the center.
            aggregate (str, optional): Aggregation method.
                Can be 'sum', 'avg' or 'max'. Defaults: 'sum'.

        Returns:
            torch.Tensor: (B, out_dim, npoint, K), the aggregated features.
        """
        agg = {'sum': 0, 'avg': 1, 'max': 2}

        B, N, M, out_dim = point_features.size()
        _, npoint, K, _ = scores.size()

        output = point_features.new_zeros((B, out_dim, npoint, K))
        ext_module.assign_score_withk_forward(
            point_features.contiguous(),
            center_features.contiguous(),
            scores.contiguous(),
            knn_idx.contiguous(),
            output,
            B=B,
            N0=N,
            N1=npoint,
            M=M,
            K=K,
            O=out_dim,
            aggregate=agg[aggregate])

        ctx.save_for_backward(output, point_features, center_features, scores,
                              knn_idx)
        ctx.agg = agg[aggregate]

        return output

    @staticmethod
    def backward(ctx, grad_out):
        """
        Args:
            grad_out (torch.Tensor): (B, out_dim, npoint, K)

        Returns:
            grad_scores (torch.Tensor): (B, npoint, K, M)
            grad_point_features (torch.Tensor): (B, N, M, out_dim)
            grad_center_features (torch.Tensor): (B, N, M, out_dim)
        """
        _, point_features, center_features, scores, knn_idx = ctx.saved_tensors

        agg = ctx.agg

        B, N, M, out_dim = point_features.size()
        _, npoint, K, _ = scores.size()

        grad_point_features = point_features.new_zeros(point_features.shape)
        grad_center_features = center_features.new_zeros(center_features.shape)
        grad_scores = scores.new_zeros(scores.shape)

        ext_module.assign_score_withk_backward(
            grad_out.contiguous(),
            point_features.contiguous(),
            center_features.contiguous(),
            scores.contiguous(),
            knn_idx.contiguous(),
            grad_point_features,
            grad_center_features,
            grad_scores,
            B=B,
            N0=N,
            N1=npoint,
            M=M,
            K=K,
            O=out_dim,
            aggregate=agg)

        return grad_scores, grad_point_features, \
            grad_center_features, None, None


assign_score_withk = AssignScoreWithK.apply