|
import torch |
|
from einops import repeat |
|
|
|
|
|
def sample_farthest_points(pts, k, return_index=False): |
|
b, c, n = pts.shape |
|
farthest_pts = torch.zeros((b, 3, k), device=pts.device, dtype=pts.dtype) |
|
indexes = torch.zeros((b, k), device=pts.device, dtype=torch.int64) |
|
|
|
index = torch.randint(n, [b], device=pts.device) |
|
|
|
gather_index = repeat(index, 'b -> b c 1', c=c) |
|
farthest_pts[:, :, 0] = torch.gather(pts, 2, gather_index)[:, :, 0] |
|
indexes[:, 0] = index |
|
distances = torch.norm(farthest_pts[:, :, 0][:, :, None] - pts, dim=1) |
|
|
|
for i in range(1, k): |
|
_, index = torch.max(distances, dim=1) |
|
gather_index = repeat(index, 'b -> b c 1', c=c) |
|
farthest_pts[:, :, i] = torch.gather(pts, 2, gather_index)[:, :, 0] |
|
indexes[:, i] = index |
|
distances = torch.min(distances, torch.norm(farthest_pts[:, :, i][:, :, None] - pts, dim=1)) |
|
|
|
if return_index: |
|
return farthest_pts, indexes |
|
else: |
|
return farthest_pts |
|
|
|
|
|
def line_segment_distance(a, b, points, sqrt=True): |
|
""" |
|
compute the distance between a point and a line segment defined by a and b |
|
a, b: ... x D |
|
points: ... x D |
|
""" |
|
def sumprod(x, y, keepdim=True): |
|
return torch.sum(x * y, dim=-1, keepdim=keepdim) |
|
|
|
a, b = a[..., None, :], b[..., None, :] |
|
|
|
t_min = sumprod(points - a, b - a) / torch.max(sumprod(b - a, b - a), torch.tensor(1e-6, device=a.device)) |
|
|
|
t_line = torch.clamp(t_min, 0.0, 1.0) |
|
|
|
|
|
s = a + t_line * (b - a) |
|
|
|
distance = sumprod(s - points, s - points, keepdim=False) |
|
|
|
if sqrt: |
|
distance = torch.sqrt(distance + 1e-6) |
|
|
|
return distance |
|
|