File size: 3,655 Bytes
854f0d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import numpy as np
import torch


def rigid_transform(xyz, transform):
    """Applies a rigid transform (c2w) to an (N, 3) pointcloud.
    """
    device = xyz.device
    xyz_h = torch.cat([xyz, torch.ones((len(xyz), 1)).to(device)], dim=1)  # (N, 4)
    xyz_t_h = (transform @ xyz_h.T).T  # * checked: the same with the below

    return xyz_t_h[:, :3]


def get_view_frustum(min_depth, max_depth, size, cam_intr, c2w):
    """Get corners of 3D camera view frustum of depth image
    """
    device = cam_intr.device
    im_h, im_w = size
    im_h = int(im_h)
    im_w = int(im_w)
    view_frust_pts = torch.stack([
        (torch.tensor([0, 0, im_w, im_w, 0, 0, im_w, im_w]).to(device) - cam_intr[0, 2]) * torch.tensor(
            [min_depth, min_depth, min_depth, min_depth, max_depth, max_depth, max_depth, max_depth]).to(device) /
        cam_intr[0, 0],
        (torch.tensor([0, im_h, 0, im_h, 0, im_h, 0, im_h]).to(device) - cam_intr[1, 2]) * torch.tensor(
            [min_depth, min_depth, min_depth, min_depth, max_depth, max_depth, max_depth, max_depth]).to(device) /
        cam_intr[1, 1],
        torch.tensor([min_depth, min_depth, min_depth, min_depth, max_depth, max_depth, max_depth, max_depth]).to(
            device)
    ])
    view_frust_pts = view_frust_pts.type(torch.float32)
    c2w = c2w.type(torch.float32)
    view_frust_pts = rigid_transform(view_frust_pts.T, c2w).T
    return view_frust_pts


def set_pixel_coords(h, w):
    i_range = torch.arange(0, h).view(1, h, 1).expand(1, h, w).type(torch.float32)  # [1, H, W]
    j_range = torch.arange(0, w).view(1, 1, w).expand(1, h, w).type(torch.float32)  # [1, H, W]
    ones = torch.ones(1, h, w).type(torch.float32)

    pixel_coords = torch.stack((j_range, i_range, ones), dim=1)  # [1, 3, H, W]

    return pixel_coords


def get_boundingbox(img_hw, intrinsics, extrinsics, near_fars):
    """
    # get the minimum bounding box of all visual hulls
    :param img_hw:
    :param intrinsics:
    :param extrinsics:
    :param near_fars:
    :return:
    """

    bnds = torch.zeros((3, 2))
    bnds[:, 0] = np.inf
    bnds[:, 1] = -np.inf

    if isinstance(intrinsics, list):
        num = len(intrinsics)
    else:
        num = intrinsics.shape[0]
    # print("num: ", num)
    view_frust_pts_list = []
    for i in range(num):
        if not isinstance(intrinsics[i], torch.Tensor):
            cam_intr = torch.tensor(intrinsics[i])
            w2c = torch.tensor(extrinsics[i])
            c2w = torch.inverse(w2c)
        else:
            cam_intr = intrinsics[i]
            w2c = extrinsics[i]
            c2w = torch.inverse(w2c)
        min_depth, max_depth = near_fars[i][0], near_fars[i][1]
        # todo: check the coresponding points are matched

        view_frust_pts = get_view_frustum(min_depth, max_depth, img_hw, cam_intr, c2w)
        bnds[:, 0] = torch.min(bnds[:, 0], torch.min(view_frust_pts, dim=1)[0])
        bnds[:, 1] = torch.max(bnds[:, 1], torch.max(view_frust_pts, dim=1)[0])
        view_frust_pts_list.append(view_frust_pts)
    all_view_frust_pts = torch.cat(view_frust_pts_list, dim=1)

    # print("all_view_frust_pts: ", all_view_frust_pts.shape)
    # distance = torch.norm(all_view_frust_pts, dim=0)
    # print("distance: ", distance)

    # print("all_view_frust_pts_z: ", all_view_frust_pts[2, :])

    center = torch.tensor(((bnds[0, 1] + bnds[0, 0]) / 2, (bnds[1, 1] + bnds[1, 0]) / 2,
                           (bnds[2, 1] + bnds[2, 0]) / 2))

    lengths = bnds[:, 1] - bnds[:, 0]

    max_length, _ = torch.max(lengths, dim=0)
    radius = max_length / 2

    # print("radius: ", radius)
    return center, radius, bnds