|
|
|
|
|
|
|
|
|
#pragma once |
|
|
|
#include <cute/tensor.hpp> |
|
|
|
#include "cutlass/fast_math.h" |
|
|
|
#include "utils.h" |
|
|
|
namespace flash { |
|
|
|
using namespace cute; |
|
|
|
template <int kBlockM, int kBlockN, bool PackGQA, typename TiledMma, bool SwapAB=false> |
|
struct Mask { |
|
|
|
static_assert(!(PackGQA && SwapAB), "Cannot be both PackGQA and SwapAB"); |
|
|
|
int const thread_idx; |
|
int const seqlen_q, seqlen_k; |
|
int const window_size_left, window_size_right, sink_token_length; |
|
cutlass::FastDivmod const attention_chunk_divmod; |
|
cutlass::FastDivmod const qhead_per_khead_divmod; |
|
|
|
CUTLASS_DEVICE |
|
Mask(const int thread_idx, const int seqlen_q, const int seqlen_k, |
|
const int window_size_left, const int window_size_right, const int sink_token_length, |
|
cutlass::FastDivmod const &attention_chunk_divmod, |
|
cutlass::FastDivmod const &qhead_per_khead_divmod) |
|
: thread_idx(thread_idx) |
|
, seqlen_q(seqlen_q) |
|
, seqlen_k(seqlen_k) |
|
, window_size_left(window_size_left) |
|
, window_size_right(window_size_right) |
|
, sink_token_length(sink_token_length) |
|
, attention_chunk_divmod(attention_chunk_divmod) |
|
, qhead_per_khead_divmod(qhead_per_khead_divmod) |
|
{ |
|
}; |
|
|
|
template <bool Seqlenk_mask=false, bool Causal_mask=false, bool Local_mask=false, |
|
typename Engine, typename Layout> |
|
CUTLASS_DEVICE |
|
void apply(Tensor<Engine, Layout> &tSrS, const int m_block, const int n_block) const { |
|
static_assert(!(Causal_mask && Local_mask), "Cannot be both causal and local"); |
|
static_assert(Layout::rank == 3, "Only support 3D Tensor"); |
|
if (!Seqlenk_mask && !Causal_mask && !Local_mask) { return; } |
|
|
|
auto thread_mma = TiledMma{}.get_thread_slice(thread_idx); |
|
auto thread0_mma = TiledMma{}.get_thread_slice(_0{}); |
|
|
|
static constexpr int Row = !SwapAB ? 0 : 1, Col = !SwapAB ? 1 : 0; |
|
|
|
Tensor cS = cute::make_identity_tensor(Shape<Int<!SwapAB ? kBlockM : kBlockN>, Int<!SwapAB ? kBlockN : kBlockM>>{}); |
|
Tensor tScS = thread_mma.partition_C(cS); |
|
Tensor tSrS_rowcol = make_tensor(tSrS.data(), flash::convert_layout_acc_rowcol<SwapAB>(tSrS.layout())); |
|
Tensor tScS_rowcol = make_tensor(tScS.data(), flash::convert_layout_acc_rowcol<SwapAB>(tScS.layout())); |
|
Tensor t0ScS = thread0_mma.partition_C(cS); |
|
Tensor t0ScS_rowcol = make_tensor(t0ScS.data(), flash::convert_layout_acc_rowcol<SwapAB>(t0ScS.layout())); |
|
|
|
|
|
int const thread_col_offset = get<Col>(tScS_rowcol(_0{}, _0{})); |
|
int const seqlenk_col_limit = seqlen_k - n_block * kBlockN - thread_col_offset; |
|
if constexpr (!Causal_mask && !Local_mask) { |
|
if constexpr (Seqlenk_mask) { |
|
#pragma unroll |
|
for (int n = 0; n < size<1>(tSrS_rowcol); ++n) { |
|
if (int(get<Col>(t0ScS_rowcol(_0{}, n))) >= seqlenk_col_limit) { |
|
#pragma unroll |
|
for (int m = 0; m < size<0>(tSrS_rowcol); ++m) { tSrS_rowcol(m, n) = -INFINITY; } |
|
} |
|
} |
|
} |
|
} else { |
|
if constexpr (!SwapAB) { |
|
|
|
static constexpr int kMmaThreadsPerRow = size<0, 0>(typename TiledMma::AtomLayoutC_TV{}); |
|
static_assert(cutlass::NumThreadsPerWarp % kMmaThreadsPerRow == 0); |
|
static_assert(!PackGQA || CUTE_STATIC_V(size<0>(tSrS_rowcol)) <= kMmaThreadsPerRow); |
|
int mma_m_idx; |
|
|
|
if constexpr (PackGQA) { |
|
mma_m_idx = qhead_per_khead_divmod.divide(m_block * kBlockM + get<Row>(tScS_rowcol(thread_idx % kMmaThreadsPerRow, _0{}))); |
|
} |
|
int const causal_row_offset = 1 + seqlen_k - n_block * kBlockN - seqlen_q - thread_col_offset; |
|
if constexpr (Causal_mask) { |
|
#pragma unroll |
|
for (int m = 0; m < size<0>(tSrS_rowcol); ++m) { |
|
int const row_idx = !PackGQA |
|
? get<Row>(tScS_rowcol(m, _0{})) + m_block * kBlockM |
|
: __shfl_sync(0xffffffff, mma_m_idx, m % kMmaThreadsPerRow, kMmaThreadsPerRow); |
|
int const col_limit_right = !Seqlenk_mask |
|
? row_idx + causal_row_offset |
|
: __viaddmin_s32(row_idx, causal_row_offset, seqlenk_col_limit); |
|
#pragma unroll |
|
for (int n = 0; n < size<1>(tSrS_rowcol); ++n) { |
|
if (int(get<Col>(t0ScS_rowcol(_0{}, n))) >= col_limit_right) { tSrS_rowcol(m, n) = -INFINITY; } |
|
} |
|
} |
|
} else { |
|
int const local_row_offset_right = causal_row_offset + window_size_right; |
|
int const local_row_offset_left = causal_row_offset - 1 - window_size_left; |
|
int const col_limit_sink = sink_token_length - n_block * kBlockN; |
|
#pragma unroll |
|
for (int m = 0; m < size<0>(tSrS_rowcol); ++m) { |
|
int const row_idx = !PackGQA |
|
? get<Row>(tScS_rowcol(m, _0{})) + m_block * kBlockM |
|
: __shfl_sync(0xffffffff, mma_m_idx, m % kMmaThreadsPerRow, kMmaThreadsPerRow); |
|
int col_limit_right = !Seqlenk_mask |
|
? row_idx + local_row_offset_right |
|
: __viaddmin_s32(row_idx, local_row_offset_right, seqlenk_col_limit); |
|
int col_limit_left = row_idx + local_row_offset_left; |
|
if (attention_chunk_divmod.divisor > 0) { |
|
int col_limit_left_chunk = flash::round_down(attention_chunk_divmod, row_idx + seqlen_k - seqlen_q) - n_block * kBlockN - thread_col_offset; |
|
col_limit_left = std::max(col_limit_left, col_limit_left_chunk); |
|
col_limit_right = std::min(col_limit_right, col_limit_left_chunk + attention_chunk_divmod.divisor); |
|
} |
|
#pragma unroll |
|
for (int n = 0; n < size<1>(tSrS_rowcol); ++n) { |
|
int const col_idx = int(get<Col>(t0ScS_rowcol(m, n))); |
|
if (col_idx >= col_limit_right || (col_idx < col_limit_left && col_idx >= col_limit_sink)) { tSrS_rowcol(m, n) = -INFINITY; } |
|
} |
|
} |
|
} |
|
} else { |
|
|
|
int const thread_row_offset = get<Row>(tScS_rowcol(_0{}, _0{})); |
|
int const causal_row_offset = seqlenk_col_limit - seqlen_q + m_block * kBlockM + thread_row_offset; |
|
if constexpr (Causal_mask) { |
|
#pragma unroll |
|
for (int n = 0; n < size<1>(tSrS_rowcol); ++n) { |
|
int const col0 = int(get<Col>(t0ScS_rowcol(_0{}, n))); |
|
|
|
|
|
int const row_limit_top = col0 >= seqlenk_col_limit ? kBlockM : col0 - causal_row_offset; |
|
#pragma unroll |
|
for (int m = 0; m < size<0>(tSrS_rowcol); ++m) { |
|
if (int(get<Row>(t0ScS_rowcol(m, _0{}))) < row_limit_top) { tSrS_rowcol(m, n) = -INFINITY; } |
|
} |
|
} |
|
} else { |
|
int const col_limit_sink = sink_token_length - n_block * kBlockN - thread_col_offset; |
|
#pragma unroll |
|
for (int n = 0; n < size<1>(tSrS_rowcol); ++n) { |
|
int const col0 = int(get<Col>(t0ScS_rowcol(_0{}, n))); |
|
|
|
|
|
int const row_limit_top = col0 >= seqlenk_col_limit ? kBlockM : col0 - causal_row_offset - window_size_right; |
|
int const row_limit_bot = col0 < col_limit_sink ? kBlockM : col0 - causal_row_offset + window_size_left; |
|
#pragma unroll |
|
for (int m = 0; m < size<0>(tSrS_rowcol); ++m) { |
|
int const row_idx = int(get<Row>(t0ScS_rowcol(m, _0{}))); |
|
if (row_idx < row_limit_top || row_idx > row_limit_bot) { tSrS_rowcol(m, n) = -INFINITY; } |
|
} |
|
} |
|
} |
|
} |
|
} |
|
}; |
|
|
|
}; |
|
|
|
} |
|
|