Spaces:
Paused
Paused
// CUDA forward declarations | |
torch::Tensor cumdist_thres_cuda(torch::Tensor dist, float thres); | |
// C++ interface | |
torch::Tensor cumdist_thres(torch::Tensor dist, float thres) { | |
CHECK_INPUT(dist); | |
return cumdist_thres_cuda(dist, thres); | |
} | |
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { | |
m.def("cumdist_thres", &cumdist_thres, "Generate mask for cumulative dist."); | |
} | |