Spaces:
Sleeping
Sleeping
| // Copyright 2021 AlQuraishi Laboratory | |
| // | |
| // Licensed under the Apache License, Version 2.0 (the "License"); | |
| // you may not use this file except in compliance with the License. | |
| // You may obtain a copy of the License at | |
| // | |
| // http://www.apache.org/licenses/LICENSE-2.0 | |
| // | |
| // Unless required by applicable law or agreed to in writing, software | |
| // distributed under the License is distributed on an "AS IS" BASIS, | |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| // See the License for the specific language governing permissions and | |
| // limitations under the License. | |
| // modified from fastfold/model/fastnn/kernel/cuda_native/csrc/softmax_cuda_kernel.cu | |
| __inline__ __device__ float WarpAllReduceMax(float val) { | |
| for (int mask = 1; mask < 32; mask *= 2) { | |
| val = max(val, __shfl_xor_sync(0xffffffff, val, mask)); | |
| } | |
| return val; | |
| } | |
| __inline__ __device__ float WarpAllReduceSum(float val) { | |
| for (int mask = 1; mask < 32; mask *= 2) { | |
| val += __shfl_xor_sync(0xffffffff, val, mask); | |
| } | |
| return val; | |
| } | |
| template<typename T> | |
| __global__ void attn_softmax_inplace_( | |
| T *input, | |
| long long rows, int cols | |
| ) { | |
| int threadidx_x = threadIdx.x / 32; | |
| int threadidx_y = threadIdx.x % 32; | |
| long long row_offset = (long long)(blockIdx.x * 4 + threadidx_x); | |
| int cols_per_thread = (cols + 31) / 32; | |
| int cols_this_thread = cols_per_thread; | |
| int last_y = (cols / cols_per_thread); | |
| if (threadidx_y == last_y) { | |
| cols_this_thread = cols - cols_per_thread * last_y; | |
| } | |
| else if (threadidx_y > last_y) { | |
| cols_this_thread = 0; | |
| } | |
| float buf[32]; | |
| int lane_id = threadidx_y; | |
| if (row_offset < rows) { | |
| T *row_input = input + row_offset * cols; | |
| T *row_output = row_input; | |
| for (int i = 0; i < cols_this_thread; i++) { | |
| int idx = lane_id * cols_per_thread + i; | |
| buf[i] = static_cast<float>(row_input[idx]); | |
| } | |
| float thread_max = -1 * CUDART_INF_F; | |
| for (int i = 0; i < cols_this_thread; i++) { | |
| thread_max = max(thread_max, buf[i]); | |
| } | |
| float warp_max = WarpAllReduceMax(thread_max); | |
| float thread_sum = 0.f; | |
| for (int i = 0; i < cols_this_thread; i++) { | |
| buf[i] = __expf(buf[i] - warp_max); | |
| thread_sum += buf[i]; | |
| } | |
| float warp_sum = WarpAllReduceSum(thread_sum); | |
| for (int i = 0; i < cols_this_thread; i++) { | |
| row_output[lane_id * cols_per_thread + i] = | |
| static_cast<T>(__fdividef(buf[i], warp_sum)); | |
| } | |
| } | |
| } | |
| void attn_softmax_inplace_forward_( | |
| at::Tensor input, | |
| long long rows, int cols | |
| ) { | |
| CHECK_INPUT(input); | |
| const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); | |
| int grid = (rows + 3) / 4; | |
| dim3 block(128); | |
| if (input.dtype() == torch::kFloat32) { | |
| attn_softmax_inplace_<float><<<grid, block>>>( | |
| (float *)input.data_ptr(), | |
| rows, cols | |
| ); | |
| } | |
| else { | |
| attn_softmax_inplace_<at::BFloat16><<<grid, block>>>( | |
| (at::BFloat16 *)input.data_ptr(), | |
| rows, cols | |
| ); | |
| } | |
| } | |
| template<typename T> | |
| __global__ void attn_softmax_inplace_grad_( | |
| T *output, | |
| T *d_ov, | |
| T *values, | |
| long long rows, | |
| int cols_output, | |
| int cols_values | |
| ) { | |
| int threadidx_x = threadIdx.x / 32; | |
| int threadidx_y = threadIdx.x % 32; | |
| long long row_offset = (long long)(blockIdx.x * 4 + threadidx_x); | |
| int cols_per_thread = (cols_output + 31) / 32; | |
| int cols_this_thread = cols_per_thread; | |
| int rows_values = cols_output; | |
| // values are set to the beginning of the current | |
| // rows_values x cols_values leaf matrix | |
| long long value_row_offset = row_offset - row_offset % rows_values; | |
| int last_y = (cols_output / cols_per_thread); | |
| if (threadidx_y == last_y) { | |
| cols_this_thread = cols_output - cols_per_thread * last_y; | |
| } | |
| else if (threadidx_y > last_y) { | |
| cols_this_thread = 0; | |
| } | |
| float y_buf[32]; | |
| float dy_buf[32]; | |
| int lane_id = threadidx_y; | |
| if (row_offset < rows) { | |
| T *row_output = output + row_offset * cols_output; | |
| T *row_d_ov = d_ov + row_offset * cols_values; | |
| T *row_values = values + value_row_offset * cols_values; | |
| float thread_max = -1 * CUDART_INF_F; | |
| // Compute a chunk of the output gradient on the fly | |
| int value_row_idx = 0; | |
| int value_idx = 0; | |
| for (int i = 0; i < cols_this_thread; i++) { | |
| T sum = 0.; | |
| for (int j = 0; j < cols_values; j++) { | |
| value_row_idx = ((lane_id * cols_per_thread) + i); | |
| value_idx = value_row_idx * cols_values + j; | |
| sum += row_d_ov[j] * row_values[value_idx]; | |
| } | |
| dy_buf[i] = static_cast<float>(sum); | |
| } | |
| for (int i = 0; i < cols_this_thread; i++) { | |
| y_buf[i] = static_cast<float>(row_output[lane_id * cols_per_thread + i]); | |
| } | |
| float thread_sum = 0.; | |
| for (int i = 0; i < cols_this_thread; i++) { | |
| thread_sum += y_buf[i] * dy_buf[i]; | |
| } | |
| float warp_sum = WarpAllReduceSum(thread_sum); | |
| for (int i = 0; i < cols_this_thread; i++) { | |
| row_output[lane_id * cols_per_thread + i] = static_cast<T>( | |
| (dy_buf[i] - warp_sum) * y_buf[i] | |
| ); | |
| } | |
| } | |
| } | |
| void attn_softmax_inplace_backward_( | |
| at::Tensor output, | |
| at::Tensor d_ov, | |
| at::Tensor values, | |
| long long rows, | |
| int cols_output, | |
| int cols_values | |
| ) { | |
| CHECK_INPUT(output); | |
| CHECK_INPUT(d_ov); | |
| CHECK_INPUT(values); | |
| const at::cuda::OptionalCUDAGuard device_guard(device_of(output)); | |
| int grid = (rows + 3) / 4; | |
| dim3 block(128); | |
| if (output.dtype() == torch::kFloat32) { | |
| attn_softmax_inplace_grad_<float><<<grid, block>>>( | |
| (float *)output.data_ptr(), | |
| (float *)d_ov.data_ptr(), | |
| (float *)values.data_ptr(), | |
| rows, cols_output, cols_values | |
| ); | |
| } else { | |
| attn_softmax_inplace_grad_<at::BFloat16><<<grid, block>>>( | |
| (at::BFloat16 *)output.data_ptr(), | |
| (at::BFloat16 *)d_ov.data_ptr(), | |
| (at::BFloat16 *)values.data_ptr(), | |
| rows, cols_output, cols_values | |
| ); | |
| } | |
| } | |