File size: 3,398 Bytes
57746f1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import torch
from torch.autograd import Function

from pointops._C import knn_query_cuda, random_ball_query_cuda, ball_query_cuda


class KNNQuery(Function):
    @staticmethod
    def forward(ctx, nsample, xyz, offset, new_xyz=None, new_offset=None):
        """
        input: coords: (n, 3), new_xyz: (m, 3), offset: (b), new_offset: (b)
        output: idx: (m, nsample) -1 is placeholder, dist2: (m, nsample)
        """
        if new_xyz is None or new_offset is None:
            new_xyz = xyz
            new_offset = offset
        assert xyz.is_contiguous() and new_xyz.is_contiguous()
        m = new_xyz.shape[0]
        idx = torch.cuda.IntTensor(m, nsample).zero_()
        dist2 = torch.cuda.FloatTensor(m, nsample).zero_()
        knn_query_cuda(
            m, nsample, xyz, new_xyz, offset.int(), new_offset.int(), idx, dist2
        )
        return idx, torch.sqrt(dist2)


class RandomBallQuery(Function):
    """Random Ball Query.

    Find nearby points in spherical space.
    """

    @staticmethod
    def forward(
        ctx, nsample, max_radius, min_radius, xyz, offset, new_xyz=None, new_offset=None
    ):
        """
        input: coords: (n, 3), new_xyz: (m, 3), offset: (b), new_offset: (b)
        output: idx: (m, nsample), dist2: (m, nsample)
        """
        if new_xyz is None or new_offset is None:
            new_xyz = xyz
            new_offset = offset
        assert xyz.is_contiguous() and new_xyz.is_contiguous()
        assert min_radius < max_radius

        m = new_xyz.shape[0]
        order = []
        for k in range(offset.shape[0]):
            s_k, e_k = (0, offset[0]) if k == 0 else (offset[k - 1], offset[k])
            order.append(
                torch.randperm(e_k - s_k, dtype=torch.int32, device=offset.device) + s_k
            )
        order = torch.cat(order, dim=0)
        idx = torch.cuda.IntTensor(m, nsample).zero_()
        dist2 = torch.cuda.FloatTensor(m, nsample).zero_()
        random_ball_query_cuda(
            m,
            nsample,
            min_radius,
            max_radius,
            order,
            xyz,
            new_xyz,
            offset.int(),
            new_offset.int(),
            idx,
            dist2,
        )
        return idx, torch.sqrt(dist2)


class BallQuery(Function):
    """Ball Query.

    Find nearby points in spherical space.
    """

    @staticmethod
    def forward(
        ctx, nsample, max_radius, min_radius, xyz, offset, new_xyz=None, new_offset=None
    ):
        """
        input: coords: (n, 3), new_xyz: (m, 3), offset: (b), new_offset: (b)
        output: idx: (m, nsample), dist2: (m, nsample)
        """
        if new_xyz is None or new_offset is None:
            new_xyz = xyz
            new_offset = offset
        assert xyz.is_contiguous() and new_xyz.is_contiguous()
        assert min_radius < max_radius

        m = new_xyz.shape[0]
        idx = torch.cuda.IntTensor(m, nsample).zero_()
        dist2 = torch.cuda.FloatTensor(m, nsample).zero_()
        ball_query_cuda(
            m,
            nsample,
            min_radius,
            max_radius,
            xyz,
            new_xyz,
            offset.int(),
            new_offset.int(),
            idx,
            dist2,
        )
        return idx, torch.sqrt(dist2)


knn_query = KNNQuery.apply
ball_query = BallQuery.apply
random_ball_query = RandomBallQuery.apply