| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| #include "flash_attn.h" |
|
|
| #include "common.h" |
| #include "gemm.h" |
|
|
| |
|
|
| namespace { |
|
|
| template <typename scalar_t, int BLOCK_M, int BLOCK_N> |
| void flash_attn_kernel_impl( |
| scalar_t* __restrict__ out, |
| const scalar_t* __restrict__ q, |
| const scalar_t* __restrict__ k, |
| const scalar_t* __restrict__ v, |
| void* __restrict__ buffer, |
| int seqlen_q, |
| int seqlen_k, |
| int batches, |
| int num_heads, |
| int num_heads_kv, |
| int head_size, |
| int head_size_v, |
| int q_strideM, |
| int q_strideH, |
| int k_strideN, |
| int k_strideH, |
| int v_strideN, |
| int v_strideH, |
| float sm_scale, |
| int buffer_size_per_thread, |
| bool causal) { |
| |
| const int o_strideM = num_heads * head_size_v; |
| const int o_strideH = head_size_v; |
|
|
| |
| const int ldb_tmp = std::max(head_size, head_size_v); |
|
|
| const int num_groups = num_heads / num_heads_kv; |
| TORCH_CHECK(num_groups * num_heads_kv == num_heads); |
|
|
| |
| int MB = div_up(seqlen_q, BLOCK_M); |
|
|
| |
| parallel_for(batches * num_heads * MB, [&](int begin, int end) { |
| int bs{0}, head_id{0}, mb{0}; |
| data_index_init(begin, bs, batches, head_id, num_heads, mb, MB); |
|
|
| int tid = get_thread_num(); |
| |
| float* __restrict__ s_i = reinterpret_cast<float*>((char*)(buffer) + tid * buffer_size_per_thread); |
| scalar_t* __restrict__ s_delta = reinterpret_cast<scalar_t*>(s_i); |
|
|
| |
| float* __restrict__ v_prime = s_i + BLOCK_M * BLOCK_N; |
|
|
| |
| scalar_t* __restrict__ Btmp = reinterpret_cast<scalar_t*>(v_prime + BLOCK_M * head_size_v); |
|
|
| |
| fill_stub(Btmp, 0.f, BLOCK_N * ldb_tmp); |
|
|
| alignas(64) float s_prime[BLOCK_M]; |
| alignas(64) float m_prime[BLOCK_M]; |
|
|
| for (int i = begin; i < end; ++i) { |
| int seq_q_start_loc = bs * seqlen_q; |
| int seq_k_start_loc = bs * seqlen_k; |
|
|
| |
| int m = mb * BLOCK_M; |
| int m_size = std::min(BLOCK_M, seqlen_q - m); |
|
|
| assert(m_size > 0); |
|
|
| int head_kv_id = head_id / num_groups; |
|
|
| |
| const scalar_t* __restrict__ q_ptr = q + (seq_q_start_loc + m) * q_strideM + head_id * q_strideH; |
|
|
| |
| fill_stub(v_prime, 0.f, m_size * head_size_v); |
| fill_stub(s_prime, 0.f, m_size); |
| fill_stub(m_prime, -std::numeric_limits<scalar_t>::infinity(), m_size); |
|
|
| int num_keys = causal ? std::min(m + m_size, seqlen_k) : seqlen_k; |
| for (int n = 0; n < num_keys; n += BLOCK_N) { |
| int n_size = std::min(BLOCK_N, num_keys - n); |
|
|
| |
| const int padded_n_size = div_up(n_size, TILE_K) * TILE_K; |
|
|
| |
| pack_vnni<scalar_t>( |
| Btmp, |
| k + (seq_k_start_loc + n) * k_strideN + head_kv_id * k_strideH, |
| n_size, |
| head_size, |
| k_strideN, |
| BLOCK_N); |
|
|
| |
| at::native::cpublas::brgemm( |
| m_size, |
| n_size, |
| head_size, |
| q_strideM, |
| BLOCK_N, |
| BLOCK_N, |
| false, |
| q_ptr, |
| Btmp, |
| s_i); |
|
|
| |
| if (causal && num_keys - n <= BLOCK_N) { |
| for (int row = 0; row < m_size; ++row) { |
| int last_col = m + row - n; |
| |
| float* row_ptr = s_i + row * BLOCK_N; |
| fill_stub(row_ptr + last_col + 1, -std::numeric_limits<float>::infinity(), n_size - last_col - 1); |
| } |
| } |
|
|
| flash_attn_softmax<scalar_t, BLOCK_M, BLOCK_N>::apply( |
| s_i, s_delta, v_prime, s_prime, m_prime, m_size, n_size, padded_n_size, head_size_v, sm_scale); |
|
|
| |
| pack_vnni2<scalar_t>( |
| Btmp, |
| v + (seq_k_start_loc + n) * v_strideN + head_kv_id * v_strideH, |
| n_size, |
| head_size_v, |
| v_strideN, |
| head_size_v); |
|
|
| |
| at::native::cpublas::brgemm( |
| m_size, |
| head_size_v, |
| padded_n_size, |
| BLOCK_N, |
| head_size_v, |
| head_size_v, |
| true, |
| s_delta, |
| Btmp, |
| v_prime); |
| } |
|
|
| scalar_t* __restrict__ out_ptr = out + (seq_q_start_loc + m) * o_strideM + head_id * o_strideH; |
| for (int row = 0; row < m_size; ++row) { |
| float s = 1 / s_prime[row]; |
| copy_stub<scalar_t>(out_ptr + row * o_strideM, v_prime + row * head_size_v, s, head_size_v); |
| } |
|
|
| |
| data_index_step(bs, batches, head_id, num_heads, mb, MB); |
| } |
| at::native::cpublas::brgemm_release(); |
| }); |
| } |
|
|
| template <typename scalar_t, int BLOCK_M, int BLOCK_N> |
| void flash_attn_varlen_kernel_impl( |
| scalar_t* __restrict__ out, |
| const scalar_t* __restrict__ q, |
| const scalar_t* __restrict__ k, |
| const scalar_t* __restrict__ v, |
| const int32_t* __restrict__ cu_seqlens_q, |
| const int32_t* __restrict__ cu_seqlens_k, |
| void* __restrict__ buffer, |
| int32_t* __restrict__ indices, |
| int max_seqlen_q, |
| int max_seqlen_k, |
| int batches, |
| int num_heads, |
| int num_heads_kv, |
| int head_size, |
| int head_size_v, |
| int q_strideM, |
| int q_strideH, |
| int k_strideN, |
| int k_strideH, |
| int v_strideN, |
| int v_strideH, |
| float sm_scale, |
| int buffer_size_per_thread, |
| bool causal) { |
| |
| const int o_strideM = num_heads * head_size_v; |
| const int o_strideH = head_size_v; |
|
|
| |
| |
| int idx = 0; |
| for (int32_t bs = 0; bs < batches; ++bs) { |
| int32_t seqlen_q = cu_seqlens_q[bs + 1] - cu_seqlens_q[bs]; |
| int32_t seqlen_k = cu_seqlens_k[bs + 1] - cu_seqlens_k[bs]; |
| TORCH_CHECK(seqlen_q <= max_seqlen_q && seqlen_k <= max_seqlen_k); |
|
|
| int32_t blocks = div_up(seqlen_q, BLOCK_M); |
| for (int32_t offset = 0; offset < blocks; ++offset) { |
| indices[idx * 2 + 0] = bs; |
| indices[idx * 2 + 1] = offset; |
| idx++; |
| } |
| } |
| |
| int MB = idx; |
|
|
| |
| const int ldb_tmp = std::max(head_size, head_size_v); |
|
|
| const int num_groups = num_heads / num_heads_kv; |
| TORCH_CHECK(num_groups * num_heads_kv == num_heads); |
|
|
| |
| parallel_for(num_heads * MB, [&](int begin, int end) { |
| int head_id{0}, mb{0}; |
| data_index_init(begin, head_id, num_heads, mb, MB); |
|
|
| int tid = get_thread_num(); |
| |
| float* __restrict__ s_i = reinterpret_cast<float*>((char*)(buffer) + tid * buffer_size_per_thread); |
| scalar_t* __restrict__ s_delta = reinterpret_cast<scalar_t*>(s_i); |
|
|
| |
| float* __restrict__ v_prime = s_i + BLOCK_M * BLOCK_N; |
|
|
| |
| scalar_t* __restrict__ Btmp = reinterpret_cast<scalar_t*>(v_prime + BLOCK_M * head_size_v); |
|
|
| |
| fill_stub(Btmp, 0.f, BLOCK_N * ldb_tmp); |
|
|
| alignas(64) float s_prime[BLOCK_M]; |
| alignas(64) float m_prime[BLOCK_M]; |
|
|
| for (int i = begin; i < end; ++i) { |
| int32_t bs = indices[mb * 2 + 0]; |
| int32_t seq_q_start_loc = cu_seqlens_q[bs]; |
| int32_t seq_k_start_loc = cu_seqlens_k[bs]; |
| int32_t seqlen_q = cu_seqlens_q[bs + 1] - cu_seqlens_q[bs]; |
|
|
| |
| int m = indices[mb * 2 + 1] * BLOCK_M; |
| int m_size = std::min(BLOCK_M, seqlen_q - m); |
|
|
| assert(m_size > 0); |
|
|
| int head_kv_id = head_id / num_groups; |
|
|
| |
| const scalar_t* __restrict__ q_ptr = q + (seq_q_start_loc + m) * q_strideM + head_id * q_strideH; |
|
|
| |
| fill_stub(v_prime, 0.f, m_size * head_size_v); |
| fill_stub(s_prime, 0.f, m_size); |
| fill_stub(m_prime, -std::numeric_limits<scalar_t>::infinity(), m_size); |
|
|
| int seqlen_k = cu_seqlens_k[bs + 1] - cu_seqlens_k[bs]; |
| int num_keys = causal ? std::min(m + m_size, seqlen_k) : seqlen_k; |
| for (int n = 0; n < num_keys; n += BLOCK_N) { |
| int n_size = std::min(BLOCK_N, num_keys - n); |
|
|
| |
| const int padded_n_size = div_up(n_size, TILE_K) * TILE_K; |
|
|
| |
| pack_vnni<scalar_t>( |
| Btmp, |
| k + (seq_k_start_loc + n) * k_strideN + head_kv_id * k_strideH, |
| n_size, |
| head_size, |
| k_strideN, |
| BLOCK_N); |
|
|
| |
| at::native::cpublas::brgemm( |
| m_size, |
| n_size, |
| head_size, |
| q_strideM, |
| BLOCK_N, |
| BLOCK_N, |
| false, |
| q_ptr, |
| Btmp, |
| s_i); |
|
|
| |
| if (causal && num_keys - n <= BLOCK_N) { |
| for (int row = 0; row < m_size; ++row) { |
| int last_col = m + row - n; |
| |
| float* row_ptr = s_i + row * BLOCK_N; |
| fill_stub(row_ptr + last_col + 1, -std::numeric_limits<float>::infinity(), n_size - last_col - 1); |
| } |
| } |
|
|
| flash_attn_softmax<scalar_t, BLOCK_M, BLOCK_N>::apply( |
| s_i, s_delta, v_prime, s_prime, m_prime, m_size, n_size, padded_n_size, head_size_v, sm_scale); |
|
|
| |
| pack_vnni2<scalar_t>( |
| Btmp, |
| v + (seq_k_start_loc + n) * v_strideN + head_kv_id * v_strideH, |
| n_size, |
| head_size_v, |
| v_strideN, |
| head_size_v); |
|
|
| |
| at::native::cpublas::brgemm( |
| m_size, |
| head_size_v, |
| padded_n_size, |
| BLOCK_N, |
| head_size_v, |
| head_size_v, |
| true, |
| s_delta, |
| Btmp, |
| v_prime); |
| } |
|
|
| scalar_t* __restrict__ out_ptr = out + (seq_q_start_loc + m) * o_strideM + head_id * o_strideH; |
| for (int row = 0; row < m_size; ++row) { |
| float s = 1 / s_prime[row]; |
| copy_stub<scalar_t>(out_ptr + row * o_strideM, v_prime + row * head_size_v, s, head_size_v); |
| } |
|
|
| |
| data_index_step(head_id, num_heads, mb, MB); |
| } |
| at::native::cpublas::brgemm_release(); |
| }); |
| } |
|
|
| } |
|
|
| template <typename index_t> |
| inline bool has_varlen_sequences( |
| const at::Tensor& cu_seqlens_q, |
| const at::Tensor& cu_seqlens_k, |
| int batches, |
| index_t max_seqlen_q, |
| index_t max_seqlen_k) { |
| const index_t* cu_seqlens_q_data = cu_seqlens_q.data_ptr<index_t>(); |
| const index_t* cu_seqlens_k_data = cu_seqlens_k.data_ptr<index_t>(); |
|
|
| for (int bs = 0; bs < batches; ++bs) { |
| index_t seqlen_q = cu_seqlens_q_data[bs + 1] - cu_seqlens_q_data[bs]; |
| index_t seqlen_k = cu_seqlens_k_data[bs + 1] - cu_seqlens_k_data[bs]; |
| if (seqlen_q != max_seqlen_q || seqlen_k != max_seqlen_k) { |
| return true; |
| } |
| } |
| return false; |
| } |
|
|
| template <int BLOCK_M, int BLOCK_N> |
| inline int resize_buffer(at::Tensor& buffer, int num_threads, int head_size, int head_size_v) { |
| static_assert(BLOCK_M <= BLOCK_N, "Make sure BLOCK_M <= BLOCK_N to prevent buffer overflows during causal masking"); |
| const int size_per_thread = |
| BLOCK_M * BLOCK_N * sizeof(float) + |
| BLOCK_M * head_size_v * sizeof(float) + |
| BLOCK_N * std::max(head_size, head_size_v) * sizeof(uint16_t); |
|
|
| buffer.resize_({num_threads, size_per_thread}); |
| return size_per_thread; |
| } |
|
|
| template <int BLOCK_M> |
| inline void resize_indices(at::Tensor& indices, int num_seqs, int max_seqlen_q) { |
| |
| indices.resize_({num_seqs, div_up(max_seqlen_q, BLOCK_M), 2}); |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| at::Tensor flash_attn_varlen_func( |
| const at::Tensor& q, |
| const at::Tensor& k, |
| const at::Tensor& v, |
| const at::Tensor& cu_seqlens_q, |
| const at::Tensor& cu_seqlens_k, |
| int64_t max_seqlen_q, |
| int64_t max_seqlen_k, |
| bool causal) { |
| RECORD_FUNCTION( |
| "sgl_kernel::flash_attn_varlen_func", |
| std::vector<c10::IValue>({q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, causal})); |
|
|
| CHECK_LAST_DIM_CONTIGUOUS_INPUT(q); |
| CHECK_LAST_DIM_CONTIGUOUS_INPUT(k); |
| CHECK_LAST_DIM_CONTIGUOUS_INPUT(v); |
| CHECK_DIM(3, q); |
| CHECK_DIM(3, k); |
| CHECK_DIM(3, v); |
| CHECK_INPUT(cu_seqlens_q); |
| CHECK_INPUT(cu_seqlens_k); |
| CHECK_EQ(cu_seqlens_q.scalar_type(), at::kInt); |
| CHECK_EQ(cu_seqlens_k.scalar_type(), at::kInt); |
|
|
| int num_seqs = cu_seqlens_q.size(0) - 1; |
| int num_tokens = q.size(0); |
| int num_heads = q.size(1); |
| int num_heads_kv = k.size(1); |
| int head_size = q.size(2); |
| int head_size_v = v.size(2); |
|
|
| |
| int q_strideM = q.stride(0); |
| int q_strideH = q.stride(1); |
| int k_strideN = k.stride(0); |
| int k_strideH = k.stride(1); |
| int v_strideN = v.stride(0); |
| int v_strideH = v.stride(1); |
|
|
| |
| CHECK_EQ(k.size(2), head_size); |
| CHECK_EQ(v.size(1), num_heads_kv); |
| CHECK_EQ(cu_seqlens_k.size(0), num_seqs + 1); |
|
|
| |
| TORCH_CHECK(head_size % 2 == 0, "invalid head_size ", head_size); |
| TORCH_CHECK(head_size_v % 2 == 0, "invalid head_size_v ", head_size_v); |
|
|
| |
| double sm_scale = 1.0 / std::sqrt(static_cast<double>(head_size)); |
|
|
| |
| const bool is_varlen = |
| has_varlen_sequences<int32_t>(cu_seqlens_q, cu_seqlens_k, num_seqs, max_seqlen_q, max_seqlen_k); |
|
|
| int num_threads = at::get_num_threads(); |
| at::Tensor buffer = at::empty({}, q.options().dtype(at::kChar)); |
| at::Tensor indices = at::empty({}, q.options().dtype(at::kInt)); |
| at::Tensor out = at::empty({num_tokens, num_heads, head_size_v}, q.options()); |
|
|
| |
| constexpr int BLOCK_M = 512; |
| constexpr int BLOCK_N = 768; |
|
|
| AT_DISPATCH_REDUCED_FLOATING_TYPES(q.scalar_type(), "flash_attn_varlen_func", [&] { |
| int sz = resize_buffer<BLOCK_M, BLOCK_N>(buffer, num_threads, head_size, head_size_v); |
|
|
| if (is_varlen) { |
| resize_indices<BLOCK_M>(indices, num_seqs, max_seqlen_q); |
| flash_attn_varlen_kernel_impl<scalar_t, BLOCK_M, BLOCK_N>( |
| out.data_ptr<scalar_t>(), |
| q.data_ptr<scalar_t>(), |
| k.data_ptr<scalar_t>(), |
| v.data_ptr<scalar_t>(), |
| cu_seqlens_q.data_ptr<int32_t>(), |
| cu_seqlens_k.data_ptr<int32_t>(), |
| buffer.data_ptr(), |
| indices.data_ptr<int32_t>(), |
| max_seqlen_q, |
| max_seqlen_k, |
| num_seqs, |
| num_heads, |
| num_heads_kv, |
| head_size, |
| head_size_v, |
| q_strideM, |
| q_strideH, |
| k_strideN, |
| k_strideH, |
| v_strideN, |
| v_strideH, |
| sm_scale, |
| sz, |
| causal); |
| } else { |
| flash_attn_kernel_impl<scalar_t, BLOCK_M, BLOCK_N>( |
| out.data_ptr<scalar_t>(), |
| q.data_ptr<scalar_t>(), |
| k.data_ptr<scalar_t>(), |
| v.data_ptr<scalar_t>(), |
| buffer.data_ptr(), |
| max_seqlen_q, |
| max_seqlen_k, |
| num_seqs, |
| num_heads, |
| num_heads_kv, |
| head_size, |
| head_size_v, |
| q_strideM, |
| q_strideH, |
| k_strideN, |
| k_strideH, |
| v_strideN, |
| v_strideH, |
| sm_scale, |
| sz, |
| causal); |
| } |
| }); |
|
|
| return out; |
| } |
|
|