from typing import Tuple import torch from torch.autograd import Function from ..utils import ext_loader ext_module = ext_loader.load_ext('_ext', ['three_nn_forward']) class ThreeNN(Function): """Find the top-3 nearest neighbors of the target set from the source set. Please refer to `Paper of PointNet++ `_ for more details. """ @staticmethod def forward(ctx, target: torch.Tensor, source: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: target (Tensor): shape (B, N, 3), points set that needs to find the nearest neighbors. source (Tensor): shape (B, M, 3), points set that is used to find the nearest neighbors of points in target set. Returns: Tensor: shape (B, N, 3), L2 distance of each point in target set to their corresponding nearest neighbors. """ target = target.contiguous() source = source.contiguous() B, N, _ = target.size() m = source.size(1) dist2 = torch.cuda.FloatTensor(B, N, 3) idx = torch.cuda.IntTensor(B, N, 3) ext_module.three_nn_forward(target, source, dist2, idx, b=B, n=N, m=m) if torch.__version__ != 'parrots': ctx.mark_non_differentiable(idx) return torch.sqrt(dist2), idx @staticmethod def backward(ctx, a=None, b=None): return None, None three_nn = ThreeNN.apply