import torch from pointops import knn_query, ball_query, grouping def knn_query_and_group( feat, xyz, offset=None, new_xyz=None, new_offset=None, idx=None, nsample=None, with_xyz=False, ): if idx is None: assert nsample is not None idx, _ = knn_query(nsample, xyz, offset, new_xyz, new_offset) return grouping(idx, feat, xyz, new_xyz, with_xyz), idx def ball_query_and_group( feat, xyz, offset=None, new_xyz=None, new_offset=None, idx=None, max_radio=None, min_radio=0, nsample=None, with_xyz=False, ): if idx is None: assert nsample is not None and offset is not None assert max_radio is not None and min_radio is not None idx, _ = ball_query( nsample, max_radio, min_radio, xyz, offset, new_xyz, new_offset ) return grouping(idx, feat, xyz, new_xyz, with_xyz), idx def query_and_group( nsample, xyz, new_xyz, feat, idx, offset, new_offset, dilation=0, with_feat=True, with_xyz=True, ): """ input: coords: (n, 3), new_xyz: (m, 3), color: (n, c), idx: (m, nsample), offset: (b), new_offset: (b) output: new_feat: (m, nsample, c+3), grouped_idx: (m, nsample) """ assert xyz.is_contiguous() and new_xyz.is_contiguous() and feat.is_contiguous() if new_xyz is None: new_xyz = xyz if idx is None: num_samples_total = 1 + (nsample - 1) * (dilation + 1) # num points in a batch might < num_samples_total => [n1, n2, ..., nk, ns, ns, ns, ...] idx_no_dilation, _ = knn_query( num_samples_total, xyz, offset, new_xyz, new_offset ) # (m, nsample * (d + 1)) idx = [] batch_end = offset.tolist() batch_start = [0] + batch_end[:-1] new_batch_end = new_offset.tolist() new_batch_start = [0] + new_batch_end[:-1] for i in range(offset.shape[0]): if batch_end[i] - batch_start[i] < num_samples_total: soft_dilation = (batch_end[i] - batch_start[i] - 1) / (nsample - 1) - 1 else: soft_dilation = dilation idx.append( idx_no_dilation[ new_batch_start[i] : new_batch_end[i], [int((soft_dilation + 1) * i) for i in range(nsample)], ] ) idx = torch.cat(idx, dim=0) if not with_feat: return idx n, m, c = xyz.shape[0], new_xyz.shape[0], feat.shape[1] grouped_xyz = xyz[idx.view(-1).long(), :].view(m, nsample, 3) # (m, nsample, 3) # grouped_xyz = grouping(coords, idx) # (m, nsample, 3) grouped_xyz -= new_xyz.unsqueeze(1) # (m, nsample, 3) grouped_feat = feat[idx.view(-1).long(), :].view(m, nsample, c) # (m, nsample, c) # grouped_feat = grouping(color, idx) # (m, nsample, c) if with_xyz: return torch.cat((grouped_xyz, grouped_feat), -1), idx # (m, nsample, 3+c) else: return grouped_feat, idx def offset2batch(offset): return ( torch.cat( [ ( torch.tensor([i] * (o - offset[i - 1])) if i > 0 else torch.tensor([i] * o) ) for i, o in enumerate(offset) ], dim=0, ) .long() .to(offset.device) ) def batch2offset(batch): return torch.cumsum(batch.bincount(), dim=0).int()