File size: 8,145 Bytes
854f0d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216282e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
"""
Patch Projector
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from models.render_utils import sample_ptsFeatures_from_featureMaps


class PatchProjector():
    def __init__(self, patch_size):
        self.h_patch_size = patch_size
        self.offsets = build_patch_offset(patch_size)  # the warping patch offsets index

        self.z_axis = torch.tensor([0, 0, 1]).float()

        self.plane_dist_thresh = 0.001

    # * correctness checked
    def pixel_warp(self, pts, imgs, intrinsics,
                   w2cs, img_wh=None):
        """

        :param pts: [N_rays, n_samples, 3]
        :param imgs: [N_views, 3, H, W]
        :param intrinsics: [N_views, 4, 4]
        :param c2ws: [N_views, 4, 4]
        :param img_wh:
        :return:
        """
        if img_wh is None:
            N_views, _, sizeH, sizeW = imgs.shape
            img_wh = [sizeW, sizeH]

        pts_color, valid_mask = sample_ptsFeatures_from_featureMaps(
            pts, imgs, w2cs, intrinsics, img_wh,
            proj_matrix=None, return_mask=True)  # [N_views, c, N_rays, n_samples], [N_views, N_rays, n_samples]

        pts_color = pts_color.permute(2, 3, 0, 1)
        valid_mask = valid_mask.permute(1, 2, 0)

        return pts_color, valid_mask  # [N_rays, n_samples, N_views,  3] , [N_rays, n_samples, N_views]

    def patch_warp(self, pts, uv, normals, src_imgs,
                   ref_intrinsic, src_intrinsics,
                   ref_c2w, src_c2ws, img_wh=None
                   ):
        """

        :param pts: [N_rays, n_samples, 3]
        :param uv : [N_rays, 2]  normalized in (-1, 1)
        :param normals: [N_rays, n_samples, 3]  The normal of pt in world space
        :param src_imgs: [N_src, 3, h, w]
        :param ref_intrinsic: [4,4]
        :param src_intrinsics: [N_src, 4, 4]
        :param ref_c2w: [4,4]
        :param src_c2ws: [N_src, 4, 4]
        :return:
        """
        device = pts.device

        N_rays, n_samples, _ = pts.shape
        N_pts = N_rays * n_samples

        N_src, _, sizeH, sizeW = src_imgs.shape

        if img_wh is not None:
            sizeW, sizeH = img_wh[0], img_wh[1]

        # scale uv from (-1, 1) to (0, W/H)
        uv[:, 0] = (uv[:, 0] + 1) / 2. * (sizeW - 1)
        uv[:, 1] = (uv[:, 1] + 1) / 2. * (sizeH - 1)

        ref_intr = ref_intrinsic[:3, :3]
        inv_ref_intr = torch.inverse(ref_intr)
        src_intrs = src_intrinsics[:, :3, :3]
        inv_src_intrs = torch.inverse(src_intrs)

        ref_pose = ref_c2w
        inv_ref_pose = torch.inverse(ref_pose)
        src_poses = src_c2ws
        inv_src_poses = torch.inverse(src_poses)

        ref_cam_loc = ref_pose[:3, 3].unsqueeze(0)  # [1, 3]
        sampled_dists = torch.norm(pts - ref_cam_loc, dim=-1)  # [N_pts, 1]

        relative_proj = inv_src_poses @ ref_pose
        R_rel = relative_proj[:, :3, :3]
        t_rel = relative_proj[:, :3, 3:]
        R_ref = inv_ref_pose[:3, :3]
        t_ref = inv_ref_pose[:3, 3:]

        pts = pts.view(-1, 3)
        normals = normals.view(-1, 3)

        with torch.no_grad():
            rot_normals = R_ref @ normals.unsqueeze(-1)  # [N_pts, 3, 1]
            points_in_ref = R_ref @ pts.unsqueeze(
                -1) + t_ref  # [N_pts, 3, 1]  points in the reference frame coordiantes system
            d1 = torch.sum(rot_normals * points_in_ref, dim=1).unsqueeze(
                1)  # distance from the plane to ref camera center

            d2 = torch.sum(rot_normals.unsqueeze(1) * (-R_rel.transpose(1, 2) @ t_rel).unsqueeze(0),
                           dim=2)  # distance from the plane to src camera center
            valid_hom = (torch.abs(d1) > self.plane_dist_thresh) & (
                    torch.abs(d1 - d2) > self.plane_dist_thresh) & ((d2 / d1) < 1)

            d1 = d1.squeeze()
            sign = torch.sign(d1)
            sign[sign == 0] = 1
            d = torch.clamp(torch.abs(d1), 1e-8) * sign

            H = src_intrs.unsqueeze(1) @ (
                    R_rel.unsqueeze(1) + t_rel.unsqueeze(1) @ rot_normals.view(1, N_pts, 1, 3) / d.view(1,
                                                                                                        N_pts,
                                                                                                        1, 1)
            ) @ inv_ref_intr.view(1, 1, 3, 3)

            # replace invalid homs with fronto-parallel homographies
            H_invalid = src_intrs.unsqueeze(1) @ (
                    R_rel.unsqueeze(1) + t_rel.unsqueeze(1) @ self.z_axis.to(device).view(1, 1, 1, 3).expand(-1, N_pts,
                                                                                                             -1,
                                                                                                             -1) / sampled_dists.view(
                1, N_pts, 1, 1)
            ) @ inv_ref_intr.view(1, 1, 3, 3)
            tmp_m = ~valid_hom.view(-1, N_src).t()
            H[tmp_m] = H_invalid[tmp_m]

        pixels = uv.view(N_rays, 1, 2) + self.offsets.float().to(device)
        Npx = pixels.shape[1]
        grid, warp_mask_full = self.patch_homography(H, pixels)

        warp_mask_full = warp_mask_full & (grid[..., 0] < (sizeW - self.h_patch_size)) & (
                grid[..., 1] < (sizeH - self.h_patch_size)) & (grid >= self.h_patch_size).all(dim=-1)
        warp_mask_full = warp_mask_full.view(N_src, N_rays, n_samples, Npx)

        grid = torch.clamp(normalize(grid, sizeH, sizeW), -10, 10)

        sampled_rgb_val = F.grid_sample(src_imgs, grid.view(N_src, -1, 1, 2), align_corners=True).squeeze(
            -1).transpose(1, 2)
        sampled_rgb_val = sampled_rgb_val.view(N_src, N_rays, n_samples, Npx, 3)

        warp_mask_full = warp_mask_full.permute(1, 2, 0, 3).contiguous()  # (N_rays, n_samples, N_src, Npx)
        sampled_rgb_val = sampled_rgb_val.permute(1, 2, 0, 3, 4).contiguous()  # (N_rays, n_samples, N_src, Npx, 3)

        return sampled_rgb_val, warp_mask_full

    def patch_homography(self, H, uv):
        N, Npx = uv.shape[:2]
        Nsrc = H.shape[0]
        H = H.view(Nsrc, N, -1, 3, 3)
        hom_uv = add_hom(uv)

        # einsum is 30 times faster
        # tmp = (H.view(Nsrc, N, -1, 1, 3, 3) @ hom_uv.view(1, N, 1, -1, 3, 1)).squeeze(-1).view(Nsrc, -1, 3)
        tmp = torch.einsum("vprik,pok->vproi", H, hom_uv).reshape(Nsrc, -1, 3)

        grid = tmp[..., :2] / torch.clamp(tmp[..., 2:], 1e-8)
        mask = tmp[..., 2] > 0
        return grid, mask


def add_hom(pts):
    try:
        dev = pts.device
        ones = torch.ones(pts.shape[:-1], device=dev).unsqueeze(-1)
        return torch.cat((pts, ones), dim=-1)

    except AttributeError:
        ones = np.ones((pts.shape[0], 1))
        return np.concatenate((pts, ones), axis=1)


def normalize(flow, h, w, clamp=None):
    # either h and w are simple float or N torch.tensor where N batch size
    try:
        h.device

    except AttributeError:
        h = torch.tensor(h, device=flow.device).float().unsqueeze(0)
        w = torch.tensor(w, device=flow.device).float().unsqueeze(0)

    if len(flow.shape) == 4:
        w = w.unsqueeze(1).unsqueeze(2)
        h = h.unsqueeze(1).unsqueeze(2)
    elif len(flow.shape) == 3:
        w = w.unsqueeze(1)
        h = h.unsqueeze(1)
    elif len(flow.shape) == 5:
        w = w.unsqueeze(0).unsqueeze(2).unsqueeze(2)
        h = h.unsqueeze(0).unsqueeze(2).unsqueeze(2)

    res = torch.empty_like(flow)
    if res.shape[-1] == 3:
        res[..., 2] = 1

    # for grid_sample with align_corners=True
    # https://github.com/pytorch/pytorch/blob/c371542efc31b1abfe6f388042aa3ab0cef935f2/aten/src/ATen/native/GridSampler.h#L33
    res[..., 0] = 2 * flow[..., 0] / (w - 1) - 1
    res[..., 1] = 2 * flow[..., 1] / (h - 1) - 1

    if clamp:
        return torch.clamp(res, -clamp, clamp)
    else:
        return res


def build_patch_offset(h_patch_size):
    offsets = torch.arange(-h_patch_size, h_patch_size + 1)
    return torch.stack(torch.meshgrid(offsets, offsets, indexing="ij")[::-1], dim=-1).view(1, -1, 2)  # nb_pixels_patch * 2