Spaces:
Sleeping
Sleeping
import torch | |
from einops import repeat | |
def sample_farthest_points(pts, k, return_index=False): | |
b, c, n = pts.shape | |
farthest_pts = torch.zeros((b, 3, k), device=pts.device, dtype=pts.dtype) | |
indexes = torch.zeros((b, k), device=pts.device, dtype=torch.int64) | |
index = torch.randint(n, [b], device=pts.device) | |
gather_index = repeat(index, 'b -> b c 1', c=c) | |
farthest_pts[:, :, 0] = torch.gather(pts, 2, gather_index)[:, :, 0] | |
indexes[:, 0] = index | |
distances = torch.norm(farthest_pts[:, :, 0][:, :, None] - pts, dim=1) | |
for i in range(1, k): | |
_, index = torch.max(distances, dim=1) | |
gather_index = repeat(index, 'b -> b c 1', c=c) | |
farthest_pts[:, :, i] = torch.gather(pts, 2, gather_index)[:, :, 0] | |
indexes[:, i] = index | |
distances = torch.min(distances, torch.norm(farthest_pts[:, :, i][:, :, None] - pts, dim=1)) | |
if return_index: | |
return farthest_pts, indexes | |
else: | |
return farthest_pts | |
def line_segment_distance(a, b, points, sqrt=True): | |
""" | |
compute the distance between a point and a line segment defined by a and b | |
a, b: ... x D | |
points: ... x D | |
""" | |
def sumprod(x, y, keepdim=True): | |
return torch.sum(x * y, dim=-1, keepdim=keepdim) | |
a, b = a[..., None, :], b[..., None, :] | |
t_min = sumprod(points - a, b - a) / torch.max(sumprod(b - a, b - a), torch.tensor(1e-6, device=a.device)) | |
t_line = torch.clamp(t_min, 0.0, 1.0) | |
# closest points on the line to every point | |
s = a + t_line * (b - a) | |
distance = sumprod(s - points, s - points, keepdim=False) | |
if sqrt: | |
distance = torch.sqrt(distance + 1e-6) | |
return distance | |