Anonymous-sub's picture
merge (#1)
251e479
raw
history blame
No virus
4.34 kB
from torch.autograd import Function
from ..utils import ext_loader
ext_module = ext_loader.load_ext(
'_ext', ['assign_score_withk_forward', 'assign_score_withk_backward'])
class AssignScoreWithK(Function):
r"""Perform weighted sum to generate output features according to scores.
Modified from `PAConv <https://github.com/CVMI-Lab/PAConv/tree/main/
scene_seg/lib/paconv_lib/src/gpu>`_.
This is a memory-efficient CUDA implementation of assign_scores operation,
which first transform all point features with weight bank, then assemble
neighbor features with ``knn_idx`` and perform weighted sum of ``scores``.
See the `paper <https://arxiv.org/pdf/2103.14635.pdf>`_ appendix Sec. D for
more detailed descriptions.
Note:
This implementation assumes using ``neighbor`` kernel input, which is
(point_features - center_features, point_features).
See https://github.com/CVMI-Lab/PAConv/blob/main/scene_seg/model/
pointnet2/paconv.py#L128 for more details.
"""
@staticmethod
def forward(ctx,
scores,
point_features,
center_features,
knn_idx,
aggregate='sum'):
"""
Args:
scores (torch.Tensor): (B, npoint, K, M), predicted scores to
aggregate weight matrices in the weight bank.
``npoint`` is the number of sampled centers.
``K`` is the number of queried neighbors.
``M`` is the number of weight matrices in the weight bank.
point_features (torch.Tensor): (B, N, M, out_dim)
Pre-computed point features to be aggregated.
center_features (torch.Tensor): (B, N, M, out_dim)
Pre-computed center features to be aggregated.
knn_idx (torch.Tensor): (B, npoint, K), index of sampled kNN.
We assume the first idx in each row is the idx of the center.
aggregate (str, optional): Aggregation method.
Can be 'sum', 'avg' or 'max'. Defaults: 'sum'.
Returns:
torch.Tensor: (B, out_dim, npoint, K), the aggregated features.
"""
agg = {'sum': 0, 'avg': 1, 'max': 2}
B, N, M, out_dim = point_features.size()
_, npoint, K, _ = scores.size()
output = point_features.new_zeros((B, out_dim, npoint, K))
ext_module.assign_score_withk_forward(
point_features.contiguous(),
center_features.contiguous(),
scores.contiguous(),
knn_idx.contiguous(),
output,
B=B,
N0=N,
N1=npoint,
M=M,
K=K,
O=out_dim,
aggregate=agg[aggregate])
ctx.save_for_backward(output, point_features, center_features, scores,
knn_idx)
ctx.agg = agg[aggregate]
return output
@staticmethod
def backward(ctx, grad_out):
"""
Args:
grad_out (torch.Tensor): (B, out_dim, npoint, K)
Returns:
grad_scores (torch.Tensor): (B, npoint, K, M)
grad_point_features (torch.Tensor): (B, N, M, out_dim)
grad_center_features (torch.Tensor): (B, N, M, out_dim)
"""
_, point_features, center_features, scores, knn_idx = ctx.saved_tensors
agg = ctx.agg
B, N, M, out_dim = point_features.size()
_, npoint, K, _ = scores.size()
grad_point_features = point_features.new_zeros(point_features.shape)
grad_center_features = center_features.new_zeros(center_features.shape)
grad_scores = scores.new_zeros(scores.shape)
ext_module.assign_score_withk_backward(
grad_out.contiguous(),
point_features.contiguous(),
center_features.contiguous(),
scores.contiguous(),
knn_idx.contiguous(),
grad_point_features,
grad_center_features,
grad_scores,
B=B,
N0=N,
N1=npoint,
M=M,
K=K,
O=out_dim,
aggregate=agg)
return grad_scores, grad_point_features, \
grad_center_features, None, None
assign_score_withk = AssignScoreWithK.apply