Anonymous-sub's picture
merge (#1)
251e479
raw
history blame
No virus
2.6 kB
import torch
from torch.autograd import Function
from ..utils import ext_loader
ext_module = ext_loader.load_ext('_ext', ['knn_forward'])
class KNN(Function):
r"""KNN (CUDA) based on heap data structure.
Modified from `PAConv <https://github.com/CVMI-Lab/PAConv/tree/main/
scene_seg/lib/pointops/src/knnquery_heap>`_.
Find k-nearest points.
"""
@staticmethod
def forward(ctx,
k: int,
xyz: torch.Tensor,
center_xyz: torch.Tensor = None,
transposed: bool = False) -> torch.Tensor:
"""
Args:
k (int): number of nearest neighbors.
xyz (Tensor): (B, N, 3) if transposed == False, else (B, 3, N).
xyz coordinates of the features.
center_xyz (Tensor, optional): (B, npoint, 3) if transposed ==
False, else (B, 3, npoint). centers of the knn query.
Default: None.
transposed (bool, optional): whether the input tensors are
transposed. Should not explicitly use this keyword when
calling knn (=KNN.apply), just add the fourth param.
Default: False.
Returns:
Tensor: (B, k, npoint) tensor with the indices of
the features that form k-nearest neighbours.
"""
assert (k > 0) & (k < 100), 'k should be in range(0, 100)'
if center_xyz is None:
center_xyz = xyz
if transposed:
xyz = xyz.transpose(2, 1).contiguous()
center_xyz = center_xyz.transpose(2, 1).contiguous()
assert xyz.is_contiguous() # [B, N, 3]
assert center_xyz.is_contiguous() # [B, npoint, 3]
center_xyz_device = center_xyz.get_device()
assert center_xyz_device == xyz.get_device(), \
'center_xyz and xyz should be put on the same device'
if torch.cuda.current_device() != center_xyz_device:
torch.cuda.set_device(center_xyz_device)
B, npoint, _ = center_xyz.shape
N = xyz.shape[1]
idx = center_xyz.new_zeros((B, npoint, k)).int()
dist2 = center_xyz.new_zeros((B, npoint, k)).float()
ext_module.knn_forward(
xyz, center_xyz, idx, dist2, b=B, n=N, m=npoint, nsample=k)
# idx shape to [B, k, npoint]
idx = idx.transpose(2, 1).contiguous()
if torch.__version__ != 'parrots':
ctx.mark_non_differentiable(idx)
return idx
@staticmethod
def backward(ctx, a=None):
return None, None, None
knn = KNN.apply