|
import torch |
|
from torch.autograd import Function |
|
|
|
from pointops._C import knn_query_cuda, random_ball_query_cuda, ball_query_cuda |
|
|
|
|
|
class KNNQuery(Function): |
|
@staticmethod |
|
def forward(ctx, nsample, xyz, offset, new_xyz=None, new_offset=None): |
|
""" |
|
input: coords: (n, 3), new_xyz: (m, 3), offset: (b), new_offset: (b) |
|
output: idx: (m, nsample) -1 is placeholder, dist2: (m, nsample) |
|
""" |
|
if new_xyz is None or new_offset is None: |
|
new_xyz = xyz |
|
new_offset = offset |
|
assert xyz.is_contiguous() and new_xyz.is_contiguous() |
|
m = new_xyz.shape[0] |
|
idx = torch.cuda.IntTensor(m, nsample).zero_() |
|
dist2 = torch.cuda.FloatTensor(m, nsample).zero_() |
|
knn_query_cuda( |
|
m, nsample, xyz, new_xyz, offset.int(), new_offset.int(), idx, dist2 |
|
) |
|
return idx, torch.sqrt(dist2) |
|
|
|
|
|
class RandomBallQuery(Function): |
|
"""Random Ball Query. |
|
|
|
Find nearby points in spherical space. |
|
""" |
|
|
|
@staticmethod |
|
def forward( |
|
ctx, nsample, max_radius, min_radius, xyz, offset, new_xyz=None, new_offset=None |
|
): |
|
""" |
|
input: coords: (n, 3), new_xyz: (m, 3), offset: (b), new_offset: (b) |
|
output: idx: (m, nsample), dist2: (m, nsample) |
|
""" |
|
if new_xyz is None or new_offset is None: |
|
new_xyz = xyz |
|
new_offset = offset |
|
assert xyz.is_contiguous() and new_xyz.is_contiguous() |
|
assert min_radius < max_radius |
|
|
|
m = new_xyz.shape[0] |
|
order = [] |
|
for k in range(offset.shape[0]): |
|
s_k, e_k = (0, offset[0]) if k == 0 else (offset[k - 1], offset[k]) |
|
order.append( |
|
torch.randperm(e_k - s_k, dtype=torch.int32, device=offset.device) + s_k |
|
) |
|
order = torch.cat(order, dim=0) |
|
idx = torch.cuda.IntTensor(m, nsample).zero_() |
|
dist2 = torch.cuda.FloatTensor(m, nsample).zero_() |
|
random_ball_query_cuda( |
|
m, |
|
nsample, |
|
min_radius, |
|
max_radius, |
|
order, |
|
xyz, |
|
new_xyz, |
|
offset.int(), |
|
new_offset.int(), |
|
idx, |
|
dist2, |
|
) |
|
return idx, torch.sqrt(dist2) |
|
|
|
|
|
class BallQuery(Function): |
|
"""Ball Query. |
|
|
|
Find nearby points in spherical space. |
|
""" |
|
|
|
@staticmethod |
|
def forward( |
|
ctx, nsample, max_radius, min_radius, xyz, offset, new_xyz=None, new_offset=None |
|
): |
|
""" |
|
input: coords: (n, 3), new_xyz: (m, 3), offset: (b), new_offset: (b) |
|
output: idx: (m, nsample), dist2: (m, nsample) |
|
""" |
|
if new_xyz is None or new_offset is None: |
|
new_xyz = xyz |
|
new_offset = offset |
|
assert xyz.is_contiguous() and new_xyz.is_contiguous() |
|
assert min_radius < max_radius |
|
|
|
m = new_xyz.shape[0] |
|
idx = torch.cuda.IntTensor(m, nsample).zero_() |
|
dist2 = torch.cuda.FloatTensor(m, nsample).zero_() |
|
ball_query_cuda( |
|
m, |
|
nsample, |
|
min_radius, |
|
max_radius, |
|
xyz, |
|
new_xyz, |
|
offset.int(), |
|
new_offset.int(), |
|
idx, |
|
dist2, |
|
) |
|
return idx, torch.sqrt(dist2) |
|
|
|
|
|
knn_query = KNNQuery.apply |
|
ball_query = BallQuery.apply |
|
random_ball_query = RandomBallQuery.apply |
|
|