File size: 15,538 Bytes
854f0d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
import torch
import torch.nn.functional as F
import torch.nn as nn
from icecream import ic


# - neus: use sphere-tracing to speed up depth maps extraction
# This code snippet is heavily borrowed from IDR.
class FastRenderer(nn.Module):
    def __init__(self):
        super(FastRenderer, self).__init__()

        self.sdf_threshold = 5e-5
        self.line_search_step = 0.5
        self.line_step_iters = 1
        self.sphere_tracing_iters = 10
        self.n_steps = 100
        self.n_secant_steps = 8

        # - use sdf_network to inference sdf value or directly interpolate sdf value from precomputed sdf_volume
        self.network_inference = False

    def extract_depth_maps(self, rays_o, rays_d, near, far, sdf_network, conditional_volume):
        with torch.no_grad():
            curr_start_points, network_object_mask, acc_start_dis = self.get_intersection(
                rays_o, rays_d, near, far,
                sdf_network, conditional_volume)

        network_object_mask = network_object_mask.reshape(-1)

        return network_object_mask, acc_start_dis

    def get_intersection(self, rays_o, rays_d, near, far, sdf_network, conditional_volume):
        device = rays_o.device
        num_pixels, _ = rays_d.shape

        curr_start_points, unfinished_mask_start, acc_start_dis, acc_end_dis, min_dis, max_dis = \
            self.sphere_tracing(rays_o, rays_d, near, far, sdf_network, conditional_volume)

        network_object_mask = (acc_start_dis < acc_end_dis)

        # The non convergent rays should be handled by the sampler
        sampler_mask = unfinished_mask_start
        sampler_net_obj_mask = torch.zeros_like(sampler_mask).bool().to(device)
        if sampler_mask.sum() > 0:
            # sampler_min_max = torch.zeros((num_pixels, 2)).to(device)
            # sampler_min_max[sampler_mask, 0] = acc_start_dis[sampler_mask]
            # sampler_min_max[sampler_mask, 1] = acc_end_dis[sampler_mask]

            # ray_sampler(self, rays_o, rays_d, near, far, sampler_mask):
            sampler_pts, sampler_net_obj_mask, sampler_dists = self.ray_sampler(rays_o,
                                                                                rays_d,
                                                                                acc_start_dis,
                                                                                acc_end_dis,
                                                                                sampler_mask,
                                                                                sdf_network,
                                                                                conditional_volume
                                                                                )

            curr_start_points[sampler_mask] = sampler_pts[sampler_mask]
            acc_start_dis[sampler_mask] = sampler_dists[sampler_mask][:, None]
            network_object_mask[sampler_mask] = sampler_net_obj_mask[sampler_mask][:, None]

        # print('----------------------------------------------------------------')
        # print('RayTracing: object = {0}/{1}, secant on {2}/{3}.'
        #       .format(network_object_mask.sum(), len(network_object_mask), sampler_net_obj_mask.sum(),
        #               sampler_mask.sum()))
        # print('----------------------------------------------------------------')

        return curr_start_points, network_object_mask, acc_start_dis

    def sphere_tracing(self, rays_o, rays_d, near, far, sdf_network, conditional_volume):
        ''' Run sphere tracing algorithm for max iterations from both sides of unit sphere intersection '''

        device = rays_o.device

        unfinished_mask_start = (near < far).reshape(-1).clone()
        unfinished_mask_end = (near < far).reshape(-1).clone()

        # Initialize start current points
        curr_start_points = rays_o + rays_d * near
        acc_start_dis = near.clone()

        # Initialize end current points
        curr_end_points = rays_o + rays_d * far
        acc_end_dis = far.clone()

        # Initizlize min and max depth
        min_dis = acc_start_dis.clone()
        max_dis = acc_end_dis.clone()

        # Iterate on the rays (from both sides) till finding a surface
        iters = 0

        next_sdf_start = torch.zeros_like(acc_start_dis).to(device)

        if self.network_inference:
            sdf_func = sdf_network.sdf
        else:
            sdf_func = sdf_network.sdf_from_sdfvolume

        next_sdf_start[unfinished_mask_start] = sdf_func(
            curr_start_points[unfinished_mask_start],
            conditional_volume, lod=0, gru_fusion=False)['sdf_pts_scale%d' % 0]

        next_sdf_end = torch.zeros_like(acc_end_dis).to(device)
        next_sdf_end[unfinished_mask_end] = sdf_func(curr_end_points[unfinished_mask_end],
                                                     conditional_volume, lod=0, gru_fusion=False)[
            'sdf_pts_scale%d' % 0]

        while True:
            # Update sdf
            curr_sdf_start = torch.zeros_like(acc_start_dis).to(device)
            curr_sdf_start[unfinished_mask_start] = next_sdf_start[unfinished_mask_start]
            curr_sdf_start[curr_sdf_start <= self.sdf_threshold] = 0

            curr_sdf_end = torch.zeros_like(acc_end_dis).to(device)
            curr_sdf_end[unfinished_mask_end] = next_sdf_end[unfinished_mask_end]
            curr_sdf_end[curr_sdf_end <= self.sdf_threshold] = 0

            # Update masks
            unfinished_mask_start = unfinished_mask_start & (curr_sdf_start > self.sdf_threshold).reshape(-1)
            unfinished_mask_end = unfinished_mask_end & (curr_sdf_end > self.sdf_threshold).reshape(-1)

            if (
                    unfinished_mask_start.sum() == 0 and unfinished_mask_end.sum() == 0) or iters == self.sphere_tracing_iters:
                break
            iters += 1

            # Make step
            # Update distance
            acc_start_dis = acc_start_dis + curr_sdf_start
            acc_end_dis = acc_end_dis - curr_sdf_end

            # Update points
            curr_start_points = rays_o + acc_start_dis * rays_d
            curr_end_points = rays_o + acc_end_dis * rays_d

            # Fix points which wrongly crossed the surface
            next_sdf_start = torch.zeros_like(acc_start_dis).to(device)
            if unfinished_mask_start.sum() > 0:
                next_sdf_start[unfinished_mask_start] = sdf_func(curr_start_points[unfinished_mask_start],
                                                                 conditional_volume, lod=0, gru_fusion=False)[
                    'sdf_pts_scale%d' % 0]

            next_sdf_end = torch.zeros_like(acc_end_dis).to(device)
            if unfinished_mask_end.sum() > 0:
                next_sdf_end[unfinished_mask_end] = sdf_func(curr_end_points[unfinished_mask_end],
                                                             conditional_volume, lod=0, gru_fusion=False)[
                    'sdf_pts_scale%d' % 0]

            not_projected_start = (next_sdf_start < 0).reshape(-1)
            not_projected_end = (next_sdf_end < 0).reshape(-1)
            not_proj_iters = 0

            while (
                    not_projected_start.sum() > 0 or not_projected_end.sum() > 0) and not_proj_iters < self.line_step_iters:
                # Step backwards
                if not_projected_start.sum() > 0:
                    acc_start_dis[not_projected_start] -= ((1 - self.line_search_step) / (2 ** not_proj_iters)) * \
                                                          curr_sdf_start[not_projected_start]
                    curr_start_points[not_projected_start] = (rays_o + acc_start_dis * rays_d)[not_projected_start]

                    next_sdf_start[not_projected_start] = sdf_func(
                        curr_start_points[not_projected_start],
                        conditional_volume, lod=0, gru_fusion=False)['sdf_pts_scale%d' % 0]

                if not_projected_end.sum() > 0:
                    acc_end_dis[not_projected_end] += ((1 - self.line_search_step) / (2 ** not_proj_iters)) * \
                                                      curr_sdf_end[
                                                          not_projected_end]
                    curr_end_points[not_projected_end] = (rays_o + acc_end_dis * rays_d)[not_projected_end]

                    # Calc sdf

                    next_sdf_end[not_projected_end] = sdf_func(
                        curr_end_points[not_projected_end],
                        conditional_volume, lod=0, gru_fusion=False)['sdf_pts_scale%d' % 0]

                # Update mask
                not_projected_start = (next_sdf_start < 0).reshape(-1)
                not_projected_end = (next_sdf_end < 0).reshape(-1)
                not_proj_iters += 1

            unfinished_mask_start = unfinished_mask_start & (acc_start_dis < acc_end_dis).reshape(-1)
            unfinished_mask_end = unfinished_mask_end & (acc_start_dis < acc_end_dis).reshape(-1)

        return curr_start_points, unfinished_mask_start, acc_start_dis, acc_end_dis, min_dis, max_dis

    def ray_sampler(self, rays_o, rays_d, near, far, sampler_mask, sdf_network, conditional_volume):
        ''' Sample the ray in a given range and run secant on rays which have sign transition '''
        device = rays_o.device
        num_pixels, _ = rays_d.shape
        sampler_pts = torch.zeros(num_pixels, 3).to(device).float()
        sampler_dists = torch.zeros(num_pixels).to(device).float()

        intervals_dist = torch.linspace(0, 1, steps=self.n_steps).to(device).view(1, -1)

        pts_intervals = near + intervals_dist * (far - near)
        points = rays_o[:, None, :] + pts_intervals[:, :, None] * rays_d[:, None, :]

        # Get the non convergent rays
        mask_intersect_idx = torch.nonzero(sampler_mask).flatten()
        points = points.reshape((-1, self.n_steps, 3))[sampler_mask, :, :]
        pts_intervals = pts_intervals.reshape((-1, self.n_steps))[sampler_mask]

        if self.network_inference:
            sdf_func = sdf_network.sdf
        else:
            sdf_func = sdf_network.sdf_from_sdfvolume

        sdf_val_all = []
        for pnts in torch.split(points.reshape(-1, 3), 100000, dim=0):
            sdf_val_all.append(sdf_func(pnts,
                                        conditional_volume, lod=0, gru_fusion=False)['sdf_pts_scale%d' % 0])
        sdf_val = torch.cat(sdf_val_all).reshape(-1, self.n_steps)

        tmp = torch.sign(sdf_val) * torch.arange(self.n_steps, 0, -1).to(device).float().reshape(
            (1, self.n_steps))  # Force argmin to return the first min value
        sampler_pts_ind = torch.argmin(tmp, -1)
        sampler_pts[mask_intersect_idx] = points[torch.arange(points.shape[0]), sampler_pts_ind, :]
        sampler_dists[mask_intersect_idx] = pts_intervals[torch.arange(pts_intervals.shape[0]), sampler_pts_ind]

        net_surface_pts = (sdf_val[torch.arange(sdf_val.shape[0]), sampler_pts_ind] < 0)

        # take points with minimal SDF value for P_out pixels
        p_out_mask = ~net_surface_pts
        n_p_out = p_out_mask.sum()
        if n_p_out > 0:
            out_pts_idx = torch.argmin(sdf_val[p_out_mask, :], -1)
            sampler_pts[mask_intersect_idx[p_out_mask]] = points[p_out_mask, :, :][torch.arange(n_p_out), out_pts_idx,
                                                          :]
            sampler_dists[mask_intersect_idx[p_out_mask]] = pts_intervals[p_out_mask, :][
                torch.arange(n_p_out), out_pts_idx]

        # Get Network object mask
        sampler_net_obj_mask = sampler_mask.clone()
        sampler_net_obj_mask[mask_intersect_idx[~net_surface_pts]] = False

        # Run Secant method
        secant_pts = net_surface_pts
        n_secant_pts = secant_pts.sum()
        if n_secant_pts > 0:
            # Get secant z predictions
            z_high = pts_intervals[torch.arange(pts_intervals.shape[0]), sampler_pts_ind][secant_pts]
            sdf_high = sdf_val[torch.arange(sdf_val.shape[0]), sampler_pts_ind][secant_pts]
            z_low = pts_intervals[secant_pts][torch.arange(n_secant_pts), sampler_pts_ind[secant_pts] - 1]
            sdf_low = sdf_val[secant_pts][torch.arange(n_secant_pts), sampler_pts_ind[secant_pts] - 1]

            cam_loc_secant = rays_o[mask_intersect_idx[secant_pts]]
            ray_directions_secant = rays_d[mask_intersect_idx[secant_pts]]
            z_pred_secant = self.secant(sdf_low, sdf_high, z_low, z_high, cam_loc_secant, ray_directions_secant,
                                        sdf_network, conditional_volume)

            # Get points
            sampler_pts[mask_intersect_idx[secant_pts]] = cam_loc_secant + z_pred_secant[:,
                                                                           None] * ray_directions_secant
            sampler_dists[mask_intersect_idx[secant_pts]] = z_pred_secant

        return sampler_pts, sampler_net_obj_mask, sampler_dists

    def secant(self, sdf_low, sdf_high, z_low, z_high, rays_o, rays_d, sdf_network, conditional_volume):
        ''' Runs the secant method for interval [z_low, z_high] for n_secant_steps '''

        if self.network_inference:
            sdf_func = sdf_network.sdf
        else:
            sdf_func = sdf_network.sdf_from_sdfvolume

        z_pred = -sdf_low * (z_high - z_low) / (sdf_high - sdf_low) + z_low
        for i in range(self.n_secant_steps):
            p_mid = rays_o + z_pred[:, None] * rays_d
            sdf_mid = sdf_func(p_mid,
                               conditional_volume, lod=0, gru_fusion=False)['sdf_pts_scale%d' % 0].reshape(-1)
            ind_low = (sdf_mid > 0).reshape(-1)
            if ind_low.sum() > 0:
                z_low[ind_low] = z_pred[ind_low]
                sdf_low[ind_low] = sdf_mid[ind_low]
            ind_high = sdf_mid < 0
            if ind_high.sum() > 0:
                z_high[ind_high] = z_pred[ind_high]
                sdf_high[ind_high] = sdf_mid[ind_high]

            z_pred = - sdf_low * (z_high - z_low) / (sdf_high - sdf_low) + z_low

        return z_pred  # 1D tensor

    def minimal_sdf_points(self, num_pixels, sdf, cam_loc, ray_directions, mask, min_dis, max_dis):
        ''' Find points with minimal SDF value on rays for P_out pixels '''
        device = sdf.device
        n_mask_points = mask.sum()

        n = self.n_steps
        # steps = torch.linspace(0.0, 1.0,n).to(device)
        steps = torch.empty(n).uniform_(0.0, 1.0).to(device)
        mask_max_dis = max_dis[mask].unsqueeze(-1)
        mask_min_dis = min_dis[mask].unsqueeze(-1)
        steps = steps.unsqueeze(0).repeat(n_mask_points, 1) * (mask_max_dis - mask_min_dis) + mask_min_dis

        mask_points = cam_loc.unsqueeze(1).repeat(1, num_pixels, 1).reshape(-1, 3)[mask]
        mask_rays = ray_directions[mask, :]

        mask_points_all = mask_points.unsqueeze(1).repeat(1, n, 1) + steps.unsqueeze(-1) * mask_rays.unsqueeze(
            1).repeat(1, n, 1)
        points = mask_points_all.reshape(-1, 3)

        mask_sdf_all = []
        for pnts in torch.split(points, 100000, dim=0):
            mask_sdf_all.append(sdf(pnts))

        mask_sdf_all = torch.cat(mask_sdf_all).reshape(-1, n)
        min_vals, min_idx = mask_sdf_all.min(-1)
        min_mask_points = mask_points_all.reshape(-1, n, 3)[torch.arange(0, n_mask_points), min_idx]
        min_mask_dist = steps.reshape(-1, n)[torch.arange(0, n_mask_points), min_idx]

        return min_mask_points, min_mask_dist