ziqima's picture
initial commit
4893ce0
import torch
from torch.autograd import Function
import pointgroup_ops_cuda
class BallQueryBatchP(Function):
@staticmethod
def forward(ctx, coords, batch_idxs, batch_offsets, radius, meanActive):
"""
:param ctx:
:param coords: (n, 3) float
:param batch_idxs: (n) int
:param batch_offsets: (B+1) int
:param radius: float
:param meanActive: int
:return: idx (nActive), int
:return: start_len (n, 2), int
"""
n = coords.size(0)
assert coords.is_contiguous() and coords.is_cuda
assert batch_idxs.is_contiguous() and batch_idxs.is_cuda
assert batch_offsets.is_contiguous() and batch_offsets.is_cuda
while True:
idx = torch.cuda.IntTensor(n * meanActive).zero_()
start_len = torch.cuda.IntTensor(n, 2).zero_()
nActive = pointgroup_ops_cuda.ballquery_batch_p(
coords, batch_idxs, batch_offsets, idx, start_len, n, meanActive, radius
)
if nActive <= n * meanActive:
break
meanActive = int(nActive // n + 1)
idx = idx[:nActive]
return idx, start_len
@staticmethod
def backward(ctx, a=None, b=None):
return None, None, None
ballquery_batch_p = BallQueryBatchP.apply
class Clustering:
def __init__(
self,
ignored_labels,
class_mapping,
thresh=0.03,
closed_points=300,
min_points=50,
propose_points=100,
score_func=torch.max,
) -> None:
self.ignored_labels = ignored_labels
self.thresh = thresh
self.closed_points = closed_points
self.min_points = min_points
self.class_mapping = class_mapping
self.propose_points = propose_points
self.score_func = score_func
def cluster(self, vertices, scores):
labels = torch.max(scores, 1)[1] # (N) long, cuda
proposals_idx, proposals_offset = self.cluster_(vertices, labels)
## debug
# import ipdb; ipdb.set_trace()
# colors = np.array(create_color_palette())[labels.cpu()]
# write_triangle_mesh(vertices, colors, None, 'semantics.ply')
# scatter
proposals_pred = torch.zeros(
(proposals_offset.shape[0] - 1, vertices.shape[0]), dtype=torch.int
) # (nProposal, N), int, cuda
proposals_pred[proposals_idx[:, 0].long(), proposals_idx[:, 1].long()] = 1
labels = labels[proposals_idx[:, 1][proposals_offset[:-1].long()].long()]
proposals_pointnum = proposals_pred.sum(1)
npoint_mask = proposals_pointnum > self.propose_points
proposals_pred = proposals_pred[npoint_mask]
labels = labels[npoint_mask]
return proposals_pred, labels
def cluster_(self, vertices, labels):
"""
:param batch_idxs: (N), int, cuda
:labels: 0-19
"""
batch_idxs = torch.zeros_like(labels)
mask_non_ignored = torch.ones_like(labels).bool()
for ignored_label in self.ignored_labels:
mask_non_ignored = mask_non_ignored & (
self.class_mapping[labels] != ignored_label
)
object_idxs = mask_non_ignored.nonzero().view(-1)
vertices_ = vertices[object_idxs].float()
labels_ = labels[object_idxs].int()
if vertices_.numel() == 0:
return torch.zeros((0, 2)).int(), torch.zeros(1).int()
batch_idxs_ = batch_idxs[object_idxs].int()
batch_offsets_ = torch.FloatTensor([0, object_idxs.shape[0]]).int().cuda()
idx, start_len = ballquery_batch_p(
vertices_, batch_idxs_, batch_offsets_, self.thresh, self.closed_points
)
proposals_idx, proposals_offset = bfs_cluster(
labels_.cpu(), idx.cpu(), start_len.cpu(), self.min_points
)
proposals_idx[:, 1] = object_idxs[proposals_idx[:, 1].long()].int()
return proposals_idx, proposals_offset
def get_instances(self, vertices, scores):
proposals_pred, labels = self.cluster(vertices, scores)
instances = {}
for proposal_id in range(len(proposals_pred)):
clusters_i = proposals_pred[proposal_id]
score = scores[clusters_i.bool(), labels[proposal_id]]
score = self.score_func(score)
instances[proposal_id] = {}
instances[proposal_id]["conf"] = score.cpu().numpy()
instances[proposal_id]["label_id"] = self.class_mapping.cpu()[
labels[proposal_id]
]
instances[proposal_id]["pred_mask"] = clusters_i.cpu().numpy()
return instances
class BFSCluster(Function):
@staticmethod
def forward(ctx, semantic_label, ball_query_idxs, start_len, threshold):
"""
:param ctx:
:param semantic_label: (N), int
:param ball_query_idxs: (nActive), int
:param start_len: (N, 2), int
:return: cluster_idxs: int (sumNPoint, 2), dim 0 for cluster_id, dim 1 for corresponding point idxs in N
:return: cluster_offsets: int (nCluster + 1)
"""
N = start_len.size(0)
assert semantic_label.is_contiguous()
assert ball_query_idxs.is_contiguous()
assert start_len.is_contiguous()
cluster_idxs = semantic_label.new()
cluster_offsets = semantic_label.new()
pointgroup_ops_cuda.bfs_cluster(
semantic_label,
ball_query_idxs,
start_len,
cluster_idxs,
cluster_offsets,
N,
threshold,
)
return cluster_idxs, cluster_offsets
@staticmethod
def backward(ctx, a=None):
return None
bfs_cluster = BFSCluster.apply