batch_invariant_kernel / torch-ext /torch_binding.cpp
gagan3012's picture
Upload folder using huggingface_hub
e6010fe verified
raw
history blame contribute delete
576 Bytes
#include <torch/extension.h>
#include "torch_binding.h"
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops)
{
ops.def("matmul_persistent(Tensor a, Tensor b, Tensor! c, Tensor? bias) -> ()");
ops.def("log_softmax(Tensor input, Tensor! output) -> ()");
ops.def("mean_dim(Tensor input, Tensor! output, int dim) -> ()");
ops.impl("matmul_persistent", torch::kCUDA, &matmul_persistent_cuda);
ops.impl("log_softmax", torch::kCUDA, &log_softmax_cuda);
ops.impl("mean_dim", torch::kCUDA, &mean_dim_cuda);
}
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)