File size: 11,975 Bytes
79df973
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
// Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
//
// This work is made available under the Nvidia Source Code License-NC.
// To view a copy of this license, check out LICENSE.md
//
// The ray marching algorithm used in this file is a variety of modified
// Bresenham method:
// http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.42.3443&rep=rep1&type=pdf
// Search for "voxel traversal algorithm" for related information

#include <torch/types.h>

#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAApplyUtils.cuh>
#include <ATen/cuda/CUDAContext.h>

#include <cuda.h>
#include <cuda_runtime.h>
#include <curand.h>
#include <curand_kernel.h>
#include <time.h>

//#include <pybind11/numpy.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <vector>

#include "voxlib_common.h"

struct RVIP_Params {
  int voxel_dims[3];
  int voxel_strides[3];
  int max_samples;
  int img_dims[2];
  // Camera parameters
  float cam_ori[3];
  float cam_fwd[3];
  float cam_side[3];
  float cam_up[3];
  float cam_c[2];
  float cam_f;
  // unsigned long seed;
};

/*
    out_voxel_id: torch CUDA int32  [   img_dims[0], img_dims[1], max_samples,
   1] out_depth:    torch CUDA float  [2, img_dims[0], img_dims[1], max_samples,
   1] out_raydirs:  torch CUDA float  [   img_dims[0], img_dims[1],           1,
   3] Image coordinates refer to the center of the pixel [0, 0, 0] at voxel
   coordinate is at the corner of the corner block (instead of at the center)
*/
template <int TILE_DIM>
static __global__ void ray_voxel_intersection_perspective_kernel(
    int32_t *__restrict__ out_voxel_id, float *__restrict__ out_depth,
    float *__restrict__ out_raydirs, const int32_t *__restrict__ in_voxel,
    const RVIP_Params p) {

  int img_coords[2];
  img_coords[1] = blockIdx.x * TILE_DIM + threadIdx.x;
  img_coords[0] = blockIdx.y * TILE_DIM + threadIdx.y;
  if (img_coords[0] >= p.img_dims[0] || img_coords[1] >= p.img_dims[1]) {
    return;
  }
  int pix_index = img_coords[0] * p.img_dims[1] + img_coords[1];

  // Calculate ray origin and direction
  float rayori[3], raydir[3];
  rayori[0] = p.cam_ori[0];
  rayori[1] = p.cam_ori[1];
  rayori[2] = p.cam_ori[2];

  // Camera intrinsics
  float ndc_imcoords[2];
  ndc_imcoords[0] = p.cam_c[0] - (float)img_coords[0]; // Flip height
  ndc_imcoords[1] = (float)img_coords[1] - p.cam_c[1];

  raydir[0] = p.cam_up[0] * ndc_imcoords[0] + p.cam_side[0] * ndc_imcoords[1] +
              p.cam_fwd[0] * p.cam_f;
  raydir[1] = p.cam_up[1] * ndc_imcoords[0] + p.cam_side[1] * ndc_imcoords[1] +
              p.cam_fwd[1] * p.cam_f;
  raydir[2] = p.cam_up[2] * ndc_imcoords[0] + p.cam_side[2] * ndc_imcoords[1] +
              p.cam_fwd[2] * p.cam_f;
  normalize<float, 3>(raydir);

  // Save out_raydirs
  out_raydirs[pix_index * 3] = raydir[0];
  out_raydirs[pix_index * 3 + 1] = raydir[1];
  out_raydirs[pix_index * 3 + 2] = raydir[2];

  float axis_t[3];
  int axis_int[3];
  // int axis_intbound[3];

  // Current voxel
  axis_int[0] = floorf(rayori[0]);
  axis_int[1] = floorf(rayori[1]);
  axis_int[2] = floorf(rayori[2]);

#pragma unroll
  for (int i = 0; i < 3; i++) {
    if (raydir[i] > 0) {
      // Initial t value
      // Handle boundary case where rayori[i] is a whole number. Always round Up
      // for the next block
      // axis_t[i] = (ceilf(nextafterf(rayori[i], HUGE_VALF)) - rayori[i]) /
      // raydir[i];
      axis_t[i] = ((float)(axis_int[i] + 1) - rayori[i]) / raydir[i];
    } else if (raydir[i] < 0) {
      axis_t[i] = ((float)axis_int[i] - rayori[i]) / raydir[i];
    } else {
      axis_t[i] = HUGE_VALF;
    }
  }

  // Fused raymarching and sampling
  bool quit = false;
  for (int cur_plane = 0; cur_plane < p.max_samples;
       cur_plane++) { // Last cycle is for calculating p2
    float t = nanf("0");
    float t2 = nanf("0");
    int32_t blk_id = 0;
    // Find the next intersection
    while (!quit) {
      // Find the next smallest t
      float tnow;
      /*
      #pragma unroll
      for (int i=0; i<3; i++) {
          if (axis_t[i] <= axis_t[(i+1)%3] && axis_t[i] <= axis_t[(i+2)%3]) {
              // Update current t
              tnow = axis_t[i];
              // Update t candidates
              if (raydir[i] > 0) {
                  axis_int[i] += 1;
                  if (axis_int[i] >= p.voxel_dims[i]) {
                      quit = true;
                  }
                  axis_t[i] = ((float)(axis_int[i]+1) - rayori[i]) / raydir[i];
              } else {
                  axis_int[i] -= 1;
                  if (axis_int[i] < 0) {
                      quit = true;
                  }
                  axis_t[i] = ((float)axis_int[i] - rayori[i]) / raydir[i];
              }
              break; // Avoid advancing multiple steps as axis_t is updated
          }
      }
      */
      // Hand unroll
      if (axis_t[0] <= axis_t[1] && axis_t[0] <= axis_t[2]) {
        // Update current t
        tnow = axis_t[0];
        // Update t candidates
        if (raydir[0] > 0) {
          axis_int[0] += 1;
          if (axis_int[0] >= p.voxel_dims[0]) {
            quit = true;
          }
          axis_t[0] = ((float)(axis_int[0] + 1) - rayori[0]) / raydir[0];
        } else {
          axis_int[0] -= 1;
          if (axis_int[0] < 0) {
            quit = true;
          }
          axis_t[0] = ((float)axis_int[0] - rayori[0]) / raydir[0];
        }
      } else if (axis_t[1] <= axis_t[2]) {
        tnow = axis_t[1];
        if (raydir[1] > 0) {
          axis_int[1] += 1;
          if (axis_int[1] >= p.voxel_dims[1]) {
            quit = true;
          }
          axis_t[1] = ((float)(axis_int[1] + 1) - rayori[1]) / raydir[1];
        } else {
          axis_int[1] -= 1;
          if (axis_int[1] < 0) {
            quit = true;
          }
          axis_t[1] = ((float)axis_int[1] - rayori[1]) / raydir[1];
        }
      } else {
        tnow = axis_t[2];
        if (raydir[2] > 0) {
          axis_int[2] += 1;
          if (axis_int[2] >= p.voxel_dims[2]) {
            quit = true;
          }
          axis_t[2] = ((float)(axis_int[2] + 1) - rayori[2]) / raydir[2];
        } else {
          axis_int[2] -= 1;
          if (axis_int[2] < 0) {
            quit = true;
          }
          axis_t[2] = ((float)axis_int[2] - rayori[2]) / raydir[2];
        }
      }

      if (quit) {
        break;
      }

      // Skip empty space
      // Could there be deadlock if the ray direction is away from the world?
      if (axis_int[0] < 0 || axis_int[0] >= p.voxel_dims[0] ||
          axis_int[1] < 0 || axis_int[1] >= p.voxel_dims[1] ||
          axis_int[2] < 0 || axis_int[2] >= p.voxel_dims[2]) {
        continue;
      }

      // Test intersection using voxel grid
      blk_id = in_voxel[axis_int[0] * p.voxel_strides[0] +
                        axis_int[1] * p.voxel_strides[1] +
                        axis_int[2] * p.voxel_strides[2]];
      if (blk_id == 0) {
        continue;
      }

      // Now that there is an intersection
      t = tnow;
      // Calculate t2
      /*
      #pragma unroll
      for (int i=0; i<3; i++) {
          if (axis_t[i] <= axis_t[(i+1)%3] && axis_t[i] <= axis_t[(i+2)%3]) {
              t2 = axis_t[i];
              break;
          }
      }
      */
      // Hand unroll
      if (axis_t[0] <= axis_t[1] && axis_t[0] <= axis_t[2]) {
        t2 = axis_t[0];
      } else if (axis_t[1] <= axis_t[2]) {
        t2 = axis_t[1];
      } else {
        t2 = axis_t[2];
      }
      break;
    } // while !quit (ray marching loop)

    out_depth[pix_index * p.max_samples + cur_plane] = t;
    out_depth[p.img_dims[0] * p.img_dims[1] * p.max_samples +
              pix_index * p.max_samples + cur_plane] = t2;
    out_voxel_id[pix_index * p.max_samples + cur_plane] = blk_id;
  } // cur_plane
}

/*
    out:
        out_voxel_id: torch CUDA int32  [   img_dims[0], img_dims[1],
   max_samples, 1] out_depth:    torch CUDA float  [2, img_dims[0], img_dims[1],
   max_samples, 1] out_raydirs:  torch CUDA float  [   img_dims[0], img_dims[1],
   1, 3] in: in_voxel:     torch CUDA int32  [X, Y, Z] [40, 512, 512] cam_ori:
   torch      float  [3] cam_dir:      torch      float  [3] cam_up:       torch
   float  [3] cam_f:                   float cam_c:                   int    [2]
        img_dims:                int    [2]
        max_samples:             int
*/
std::vector<torch::Tensor> ray_voxel_intersection_perspective_cuda(
    const torch::Tensor &in_voxel, const torch::Tensor &cam_ori,
    const torch::Tensor &cam_dir, const torch::Tensor &cam_up, float cam_f,
    const std::vector<float> &cam_c, const std::vector<int> &img_dims,
    int max_samples) {
  CHECK_CUDA(in_voxel);

  int curDevice = -1;
  cudaGetDevice(&curDevice);
  cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
  torch::Device device = in_voxel.device();

  // assert(in_voxel.dtype() == torch::kU8);
  assert(in_voxel.dtype() == torch::kInt32); // Minecraft compatibility
  assert(in_voxel.dim() == 3);
  assert(cam_ori.dtype() == torch::kFloat32);
  assert(cam_ori.numel() == 3);
  assert(cam_dir.dtype() == torch::kFloat32);
  assert(cam_dir.numel() == 3);
  assert(cam_up.dtype() == torch::kFloat32);
  assert(cam_up.numel() == 3);
  assert(img_dims.size() == 2);

  RVIP_Params p;

  // Calculate camera rays
  const torch::Tensor cam_ori_c = cam_ori.cpu();
  const torch::Tensor cam_dir_c = cam_dir.cpu();
  const torch::Tensor cam_up_c = cam_up.cpu();

  // Get the coordinate frame of camera space in world space
  normalize<float, 3>(p.cam_fwd, cam_dir_c.data_ptr<float>());
  cross<float>(p.cam_side, p.cam_fwd, cam_up_c.data_ptr<float>());
  normalize<float, 3>(p.cam_side);
  cross<float>(p.cam_up, p.cam_side, p.cam_fwd);
  normalize<float, 3>(p.cam_up); // Not absolutely necessary as both vectors are
                                 // normalized. But just in case...

  copyarr<float, 3>(p.cam_ori, cam_ori_c.data_ptr<float>());

  p.cam_f = cam_f;
  p.cam_c[0] = cam_c[0];
  p.cam_c[1] = cam_c[1];
  p.max_samples = max_samples;
  // printf("[Renderer] max_dist: %ld\n", max_dist);

  p.voxel_dims[0] = in_voxel.size(0);
  p.voxel_dims[1] = in_voxel.size(1);
  p.voxel_dims[2] = in_voxel.size(2);
  p.voxel_strides[0] = in_voxel.stride(0);
  p.voxel_strides[1] = in_voxel.stride(1);
  p.voxel_strides[2] = in_voxel.stride(2);

  // printf("[Renderer] Voxel resolution: %ld, %ld, %ld\n", p.voxel_dims[0],
  // p.voxel_dims[1], p.voxel_dims[2]);

  p.img_dims[0] = img_dims[0];
  p.img_dims[1] = img_dims[1];

  // Create output tensors
  // For Minecraft Seg Mask
  torch::Tensor out_voxel_id =
      torch::empty({p.img_dims[0], p.img_dims[1], p.max_samples, 1},
                   torch::TensorOptions().dtype(torch::kInt32).device(device));

  torch::Tensor out_depth;
  // Produce two sets of localcoords, one for entry point, the other one for
  // exit point. They share the same corner_ids.
  out_depth = torch::empty(
      {2, p.img_dims[0], p.img_dims[1], p.max_samples, 1},
      torch::TensorOptions().dtype(torch::kFloat32).device(device));

  torch::Tensor out_raydirs = torch::empty({p.img_dims[0], p.img_dims[1], 1, 3},
                                           torch::TensorOptions()
                                               .dtype(torch::kFloat32)
                                               .device(device)
                                               .requires_grad(false));

  const int TILE_DIM = 8;
  dim3 dimGrid((p.img_dims[1] + TILE_DIM - 1) / TILE_DIM,
               (p.img_dims[0] + TILE_DIM - 1) / TILE_DIM, 1);
  dim3 dimBlock(TILE_DIM, TILE_DIM, 1);

  ray_voxel_intersection_perspective_kernel<TILE_DIM>
      <<<dimGrid, dimBlock, 0, stream>>>(
          out_voxel_id.data_ptr<int32_t>(), out_depth.data_ptr<float>(),
          out_raydirs.data_ptr<float>(), in_voxel.data_ptr<int32_t>(), p);

  return {out_voxel_id, out_depth, out_raydirs};
}