#include #include #include #include /* Points sampling helper functions. */ template __global__ void infer_t_minmax_cuda_kernel( scalar_t* __restrict__ rays_o, scalar_t* __restrict__ rays_d, scalar_t* __restrict__ xyz_min, scalar_t* __restrict__ xyz_max, const float near, const float far, const int n_rays, scalar_t* __restrict__ t_min, scalar_t* __restrict__ t_max) { const int i_ray = blockIdx.x * blockDim.x + threadIdx.x; if(i_ray __global__ void infer_n_samples_cuda_kernel( scalar_t* __restrict__ rays_d, scalar_t* __restrict__ t_min, scalar_t* __restrict__ t_max, const float stepdist, const int n_rays, int64_t* __restrict__ n_samples) { const int i_ray = blockIdx.x * blockDim.x + threadIdx.x; if(i_ray __global__ void infer_ray_start_dir_cuda_kernel( scalar_t* __restrict__ rays_o, scalar_t* __restrict__ rays_d, scalar_t* __restrict__ t_min, const int n_rays, scalar_t* __restrict__ rays_start, scalar_t* __restrict__ rays_dir) { const int i_ray = blockIdx.x * blockDim.x + threadIdx.x; if(i_ray infer_t_minmax_cuda( torch::Tensor rays_o, torch::Tensor rays_d, torch::Tensor xyz_min, torch::Tensor xyz_max, const float near, const float far) { const int n_rays = rays_o.size(0); auto t_min = torch::empty({n_rays}, rays_o.options()); auto t_max = torch::empty({n_rays}, rays_o.options()); const int threads = 256; const int blocks = (n_rays + threads - 1) / threads; AT_DISPATCH_FLOATING_TYPES(rays_o.type(), "infer_t_minmax_cuda", ([&] { infer_t_minmax_cuda_kernel<<>>( rays_o.data(), rays_d.data(), xyz_min.data(), xyz_max.data(), near, far, n_rays, t_min.data(), t_max.data()); })); return {t_min, t_max}; } torch::Tensor infer_n_samples_cuda(torch::Tensor rays_d, torch::Tensor t_min, torch::Tensor t_max, const float stepdist) { const int n_rays = t_min.size(0); auto n_samples = torch::empty({n_rays}, torch::dtype(torch::kInt64).device(torch::kCUDA)); const int threads = 256; const int blocks = (n_rays + threads - 1) / threads; AT_DISPATCH_FLOATING_TYPES(t_min.type(), "infer_n_samples_cuda", ([&] { infer_n_samples_cuda_kernel<<>>( rays_d.data(), t_min.data(), t_max.data(), stepdist, n_rays, n_samples.data()); })); return n_samples; } std::vector infer_ray_start_dir_cuda(torch::Tensor rays_o, torch::Tensor rays_d, torch::Tensor t_min) { const int n_rays = rays_o.size(0); const int threads = 256; const int blocks = (n_rays + threads - 1) / threads; auto rays_start = torch::empty_like(rays_o); auto rays_dir = torch::empty_like(rays_o); AT_DISPATCH_FLOATING_TYPES(rays_o.type(), "infer_ray_start_dir_cuda", ([&] { infer_ray_start_dir_cuda_kernel<<>>( rays_o.data(), rays_d.data(), t_min.data(), n_rays, rays_start.data(), rays_dir.data()); })); return {rays_start, rays_dir}; } /* Sampling query points on rays. */ __global__ void __set_1_at_ray_seg_start( int64_t* __restrict__ ray_id, int64_t* __restrict__ N_steps_cumsum, const int n_rays) { const int idx = blockIdx.x * blockDim.x + threadIdx.x; if(0 __global__ void sample_pts_on_rays_cuda_kernel( scalar_t* __restrict__ rays_start, scalar_t* __restrict__ rays_dir, scalar_t* __restrict__ xyz_min, scalar_t* __restrict__ xyz_max, int64_t* __restrict__ ray_id, int64_t* __restrict__ step_id, const float stepdist, const int total_len, scalar_t* __restrict__ rays_pts, bool* __restrict__ mask_outbbox) { const int idx = blockIdx.x * blockDim.x + threadIdx.x; if(idxpx) | (xyz_min[1]>py) | (xyz_min[2]>pz) | \ (xyz_max[0] sample_pts_on_rays_cuda( torch::Tensor rays_o, torch::Tensor rays_d, torch::Tensor xyz_min, torch::Tensor xyz_max, const float near, const float far, const float stepdist) { const int threads = 256; const int n_rays = rays_o.size(0); // Compute ray-bbox intersection auto t_minmax = infer_t_minmax_cuda(rays_o, rays_d, xyz_min, xyz_max, near, far); auto t_min = t_minmax[0]; auto t_max = t_minmax[1]; // Compute the number of points required. // Assign ray index and step index to each. auto N_steps = infer_n_samples_cuda(rays_d, t_min, t_max, stepdist); auto N_steps_cumsum = N_steps.cumsum(0); const int total_len = N_steps.sum().item(); auto ray_id = torch::zeros({total_len}, torch::dtype(torch::kInt64).device(torch::kCUDA)); __set_1_at_ray_seg_start<<<(n_rays+threads-1)/threads, threads>>>( ray_id.data(), N_steps_cumsum.data(), n_rays); ray_id.cumsum_(0); auto step_id = torch::empty({total_len}, ray_id.options()); __set_step_id<<<(total_len+threads-1)/threads, threads>>>( step_id.data(), ray_id.data(), N_steps_cumsum.data(), total_len); // Compute the global xyz of each point auto rays_start_dir = infer_ray_start_dir_cuda(rays_o, rays_d, t_min); auto rays_start = rays_start_dir[0]; auto rays_dir = rays_start_dir[1]; auto rays_pts = torch::empty({total_len, 3}, torch::dtype(rays_o.dtype()).device(torch::kCUDA)); auto mask_outbbox = torch::empty({total_len}, torch::dtype(torch::kBool).device(torch::kCUDA)); AT_DISPATCH_FLOATING_TYPES(rays_o.type(), "sample_pts_on_rays_cuda", ([&] { sample_pts_on_rays_cuda_kernel<<<(total_len+threads-1)/threads, threads>>>( rays_start.data(), rays_dir.data(), xyz_min.data(), xyz_max.data(), ray_id.data(), step_id.data(), stepdist, total_len, rays_pts.data(), mask_outbbox.data()); })); return {rays_pts, mask_outbbox, ray_id, step_id, N_steps, t_min, t_max}; } template __global__ void sample_ndc_pts_on_rays_cuda_kernel( const scalar_t* __restrict__ rays_o, const scalar_t* __restrict__ rays_d, const scalar_t* __restrict__ xyz_min, const scalar_t* __restrict__ xyz_max, const int N_samples, const int n_rays, scalar_t* __restrict__ rays_pts, bool* __restrict__ mask_outbbox) { const int idx = blockIdx.x * blockDim.x + threadIdx.x; if(idxpx) | (xyz_min[1]>py) | (xyz_min[2]>pz) | \ (xyz_max[0] sample_ndc_pts_on_rays_cuda( torch::Tensor rays_o, torch::Tensor rays_d, torch::Tensor xyz_min, torch::Tensor xyz_max, const int N_samples) { const int threads = 256; const int n_rays = rays_o.size(0); auto rays_pts = torch::empty({n_rays, N_samples, 3}, torch::dtype(rays_o.dtype()).device(torch::kCUDA)); auto mask_outbbox = torch::empty({n_rays, N_samples}, torch::dtype(torch::kBool).device(torch::kCUDA)); AT_DISPATCH_FLOATING_TYPES(rays_o.type(), "sample_ndc_pts_on_rays_cuda", ([&] { sample_ndc_pts_on_rays_cuda_kernel<<<(n_rays*N_samples+threads-1)/threads, threads>>>( rays_o.data(), rays_d.data(), xyz_min.data(), xyz_max.data(), N_samples, n_rays, rays_pts.data(), mask_outbbox.data()); })); return {rays_pts, mask_outbbox}; } template __device__ __forceinline__ scalar_t norm3(const scalar_t x, const scalar_t y, const scalar_t z) { return sqrt(x*x + y*y + z*z); } template __global__ void sample_bg_pts_on_rays_cuda_kernel( const scalar_t* __restrict__ rays_o, const scalar_t* __restrict__ rays_d, const scalar_t* __restrict__ t_max, const float bg_preserve, const int N_samples, const int n_rays, scalar_t* __restrict__ rays_pts) { const int idx = blockIdx.x * blockDim.x + threadIdx.x; if(idx<<<(n_rays*N_samples+threads-1)/threads, threads>>>( rays_o.data(), rays_d.data(), t_max.data(), bg_preserve, N_samples, n_rays, rays_pts.data()); })); return rays_pts; } /* MaskCache lookup to skip known freespace. */ static __forceinline__ __device__ bool check_xyz(int i, int j, int k, int sz_i, int sz_j, int sz_k) { return (0 <= i) && (i < sz_i) && (0 <= j) && (j < sz_j) && (0 <= k) && (k < sz_k); } template __global__ void maskcache_lookup_cuda_kernel( bool* __restrict__ world, scalar_t* __restrict__ xyz, bool* __restrict__ out, scalar_t* __restrict__ xyz2ijk_scale, scalar_t* __restrict__ xyz2ijk_shift, const int sz_i, const int sz_j, const int sz_k, const int n_pts) { const int i_pt = blockIdx.x * blockDim.x + threadIdx.x; if(i_pt<<>>( world.data(), xyz.data(), out.data(), xyz2ijk_scale.data(), xyz2ijk_shift.data(), sz_i, sz_j, sz_k, n_pts); })); return out; } /* Ray marching helper function. */ template __global__ void raw2alpha_cuda_kernel( scalar_t* __restrict__ density, const float shift, const float interval, const int n_pts, scalar_t* __restrict__ exp_d, scalar_t* __restrict__ alpha) { const int i_pt = blockIdx.x * blockDim.x + threadIdx.x; if(i_pt __global__ void raw2alpha_nonuni_cuda_kernel( scalar_t* __restrict__ density, const float shift, scalar_t* __restrict__ interval, const int n_pts, scalar_t* __restrict__ exp_d, scalar_t* __restrict__ alpha) { const int i_pt = blockIdx.x * blockDim.x + threadIdx.x; if(i_pt raw2alpha_cuda(torch::Tensor density, const float shift, const float interval) { const int n_pts = density.size(0); auto exp_d = torch::empty_like(density); auto alpha = torch::empty_like(density); if(n_pts==0) { return {exp_d, alpha}; } const int threads = 256; const int blocks = (n_pts + threads - 1) / threads; AT_DISPATCH_FLOATING_TYPES(density.type(), "raw2alpha_cuda", ([&] { raw2alpha_cuda_kernel<<>>( density.data(), shift, interval, n_pts, exp_d.data(), alpha.data()); })); return {exp_d, alpha}; } std::vector raw2alpha_nonuni_cuda(torch::Tensor density, const float shift, torch::Tensor interval) { const int n_pts = density.size(0); auto exp_d = torch::empty_like(density); auto alpha = torch::empty_like(density); if(n_pts==0) { return {exp_d, alpha}; } const int threads = 256; const int blocks = (n_pts + threads - 1) / threads; AT_DISPATCH_FLOATING_TYPES(density.type(), "raw2alpha_cuda", ([&] { raw2alpha_nonuni_cuda_kernel<<>>( density.data(), shift, interval.data(), n_pts, exp_d.data(), alpha.data()); })); return {exp_d, alpha}; } template __global__ void raw2alpha_backward_cuda_kernel( scalar_t* __restrict__ exp_d, scalar_t* __restrict__ grad_back, const float interval, const int n_pts, scalar_t* __restrict__ grad) { const int i_pt = blockIdx.x * blockDim.x + threadIdx.x; if(i_pt __global__ void raw2alpha_nonuni_backward_cuda_kernel( scalar_t* __restrict__ exp_d, scalar_t* __restrict__ grad_back, scalar_t* __restrict__ interval, const int n_pts, scalar_t* __restrict__ grad) { const int i_pt = blockIdx.x * blockDim.x + threadIdx.x; if(i_pt<<>>( exp_d.data(), grad_back.data(), interval, n_pts, grad.data()); })); return grad; } torch::Tensor raw2alpha_nonuni_backward_cuda(torch::Tensor exp_d, torch::Tensor grad_back, torch::Tensor interval) { const int n_pts = exp_d.size(0); auto grad = torch::empty_like(exp_d); if(n_pts==0) { return grad; } const int threads = 256; const int blocks = (n_pts + threads - 1) / threads; AT_DISPATCH_FLOATING_TYPES(exp_d.type(), "raw2alpha_backward_cuda", ([&] { raw2alpha_nonuni_backward_cuda_kernel<<>>( exp_d.data(), grad_back.data(), interval.data(), n_pts, grad.data()); })); return grad; } template __global__ void alpha2weight_cuda_kernel( scalar_t* __restrict__ alpha, const int n_rays, scalar_t* __restrict__ weight, scalar_t* __restrict__ T, scalar_t* __restrict__ alphainv_last, int64_t* __restrict__ i_start, int64_t* __restrict__ i_end) { const int i_ray = blockIdx.x * blockDim.x + threadIdx.x; if(i_ray alpha2weight_cuda(torch::Tensor alpha, torch::Tensor ray_id, const int n_rays) { const int n_pts = alpha.size(0); const int threads = 256; auto weight = torch::zeros_like(alpha); auto T = torch::ones_like(alpha); auto alphainv_last = torch::ones({n_rays}, alpha.options()); auto i_start = torch::zeros({n_rays}, torch::dtype(torch::kInt64).device(torch::kCUDA)); auto i_end = torch::zeros({n_rays}, torch::dtype(torch::kInt64).device(torch::kCUDA)); if(n_pts==0) { return {weight, T, alphainv_last, i_start, i_end}; } __set_i_for_segment_start_end<<<(n_pts+threads-1)/threads, threads>>>( ray_id.data(), n_pts, i_start.data(), i_end.data()); i_end[ray_id[n_pts-1]] = n_pts; const int blocks = (n_rays + threads - 1) / threads; AT_DISPATCH_FLOATING_TYPES(alpha.type(), "alpha2weight_cuda", ([&] { alpha2weight_cuda_kernel<<>>( alpha.data(), n_rays, weight.data(), T.data(), alphainv_last.data(), i_start.data(), i_end.data()); })); return {weight, T, alphainv_last, i_start, i_end}; } template __global__ void alpha2weight_backward_cuda_kernel( scalar_t* __restrict__ alpha, scalar_t* __restrict__ weight, scalar_t* __restrict__ T, scalar_t* __restrict__ alphainv_last, int64_t* __restrict__ i_start, int64_t* __restrict__ i_end, const int n_rays, scalar_t* __restrict__ grad_weights, scalar_t* __restrict__ grad_last, scalar_t* __restrict__ grad) { const int i_ray = blockIdx.x * blockDim.x + threadIdx.x; if(i_ray=i_s; --i) { grad[i] = grad_weights[i] * T[i] - back_cum / (1-alpha[i] + 1e-10); back_cum += grad_weights[i] * weight[i]; } } } torch::Tensor alpha2weight_backward_cuda( torch::Tensor alpha, torch::Tensor weight, torch::Tensor T, torch::Tensor alphainv_last, torch::Tensor i_start, torch::Tensor i_end, const int n_rays, torch::Tensor grad_weights, torch::Tensor grad_last) { auto grad = torch::zeros_like(alpha); if(n_rays==0) { return grad; } const int threads = 256; const int blocks = (n_rays + threads - 1) / threads; AT_DISPATCH_FLOATING_TYPES(alpha.type(), "alpha2weight_backward_cuda", ([&] { alpha2weight_backward_cuda_kernel<<>>( alpha.data(), weight.data(), T.data(), alphainv_last.data(), i_start.data(), i_end.data(), n_rays, grad_weights.data(), grad_last.data(), grad.data()); })); return grad; }