// Copyright (c) Facebook, Inc. and its affiliates.All Rights Reserved // Please refer to original code: https://github.com/NVlabs/instant-ngp // and the pytorch wrapper from https://github.com/ashawkey/torch-ngp #include #include #include #include "hash_sample.h" #include "utils.h" void hash_encode_forward(at::Tensor inputs, at::Tensor embeddings, at::Tensor offsets, at::Tensor outputs, const float beta, const uint32_t B, const uint32_t N, const uint32_t D, const uint32_t C, const uint32_t L, const uint32_t H, const bool calc_grad_inputs, at::Tensor dy_dx, const uint32_t mode) { CHECK_CUDA(inputs); CHECK_CUDA(embeddings); CHECK_CUDA(offsets); CHECK_CUDA(outputs); CHECK_CUDA(dy_dx); CHECK_CONTIGUOUS(inputs); CHECK_CONTIGUOUS(embeddings); CHECK_CONTIGUOUS(offsets); CHECK_CONTIGUOUS(outputs); CHECK_CONTIGUOUS(dy_dx); CHECK_IS_FLOAT(inputs); CHECK_IS_FLOAT(embeddings); CHECK_IS_INT(offsets); CHECK_IS_FLOAT(outputs); CHECK_IS_FLOAT(dy_dx); hash_encode_forward_cuda(inputs.data_ptr(), embeddings.data_ptr(), offsets.data_ptr(), outputs.data_ptr(), beta, B, N, D, C, L, H, calc_grad_inputs, dy_dx.data_ptr(), mode); } void hash_encode_backward(at::Tensor grad, at::Tensor inputs, at::Tensor embeddings, at::Tensor offsets, at::Tensor grad_embeddings, const float beta, const uint32_t B, const uint32_t N, const uint32_t D, const uint32_t C, const uint32_t L, const uint32_t H, const bool calc_grad_inputs, at::Tensor dy_dx, at::Tensor grad_inputs, const uint32_t mode) { CHECK_CUDA(grad); CHECK_CUDA(inputs); CHECK_CUDA(embeddings); CHECK_CUDA(offsets); CHECK_CUDA(grad_embeddings); CHECK_CUDA(dy_dx); CHECK_CUDA(grad_inputs); CHECK_CONTIGUOUS(grad); CHECK_CONTIGUOUS(inputs); CHECK_CONTIGUOUS(embeddings); CHECK_CONTIGUOUS(offsets); CHECK_CONTIGUOUS(grad_embeddings); CHECK_CONTIGUOUS(dy_dx); CHECK_CONTIGUOUS(grad_inputs); CHECK_IS_FLOAT(grad); CHECK_IS_FLOAT(inputs); CHECK_IS_FLOAT(embeddings); CHECK_IS_INT(offsets); CHECK_IS_FLOAT(grad_embeddings); CHECK_IS_FLOAT(dy_dx); CHECK_IS_FLOAT(grad_inputs); hash_encode_backward_cuda(grad.data_ptr(), inputs.data_ptr(), embeddings.data_ptr(), offsets.data_ptr(), grad_embeddings.data_ptr(), beta, B, N, D, C, L, H, calc_grad_inputs, dy_dx.data_ptr(), grad_inputs.data_ptr(), mode); } //------------------------------------------------------------------------ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("hash_encode_forward", &hash_encode_forward, "hash encode forward (CUDA)"); m.def("hash_encode_backward", &hash_encode_backward, "hash encode backward (CUDA)"); }