File size: 576 Bytes
e6010fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
#include <torch/extension.h>
#include "torch_binding.h"

TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops)
{
    ops.def("matmul_persistent(Tensor a, Tensor b, Tensor! c, Tensor? bias) -> ()");
    ops.def("log_softmax(Tensor input, Tensor! output) -> ()");
    ops.def("mean_dim(Tensor input, Tensor! output, int dim) -> ()");

    ops.impl("matmul_persistent", torch::kCUDA, &matmul_persistent_cuda);
    ops.impl("log_softmax", torch::kCUDA, &log_softmax_cuda);
    ops.impl("mean_dim", torch::kCUDA, &mean_dim_cuda);
}

REGISTER_EXTENSION(TORCH_EXTENSION_NAME)