// File from https://github.com/mlpen/YOSO/blob/main/encoders/backbones/efficient_attentions/yoso/yoso_v1/cuda/fast_lsh_cumulation_cuda.cu #include "fast_lsh_cumulation_cuda.h" #include "common_cuda_device.h" #include "common_cuda.h" #include "common.h" #include ////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void fast_hadamard_transform(float *vector_buffer, int vector_dim, int dim_idx) { int stride = vector_dim / 2; while (stride > (WARP_SIZE / 2)) { __syncthreads(); int sign = 1 - ((dim_idx / stride) % 2) * 2; float val1 = vector_buffer[dim_idx]; float val2 = vector_buffer[dim_idx + sign * stride]; __syncthreads(); vector_buffer[dim_idx] = float(sign) * val1 + val2; stride = stride / 2; } float val = vector_buffer[dim_idx]; #pragma unroll for (stride = (WARP_SIZE / 2); stride > 0; stride = stride / 2) { int sign = 1 - ((dim_idx / stride) % 2) * 2; val = float(sign) * val + __shfl_xor_sync(FULL_MASK, val, stride); } vector_buffer[dim_idx] = val; } __global__ void fast_hash_ver1_cuda_kernel( int *mask, // [batch_size, num_vector] float *vector, // [batch_size, num_vector, vector_dim] int *Dmat, // [batch_size, 3, num_part, vector_dim] int *hash_code, // [batch_size, num_vector, num_hash_f] int batch_size, int num_vector, int vector_dim, int num_part, int num_hash_f, int hash_code_len ) { int batch_idx = blockIdx.z; int vector_idx = blockIdx.y; int part_idx = blockIdx.x; int dim_idx = threadIdx.x; int batch_idx__vector_idx = batch_idx * num_vector + vector_idx; if (mask[batch_idx__vector_idx] == 0) { return; } extern __shared__ float buffer[]; float *vector_buffer = buffer; vector_buffer[dim_idx] = vector[batch_idx__vector_idx * vector_dim + dim_idx]; vector_buffer[dim_idx] = vector_buffer[dim_idx] * (float)Dmat[((batch_idx * 3 + 0) * num_part + part_idx) * vector_dim + dim_idx]; fast_hadamard_transform(vector_buffer, vector_dim, dim_idx); vector_buffer[dim_idx] = vector_buffer[dim_idx] * (float)Dmat[((batch_idx * 3 + 1) * num_part + part_idx) * vector_dim + dim_idx]; fast_hadamard_transform(vector_buffer, vector_dim, dim_idx); vector_buffer[dim_idx] = vector_buffer[dim_idx] * (float)Dmat[((batch_idx * 3 + 2) * num_part + part_idx) * vector_dim + dim_idx]; fast_hadamard_transform(vector_buffer, vector_dim, dim_idx); int num_hash_per_part = vector_dim / hash_code_len; if (hash_code_len == 8 || hash_code_len == 16) { int code = select(vector_buffer[dim_idx] > 0, 1 << (dim_idx % hash_code_len), 0); for (int offset = 1; offset < hash_code_len; offset = offset * 2) { code += __shfl_xor_sync(FULL_MASK, code, offset); } if (dim_idx % hash_code_len == 0) { int hash_f_idx = part_idx * num_hash_per_part + dim_idx / hash_code_len; if (hash_f_idx < num_hash_f) { hash_code[batch_idx__vector_idx * num_hash_f + hash_f_idx] = code; } } } else { vector_buffer[dim_idx] = select(vector_buffer[dim_idx] > 0, 1 << (dim_idx % hash_code_len), 0); __syncthreads(); if (dim_idx < num_hash_per_part) { int code = 0; for (int i = 0; i < hash_code_len; i++) { code += vector_buffer[dim_idx * hash_code_len + i]; } int hash_f_idx = part_idx * num_hash_per_part + dim_idx; if (hash_f_idx < num_hash_f) { hash_code[batch_idx__vector_idx * num_hash_f + hash_f_idx] = code; } } } } __global__ void lsh_cumulation_ver1_step1_cuda_kernel( int *key_mask, // [batch_size, num_key] int *key_hash_code, // [batch_size, num_key, num_hash_f] float *value, // [batch_size, num_key, value_dim] float *hashtable_value, // [batch_size, num_hash_f, hashtable_capacity, WARP_SIZE] int batch_size, int num_hash_f, int hashtable_capacity, int num_key, int value_dim, int offset_warp ) { int warp_thread_idx = threadIdx.x; int batch_idx = blockIdx.y; int key_idx = blockIdx.x * blockDim.y + threadIdx.y; int batch_idx__key_idx = batch_idx * num_key + key_idx; if (key_mask[batch_idx__key_idx] == 0) { return; } if (num_hash_f > WARP_SIZE) { float warp_value = value[batch_idx__key_idx * value_dim + offset_warp + warp_thread_idx]; for (int hash_f_start = 0; hash_f_start < num_hash_f; hash_f_start = hash_f_start + WARP_SIZE) { int warp_hashcode = key_hash_code[batch_idx__key_idx * num_hash_f + hash_f_start + warp_thread_idx]; #pragma unroll for (int hash_f_offset = 0; hash_f_offset < WARP_SIZE; hash_f_offset++) { int current_hashcode = warp_hashcode; current_hashcode = __shfl_sync(FULL_MASK, current_hashcode, hash_f_offset); int hashtable_idx = (batch_idx * num_hash_f + (hash_f_start + hash_f_offset)) * hashtable_capacity + current_hashcode; atomicAdd(&hashtable_value[hashtable_idx * WARP_SIZE + warp_thread_idx], warp_value); } } } else { float warp_value = value[batch_idx__key_idx * value_dim + offset_warp + warp_thread_idx]; int warp_hashcode = 0; if (warp_thread_idx < num_hash_f) { warp_hashcode = key_hash_code[batch_idx__key_idx * num_hash_f + warp_thread_idx]; } for (int hash_f_idx = 0; hash_f_idx < num_hash_f; hash_f_idx++) { int current_hashcode = warp_hashcode; current_hashcode = __shfl_sync(FULL_MASK, current_hashcode, hash_f_idx); int hashtable_idx = (batch_idx * num_hash_f + hash_f_idx) * hashtable_capacity + current_hashcode; atomicAdd(&hashtable_value[hashtable_idx * WARP_SIZE + warp_thread_idx], warp_value); } } } __global__ void lsh_cumulation_ver1_step2_cuda_kernel( int *query_mask, // [batch_size, num_query] int *query_hash_code, // [batch_size, num_query, num_hash_f] float *hashtable_value, // [batch_size, num_hash_f, hashtable_capacity, WARP_SIZE] float *cumulation_value, // [batch_size, num_query, value_dim] int batch_size, int num_hash_f, int hashtable_capacity, int num_query, int value_dim, int offset_warp ) { int warp_thread_idx = threadIdx.x; int batch_idx = blockIdx.y; int query_idx = blockIdx.x * blockDim.y + threadIdx.y; int batch_idx__query_idx = batch_idx * num_query + query_idx; if (query_mask[batch_idx__query_idx] == 0) { return; } if (num_hash_f > WARP_SIZE) { float warp_value = 0; for (int hash_f_start = 0; hash_f_start < num_hash_f; hash_f_start = hash_f_start + WARP_SIZE) { int warp_hashcode = query_hash_code[batch_idx__query_idx * num_hash_f + hash_f_start + warp_thread_idx]; #pragma unroll for (int hash_f_offset = 0; hash_f_offset < WARP_SIZE; hash_f_offset++) { int current_hashcode = warp_hashcode; current_hashcode = __shfl_sync(FULL_MASK, current_hashcode, hash_f_offset); int hashtable_idx = (batch_idx * num_hash_f + (hash_f_start + hash_f_offset)) * hashtable_capacity + current_hashcode; warp_value = warp_value + hashtable_value[hashtable_idx * WARP_SIZE + warp_thread_idx]; } } cumulation_value[batch_idx__query_idx * value_dim + offset_warp + warp_thread_idx] = warp_value / float(num_hash_f); } else { float warp_value = 0; int warp_hashcode = 0; if (warp_thread_idx < num_hash_f) { warp_hashcode = query_hash_code[batch_idx__query_idx * num_hash_f + warp_thread_idx]; } for (int hash_f_idx = 0; hash_f_idx < num_hash_f; hash_f_idx++) { int current_hashcode = warp_hashcode; current_hashcode = __shfl_sync(FULL_MASK, current_hashcode, hash_f_idx); int hashtable_idx = (batch_idx * num_hash_f + hash_f_idx) * hashtable_capacity + current_hashcode; warp_value = warp_value + hashtable_value[hashtable_idx * WARP_SIZE + warp_thread_idx]; } cumulation_value[batch_idx__query_idx * value_dim + offset_warp + warp_thread_idx] = warp_value / float(num_hash_f); } } __global__ void lsh_weighted_cumulation_ver1_step1_cuda_kernel( int *key_mask, // [batch_size, num_key] int *key_hash_code, // [batch_size, num_key, num_hash_f] float *key_weight, // [batch_size, num_key, weight_dim] float *value, // [batch_size, num_key, value_dim] float *hashtable_value, // [batch_size, num_hash_f, hashtable_capacity, WARP_SIZE] int batch_size, int num_hash_f, int hashtable_capacity, int num_key, int value_dim, int weight_dim, int offset_warp, int weight_idx ) { int warp_thread_idx = threadIdx.x; int batch_idx = blockIdx.y; int key_idx = blockIdx.x * blockDim.y + threadIdx.y; int batch_idx__key_idx = batch_idx * num_key + key_idx; if (key_mask[batch_idx__key_idx] == 0) { return; } if (num_hash_f > WARP_SIZE) { float warp_value = key_weight[batch_idx__key_idx * weight_dim + weight_idx] * value[batch_idx__key_idx * value_dim + offset_warp + warp_thread_idx]; for (int hash_f_start = 0; hash_f_start < num_hash_f; hash_f_start = hash_f_start + WARP_SIZE) { int warp_hashcode = key_hash_code[batch_idx__key_idx * num_hash_f + hash_f_start + warp_thread_idx]; #pragma unroll for (int hash_f_offset = 0; hash_f_offset < WARP_SIZE; hash_f_offset++) { int current_hashcode = warp_hashcode; current_hashcode = __shfl_sync(FULL_MASK, current_hashcode, hash_f_offset); int hashtable_idx = (batch_idx * num_hash_f + (hash_f_start + hash_f_offset)) * hashtable_capacity + current_hashcode; atomicAdd(&hashtable_value[hashtable_idx * WARP_SIZE + warp_thread_idx], warp_value); } } } else { float warp_value = key_weight[batch_idx__key_idx * weight_dim + weight_idx] * value[batch_idx__key_idx * value_dim + offset_warp + warp_thread_idx]; int warp_hashcode = 0; if (warp_thread_idx < num_hash_f) { warp_hashcode = key_hash_code[batch_idx__key_idx * num_hash_f + warp_thread_idx]; } for (int hash_f_idx = 0; hash_f_idx < num_hash_f; hash_f_idx++) { int current_hashcode = warp_hashcode; current_hashcode = __shfl_sync(FULL_MASK, current_hashcode, hash_f_idx); int hashtable_idx = (batch_idx * num_hash_f + hash_f_idx) * hashtable_capacity + current_hashcode; atomicAdd(&hashtable_value[hashtable_idx * WARP_SIZE + warp_thread_idx], warp_value); } } } __global__ void lsh_weighted_cumulation_ver1_step2_cuda_kernel( int *query_mask, // [batch_size, num_query] int *query_hash_code, // [batch_size, num_query, num_hash_f] float *query_weight, // [batch_size, num_query, weight_dim] float *hashtable_value, // [batch_size, num_hash_f, hashtable_capacity, WARP_SIZE] float *cumulation_value, // [batch_size, num_query, value_dim] int batch_size, int num_hash_f, int hashtable_capacity, int num_query, int value_dim, int weight_dim, int offset_warp, int weight_idx ) { int warp_thread_idx = threadIdx.x; int batch_idx = blockIdx.y; int query_idx = blockIdx.x * blockDim.y + threadIdx.y; int batch_idx__query_idx = batch_idx * num_query + query_idx; if (query_mask[batch_idx__query_idx] == 0) { return; } if (num_hash_f > WARP_SIZE) { float warp_value = 0; for (int hash_f_start = 0; hash_f_start < num_hash_f; hash_f_start = hash_f_start + WARP_SIZE) { int warp_hashcode = query_hash_code[batch_idx__query_idx * num_hash_f + hash_f_start + warp_thread_idx]; #pragma unroll for (int hash_f_offset = 0; hash_f_offset < WARP_SIZE; hash_f_offset++) { int current_hashcode = warp_hashcode; current_hashcode = __shfl_sync(FULL_MASK, current_hashcode, hash_f_offset); int hashtable_idx = (batch_idx * num_hash_f + (hash_f_start + hash_f_offset)) * hashtable_capacity + current_hashcode; warp_value = warp_value + hashtable_value[hashtable_idx * WARP_SIZE + warp_thread_idx]; } } float warp_weight = query_weight[batch_idx__query_idx * weight_dim + weight_idx]; cumulation_value[batch_idx__query_idx * value_dim + offset_warp + warp_thread_idx] += warp_weight * warp_value / float(num_hash_f); } else { float warp_value = 0; int warp_hashcode = 0; if (warp_thread_idx < num_hash_f) { warp_hashcode = query_hash_code[batch_idx__query_idx * num_hash_f + warp_thread_idx]; } for (int hash_f_idx = 0; hash_f_idx < num_hash_f; hash_f_idx++) { int current_hashcode = warp_hashcode; current_hashcode = __shfl_sync(FULL_MASK, current_hashcode, hash_f_idx); int hashtable_idx = (batch_idx * num_hash_f + hash_f_idx) * hashtable_capacity + current_hashcode; warp_value = warp_value + hashtable_value[hashtable_idx * WARP_SIZE + warp_thread_idx]; } float warp_weight = query_weight[batch_idx__query_idx * weight_dim + weight_idx]; cumulation_value[batch_idx__query_idx * value_dim + offset_warp + warp_thread_idx] += warp_weight * warp_value / float(num_hash_f); } } __global__ void count_sort_step1_cuda_kernel( int *key_mask, // [batch_size, num_key] int *key_hash_code, // [batch_size, num_key, num_hash_f] int *count_sort_table, // [batch_size, num_hash_f, hashtable_capacity] int batch_size, int num_hash_f, int hashtable_capacity, int num_key ) { int batch_idx = blockIdx.y; int key_idx = blockIdx.x * blockDim.y + threadIdx.y; int hash_f_idx = threadIdx.x; int batch_idx__key_idx = batch_idx * num_key + key_idx; if (key_mask[batch_idx__key_idx] == 0) { return; } int hash_code = key_hash_code[batch_idx__key_idx * num_hash_f + hash_f_idx]; atomicAdd(&count_sort_table[(batch_idx * num_hash_f + hash_f_idx) * hashtable_capacity + hash_code], 1); } __global__ void count_sort_step2_cuda_kernel( int *count_sort_table, // [batch_size, num_hash_f, hashtable_capacity] int batch_size, int num_hash_f, int hashtable_capacity ) { int batch_idx = blockIdx.y; int hash_f_idx = blockIdx.x; int num_threads = blockDim.x; int thread_id = threadIdx.x; int batch_idx__hash_f_idx = batch_idx * num_hash_f + hash_f_idx; extern __shared__ float buffer[]; int *table_buffer = (int*)buffer; if (thread_id == 0) { table_buffer[0] = 0; } copy_data(&count_sort_table[batch_idx__hash_f_idx * hashtable_capacity], &table_buffer[1], hashtable_capacity - 1, num_threads, thread_id); for (int table_idx_start = 0; table_idx_start < hashtable_capacity; table_idx_start = table_idx_start + num_threads) { int thread_value = table_buffer[table_idx_start + thread_id]; int next_thread_value = 0; for (int offset = 1; offset < WARP_SIZE; offset = offset << 1) { next_thread_value = __shfl_up_sync(FULL_MASK, thread_value, offset); if (thread_id % WARP_SIZE >= offset) { thread_value = thread_value + next_thread_value; } } table_buffer[table_idx_start + thread_id] = thread_value; } __syncthreads(); if (hashtable_capacity > WARP_SIZE) { if (thread_id < WARP_SIZE) { for (int table_idx_start = WARP_SIZE; table_idx_start < hashtable_capacity; table_idx_start = table_idx_start + WARP_SIZE) { table_buffer[table_idx_start + thread_id] += table_buffer[table_idx_start - 1]; } } } copy_data(table_buffer, &count_sort_table[batch_idx__hash_f_idx * hashtable_capacity], hashtable_capacity, num_threads, thread_id); } __global__ void count_sort_step3_cuda_kernel( int *key_mask, // [batch_size, num_key] int *key_hash_code, // [batch_size, num_key, num_hash_f] int *count_sort_table, // [batch_size, num_hash_f, hashtable_capacity] int *key_sorted_idxes, // [batch_size, num_hash_f, num_key] int batch_size, int num_hash_f, int hashtable_capacity, int num_key ) { int batch_idx = blockIdx.y; int key_idx = blockIdx.x * blockDim.y + threadIdx.y; int hash_f_idx = threadIdx.x; int batch_idx__key_idx = batch_idx * num_key + key_idx; if (key_mask[batch_idx__key_idx] == 0) { return; } int batch_idx__hash_f_idx = batch_idx * num_hash_f + hash_f_idx; int hash_code = key_hash_code[batch_idx__key_idx * num_hash_f + hash_f_idx]; int sort_idx = atomicAdd(&count_sort_table[batch_idx__hash_f_idx * hashtable_capacity + hash_code], 1); key_sorted_idxes[batch_idx__hash_f_idx * num_key + sort_idx] = key_idx; } __global__ void extract_query_info_cuda_kernel( int *query_mask, // [batch_size, num_query] int *query_hash_code, // [batch_size, num_query, num_hash_f] int *count_sort_table, // [batch_size, num_hash_f, hashtable_capacity] int *query_info, // [batch_size, num_query, 2, num_hash_f] int batch_size, int num_hash_f, int hashtable_capacity, int num_query ) { int batch_idx = blockIdx.y; int query_idx = blockIdx.x * blockDim.y + threadIdx.y; int hash_f_idx = threadIdx.x; int batch_idx__query_idx = batch_idx * num_query + query_idx; if (query_mask[batch_idx__query_idx] == 0) { return; } int hash_code = query_hash_code[batch_idx__query_idx * num_hash_f + hash_f_idx]; int batch_idx__hash_f_idx__hash_code = (batch_idx * num_hash_f + hash_f_idx) * hashtable_capacity + hash_code; int key_offset = select(hash_code == 0, 0, count_sort_table[batch_idx__hash_f_idx__hash_code - 1]); int key_count = count_sort_table[batch_idx__hash_f_idx__hash_code] - key_offset; query_info[batch_idx__query_idx * 2 * num_hash_f + hash_f_idx] = key_offset; query_info[(batch_idx__query_idx * 2 + 1) * num_hash_f + hash_f_idx] = key_count; } __global__ void lsh_weighted_cumulation_ver2_step2_cuda_kernel( int *query_mask, // [batch_size, num_query] int *query_info, // [batch_size, num_query, 2, num_hash_f] int *key_sorted_idxes, // [batch_size, num_hash_f, num_key] float *query_weight, // [batch_size, num_query, weight_dim] float *key_weight, // [batch_size, num_key, weight_dim] float *value, // [batch_size, num_key, value_dim] float *cumulation_value, // [batch_size, num_query, value_dim] int batch_size, int num_hash_f, int num_query, int num_key, int value_dim, int weight_dim ) { int batch_idx = blockIdx.z; int hash_f_idx = blockIdx.y; int query_idx = blockIdx.x; int num_threads = blockDim.y * blockDim.x; int thread_id = threadIdx.y * blockDim.x + threadIdx.x; int num_warps = blockDim.y; int warp_idx = threadIdx.y; int warp_thread_idx = threadIdx.x; int batch_idx__query_idx = batch_idx * num_query + query_idx; if (query_mask[batch_idx__query_idx] == 0) { return; } int key_offset = query_info[batch_idx__query_idx * 2 * num_hash_f + hash_f_idx]; int key_count = query_info[(batch_idx__query_idx * 2 + 1) * num_hash_f + hash_f_idx]; if (key_count == 0) { return; } extern __shared__ float buffer[]; if (key_count == 1) { if (warp_idx == 0) { int key_idx = key_sorted_idxes[(batch_idx * num_hash_f + hash_f_idx) * num_key + key_offset]; int batch_idx__key_idx = batch_idx * num_key + key_idx; float weight = 0; for (int weight_offset = 0; weight_offset < weight_dim; weight_offset = weight_offset + WARP_SIZE) { int weight_dim_idx = weight_offset + warp_thread_idx; float val = query_weight[batch_idx__query_idx * weight_dim + weight_dim_idx] * key_weight[batch_idx__key_idx * weight_dim + weight_dim_idx]; #pragma unroll for (int offset = 1; offset < WARP_SIZE; offset = offset << 1) { val += __shfl_xor_sync(FULL_MASK, val, offset); } weight = weight + val; } weight = weight / float(num_hash_f); for (int value_offset = 0; value_offset < value_dim; value_offset = value_offset + WARP_SIZE) { int value_dim_idx = value_offset + warp_thread_idx; float val = value[batch_idx__key_idx * value_dim + value_dim_idx]; atomicAdd(&cumulation_value[batch_idx__query_idx * value_dim + value_dim_idx], weight * val); } } } else { float *weight_buffer = buffer; int *key_idxes_buffer = (int*)&buffer[weight_dim]; copy_data_nonblocking(&query_weight[batch_idx__query_idx * weight_dim], weight_buffer, weight_dim, num_threads, thread_id); while (key_count > 0) { int work_size = min(WARP_SIZE, key_count); copy_data_nonblocking(&key_sorted_idxes[(batch_idx * num_hash_f + hash_f_idx) * num_key + key_offset], key_idxes_buffer, work_size, num_threads, thread_id); __syncthreads(); for (int work_offset = 0; work_offset < WARP_SIZE; work_offset = work_offset + num_warps) { int work_idx = work_offset + warp_idx; if (work_idx < key_count) { int key_idx = key_idxes_buffer[work_idx]; int batch_idx__key_idx = batch_idx * num_key + key_idx; float weight = 0; for (int weight_offset = 0; weight_offset < weight_dim; weight_offset = weight_offset + WARP_SIZE) { int weight_dim_idx = weight_offset + warp_thread_idx; float val = weight_buffer[weight_dim_idx] * key_weight[batch_idx__key_idx * weight_dim + weight_dim_idx]; #pragma unroll for (int offset = 1; offset < WARP_SIZE; offset = offset << 1) { val += __shfl_xor_sync(FULL_MASK, val, offset); } weight = weight + val; } weight = weight / float(num_hash_f); for (int value_offset = 0; value_offset < value_dim; value_offset = value_offset + WARP_SIZE) { int value_dim_idx = value_offset + warp_thread_idx; float val = value[batch_idx__key_idx * value_dim + value_dim_idx]; atomicAdd(&cumulation_value[batch_idx__query_idx * value_dim + value_dim_idx], weight * val); } } } key_count = key_count - work_size; key_offset = key_offset + work_size; } } } __global__ void lsh_weighted_cumulation_ver3_step2_cuda_kernel( int *query_sorted_idxes, // [batch_size, num_hash_f, num_query] int *key_mask, // [batch_size, num_key] int *key_info, // [batch_size, num_key, 2, num_hash_f] float *query_weight, // [batch_size, num_query, weight_dim] float *key_weight, // [batch_size, num_key, weight_dim] float *value, // [batch_size, num_key, value_dim] float *cumulation_value, // [batch_size, num_query, value_dim] int batch_size, int num_hash_f, int num_query, int num_key, int value_dim, int weight_dim ) { int batch_idx = blockIdx.z; int hash_f_idx = blockIdx.y; int key_idx = blockIdx.x; int num_threads = blockDim.y * blockDim.x; int thread_id = threadIdx.y * blockDim.x + threadIdx.x; int num_warps = blockDim.y; int warp_idx = threadIdx.y; int warp_thread_idx = threadIdx.x; int batch_idx__key_idx = batch_idx * num_key + key_idx; if (key_mask[batch_idx__key_idx] == 0) { return; } int query_offset = key_info[batch_idx__key_idx * 2 * num_hash_f + hash_f_idx]; int query_count = key_info[(batch_idx__key_idx * 2 + 1) * num_hash_f + hash_f_idx]; if (query_count == 0) { return; } extern __shared__ float buffer[]; if (query_count == 1) { if (warp_idx == 0) { int query_idx = query_sorted_idxes[(batch_idx * num_hash_f + hash_f_idx) * num_query + query_offset]; int batch_idx__query_idx = batch_idx * num_query + query_idx; float weight = 0; for (int weight_offset = 0; weight_offset < weight_dim; weight_offset = weight_offset + WARP_SIZE) { int weight_dim_idx = weight_offset + warp_thread_idx; float val = key_weight[batch_idx__key_idx * weight_dim + weight_dim_idx] * query_weight[batch_idx__query_idx * weight_dim + weight_dim_idx]; #pragma unroll for (int offset = 1; offset < WARP_SIZE; offset = offset << 1) { val += __shfl_xor_sync(FULL_MASK, val, offset); } weight = weight + val; } weight = weight / float(num_hash_f); for (int value_offset = 0; value_offset < value_dim; value_offset = value_offset + WARP_SIZE) { int value_dim_idx = value_offset + warp_thread_idx; float val = value[batch_idx__key_idx * value_dim + value_dim_idx]; atomicAdd(&cumulation_value[batch_idx__query_idx * value_dim + value_dim_idx], weight * val); } } } else { float *weight_buffer = buffer; float *value_buffer = &buffer[weight_dim]; int *query_idxes_buffer = (int*)&buffer[weight_dim + value_dim]; copy_data_nonblocking(&key_weight[batch_idx__key_idx * weight_dim], weight_buffer, weight_dim, num_threads, thread_id); copy_data_nonblocking(&value[batch_idx__key_idx * value_dim], value_buffer, value_dim, num_threads, thread_id); while (query_count > 0) { int work_size = min(WARP_SIZE, query_count); copy_data_nonblocking(&query_sorted_idxes[(batch_idx * num_hash_f + hash_f_idx) * num_query + query_offset], query_idxes_buffer, work_size, num_threads, thread_id); __syncthreads(); for (int work_offset = 0; work_offset < WARP_SIZE; work_offset = work_offset + num_warps) { int work_idx = work_offset + warp_idx; if (work_idx < query_count) { int query_idx = query_idxes_buffer[work_idx]; int batch_idx__query_idx = batch_idx * num_query + query_idx; float weight = 0; for (int weight_offset = 0; weight_offset < weight_dim; weight_offset = weight_offset + WARP_SIZE) { int weight_dim_idx = weight_offset + warp_thread_idx; float val = weight_buffer[weight_dim_idx] * query_weight[batch_idx__query_idx * weight_dim + weight_dim_idx]; #pragma unroll for (int offset = 1; offset < WARP_SIZE; offset = offset << 1) { val += __shfl_xor_sync(FULL_MASK, val, offset); } weight = weight + val; } weight = weight / float(num_hash_f); for (int value_offset = 0; value_offset < value_dim; value_offset = value_offset + WARP_SIZE) { int value_dim_idx = value_offset + warp_thread_idx; float val = value_buffer[value_dim_idx]; atomicAdd(&cumulation_value[batch_idx__query_idx * value_dim + value_dim_idx], weight * val); } } } query_count = query_count - work_size; query_offset = query_offset + work_size; } } } __global__ void lsh_weighted_cumulation_ver4_step2_cuda_kernel( int *query_sorted_idxes, // [batch_size, num_hash_f, num_query] int *key_mask, // [batch_size, num_key] int *key_info, // [batch_size, num_key, 2, num_hash_f] float *query_weight, // [batch_size, num_query, weight_dim] float *key_weight, // [batch_size, num_key, weight_dim] float *value, // [batch_size, num_key, value_dim] float *cumulation_value, // [batch_size, num_query, value_dim] int batch_size, int num_hash_f, int num_query, int num_key, int value_dim, int weight_dim ) { int batch_idx = blockIdx.y; int key_idx = blockIdx.x; int num_threads = blockDim.y * blockDim.x; int thread_id = threadIdx.y * blockDim.x + threadIdx.x; int num_warps = blockDim.y; int warp_idx = threadIdx.y; int warp_thread_idx = threadIdx.x; int batch_idx__key_idx = batch_idx * num_key + key_idx; if (key_mask[batch_idx__key_idx] == 0) { return; } extern __shared__ float buffer[]; float *weight_buffer = buffer; float *value_buffer = &buffer[weight_dim]; int *key_info_buffer = (int*)&buffer[weight_dim + value_dim]; copy_data_nonblocking(&key_weight[batch_idx__key_idx * weight_dim], weight_buffer, weight_dim, num_threads, thread_id); copy_data_nonblocking(&value[batch_idx__key_idx * value_dim], value_buffer, value_dim, num_threads, thread_id); copy_data_nonblocking(&key_info[batch_idx__key_idx * 2 * num_hash_f], key_info_buffer, 2 * num_hash_f, num_threads, thread_id); int *query_offset_buffer = key_info_buffer; int *query_count_buffer = &key_info_buffer[num_hash_f]; const int hashtable_size = 1024 + OPTIMAL_THREADS_PER_BLOCK; __shared__ int hashtable_query[hashtable_size]; __shared__ int hashtable_count[hashtable_size]; __shared__ int inserted_query[hashtable_size]; __shared__ int query_counter[1]; int hash_f_idx_base = 0; while (true) { init_buffer_nonblocking(EMPTY_VALUE, hashtable_query, hashtable_size, num_threads, thread_id); init_buffer_nonblocking(0, hashtable_count, hashtable_size, num_threads, thread_id); init_buffer_nonblocking(EMPTY_VALUE, inserted_query, hashtable_size, num_threads, thread_id); init_buffer_nonblocking(0, query_counter, 1, num_threads, thread_id); __syncthreads(); while (hash_f_idx_base < num_hash_f) { int hash_f_idx = hash_f_idx_base + warp_idx; int batch_idx__hash_f_idx = batch_idx * num_hash_f + hash_f_idx; int stop_flag = 0; int query_offset = query_offset_buffer[hash_f_idx]; int query_count = query_count_buffer[hash_f_idx]; while (query_count > 0) { int work_size = min(query_count, WARP_SIZE); // try inserting query to set and check whether the query is new int found_new_query = 0; int query_idx = -1; if (warp_thread_idx < work_size) { query_idx = query_sorted_idxes[batch_idx__hash_f_idx * num_query + query_offset + warp_thread_idx]; int slot = set_insert(hashtable_query, hashtable_size, query_idx); if (slot >= 0) { found_new_query = atomicAdd(&hashtable_count[slot], 1) == 0; } } // compute cumulative offset int position_offset = found_new_query; int next_position_offset = 0; #pragma unroll for (int offset = 1; offset < WARP_SIZE; offset = offset << 1) { next_position_offset = __shfl_up_sync(FULL_MASK, position_offset, offset); if (thread_id % WARP_SIZE >= offset) { position_offset = position_offset + next_position_offset; } } // get the inserted query list end index int inserted_query_base = 0; if (thread_id % WARP_SIZE == WARP_SIZE - 1) { inserted_query_base = atomicAdd(query_counter, position_offset); } inserted_query_base = __shfl_sync(FULL_MASK, inserted_query_base, WARP_SIZE - 1); // insert new queries to list int insert_idx = inserted_query_base + position_offset - 1; if (found_new_query) { inserted_query[insert_idx] = query_idx; } // remove inserted queries from list query_offset_buffer[hash_f_idx] += work_size; query_count_buffer[hash_f_idx] -= work_size; query_offset += work_size; query_count -= work_size; // if list is almost full, stop inserting if (inserted_query_base + OPTIMAL_THREADS_PER_BLOCK > hashtable_size) { stop_flag = 1; break; } } if (stop_flag) { break; } hash_f_idx_base = hash_f_idx_base + num_warps; } __syncthreads(); int num_distint_query = query_counter[0]; if (num_distint_query > 0) { for (int idx_base = 0; idx_base < num_distint_query; idx_base = idx_base + num_warps) { int idx = idx_base + warp_idx; if (idx < num_distint_query) { int query_idx = inserted_query[idx]; int batch_idx__query_idx = batch_idx * num_query + query_idx; int slot = set_lookup(hashtable_query, hashtable_size, query_idx); int duplicate_count = hashtable_count[slot]; float weight = 0; for (int weight_idx_base = 0; weight_idx_base < weight_dim; weight_idx_base = weight_idx_base + WARP_SIZE) { int weight_dim_idx = weight_idx_base + warp_thread_idx; float val = weight_buffer[weight_dim_idx] * query_weight[batch_idx__query_idx * weight_dim + weight_dim_idx]; #pragma unroll for (int offset = 1; offset < WARP_SIZE; offset = offset << 1) { val += __shfl_xor_sync(FULL_MASK, val, offset); } weight = weight + val; } weight = (float)duplicate_count * weight / float(num_hash_f); for (int value_idx_base = 0; value_idx_base < value_dim; value_idx_base = value_idx_base + WARP_SIZE) { int value_dim_idx = value_idx_base + warp_thread_idx; float val = value_buffer[value_dim_idx]; atomicAdd(&cumulation_value[batch_idx__query_idx * value_dim + value_dim_idx], weight * val); } } } } else { // all computation is completed if num_distint_query == 0 break; } __syncthreads(); } }