|
|
|
|
|
|
|
|
|
#pragma once |
|
|
|
#include "cutlass/fast_math.h" |
|
#include "cutlass/arch/barrier.h" |
|
|
|
#include "named_barrier.hpp" |
|
#include "utils.h" |
|
|
|
namespace flash { |
|
|
|
|
|
|
|
|
|
struct TileSchedulerArguments { |
|
|
|
int const num_blocks, num_head, num_batch, num_splits; |
|
int const qhead_per_khead; |
|
int const seqlen; |
|
int const seqlen_k, headdim, headdim_v, element_size; |
|
int* const tile_count_semaphore = nullptr; |
|
int const* const cu_seqlens = nullptr; |
|
int const* const seqused = nullptr; |
|
|
|
int const* const num_splits_dynamic_ptr = nullptr; |
|
}; |
|
|
|
|
|
|
|
template<bool Varlen=false, bool Split=false, bool PackGQA=false, int kBlock=128> |
|
class SingleTileScheduler { |
|
|
|
public: |
|
|
|
using SharedStorage = int; |
|
|
|
|
|
struct Params { |
|
int const num_blocks, num_head, num_batch, num_splits; |
|
int const qhead_per_khead; |
|
int const seqlen; |
|
cutlass::FastDivmod nsplits_divmod; |
|
int const* const cu_seqlens; |
|
int const* const seqused; |
|
int const* const num_splits_dynamic_ptr = nullptr; |
|
}; |
|
|
|
static Params |
|
to_underlying_arguments(TileSchedulerArguments const& args) { |
|
assert(!Split || !Varlen || args.num_splits_dynamic_ptr != nullptr); |
|
assert(!Split || !Varlen || args.num_splits < (1 << 16)); |
|
return {args.num_blocks, args.num_head, args.num_batch, !Split ? 1 : args.num_splits, |
|
args.qhead_per_khead, args.seqlen, |
|
cutlass::FastDivmod(!Split ? 1 : args.num_splits), |
|
!Varlen ? nullptr : args.cu_seqlens, !Varlen ? nullptr : args.seqused, |
|
args.num_splits_dynamic_ptr}; |
|
} |
|
|
|
static dim3 |
|
get_grid_shape(Params const& params, int num_sm) { |
|
return {uint32_t(params.num_blocks), uint32_t((!Split ? 1 : params.num_splits) * params.num_head), uint32_t(params.num_batch)}; |
|
} |
|
|
|
struct WorkTileInfo { |
|
int block_idx = 0; |
|
int bidh = 0; |
|
int bidb = 0; |
|
int split_idx = 0; |
|
|
|
CUTLASS_DEVICE |
|
bool |
|
is_valid(Params const& params) const { |
|
return bidb >= 0; |
|
} |
|
|
|
CUTLASS_DEVICE |
|
cute::tuple<int32_t, int32_t, int32_t, int32_t> |
|
get_block_coord(Params const& params) const { |
|
return {block_idx, bidh, bidb, !Split ? 0 : split_idx}; |
|
} |
|
|
|
}; |
|
|
|
CUTLASS_DEVICE |
|
SingleTileScheduler(SharedStorage* const smem_scheduler) { } |
|
|
|
template<bool IsProducerWarp=false> |
|
CUTLASS_DEVICE |
|
WorkTileInfo |
|
get_initial_work(Params const& params) const { |
|
WorkTileInfo work_info {int(blockIdx.x), int(blockIdx.y), int(blockIdx.z), 0}; |
|
if constexpr (Split) { |
|
int split_idx; |
|
work_info.bidh = params.nsplits_divmod.divmod(split_idx, work_info.bidh); |
|
work_info.split_idx = split_idx; |
|
} |
|
bool is_valid_tile = true; |
|
if constexpr (Varlen) { |
|
int seqlen = params.seqused |
|
? params.seqused[work_info.bidb] |
|
: (params.cu_seqlens ? params.cu_seqlens[work_info.bidb + 1] - params.cu_seqlens[work_info.bidb] : params.seqlen); |
|
if constexpr (PackGQA) { seqlen *= params.qhead_per_khead; } |
|
is_valid_tile = work_info.block_idx * kBlock < seqlen; |
|
} |
|
if constexpr (Varlen && Split) { |
|
int num_splits_dynamic = params.num_splits_dynamic_ptr ? params.num_splits_dynamic_ptr[work_info.bidb] : params.num_splits; |
|
is_valid_tile &= work_info.split_idx < num_splits_dynamic; |
|
|
|
work_info.split_idx |= (num_splits_dynamic << 16); |
|
} |
|
work_info.bidb = is_valid_tile ? work_info.bidb : -1; |
|
return work_info; |
|
} |
|
|
|
CUTLASS_DEVICE |
|
void |
|
init_consumer() const {} |
|
|
|
CUTLASS_DEVICE |
|
void |
|
prefetch_next_work(Params const& params, WorkTileInfo& current_work) const {} |
|
|
|
template<bool IsProducerWarp=false> |
|
CUTLASS_DEVICE |
|
WorkTileInfo |
|
get_next_work(Params const& params, WorkTileInfo const& current_work) const { |
|
return {0, 0, -1, 0}; |
|
} |
|
|
|
}; |
|
|
|
|
|
|
|
template<bool Split=false> |
|
class StaticPersistentTileScheduler { |
|
|
|
public: |
|
|
|
using SharedStorage = int; |
|
|
|
|
|
struct Params { |
|
int total_blocks; |
|
cutlass::FastDivmod m_block_divmod, head_divmod; |
|
cutlass::FastDivmod nsplits_divmod; |
|
}; |
|
|
|
static Params |
|
to_underlying_arguments(TileSchedulerArguments const& args) { |
|
return {args.num_blocks * args.num_head * args.num_batch * (!Split ? 1 : args.num_splits), |
|
cutlass::FastDivmod(args.num_blocks), cutlass::FastDivmod(args.num_head * (!Split ? 1 : args.num_splits)), |
|
cutlass::FastDivmod(!Split ? 1 : args.num_splits)}; |
|
} |
|
|
|
static dim3 |
|
get_grid_shape(Params const& params, int num_sm) { |
|
return {uint32_t(num_sm)}; |
|
} |
|
|
|
struct WorkTileInfo { |
|
int tile_idx; |
|
|
|
CUTLASS_DEVICE |
|
bool |
|
is_valid(Params const& params) const { |
|
return tile_idx < params.total_blocks; |
|
} |
|
|
|
CUTLASS_DEVICE |
|
cute::tuple<int32_t, int32_t, int32_t, int32_t> |
|
get_block_coord(Params const& params) const { |
|
int block, bidh, bidb; |
|
bidb = params.head_divmod.divmod(bidh, params.m_block_divmod.divmod(block, tile_idx)); |
|
int split_idx = 0; |
|
if constexpr (Split) { |
|
bidh = params.nsplits_divmod.divmod(split_idx, bidh); |
|
} |
|
return {block, bidh, bidb, split_idx}; |
|
} |
|
|
|
}; |
|
|
|
CUTLASS_DEVICE |
|
StaticPersistentTileScheduler(SharedStorage* const smem_scheduler) {}; |
|
|
|
template<bool IsProducerWarp=false> |
|
CUTLASS_DEVICE |
|
WorkTileInfo |
|
get_initial_work(Params const& params) const { |
|
return {int(blockIdx.x)}; |
|
} |
|
|
|
CUTLASS_DEVICE |
|
void |
|
init_consumer() const {} |
|
|
|
CUTLASS_DEVICE |
|
void |
|
prefetch_next_work(Params const& params, WorkTileInfo& current_work) const {} |
|
|
|
template<bool IsProducerWarp=false> |
|
CUTLASS_DEVICE |
|
WorkTileInfo |
|
get_next_work(Params const& params, WorkTileInfo const& current_work) const { |
|
return {current_work.tile_idx + int(gridDim.x)}; |
|
} |
|
|
|
}; |
|
|
|
|
|
|
|
template<int NumMmaThreads=2 * cutlass::NumThreadsPerWarpGroup, int NumProducerThreads=cutlass::NumThreadsPerWarp, |
|
bool Split=false, bool PackGQA=false, bool WarpSpecialized=true> |
|
class DynamicPersistentTileScheduler { |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
static_assert(WarpSpecialized || NumProducerThreads == NumMmaThreads); |
|
static constexpr int NumThreads = WarpSpecialized ? NumMmaThreads + NumProducerThreads : NumMmaThreads; |
|
|
|
public: |
|
using SharedStorage = int; |
|
|
|
protected: |
|
SharedStorage* const tile_count_smem; |
|
|
|
public: |
|
|
|
|
|
struct Params { |
|
int const total_blocks; |
|
cutlass::FastDivmod const m_block_divmod, head_divmod; |
|
cutlass::FastDivmod const l2_minor_divmod, l2_major_divmod; |
|
cutlass::FastDivmod const l2_minor_residual_divmod; |
|
int const num_hb_quotient; |
|
int* const tile_count_semaphore; |
|
}; |
|
|
|
static Params |
|
to_underlying_arguments(TileSchedulerArguments const& args) { |
|
int const size_one_kv_head = args.seqlen_k * (args.headdim + args.headdim_v) * args.element_size; |
|
int const size_l2 = 32 * 1024 * 1024; |
|
|
|
|
|
|
|
auto find_log2_floor = [&](int n) { return 31 - cutlass::clz(n); }; |
|
|
|
int const swizzle = (size_l2 < size_one_kv_head ? 1 : (1 << find_log2_floor(size_l2 / size_one_kv_head))) * (PackGQA ? 1 : args.qhead_per_khead); |
|
|
|
|
|
int const num_hb_remainder = (args.num_head * args.num_batch) % swizzle; |
|
int const num_split_blocks = args.num_blocks * (!Split ? 1 : args.num_splits); |
|
|
|
assert(args.tile_count_semaphore != nullptr); |
|
return {num_split_blocks * args.num_head * args.num_batch, |
|
cutlass::FastDivmod(args.num_blocks), cutlass::FastDivmod(args.num_head), |
|
cutlass::FastDivmod(swizzle), cutlass::FastDivmod(swizzle * num_split_blocks), |
|
|
|
cutlass::FastDivmod(num_hb_remainder > 0 ? num_hb_remainder : 1), |
|
(args.num_head * args.num_batch) / swizzle, |
|
args.tile_count_semaphore}; |
|
} |
|
|
|
static dim3 |
|
get_grid_shape(Params const& params, int num_sm) { |
|
return {uint32_t(num_sm)}; |
|
} |
|
|
|
struct WorkTileInfo { |
|
int tile_idx; |
|
|
|
CUTLASS_DEVICE |
|
bool |
|
is_valid(Params const& params) const { |
|
return tile_idx < params.total_blocks; |
|
} |
|
|
|
CUTLASS_DEVICE |
|
cute::tuple<int32_t, int32_t, int32_t, int32_t> |
|
get_block_coord(Params const& params) const { |
|
int block, bidh, bidb; |
|
int l2_mod, bidhb, bidhb_residual; |
|
bidhb = params.l2_major_divmod.divmod(l2_mod, tile_idx); |
|
|
|
|
|
if (bidhb < params.num_hb_quotient) { |
|
block = params.l2_minor_divmod.divmod(bidhb_residual, l2_mod); |
|
} else { |
|
block = params.l2_minor_residual_divmod.divmod(bidhb_residual, l2_mod); |
|
} |
|
bidb = params.head_divmod.divmod(bidh, bidhb * params.l2_minor_divmod.divisor + bidhb_residual); |
|
int split_idx = 0; |
|
if constexpr (Split) { |
|
split_idx = params.m_block_divmod.divmod(block, block); |
|
} |
|
|
|
block = params.m_block_divmod.divisor - 1 - block; |
|
return {block, bidh, bidb, split_idx}; |
|
} |
|
|
|
}; |
|
|
|
CUTLASS_DEVICE |
|
DynamicPersistentTileScheduler(SharedStorage* const smem_scheduler) : tile_count_smem(smem_scheduler) {}; |
|
|
|
template<bool IsProducerWarp=false> |
|
CUTLASS_DEVICE |
|
WorkTileInfo |
|
get_initial_work(Params const& params) const { |
|
return {int(blockIdx.x)}; |
|
} |
|
|
|
CUTLASS_DEVICE |
|
void |
|
init_consumer() const { |
|
if (WarpSpecialized || cutlass::canonical_warp_idx_sync() > 0) { |
|
flash::named_barrier_arrive(NumThreads, cutlass::arch::ReservedNamedBarriers::StreamkBarrier0 ); |
|
} |
|
} |
|
|
|
CUTLASS_DEVICE |
|
void |
|
prefetch_next_work(Params const& params, WorkTileInfo& current_work) const { |
|
if (threadIdx.x % NumProducerThreads == 0) { |
|
current_work.tile_idx = atomicAdd(params.tile_count_semaphore, 1) + int(gridDim.x); |
|
} |
|
} |
|
|
|
template<bool IsProducerWarp=false> |
|
CUTLASS_DEVICE |
|
WorkTileInfo |
|
get_next_work(Params const& params, WorkTileInfo const& current_work) const { |
|
if constexpr (IsProducerWarp) { |
|
|
|
int new_tile_idx = __shfl_sync(0xffffffff, current_work.tile_idx, 0 ); |
|
flash::named_barrier_sync(NumThreads, cutlass::arch::ReservedNamedBarriers::StreamkBarrier0 ); |
|
if (threadIdx.x % NumProducerThreads == 0) { |
|
*tile_count_smem = current_work.tile_idx; |
|
} |
|
flash::named_barrier_arrive(NumThreads, cutlass::arch::ReservedNamedBarriers::StreamkBarrier1 ); |
|
return {new_tile_idx}; |
|
} else { |
|
flash::named_barrier_sync(NumThreads, cutlass::arch::ReservedNamedBarriers::StreamkBarrier1 ); |
|
int tile_idx = *tile_count_smem; |
|
flash::named_barrier_arrive(NumThreads, cutlass::arch::ReservedNamedBarriers::StreamkBarrier0 ); |
|
return {tile_idx}; |
|
} |
|
} |
|
|
|
}; |
|
|
|
|
|
|
|
class SingleTileBwdLPTScheduler { |
|
|
|
public: |
|
|
|
using SharedStorage = int; |
|
|
|
|
|
struct Params { |
|
int const total_blocks; |
|
cutlass::FastDivmod const m_block_divmod, head_divmod; |
|
cutlass::FastDivmod const l2_minor_divmod, l2_major_divmod; |
|
cutlass::FastDivmod const l2_minor_residual_divmod; |
|
int const num_hb_quotient; |
|
}; |
|
|
|
static Params |
|
to_underlying_arguments(TileSchedulerArguments const& args) { |
|
|
|
int const size_one_qdo_head = args.seqlen_k * (args.headdim + args.headdim_v) * args.element_size; |
|
int const size_one_dqaccum_head = args.seqlen_k * args.headdim * sizeof(float); |
|
int const size_one_head = size_one_qdo_head + size_one_dqaccum_head; |
|
int const size_l2 = 40 * 1024 * 1024; |
|
|
|
|
|
auto find_log2_floor = [&](int n) { return 31 - cutlass::clz(n); }; |
|
|
|
int const swizzle = size_l2 < size_one_head ? 1 : (1 << find_log2_floor(size_l2 / size_one_head)); |
|
|
|
|
|
int const num_hb_remainder = (args.num_head * args.num_batch) % swizzle; |
|
|
|
assert(args.tile_count_semaphore != nullptr); |
|
return {args.num_blocks * args.num_head * args.num_batch, |
|
cutlass::FastDivmod(args.num_blocks), cutlass::FastDivmod(args.num_head), |
|
cutlass::FastDivmod(swizzle), cutlass::FastDivmod(swizzle * args.num_blocks), |
|
|
|
cutlass::FastDivmod(num_hb_remainder > 0 ? num_hb_remainder : 1), |
|
(args.num_head * args.num_batch) / swizzle}; |
|
} |
|
|
|
static dim3 |
|
get_grid_shape(Params const& params, int num_sm) { |
|
return {uint32_t(params.total_blocks)}; |
|
} |
|
|
|
struct WorkTileInfo { |
|
int tile_idx; |
|
|
|
CUTLASS_DEVICE |
|
bool |
|
is_valid(Params const& params) const { |
|
return tile_idx < params.total_blocks; |
|
} |
|
|
|
CUTLASS_DEVICE |
|
cute::tuple<int32_t, int32_t, int32_t, int32_t> |
|
get_block_coord(Params const& params) const { |
|
int block, bidh, bidb; |
|
int l2_mod, bidhb, bidhb_residual; |
|
bidhb = params.l2_major_divmod.divmod(l2_mod, tile_idx); |
|
|
|
|
|
if (bidhb < params.num_hb_quotient) { |
|
block = params.l2_minor_divmod.divmod(bidhb_residual, l2_mod); |
|
} else { |
|
block = params.l2_minor_residual_divmod.divmod(bidhb_residual, l2_mod); |
|
} |
|
bidb = params.head_divmod.divmod(bidh, bidhb * params.l2_minor_divmod.divisor + bidhb_residual); |
|
return {block, bidh, bidb, 0 }; |
|
} |
|
|
|
}; |
|
|
|
CUTLASS_DEVICE |
|
SingleTileBwdLPTScheduler(SharedStorage* const smem_scheduler) { } |
|
|
|
template<bool IsProducerWarp=false> |
|
CUTLASS_DEVICE |
|
WorkTileInfo |
|
get_initial_work(Params const& params) const { |
|
return {int(blockIdx.x)}; |
|
} |
|
|
|
CUTLASS_DEVICE |
|
void |
|
init_consumer() const {} |
|
|
|
CUTLASS_DEVICE |
|
void |
|
prefetch_next_work(Params const& params, WorkTileInfo& current_work) const {} |
|
|
|
template<bool IsProducerWarp=false> |
|
CUTLASS_DEVICE |
|
WorkTileInfo |
|
get_next_work(Params const& params, WorkTileInfo const& current_work) const { |
|
return {params.total_blocks}; |
|
} |
|
|
|
}; |
|
|
|
|
|
|
|
template<int kBlock, int NumMmaThreads=2 * cutlass::NumThreadsPerWarpGroup, int NumProducerThreads=cutlass::NumThreadsPerWarp, bool Split=false, bool PackGQA=false, bool WarpSpecialized=true> |
|
class VarlenDynamicPersistentTileScheduler { |
|
|
|
static_assert(WarpSpecialized || NumProducerThreads == NumMmaThreads); |
|
static constexpr int NumThreads = WarpSpecialized ? NumMmaThreads + NumProducerThreads : NumMmaThreads; |
|
|
|
public: |
|
using SharedStorage = int4; |
|
|
|
protected: |
|
SharedStorage* const work_info_smem; |
|
|
|
public: |
|
|
|
|
|
struct Params { |
|
int num_head, num_batch; |
|
int const qhead_per_khead; |
|
int const seqlen; |
|
cutlass::FastDivmod head_divmod; |
|
cutlass::FastDivmod nsplits_divmod; |
|
int* const tile_count_semaphore; |
|
int const* const cu_seqlens; |
|
int const* const seqused; |
|
|
|
int const* const num_splits_dynamic_ptr; |
|
}; |
|
|
|
static Params |
|
to_underlying_arguments(TileSchedulerArguments const& args) { |
|
|
|
|
|
assert(args.tile_count_semaphore != nullptr); |
|
assert(args.num_head < (1 << 16)); |
|
assert(!Split || args.num_splits < (1 << 8)); |
|
return {args.num_head, args.num_batch, |
|
args.qhead_per_khead, args.seqlen, |
|
cutlass::FastDivmod(args.num_head), |
|
cutlass::FastDivmod(!Split ? 1 : args.num_splits), |
|
args.tile_count_semaphore, args.cu_seqlens, args.seqused, |
|
|
|
args.num_splits_dynamic_ptr}; |
|
} |
|
|
|
static dim3 |
|
get_grid_shape(Params const& params, int num_sm) { |
|
return {uint32_t(num_sm)}; |
|
} |
|
|
|
struct WorkTileInfo { |
|
int tile_idx, block, bidh, bidb; |
|
|
|
CUTLASS_DEVICE |
|
bool |
|
is_valid(Params const& params) const { |
|
|
|
return bidb < params.num_batch; |
|
} |
|
|
|
CUTLASS_DEVICE |
|
cute::tuple<int32_t, int32_t, int32_t, int32_t> |
|
get_block_coord(Params const& params) const { |
|
if constexpr (!Split) { |
|
return {block, bidh, bidb, 0 }; |
|
} else { |
|
|
|
|
|
uint32_t bidh_packed = reinterpret_cast<uint32_t const&>(bidh); |
|
uint32_t bidh_actual_u = bidh_packed & 0x0000FFFF; |
|
int bidh_actual = reinterpret_cast<int&>(bidh_actual_u); |
|
|
|
uint32_t split_idx_u = ((bidh_packed & 0x00FF0000) >> 16) + ((bidh_packed & 0xFF000000) >> 8); |
|
int split_idx = reinterpret_cast<int&>(split_idx_u); |
|
|
|
|
|
|
|
|
|
return {block, bidh_actual, bidb, split_idx}; |
|
} |
|
} |
|
}; |
|
|
|
CUTLASS_DEVICE |
|
VarlenDynamicPersistentTileScheduler(SharedStorage* const smem_scheduler) : work_info_smem(smem_scheduler) {}; |
|
|
|
CUTLASS_DEVICE |
|
WorkTileInfo |
|
tile_idx_to_work_tile(Params const& params, int next_tile_idx, WorkTileInfo const& current_work) const { |
|
int lane = threadIdx.x % cutlass::NumThreadsPerWarp; |
|
auto get_num_m_blocks = [&] (int bidb_start) { |
|
int batch_idx = lane + bidb_start; |
|
int seqlen = params.seqlen * (!PackGQA ? 1 : params.qhead_per_khead); |
|
if (seqlen > kBlock) { |
|
if (params.seqused) { |
|
seqlen = batch_idx < params.num_batch ? params.seqused[batch_idx] : 0; |
|
} else if (params.cu_seqlens) { |
|
int cur_cu_seqlen = batch_idx <= params.num_batch ? params.cu_seqlens[batch_idx] : 0; |
|
int next_cu_seqlen = __shfl_down_sync(0xffffffff, cur_cu_seqlen, 1); |
|
seqlen = next_cu_seqlen - cur_cu_seqlen; |
|
} else { |
|
seqlen = params.seqlen; |
|
} |
|
if constexpr (PackGQA) { seqlen *= params.qhead_per_khead; } |
|
} |
|
return batch_idx < params.num_batch && lane < cutlass::NumThreadsPerWarp - 1 |
|
? cute::ceil_div(seqlen, kBlock) : 0; |
|
|
|
}; |
|
|
|
auto get_num_splits = [&] (int bidb_start) { |
|
int batch_idx = lane + bidb_start; |
|
return batch_idx < params.num_batch && lane < cutlass::NumThreadsPerWarp - 1 |
|
? (!Split ? 1 : (params.num_splits_dynamic_ptr |
|
? params.num_splits_dynamic_ptr[batch_idx] |
|
: params.nsplits_divmod.divisor)) |
|
: 0; |
|
}; |
|
|
|
int num_m_blocks = get_num_m_blocks(current_work.bidb); |
|
int num_splits = get_num_splits(current_work.bidb); |
|
int num_split_m_blocks = !Split ? num_m_blocks : num_m_blocks * num_splits; |
|
|
|
int num_m_blocks_cumulative = warp_prefix_sum(num_split_m_blocks); |
|
|
|
int m_blocks_in_group = __shfl_sync(0xffffffff, num_m_blocks_cumulative, cutlass::NumThreadsPerWarp - 1); |
|
|
|
int current_bidh = !Split ? current_work.bidh : (current_work.bidh & 0x0000FFFF); |
|
int group_end_tile = current_work.tile_idx - current_work.block - current_bidh * __shfl_sync(0xffffffff, num_split_m_blocks, 0 ) + m_blocks_in_group * params.num_head; |
|
if constexpr (Split) { |
|
int current_split_idx = (current_work.bidh & 0x00FF0000) >> 16; |
|
group_end_tile -= current_split_idx * __shfl_sync(0xffffffff, num_m_blocks, 0 ); |
|
} |
|
int bidb = current_work.bidb; |
|
|
|
|
|
|
|
|
|
while (group_end_tile <= next_tile_idx) { |
|
bidb += cutlass::NumThreadsPerWarp - 1; |
|
if (bidb >= params.num_batch) { |
|
|
|
|
|
|
|
return {next_tile_idx, 0, 0, params.num_batch}; |
|
} |
|
num_m_blocks = get_num_m_blocks(bidb); |
|
num_splits = get_num_splits(bidb); |
|
num_split_m_blocks = !Split ? num_m_blocks : num_m_blocks * num_splits; |
|
num_m_blocks_cumulative = warp_prefix_sum(num_split_m_blocks); |
|
m_blocks_in_group = __shfl_sync(0xffffffff, num_m_blocks_cumulative, cutlass::NumThreadsPerWarp - 1); |
|
group_end_tile += m_blocks_in_group * params.num_head; |
|
|
|
|
|
|
|
} |
|
int group_start_tile = group_end_tile - m_blocks_in_group * params.num_head; |
|
|
|
|
|
int batch_idx_in_group = __popc(__ballot_sync(0xffffffff, group_start_tile + num_m_blocks_cumulative * params.num_head <= next_tile_idx)); |
|
|
|
bidb += batch_idx_in_group; |
|
num_m_blocks = __shfl_sync(0xffffffff, num_m_blocks, batch_idx_in_group); |
|
if constexpr (Split) { num_splits = __shfl_sync(0xffffffff, num_splits, batch_idx_in_group); } |
|
int mh_block = next_tile_idx - group_start_tile - (batch_idx_in_group == 0 ? 0 : __shfl_sync(0xffffffff, num_m_blocks_cumulative, batch_idx_in_group - 1)) * params.num_head; |
|
int bidh = mh_block / num_m_blocks; |
|
int block = mh_block - bidh * num_m_blocks; |
|
if constexpr (Split) { |
|
int bidh_actual = bidh / num_splits; |
|
int split_idx = bidh - bidh_actual * num_splits; |
|
|
|
|
|
|
|
|
|
|
|
uint32_t bidh_packed = reinterpret_cast<uint32_t&>(bidh_actual) + (reinterpret_cast<uint32_t&>(split_idx) << 16) + (reinterpret_cast<uint32_t&>(num_splits) << 24); |
|
|
|
|
|
|
|
bidh = reinterpret_cast<int&>(bidh_packed); |
|
} |
|
|
|
|
|
|
|
return {next_tile_idx, block, bidh, bidb}; |
|
} |
|
|
|
template<bool IsProducerWarp=false> |
|
CUTLASS_DEVICE |
|
WorkTileInfo |
|
get_initial_work(Params const& params) const { |
|
if constexpr (IsProducerWarp) { |
|
WorkTileInfo work_info = tile_idx_to_work_tile(params, int(blockIdx.x), {0, 0, 0, 0}); |
|
if (threadIdx.x % cutlass::NumThreadsPerWarp == 0) { |
|
*work_info_smem = make_int4(work_info.tile_idx, work_info.block, work_info.bidh, work_info.bidb); |
|
} |
|
flash::named_barrier_arrive(NumThreads, cutlass::arch::ReservedNamedBarriers::StreamkBarrier1 ); |
|
return work_info; |
|
} else { |
|
return get_next_work<false>(params, {0, 0, 0, 0}); |
|
} |
|
} |
|
|
|
CUTLASS_DEVICE |
|
void |
|
init_consumer() const { |
|
|
|
} |
|
|
|
CUTLASS_DEVICE |
|
void |
|
prefetch_next_work(Params const& params, WorkTileInfo& current_work) const { |
|
if (threadIdx.x % NumProducerThreads == 0) { |
|
current_work.tile_idx = atomicAdd(params.tile_count_semaphore, 1) + int(gridDim.x); |
|
} |
|
} |
|
|
|
template<bool IsProducerWarp=false> |
|
CUTLASS_DEVICE |
|
WorkTileInfo |
|
get_next_work(Params const& params, WorkTileInfo const& current_work) const { |
|
if constexpr (IsProducerWarp) { |
|
|
|
int new_tile_idx = __shfl_sync(0xffffffff, current_work.tile_idx, 0 ); |
|
WorkTileInfo work_info = {__shfl_sync(0xffffffff, current_work.tile_idx, 1 ), current_work.block, current_work.bidh, current_work.bidb}; |
|
work_info = tile_idx_to_work_tile(params, new_tile_idx, work_info); |
|
flash::named_barrier_sync(NumThreads, cutlass::arch::ReservedNamedBarriers::StreamkBarrier0 ); |
|
if (threadIdx.x % cutlass::NumThreadsPerWarp == 0) { |
|
*work_info_smem = make_int4(work_info.tile_idx, work_info.block, work_info.bidh, work_info.bidb); |
|
} |
|
flash::named_barrier_arrive(NumThreads, cutlass::arch::ReservedNamedBarriers::StreamkBarrier1 ); |
|
return work_info; |
|
} else { |
|
flash::named_barrier_sync(NumThreads, cutlass::arch::ReservedNamedBarriers::StreamkBarrier1 ); |
|
int4 work_info = *work_info_smem; |
|
flash::named_barrier_arrive(NumThreads, cutlass::arch::ReservedNamedBarriers::StreamkBarrier0 ); |
|
return WorkTileInfo{work_info.x, work_info.y, work_info.z, work_info.w}; |
|
} |
|
} |
|
|
|
}; |
|
|
|
|
|
|
|
} |
|
|