|
|
|
|
|
|
|
|
|
#pragma once |
|
|
|
#include <cuda.h> |
|
#include <vector> |
|
|
|
|
|
|
|
struct Qkv_params { |
|
using index_t = int64_t; |
|
|
|
void *__restrict__ q_ptr; |
|
void *__restrict__ k_ptr; |
|
void *__restrict__ v_ptr; |
|
|
|
|
|
index_t q_batch_stride; |
|
index_t k_batch_stride; |
|
index_t v_batch_stride; |
|
index_t q_row_stride; |
|
index_t k_row_stride; |
|
index_t v_row_stride; |
|
index_t q_head_stride; |
|
index_t k_head_stride; |
|
index_t v_head_stride; |
|
index_t v_dim_stride; |
|
|
|
|
|
int h, h_k; |
|
}; |
|
|
|
|
|
|
|
struct Flash_fwd_params : public Qkv_params { |
|
using index_t = int64_t; |
|
|
|
|
|
void * __restrict__ o_ptr; |
|
void * __restrict__ oaccum_ptr; |
|
|
|
|
|
index_t o_batch_stride; |
|
index_t o_row_stride; |
|
index_t o_head_stride; |
|
|
|
|
|
void * __restrict__ softmax_lse_ptr; |
|
void * __restrict__ softmax_lseaccum_ptr; |
|
|
|
|
|
float * __restrict__ q_descale_ptr; |
|
float * __restrict__ k_descale_ptr; |
|
float * __restrict__ v_descale_ptr; |
|
index_t q_descale_batch_stride; |
|
index_t q_descale_head_stride; |
|
index_t k_descale_batch_stride; |
|
index_t k_descale_head_stride; |
|
index_t v_descale_batch_stride; |
|
index_t v_descale_head_stride; |
|
|
|
|
|
int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim; |
|
int total_q, total_k, total_knew; |
|
int b_k; |
|
int dv, dv_rounded; |
|
|
|
|
|
float scale_softmax; |
|
float softcap; |
|
|
|
|
|
int * __restrict__ cu_seqlens_q; |
|
int * __restrict__ cu_seqlens_k; |
|
int * __restrict__ cu_seqlens_knew; |
|
int * __restrict__ leftpad_k; |
|
|
|
|
|
int *__restrict__ seqused_q; |
|
int *__restrict__ seqused_k; |
|
|
|
|
|
index_t oaccum_split_stride; |
|
index_t oaccum_batch_stride; |
|
index_t oaccum_row_stride; |
|
index_t oaccum_head_stride; |
|
|
|
|
|
index_t lseaccum_split_stride; |
|
index_t lseaccum_batch_stride; |
|
index_t lseaccum_head_stride; |
|
|
|
|
|
void * __restrict__ knew_ptr; |
|
void * __restrict__ vnew_ptr; |
|
|
|
|
|
index_t knew_batch_stride; |
|
index_t vnew_batch_stride; |
|
index_t knew_row_stride; |
|
index_t vnew_row_stride; |
|
index_t knew_head_stride; |
|
index_t vnew_head_stride; |
|
|
|
void *__restrict__ qv_ptr; |
|
index_t qv_batch_stride; |
|
index_t qv_row_stride; |
|
index_t qv_head_stride; |
|
|
|
|
|
void * __restrict__ rotary_cos_ptr; |
|
void * __restrict__ rotary_sin_ptr; |
|
int *__restrict__ seqlens_rotary; |
|
|
|
|
|
int * __restrict__ kv_batch_idx; |
|
|
|
|
|
int * __restrict__ page_table; |
|
index_t page_table_batch_stride; |
|
int page_size; |
|
int num_pages; |
|
bool pagedkv_tma; |
|
|
|
|
|
float p_dropout; |
|
|
|
|
|
uint8_t p_dropout_in_uint8_t; |
|
|
|
|
|
float rp_dropout; |
|
|
|
|
|
int window_size_left, window_size_right; |
|
int attention_chunk; |
|
|
|
|
|
uint64_t * rng_state; |
|
|
|
bool is_bf16; |
|
bool is_fp32; |
|
bool is_e4m3; |
|
bool is_causal; |
|
bool is_local; |
|
|
|
bool is_rotary_interleaved; |
|
|
|
int num_splits; |
|
bool pack_gqa; |
|
|
|
int * __restrict__ tile_count_semaphore; |
|
|
|
|
|
int * __restrict__ num_splits_dynamic_ptr; |
|
bool skip_scheduler_metadata_computation; |
|
|
|
int arch; |
|
int num_sm; |
|
}; |
|
|
|
|
|
|
|
struct Flash_bwd_params : public Flash_fwd_params { |
|
using index_t = int64_t; |
|
|
|
|
|
void *__restrict__ do_ptr; |
|
void *__restrict__ dq_ptr; |
|
void *__restrict__ dk_ptr; |
|
void *__restrict__ dv_ptr; |
|
|
|
|
|
void *__restrict__ dq_accum_ptr; |
|
void *__restrict__ dk_accum_ptr; |
|
void *__restrict__ dv_accum_ptr; |
|
|
|
|
|
|
|
|
|
|
|
|
|
index_t do_batch_stride; |
|
index_t do_row_stride; |
|
index_t do_head_stride; |
|
index_t dq_batch_stride; |
|
index_t dk_batch_stride; |
|
index_t dv_batch_stride; |
|
index_t dq_row_stride; |
|
index_t dk_row_stride; |
|
index_t dv_row_stride; |
|
index_t dq_head_stride; |
|
index_t dk_head_stride; |
|
index_t dv_head_stride; |
|
|
|
|
|
void *__restrict__ dsoftmax_sum; |
|
void *__restrict__ softmax_lse_log2_ptr; |
|
|
|
int *__restrict__ dq_semaphore; |
|
int *__restrict__ dk_semaphore; |
|
int *__restrict__ dv_semaphore; |
|
|
|
bool deterministic; |
|
index_t dq_accum_split_stride; |
|
}; |
|
|
|
|
|
|
|
template <int Arch, typename T, int kHeadDim, int kHeadDimV, bool Split, bool PagedKVNonTMA, bool Has_softcap, bool PackGQA> |
|
void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream); |
|
void prepare_varlen_num_blocks(Flash_fwd_params ¶ms, cudaStream_t stream, bool packgqa, int blockM, int blockN, bool enable_pdl); |
|
template <int Arch, typename T, int kHeadDim, bool Has_softcap> |
|
void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream); |
|
template <typename T, typename Tpartial, int kBlockK> |
|
void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl); |
|
|