|
|
|
|
|
|
|
|
|
#pragma once |
|
|
|
#include "cute/tensor.hpp" |
|
|
|
#include <cutlass/cutlass.h> |
|
#include <cutlass/array.h> |
|
#include <cutlass/numeric_types.h> |
|
#include <cutlass/numeric_conversion.h> |
|
#include "cutlass/arch/barrier.h" |
|
|
|
#include "seqlen.h" |
|
#include "utils.h" |
|
|
|
namespace flash { |
|
|
|
using namespace cute; |
|
|
|
template <class TileShape_MK_, class Element, class ElementAccum, class ArchTag_, int kNThreads, class TiledMma, bool dQ_swapAB> |
|
class FlashAttnBwdPostprocessConvertdQ { |
|
|
|
public: |
|
|
|
|
|
using TileShape_MK = TileShape_MK_; |
|
using ArchTag = ArchTag_; |
|
|
|
static_assert(ArchTag::kMinComputeCapability >= 75); |
|
static constexpr bool IsSm90 = ArchTag::kMinComputeCapability >= 90; |
|
|
|
static constexpr uint32_t MaxThreadsPerBlock = kNThreads; |
|
static constexpr uint32_t MinBlocksPerMultiprocessor = 2; |
|
|
|
static constexpr int kBlockM = get<0>(TileShape_MK{}); |
|
static constexpr int kHeadDim = get<1>(TileShape_MK{}); |
|
static_assert(!IsSm90 || kNThreads % cutlass::NumThreadsPerWarpGroup == 0, "kNThreads must be a multiple of NumThreadsPerWarpGroup"); |
|
static constexpr int NumdQWarpGgroups = kNThreads / cutlass::NumThreadsPerWarpGroup; |
|
using R2SLayoutAtomdQaccum = std::conditional_t< |
|
IsSm90, |
|
Layout<Shape<Int<cutlass::NumThreadsPerWarpGroup>, Int<NumdQWarpGgroups>>>, |
|
Layout<Shape<Int<kNThreads>>> |
|
>; |
|
using R2STiledCopydQaccum = decltype(make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{}, R2SLayoutAtomdQaccum{}, |
|
Layout<Shape<Int<IsSm90 ? 4 : 1>>>{})); |
|
using G2SLayoutAtomdQaccum = Layout<Shape<Int<kNThreads>>>; |
|
|
|
using G2STiledCopydQaccum = decltype(make_tiled_copy(Copy_Atom<UniversalCopy<uint128_t>, ElementAccum>{}, G2SLayoutAtomdQaccum{}, |
|
Layout<Shape<_4>>{})); |
|
|
|
static_assert(IsSm90 || (kBlockM * kHeadDim) % (kNThreads * 4) == 0); |
|
static constexpr int SmemdQaccumSize = size(TileShape_MK{}); |
|
using SmemLayoutdQaccumFlat = Layout<Shape<Int<SmemdQaccumSize>>>; |
|
using SmemLayoutdQaccum = std::conditional_t< |
|
IsSm90, |
|
Layout<Shape<Int<kBlockM * kHeadDim / NumdQWarpGgroups>, Int<NumdQWarpGgroups>>>, |
|
Layout<Shape<Int<kBlockM * kHeadDim>>> |
|
>; |
|
|
|
|
|
|
|
|
|
static constexpr int MmaShapeN = get<1>(typename TiledMma::AtomShape_MNK{}); |
|
static constexpr int kBlockKSmem = MmaShapeN % 64 == 0 ? 64 : (MmaShapeN % 32 == 0 ? 32 : 16); |
|
static constexpr int kSwizzle = kBlockKSmem == 64 ? 3 : (kBlockKSmem == 32 ? 2 : 1); |
|
using SmemLayoutAtomdQ = |
|
decltype(composition(Swizzle<kSwizzle, 3, 3>{}, |
|
Layout<Shape<Int<8>, Int<kBlockKSmem>>, |
|
Stride<Int<kBlockKSmem>, _1>>{})); |
|
using SmemLayoutdQ = decltype(tile_to_shape(SmemLayoutAtomdQ{}, TileShape_MK{})); |
|
using SmemLayoutdQt = |
|
decltype(cute::composition(SmemLayoutdQ{}, |
|
make_layout(make_shape(get<1>(TileShape_MK{}), get<0>(TileShape_MK{})), |
|
make_stride(Int<get<0>(TileShape_MK{})>{}, _1{})))); |
|
|
|
using SmemCopyAtomdQ = Copy_Atom< |
|
std::conditional_t< |
|
IsSm90, |
|
std::conditional_t<!dQ_swapAB, cute::SM90_U32x4_STSM_N, cute::SM90_U16x8_STSM_T>, |
|
AutoVectorizingCopyWithAssumedAlignment<128> |
|
>, |
|
Element>; |
|
|
|
static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); |
|
static_assert(kHeadDim % kGmemElemsPerLoad == 0, "Headdim must be a multiple of kGmemElemsPerLoad"); |
|
static constexpr int kGmemThreadsPerRow = cutlass::gcd(kHeadDim / kGmemElemsPerLoad, int(MaxThreadsPerBlock)); |
|
static_assert(MaxThreadsPerBlock % kGmemThreadsPerRow == 0, "MaxThreadsPerBlock must be a multiple of kGmemThreadsPerRow"); |
|
using GmemLayoutAtom = Layout<Shape <Int<MaxThreadsPerBlock / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>, |
|
Stride<Int<kGmemThreadsPerRow>, _1>>; |
|
using GmemTiledCopy = decltype( |
|
make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{}, |
|
GmemLayoutAtom{}, |
|
Layout<Shape<_1, Int<kGmemElemsPerLoad>>>{})); |
|
|
|
struct SharedStorage : cute::aligned_struct<128> { |
|
cute::array_aligned<ElementAccum, cute::cosize_v<SmemLayoutdQaccum>> smem_dqacc; |
|
cute::array_aligned<Element, cute::cosize_v<SmemLayoutdQ>> smem_dq; |
|
alignas(16) cutlass::arch::ClusterTransactionBarrier barrier_dQaccum; |
|
}; |
|
|
|
static constexpr int SharedStorageSize = sizeof(SharedStorage); |
|
|
|
using ShapedQ = cute::Shape<int32_t, int32_t, int32_t, int32_t>; |
|
using StridedQ = cute::Stride<int64_t, _1, int64_t, int64_t>; |
|
using ShapedQaccum = cute::Shape<int32_t, int32_t, int32_t>; |
|
using StridedQaccum = cute::Stride<_1, int64_t, int64_t>; |
|
|
|
|
|
struct Arguments { |
|
ElementAccum const* ptr_dQaccum; |
|
ShapedQaccum const shape_dQaccum; |
|
StridedQaccum const stride_dQaccum; |
|
Element* ptr_dQ; |
|
ShapedQ const shape_dQ; |
|
StridedQ const stride_dQ; |
|
float const softmax_scale; |
|
int const* cu_seqlens = nullptr; |
|
int const* seqused = nullptr; |
|
}; |
|
|
|
|
|
struct Params { |
|
ElementAccum const* ptr_dQaccum; |
|
ShapedQaccum const shape_dQaccum; |
|
StridedQaccum const stride_dQaccum; |
|
Element* ptr_dQ; |
|
ShapedQ const shape_dQ; |
|
StridedQ const stride_dQ; |
|
float const softmax_scale; |
|
int const* cu_seqlens = nullptr; |
|
int const* seqused = nullptr; |
|
}; |
|
|
|
|
|
static |
|
Params |
|
to_underlying_arguments(Arguments const& args) { |
|
return { |
|
args.ptr_dQaccum, |
|
args.shape_dQaccum, |
|
args.stride_dQaccum, |
|
args.ptr_dQ, |
|
args.shape_dQ, |
|
args.stride_dQ, |
|
args.softmax_scale, |
|
args.cu_seqlens, |
|
args.seqused |
|
}; |
|
} |
|
|
|
CUTLASS_DEVICE |
|
void |
|
operator()(Params const& params, char* smem_buf) { |
|
|
|
static constexpr int kBlockM = get<0>(TileShape_MK{}); |
|
SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(smem_buf); |
|
|
|
Tensor sdQaccum = make_tensor(make_smem_ptr(shared_storage.smem_dqacc.data()), SmemLayoutdQaccum{}); |
|
Tensor sdQaccum_flat = make_tensor(make_smem_ptr(shared_storage.smem_dqacc.data()), SmemLayoutdQaccumFlat{}); |
|
Tensor sdQ = make_tensor(make_smem_ptr(shared_storage.smem_dq.data()), SmemLayoutdQ{}); |
|
Tensor sdQt = make_tensor(make_smem_ptr(shared_storage.smem_dq.data()), SmemLayoutdQt{}); |
|
|
|
int const thread_idx = threadIdx.x; |
|
int const m_block = blockIdx.x; |
|
int const bidh = blockIdx.y; |
|
int const bidb = blockIdx.z; |
|
|
|
flash::SeqlenInfo<true , kBlockM> seqlen_info(bidb, size<0>(params.shape_dQ), params.cu_seqlens, params.seqused); |
|
bool const is_varlen = params.cu_seqlens; |
|
if (is_varlen && m_block * kBlockM >= seqlen_info.seqlen) { return; } |
|
|
|
|
|
Tensor mdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum const*>(params.ptr_dQaccum)), |
|
params.shape_dQaccum, params.stride_dQaccum)(_, bidh, !is_varlen ? bidb : 0); |
|
Tensor gdQaccum = local_tile(domain_offset(make_coord(seqlen_info.offset_padded * kHeadDim), mdQaccum), Shape<Int<kBlockM * kHeadDim>>{}, make_coord(m_block)); |
|
if constexpr (IsSm90) { |
|
static constexpr uint32_t TmaTransactionBytesdQaccum = static_cast<uint32_t>(size(SmemLayoutdQaccumFlat{}) * cute::sizeof_bits_v<ElementAccum> / 8); |
|
auto bulk_copy = Copy_Traits<SM90_BULK_COPY_AUTO>{}; |
|
|
|
if (thread_idx == 0) { |
|
shared_storage.barrier_dQaccum.init(1 ); |
|
shared_storage.barrier_dQaccum.arrive_and_expect_tx(TmaTransactionBytesdQaccum); |
|
copy(bulk_copy.with(*reinterpret_cast<uint64_t*>(&shared_storage.barrier_dQaccum)), gdQaccum, sdQaccum_flat); |
|
} |
|
__syncthreads(); |
|
shared_storage.barrier_dQaccum.wait(0); |
|
} else { |
|
G2STiledCopydQaccum g2s_tiled_copy_dQaccum; |
|
auto g2s_thr_copy_dQaccum = g2s_tiled_copy_dQaccum.get_thread_slice(thread_idx); |
|
Tensor tdQgdQaccumg2s = g2s_thr_copy_dQaccum.partition_S(gdQaccum); |
|
Tensor tdQsdQaccumg2s = g2s_thr_copy_dQaccum.partition_D(sdQaccum); |
|
cute::copy(g2s_tiled_copy_dQaccum, tdQgdQaccumg2s, tdQsdQaccumg2s); |
|
__syncthreads(); |
|
} |
|
|
|
|
|
|
|
|
|
R2STiledCopydQaccum s2r_tiled_copy_dQaccum; |
|
auto s2r_thr_copy_dQaccum = s2r_tiled_copy_dQaccum.get_thread_slice(thread_idx); |
|
Tensor tdQsdQaccum = s2r_thr_copy_dQaccum.partition_S(sdQaccum); |
|
TiledMma tiled_mma_dQ; |
|
Tensor taccdQrdQaccum = partition_fragment_C(tiled_mma_dQ, select<!dQ_swapAB ? 0 : 1, !dQ_swapAB ? 1 : 0>(TileShape_MK{})); |
|
|
|
|
|
|
|
CUTE_STATIC_ASSERT_V(size(taccdQrdQaccum) == size(tdQsdQaccum)); |
|
Tensor tdQrdQaccum = s2r_thr_copy_dQaccum.retile_D(taccdQrdQaccum); |
|
cute::copy(s2r_tiled_copy_dQaccum, tdQsdQaccum, tdQrdQaccum); |
|
#pragma unroll |
|
for (int i = 0; i < size(taccdQrdQaccum); ++i) { taccdQrdQaccum(i) *= params.softmax_scale; } |
|
|
|
Tensor rdQ = make_tensor_like<Element>(taccdQrdQaccum); |
|
flash::convert_type_out(taccdQrdQaccum, rdQ); |
|
|
|
|
|
auto smem_tiled_copy_dQ = make_tiled_copy_C(SmemCopyAtomdQ{}, tiled_mma_dQ); |
|
auto smem_thr_copy_dQ = smem_tiled_copy_dQ.get_thread_slice(thread_idx); |
|
Tensor taccdQrdQ = smem_thr_copy_dQ.retile_S(rdQ); |
|
|
|
|
|
|
|
Tensor taccdQsdQ = smem_thr_copy_dQ.partition_D(cute::conditional_return<!dQ_swapAB>(sdQ, sdQt)); |
|
cute::copy(smem_tiled_copy_dQ, taccdQrdQ, taccdQsdQ); |
|
__syncthreads(); |
|
|
|
|
|
Tensor mdQ = make_tensor(make_gmem_ptr(params.ptr_dQ), params.shape_dQ, params.stride_dQ)(_, _, bidh, !is_varlen ? bidb : 0); |
|
Tensor gdQ = local_tile(domain_offset(make_coord(seqlen_info.offset, _0{}), mdQ), TileShape_MK{}, make_coord(m_block, _0{})); |
|
GmemTiledCopy gmem_tiled_copy_dQ; |
|
auto gmem_thr_copy_dQ = gmem_tiled_copy_dQ.get_thread_slice(thread_idx); |
|
Tensor tdQsdQ = gmem_thr_copy_dQ.partition_S(sdQ); |
|
Tensor tdQgdQ = gmem_thr_copy_dQ.partition_D(gdQ); |
|
|
|
Tensor tdQrdQ = make_fragment_like(tdQsdQ); |
|
Tensor tdQcdQ = gmem_thr_copy_dQ.partition_D(cute::make_identity_tensor(TileShape_MK{})); |
|
Tensor tdQpdQ = make_tensor<bool>(make_shape(size<2>(tdQgdQ))); |
|
#pragma unroll |
|
for (int k = 0; k < size(tdQpdQ); ++k) { tdQpdQ(k) = get<1>(tdQcdQ(_0{}, _0{}, k)) < get<1>(params.shape_dQ); } |
|
|
|
static constexpr bool EvenM = kBlockM % CUTE_STATIC_V(size<0>(GmemLayoutAtom{})) == 0; |
|
flash::copy</*Is_even_MN=*/EvenM, true, false>( |
|
gmem_tiled_copy_dQ, tdQsdQ, tdQrdQ, tdQcdQ, tdQpdQ, kBlockM); |
|
|
|
|
|
|
|
flash::copy</*Is_even_MN=*/false, false, false, false>( |
|
gmem_tiled_copy_dQ, tdQrdQ, tdQgdQ, tdQcdQ, tdQpdQ, std::min(seqlen_info.seqlen - m_block * kBlockM, kBlockM) |
|
); |
|
} |
|
|
|
}; |
|
|
|
} |
|
|