File size: 3,250 Bytes
059842e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import torch
import torch.nn.functional as F


def estimate_surface_normal(
    points: torch.Tensor, d: int = 2, mode: str = "closest"
) -> torch.Tensor:
    # estimate surface normal from coordinated point clouds
    # re-implemented the following codes with pytorch:
    # https://github.com/wkentaro/morefusion/blob/master/morefusion/geometry/estimate_pointcloud_normals.py
    # https://github.com/jmccormac/pySceneNetRGBD/blob/master/calculate_surface_normals.py

    assert points.dim() == 4, f"expected (B,3,H,W), but got {points.shape}"
    B, C, H, W = points.shape
    assert C == 3, f"expected C==3, but got {C}"
    device = points.device

    # points = F.pad(points, (0, 0, d, d), mode="constant", value=float("inf"))
    points = F.pad(points, (0, 0, d, d), mode="replicate")
    points = F.pad(points, (d, d, 0, 0), mode="circular")
    points = points.permute(0, 2, 3, 1)  # (B,H,W,3)

    # 8 adjacent offsets
    #  -----------
    # | 7 | 6 | 5 |
    #  -----------
    # | 0 |   | 4 |
    #  -----------
    # | 1 | 2 | 3 |
    #  -----------
    offsets = torch.tensor(
        [
            # (dh,dw)
            (-d, 0),  # 0
            (-d, d),  # 1
            (0, d),  # 2
            (d, d),  # 3
            (d, 0),  # 4
            (d, -d),  # 5
            (0, -d),  # 6
            (-d, -d),  # 7
        ],
        device=device,
    )

    # (B,H,W) indices
    b = torch.arange(B, device=device)[:, None, None]
    h = torch.arange(H, device=device)[None, :, None]
    w = torch.arange(W, device=device)[None, None, :]
    k = torch.arange(8, device=device)

    # anchor points
    b1 = b[:, None]  # (B,1,1,1)
    h1 = h[:, None] + d  # (1,1,H,1)
    w1 = w[:, None] + d  # (1,1,1,W)
    anchors = points[b1, h1, w1]  # (B,H,W,3) -> (B,1,H,W,3)

    # neighbor points
    offset = offsets[k]  # (8,2)
    b2 = b1
    h2 = h1 + offset[None, :, 0, None, None]  # (1,8,H,1)
    w2 = w1 + offset[None, :, 1, None, None]  # (1,8,1,W)
    points1 = points[b2, h2, w2]  # (B,8,H,W,3)

    # anothor neighbor points
    offset = offsets[(k + 2) % 8]
    b3 = b1
    h3 = h1 + offset[None, :, 0, None, None]
    w3 = w1 + offset[None, :, 1, None, None]
    points2 = points[b3, h3, w3]  # (B,8,H,W,3)

    if mode == "closest":
        # find the closest neighbor pair
        diff = torch.norm(points1 - anchors, dim=4)
        diff = diff + torch.norm(points2 - anchors, dim=4)
        i = torch.argmin(diff, dim=1)  # (B,H,W)
        # get normals by cross product
        anchors = anchors[b, 0, h, w]  # (B,H,W,3)
        points1 = points1[b, i, h, w]  # (B,H,W,3)
        points2 = points2[b, i, h, w]  # (B,H,W,3)
        vector1 = points1 - anchors
        vector2 = points2 - anchors
        normals = torch.cross(vector1, vector2, dim=-1)  # (B,H,W,3)
    elif mode == "mean":
        # get normals by cross product
        vector1 = points1 - anchors
        vector2 = points2 - anchors
        normals = torch.cross(vector1, vector2, dim=-1)  # (B,8,H,W,3)
        normals = normals.mean(dim=1)  # (B,H,W,3)
    else:
        raise NotImplementedError(mode)

    normals = normals / (torch.norm(normals, dim=3, keepdim=True) + 1e-8)
    normals = normals.permute(0, 3, 1, 2)  # (B,3,H,W)

    return normals