|
#include <torch/extension.h> |
|
|
|
#include <iostream> |
|
#include <cassert> |
|
|
|
void lookupmatmul_d4_k8( |
|
torch::Tensor X, |
|
torch::Tensor YIs, |
|
torch::Tensor CB, |
|
torch::Tensor Z |
|
); |
|
|
|
void lookupmatmul_d4_k16( |
|
torch::Tensor X, |
|
torch::Tensor YIs, |
|
torch::Tensor CB, |
|
torch::Tensor Z |
|
); |
|
|
|
void lookupmatmul_d4_k32( |
|
torch::Tensor X, |
|
torch::Tensor YIs, |
|
torch::Tensor CB, |
|
torch::Tensor Z |
|
); |
|
|
|
void decompress_d4( |
|
torch::Tensor YIs, |
|
torch::Tensor CB, |
|
torch::Tensor Y |
|
); |
|
|
|
void decompress_d4_origorder( |
|
torch::Tensor YIs, |
|
torch::Tensor CB, |
|
torch::Tensor Y |
|
); |
|
|
|
void decompress_e8p_origorder( |
|
torch::Tensor YIs, |
|
torch::Tensor CB, |
|
torch::Tensor CB_even_flips, |
|
torch::Tensor &Y |
|
); |
|
|
|
torch::Tensor decompress_packed_e8p( |
|
torch::Tensor weights_compressed, |
|
torch::Tensor codebook_abs |
|
); |
|
|
|
torch::Tensor decode_matvec_e8p( |
|
torch::Tensor x, |
|
torch::Tensor weights_compressed, |
|
torch::Tensor codebook_abs |
|
); |
|
|
|
void decompress_hi4b1c_packed( |
|
torch::Tensor YIs, |
|
torch::Tensor CB, |
|
torch::Tensor &Y |
|
); |
|
|
|
|
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { |
|
m.def("lookupmatmul_d4_k8", &lookupmatmul_d4_k8, "lookupmatmul_d4_k8"); |
|
m.def("lookupmatmul_d4_k16", &lookupmatmul_d4_k16, "lookupmatmul_d4_k16"); |
|
m.def("lookupmatmul_d4_k32", &lookupmatmul_d4_k32, "lookupmatmul_d4_k32"); |
|
m.def("decompress_d4", &decompress_d4, "decompress_d4"); |
|
m.def("decompress_d4_origorder", &decompress_d4_origorder, "decompress_d4_origorder"); |
|
m.def("decompress_e8p_origorder", &decompress_e8p_origorder, "decompress_e8p_origorder"); |
|
m.def("decompress_packed_e8p", &decompress_packed_e8p, "decompress_packed_e8p"); |
|
m.def("decode_matvec_e8p", &decode_matvec_e8p, "decode_matvec_e8p"); |
|
m.def("decompress_hi4b1c_packed", &decompress_hi4b1c_packed, "decompress_hi4b1c_packed"); |
|
} |
|
|
|
|