r2dm / rendering.py
history blame
No virus
3.25 kB
import torch
import torch.nn.functional as F
def estimate_surface_normal(
points: torch.Tensor, d: int = 2, mode: str = "closest"
) -> torch.Tensor:
# estimate surface normal from coordinated point clouds
# re-implemented the following codes with pytorch:
# https://github.com/wkentaro/morefusion/blob/master/morefusion/geometry/estimate_pointcloud_normals.py
# https://github.com/jmccormac/pySceneNetRGBD/blob/master/calculate_surface_normals.py
assert points.dim() == 4, f"expected (B,3,H,W), but got {points.shape}"
B, C, H, W = points.shape
assert C == 3, f"expected C==3, but got {C}"
device = points.device
# points = F.pad(points, (0, 0, d, d), mode="constant", value=float("inf"))
points = F.pad(points, (0, 0, d, d), mode="replicate")
points = F.pad(points, (d, d, 0, 0), mode="circular")
points = points.permute(0, 2, 3, 1) # (B,H,W,3)
# 8 adjacent offsets
# -----------
# | 7 | 6 | 5 |
# -----------
# | 0 | | 4 |
# -----------
# | 1 | 2 | 3 |
# -----------
offsets = torch.tensor(
# (dh,dw)
(-d, 0), # 0
(-d, d), # 1
(0, d), # 2
(d, d), # 3
(d, 0), # 4
(d, -d), # 5
(0, -d), # 6
(-d, -d), # 7
# (B,H,W) indices
b = torch.arange(B, device=device)[:, None, None]
h = torch.arange(H, device=device)[None, :, None]
w = torch.arange(W, device=device)[None, None, :]
k = torch.arange(8, device=device)
# anchor points
b1 = b[:, None] # (B,1,1,1)
h1 = h[:, None] + d # (1,1,H,1)
w1 = w[:, None] + d # (1,1,1,W)
anchors = points[b1, h1, w1] # (B,H,W,3) -> (B,1,H,W,3)
# neighbor points
offset = offsets[k] # (8,2)
b2 = b1
h2 = h1 + offset[None, :, 0, None, None] # (1,8,H,1)
w2 = w1 + offset[None, :, 1, None, None] # (1,8,1,W)
points1 = points[b2, h2, w2] # (B,8,H,W,3)
# anothor neighbor points
offset = offsets[(k + 2) % 8]
b3 = b1
h3 = h1 + offset[None, :, 0, None, None]
w3 = w1 + offset[None, :, 1, None, None]
points2 = points[b3, h3, w3] # (B,8,H,W,3)
if mode == "closest":
# find the closest neighbor pair
diff = torch.norm(points1 - anchors, dim=4)
diff = diff + torch.norm(points2 - anchors, dim=4)
i = torch.argmin(diff, dim=1) # (B,H,W)
# get normals by cross product
anchors = anchors[b, 0, h, w] # (B,H,W,3)
points1 = points1[b, i, h, w] # (B,H,W,3)
points2 = points2[b, i, h, w] # (B,H,W,3)
vector1 = points1 - anchors
vector2 = points2 - anchors
normals = torch.cross(vector1, vector2, dim=-1) # (B,H,W,3)
elif mode == "mean":
# get normals by cross product
vector1 = points1 - anchors
vector2 = points2 - anchors
normals = torch.cross(vector1, vector2, dim=-1) # (B,8,H,W,3)
normals = normals.mean(dim=1) # (B,H,W,3)
raise NotImplementedError(mode)
normals = normals / (torch.norm(normals, dim=3, keepdim=True) + 1e-8)
normals = normals.permute(0, 3, 1, 2) # (B,3,H,W)
return normals