TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) | |
{ | |
ops.def("matmul_persistent(Tensor a, Tensor b, Tensor! c, Tensor? bias) -> ()"); | |
ops.def("log_softmax(Tensor input, Tensor! output) -> ()"); | |
ops.def("mean_dim(Tensor input, Tensor! output, int dim) -> ()"); | |
ops.impl("matmul_persistent", torch::kCUDA, &matmul_persistent_cuda); | |
ops.impl("log_softmax", torch::kCUDA, &log_softmax_cuda); | |
ops.impl("mean_dim", torch::kCUDA, &mean_dim_cuda); | |
} | |
REGISTER_EXTENSION(TORCH_EXTENSION_NAME) |