#include #include #include "cuda_launch.h" #include "cuda_kernel.h" #include ////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////// std::vector index_max_kernel( at::Tensor index_vals, // [batch_size, 32, num_block] at::Tensor indices, // [batch_size, num_block], int A_num_block, int B_num_block ) { int batch_size = indices.size(0); int num_block = indices.size(1); at::Tensor max_vals = at::zeros({batch_size, A_num_block * 32}, index_vals.options()); at::Tensor max_vals_scatter = at::zeros({batch_size, 32, num_block}, index_vals.options()); dim3 threads(256); dim3 blocks(batch_size); int shared_mem = A_num_block * 32 * sizeof(float); index_max_cuda_kernel<<>>( index_vals.data_ptr(), indices.data_ptr(), max_vals.data_ptr(), max_vals_scatter.data_ptr(), batch_size, A_num_block, B_num_block, num_block ); return {max_vals, max_vals_scatter}; } at::Tensor mm_to_sparse_kernel( at::Tensor dense_A, // [batch_size, A_num_block, dim, 32] at::Tensor dense_B, // [batch_size, B_num_block, dim, 32] at::Tensor indices // [batch_size, num_block] ) { int batch_size = dense_A.size(0); int A_num_block = dense_A.size(1); int B_num_block = dense_B.size(1); int dim = dense_A.size(2); int num_block = indices.size(1); at::Tensor sparse_C = at::zeros({batch_size, num_block, 32, 32}, dense_A.options()); dim3 threads(64, 4); dim3 blocks(num_block / 4, batch_size); mm_to_sparse_cuda_kernel<<>>( dense_A.data_ptr(), dense_B.data_ptr(), indices.data_ptr(), sparse_C.data_ptr(), batch_size, A_num_block, B_num_block, dim, num_block ); return sparse_C; } at::Tensor sparse_dense_mm_kernel( at::Tensor sparse_A, // [batch_size, num_block, 32, 32] at::Tensor indices, // [batch_size, num_block] at::Tensor dense_B, // [batch_size, B_num_block, dim, 32] int A_num_block ) { int batch_size = sparse_A.size(0); int num_block = sparse_A.size(1); int B_num_block = dense_B.size(1); int dim = dense_B.size(2); at::Tensor dense_C = at::zeros({batch_size, A_num_block, dim, 32}, dense_B.options()); dim3 threads(128, 2); dim3 blocks(num_block / 2, batch_size); sparse_dense_mm_cuda_kernel<<>>( sparse_A.data_ptr(), indices.data_ptr(), dense_B.data_ptr(), dense_C.data_ptr(), batch_size, A_num_block, B_num_block, dim, num_block ); return dense_C; } at::Tensor reduce_sum_kernel( at::Tensor sparse_A, // [batch_size, num_block, 32, 32] at::Tensor indices, // [batch_size, num_block] int A_num_block, int B_num_block ) { int batch_size = sparse_A.size(0); int num_block = sparse_A.size(1); at::Tensor dense_C = at::zeros({batch_size, A_num_block, 32}, sparse_A.options()); dim3 threads(32, 4); dim3 blocks(num_block / 4, batch_size); reduce_sum_cuda_kernel<<>>( sparse_A.data_ptr(), indices.data_ptr(), dense_C.data_ptr(), batch_size, A_num_block, B_num_block, num_block ); return dense_C; } at::Tensor scatter_kernel( at::Tensor dense_A, // [batch_size, A_num_block, 32] at::Tensor indices, // [batch_size, num_block] int B_num_block ) { int batch_size = dense_A.size(0); int A_num_block = dense_A.size(1); int num_block = indices.size(1); at::Tensor sparse_C = at::zeros({batch_size, num_block, 32, 32}, dense_A.options()); dim3 threads(32, 4); dim3 blocks(num_block / 4, batch_size); scatter_cuda_kernel<<>>( dense_A.data_ptr(), indices.data_ptr(), sparse_C.data_ptr(), batch_size, A_num_block, B_num_block, num_block ); return sparse_C; }