// File from https://github.com/mlpen/YOSO/blob/main/encoders/backbones/efficient_attentions/yoso/yoso_v1/cuda/fast_lsh_cumulation.cu #include #include #include "fast_lsh_cumulation.h" #include "fast_lsh_cumulation_cuda.h" #include "common_cuda.h" #include "common.h" #include ////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////// std::vector fast_hash_ver1_kernel( at::Tensor query_mask, at::Tensor query_vector, at::Tensor key_mask, at::Tensor key_vector, int num_hash_f, int hash_code_len, bool use_cuda ) { int batch_size = query_vector.size(0); int num_query = query_vector.size(1); int num_key = key_vector.size(1); int vector_dim = query_vector.size(2); int num_hash_per_part = vector_dim / hash_code_len; int num_part = max(1, ceil_divide(num_hash_f, num_hash_per_part)); at::Tensor Dmat = 2 * at::randint(0, 2, {batch_size, 3, num_part, vector_dim}, query_mask.options()) - 1; at::Tensor query_hash_code = at::zeros({batch_size, num_query, num_hash_f}, query_mask.options()); at::Tensor key_hash_code = at::zeros({batch_size, num_key, num_hash_f}, key_mask.options()); int *query_mask_ptr = query_mask.data_ptr(); float *query_vector_ptr = query_vector.data_ptr(); int *key_mask_ptr = key_mask.data_ptr(); float *key_vector_ptr = key_vector.data_ptr(); int *Dmat_ptr = Dmat.data_ptr(); int *query_hash_code_ptr = query_hash_code.data_ptr(); int *key_hash_code_ptr = key_hash_code.data_ptr(); if (use_cuda) { { dim3 threads(vector_dim); dim3 blocks(num_part, num_query, batch_size); int shared_mem = vector_dim * sizeof(float); fast_hash_ver1_cuda_kernel<<>>( query_mask_ptr, query_vector_ptr, Dmat_ptr, query_hash_code_ptr, batch_size, num_query, vector_dim, num_part, num_hash_f, hash_code_len ); } { dim3 threads(vector_dim); dim3 blocks(num_part, num_key, batch_size); int shared_mem = vector_dim * sizeof(float); fast_hash_ver1_cuda_kernel<<>>( key_mask_ptr, key_vector_ptr, Dmat_ptr, key_hash_code_ptr, batch_size, num_key, vector_dim, num_part, num_hash_f, hash_code_len ); } } return {query_hash_code, key_hash_code}; } at::Tensor lsh_cumulation_ver1_kernel( at::Tensor query_mask, at::Tensor query_hash_code, at::Tensor key_mask, at::Tensor key_hash_code, at::Tensor value, int hashtable_capacity, bool use_cuda ) { int batch_size = query_hash_code.size(0); int num_hash_f = query_hash_code.size(2); int num_query = query_hash_code.size(1); int num_key = key_hash_code.size(1); int value_dim = value.size(2); at::Tensor hashtable_value = at::empty({batch_size, num_hash_f, hashtable_capacity, WARP_SIZE}, value.options()); at::Tensor cumulation_value = at::zeros({batch_size, num_query, value_dim}, value.options()); if (use_cuda) { int threads_x = WARP_SIZE; int threads_y = OPTIMAL_THREADS_PER_BLOCK / WARP_SIZE; int block_x_step1 = num_key / threads_y; int block_x_step2 = num_query / threads_y; int block_y = batch_size; dim3 threads(threads_x, threads_y); dim3 blocks_step1(block_x_step1, block_y); dim3 blocks_step2(block_x_step2, block_y); int *query_mask_ptr = query_mask.data_ptr(); int *query_hash_code_ptr = query_hash_code.data_ptr(); int *key_mask_ptr = key_mask.data_ptr(); int *key_hash_code_ptr = key_hash_code.data_ptr(); float *value_ptr = value.data_ptr(); float *hashtable_value_ptr = hashtable_value.data_ptr(); float *cumulation_value_ptr = cumulation_value.data_ptr(); for (int value_offset = 0; value_offset < value_dim; value_offset = value_offset + WARP_SIZE) { cudaMemset(hashtable_value_ptr, 0, (batch_size * num_hash_f * hashtable_capacity * WARP_SIZE) * sizeof(float)); lsh_cumulation_ver1_step1_cuda_kernel<<>>( key_mask_ptr, key_hash_code_ptr, value_ptr, hashtable_value_ptr, batch_size, num_hash_f, hashtable_capacity, num_key, value_dim, value_offset ); lsh_cumulation_ver1_step2_cuda_kernel<<>>( query_mask_ptr, query_hash_code_ptr, hashtable_value_ptr, cumulation_value_ptr, batch_size, num_hash_f, hashtable_capacity, num_query, value_dim, value_offset ); } } return cumulation_value; } at::Tensor lsh_weighted_cumulation_ver1_kernel( at::Tensor query_mask, at::Tensor query_hash_code, at::Tensor query_weight, at::Tensor key_mask, at::Tensor key_hash_code, at::Tensor key_weight, at::Tensor value, int hashtable_capacity, bool use_cuda ) { int batch_size = query_hash_code.size(0); int num_hash_f = query_hash_code.size(2); int num_query = query_hash_code.size(1); int num_key = key_hash_code.size(1); int value_dim = value.size(2); int weight_dim = query_weight.size(2); at::Tensor hashtable_value = at::zeros({batch_size, num_hash_f, hashtable_capacity, WARP_SIZE}, value.options()); at::Tensor cumulation_value = at::zeros({batch_size, num_query, value_dim}, value.options()); if (use_cuda) { int threads_x = WARP_SIZE; int threads_y = OPTIMAL_THREADS_PER_BLOCK / WARP_SIZE; int block_x_step1 = num_key / threads_y; int block_x_step2 = num_query / threads_y; int block_y = batch_size; dim3 threads(threads_x, threads_y); dim3 blocks_step1(block_x_step1, block_y); dim3 blocks_step2(block_x_step2, block_y); int *query_mask_ptr = query_mask.data_ptr(); int *query_hash_code_ptr = query_hash_code.data_ptr(); float *query_weight_ptr = query_weight.data_ptr(); int *key_mask_ptr = key_mask.data_ptr(); int *key_hash_code_ptr = key_hash_code.data_ptr(); float *key_weight_ptr = key_weight.data_ptr(); float *value_ptr = value.data_ptr(); float *hashtable_value_ptr = hashtable_value.data_ptr(); float *cumulation_value_ptr = cumulation_value.data_ptr(); for (int value_offset = 0; value_offset < value_dim; value_offset = value_offset + WARP_SIZE) { for (int weight_idx = 0; weight_idx < weight_dim; weight_idx++) { cudaMemset(hashtable_value_ptr, 0, (batch_size * num_hash_f * hashtable_capacity * WARP_SIZE) * sizeof(float)); lsh_weighted_cumulation_ver1_step1_cuda_kernel<<>>( key_mask_ptr, key_hash_code_ptr, key_weight_ptr, value_ptr, hashtable_value_ptr, batch_size, num_hash_f, hashtable_capacity, num_key, value_dim, weight_dim, value_offset, weight_idx ); lsh_weighted_cumulation_ver1_step2_cuda_kernel<<>>( query_mask_ptr, query_hash_code_ptr, query_weight_ptr, hashtable_value_ptr, cumulation_value_ptr, batch_size, num_hash_f, hashtable_capacity, num_query, value_dim, weight_dim, value_offset, weight_idx ); } } } return cumulation_value; } at::Tensor lsh_weighted_cumulation_ver2_kernel( at::Tensor query_mask, at::Tensor query_hash_code, at::Tensor query_weight, at::Tensor key_mask, at::Tensor key_hash_code, at::Tensor key_weight, at::Tensor value, int hashtable_capacity, bool use_cuda ) { int batch_size = query_hash_code.size(0); int num_hash_f = query_hash_code.size(2); int num_query = query_hash_code.size(1); int num_key = key_hash_code.size(1); int value_dim = value.size(2); int weight_dim = query_weight.size(2); at::Tensor count_sort_table = at::zeros({batch_size, num_hash_f, hashtable_capacity}, query_hash_code.options()); at::Tensor key_sorted_idxes = at::zeros({batch_size, num_hash_f, num_key}, query_hash_code.options()); at::Tensor query_info = at::zeros({batch_size, num_query, 2, num_hash_f}, query_hash_code.options()); at::Tensor cumulation_value = at::zeros({batch_size, num_query, value_dim}, value.options()); if (use_cuda) { int *query_mask_ptr = query_mask.data_ptr(); int *query_hash_code_ptr = query_hash_code.data_ptr(); float *query_weight_ptr = query_weight.data_ptr(); int *key_mask_ptr = key_mask.data_ptr(); int *key_hash_code_ptr = key_hash_code.data_ptr(); float *key_weight_ptr = key_weight.data_ptr(); float *value_ptr = value.data_ptr(); int *count_sort_table_ptr = count_sort_table.data_ptr(); int *key_sorted_idxes_ptr = key_sorted_idxes.data_ptr(); int *query_info_ptr = query_info.data_ptr(); float *cumulation_value_ptr = cumulation_value.data_ptr(); { dim3 threads_step13(num_hash_f, max(1, OPTIMAL_THREADS_PER_BLOCK / num_hash_f)); dim3 blocks_step13(num_key / max(1, OPTIMAL_THREADS_PER_BLOCK / num_hash_f), batch_size); dim3 threads_step2(min(hashtable_capacity, OPTIMAL_THREADS_PER_BLOCK)); dim3 blocks_step2(num_hash_f, batch_size); int shared_mem = hashtable_capacity * sizeof(float); count_sort_step1_cuda_kernel<<>>( key_mask_ptr, key_hash_code_ptr, count_sort_table_ptr, batch_size, num_hash_f, hashtable_capacity, num_key ); count_sort_step2_cuda_kernel<<>>( count_sort_table_ptr, batch_size, num_hash_f, hashtable_capacity ); count_sort_step3_cuda_kernel<<>>( key_mask_ptr, key_hash_code_ptr, count_sort_table_ptr, key_sorted_idxes_ptr, batch_size, num_hash_f, hashtable_capacity, num_key ); } { dim3 threads(num_hash_f, max(1, OPTIMAL_THREADS_PER_BLOCK / num_hash_f)); dim3 blocks(num_query / max(1, OPTIMAL_THREADS_PER_BLOCK / num_hash_f), batch_size); extract_query_info_cuda_kernel<<>>( query_mask_ptr, query_hash_code_ptr, count_sort_table_ptr, query_info_ptr, batch_size, num_hash_f, hashtable_capacity, num_query ); } { dim3 threads(WARP_SIZE, OPTIMAL_THREADS_PER_BLOCK / WARP_SIZE); dim3 blocks(num_query, num_hash_f, batch_size); int shared_mem = (weight_dim + WARP_SIZE) * sizeof(float); lsh_weighted_cumulation_ver2_step2_cuda_kernel<<>>( query_mask_ptr, query_info_ptr, key_sorted_idxes_ptr, query_weight_ptr, key_weight_ptr, value_ptr, cumulation_value_ptr, batch_size, num_hash_f, num_query, num_key, value_dim, weight_dim ); } } return cumulation_value; } at::Tensor lsh_weighted_cumulation_ver3_kernel( at::Tensor query_mask, at::Tensor query_hash_code, at::Tensor query_weight, at::Tensor key_mask, at::Tensor key_hash_code, at::Tensor key_weight, at::Tensor value, int hashtable_capacity, bool use_cuda ) { int batch_size = query_hash_code.size(0); int num_hash_f = query_hash_code.size(2); int num_query = query_hash_code.size(1); int num_key = key_hash_code.size(1); int value_dim = value.size(2); int weight_dim = query_weight.size(2); at::Tensor count_sort_table = at::zeros({batch_size, num_hash_f, hashtable_capacity}, query_hash_code.options()); at::Tensor query_sorted_idxes = at::zeros({batch_size, num_hash_f, num_query}, query_hash_code.options()); at::Tensor key_info = at::zeros({batch_size, num_key, 2, num_hash_f}, query_hash_code.options()); at::Tensor cumulation_value = at::zeros({batch_size, num_query, value_dim}, value.options()); if (use_cuda) { int *query_mask_ptr = query_mask.data_ptr(); int *query_hash_code_ptr = query_hash_code.data_ptr(); float *query_weight_ptr = query_weight.data_ptr(); int *key_mask_ptr = key_mask.data_ptr(); int *key_hash_code_ptr = key_hash_code.data_ptr(); float *key_weight_ptr = key_weight.data_ptr(); float *value_ptr = value.data_ptr(); int *count_sort_table_ptr = count_sort_table.data_ptr(); int *query_sorted_idxes_ptr = query_sorted_idxes.data_ptr(); int *key_info_ptr = key_info.data_ptr(); float *cumulation_value_ptr = cumulation_value.data_ptr(); { dim3 threads_step13(num_hash_f, max(1, OPTIMAL_THREADS_PER_BLOCK / num_hash_f)); dim3 blocks_step13(num_query / max(1, OPTIMAL_THREADS_PER_BLOCK / num_hash_f), batch_size); dim3 threads_step2(min(hashtable_capacity, OPTIMAL_THREADS_PER_BLOCK)); dim3 blocks_step2(num_hash_f, batch_size); int shared_mem = hashtable_capacity * sizeof(float); count_sort_step1_cuda_kernel<<>>( query_mask_ptr, query_hash_code_ptr, count_sort_table_ptr, batch_size, num_hash_f, hashtable_capacity, num_query ); count_sort_step2_cuda_kernel<<>>( count_sort_table_ptr, batch_size, num_hash_f, hashtable_capacity ); count_sort_step3_cuda_kernel<<>>( query_mask_ptr, query_hash_code_ptr, count_sort_table_ptr, query_sorted_idxes_ptr, batch_size, num_hash_f, hashtable_capacity, num_query ); } { dim3 threads(num_hash_f, max(1, OPTIMAL_THREADS_PER_BLOCK / num_hash_f)); dim3 blocks(num_key / max(1, OPTIMAL_THREADS_PER_BLOCK / num_hash_f), batch_size); extract_query_info_cuda_kernel<<>>( key_mask_ptr, key_hash_code_ptr, count_sort_table_ptr, key_info_ptr, batch_size, num_hash_f, hashtable_capacity, num_key ); } { dim3 threads(WARP_SIZE, OPTIMAL_THREADS_PER_BLOCK / WARP_SIZE); dim3 blocks(num_key, num_hash_f, batch_size); int shared_mem = (weight_dim + value_dim + WARP_SIZE) * sizeof(float); lsh_weighted_cumulation_ver3_step2_cuda_kernel<<>>( query_sorted_idxes_ptr, key_mask_ptr, key_info_ptr, query_weight_ptr, key_weight_ptr, value_ptr, cumulation_value_ptr, batch_size, num_hash_f, num_query, num_key, value_dim, weight_dim ); } } return cumulation_value; } at::Tensor lsh_weighted_cumulation_ver4_kernel( at::Tensor query_mask, at::Tensor query_hash_code, at::Tensor query_weight, at::Tensor key_mask, at::Tensor key_hash_code, at::Tensor key_weight, at::Tensor value, int hashtable_capacity, bool use_cuda ) { int batch_size = query_hash_code.size(0); int num_hash_f = query_hash_code.size(2); int num_query = query_hash_code.size(1); int num_key = key_hash_code.size(1); int value_dim = value.size(2); int weight_dim = query_weight.size(2); at::Tensor count_sort_table = at::zeros({batch_size, num_hash_f, hashtable_capacity}, query_hash_code.options()); at::Tensor query_sorted_idxes = at::zeros({batch_size, num_hash_f, num_query}, query_hash_code.options()); at::Tensor key_info = at::zeros({batch_size, num_key, 2, num_hash_f}, query_hash_code.options()); at::Tensor cumulation_value = at::zeros({batch_size, num_query, value_dim}, value.options()); if (use_cuda) { int *query_mask_ptr = query_mask.data_ptr(); int *query_hash_code_ptr = query_hash_code.data_ptr(); float *query_weight_ptr = query_weight.data_ptr(); int *key_mask_ptr = key_mask.data_ptr(); int *key_hash_code_ptr = key_hash_code.data_ptr(); float *key_weight_ptr = key_weight.data_ptr(); float *value_ptr = value.data_ptr(); int *count_sort_table_ptr = count_sort_table.data_ptr(); int *query_sorted_idxes_ptr = query_sorted_idxes.data_ptr(); int *key_info_ptr = key_info.data_ptr(); float *cumulation_value_ptr = cumulation_value.data_ptr(); { dim3 threads_step13(num_hash_f, max(1, OPTIMAL_THREADS_PER_BLOCK / num_hash_f)); dim3 blocks_step13(num_query / max(1, OPTIMAL_THREADS_PER_BLOCK / num_hash_f), batch_size); dim3 threads_step2(min(hashtable_capacity, OPTIMAL_THREADS_PER_BLOCK)); dim3 blocks_step2(num_hash_f, batch_size); int shared_mem = hashtable_capacity * sizeof(float); count_sort_step1_cuda_kernel<<>>( query_mask_ptr, query_hash_code_ptr, count_sort_table_ptr, batch_size, num_hash_f, hashtable_capacity, num_query ); count_sort_step2_cuda_kernel<<>>( count_sort_table_ptr, batch_size, num_hash_f, hashtable_capacity ); count_sort_step3_cuda_kernel<<>>( query_mask_ptr, query_hash_code_ptr, count_sort_table_ptr, query_sorted_idxes_ptr, batch_size, num_hash_f, hashtable_capacity, num_query ); } { dim3 threads(num_hash_f, max(1, OPTIMAL_THREADS_PER_BLOCK / num_hash_f)); dim3 blocks(num_key / max(1, OPTIMAL_THREADS_PER_BLOCK / num_hash_f), batch_size); extract_query_info_cuda_kernel<<>>( key_mask_ptr, key_hash_code_ptr, count_sort_table_ptr, key_info_ptr, batch_size, num_hash_f, hashtable_capacity, num_key ); } { dim3 threads(WARP_SIZE, OPTIMAL_THREADS_PER_BLOCK / WARP_SIZE); dim3 blocks(num_key, batch_size); int shared_mem = (weight_dim + value_dim + 2 * num_hash_f) * sizeof(float); lsh_weighted_cumulation_ver4_step2_cuda_kernel<<>>( query_sorted_idxes_ptr, key_mask_ptr, key_info_ptr, query_weight_ptr, key_weight_ptr, value_ptr, cumulation_value_ptr, batch_size, num_hash_f, num_query, num_key, value_dim, weight_dim ); } } return cumulation_value; }