File size: 2,676 Bytes
4b532c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import torch
from torch.autograd import Function

from ..utils import ext_loader

ext_module = ext_loader.load_ext('_ext', ['knn_forward'])


class KNN(Function):
    r"""KNN (CUDA) based on heap data structure.

    Modified from `PAConv <https://github.com/CVMI-Lab/PAConv/tree/main/

    scene_seg/lib/pointops/src/knnquery_heap>`_.



    Find k-nearest points.

    """

    @staticmethod
    def forward(ctx,

                k: int,

                xyz: torch.Tensor,

                center_xyz: torch.Tensor = None,

                transposed: bool = False) -> torch.Tensor:
        """

        Args:

            k (int): number of nearest neighbors.

            xyz (Tensor): (B, N, 3) if transposed == False, else (B, 3, N).

                xyz coordinates of the features.

            center_xyz (Tensor, optional): (B, npoint, 3) if transposed ==

                False, else (B, 3, npoint). centers of the knn query.

                Default: None.

            transposed (bool, optional): whether the input tensors are

                transposed. Should not explicitly use this keyword when

                calling knn (=KNN.apply), just add the fourth param.

                Default: False.



        Returns:

            Tensor: (B, k, npoint) tensor with the indices of

                the features that form k-nearest neighbours.

        """
        assert (k > 0) & (k < 100), 'k should be in range(0, 100)'

        if center_xyz is None:
            center_xyz = xyz

        if transposed:
            xyz = xyz.transpose(2, 1).contiguous()
            center_xyz = center_xyz.transpose(2, 1).contiguous()

        assert xyz.is_contiguous()  # [B, N, 3]
        assert center_xyz.is_contiguous()  # [B, npoint, 3]

        center_xyz_device = center_xyz.get_device()
        assert center_xyz_device == xyz.get_device(), \
            'center_xyz and xyz should be put on the same device'
        if torch.cuda.current_device() != center_xyz_device:
            torch.cuda.set_device(center_xyz_device)

        B, npoint, _ = center_xyz.shape
        N = xyz.shape[1]

        idx = center_xyz.new_zeros((B, npoint, k)).int()
        dist2 = center_xyz.new_zeros((B, npoint, k)).float()

        ext_module.knn_forward(
            xyz, center_xyz, idx, dist2, b=B, n=N, m=npoint, nsample=k)
        # idx shape to [B, k, npoint]
        idx = idx.transpose(2, 1).contiguous()
        if torch.__version__ != 'parrots':
            ctx.mark_non_differentiable(idx)
        return idx

    @staticmethod
    def backward(ctx, a=None):
        return None, None, None


knn = KNN.apply