Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
from torch.autograd import Function | |
import pointgroup_ops_cuda | |
class BallQueryBatchP(Function): | |
def forward(ctx, coords, batch_idxs, batch_offsets, radius, meanActive): | |
""" | |
:param ctx: | |
:param coords: (n, 3) float | |
:param batch_idxs: (n) int | |
:param batch_offsets: (B+1) int | |
:param radius: float | |
:param meanActive: int | |
:return: idx (nActive), int | |
:return: start_len (n, 2), int | |
""" | |
n = coords.size(0) | |
assert coords.is_contiguous() and coords.is_cuda | |
assert batch_idxs.is_contiguous() and batch_idxs.is_cuda | |
assert batch_offsets.is_contiguous() and batch_offsets.is_cuda | |
while True: | |
idx = torch.cuda.IntTensor(n * meanActive).zero_() | |
start_len = torch.cuda.IntTensor(n, 2).zero_() | |
nActive = pointgroup_ops_cuda.ballquery_batch_p( | |
coords, batch_idxs, batch_offsets, idx, start_len, n, meanActive, radius | |
) | |
if nActive <= n * meanActive: | |
break | |
meanActive = int(nActive // n + 1) | |
idx = idx[:nActive] | |
return idx, start_len | |
def backward(ctx, a=None, b=None): | |
return None, None, None | |
ballquery_batch_p = BallQueryBatchP.apply | |
class Clustering: | |
def __init__( | |
self, | |
ignored_labels, | |
class_mapping, | |
thresh=0.03, | |
closed_points=300, | |
min_points=50, | |
propose_points=100, | |
score_func=torch.max, | |
) -> None: | |
self.ignored_labels = ignored_labels | |
self.thresh = thresh | |
self.closed_points = closed_points | |
self.min_points = min_points | |
self.class_mapping = class_mapping | |
self.propose_points = propose_points | |
self.score_func = score_func | |
def cluster(self, vertices, scores): | |
labels = torch.max(scores, 1)[1] # (N) long, cuda | |
proposals_idx, proposals_offset = self.cluster_(vertices, labels) | |
## debug | |
# import ipdb; ipdb.set_trace() | |
# colors = np.array(create_color_palette())[labels.cpu()] | |
# write_triangle_mesh(vertices, colors, None, 'semantics.ply') | |
# scatter | |
proposals_pred = torch.zeros( | |
(proposals_offset.shape[0] - 1, vertices.shape[0]), dtype=torch.int | |
) # (nProposal, N), int, cuda | |
proposals_pred[proposals_idx[:, 0].long(), proposals_idx[:, 1].long()] = 1 | |
labels = labels[proposals_idx[:, 1][proposals_offset[:-1].long()].long()] | |
proposals_pointnum = proposals_pred.sum(1) | |
npoint_mask = proposals_pointnum > self.propose_points | |
proposals_pred = proposals_pred[npoint_mask] | |
labels = labels[npoint_mask] | |
return proposals_pred, labels | |
def cluster_(self, vertices, labels): | |
""" | |
:param batch_idxs: (N), int, cuda | |
:labels: 0-19 | |
""" | |
batch_idxs = torch.zeros_like(labels) | |
mask_non_ignored = torch.ones_like(labels).bool() | |
for ignored_label in self.ignored_labels: | |
mask_non_ignored = mask_non_ignored & ( | |
self.class_mapping[labels] != ignored_label | |
) | |
object_idxs = mask_non_ignored.nonzero().view(-1) | |
vertices_ = vertices[object_idxs].float() | |
labels_ = labels[object_idxs].int() | |
if vertices_.numel() == 0: | |
return torch.zeros((0, 2)).int(), torch.zeros(1).int() | |
batch_idxs_ = batch_idxs[object_idxs].int() | |
batch_offsets_ = torch.FloatTensor([0, object_idxs.shape[0]]).int().cuda() | |
idx, start_len = ballquery_batch_p( | |
vertices_, batch_idxs_, batch_offsets_, self.thresh, self.closed_points | |
) | |
proposals_idx, proposals_offset = bfs_cluster( | |
labels_.cpu(), idx.cpu(), start_len.cpu(), self.min_points | |
) | |
proposals_idx[:, 1] = object_idxs[proposals_idx[:, 1].long()].int() | |
return proposals_idx, proposals_offset | |
def get_instances(self, vertices, scores): | |
proposals_pred, labels = self.cluster(vertices, scores) | |
instances = {} | |
for proposal_id in range(len(proposals_pred)): | |
clusters_i = proposals_pred[proposal_id] | |
score = scores[clusters_i.bool(), labels[proposal_id]] | |
score = self.score_func(score) | |
instances[proposal_id] = {} | |
instances[proposal_id]["conf"] = score.cpu().numpy() | |
instances[proposal_id]["label_id"] = self.class_mapping.cpu()[ | |
labels[proposal_id] | |
] | |
instances[proposal_id]["pred_mask"] = clusters_i.cpu().numpy() | |
return instances | |
class BFSCluster(Function): | |
def forward(ctx, semantic_label, ball_query_idxs, start_len, threshold): | |
""" | |
:param ctx: | |
:param semantic_label: (N), int | |
:param ball_query_idxs: (nActive), int | |
:param start_len: (N, 2), int | |
:return: cluster_idxs: int (sumNPoint, 2), dim 0 for cluster_id, dim 1 for corresponding point idxs in N | |
:return: cluster_offsets: int (nCluster + 1) | |
""" | |
N = start_len.size(0) | |
assert semantic_label.is_contiguous() | |
assert ball_query_idxs.is_contiguous() | |
assert start_len.is_contiguous() | |
cluster_idxs = semantic_label.new() | |
cluster_offsets = semantic_label.new() | |
pointgroup_ops_cuda.bfs_cluster( | |
semantic_label, | |
ball_query_idxs, | |
start_len, | |
cluster_idxs, | |
cluster_offsets, | |
N, | |
threshold, | |
) | |
return cluster_idxs, cluster_offsets | |
def backward(ctx, a=None): | |
return None | |
bfs_cluster = BFSCluster.apply | |