Spaces:
Runtime error
Runtime error
| // just for compatability of half precision in AT_DISPATCH_FLOATING_TYPES_AND_HALF... | |
| static inline __device__ at::Half atomicAdd(at::Half *address, at::Half val) { | |
| // requires CUDA >= 10 and ARCH >= 70 | |
| // this is very slow compared to float or __half2, and never used. | |
| //return atomicAdd(reinterpret_cast<__half*>(address), val); | |
| } | |
| template <typename T> | |
| static inline __host__ __device__ T div_round_up(T val, T divisor) { | |
| return (val + divisor - 1) / divisor; | |
| } | |
| template <uint32_t D> | |
| __device__ uint32_t fast_hash(const uint32_t pos_grid[D]) { | |
| static_assert(D <= 7, "fast_hash can only hash up to 7 dimensions."); | |
| // While 1 is technically not a good prime for hashing (or a prime at all), it helps memory coherence | |
| // and is sufficient for our use case of obtaining a uniformly colliding index from high-dimensional | |
| // coordinates. | |
| constexpr uint32_t primes[7] = { 1, 2654435761, 805459861, 3674653429, 2097192037, 1434869437, 2165219737 }; | |
| uint32_t result = 0; | |
| for (uint32_t i = 0; i < D; ++i) { | |
| result ^= pos_grid[i] * primes[i]; | |
| } | |
| return result; | |
| } | |
| template <uint32_t D, uint32_t C> | |
| __device__ uint32_t get_grid_index(const uint32_t gridtype, const bool align_corners, const uint32_t ch, const uint32_t hashmap_size, const uint32_t resolution, const uint32_t pos_grid[D]) { | |
| uint32_t stride = 1; | |
| uint32_t index = 0; | |
| for (uint32_t d = 0; d < D && stride <= hashmap_size; d++) { | |
| index += pos_grid[d] * stride; | |
| stride *= align_corners ? resolution: (resolution + 1); | |
| } | |
| // NOTE: for NeRF, the hash is in fact not necessary. Check https://github.com/NVlabs/instant-ngp/issues/97. | |
| // gridtype: 0 == hash, 1 == tiled | |
| if (gridtype == 0 && stride > hashmap_size) { | |
| index = fast_hash<D>(pos_grid); | |
| } | |
| return (index % hashmap_size) * C + ch; | |
| } | |
| template <typename scalar_t, uint32_t D, uint32_t C> | |
| __global__ void kernel_grid( | |
| const float * __restrict__ inputs, | |
| const scalar_t * __restrict__ grid, | |
| const int * __restrict__ offsets, | |
| scalar_t * __restrict__ outputs, | |
| const uint32_t B, const uint32_t L, const float S, const uint32_t H, | |
| scalar_t * __restrict__ dy_dx, | |
| const uint32_t gridtype, | |
| const bool align_corners | |
| ) { | |
| const uint32_t b = blockIdx.x * blockDim.x + threadIdx.x; | |
| if (b >= B) return; | |
| const uint32_t level = blockIdx.y; | |
| // locate | |
| grid += (uint32_t)offsets[level] * C; | |
| inputs += b * D; | |
| outputs += level * B * C + b * C; | |
| // check input range (should be in [0, 1]) | |
| bool flag_oob = false; | |
| for (uint32_t d = 0; d < D; d++) { | |
| if (inputs[d] < 0 || inputs[d] > 1) { | |
| flag_oob = true; | |
| } | |
| } | |
| // if input out of bound, just set output to 0 | |
| if (flag_oob) { | |
| for (uint32_t ch = 0; ch < C; ch++) { | |
| outputs[ch] = 0; | |
| } | |
| if (dy_dx) { | |
| dy_dx += b * D * L * C + level * D * C; // B L D C | |
| for (uint32_t d = 0; d < D; d++) { | |
| for (uint32_t ch = 0; ch < C; ch++) { | |
| dy_dx[d * C + ch] = 0; | |
| } | |
| } | |
| } | |
| return; | |
| } | |
| const uint32_t hashmap_size = offsets[level + 1] - offsets[level]; | |
| const float scale = exp2f(level * S) * H - 1.0f; | |
| const uint32_t resolution = (uint32_t)ceil(scale) + 1; | |
| // calculate coordinate | |
| float pos[D]; | |
| uint32_t pos_grid[D]; | |
| for (uint32_t d = 0; d < D; d++) { | |
| pos[d] = inputs[d] * scale + (align_corners ? 0.0f : 0.5f); | |
| pos_grid[d] = floorf(pos[d]); | |
| pos[d] -= (float)pos_grid[d]; | |
| } | |
| //printf("[b=%d, l=%d] pos=(%f, %f)+(%d, %d)\n", b, level, pos[0], pos[1], pos_grid[0], pos_grid[1]); | |
| // interpolate | |
| scalar_t results[C] = {0}; // temp results in register | |
| for (uint32_t idx = 0; idx < (1 << D); idx++) { | |
| float w = 1; | |
| uint32_t pos_grid_local[D]; | |
| for (uint32_t d = 0; d < D; d++) { | |
| if ((idx & (1 << d)) == 0) { | |
| w *= 1 - pos[d]; | |
| pos_grid_local[d] = pos_grid[d]; | |
| } else { | |
| w *= pos[d]; | |
| pos_grid_local[d] = pos_grid[d] + 1; | |
| } | |
| } | |
| uint32_t index = get_grid_index<D, C>(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid_local); | |
| // writing to register (fast) | |
| for (uint32_t ch = 0; ch < C; ch++) { | |
| results[ch] += w * grid[index + ch]; | |
| } | |
| //printf("[b=%d, l=%d] int %d, idx %d, w %f, val %f\n", b, level, idx, index, w, grid[index]); | |
| } | |
| // writing to global memory (slow) | |
| for (uint32_t ch = 0; ch < C; ch++) { | |
| outputs[ch] = results[ch]; | |
| } | |
| // prepare dy_dx | |
| // differentiable (soft) indexing: https://discuss.pytorch.org/t/differentiable-indexing/17647/9 | |
| if (dy_dx) { | |
| dy_dx += b * D * L * C + level * D * C; // B L D C | |
| for (uint32_t gd = 0; gd < D; gd++) { | |
| scalar_t results_grad[C] = {0}; | |
| for (uint32_t idx = 0; idx < (1 << (D - 1)); idx++) { | |
| float w = scale; | |
| uint32_t pos_grid_local[D]; | |
| for (uint32_t nd = 0; nd < D - 1; nd++) { | |
| const uint32_t d = (nd >= gd) ? (nd + 1) : nd; | |
| if ((idx & (1 << nd)) == 0) { | |
| w *= 1 - pos[d]; | |
| pos_grid_local[d] = pos_grid[d]; | |
| } else { | |
| w *= pos[d]; | |
| pos_grid_local[d] = pos_grid[d] + 1; | |
| } | |
| } | |
| pos_grid_local[gd] = pos_grid[gd]; | |
| uint32_t index_left = get_grid_index<D, C>(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid_local); | |
| pos_grid_local[gd] = pos_grid[gd] + 1; | |
| uint32_t index_right = get_grid_index<D, C>(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid_local); | |
| for (uint32_t ch = 0; ch < C; ch++) { | |
| results_grad[ch] += w * (grid[index_right + ch] - grid[index_left + ch]); | |
| } | |
| } | |
| for (uint32_t ch = 0; ch < C; ch++) { | |
| dy_dx[gd * C + ch] = results_grad[ch]; | |
| } | |
| } | |
| } | |
| } | |
| template <typename scalar_t, uint32_t D, uint32_t C, uint32_t N_C> | |
| __global__ void kernel_grid_backward( | |
| const scalar_t * __restrict__ grad, | |
| const float * __restrict__ inputs, | |
| const scalar_t * __restrict__ grid, | |
| const int * __restrict__ offsets, | |
| scalar_t * __restrict__ grad_grid, | |
| const uint32_t B, const uint32_t L, const float S, const uint32_t H, | |
| const uint32_t gridtype, | |
| const bool align_corners | |
| ) { | |
| const uint32_t b = (blockIdx.x * blockDim.x + threadIdx.x) * N_C / C; | |
| if (b >= B) return; | |
| const uint32_t level = blockIdx.y; | |
| const uint32_t ch = (blockIdx.x * blockDim.x + threadIdx.x) * N_C - b * C; | |
| // locate | |
| grad_grid += offsets[level] * C; | |
| inputs += b * D; | |
| grad += level * B * C + b * C + ch; // L, B, C | |
| const uint32_t hashmap_size = offsets[level + 1] - offsets[level]; | |
| const float scale = exp2f(level * S) * H - 1.0f; | |
| const uint32_t resolution = (uint32_t)ceil(scale) + 1; | |
| // check input range (should be in [0, 1]) | |
| for (uint32_t d = 0; d < D; d++) { | |
| if (inputs[d] < 0 || inputs[d] > 1) { | |
| return; // grad is init as 0, so we simply return. | |
| } | |
| } | |
| // calculate coordinate | |
| float pos[D]; | |
| uint32_t pos_grid[D]; | |
| for (uint32_t d = 0; d < D; d++) { | |
| pos[d] = inputs[d] * scale + (align_corners ? 0.0f : 0.5f); | |
| pos_grid[d] = floorf(pos[d]); | |
| pos[d] -= (float)pos_grid[d]; | |
| } | |
| scalar_t grad_cur[N_C] = {0}; // fetch to register | |
| for (uint32_t c = 0; c < N_C; c++) { | |
| grad_cur[c] = grad[c]; | |
| } | |
| // interpolate | |
| for (uint32_t idx = 0; idx < (1 << D); idx++) { | |
| float w = 1; | |
| uint32_t pos_grid_local[D]; | |
| for (uint32_t d = 0; d < D; d++) { | |
| if ((idx & (1 << d)) == 0) { | |
| w *= 1 - pos[d]; | |
| pos_grid_local[d] = pos_grid[d]; | |
| } else { | |
| w *= pos[d]; | |
| pos_grid_local[d] = pos_grid[d] + 1; | |
| } | |
| } | |
| uint32_t index = get_grid_index<D, C>(gridtype, align_corners, ch, hashmap_size, resolution, pos_grid_local); | |
| // atomicAdd for __half is slow (especially for large values), so we use __half2 if N_C % 2 == 0 | |
| // TODO: use float which is better than __half, if N_C % 2 != 0 | |
| if (std::is_same<scalar_t, at::Half>::value && N_C % 2 == 0) { | |
| for (uint32_t c = 0; c < N_C; c += 2) { | |
| // process two __half at once (by interpreting as a __half2) | |
| __half2 v = {(__half)(w * grad_cur[c]), (__half)(w * grad_cur[c + 1])}; | |
| atomicAdd((__half2*)&grad_grid[index + c], v); | |
| } | |
| // float, or __half when N_C % 2 != 0 (which means C == 1) | |
| } else { | |
| for (uint32_t c = 0; c < N_C; c++) { | |
| atomicAdd(&grad_grid[index + c], w * grad_cur[c]); | |
| } | |
| } | |
| } | |
| } | |
| template <typename scalar_t, uint32_t D, uint32_t C> | |
| __global__ void kernel_input_backward( | |
| const scalar_t * __restrict__ grad, | |
| const scalar_t * __restrict__ dy_dx, | |
| scalar_t * __restrict__ grad_inputs, | |
| uint32_t B, uint32_t L | |
| ) { | |
| const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x; | |
| if (t >= B * D) return; | |
| const uint32_t b = t / D; | |
| const uint32_t d = t - b * D; | |
| dy_dx += b * L * D * C; | |
| scalar_t result = 0; | |
| for (int l = 0; l < L; l++) { | |
| for (int ch = 0; ch < C; ch++) { | |
| result += grad[l * B * C + b * C + ch] * dy_dx[l * D * C + d * C + ch]; | |
| } | |
| } | |
| grad_inputs[t] = result; | |
| } | |
| template <typename scalar_t, uint32_t D> | |
| void kernel_grid_wrapper(const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *outputs, const uint32_t B, const uint32_t C, const uint32_t L, const float S, const uint32_t H, scalar_t *dy_dx, const uint32_t gridtype, const bool align_corners) { | |
| static constexpr uint32_t N_THREAD = 512; | |
| const dim3 blocks_hashgrid = { div_round_up(B, N_THREAD), L, 1 }; | |
| switch (C) { | |
| case 1: kernel_grid<scalar_t, D, 1><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners); break; | |
| case 2: kernel_grid<scalar_t, D, 2><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners); break; | |
| case 4: kernel_grid<scalar_t, D, 4><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners); break; | |
| case 8: kernel_grid<scalar_t, D, 8><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners); break; | |
| default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."}; | |
| } | |
| } | |
| // inputs: [B, D], float, in [0, 1] | |
| // embeddings: [sO, C], float | |
| // offsets: [L + 1], uint32_t | |
| // outputs: [L, B, C], float (L first, so only one level of hashmap needs to fit into cache at a time.) | |
| // H: base resolution | |
| // dy_dx: [B, L * D * C] | |
| template <typename scalar_t> | |
| void grid_encode_forward_cuda(const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, scalar_t *dy_dx, const uint32_t gridtype, const bool align_corners) { | |
| switch (D) { | |
| case 1: kernel_grid_wrapper<scalar_t, 1>(inputs, embeddings, offsets, outputs, B, C, L, S, H, dy_dx, gridtype, align_corners); break; | |
| case 2: kernel_grid_wrapper<scalar_t, 2>(inputs, embeddings, offsets, outputs, B, C, L, S, H, dy_dx, gridtype, align_corners); break; | |
| case 3: kernel_grid_wrapper<scalar_t, 3>(inputs, embeddings, offsets, outputs, B, C, L, S, H, dy_dx, gridtype, align_corners); break; | |
| case 4: kernel_grid_wrapper<scalar_t, 4>(inputs, embeddings, offsets, outputs, B, C, L, S, H, dy_dx, gridtype, align_corners); break; | |
| case 5: kernel_grid_wrapper<scalar_t, 5>(inputs, embeddings, offsets, outputs, B, C, L, S, H, dy_dx, gridtype, align_corners); break; | |
| default: throw std::runtime_error{"GridEncoding: D must be 1, 2, 3, 4, or 5."}; | |
| } | |
| } | |
| template <typename scalar_t, uint32_t D> | |
| void kernel_grid_backward_wrapper(const scalar_t *grad, const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *grad_embeddings, const uint32_t B, const uint32_t C, const uint32_t L, const float S, const uint32_t H, scalar_t *dy_dx, scalar_t *grad_inputs, const uint32_t gridtype, const bool align_corners) { | |
| static constexpr uint32_t N_THREAD = 256; | |
| const uint32_t N_C = std::min(2u, C); // n_features_per_thread | |
| const dim3 blocks_hashgrid = { div_round_up(B * C / N_C, N_THREAD), L, 1 }; | |
| switch (C) { | |
| case 1: | |
| kernel_grid_backward<scalar_t, D, 1, 1><<<blocks_hashgrid, N_THREAD>>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners); | |
| if (dy_dx) kernel_input_backward<scalar_t, D, 1><<<div_round_up(B * D, N_THREAD), N_THREAD>>>(grad, dy_dx, grad_inputs, B, L); | |
| break; | |
| case 2: | |
| kernel_grid_backward<scalar_t, D, 2, 2><<<blocks_hashgrid, N_THREAD>>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners); | |
| if (dy_dx) kernel_input_backward<scalar_t, D, 2><<<div_round_up(B * D, N_THREAD), N_THREAD>>>(grad, dy_dx, grad_inputs, B, L); | |
| break; | |
| case 4: | |
| kernel_grid_backward<scalar_t, D, 4, 2><<<blocks_hashgrid, N_THREAD>>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners); | |
| if (dy_dx) kernel_input_backward<scalar_t, D, 4><<<div_round_up(B * D, N_THREAD), N_THREAD>>>(grad, dy_dx, grad_inputs, B, L); | |
| break; | |
| case 8: | |
| kernel_grid_backward<scalar_t, D, 8, 2><<<blocks_hashgrid, N_THREAD>>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners); | |
| if (dy_dx) kernel_input_backward<scalar_t, D, 8><<<div_round_up(B * D, N_THREAD), N_THREAD>>>(grad, dy_dx, grad_inputs, B, L); | |
| break; | |
| default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."}; | |
| } | |
| } | |
| // grad: [L, B, C], float | |
| // inputs: [B, D], float, in [0, 1] | |
| // embeddings: [sO, C], float | |
| // offsets: [L + 1], uint32_t | |
| // grad_embeddings: [sO, C] | |
| // H: base resolution | |
| template <typename scalar_t> | |
| void grid_encode_backward_cuda(const scalar_t *grad, const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, scalar_t *dy_dx, scalar_t *grad_inputs, const uint32_t gridtype, const bool align_corners) { | |
| switch (D) { | |
| case 1: kernel_grid_backward_wrapper<scalar_t, 1>(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners); break; | |
| case 2: kernel_grid_backward_wrapper<scalar_t, 2>(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners); break; | |
| case 3: kernel_grid_backward_wrapper<scalar_t, 3>(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners); break; | |
| case 4: kernel_grid_backward_wrapper<scalar_t, 4>(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners); break; | |
| case 5: kernel_grid_backward_wrapper<scalar_t, 5>(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners); break; | |
| default: throw std::runtime_error{"GridEncoding: D must be 1, 2, 3, 4, or 5."}; | |
| } | |
| } | |
| void grid_encode_forward(const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, at::optional<at::Tensor> dy_dx, const uint32_t gridtype, const bool align_corners) { | |
| CHECK_CUDA(inputs); | |
| CHECK_CUDA(embeddings); | |
| CHECK_CUDA(offsets); | |
| CHECK_CUDA(outputs); | |
| // CHECK_CUDA(dy_dx); | |
| CHECK_CONTIGUOUS(inputs); | |
| CHECK_CONTIGUOUS(embeddings); | |
| CHECK_CONTIGUOUS(offsets); | |
| CHECK_CONTIGUOUS(outputs); | |
| // CHECK_CONTIGUOUS(dy_dx); | |
| CHECK_IS_FLOATING(inputs); | |
| CHECK_IS_FLOATING(embeddings); | |
| CHECK_IS_INT(offsets); | |
| CHECK_IS_FLOATING(outputs); | |
| // CHECK_IS_FLOATING(dy_dx); | |
| AT_DISPATCH_FLOATING_TYPES_AND_HALF( | |
| embeddings.scalar_type(), "grid_encode_forward", ([&] { | |
| grid_encode_forward_cuda<scalar_t>(inputs.data_ptr<float>(), embeddings.data_ptr<scalar_t>(), offsets.data_ptr<int>(), outputs.data_ptr<scalar_t>(), B, D, C, L, S, H, dy_dx.has_value() ? dy_dx.value().data_ptr<scalar_t>() : nullptr, gridtype, align_corners); | |
| })); | |
| } | |
| void grid_encode_backward(const at::Tensor grad, const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const at::optional<at::Tensor> dy_dx, at::optional<at::Tensor> grad_inputs, const uint32_t gridtype, const bool align_corners) { | |
| CHECK_CUDA(grad); | |
| CHECK_CUDA(inputs); | |
| CHECK_CUDA(embeddings); | |
| CHECK_CUDA(offsets); | |
| CHECK_CUDA(grad_embeddings); | |
| // CHECK_CUDA(dy_dx); | |
| // CHECK_CUDA(grad_inputs); | |
| CHECK_CONTIGUOUS(grad); | |
| CHECK_CONTIGUOUS(inputs); | |
| CHECK_CONTIGUOUS(embeddings); | |
| CHECK_CONTIGUOUS(offsets); | |
| CHECK_CONTIGUOUS(grad_embeddings); | |
| // CHECK_CONTIGUOUS(dy_dx); | |
| // CHECK_CONTIGUOUS(grad_inputs); | |
| CHECK_IS_FLOATING(grad); | |
| CHECK_IS_FLOATING(inputs); | |
| CHECK_IS_FLOATING(embeddings); | |
| CHECK_IS_INT(offsets); | |
| CHECK_IS_FLOATING(grad_embeddings); | |
| // CHECK_IS_FLOATING(dy_dx); | |
| // CHECK_IS_FLOATING(grad_inputs); | |
| AT_DISPATCH_FLOATING_TYPES_AND_HALF( | |
| grad.scalar_type(), "grid_encode_backward", ([&] { | |
| grid_encode_backward_cuda<scalar_t>(grad.data_ptr<scalar_t>(), inputs.data_ptr<float>(), embeddings.data_ptr<scalar_t>(), offsets.data_ptr<int>(), grad_embeddings.data_ptr<scalar_t>(), B, D, C, L, S, H, dy_dx.has_value() ? dy_dx.value().data_ptr<scalar_t>() : nullptr, grad_inputs.has_value() ? grad_inputs.value().data_ptr<scalar_t>() : nullptr, gridtype, align_corners); | |
| })); | |
| } | |