|
|
#include <torch/library.h> |
|
|
|
|
|
#include "registration.h" |
|
|
#include "torch_binding.h" |
|
|
|
|
|
extern "C" void higgs_dequantize_2_256_ptr_cuda_portable(uint64_t x_ptr, |
|
|
uint64_t grid_ptr, |
|
|
uint64_t out_ptr, |
|
|
int64_t out_dim); |
|
|
|
|
|
extern "C" void higgs_quantize_2_256_ptr_f16_cuda_portable( |
|
|
uint64_t x_ptr, uint64_t grid_ptr, uint64_t grid_norms_ptr, |
|
|
uint64_t out_ptr, int64_t out_dim); |
|
|
|
|
|
extern "C" void higgs_quantize_2_256_ptr_bf16_cuda_portable( |
|
|
uint64_t x_ptr, uint64_t grid_ptr, uint64_t grid_norms_ptr, |
|
|
uint64_t out_ptr, int64_t out_dim); |
|
|
|
|
|
void higgs_dequantize_2_256(torch::Tensor x, torch::Tensor grid, |
|
|
torch::Tensor out) { |
|
|
int64_t out_dim = x.size(0); |
|
|
higgs_dequantize_2_256_ptr_cuda_portable( |
|
|
reinterpret_cast<uint64_t>(x.data_ptr()), |
|
|
reinterpret_cast<uint64_t>(grid.data_ptr()), |
|
|
reinterpret_cast<uint64_t>(out.data_ptr()), out_dim); |
|
|
} |
|
|
|
|
|
void higgs_quantize_2_256_f16(torch::Tensor x, torch::Tensor grid, |
|
|
torch::Tensor grid_norms, torch::Tensor out) { |
|
|
int64_t out_dim = x.size(0); |
|
|
higgs_quantize_2_256_ptr_f16_cuda_portable( |
|
|
reinterpret_cast<uint64_t>(x.data_ptr()), |
|
|
reinterpret_cast<uint64_t>(grid.data_ptr()), |
|
|
reinterpret_cast<uint64_t>(grid_norms.data_ptr()), |
|
|
reinterpret_cast<uint64_t>(out.data_ptr()), out_dim); |
|
|
} |
|
|
|
|
|
void higgs_quantize_2_256_bf16(torch::Tensor x, torch::Tensor grid, |
|
|
torch::Tensor grid_norms, torch::Tensor out) { |
|
|
int64_t out_dim = x.size(0); |
|
|
higgs_quantize_2_256_ptr_bf16_cuda_portable( |
|
|
reinterpret_cast<uint64_t>(x.data_ptr()), |
|
|
reinterpret_cast<uint64_t>(grid.data_ptr()), |
|
|
reinterpret_cast<uint64_t>(grid_norms.data_ptr()), |
|
|
reinterpret_cast<uint64_t>(out.data_ptr()), out_dim); |
|
|
} |
|
|
|
|
|
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { |
|
|
ops.def( |
|
|
"higgs_dequantize_2_256(Tensor x, Tensor grid, Tensor! out) -> ()"); |
|
|
ops.impl("higgs_dequantize_2_256", torch::kCUDA, &higgs_dequantize_2_256); |
|
|
|
|
|
ops.def("higgs_quantize_2_256_f16(Tensor x, Tensor grid, Tensor " |
|
|
"grid_norms, Tensor! out) -> ()"); |
|
|
ops.impl("higgs_quantize_2_256_f16", torch::kCUDA, |
|
|
&higgs_quantize_2_256_f16); |
|
|
|
|
|
ops.def("higgs_quantize_2_256_bf16(Tensor x, Tensor grid, Tensor " |
|
|
"grid_norms, Tensor! out) -> ()"); |
|
|
ops.impl("higgs_quantize_2_256_bf16", torch::kCUDA, |
|
|
&higgs_quantize_2_256_bf16); |
|
|
} |
|
|
|
|
|
REGISTER_EXTENSION(TORCH_EXTENSION_NAME) |
|
|
|