Spaces:
Running
Running
inline constexpr __device__ float SQRT3() { return 1.7320508075688772f; } | |
inline constexpr __device__ float RSQRT3() { return 0.5773502691896258f; } | |
inline constexpr __device__ float PI() { return 3.141592653589793f; } | |
inline constexpr __device__ float RPI() { return 0.3183098861837907f; } | |
template <typename T> | |
inline __host__ __device__ T div_round_up(T val, T divisor) { | |
return (val + divisor - 1) / divisor; | |
} | |
inline __host__ __device__ float signf(const float x) { | |
return copysignf(1.0, x); | |
} | |
inline __host__ __device__ float clamp(const float x, const float min, const float max) { | |
return fminf(max, fmaxf(min, x)); | |
} | |
inline __host__ __device__ void swapf(float& a, float& b) { | |
float c = a; a = b; b = c; | |
} | |
inline __device__ int mip_from_pos(const float x, const float y, const float z, const float max_cascade) { | |
const float mx = fmaxf(fabsf(x), fmaxf(fabs(y), fabs(z))); | |
int exponent; | |
frexpf(mx, &exponent); // [0, 0.5) --> -1, [0.5, 1) --> 0, [1, 2) --> 1, [2, 4) --> 2, ... | |
return fminf(max_cascade - 1, fmaxf(0, exponent)); | |
} | |
inline __device__ int mip_from_dt(const float dt, const float H, const float max_cascade) { | |
const float mx = dt * H * 0.5; | |
int exponent; | |
frexpf(mx, &exponent); | |
return fminf(max_cascade - 1, fmaxf(0, exponent)); | |
} | |
inline __host__ __device__ uint32_t __expand_bits(uint32_t v) | |
{ | |
v = (v * 0x00010001u) & 0xFF0000FFu; | |
v = (v * 0x00000101u) & 0x0F00F00Fu; | |
v = (v * 0x00000011u) & 0xC30C30C3u; | |
v = (v * 0x00000005u) & 0x49249249u; | |
return v; | |
} | |
inline __host__ __device__ uint32_t __morton3D(uint32_t x, uint32_t y, uint32_t z) | |
{ | |
uint32_t xx = __expand_bits(x); | |
uint32_t yy = __expand_bits(y); | |
uint32_t zz = __expand_bits(z); | |
return xx | (yy << 1) | (zz << 2); | |
} | |
inline __host__ __device__ uint32_t __morton3D_invert(uint32_t x) | |
{ | |
x = x & 0x49249249; | |
x = (x | (x >> 2)) & 0xc30c30c3; | |
x = (x | (x >> 4)) & 0x0f00f00f; | |
x = (x | (x >> 8)) & 0xff0000ff; | |
x = (x | (x >> 16)) & 0x0000ffff; | |
return x; | |
} | |
//////////////////////////////////////////////////// | |
///////////// utils ///////////// | |
//////////////////////////////////////////////////// | |
// rays_o/d: [N, 3] | |
// nears/fars: [N] | |
// scalar_t should always be float in use. | |
template <typename scalar_t> | |
__global__ void kernel_near_far_from_aabb( | |
const scalar_t * __restrict__ rays_o, | |
const scalar_t * __restrict__ rays_d, | |
const scalar_t * __restrict__ aabb, | |
const uint32_t N, | |
const float min_near, | |
scalar_t * nears, scalar_t * fars | |
) { | |
// parallel per ray | |
const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; | |
if (n >= N) return; | |
// locate | |
rays_o += n * 3; | |
rays_d += n * 3; | |
const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2]; | |
const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2]; | |
const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz; | |
// get near far (assume cube scene) | |
float near = (aabb[0] - ox) * rdx; | |
float far = (aabb[3] - ox) * rdx; | |
if (near > far) swapf(near, far); | |
float near_y = (aabb[1] - oy) * rdy; | |
float far_y = (aabb[4] - oy) * rdy; | |
if (near_y > far_y) swapf(near_y, far_y); | |
if (near > far_y || near_y > far) { | |
nears[n] = fars[n] = std::numeric_limits<scalar_t>::max(); | |
return; | |
} | |
if (near_y > near) near = near_y; | |
if (far_y < far) far = far_y; | |
float near_z = (aabb[2] - oz) * rdz; | |
float far_z = (aabb[5] - oz) * rdz; | |
if (near_z > far_z) swapf(near_z, far_z); | |
if (near > far_z || near_z > far) { | |
nears[n] = fars[n] = std::numeric_limits<scalar_t>::max(); | |
return; | |
} | |
if (near_z > near) near = near_z; | |
if (far_z < far) far = far_z; | |
if (near < min_near) near = min_near; | |
nears[n] = near; | |
fars[n] = far; | |
} | |
void near_far_from_aabb(const at::Tensor rays_o, const at::Tensor rays_d, const at::Tensor aabb, const uint32_t N, const float min_near, at::Tensor nears, at::Tensor fars) { | |
static constexpr uint32_t N_THREAD = 128; | |
AT_DISPATCH_FLOATING_TYPES_AND_HALF( | |
rays_o.scalar_type(), "near_far_from_aabb", ([&] { | |
kernel_near_far_from_aabb<<<div_round_up(N, N_THREAD), N_THREAD>>>(rays_o.data_ptr<scalar_t>(), rays_d.data_ptr<scalar_t>(), aabb.data_ptr<scalar_t>(), N, min_near, nears.data_ptr<scalar_t>(), fars.data_ptr<scalar_t>()); | |
})); | |
} | |
// rays_o/d: [N, 3] | |
// radius: float | |
// coords: [N, 2] | |
template <typename scalar_t> | |
__global__ void kernel_sph_from_ray( | |
const scalar_t * __restrict__ rays_o, | |
const scalar_t * __restrict__ rays_d, | |
const float radius, | |
const uint32_t N, | |
scalar_t * coords | |
) { | |
// parallel per ray | |
const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; | |
if (n >= N) return; | |
// locate | |
rays_o += n * 3; | |
rays_d += n * 3; | |
coords += n * 2; | |
const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2]; | |
const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2]; | |
const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz; | |
// solve t from || o + td || = radius | |
const float A = dx * dx + dy * dy + dz * dz; | |
const float B = ox * dx + oy * dy + oz * dz; // in fact B / 2 | |
const float C = ox * ox + oy * oy + oz * oz - radius * radius; | |
const float t = (- B + sqrtf(B * B - A * C)) / A; // always use the larger solution (positive) | |
// solve theta, phi (assume y is the up axis) | |
const float x = ox + t * dx, y = oy + t * dy, z = oz + t * dz; | |
const float theta = atan2(sqrtf(x * x + z * z), y); // [0, PI) | |
const float phi = atan2(z, x); // [-PI, PI) | |
// normalize to [-1, 1] | |
coords[0] = 2 * theta * RPI() - 1; | |
coords[1] = phi * RPI(); | |
} | |
void sph_from_ray(const at::Tensor rays_o, const at::Tensor rays_d, const float radius, const uint32_t N, at::Tensor coords) { | |
static constexpr uint32_t N_THREAD = 128; | |
AT_DISPATCH_FLOATING_TYPES_AND_HALF( | |
rays_o.scalar_type(), "sph_from_ray", ([&] { | |
kernel_sph_from_ray<<<div_round_up(N, N_THREAD), N_THREAD>>>(rays_o.data_ptr<scalar_t>(), rays_d.data_ptr<scalar_t>(), radius, N, coords.data_ptr<scalar_t>()); | |
})); | |
} | |
// coords: int32, [N, 3] | |
// indices: int32, [N] | |
__global__ void kernel_morton3D( | |
const int * __restrict__ coords, | |
const uint32_t N, | |
int * indices | |
) { | |
// parallel | |
const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; | |
if (n >= N) return; | |
// locate | |
coords += n * 3; | |
indices[n] = __morton3D(coords[0], coords[1], coords[2]); | |
} | |
void morton3D(const at::Tensor coords, const uint32_t N, at::Tensor indices) { | |
static constexpr uint32_t N_THREAD = 128; | |
kernel_morton3D<<<div_round_up(N, N_THREAD), N_THREAD>>>(coords.data_ptr<int>(), N, indices.data_ptr<int>()); | |
} | |
// indices: int32, [N] | |
// coords: int32, [N, 3] | |
__global__ void kernel_morton3D_invert( | |
const int * __restrict__ indices, | |
const uint32_t N, | |
int * coords | |
) { | |
// parallel | |
const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; | |
if (n >= N) return; | |
// locate | |
coords += n * 3; | |
const int ind = indices[n]; | |
coords[0] = __morton3D_invert(ind >> 0); | |
coords[1] = __morton3D_invert(ind >> 1); | |
coords[2] = __morton3D_invert(ind >> 2); | |
} | |
void morton3D_invert(const at::Tensor indices, const uint32_t N, at::Tensor coords) { | |
static constexpr uint32_t N_THREAD = 128; | |
kernel_morton3D_invert<<<div_round_up(N, N_THREAD), N_THREAD>>>(indices.data_ptr<int>(), N, coords.data_ptr<int>()); | |
} | |
// grid: float, [C, H, H, H] | |
// N: int, C * H * H * H / 8 | |
// density_thresh: float | |
// bitfield: uint8, [N] | |
template <typename scalar_t> | |
__global__ void kernel_packbits( | |
const scalar_t * __restrict__ grid, | |
const uint32_t N, | |
const float density_thresh, | |
uint8_t * bitfield | |
) { | |
// parallel per byte | |
const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; | |
if (n >= N) return; | |
// locate | |
grid += n * 8; | |
uint8_t bits = 0; | |
for (uint8_t i = 0; i < 8; i++) { | |
bits |= (grid[i] > density_thresh) ? ((uint8_t)1 << i) : 0; | |
} | |
bitfield[n] = bits; | |
} | |
void packbits(const at::Tensor grid, const uint32_t N, const float density_thresh, at::Tensor bitfield) { | |
static constexpr uint32_t N_THREAD = 128; | |
AT_DISPATCH_FLOATING_TYPES_AND_HALF( | |
grid.scalar_type(), "packbits", ([&] { | |
kernel_packbits<<<div_round_up(N, N_THREAD), N_THREAD>>>(grid.data_ptr<scalar_t>(), N, density_thresh, bitfield.data_ptr<uint8_t>()); | |
})); | |
} | |
// grid: float, [C, H, H, H] | |
__global__ void kernel_morton3D_dilation( | |
const float * __restrict__ grid, | |
const uint32_t C, | |
const uint32_t H, | |
float * __restrict__ grid_dilation | |
) { | |
// parallel per byte | |
const uint32_t H3 = H * H * H; | |
const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; | |
if (n >= C * H3) return; | |
// locate | |
const uint32_t c = n / H3; | |
const uint32_t ind = n - c * H3; | |
const uint32_t x = __morton3D_invert(ind >> 0); | |
const uint32_t y = __morton3D_invert(ind >> 1); | |
const uint32_t z = __morton3D_invert(ind >> 2); | |
// manual max pool | |
float res = grid[n]; | |
if (x + 1 < H) res = fmaxf(res, grid[c * H3 + __morton3D(x + 1, y, z)]); | |
if (x > 0) res = fmaxf(res, grid[c * H3 + __morton3D(x - 1, y, z)]); | |
if (y + 1 < H) res = fmaxf(res, grid[c * H3 + __morton3D(x, y + 1, z)]); | |
if (y > 0) res = fmaxf(res, grid[c * H3 + __morton3D(x, y - 1, z)]); | |
if (z + 1 < H) res = fmaxf(res, grid[c * H3 + __morton3D(x, y, z + 1)]); | |
if (z > 0) res = fmaxf(res, grid[c * H3 + __morton3D(x, y, z - 1)]); | |
// write | |
grid_dilation[n] = res; | |
} | |
void morton3D_dilation(const at::Tensor grid, const uint32_t C, const uint32_t H, at::Tensor grid_dilation) { | |
static constexpr uint32_t N_THREAD = 128; | |
kernel_morton3D_dilation<<<div_round_up(C * H * H * H, N_THREAD), N_THREAD>>>(grid.data_ptr<float>(), C, H, grid_dilation.data_ptr<float>()); | |
} | |
//////////////////////////////////////////////////// | |
///////////// training ///////////// | |
//////////////////////////////////////////////////// | |
// rays_o/d: [N, 3] | |
// grid: [CHHH / 8] | |
// xyzs, dirs, deltas: [M, 3], [M, 3], [M, 2] | |
// dirs: [M, 3] | |
// rays: [N, 3], idx, offset, num_steps | |
template <typename scalar_t> | |
__global__ void kernel_march_rays_train( | |
const scalar_t * __restrict__ rays_o, | |
const scalar_t * __restrict__ rays_d, | |
const uint8_t * __restrict__ grid, | |
const float bound, | |
const float dt_gamma, const uint32_t max_steps, | |
const uint32_t N, const uint32_t C, const uint32_t H, const uint32_t M, | |
const scalar_t* __restrict__ nears, | |
const scalar_t* __restrict__ fars, | |
scalar_t * xyzs, scalar_t * dirs, scalar_t * deltas, | |
int * rays, | |
int * counter, | |
const scalar_t* __restrict__ noises | |
) { | |
// parallel per ray | |
const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; | |
if (n >= N) return; | |
// locate | |
rays_o += n * 3; | |
rays_d += n * 3; | |
// ray marching | |
const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2]; | |
const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2]; | |
const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz; | |
const float rH = 1 / (float)H; | |
const float H3 = H * H * H; | |
const float near = nears[n]; | |
const float far = fars[n]; | |
const float noise = noises[n]; | |
const float dt_max = 2 * SQRT3() * (1 << (C - 1)) / H; | |
const float dt_min = fminf(dt_max, 2 * SQRT3() / max_steps); | |
float t0 = near; | |
// perturb | |
t0 += clamp(t0 * dt_gamma, dt_min, dt_max) * noise; | |
// first pass: estimation of num_steps | |
float t = t0; | |
uint32_t num_steps = 0; | |
//if (t < far) printf("valid ray %d t=%f near=%f far=%f \n", n, t, near, far); | |
while (t < far && num_steps < max_steps) { | |
// current point | |
const float x = clamp(ox + t * dx, -bound, bound); | |
const float y = clamp(oy + t * dy, -bound, bound); | |
const float z = clamp(oz + t * dz, -bound, bound); | |
const float dt = clamp(t * dt_gamma, dt_min, dt_max); | |
// get mip level | |
const int level = max(mip_from_pos(x, y, z, C), mip_from_dt(dt, H, C)); // range in [0, C - 1] | |
const float mip_bound = fminf(scalbnf(1.0f, level), bound); | |
const float mip_rbound = 1 / mip_bound; | |
// convert to nearest grid position | |
const int nx = clamp(0.5 * (x * mip_rbound + 1) * H, 0.0f, (float)(H - 1)); | |
const int ny = clamp(0.5 * (y * mip_rbound + 1) * H, 0.0f, (float)(H - 1)); | |
const int nz = clamp(0.5 * (z * mip_rbound + 1) * H, 0.0f, (float)(H - 1)); | |
const uint32_t index = level * H3 + __morton3D(nx, ny, nz); | |
const bool occ = grid[index / 8] & (1 << (index % 8)); | |
// if occpuied, advance a small step, and write to output | |
//if (n == 0) printf("t=%f density=%f vs thresh=%f step=%d\n", t, density, density_thresh, num_steps); | |
if (occ) { | |
num_steps++; | |
t += dt; | |
// else, skip a large step (basically skip a voxel grid) | |
} else { | |
// calc distance to next voxel | |
const float tx = (((nx + 0.5f + 0.5f * signf(dx)) * rH * 2 - 1) * mip_bound - x) * rdx; | |
const float ty = (((ny + 0.5f + 0.5f * signf(dy)) * rH * 2 - 1) * mip_bound - y) * rdy; | |
const float tz = (((nz + 0.5f + 0.5f * signf(dz)) * rH * 2 - 1) * mip_bound - z) * rdz; | |
const float tt = t + fmaxf(0.0f, fminf(tx, fminf(ty, tz))); | |
// step until next voxel | |
do { | |
t += clamp(t * dt_gamma, dt_min, dt_max); | |
} while (t < tt); | |
} | |
} | |
//printf("[n=%d] num_steps=%d, near=%f, far=%f, dt=%f, max_steps=%f\n", n, num_steps, near, far, dt_min, (far - near) / dt_min); | |
// second pass: really locate and write points & dirs | |
uint32_t point_index = atomicAdd(counter, num_steps); | |
uint32_t ray_index = atomicAdd(counter + 1, 1); | |
//printf("[n=%d] num_steps=%d, point_index=%d, ray_index=%d\n", n, num_steps, point_index, ray_index); | |
// write rays | |
rays[ray_index * 3] = n; | |
rays[ray_index * 3 + 1] = point_index; | |
rays[ray_index * 3 + 2] = num_steps; | |
if (num_steps == 0) return; | |
if (point_index + num_steps > M) return; | |
xyzs += point_index * 3; | |
dirs += point_index * 3; | |
deltas += point_index * 2; | |
t = t0; | |
uint32_t step = 0; | |
while (t < far && step < num_steps) { | |
// current point | |
const float x = clamp(ox + t * dx, -bound, bound); | |
const float y = clamp(oy + t * dy, -bound, bound); | |
const float z = clamp(oz + t * dz, -bound, bound); | |
const float dt = clamp(t * dt_gamma, dt_min, dt_max); | |
// get mip level | |
const int level = max(mip_from_pos(x, y, z, C), mip_from_dt(dt, H, C)); // range in [0, C - 1] | |
const float mip_bound = fminf(scalbnf(1.0f, level), bound); | |
const float mip_rbound = 1 / mip_bound; | |
// convert to nearest grid position | |
const int nx = clamp(0.5 * (x * mip_rbound + 1) * H, 0.0f, (float)(H - 1)); | |
const int ny = clamp(0.5 * (y * mip_rbound + 1) * H, 0.0f, (float)(H - 1)); | |
const int nz = clamp(0.5 * (z * mip_rbound + 1) * H, 0.0f, (float)(H - 1)); | |
// query grid | |
const uint32_t index = level * H3 + __morton3D(nx, ny, nz); | |
const bool occ = grid[index / 8] & (1 << (index % 8)); | |
// if occpuied, advance a small step, and write to output | |
if (occ) { | |
// write step | |
xyzs[0] = x; | |
xyzs[1] = y; | |
xyzs[2] = z; | |
dirs[0] = dx; | |
dirs[1] = dy; | |
dirs[2] = dz; | |
t += dt; | |
deltas[0] = dt; | |
deltas[1] = t; // used to calc depth | |
xyzs += 3; | |
dirs += 3; | |
deltas += 2; | |
step++; | |
// else, skip a large step (basically skip a voxel grid) | |
} else { | |
// calc distance to next voxel | |
const float tx = (((nx + 0.5f + 0.5f * signf(dx)) * rH * 2 - 1) * mip_bound - x) * rdx; | |
const float ty = (((ny + 0.5f + 0.5f * signf(dy)) * rH * 2 - 1) * mip_bound - y) * rdy; | |
const float tz = (((nz + 0.5f + 0.5f * signf(dz)) * rH * 2 - 1) * mip_bound - z) * rdz; | |
const float tt = t + fmaxf(0.0f, fminf(tx, fminf(ty, tz))); | |
// step until next voxel | |
do { | |
t += clamp(t * dt_gamma, dt_min, dt_max); | |
} while (t < tt); | |
} | |
} | |
} | |
void march_rays_train(const at::Tensor rays_o, const at::Tensor rays_d, const at::Tensor grid, const float bound, const float dt_gamma, const uint32_t max_steps, const uint32_t N, const uint32_t C, const uint32_t H, const uint32_t M, const at::Tensor nears, const at::Tensor fars, at::Tensor xyzs, at::Tensor dirs, at::Tensor deltas, at::Tensor rays, at::Tensor counter, at::Tensor noises) { | |
static constexpr uint32_t N_THREAD = 128; | |
AT_DISPATCH_FLOATING_TYPES_AND_HALF( | |
rays_o.scalar_type(), "march_rays_train", ([&] { | |
kernel_march_rays_train<<<div_round_up(N, N_THREAD), N_THREAD>>>(rays_o.data_ptr<scalar_t>(), rays_d.data_ptr<scalar_t>(), grid.data_ptr<uint8_t>(), bound, dt_gamma, max_steps, N, C, H, M, nears.data_ptr<scalar_t>(), fars.data_ptr<scalar_t>(), xyzs.data_ptr<scalar_t>(), dirs.data_ptr<scalar_t>(), deltas.data_ptr<scalar_t>(), rays.data_ptr<int>(), counter.data_ptr<int>(), noises.data_ptr<scalar_t>()); | |
})); | |
} | |
// grad_xyzs/dirs: [M, 3] | |
// rays: [N, 3] | |
// deltas: [M, 2] | |
// grad_rays_o/d: [N, 3] | |
template <typename scalar_t> | |
__global__ void kernel_march_rays_train_backward( | |
const scalar_t * __restrict__ grad_xyzs, | |
const scalar_t * __restrict__ grad_dirs, | |
const int * __restrict__ rays, | |
const scalar_t * __restrict__ deltas, | |
const uint32_t N, const uint32_t M, | |
scalar_t * grad_rays_o, | |
scalar_t * grad_rays_d | |
) { | |
// parallel per ray | |
const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; | |
if (n >= N) return; | |
// locate | |
grad_rays_o += n * 3; | |
grad_rays_d += n * 3; | |
uint32_t index = rays[n * 3]; | |
uint32_t offset = rays[n * 3 + 1]; | |
uint32_t num_steps = rays[n * 3 + 2]; | |
// empty ray, or ray that exceed max step count. | |
if (num_steps == 0 || offset + num_steps > M) return; | |
grad_xyzs += offset * 3; | |
grad_dirs += offset * 3; | |
deltas += offset * 2; | |
// accumulate | |
uint32_t step = 0; | |
while (step < num_steps) { | |
grad_rays_o[0] += grad_xyzs[0]; | |
grad_rays_o[1] += grad_xyzs[1]; | |
grad_rays_o[2] += grad_xyzs[2]; | |
grad_rays_d[0] += grad_xyzs[0] * deltas[1] + grad_dirs[0]; | |
grad_rays_d[1] += grad_xyzs[1] * deltas[1] + grad_dirs[1]; | |
grad_rays_d[2] += grad_xyzs[2] * deltas[1] + grad_dirs[2]; | |
// locate | |
grad_xyzs += 3; | |
grad_dirs += 3; | |
deltas += 2; | |
step++; | |
} | |
} | |
void march_rays_train_backward(const at::Tensor grad_xyzs, const at::Tensor grad_dirs, const at::Tensor rays, const at::Tensor deltas, const uint32_t N, const uint32_t M, at::Tensor grad_rays_o, at::Tensor grad_rays_d) { | |
static constexpr uint32_t N_THREAD = 128; | |
AT_DISPATCH_FLOATING_TYPES_AND_HALF( | |
grad_xyzs.scalar_type(), "march_rays_train_backward", ([&] { | |
kernel_march_rays_train_backward<<<div_round_up(N, N_THREAD), N_THREAD>>>(grad_xyzs.data_ptr<scalar_t>(), grad_dirs.data_ptr<scalar_t>(), rays.data_ptr<int>(), deltas.data_ptr<scalar_t>(), N, M, grad_rays_o.data_ptr<scalar_t>(), grad_rays_d.data_ptr<scalar_t>()); | |
})); | |
} | |
// sigmas: [M] | |
// rgbs: [M, 3] | |
// deltas: [M, 2] | |
// rays: [N, 3], idx, offset, num_steps | |
// weights_sum: [N], final pixel alpha | |
// depth: [N,] | |
// image: [N, 3] | |
template <typename scalar_t> | |
__global__ void kernel_composite_rays_train_forward( | |
const scalar_t * __restrict__ sigmas, | |
const scalar_t * __restrict__ rgbs, | |
const scalar_t * __restrict__ ambient, | |
const scalar_t * __restrict__ deltas, | |
const int * __restrict__ rays, | |
const uint32_t M, const uint32_t N, const float T_thresh, | |
scalar_t * weights_sum, | |
scalar_t * ambient_sum, | |
scalar_t * depth, | |
scalar_t * image | |
) { | |
// parallel per ray | |
const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; | |
if (n >= N) return; | |
// locate | |
uint32_t index = rays[n * 3]; | |
uint32_t offset = rays[n * 3 + 1]; | |
uint32_t num_steps = rays[n * 3 + 2]; | |
// empty ray, or ray that exceed max step count. | |
if (num_steps == 0 || offset + num_steps > M) { | |
weights_sum[index] = 0; | |
ambient_sum[index] = 0; | |
depth[index] = 0; | |
image[index * 3] = 0; | |
image[index * 3 + 1] = 0; | |
image[index * 3 + 2] = 0; | |
return; | |
} | |
sigmas += offset; | |
rgbs += offset * 3; | |
ambient += offset; | |
deltas += offset * 2; | |
// accumulate | |
uint32_t step = 0; | |
scalar_t T = 1.0f; | |
scalar_t r = 0, g = 0, b = 0, ws = 0, d = 0, amb = 0; | |
while (step < num_steps) { | |
const scalar_t alpha = 1.0f - __expf(- sigmas[0] * deltas[0]); | |
const scalar_t weight = alpha * T; | |
r += weight * rgbs[0]; | |
g += weight * rgbs[1]; | |
b += weight * rgbs[2]; | |
d += weight * deltas[1]; | |
ws += weight; | |
amb += ambient[0]; | |
T *= 1.0f - alpha; | |
// minimal remained transmittence | |
if (T < T_thresh) break; | |
//printf("[n=%d] num_steps=%d, alpha=%f, w=%f, T=%f, sum_dt=%f, d=%f\n", n, step, alpha, weight, T, sum_delta, d); | |
// locate | |
sigmas++; | |
rgbs += 3; | |
ambient++; | |
deltas += 2; | |
step++; | |
} | |
//printf("[n=%d] rgb=(%f, %f, %f), d=%f\n", n, r, g, b, d); | |
// write | |
weights_sum[index] = ws; // weights_sum | |
ambient_sum[index] = amb; | |
depth[index] = d; | |
image[index * 3] = r; | |
image[index * 3 + 1] = g; | |
image[index * 3 + 2] = b; | |
} | |
void composite_rays_train_forward(const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor ambient, const at::Tensor deltas, const at::Tensor rays, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor weights_sum, at::Tensor ambient_sum, at::Tensor depth, at::Tensor image) { | |
static constexpr uint32_t N_THREAD = 128; | |
AT_DISPATCH_FLOATING_TYPES_AND_HALF( | |
sigmas.scalar_type(), "composite_rays_train_forward", ([&] { | |
kernel_composite_rays_train_forward<<<div_round_up(N, N_THREAD), N_THREAD>>>(sigmas.data_ptr<scalar_t>(), rgbs.data_ptr<scalar_t>(), ambient.data_ptr<scalar_t>(), deltas.data_ptr<scalar_t>(), rays.data_ptr<int>(), M, N, T_thresh, weights_sum.data_ptr<scalar_t>(), ambient_sum.data_ptr<scalar_t>(), depth.data_ptr<scalar_t>(), image.data_ptr<scalar_t>()); | |
})); | |
} | |
// grad_weights_sum: [N,] | |
// grad: [N, 3] | |
// sigmas: [M] | |
// rgbs: [M, 3] | |
// deltas: [M, 2] | |
// rays: [N, 3], idx, offset, num_steps | |
// weights_sum: [N,], weights_sum here | |
// image: [N, 3] | |
// grad_sigmas: [M] | |
// grad_rgbs: [M, 3] | |
template <typename scalar_t> | |
__global__ void kernel_composite_rays_train_backward( | |
const scalar_t * __restrict__ grad_weights_sum, | |
const scalar_t * __restrict__ grad_ambient_sum, | |
const scalar_t * __restrict__ grad_image, | |
const scalar_t * __restrict__ sigmas, | |
const scalar_t * __restrict__ rgbs, | |
const scalar_t * __restrict__ ambient, | |
const scalar_t * __restrict__ deltas, | |
const int * __restrict__ rays, | |
const scalar_t * __restrict__ weights_sum, | |
const scalar_t * __restrict__ ambient_sum, | |
const scalar_t * __restrict__ image, | |
const uint32_t M, const uint32_t N, const float T_thresh, | |
scalar_t * grad_sigmas, | |
scalar_t * grad_rgbs, | |
scalar_t * grad_ambient | |
) { | |
// parallel per ray | |
const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; | |
if (n >= N) return; | |
// locate | |
uint32_t index = rays[n * 3]; | |
uint32_t offset = rays[n * 3 + 1]; | |
uint32_t num_steps = rays[n * 3 + 2]; | |
if (num_steps == 0 || offset + num_steps > M) return; | |
grad_weights_sum += index; | |
grad_ambient_sum += index; | |
grad_image += index * 3; | |
weights_sum += index; | |
ambient_sum += index; | |
image += index * 3; | |
sigmas += offset; | |
rgbs += offset * 3; | |
ambient += offset; | |
deltas += offset * 2; | |
grad_sigmas += offset; | |
grad_rgbs += offset * 3; | |
grad_ambient += offset; | |
// accumulate | |
uint32_t step = 0; | |
scalar_t T = 1.0f; | |
const scalar_t r_final = image[0], g_final = image[1], b_final = image[2], ws_final = weights_sum[0]; | |
scalar_t r = 0, g = 0, b = 0, ws = 0; | |
while (step < num_steps) { | |
const scalar_t alpha = 1.0f - __expf(- sigmas[0] * deltas[0]); | |
const scalar_t weight = alpha * T; | |
r += weight * rgbs[0]; | |
g += weight * rgbs[1]; | |
b += weight * rgbs[2]; | |
// amb += weight * ambient[0]; | |
ws += weight; | |
T *= 1.0f - alpha; | |
// check https://note.kiui.moe/others/nerf_gradient/ for the gradient calculation. | |
// write grad_rgbs | |
grad_rgbs[0] = grad_image[0] * weight; | |
grad_rgbs[1] = grad_image[1] * weight; | |
grad_rgbs[2] = grad_image[2] * weight; | |
// write grad_ambient | |
grad_ambient[0] = grad_ambient_sum[0]; | |
// write grad_sigmas | |
grad_sigmas[0] = deltas[0] * ( | |
grad_image[0] * (T * rgbs[0] - (r_final - r)) + | |
grad_image[1] * (T * rgbs[1] - (g_final - g)) + | |
grad_image[2] * (T * rgbs[2] - (b_final - b)) + | |
// grad_ambient_sum[0] * (T * ambient[0] - (amb_final - amb)) + | |
grad_weights_sum[0] * (1 - ws_final) | |
); | |
//printf("[n=%d] num_steps=%d, T=%f, grad_sigmas=%f, r_final=%f, r=%f\n", n, step, T, grad_sigmas[0], r_final, r); | |
// minimal remained transmittence | |
if (T < T_thresh) break; | |
// locate | |
sigmas++; | |
rgbs += 3; | |
// ambient++; | |
deltas += 2; | |
grad_sigmas++; | |
grad_rgbs += 3; | |
grad_ambient++; | |
step++; | |
} | |
} | |
void composite_rays_train_backward(const at::Tensor grad_weights_sum, const at::Tensor grad_ambient_sum, const at::Tensor grad_image, const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor ambient, const at::Tensor deltas, const at::Tensor rays, const at::Tensor weights_sum, const at::Tensor ambient_sum, const at::Tensor image, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor grad_sigmas, at::Tensor grad_rgbs, at::Tensor grad_ambient) { | |
static constexpr uint32_t N_THREAD = 128; | |
AT_DISPATCH_FLOATING_TYPES_AND_HALF( | |
grad_image.scalar_type(), "composite_rays_train_backward", ([&] { | |
kernel_composite_rays_train_backward<<<div_round_up(N, N_THREAD), N_THREAD>>>(grad_weights_sum.data_ptr<scalar_t>(), grad_ambient_sum.data_ptr<scalar_t>(), grad_image.data_ptr<scalar_t>(), sigmas.data_ptr<scalar_t>(), rgbs.data_ptr<scalar_t>(), ambient.data_ptr<scalar_t>(), deltas.data_ptr<scalar_t>(), rays.data_ptr<int>(), weights_sum.data_ptr<scalar_t>(), ambient_sum.data_ptr<scalar_t>(), image.data_ptr<scalar_t>(), M, N, T_thresh, grad_sigmas.data_ptr<scalar_t>(), grad_rgbs.data_ptr<scalar_t>(), grad_ambient.data_ptr<scalar_t>()); | |
})); | |
} | |
//////////////////////////////////////////////////// | |
///////////// infernce ///////////// | |
//////////////////////////////////////////////////// | |
template <typename scalar_t> | |
__global__ void kernel_march_rays( | |
const uint32_t n_alive, | |
const uint32_t n_step, | |
const int* __restrict__ rays_alive, | |
const scalar_t* __restrict__ rays_t, | |
const scalar_t* __restrict__ rays_o, | |
const scalar_t* __restrict__ rays_d, | |
const float bound, | |
const float dt_gamma, const uint32_t max_steps, | |
const uint32_t C, const uint32_t H, | |
const uint8_t * __restrict__ grid, | |
const scalar_t* __restrict__ nears, | |
const scalar_t* __restrict__ fars, | |
scalar_t* xyzs, scalar_t* dirs, scalar_t* deltas, | |
const scalar_t* __restrict__ noises | |
) { | |
const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; | |
if (n >= n_alive) return; | |
const int index = rays_alive[n]; // ray id | |
const float noise = noises[n]; | |
// locate | |
rays_o += index * 3; | |
rays_d += index * 3; | |
xyzs += n * n_step * 3; | |
dirs += n * n_step * 3; | |
deltas += n * n_step * 2; | |
const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2]; | |
const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2]; | |
const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz; | |
const float rH = 1 / (float)H; | |
const float H3 = H * H * H; | |
float t = rays_t[index]; // current ray's t | |
const float near = nears[index], far = fars[index]; | |
const float dt_max = 2 * SQRT3() * (1 << (C - 1)) / H; | |
const float dt_min = fminf(dt_max, 2 * SQRT3() / max_steps); | |
// march for n_step steps, record points | |
uint32_t step = 0; | |
// introduce some randomness | |
t += clamp(t * dt_gamma, dt_min, dt_max) * noise; | |
while (t < far && step < n_step) { | |
// current point | |
const float x = clamp(ox + t * dx, -bound, bound); | |
const float y = clamp(oy + t * dy, -bound, bound); | |
const float z = clamp(oz + t * dz, -bound, bound); | |
const float dt = clamp(t * dt_gamma, dt_min, dt_max); | |
// get mip level | |
const int level = max(mip_from_pos(x, y, z, C), mip_from_dt(dt, H, C)); // range in [0, C - 1] | |
const float mip_bound = fminf(scalbnf(1, level), bound); | |
const float mip_rbound = 1 / mip_bound; | |
// convert to nearest grid position | |
const int nx = clamp(0.5 * (x * mip_rbound + 1) * H, 0.0f, (float)(H - 1)); | |
const int ny = clamp(0.5 * (y * mip_rbound + 1) * H, 0.0f, (float)(H - 1)); | |
const int nz = clamp(0.5 * (z * mip_rbound + 1) * H, 0.0f, (float)(H - 1)); | |
const uint32_t index = level * H3 + __morton3D(nx, ny, nz); | |
const bool occ = grid[index / 8] & (1 << (index % 8)); | |
// if occpuied, advance a small step, and write to output | |
if (occ) { | |
// write step | |
xyzs[0] = x; | |
xyzs[1] = y; | |
xyzs[2] = z; | |
dirs[0] = dx; | |
dirs[1] = dy; | |
dirs[2] = dz; | |
// calc dt | |
t += dt; | |
deltas[0] = dt; | |
deltas[1] = t; // used to calc depth | |
// step | |
xyzs += 3; | |
dirs += 3; | |
deltas += 2; | |
step++; | |
// else, skip a large step (basically skip a voxel grid) | |
} else { | |
// calc distance to next voxel | |
const float tx = (((nx + 0.5f + 0.5f * signf(dx)) * rH * 2 - 1) * mip_bound - x) * rdx; | |
const float ty = (((ny + 0.5f + 0.5f * signf(dy)) * rH * 2 - 1) * mip_bound - y) * rdy; | |
const float tz = (((nz + 0.5f + 0.5f * signf(dz)) * rH * 2 - 1) * mip_bound - z) * rdz; | |
const float tt = t + fmaxf(0.0f, fminf(tx, fminf(ty, tz))); | |
// step until next voxel | |
do { | |
t += clamp(t * dt_gamma, dt_min, dt_max); | |
} while (t < tt); | |
} | |
} | |
} | |
void march_rays(const uint32_t n_alive, const uint32_t n_step, const at::Tensor rays_alive, const at::Tensor rays_t, const at::Tensor rays_o, const at::Tensor rays_d, const float bound, const float dt_gamma, const uint32_t max_steps, const uint32_t C, const uint32_t H, const at::Tensor grid, const at::Tensor near, const at::Tensor far, at::Tensor xyzs, at::Tensor dirs, at::Tensor deltas, at::Tensor noises) { | |
static constexpr uint32_t N_THREAD = 128; | |
AT_DISPATCH_FLOATING_TYPES_AND_HALF( | |
rays_o.scalar_type(), "march_rays", ([&] { | |
kernel_march_rays<<<div_round_up(n_alive, N_THREAD), N_THREAD>>>(n_alive, n_step, rays_alive.data_ptr<int>(), rays_t.data_ptr<scalar_t>(), rays_o.data_ptr<scalar_t>(), rays_d.data_ptr<scalar_t>(), bound, dt_gamma, max_steps, C, H, grid.data_ptr<uint8_t>(), near.data_ptr<scalar_t>(), far.data_ptr<scalar_t>(), xyzs.data_ptr<scalar_t>(), dirs.data_ptr<scalar_t>(), deltas.data_ptr<scalar_t>(), noises.data_ptr<scalar_t>()); | |
})); | |
} | |
template <typename scalar_t> | |
__global__ void kernel_composite_rays( | |
const uint32_t n_alive, | |
const uint32_t n_step, | |
const float T_thresh, | |
int* rays_alive, | |
scalar_t* rays_t, | |
const scalar_t* __restrict__ sigmas, | |
const scalar_t* __restrict__ rgbs, | |
const scalar_t* __restrict__ deltas, | |
scalar_t* weights_sum, scalar_t* depth, scalar_t* image | |
) { | |
const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; | |
if (n >= n_alive) return; | |
const int index = rays_alive[n]; // ray id | |
// locate | |
sigmas += n * n_step; | |
rgbs += n * n_step * 3; | |
deltas += n * n_step * 2; | |
rays_t += index; | |
weights_sum += index; | |
depth += index; | |
image += index * 3; | |
scalar_t t = rays_t[0]; // current ray's t | |
scalar_t weight_sum = weights_sum[0]; | |
scalar_t d = depth[0]; | |
scalar_t r = image[0]; | |
scalar_t g = image[1]; | |
scalar_t b = image[2]; | |
// accumulate | |
uint32_t step = 0; | |
while (step < n_step) { | |
// ray is terminated if delta == 0 | |
if (deltas[0] == 0) break; | |
const scalar_t alpha = 1.0f - __expf(- sigmas[0] * deltas[0]); | |
/* | |
T_0 = 1; T_i = \prod_{j=0}^{i-1} (1 - alpha_j) | |
w_i = alpha_i * T_i | |
--> | |
T_i = 1 - \sum_{j=0}^{i-1} w_j | |
*/ | |
const scalar_t T = 1 - weight_sum; | |
const scalar_t weight = alpha * T; | |
weight_sum += weight; | |
t = deltas[1]; | |
d += weight * t; | |
r += weight * rgbs[0]; | |
g += weight * rgbs[1]; | |
b += weight * rgbs[2]; | |
//printf("[n=%d] num_steps=%d, alpha=%f, w=%f, T=%f, sum_dt=%f, d=%f\n", n, step, alpha, weight, T, sum_delta, d); | |
// ray is terminated if T is too small | |
// use a larger bound to further accelerate inference | |
if (T < T_thresh) break; | |
// locate | |
sigmas++; | |
rgbs += 3; | |
deltas += 2; | |
step++; | |
} | |
//printf("[n=%d] rgb=(%f, %f, %f), d=%f\n", n, r, g, b, d); | |
// rays_alive = -1 means ray is terminated early. | |
if (step < n_step) { | |
rays_alive[n] = -1; | |
} else { | |
rays_t[0] = t; | |
} | |
weights_sum[0] = weight_sum; // this is the thing I needed! | |
depth[0] = d; | |
image[0] = r; | |
image[1] = g; | |
image[2] = b; | |
} | |
void composite_rays(const uint32_t n_alive, const uint32_t n_step, const float T_thresh, at::Tensor rays_alive, at::Tensor rays_t, at::Tensor sigmas, at::Tensor rgbs, at::Tensor deltas, at::Tensor weights, at::Tensor depth, at::Tensor image) { | |
static constexpr uint32_t N_THREAD = 128; | |
AT_DISPATCH_FLOATING_TYPES_AND_HALF( | |
image.scalar_type(), "composite_rays", ([&] { | |
kernel_composite_rays<<<div_round_up(n_alive, N_THREAD), N_THREAD>>>(n_alive, n_step, T_thresh, rays_alive.data_ptr<int>(), rays_t.data_ptr<scalar_t>(), sigmas.data_ptr<scalar_t>(), rgbs.data_ptr<scalar_t>(), deltas.data_ptr<scalar_t>(), weights.data_ptr<scalar_t>(), depth.data_ptr<scalar_t>(), image.data_ptr<scalar_t>()); | |
})); | |
} | |
template <typename scalar_t> | |
__global__ void kernel_composite_rays_ambient( | |
const uint32_t n_alive, | |
const uint32_t n_step, | |
const float T_thresh, | |
int* rays_alive, | |
scalar_t* rays_t, | |
const scalar_t* __restrict__ sigmas, | |
const scalar_t* __restrict__ rgbs, | |
const scalar_t* __restrict__ deltas, | |
const scalar_t* __restrict__ ambients, | |
scalar_t* weights_sum, scalar_t* depth, scalar_t* image, scalar_t* ambient_sum | |
) { | |
const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; | |
if (n >= n_alive) return; | |
const int index = rays_alive[n]; // ray id | |
// locate | |
sigmas += n * n_step; | |
rgbs += n * n_step * 3; | |
deltas += n * n_step * 2; | |
ambients += n * n_step; | |
rays_t += index; | |
weights_sum += index; | |
depth += index; | |
image += index * 3; | |
ambient_sum += index; | |
scalar_t t = rays_t[0]; // current ray's t | |
scalar_t weight_sum = weights_sum[0]; | |
scalar_t d = depth[0]; | |
scalar_t r = image[0]; | |
scalar_t g = image[1]; | |
scalar_t b = image[2]; | |
scalar_t a = ambient_sum[0]; | |
// accumulate | |
uint32_t step = 0; | |
while (step < n_step) { | |
// ray is terminated if delta == 0 | |
if (deltas[0] == 0) break; | |
const scalar_t alpha = 1.0f - __expf(- sigmas[0] * deltas[0]); | |
/* | |
T_0 = 1; T_i = \prod_{j=0}^{i-1} (1 - alpha_j) | |
w_i = alpha_i * T_i | |
--> | |
T_i = 1 - \sum_{j=0}^{i-1} w_j | |
*/ | |
const scalar_t T = 1 - weight_sum; | |
const scalar_t weight = alpha * T; | |
weight_sum += weight; | |
t = deltas[1]; | |
d += weight * t; | |
r += weight * rgbs[0]; | |
g += weight * rgbs[1]; | |
b += weight * rgbs[2]; | |
a += ambients[0]; | |
//printf("[n=%d] num_steps=%d, alpha=%f, w=%f, T=%f, sum_dt=%f, d=%f\n", n, step, alpha, weight, T, sum_delta, d); | |
// ray is terminated if T is too small | |
// use a larger bound to further accelerate inference | |
if (T < T_thresh) break; | |
// locate | |
sigmas++; | |
rgbs += 3; | |
deltas += 2; | |
step++; | |
ambients++; | |
} | |
//printf("[n=%d] rgb=(%f, %f, %f), d=%f\n", n, r, g, b, d); | |
// rays_alive = -1 means ray is terminated early. | |
if (step < n_step) { | |
rays_alive[n] = -1; | |
} else { | |
rays_t[0] = t; | |
} | |
weights_sum[0] = weight_sum; // this is the thing I needed! | |
depth[0] = d; | |
image[0] = r; | |
image[1] = g; | |
image[2] = b; | |
ambient_sum[0] = a; | |
} | |
void composite_rays_ambient(const uint32_t n_alive, const uint32_t n_step, const float T_thresh, at::Tensor rays_alive, at::Tensor rays_t, at::Tensor sigmas, at::Tensor rgbs, at::Tensor deltas, at::Tensor ambients, at::Tensor weights, at::Tensor depth, at::Tensor image, at::Tensor ambient_sum) { | |
static constexpr uint32_t N_THREAD = 128; | |
AT_DISPATCH_FLOATING_TYPES_AND_HALF( | |
image.scalar_type(), "composite_rays_ambient", ([&] { | |
kernel_composite_rays_ambient<<<div_round_up(n_alive, N_THREAD), N_THREAD>>>(n_alive, n_step, T_thresh, rays_alive.data_ptr<int>(), rays_t.data_ptr<scalar_t>(), sigmas.data_ptr<scalar_t>(), rgbs.data_ptr<scalar_t>(), deltas.data_ptr<scalar_t>(), ambients.data_ptr<scalar_t>(), weights.data_ptr<scalar_t>(), depth.data_ptr<scalar_t>(), image.data_ptr<scalar_t>(), ambient_sum.data_ptr<scalar_t>()); | |
})); | |
} | |
// -------------------------------- sigma ambient ----------------------------- | |
// sigmas: [M] | |
// rgbs: [M, 3] | |
// deltas: [M, 2] | |
// rays: [N, 3], idx, offset, num_steps | |
// weights_sum: [N], final pixel alpha | |
// depth: [N,] | |
// image: [N, 3] | |
template <typename scalar_t> | |
__global__ void kernel_composite_rays_train_sigma_forward( | |
const scalar_t * __restrict__ sigmas, | |
const scalar_t * __restrict__ rgbs, | |
const scalar_t * __restrict__ ambient, | |
const scalar_t * __restrict__ deltas, | |
const int * __restrict__ rays, | |
const uint32_t M, const uint32_t N, const float T_thresh, | |
scalar_t * weights_sum, | |
scalar_t * ambient_sum, | |
scalar_t * depth, | |
scalar_t * image | |
) { | |
// parallel per ray | |
const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; | |
if (n >= N) return; | |
// locate | |
uint32_t index = rays[n * 3]; | |
uint32_t offset = rays[n * 3 + 1]; | |
uint32_t num_steps = rays[n * 3 + 2]; | |
// empty ray, or ray that exceed max step count. | |
if (num_steps == 0 || offset + num_steps > M) { | |
weights_sum[index] = 0; | |
ambient_sum[index] = 0; | |
depth[index] = 0; | |
image[index * 3] = 0; | |
image[index * 3 + 1] = 0; | |
image[index * 3 + 2] = 0; | |
return; | |
} | |
sigmas += offset; | |
rgbs += offset * 3; | |
ambient += offset; | |
deltas += offset * 2; | |
// accumulate | |
uint32_t step = 0; | |
scalar_t T = 1.0f; | |
scalar_t r = 0, g = 0, b = 0, ws = 0, d = 0, amb = 0; | |
while (step < num_steps) { | |
const scalar_t alpha = 1.0f - __expf(- sigmas[0] * deltas[0]); | |
const scalar_t weight = alpha * T; | |
r += weight * rgbs[0]; | |
g += weight * rgbs[1]; | |
b += weight * rgbs[2]; | |
d += weight * deltas[1]; | |
ws += weight; | |
amb += weight * ambient[0]; | |
T *= 1.0f - alpha; | |
// minimal remained transmittence | |
if (T < T_thresh) break; | |
//printf("[n=%d] num_steps=%d, alpha=%f, w=%f, T=%f, sum_dt=%f, d=%f\n", n, step, alpha, weight, T, sum_delta, d); | |
// locate | |
sigmas++; | |
rgbs += 3; | |
ambient++; | |
deltas += 2; | |
step++; | |
} | |
//printf("[n=%d] rgb=(%f, %f, %f), d=%f\n", n, r, g, b, d); | |
// write | |
weights_sum[index] = ws; // weights_sum | |
ambient_sum[index] = amb; | |
depth[index] = d; | |
image[index * 3] = r; | |
image[index * 3 + 1] = g; | |
image[index * 3 + 2] = b; | |
} | |
void composite_rays_train_sigma_forward(const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor ambient, const at::Tensor deltas, const at::Tensor rays, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor weights_sum, at::Tensor ambient_sum, at::Tensor depth, at::Tensor image) { | |
static constexpr uint32_t N_THREAD = 128; | |
AT_DISPATCH_FLOATING_TYPES_AND_HALF( | |
sigmas.scalar_type(), "composite_rays_train_sigma_forward", ([&] { | |
kernel_composite_rays_train_sigma_forward<<<div_round_up(N, N_THREAD), N_THREAD>>>(sigmas.data_ptr<scalar_t>(), rgbs.data_ptr<scalar_t>(), ambient.data_ptr<scalar_t>(), deltas.data_ptr<scalar_t>(), rays.data_ptr<int>(), M, N, T_thresh, weights_sum.data_ptr<scalar_t>(), ambient_sum.data_ptr<scalar_t>(), depth.data_ptr<scalar_t>(), image.data_ptr<scalar_t>()); | |
})); | |
} | |
// grad_weights_sum: [N,] | |
// grad: [N, 3] | |
// sigmas: [M] | |
// rgbs: [M, 3] | |
// deltas: [M, 2] | |
// rays: [N, 3], idx, offset, num_steps | |
// weights_sum: [N,], weights_sum here | |
// image: [N, 3] | |
// grad_sigmas: [M] | |
// grad_rgbs: [M, 3] | |
template <typename scalar_t> | |
__global__ void kernel_composite_rays_train_sigma_backward( | |
const scalar_t * __restrict__ grad_weights_sum, | |
const scalar_t * __restrict__ grad_ambient_sum, | |
const scalar_t * __restrict__ grad_image, | |
const scalar_t * __restrict__ sigmas, | |
const scalar_t * __restrict__ rgbs, | |
const scalar_t * __restrict__ ambient, | |
const scalar_t * __restrict__ deltas, | |
const int * __restrict__ rays, | |
const scalar_t * __restrict__ weights_sum, | |
const scalar_t * __restrict__ ambient_sum, | |
const scalar_t * __restrict__ image, | |
const uint32_t M, const uint32_t N, const float T_thresh, | |
scalar_t * grad_sigmas, | |
scalar_t * grad_rgbs, | |
scalar_t * grad_ambient | |
) { | |
// parallel per ray | |
const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; | |
if (n >= N) return; | |
// locate | |
uint32_t index = rays[n * 3]; | |
uint32_t offset = rays[n * 3 + 1]; | |
uint32_t num_steps = rays[n * 3 + 2]; | |
if (num_steps == 0 || offset + num_steps > M) return; | |
grad_weights_sum += index; | |
grad_ambient_sum += index; | |
grad_image += index * 3; | |
weights_sum += index; | |
ambient_sum += index; | |
image += index * 3; | |
sigmas += offset; | |
rgbs += offset * 3; | |
ambient += offset; | |
deltas += offset * 2; | |
grad_sigmas += offset; | |
grad_rgbs += offset * 3; | |
grad_ambient += offset; | |
// accumulate | |
uint32_t step = 0; | |
scalar_t T = 1.0f; | |
const scalar_t r_final = image[0], g_final = image[1], b_final = image[2], ws_final = weights_sum[0], amb_final = ambient_sum[0]; | |
scalar_t r = 0, g = 0, b = 0, ws = 0, amb = 0; | |
while (step < num_steps) { | |
const scalar_t alpha = 1.0f - __expf(- sigmas[0] * deltas[0]); | |
const scalar_t weight = alpha * T; | |
r += weight * rgbs[0]; | |
g += weight * rgbs[1]; | |
b += weight * rgbs[2]; | |
amb += weight * ambient[0]; | |
ws += weight; | |
T *= 1.0f - alpha; | |
// check https://note.kiui.moe/others/nerf_gradient/ for the gradient calculation. | |
// write grad_rgbs | |
grad_rgbs[0] = grad_image[0] * weight; | |
grad_rgbs[1] = grad_image[1] * weight; | |
grad_rgbs[2] = grad_image[2] * weight; | |
// write grad_ambient | |
grad_ambient[0] = grad_ambient_sum[0] * weight; | |
// write grad_sigmas | |
grad_sigmas[0] = deltas[0] * ( | |
grad_image[0] * (T * rgbs[0] - (r_final - r)) + | |
grad_image[1] * (T * rgbs[1] - (g_final - g)) + | |
grad_image[2] * (T * rgbs[2] - (b_final - b)) + | |
grad_ambient_sum[0] * (T * ambient[0] - (amb_final - amb)) + | |
grad_weights_sum[0] * (1 - ws_final) | |
); | |
//printf("[n=%d] num_steps=%d, T=%f, grad_sigmas=%f, r_final=%f, r=%f\n", n, step, T, grad_sigmas[0], r_final, r); | |
// minimal remained transmittence | |
if (T < T_thresh) break; | |
// locate | |
sigmas++; | |
rgbs += 3; | |
ambient++; | |
deltas += 2; | |
grad_sigmas++; | |
grad_rgbs += 3; | |
grad_ambient++; | |
step++; | |
} | |
} | |
void composite_rays_train_sigma_backward(const at::Tensor grad_weights_sum, const at::Tensor grad_ambient_sum, const at::Tensor grad_image, const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor ambient, const at::Tensor deltas, const at::Tensor rays, const at::Tensor weights_sum, const at::Tensor ambient_sum, const at::Tensor image, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor grad_sigmas, at::Tensor grad_rgbs, at::Tensor grad_ambient) { | |
static constexpr uint32_t N_THREAD = 128; | |
AT_DISPATCH_FLOATING_TYPES_AND_HALF( | |
grad_image.scalar_type(), "composite_rays_train_sigma_backward", ([&] { | |
kernel_composite_rays_train_sigma_backward<<<div_round_up(N, N_THREAD), N_THREAD>>>(grad_weights_sum.data_ptr<scalar_t>(), grad_ambient_sum.data_ptr<scalar_t>(), grad_image.data_ptr<scalar_t>(), sigmas.data_ptr<scalar_t>(), rgbs.data_ptr<scalar_t>(), ambient.data_ptr<scalar_t>(), deltas.data_ptr<scalar_t>(), rays.data_ptr<int>(), weights_sum.data_ptr<scalar_t>(), ambient_sum.data_ptr<scalar_t>(), image.data_ptr<scalar_t>(), M, N, T_thresh, grad_sigmas.data_ptr<scalar_t>(), grad_rgbs.data_ptr<scalar_t>(), grad_ambient.data_ptr<scalar_t>()); | |
})); | |
} | |
//////////////////////////////////////////////////// | |
///////////// infernce ///////////// | |
//////////////////////////////////////////////////// | |
template <typename scalar_t> | |
__global__ void kernel_composite_rays_ambient_sigma( | |
const uint32_t n_alive, | |
const uint32_t n_step, | |
const float T_thresh, | |
int* rays_alive, | |
scalar_t* rays_t, | |
const scalar_t* __restrict__ sigmas, | |
const scalar_t* __restrict__ rgbs, | |
const scalar_t* __restrict__ deltas, | |
const scalar_t* __restrict__ ambients, | |
scalar_t* weights_sum, scalar_t* depth, scalar_t* image, scalar_t* ambient_sum | |
) { | |
const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; | |
if (n >= n_alive) return; | |
const int index = rays_alive[n]; // ray id | |
// locate | |
sigmas += n * n_step; | |
rgbs += n * n_step * 3; | |
deltas += n * n_step * 2; | |
ambients += n * n_step; | |
rays_t += index; | |
weights_sum += index; | |
depth += index; | |
image += index * 3; | |
ambient_sum += index; | |
scalar_t t = rays_t[0]; // current ray's t | |
scalar_t weight_sum = weights_sum[0]; | |
scalar_t d = depth[0]; | |
scalar_t r = image[0]; | |
scalar_t g = image[1]; | |
scalar_t b = image[2]; | |
scalar_t a = ambient_sum[0]; | |
// accumulate | |
uint32_t step = 0; | |
while (step < n_step) { | |
// ray is terminated if delta == 0 | |
if (deltas[0] == 0) break; | |
const scalar_t alpha = 1.0f - __expf(- sigmas[0] * deltas[0]); | |
/* | |
T_0 = 1; T_i = \prod_{j=0}^{i-1} (1 - alpha_j) | |
w_i = alpha_i * T_i | |
--> | |
T_i = 1 - \sum_{j=0}^{i-1} w_j | |
*/ | |
const scalar_t T = 1 - weight_sum; | |
const scalar_t weight = alpha * T; | |
weight_sum += weight; | |
t = deltas[1]; | |
d += weight * t; | |
r += weight * rgbs[0]; | |
g += weight * rgbs[1]; | |
b += weight * rgbs[2]; | |
a += weight * ambients[0]; | |
//printf("[n=%d] num_steps=%d, alpha=%f, w=%f, T=%f, sum_dt=%f, d=%f\n", n, step, alpha, weight, T, sum_delta, d); | |
// ray is terminated if T is too small | |
// use a larger bound to further accelerate inference | |
if (T < T_thresh) break; | |
// locate | |
sigmas++; | |
rgbs += 3; | |
deltas += 2; | |
step++; | |
ambients++; | |
} | |
//printf("[n=%d] rgb=(%f, %f, %f), d=%f\n", n, r, g, b, d); | |
// rays_alive = -1 means ray is terminated early. | |
if (step < n_step) { | |
rays_alive[n] = -1; | |
} else { | |
rays_t[0] = t; | |
} | |
weights_sum[0] = weight_sum; // this is the thing I needed! | |
depth[0] = d; | |
image[0] = r; | |
image[1] = g; | |
image[2] = b; | |
ambient_sum[0] = a; | |
} | |
void composite_rays_ambient_sigma(const uint32_t n_alive, const uint32_t n_step, const float T_thresh, at::Tensor rays_alive, at::Tensor rays_t, at::Tensor sigmas, at::Tensor rgbs, at::Tensor deltas, at::Tensor ambients, at::Tensor weights, at::Tensor depth, at::Tensor image, at::Tensor ambient_sum) { | |
static constexpr uint32_t N_THREAD = 128; | |
AT_DISPATCH_FLOATING_TYPES_AND_HALF( | |
image.scalar_type(), "composite_rays_ambient_sigma", ([&] { | |
kernel_composite_rays_ambient_sigma<<<div_round_up(n_alive, N_THREAD), N_THREAD>>>(n_alive, n_step, T_thresh, rays_alive.data_ptr<int>(), rays_t.data_ptr<scalar_t>(), sigmas.data_ptr<scalar_t>(), rgbs.data_ptr<scalar_t>(), deltas.data_ptr<scalar_t>(), ambients.data_ptr<scalar_t>(), weights.data_ptr<scalar_t>(), depth.data_ptr<scalar_t>(), image.data_ptr<scalar_t>(), ambient_sum.data_ptr<scalar_t>()); | |
})); | |
} | |
// -------------------------------- uncertainty ----------------------------- | |
// sigmas: [M] | |
// rgbs: [M, 3] | |
// deltas: [M, 2] | |
// rays: [N, 3], idx, offset, num_steps | |
// weights_sum: [N], final pixel alpha | |
// depth: [N,] | |
// image: [N, 3] | |
template <typename scalar_t> | |
__global__ void kernel_composite_rays_train_uncertainty_forward( | |
const scalar_t * __restrict__ sigmas, | |
const scalar_t * __restrict__ rgbs, | |
const scalar_t * __restrict__ ambient, | |
const scalar_t * __restrict__ uncertainty, | |
const scalar_t * __restrict__ deltas, | |
const int * __restrict__ rays, | |
const uint32_t M, const uint32_t N, const float T_thresh, | |
scalar_t * weights_sum, | |
scalar_t * ambient_sum, | |
scalar_t * uncertainty_sum, | |
scalar_t * depth, | |
scalar_t * image | |
) { | |
// parallel per ray | |
const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; | |
if (n >= N) return; | |
// locate | |
uint32_t index = rays[n * 3]; | |
uint32_t offset = rays[n * 3 + 1]; | |
uint32_t num_steps = rays[n * 3 + 2]; | |
// empty ray, or ray that exceed max step count. | |
if (num_steps == 0 || offset + num_steps > M) { | |
weights_sum[index] = 0; | |
ambient_sum[index] = 0; | |
uncertainty_sum[index] = 0; | |
depth[index] = 0; | |
image[index * 3] = 0; | |
image[index * 3 + 1] = 0; | |
image[index * 3 + 2] = 0; | |
return; | |
} | |
sigmas += offset; | |
rgbs += offset * 3; | |
ambient += offset; | |
uncertainty += offset; | |
deltas += offset * 2; | |
// accumulate | |
uint32_t step = 0; | |
scalar_t T = 1.0f; | |
scalar_t r = 0, g = 0, b = 0, ws = 0, d = 0, amb = 0, unc = 0; | |
while (step < num_steps) { | |
const scalar_t alpha = 1.0f - __expf(- sigmas[0] * deltas[0]); | |
const scalar_t weight = alpha * T; | |
r += weight * rgbs[0]; | |
g += weight * rgbs[1]; | |
b += weight * rgbs[2]; | |
d += weight * deltas[1]; | |
ws += weight; | |
amb += ambient[0]; | |
unc += weight * uncertainty[0]; | |
T *= 1.0f - alpha; | |
// minimal remained transmittence | |
if (T < T_thresh) break; | |
//printf("[n=%d] num_steps=%d, alpha=%f, w=%f, T=%f, sum_dt=%f, d=%f\n", n, step, alpha, weight, T, sum_delta, d); | |
// locate | |
sigmas++; | |
rgbs += 3; | |
ambient++; | |
uncertainty++; | |
deltas += 2; | |
step++; | |
} | |
//printf("[n=%d] rgb=(%f, %f, %f), d=%f\n", n, r, g, b, d); | |
// write | |
weights_sum[index] = ws; // weights_sum | |
ambient_sum[index] = amb; | |
uncertainty_sum[index] = unc; | |
depth[index] = d; | |
image[index * 3] = r; | |
image[index * 3 + 1] = g; | |
image[index * 3 + 2] = b; | |
} | |
void composite_rays_train_uncertainty_forward(const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor ambient, const at::Tensor uncertainty, const at::Tensor deltas, const at::Tensor rays, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor weights_sum, at::Tensor ambient_sum, at::Tensor uncertainty_sum, at::Tensor depth, at::Tensor image) { | |
static constexpr uint32_t N_THREAD = 128; | |
AT_DISPATCH_FLOATING_TYPES_AND_HALF( | |
sigmas.scalar_type(), "composite_rays_train_uncertainty_forward", ([&] { | |
kernel_composite_rays_train_uncertainty_forward<<<div_round_up(N, N_THREAD), N_THREAD>>>(sigmas.data_ptr<scalar_t>(), rgbs.data_ptr<scalar_t>(), ambient.data_ptr<scalar_t>(), uncertainty.data_ptr<scalar_t>(), deltas.data_ptr<scalar_t>(), rays.data_ptr<int>(), M, N, T_thresh, weights_sum.data_ptr<scalar_t>(), ambient_sum.data_ptr<scalar_t>(), uncertainty_sum.data_ptr<scalar_t>(), depth.data_ptr<scalar_t>(), image.data_ptr<scalar_t>()); | |
})); | |
} | |
// grad_weights_sum: [N,] | |
// grad: [N, 3] | |
// sigmas: [M] | |
// rgbs: [M, 3] | |
// deltas: [M, 2] | |
// rays: [N, 3], idx, offset, num_steps | |
// weights_sum: [N,], weights_sum here | |
// image: [N, 3] | |
// grad_sigmas: [M] | |
// grad_rgbs: [M, 3] | |
template <typename scalar_t> | |
__global__ void kernel_composite_rays_train_uncertainty_backward( | |
const scalar_t * __restrict__ grad_weights_sum, | |
const scalar_t * __restrict__ grad_ambient_sum, | |
const scalar_t * __restrict__ grad_uncertainty_sum, | |
const scalar_t * __restrict__ grad_image, | |
const scalar_t * __restrict__ sigmas, | |
const scalar_t * __restrict__ rgbs, | |
const scalar_t * __restrict__ ambient, | |
const scalar_t * __restrict__ uncertainty, | |
const scalar_t * __restrict__ deltas, | |
const int * __restrict__ rays, | |
const scalar_t * __restrict__ weights_sum, | |
const scalar_t * __restrict__ ambient_sum, | |
const scalar_t * __restrict__ uncertainty_sum, | |
const scalar_t * __restrict__ image, | |
const uint32_t M, const uint32_t N, const float T_thresh, | |
scalar_t * grad_sigmas, | |
scalar_t * grad_rgbs, | |
scalar_t * grad_ambient, | |
scalar_t * grad_uncertainty | |
) { | |
// parallel per ray | |
const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; | |
if (n >= N) return; | |
// locate | |
uint32_t index = rays[n * 3]; | |
uint32_t offset = rays[n * 3 + 1]; | |
uint32_t num_steps = rays[n * 3 + 2]; | |
if (num_steps == 0 || offset + num_steps > M) return; | |
grad_weights_sum += index; | |
grad_ambient_sum += index; | |
grad_uncertainty_sum += index; | |
grad_image += index * 3; | |
weights_sum += index; | |
ambient_sum += index; | |
uncertainty_sum += index; | |
image += index * 3; | |
sigmas += offset; | |
rgbs += offset * 3; | |
ambient += offset; | |
uncertainty += offset; | |
deltas += offset * 2; | |
grad_sigmas += offset; | |
grad_rgbs += offset * 3; | |
grad_ambient += offset; | |
grad_uncertainty += offset; | |
// accumulate | |
uint32_t step = 0; | |
scalar_t T = 1.0f; | |
const scalar_t r_final = image[0], g_final = image[1], b_final = image[2], ws_final = weights_sum[0], amb_final = ambient_sum[0], unc_final = uncertainty_sum[0]; | |
scalar_t r = 0, g = 0, b = 0, ws = 0, amb = 0, unc = 0; | |
while (step < num_steps) { | |
const scalar_t alpha = 1.0f - __expf(- sigmas[0] * deltas[0]); | |
const scalar_t weight = alpha * T; | |
r += weight * rgbs[0]; | |
g += weight * rgbs[1]; | |
b += weight * rgbs[2]; | |
// amb += ambient[0]; | |
unc += weight * uncertainty[0]; | |
ws += weight; | |
T *= 1.0f - alpha; | |
// check https://note.kiui.moe/others/nerf_gradient/ for the gradient calculation. | |
// write grad_rgbs | |
grad_rgbs[0] = grad_image[0] * weight; | |
grad_rgbs[1] = grad_image[1] * weight; | |
grad_rgbs[2] = grad_image[2] * weight; | |
// write grad_ambient | |
grad_ambient[0] = grad_ambient_sum[0]; | |
// write grad_unc | |
grad_uncertainty[0] = grad_uncertainty_sum[0] * weight; | |
// write grad_sigmas | |
grad_sigmas[0] = deltas[0] * ( | |
grad_image[0] * (T * rgbs[0] - (r_final - r)) + | |
grad_image[1] * (T * rgbs[1] - (g_final - g)) + | |
grad_image[2] * (T * rgbs[2] - (b_final - b)) + | |
// grad_ambient_sum[0] * (T * ambient[0] - (amb_final - amb)) + | |
grad_uncertainty_sum[0] * (T * uncertainty[0] - (unc_final - unc)) + | |
grad_weights_sum[0] * (1 - ws_final) | |
); | |
//printf("[n=%d] num_steps=%d, T=%f, grad_sigmas=%f, r_final=%f, r=%f\n", n, step, T, grad_sigmas[0], r_final, r); | |
// minimal remained transmittence | |
if (T < T_thresh) break; | |
// locate | |
sigmas++; | |
rgbs += 3; | |
// ambient++; | |
uncertainty++; | |
deltas += 2; | |
grad_sigmas++; | |
grad_rgbs += 3; | |
grad_ambient++; | |
grad_uncertainty++; | |
step++; | |
} | |
} | |
void composite_rays_train_uncertainty_backward(const at::Tensor grad_weights_sum, const at::Tensor grad_ambient_sum, const at::Tensor grad_uncertainty_sum, const at::Tensor grad_image, const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor ambient, const at::Tensor uncertainty, const at::Tensor deltas, const at::Tensor rays, const at::Tensor weights_sum, const at::Tensor ambient_sum, const at::Tensor uncertainty_sum, const at::Tensor image, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor grad_sigmas, at::Tensor grad_rgbs, at::Tensor grad_ambient, at::Tensor grad_uncertainty) { | |
static constexpr uint32_t N_THREAD = 128; | |
AT_DISPATCH_FLOATING_TYPES_AND_HALF( | |
grad_image.scalar_type(), "composite_rays_train_uncertainty_backward", ([&] { | |
kernel_composite_rays_train_uncertainty_backward<<<div_round_up(N, N_THREAD), N_THREAD>>>(grad_weights_sum.data_ptr<scalar_t>(), grad_ambient_sum.data_ptr<scalar_t>(), grad_uncertainty_sum.data_ptr<scalar_t>(), grad_image.data_ptr<scalar_t>(), sigmas.data_ptr<scalar_t>(), rgbs.data_ptr<scalar_t>(), ambient.data_ptr<scalar_t>(), uncertainty.data_ptr<scalar_t>(), deltas.data_ptr<scalar_t>(), rays.data_ptr<int>(), weights_sum.data_ptr<scalar_t>(), ambient_sum.data_ptr<scalar_t>(), uncertainty_sum.data_ptr<scalar_t>(), image.data_ptr<scalar_t>(), M, N, T_thresh, grad_sigmas.data_ptr<scalar_t>(), grad_rgbs.data_ptr<scalar_t>(), grad_ambient.data_ptr<scalar_t>(), grad_uncertainty.data_ptr<scalar_t>()); | |
})); | |
} | |
//////////////////////////////////////////////////// | |
///////////// infernce ///////////// | |
//////////////////////////////////////////////////// | |
template <typename scalar_t> | |
__global__ void kernel_composite_rays_uncertainty( | |
const uint32_t n_alive, | |
const uint32_t n_step, | |
const float T_thresh, | |
int* rays_alive, | |
scalar_t* rays_t, | |
const scalar_t* __restrict__ sigmas, | |
const scalar_t* __restrict__ rgbs, | |
const scalar_t* __restrict__ deltas, | |
const scalar_t* __restrict__ ambients, | |
const scalar_t* __restrict__ uncertainties, | |
scalar_t* weights_sum, scalar_t* depth, scalar_t* image, scalar_t* ambient_sum, scalar_t* uncertainty_sum | |
) { | |
const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; | |
if (n >= n_alive) return; | |
const int index = rays_alive[n]; // ray id | |
// locate | |
sigmas += n * n_step; | |
rgbs += n * n_step * 3; | |
deltas += n * n_step * 2; | |
ambients += n * n_step; | |
uncertainties += n * n_step; | |
rays_t += index; | |
weights_sum += index; | |
depth += index; | |
image += index * 3; | |
ambient_sum += index; | |
uncertainty_sum += index; | |
scalar_t t = rays_t[0]; // current ray's t | |
scalar_t weight_sum = weights_sum[0]; | |
scalar_t d = depth[0]; | |
scalar_t r = image[0]; | |
scalar_t g = image[1]; | |
scalar_t b = image[2]; | |
scalar_t a = ambient_sum[0]; | |
scalar_t u = uncertainty_sum[0]; | |
// accumulate | |
uint32_t step = 0; | |
while (step < n_step) { | |
// ray is terminated if delta == 0 | |
if (deltas[0] == 0) break; | |
const scalar_t alpha = 1.0f - __expf(- sigmas[0] * deltas[0]); | |
/* | |
T_0 = 1; T_i = \prod_{j=0}^{i-1} (1 - alpha_j) | |
w_i = alpha_i * T_i | |
--> | |
T_i = 1 - \sum_{j=0}^{i-1} w_j | |
*/ | |
const scalar_t T = 1 - weight_sum; | |
const scalar_t weight = alpha * T; | |
weight_sum += weight; | |
t = deltas[1]; | |
d += weight * t; | |
r += weight * rgbs[0]; | |
g += weight * rgbs[1]; | |
b += weight * rgbs[2]; | |
a += ambients[0]; | |
u += weight * uncertainties[0]; | |
//printf("[n=%d] num_steps=%d, alpha=%f, w=%f, T=%f, sum_dt=%f, d=%f\n", n, step, alpha, weight, T, sum_delta, d); | |
// ray is terminated if T is too small | |
// use a larger bound to further accelerate inference | |
if (T < T_thresh) break; | |
// locate | |
sigmas++; | |
rgbs += 3; | |
deltas += 2; | |
step++; | |
ambients++; | |
uncertainties++; | |
} | |
//printf("[n=%d] rgb=(%f, %f, %f), d=%f\n", n, r, g, b, d); | |
// rays_alive = -1 means ray is terminated early. | |
if (step < n_step) { | |
rays_alive[n] = -1; | |
} else { | |
rays_t[0] = t; | |
} | |
weights_sum[0] = weight_sum; // this is the thing I needed! | |
depth[0] = d; | |
image[0] = r; | |
image[1] = g; | |
image[2] = b; | |
ambient_sum[0] = a; | |
uncertainty_sum[0] = u; | |
} | |
void composite_rays_uncertainty(const uint32_t n_alive, const uint32_t n_step, const float T_thresh, at::Tensor rays_alive, at::Tensor rays_t, at::Tensor sigmas, at::Tensor rgbs, at::Tensor deltas, at::Tensor ambients, at::Tensor uncertainties, at::Tensor weights, at::Tensor depth, at::Tensor image, at::Tensor ambient_sum, at::Tensor uncertainty_sum) { | |
static constexpr uint32_t N_THREAD = 128; | |
AT_DISPATCH_FLOATING_TYPES_AND_HALF( | |
image.scalar_type(), "composite_rays_uncertainty", ([&] { | |
kernel_composite_rays_uncertainty<<<div_round_up(n_alive, N_THREAD), N_THREAD>>>(n_alive, n_step, T_thresh, rays_alive.data_ptr<int>(), rays_t.data_ptr<scalar_t>(), sigmas.data_ptr<scalar_t>(), rgbs.data_ptr<scalar_t>(), deltas.data_ptr<scalar_t>(), ambients.data_ptr<scalar_t>(), uncertainties.data_ptr<scalar_t>(), weights.data_ptr<scalar_t>(), depth.data_ptr<scalar_t>(), image.data_ptr<scalar_t>(), ambient_sum.data_ptr<scalar_t>(), uncertainty_sum.data_ptr<scalar_t>()); | |
})); | |
} | |
// -------------------------------- triplane ----------------------------- | |
// sigmas: [M] | |
// rgbs: [M, 3] | |
// deltas: [M, 2] | |
// rays: [N, 3], idx, offset, num_steps | |
// weights_sum: [N], final pixel alpha | |
// depth: [N,] | |
// image: [N, 3] | |
template <typename scalar_t> | |
__global__ void kernel_composite_rays_train_triplane_forward( | |
const scalar_t * __restrict__ sigmas, | |
const scalar_t * __restrict__ rgbs, | |
const scalar_t * __restrict__ amb_aud, | |
const scalar_t * __restrict__ amb_eye, | |
const scalar_t * __restrict__ uncertainty, | |
const scalar_t * __restrict__ deltas, | |
const int * __restrict__ rays, | |
const uint32_t M, const uint32_t N, const float T_thresh, | |
scalar_t * weights_sum, | |
scalar_t * amb_aud_sum, | |
scalar_t * amb_eye_sum, | |
scalar_t * uncertainty_sum, | |
scalar_t * depth, | |
scalar_t * image | |
) { | |
// parallel per ray | |
const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; | |
if (n >= N) return; | |
// locate | |
uint32_t index = rays[n * 3]; | |
uint32_t offset = rays[n * 3 + 1]; | |
uint32_t num_steps = rays[n * 3 + 2]; | |
// empty ray, or ray that exceed max step count. | |
if (num_steps == 0 || offset + num_steps > M) { | |
weights_sum[index] = 0; | |
amb_aud_sum[index] = 0; | |
amb_eye_sum[index] = 0; | |
uncertainty_sum[index] = 0; | |
depth[index] = 0; | |
image[index * 3] = 0; | |
image[index * 3 + 1] = 0; | |
image[index * 3 + 2] = 0; | |
return; | |
} | |
sigmas += offset; | |
rgbs += offset * 3; | |
amb_aud += offset; | |
amb_eye += offset; | |
uncertainty += offset; | |
deltas += offset * 2; | |
// accumulate | |
uint32_t step = 0; | |
scalar_t T = 1.0f; | |
scalar_t r = 0, g = 0, b = 0, ws = 0, d = 0, a_aud = 0, a_eye=0, unc = 0; | |
while (step < num_steps) { | |
const scalar_t alpha = 1.0f - __expf(- sigmas[0] * deltas[0]); | |
const scalar_t weight = alpha * T; | |
r += weight * rgbs[0]; | |
g += weight * rgbs[1]; | |
b += weight * rgbs[2]; | |
d += weight * deltas[1]; | |
ws += weight; | |
a_aud += amb_aud[0]; | |
a_eye += amb_eye[0]; | |
unc += weight * uncertainty[0]; | |
T *= 1.0f - alpha; | |
// minimal remained transmittence | |
if (T < T_thresh) break; | |
//printf("[n=%d] num_steps=%d, alpha=%f, w=%f, T=%f, sum_dt=%f, d=%f\n", n, step, alpha, weight, T, sum_delta, d); | |
// locate | |
sigmas++; | |
rgbs += 3; | |
amb_aud++; | |
amb_eye++; | |
uncertainty++; | |
deltas += 2; | |
step++; | |
} | |
//printf("[n=%d] rgb=(%f, %f, %f), d=%f\n", n, r, g, b, d); | |
// write | |
weights_sum[index] = ws; // weights_sum | |
amb_aud_sum[index] = a_aud; | |
amb_eye_sum[index] = a_eye; | |
uncertainty_sum[index] = unc; | |
depth[index] = d; | |
image[index * 3] = r; | |
image[index * 3 + 1] = g; | |
image[index * 3 + 2] = b; | |
} | |
void composite_rays_train_triplane_forward(const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor amb_aud, const at::Tensor amb_eye, const at::Tensor uncertainty, const at::Tensor deltas, const at::Tensor rays, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor weights_sum, at::Tensor amb_aud_sum, at::Tensor amb_eye_sum, at::Tensor uncertainty_sum, at::Tensor depth, at::Tensor image) { | |
static constexpr uint32_t N_THREAD = 128; | |
AT_DISPATCH_FLOATING_TYPES_AND_HALF( | |
sigmas.scalar_type(), "composite_rays_train_triplane_forward", ([&] { | |
kernel_composite_rays_train_triplane_forward<<<div_round_up(N, N_THREAD), N_THREAD>>>(sigmas.data_ptr<scalar_t>(), rgbs.data_ptr<scalar_t>(), amb_aud.data_ptr<scalar_t>(), amb_eye.data_ptr<scalar_t>(), uncertainty.data_ptr<scalar_t>(), deltas.data_ptr<scalar_t>(), rays.data_ptr<int>(), M, N, T_thresh, weights_sum.data_ptr<scalar_t>(), amb_aud_sum.data_ptr<scalar_t>(), amb_eye_sum.data_ptr<scalar_t>(), uncertainty_sum.data_ptr<scalar_t>(), depth.data_ptr<scalar_t>(), image.data_ptr<scalar_t>()); | |
})); | |
} | |
// grad_weights_sum: [N,] | |
// grad: [N, 3] | |
// sigmas: [M] | |
// rgbs: [M, 3] | |
// deltas: [M, 2] | |
// rays: [N, 3], idx, offset, num_steps | |
// weights_sum: [N,], weights_sum here | |
// image: [N, 3] | |
// grad_sigmas: [M] | |
// grad_rgbs: [M, 3] | |
template <typename scalar_t> | |
__global__ void kernel_composite_rays_train_triplane_backward( | |
const scalar_t * __restrict__ grad_weights_sum, | |
const scalar_t * __restrict__ grad_amb_aud_sum, | |
const scalar_t * __restrict__ grad_amb_eye_sum, | |
const scalar_t * __restrict__ grad_uncertainty_sum, | |
const scalar_t * __restrict__ grad_image, | |
const scalar_t * __restrict__ sigmas, | |
const scalar_t * __restrict__ rgbs, | |
const scalar_t * __restrict__ amb_aud, | |
const scalar_t * __restrict__ amb_eye, | |
const scalar_t * __restrict__ uncertainty, | |
const scalar_t * __restrict__ deltas, | |
const int * __restrict__ rays, | |
const scalar_t * __restrict__ weights_sum, | |
const scalar_t * __restrict__ amb_aud_sum, | |
const scalar_t * __restrict__ amb_eye_sum, | |
const scalar_t * __restrict__ uncertainty_sum, | |
const scalar_t * __restrict__ image, | |
const uint32_t M, const uint32_t N, const float T_thresh, | |
scalar_t * grad_sigmas, | |
scalar_t * grad_rgbs, | |
scalar_t * grad_amb_aud, | |
scalar_t * grad_amb_eye, | |
scalar_t * grad_uncertainty | |
) { | |
// parallel per ray | |
const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; | |
if (n >= N) return; | |
// locate | |
uint32_t index = rays[n * 3]; | |
uint32_t offset = rays[n * 3 + 1]; | |
uint32_t num_steps = rays[n * 3 + 2]; | |
if (num_steps == 0 || offset + num_steps > M) return; | |
grad_weights_sum += index; | |
grad_amb_aud_sum += index; | |
grad_amb_eye_sum += index; | |
grad_uncertainty_sum += index; | |
grad_image += index * 3; | |
weights_sum += index; | |
amb_aud_sum += index; | |
amb_eye_sum += index; | |
uncertainty_sum += index; | |
image += index * 3; | |
sigmas += offset; | |
rgbs += offset * 3; | |
amb_aud += offset; | |
amb_eye += offset; | |
uncertainty += offset; | |
deltas += offset * 2; | |
grad_sigmas += offset; | |
grad_rgbs += offset * 3; | |
grad_amb_aud += offset; | |
grad_amb_eye += offset; | |
grad_uncertainty += offset; | |
// accumulate | |
uint32_t step = 0; | |
scalar_t T = 1.0f; | |
const scalar_t r_final = image[0], g_final = image[1], b_final = image[2], ws_final = weights_sum[0], unc_final = uncertainty_sum[0]; | |
scalar_t r = 0, g = 0, b = 0, ws = 0, amb = 0, unc = 0; | |
while (step < num_steps) { | |
const scalar_t alpha = 1.0f - __expf(- sigmas[0] * deltas[0]); | |
const scalar_t weight = alpha * T; | |
r += weight * rgbs[0]; | |
g += weight * rgbs[1]; | |
b += weight * rgbs[2]; | |
// amb += ambient[0]; | |
unc += weight * uncertainty[0]; | |
ws += weight; | |
T *= 1.0f - alpha; | |
// check https://note.kiui.moe/others/nerf_gradient/ for the gradient calculation. | |
// write grad_rgbs | |
grad_rgbs[0] = grad_image[0] * weight; | |
grad_rgbs[1] = grad_image[1] * weight; | |
grad_rgbs[2] = grad_image[2] * weight; | |
// write grad_ambient | |
grad_amb_aud[0] = grad_amb_aud_sum[0]; | |
grad_amb_eye[0] = grad_amb_eye_sum[0]; | |
// write grad_unc | |
grad_uncertainty[0] = grad_uncertainty_sum[0] * weight; | |
// write grad_sigmas | |
grad_sigmas[0] = deltas[0] * ( | |
grad_image[0] * (T * rgbs[0] - (r_final - r)) + | |
grad_image[1] * (T * rgbs[1] - (g_final - g)) + | |
grad_image[2] * (T * rgbs[2] - (b_final - b)) + | |
// grad_ambient_sum[0] * (T * ambient[0] - (amb_final - amb)) + | |
grad_uncertainty_sum[0] * (T * uncertainty[0] - (unc_final - unc)) + | |
grad_weights_sum[0] * (1 - ws_final) | |
); | |
//printf("[n=%d] num_steps=%d, T=%f, grad_sigmas=%f, r_final=%f, r=%f\n", n, step, T, grad_sigmas[0], r_final, r); | |
// minimal remained transmittence | |
if (T < T_thresh) break; | |
// locate | |
sigmas++; | |
rgbs += 3; | |
// ambient++; | |
uncertainty++; | |
deltas += 2; | |
grad_sigmas++; | |
grad_rgbs += 3; | |
grad_amb_aud++; | |
grad_amb_eye++; | |
grad_uncertainty++; | |
step++; | |
} | |
} | |
void composite_rays_train_triplane_backward(const at::Tensor grad_weights_sum, const at::Tensor grad_amb_aud_sum, const at::Tensor grad_amb_eye_sum, const at::Tensor grad_uncertainty_sum, const at::Tensor grad_image, const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor amb_aud, const at::Tensor amb_eye, const at::Tensor uncertainty, const at::Tensor deltas, const at::Tensor rays, const at::Tensor weights_sum, const at::Tensor amb_aud_sum, const at::Tensor amb_eye_sum, const at::Tensor uncertainty_sum, const at::Tensor image, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor grad_sigmas, at::Tensor grad_rgbs, at::Tensor grad_amb_aud, at::Tensor grad_amb_eye, at::Tensor grad_uncertainty) { | |
static constexpr uint32_t N_THREAD = 128; | |
AT_DISPATCH_FLOATING_TYPES_AND_HALF( | |
grad_image.scalar_type(), "composite_rays_train_triplane_backward", ([&] { | |
kernel_composite_rays_train_triplane_backward<<<div_round_up(N, N_THREAD), N_THREAD>>>(grad_weights_sum.data_ptr<scalar_t>(), grad_amb_aud_sum.data_ptr<scalar_t>(), grad_amb_eye_sum.data_ptr<scalar_t>(), grad_uncertainty_sum.data_ptr<scalar_t>(), grad_image.data_ptr<scalar_t>(), sigmas.data_ptr<scalar_t>(), rgbs.data_ptr<scalar_t>(), amb_aud.data_ptr<scalar_t>(), amb_eye.data_ptr<scalar_t>(), uncertainty.data_ptr<scalar_t>(), deltas.data_ptr<scalar_t>(), rays.data_ptr<int>(), weights_sum.data_ptr<scalar_t>(), amb_aud_sum.data_ptr<scalar_t>(), amb_eye_sum.data_ptr<scalar_t>(), uncertainty_sum.data_ptr<scalar_t>(), image.data_ptr<scalar_t>(), M, N, T_thresh, grad_sigmas.data_ptr<scalar_t>(), grad_rgbs.data_ptr<scalar_t>(), grad_amb_aud.data_ptr<scalar_t>(), grad_amb_eye.data_ptr<scalar_t>(), grad_uncertainty.data_ptr<scalar_t>()); | |
})); | |
} | |
//////////////////////////////////////////////////// | |
///////////// infernce ///////////// | |
//////////////////////////////////////////////////// | |
template <typename scalar_t> | |
__global__ void kernel_composite_rays_triplane( | |
const uint32_t n_alive, | |
const uint32_t n_step, | |
const float T_thresh, | |
int* rays_alive, | |
scalar_t* rays_t, | |
const scalar_t* __restrict__ sigmas, | |
const scalar_t* __restrict__ rgbs, | |
const scalar_t* __restrict__ deltas, | |
const scalar_t* __restrict__ ambs_aud, | |
const scalar_t* __restrict__ ambs_eye, | |
const scalar_t* __restrict__ uncertainties, | |
scalar_t* weights_sum, scalar_t* depth, scalar_t* image, scalar_t* amb_aud_sum, scalar_t* amb_eye_sum, scalar_t* uncertainty_sum | |
) { | |
const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; | |
if (n >= n_alive) return; | |
const int index = rays_alive[n]; // ray id | |
// locate | |
sigmas += n * n_step; | |
rgbs += n * n_step * 3; | |
deltas += n * n_step * 2; | |
ambs_aud += n * n_step; | |
ambs_eye += n * n_step; | |
uncertainties += n * n_step; | |
rays_t += index; | |
weights_sum += index; | |
depth += index; | |
image += index * 3; | |
amb_aud_sum += index; | |
amb_eye_sum += index; | |
uncertainty_sum += index; | |
scalar_t t = rays_t[0]; // current ray's t | |
scalar_t weight_sum = weights_sum[0]; | |
scalar_t d = depth[0]; | |
scalar_t r = image[0]; | |
scalar_t g = image[1]; | |
scalar_t b = image[2]; | |
scalar_t a_aud = amb_aud_sum[0]; | |
scalar_t a_eye = amb_eye_sum[0]; | |
scalar_t u = uncertainty_sum[0]; | |
// accumulate | |
uint32_t step = 0; | |
while (step < n_step) { | |
// ray is terminated if delta == 0 | |
if (deltas[0] == 0) break; | |
const scalar_t alpha = 1.0f - __expf(- sigmas[0] * deltas[0]); | |
/* | |
T_0 = 1; T_i = \prod_{j=0}^{i-1} (1 - alpha_j) | |
w_i = alpha_i * T_i | |
--> | |
T_i = 1 - \sum_{j=0}^{i-1} w_j | |
*/ | |
const scalar_t T = 1 - weight_sum; | |
const scalar_t weight = alpha * T; | |
weight_sum += weight; | |
t = deltas[1]; | |
d += weight * t; | |
r += weight * rgbs[0]; | |
g += weight * rgbs[1]; | |
b += weight * rgbs[2]; | |
a_aud += ambs_aud[0]; | |
a_eye += ambs_eye[0]; | |
u += weight * uncertainties[0]; | |
//printf("[n=%d] num_steps=%d, alpha=%f, w=%f, T=%f, sum_dt=%f, d=%f\n", n, step, alpha, weight, T, sum_delta, d); | |
// ray is terminated if T is too small | |
// use a larger bound to further accelerate inference | |
if (T < T_thresh) break; | |
// locate | |
sigmas++; | |
rgbs += 3; | |
deltas += 2; | |
step++; | |
ambs_aud++; | |
ambs_eye++; | |
uncertainties++; | |
} | |
//printf("[n=%d] rgb=(%f, %f, %f), d=%f\n", n, r, g, b, d); | |
// rays_alive = -1 means ray is terminated early. | |
if (step < n_step) { | |
rays_alive[n] = -1; | |
} else { | |
rays_t[0] = t; | |
} | |
weights_sum[0] = weight_sum; // this is the thing I needed! | |
depth[0] = d; | |
image[0] = r; | |
image[1] = g; | |
image[2] = b; | |
amb_aud_sum[0] = a_aud; | |
amb_eye_sum[0] = a_eye; | |
uncertainty_sum[0] = u; | |
} | |
void composite_rays_triplane(const uint32_t n_alive, const uint32_t n_step, const float T_thresh, at::Tensor rays_alive, at::Tensor rays_t, at::Tensor sigmas, at::Tensor rgbs, at::Tensor deltas, at::Tensor ambs_aud, at::Tensor ambs_eye, at::Tensor uncertainties, at::Tensor weights, at::Tensor depth, at::Tensor image, at::Tensor amb_aud_sum, at::Tensor amb_eye_sum, at::Tensor uncertainty_sum) { | |
static constexpr uint32_t N_THREAD = 128; | |
AT_DISPATCH_FLOATING_TYPES_AND_HALF( | |
image.scalar_type(), "composite_rays_triplane", ([&] { | |
kernel_composite_rays_triplane<<<div_round_up(n_alive, N_THREAD), N_THREAD>>>(n_alive, n_step, T_thresh, rays_alive.data_ptr<int>(), rays_t.data_ptr<scalar_t>(), sigmas.data_ptr<scalar_t>(), rgbs.data_ptr<scalar_t>(), deltas.data_ptr<scalar_t>(), ambs_aud.data_ptr<scalar_t>(), ambs_eye.data_ptr<scalar_t>(), uncertainties.data_ptr<scalar_t>(), weights.data_ptr<scalar_t>(), depth.data_ptr<scalar_t>(), image.data_ptr<scalar_t>(), amb_aud_sum.data_ptr<scalar_t>(), amb_eye_sum.data_ptr<scalar_t>(), uncertainty_sum.data_ptr<scalar_t>()); | |
})); | |
} | |