Tess-M-34B-2bit / quip-sharp /quiptools /quiptools_wrapper.cpp
KnutJaegersberg's picture
Upload 91 files
b3c0032
raw
history blame contribute delete
No virus
2.29 kB
#include <torch/extension.h>
#include <iostream>
#include <cassert>
void lookupmatmul_d4_k8(
torch::Tensor X, // k x n
torch::Tensor YIs, // m x (n/4)
torch::Tensor CB, // 256 x 4
torch::Tensor Z // k x m
);
void lookupmatmul_d4_k16(
torch::Tensor X, // k x n
torch::Tensor YIs, // m x (n/4)
torch::Tensor CB, // 256 x 4
torch::Tensor Z // k x m
);
void lookupmatmul_d4_k32(
torch::Tensor X, // k x n
torch::Tensor YIs, // m x (n/4)
torch::Tensor CB, // 256 x 4
torch::Tensor Z // k x m
);
void decompress_d4(
torch::Tensor YIs, // m x (n/4)
torch::Tensor CB, // 256 x 4
torch::Tensor Y // m x n
);
void decompress_d4_origorder(
torch::Tensor YIs, // m x (n/4)
torch::Tensor CB, // 256 x 4
torch::Tensor Y // m x n
);
void decompress_e8p_origorder(
torch::Tensor YIs, // m x (n/8)
torch::Tensor CB, // 256 x 8
torch::Tensor CB_even_flips, // 256
torch::Tensor &Y // m x n
);
torch::Tensor decompress_packed_e8p(
torch::Tensor weights_compressed, // m x (n/8)
torch::Tensor codebook_abs // 256 x 8
);
torch::Tensor decode_matvec_e8p(
torch::Tensor x,
torch::Tensor weights_compressed,
torch::Tensor codebook_abs
);
void decompress_hi4b1c_packed(
torch::Tensor YIs, // m x (n/8)
torch::Tensor CB, // 16 x 1
torch::Tensor &Y // m x n
);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("lookupmatmul_d4_k8", &lookupmatmul_d4_k8, "lookupmatmul_d4_k8");
m.def("lookupmatmul_d4_k16", &lookupmatmul_d4_k16, "lookupmatmul_d4_k16");
m.def("lookupmatmul_d4_k32", &lookupmatmul_d4_k32, "lookupmatmul_d4_k32");
m.def("decompress_d4", &decompress_d4, "decompress_d4");
m.def("decompress_d4_origorder", &decompress_d4_origorder, "decompress_d4_origorder");
m.def("decompress_e8p_origorder", &decompress_e8p_origorder, "decompress_e8p_origorder");
m.def("decompress_packed_e8p", &decompress_packed_e8p, "decompress_packed_e8p");
m.def("decode_matvec_e8p", &decode_matvec_e8p, "decode_matvec_e8p");
m.def("decompress_hi4b1c_packed", &decompress_hi4b1c_packed, "decompress_hi4b1c_packed");
}