import torch import torch.nn.functional as F def estimate_surface_normal( points: torch.Tensor, d: int = 2, mode: str = "closest" ) -> torch.Tensor: # estimate surface normal from coordinated point clouds # re-implemented the following codes with pytorch: # https://github.com/wkentaro/morefusion/blob/master/morefusion/geometry/estimate_pointcloud_normals.py # https://github.com/jmccormac/pySceneNetRGBD/blob/master/calculate_surface_normals.py assert points.dim() == 4, f"expected (B,3,H,W), but got {points.shape}" B, C, H, W = points.shape assert C == 3, f"expected C==3, but got {C}" device = points.device # points = F.pad(points, (0, 0, d, d), mode="constant", value=float("inf")) points = F.pad(points, (0, 0, d, d), mode="replicate") points = F.pad(points, (d, d, 0, 0), mode="circular") points = points.permute(0, 2, 3, 1) # (B,H,W,3) # 8 adjacent offsets # ----------- # | 7 | 6 | 5 | # ----------- # | 0 | | 4 | # ----------- # | 1 | 2 | 3 | # ----------- offsets = torch.tensor( [ # (dh,dw) (-d, 0), # 0 (-d, d), # 1 (0, d), # 2 (d, d), # 3 (d, 0), # 4 (d, -d), # 5 (0, -d), # 6 (-d, -d), # 7 ], device=device, ) # (B,H,W) indices b = torch.arange(B, device=device)[:, None, None] h = torch.arange(H, device=device)[None, :, None] w = torch.arange(W, device=device)[None, None, :] k = torch.arange(8, device=device) # anchor points b1 = b[:, None] # (B,1,1,1) h1 = h[:, None] + d # (1,1,H,1) w1 = w[:, None] + d # (1,1,1,W) anchors = points[b1, h1, w1] # (B,H,W,3) -> (B,1,H,W,3) # neighbor points offset = offsets[k] # (8,2) b2 = b1 h2 = h1 + offset[None, :, 0, None, None] # (1,8,H,1) w2 = w1 + offset[None, :, 1, None, None] # (1,8,1,W) points1 = points[b2, h2, w2] # (B,8,H,W,3) # anothor neighbor points offset = offsets[(k + 2) % 8] b3 = b1 h3 = h1 + offset[None, :, 0, None, None] w3 = w1 + offset[None, :, 1, None, None] points2 = points[b3, h3, w3] # (B,8,H,W,3) if mode == "closest": # find the closest neighbor pair diff = torch.norm(points1 - anchors, dim=4) diff = diff + torch.norm(points2 - anchors, dim=4) i = torch.argmin(diff, dim=1) # (B,H,W) # get normals by cross product anchors = anchors[b, 0, h, w] # (B,H,W,3) points1 = points1[b, i, h, w] # (B,H,W,3) points2 = points2[b, i, h, w] # (B,H,W,3) vector1 = points1 - anchors vector2 = points2 - anchors normals = torch.cross(vector1, vector2, dim=-1) # (B,H,W,3) elif mode == "mean": # get normals by cross product vector1 = points1 - anchors vector2 = points2 - anchors normals = torch.cross(vector1, vector2, dim=-1) # (B,8,H,W,3) normals = normals.mean(dim=1) # (B,H,W,3) else: raise NotImplementedError(mode) normals = normals / (torch.norm(normals, dim=3, keepdim=True) + 1e-8) normals = normals.permute(0, 3, 1, 2) # (B,3,H,W) return normals