| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| #pragma once |
|
|
| #include <ATen/ATen.h> |
| #include <ATen/Tensor.h> |
| #include <Python.h> |
| #include <torch/all.h> |
| #include <torch/library.h> |
| #include <torch/torch.h> |
|
|
| #include <tuple> |
| #include <vector> |
|
|
| #include "scalar_type.hpp" |
|
|
| #define _CONCAT(A, B) A##B |
| #define CONCAT(A, B) _CONCAT(A, B) |
|
|
| #define _STRINGIFY(A) #A |
| #define STRINGIFY(A) _STRINGIFY(A) |
|
|
| #define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE) |
|
|
| #define REGISTER_EXTENSION(NAME) \ |
| PyMODINIT_FUNC CONCAT(PyInit_, NAME)() { \ |
| static struct PyModuleDef module = {PyModuleDef_HEAD_INIT, STRINGIFY(NAME), nullptr, 0, nullptr}; \ |
| return PyModule_Create(&module); \ |
| } |
|
|
| using fptr_t = int64_t; |
|
|
| |
| |
| |
| #ifdef USE_ROCM |
| |
| fptr_t init_custom_ar( |
| torch::Tensor& meta, |
| torch::Tensor& rank_data, |
| const std::vector<std::string>& handles, |
| const std::vector<int64_t>& offsets, |
| int64_t rank, |
| bool full_nvlink); |
| void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out); |
| void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer, torch::Tensor& out); |
| void dispose(fptr_t _fa); |
| int64_t meta_size(); |
| void register_buffer( |
| fptr_t _fa, torch::Tensor& t, const std::vector<std::string>& handles, const std::vector<int64_t>& offsets); |
| std::tuple<torch::Tensor, std::vector<int64_t>> get_graph_buffer_ipc_meta(fptr_t _fa); |
| void register_graph_buffers( |
| fptr_t _fa, const std::vector<std::string>& handles, const std::vector<std::vector<int64_t>>& offsets); |
| torch::Tensor allocate_meta_buffer(int64_t size); |
| torch::Tensor get_meta_buffer_ipc_handle(torch::Tensor& inp); |
| |
| fptr_t init_custom_qr(int64_t rank, int64_t world_size, std::optional<int64_t> qr_max_size = std::nullopt); |
| void qr_destroy(fptr_t _fa); |
| torch::Tensor qr_get_handle(fptr_t _fa); |
| void qr_open_handles(fptr_t _fa, const std::vector<torch::Tensor>& handles); |
| void qr_all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, int64_t quant_level, bool cast_bf2half = false); |
| int64_t qr_max_size(); |
| #else |
| |
| fptr_t |
| init_custom_ar(const std::vector<fptr_t>& fake_ipc_ptrs, torch::Tensor& rank_data, int64_t rank, bool full_nvlink); |
| void dispose(fptr_t _fa); |
| int64_t meta_size(); |
| void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, fptr_t _reg_buffer, int64_t reg_buffer_sz_bytes); |
| std::tuple<std::vector<int64_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(fptr_t _fa); |
| void register_buffer(fptr_t _fa, const std::vector<fptr_t>& fake_ipc_ptrs); |
| void register_graph_buffers( |
| fptr_t _fa, const std::vector<std::vector<int64_t>>& handles, const std::vector<std::vector<int64_t>>& offsets); |
|
|
| |
| torch::Tensor mscclpp_generate_unique_id(); |
| fptr_t mscclpp_init_context( |
| const torch::Tensor& unique_id, |
| const int64_t rank, |
| const int64_t world_size, |
| torch::Tensor& scratch, |
| torch::Tensor& put_buffer, |
| const int64_t nranks_per_node, |
| const std::vector<int64_t>& rank_to_node, |
| const std::vector<int64_t>& rank_to_ib, |
| const int64_t context_selection); |
| void mscclpp_allreduce(fptr_t _context, torch::Tensor& inp, torch::Tensor& out, int64_t nthreads, int64_t nblocks); |
| #endif |
|
|
| |
| |
| |
| void merge_state( |
| at::Tensor v_a, at::Tensor s_a, at::Tensor v_b, at::Tensor s_b, at::Tensor v_merged, at::Tensor s_merged); |
| void merge_state_v2( |
| at::Tensor v_a, at::Tensor s_a, at::Tensor v_b, at::Tensor s_b, at::Tensor v_merged, at::Tensor s_merged); |
| void cutlass_mla_decode( |
| torch::Tensor const& out, |
| torch::Tensor const& q_nope, |
| torch::Tensor const& q_pe, |
| torch::Tensor const& kv_c_and_k_pe_cache, |
| torch::Tensor const& seq_lens, |
| torch::Tensor const& page_table, |
| torch::Tensor const& workspace, |
| double sm_scale, |
| int64_t num_kv_splits = 1 ); |
| int64_t cutlass_mla_get_workspace_size( |
| int64_t max_seq_len, |
| int64_t num_batches, |
| int64_t sm_count = 0, |
| int64_t num_kv_splits = 1 ); |
|
|
| |
| |
| |
| void rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, bool enable_pdl); |
| void sgl_fused_add_rmsnorm( |
| torch::Tensor input, torch::Tensor residual, torch::Tensor weight, double eps, bool enable_pdl); |
| void gemma_rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, bool enable_pdl); |
| void gemma_fused_add_rmsnorm(at::Tensor& input, at::Tensor& residual, at::Tensor& weight, double eps, bool enable_pdl); |
| void silu_and_mul(at::Tensor& out, at::Tensor& input); |
| void gelu_tanh_and_mul(at::Tensor& out, at::Tensor& input); |
| void gelu_and_mul(at::Tensor& out, at::Tensor& input); |
|
|
| void apply_rope_pos_ids_cos_sin_cache( |
| at::Tensor q, |
| at::Tensor k, |
| at::Tensor q_rope, |
| at::Tensor k_rope, |
| at::Tensor cos_sin_cache, |
| at::Tensor pos_ids, |
| bool interleave, |
| bool enable_pdl, |
| const std::optional<at::Tensor>& v, |
| const std::optional<at::Tensor>& k_buffer, |
| const std::optional<at::Tensor>& v_buffer, |
| const std::optional<at::Tensor>& kv_cache_loc); |
|
|
| void rotary_embedding( |
| torch::Tensor& positions, |
| torch::Tensor& query, |
| std::optional<torch::Tensor> key, |
| int64_t head_size, |
| torch::Tensor& cos_sin_cache, |
| bool is_neox); |
|
|
| void downcast_fp8( |
| at::Tensor& k, |
| at::Tensor& v, |
| at::Tensor& k_out, |
| at::Tensor& v_out, |
| at::Tensor& k_scale, |
| at::Tensor& v_scale, |
| at::Tensor& loc, |
| int64_t mult, |
| int64_t offset); |
|
|
| void copy_to_gpu_no_ce(const at::Tensor& input, at::Tensor& output); |
| void concat_mla_k(torch::Tensor k, torch::Tensor k_nope, torch::Tensor k_rope); |
| void concat_mla_absorb_q(at::Tensor a, at::Tensor b, at::Tensor out); |
|
|
| void fast_topk_interface( |
| const at::Tensor& score, |
| at::Tensor& indices, |
| const at::Tensor& lengths, |
| std::optional<at::Tensor> row_starts_opt = std::nullopt); |
| void fast_topk_transform_interface( |
| const at::Tensor& score, |
| const at::Tensor& lengths, |
| at::Tensor& dst_page_table, |
| const at::Tensor& src_page_table, |
| const at::Tensor& cu_seqlens_q, |
| std::optional<at::Tensor> row_starts_opt = std::nullopt); |
| void fast_topk_transform_ragged_interface( |
| const at::Tensor& score, |
| const at::Tensor& lengths, |
| at::Tensor& topk_indices_ragged, |
| const at::Tensor& topk_indices_offset, |
| std::optional<at::Tensor> row_starts_opt = std::nullopt); |
|
|
| #ifdef USE_ROCM |
| void gelu_quick(at::Tensor& out, const at::Tensor& input); |
| #endif |
|
|
| |
| |
| |
| torch::Tensor awq_dequantize(torch::Tensor qweight, torch::Tensor scales, torch::Tensor qzeros); |
| torch::Tensor int8_scaled_mm( |
| const torch::Tensor& mat_a, |
| const torch::Tensor& mat_b, |
| const torch::Tensor& scales_a, |
| const torch::Tensor& scales_b, |
| const torch::Dtype& out_dtype, |
| const c10::optional<torch::Tensor>& bias); |
| torch::Tensor fp8_scaled_mm( |
| const torch::Tensor& mat_a, |
| const torch::Tensor& mat_b, |
| const torch::Tensor& scales_a, |
| const torch::Tensor& scales_b, |
| const torch::Dtype& out_dtype, |
| const c10::optional<torch::Tensor>& bias); |
| torch::Tensor fp8_blockwise_scaled_mm( |
| const torch::Tensor& mat_a, |
| const torch::Tensor& mat_b, |
| const torch::Tensor& scales_a, |
| const torch::Tensor& scales_b, |
| const torch::Dtype& out_dtype); |
| void sgl_per_token_group_quant_8bit( |
| at::Tensor input, |
| at::Tensor output_q, |
| at::Tensor output_s, |
| int64_t group_size, |
| double eps, |
| double fp8_min, |
| double fp8_max, |
| bool scale_ue8m0); |
| void sgl_per_token_group_quant_8bit_v2( |
| at::Tensor input, |
| at::Tensor output_q, |
| at::Tensor output_s, |
| int64_t group_size, |
| double eps, |
| double min_8bit, |
| double max_8bit, |
| bool scale_ue8m0, |
| bool fuse_silu_and_mul, |
| const std::optional<torch::Tensor>& masked_m); |
| void sgl_per_tensor_quant_fp8(at::Tensor input, at::Tensor output_q, at::Tensor output_s, bool is_static); |
| void sgl_per_token_quant_fp8(at::Tensor input, at::Tensor output_q, at::Tensor output_s); |
| void bmm_fp8( |
| at::Tensor A, |
| at::Tensor B, |
| at::Tensor D, |
| at::Tensor A_scale, |
| at::Tensor B_scale, |
| at::Tensor workspace_buffer, |
| int64_t cublas_handle); |
| void dsv3_router_gemm(torch::Tensor& output, const torch::Tensor& mat_a, const torch::Tensor& mat_b); |
| void dsv3_fused_a_gemm(torch::Tensor& output, torch::Tensor const& mat_a, torch::Tensor const& mat_b); |
|
|
| torch::Tensor gptq_gemm( |
| torch::Tensor a, |
| torch::Tensor b_q_weight, |
| torch::Tensor b_gptq_qzeros, |
| torch::Tensor b_gptq_scales, |
| torch::Tensor b_g_idx, |
| bool use_shuffle, |
| int64_t bit); |
|
|
| void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit); |
|
|
| |
| |
| |
| void moe_align_block_size( |
| torch::Tensor topk_ids, |
| int64_t num_experts, |
| int64_t block_size, |
| torch::Tensor sorted_token_ids, |
| torch::Tensor experts_ids, |
| torch::Tensor num_tokens_post_pad, |
| torch::Tensor cumsum_buffer, |
| bool pad_sorted_token_ids); |
|
|
| void topk_softmax( |
| torch::Tensor& topk_weights, |
| torch::Tensor& topk_indices, |
| torch::Tensor& gating_output, |
| bool renormalize, |
| double moe_softcapping, |
| const c10::optional<torch::Tensor>& correction_bias); |
|
|
| void topk_sigmoid( |
| torch::Tensor& topk_weights, |
| torch::Tensor& topk_indices, |
| torch::Tensor& gating_output, |
| bool renormalize, |
| const c10::optional<torch::Tensor>& correction_bias); |
|
|
| void moe_sum_reduce(at::Tensor& input, at::Tensor& output, double routed_scaling_factor); |
|
|
| void moe_sum(torch::Tensor& input, torch::Tensor& output); |
|
|
| std::vector<at::Tensor> moe_fused_gate( |
| at::Tensor& input, |
| at::Tensor& bias, |
| int64_t num_expert_group, |
| int64_t topk_group, |
| int64_t topk, |
| int64_t num_fused_shared_experts, |
| double routed_scaling_factor, |
| bool apply_routed_scaling_factor_on_output); |
|
|
| std::vector<at::Tensor> kimi_k2_moe_fused_gate( |
| at::Tensor& input, |
| at::Tensor& bias, |
| int64_t topk, |
| bool renormalize, |
| double routed_scaling_factor, |
| bool apply_routed_scaling_factor_on_output); |
|
|
| void fp8_blockwise_scaled_grouped_mm( |
| torch::Tensor& output, |
| torch::Tensor& a_ptrs, |
| torch::Tensor& b_ptrs, |
| torch::Tensor& out_ptrs, |
| torch::Tensor& a_scales_ptrs, |
| torch::Tensor& b_scales_ptrs, |
| const torch::Tensor& a, |
| const torch::Tensor& b, |
| const torch::Tensor& scales_a, |
| const torch::Tensor& scales_b, |
| const torch::Tensor& stride_a, |
| const torch::Tensor& stride_b, |
| const torch::Tensor& stride_c, |
| const torch::Tensor& layout_sfa, |
| const torch::Tensor& layout_sfb, |
| const torch::Tensor& problem_sizes, |
| const torch::Tensor& expert_offsets, |
| const torch::Tensor& workspace); |
|
|
| void prepare_moe_input( |
| const torch::Tensor& topk_ids, |
| torch::Tensor& expert_offsets, |
| const std::optional<torch::Tensor>& blockscale_offsets, |
| torch::Tensor& problem_sizes1, |
| torch::Tensor& problem_sizes2, |
| torch::Tensor& input_permutation, |
| torch::Tensor& output_permutation, |
| const int64_t num_experts, |
| const int64_t n, |
| const int64_t k); |
|
|
| void shuffle_rows(const torch::Tensor& input_tensor, const torch::Tensor& dst2src_map, torch::Tensor& output_tensor); |
|
|
| void apply_shuffle_mul_sum( |
| const torch::Tensor& input, |
| torch::Tensor& output, |
| const torch::Tensor& permutation, |
| const std::optional<torch::Tensor>& factors); |
|
|
| void fused_qk_norm_rope( |
| torch::Tensor& qkv, |
| int64_t num_heads_q, |
| int64_t num_heads_k, |
| int64_t num_heads_v, |
| int64_t head_dim, |
| double eps, |
| torch::Tensor& q_weight, |
| torch::Tensor& k_weight, |
| double base, |
| bool is_neox, |
| torch::Tensor& position_ids, |
| double factor, |
| double low, |
| double high, |
| double attention_factor, |
| int64_t rotary_dim); |
|
|
| |
| |
| |
| void get_cutlass_w4a8_moe_mm_data( |
| const torch::Tensor& topk_ids, |
| torch::Tensor& expert_offsets, |
| torch::Tensor& problem_sizes1, |
| torch::Tensor& problem_sizes2, |
| torch::Tensor& input_permutation, |
| torch::Tensor& output_permutation, |
| const int64_t num_experts, |
| const int64_t n, |
| const int64_t k); |
|
|
| void cutlass_w4a8_moe_mm( |
| torch::Tensor& d_tensors, |
| torch::Tensor const& a_tensors, |
| torch::Tensor const& b_tensors, |
| torch::Tensor const& a_scales, |
| torch::Tensor const& b_scales, |
| torch::Tensor const& expert_offsets, |
| torch::Tensor const& problem_sizes, |
| torch::Tensor const& a_strides, |
| torch::Tensor const& b_strides, |
| torch::Tensor const& d_strides, |
| torch::Tensor const& s_strides, |
| int64_t chunk_size, |
| int64_t topk); |
| |
| |
| |
| void tree_speculative_sampling_target_only( |
| at::Tensor predicts, |
| at::Tensor accept_index, |
| at::Tensor accept_token_num, |
| at::Tensor candidates, |
| at::Tensor retrive_index, |
| at::Tensor retrive_next_token, |
| at::Tensor retrive_next_sibling, |
| at::Tensor uniform_samples, |
| at::Tensor uniform_samples_for_final_sampling, |
| at::Tensor target_probs, |
| at::Tensor draft_probs, |
| double threshold_single = 1, |
| double threshold_acc = 1, |
| bool deterministic = true); |
|
|
| void verify_tree_greedy( |
| at::Tensor predicts, |
| at::Tensor accept_index, |
| at::Tensor accept_token_num, |
| at::Tensor candidates, |
| at::Tensor retrive_index, |
| at::Tensor retrive_next_token, |
| at::Tensor retrive_next_sibling, |
| at::Tensor target_predict); |
|
|
| void reconstruct_indices_from_tree_mask( |
| at::Tensor tree_mask, |
| at::Tensor verified_seq_len, |
| at::Tensor positions, |
| at::Tensor retrive_index, |
| at::Tensor retrive_next_token, |
| at::Tensor retrive_next_sibling, |
| int64_t batch_size, |
| int64_t draft_token_num); |
|
|
| void build_tree_kernel_efficient( |
| at::Tensor parent_list, |
| at::Tensor selected_index, |
| at::Tensor verified_seq_len, |
| at::Tensor tree_mask, |
| at::Tensor positions, |
| at::Tensor retrive_index, |
| at::Tensor retrive_next_token, |
| at::Tensor retrive_next_sibling, |
| int64_t topk, |
| int64_t depth, |
| int64_t draft_token_num, |
| int64_t tree_mask_mode); |
|
|
| void segment_packbits( |
| at::Tensor x, |
| at::Tensor input_indptr, |
| at::Tensor output_indptr, |
| at::Tensor y, |
| int64_t batch_size, |
| int64_t cuda_stream = 0); |
|
|
| |
| |
| |
| void transfer_kv_per_layer( |
| const at::Tensor src_k, |
| at::Tensor dst_k, |
| const at::Tensor src_v, |
| at::Tensor dst_v, |
| const at::Tensor src_indices, |
| const at::Tensor dst_indices, |
| int64_t item_size, |
| int64_t block_quota, |
| int64_t num_warps_per_block); |
|
|
| void transfer_kv_per_layer_pf_lf( |
| const at::Tensor src_k, |
| at::Tensor dst_k, |
| const at::Tensor src_v, |
| at::Tensor dst_v, |
| const at::Tensor src_indices, |
| const at::Tensor dst_indices, |
| int64_t layer_id, |
| int64_t item_size, |
| int64_t src_layout_dim, |
| int64_t block_quota, |
| int64_t num_warps_per_block); |
|
|
| void transfer_kv_per_layer_ph_lf( |
| const at::Tensor src_k, |
| at::Tensor dst_k, |
| const at::Tensor src_v, |
| at::Tensor dst_v, |
| const at::Tensor src_indices, |
| const at::Tensor dst_indices, |
| int64_t layer_id, |
| int64_t item_size, |
| int64_t src_layout_dim, |
| int64_t page_size, |
| int64_t head_num, |
| int64_t block_quota, |
| int64_t num_warps_per_block); |
|
|
| void transfer_kv_all_layer( |
| const at::Tensor src_k_layers, |
| const at::Tensor dst_k_layers, |
| const at::Tensor src_v_layers, |
| const at::Tensor dst_v_layers, |
| const at::Tensor src_indices, |
| const at::Tensor dst_indices, |
| int64_t item_size, |
| int64_t num_layers, |
| int64_t block_quota, |
| int64_t num_warps_per_block); |
|
|
| void transfer_kv_all_layer_lf_pf( |
| const at::Tensor src_k_layers, |
| at::Tensor dst_k, |
| const at::Tensor src_v_layers, |
| at::Tensor dst_v, |
| const at::Tensor src_indices, |
| const at::Tensor dst_indices, |
| int64_t item_size, |
| int64_t dst_layout_dim, |
| int64_t num_layers, |
| int64_t block_quota, |
| int64_t num_warps_per_block); |
|
|
| void transfer_kv_all_layer_lf_ph( |
| const at::Tensor src_k_layers, |
| at::Tensor dst_k, |
| const at::Tensor src_v_layers, |
| at::Tensor dst_v, |
| const at::Tensor src_indices, |
| const at::Tensor dst_indices, |
| int64_t item_size, |
| int64_t dst_layout_dim, |
| int64_t num_layers, |
| int64_t page_size, |
| int64_t head_num, |
| int64_t block_quota, |
| int64_t num_warps_per_block); |
|
|
| void transfer_kv_per_layer_mla( |
| const at::Tensor src, |
| at::Tensor dst, |
| const at::Tensor src_indices, |
| const at::Tensor dst_indices, |
| int64_t item_size, |
| int64_t block_quota, |
| int64_t num_warps_per_block); |
|
|
| void transfer_kv_per_layer_mla_pf_lf( |
| const at::Tensor src, |
| at::Tensor dst, |
| const at::Tensor src_indices, |
| const at::Tensor dst_indices, |
| int64_t layer_id, |
| int64_t item_size, |
| int64_t src_layout_dim, |
| int64_t block_quota, |
| int64_t num_warps_per_block); |
|
|
| void transfer_kv_all_layer_mla( |
| const at::Tensor src_layers, |
| const at::Tensor dst_layers, |
| const at::Tensor src_indices, |
| const at::Tensor dst_indices, |
| int64_t item_size, |
| int64_t num_layers, |
| int64_t block_quota, |
| int64_t num_warps_per_block); |
|
|
| void transfer_kv_all_layer_mla_lf_pf( |
| const at::Tensor src_layers, |
| at::Tensor dst, |
| const at::Tensor src_indices, |
| const at::Tensor dst_indices, |
| int64_t item_size, |
| int64_t dst_layout_dim, |
| int64_t num_layers, |
| int64_t block_quota, |
| int64_t num_warps_per_block); |
|
|
| void transfer_kv_direct( |
| const std::vector<at::Tensor>& src_layers, |
| std::vector<at::Tensor> dst_layers, |
| const at::Tensor src_indices, |
| const at::Tensor dst_indices, |
| int64_t page_size); |
|
|
| void transfer_kv_per_layer_direct_pf_lf( |
| const std::vector<at::Tensor>& src_ptrs, |
| std::vector<at::Tensor> dst_ptrs, |
| const at::Tensor& src_indices, |
| const at::Tensor& dst_indices, |
| int64_t layer_id, |
| int64_t page_size); |
|
|
| void transfer_kv_all_layer_direct_lf_pf( |
| const std::vector<at::Tensor>& src_ptrs, |
| std::vector<at::Tensor> dst_ptrs, |
| const at::Tensor& src_indices, |
| const at::Tensor& dst_indices, |
| int64_t page_size); |
|
|
| |
| |
| |
| at::Tensor weak_ref_tensor(const at::Tensor& tensor); |
| void store_kv_cache(at::Tensor k_cache, at::Tensor v_cache, at::Tensor out_loc, at::Tensor k, at::Tensor v); |
|
|
| |
| |
| |
| void top_k_renorm_probs( |
| at::Tensor probs, at::Tensor renorm_probs, std::optional<at::Tensor> maybe_top_k_arr, int64_t top_k_val); |
|
|
| void top_p_renorm_probs( |
| at::Tensor probs, at::Tensor renorm_probs, std::optional<at::Tensor> maybe_top_p_arr, double top_p_val); |
|
|
| void top_k_mask_logits( |
| at::Tensor logits, at::Tensor mask_logits, std::optional<at::Tensor> maybe_top_k_arr, int64_t top_k_val); |
|
|
| namespace flash { |
| |
| |
| |
| std::vector<at::Tensor> mha_fwd_sparse( |
| at::Tensor& q, |
| const at::Tensor& k, |
| const at::Tensor& v, |
| const at::Tensor& block_count, |
| const at::Tensor& block_offset, |
| const at::Tensor& column_count, |
| const at::Tensor& column_index, |
| const std::optional<at::Tensor>& out_, |
| const std::optional<at::Tensor>& alibi_slopes_, |
| const double p_dropout, |
| const double softmax_scale, |
| bool is_causal, |
| const double softcap, |
| const bool return_softmax, |
| std::optional<at::Generator> gen_); |
|
|
| std::vector<at::Tensor> mha_varlen_fwd_sparse( |
| at::Tensor& q, |
| const at::Tensor& k, |
| const at::Tensor& v, |
| const at::Tensor& block_count, |
| const at::Tensor& block_offset, |
| const at::Tensor& column_count, |
| const at::Tensor& column_index, |
| const c10::optional<at::Tensor>& out_, |
| const at::Tensor& cu_seqlens_q, |
| const at::Tensor& cu_seqlens_k, |
| const c10::optional<at::Tensor>& |
| seqused_k, |
| const c10::optional<at::Tensor>& alibi_slopes_, |
| int64_t max_seqlen_q, |
| const int64_t max_seqlen_k, |
| const double p_dropout, |
| const double softmax_scale, |
| const bool zero_tensors, |
| bool is_causal, |
| const double softcap, |
| const bool return_softmax, |
| c10::optional<at::Generator> gen_); |
| } |
|
|
| void convert_vertical_slash_indexes( |
| torch::Tensor& block_count, |
| torch::Tensor& block_offset, |
| torch::Tensor& column_count, |
| torch::Tensor& column_index, |
| torch::Tensor q_seqlens, |
| torch::Tensor kv_seqlens, |
| torch::Tensor vertical_indexes, |
| torch::Tensor slash_indexes, |
| int64_t context_size, |
| int64_t block_size_M, |
| int64_t block_size_N, |
| bool causal); |
|
|
| void convert_vertical_slash_indexes_mergehead( |
| torch::Tensor& block_count, |
| torch::Tensor& block_offset, |
| torch::Tensor& column_count, |
| torch::Tensor& column_index, |
| torch::Tensor q_seqlens, |
| torch::Tensor kv_seqlens, |
| torch::Tensor vertical_indexes, |
| torch::Tensor slash_indexes, |
| torch::Tensor vertical_indices_count, |
| torch::Tensor slash_indices_count, |
| int64_t context_size, |
| int64_t block_size_M, |
| int64_t block_size_N, |
| bool causal); |
|
|
| |
| |
| |
| void ApplyTokenBitmaskInplace(at::Tensor logits, at::Tensor bitmask, at::optional<at::Tensor> indices = at::nullopt); |
|
|
| |
| |
| |
| void qserve_w4a8_per_chn_gemm( |
| const torch::Tensor& _in_feats, |
| const torch::Tensor& _kernel, |
| const torch::Tensor& _wscales, |
| const torch::Tensor& _ascales, |
| const torch::Tensor& _w_szs, |
| const torch::Tensor& _a_ssums, |
| torch::Tensor& _out_feats); |
|
|
| void qserve_w4a8_per_group_gemm( |
| const torch::Tensor& _in_feats, |
| const torch::Tensor& _kernel, |
| const torch::Tensor& _zeros, |
| const torch::Tensor& _scales_i8, |
| const torch::Tensor& _wscales, |
| const torch::Tensor& _ascales, |
| torch::Tensor& _out_feats); |
|
|
| |
| |
| |
| torch::Tensor |
| ggml_dequantize(torch::Tensor W, int64_t type, int64_t m, int64_t n, std::optional<at::ScalarType> const& dtype); |
|
|
| torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, torch::Tensor X, int64_t type, int64_t row); |
|
|
| torch::Tensor ggml_mul_mat_a8(torch::Tensor W, torch::Tensor X, int64_t type, int64_t row); |
|
|
| torch::Tensor ggml_moe_a8( |
| torch::Tensor X, |
| torch::Tensor W, |
| torch::Tensor sorted_token_ids, |
| torch::Tensor expert_ids, |
| torch::Tensor num_tokens_post_padded, |
| int64_t type, |
| int64_t row, |
| int64_t top_k, |
| int64_t tokens); |
|
|
| torch::Tensor ggml_moe_a8_vec( |
| torch::Tensor X, torch::Tensor W, torch::Tensor topk_ids, int64_t top_k, int64_t type, int64_t row, int64_t tokens); |
|
|
| int64_t ggml_moe_get_block_size(int64_t type); |
|
|
| |
| |
| |
| std::vector<int64_t> create_greenctx_stream_by_value(int64_t smA, int64_t smB, int64_t device); |
|
|
| |
| |
| |
| void causal_conv1d_update( |
| const at::Tensor& x, |
| const at::Tensor& conv_state, |
| const at::Tensor& weight, |
| const std::optional<at::Tensor>& bias_, |
| bool silu_activation, |
| const std::optional<at::Tensor>& cache_seqlens_, |
| const std::optional<at::Tensor>& conv_state_indices_, |
| int64_t pad_slot_id); |
|
|
| void causal_conv1d_fwd( |
| const at::Tensor& x, |
| const at::Tensor& weight, |
| const std::optional<at::Tensor>& bias_, |
| const std::optional<at::Tensor>& conv_states, |
| const std::optional<at::Tensor>& query_start_loc, |
| const std::optional<at::Tensor>& cache_indices, |
| const std::optional<at::Tensor>& has_initial_state, |
| bool silu_activation, |
| int64_t pad_slot_id); |
|
|
| |
| |
| |
| void es_fp8_blockwise_scaled_grouped_mm( |
| torch::Tensor& output, |
| const torch::Tensor& a, |
| const torch::Tensor& b, |
| const torch::Tensor& scales_a, |
| const torch::Tensor& scales_b, |
| const torch::Tensor& stride_a, |
| const torch::Tensor& stride_b, |
| const torch::Tensor& stride_d, |
| const torch::Tensor& problem_sizes, |
| const torch::Tensor& expert_offsets, |
| const torch::Tensor& workspace); |
|
|
| void es_sm100_mxfp8_blockscaled_grouped_mm( |
| const torch::Tensor& a, |
| const torch::Tensor& b, |
| const torch::Tensor& sfa, |
| const torch::Tensor& sfb, |
| torch::Tensor& d, |
| const torch::Tensor& problem_sizes, |
| const torch::Tensor& expert_offsets, |
| const torch::Tensor& blockscale_offsets); |
|
|
| void es_sm100_mxfp8_blockscaled_grouped_quant( |
| const torch::Tensor& input, |
| const torch::Tensor& problem_sizes, |
| const torch::Tensor& expert_offsets, |
| const torch::Tensor& blockscale_offsets, |
| torch::Tensor& quant_output, |
| torch::Tensor& scale_factor); |
|
|
| |
| |
| |
| std::vector<at::Tensor> get_mla_decoding_metadata( |
| at::Tensor& seqlens_k, |
| const int64_t num_q_tokens_per_head_k, |
| const int64_t h_k, |
| const std::optional<int64_t> h_q, |
| const bool is_fp8_kvcache, |
| const std::optional<int64_t> topk); |
|
|
| std::vector<at::Tensor> fwd_kvcache_mla( |
| at::Tensor& q, |
| const at::Tensor& kcache, |
| |
| const int64_t head_size_v, |
| const at::Tensor& seqlens_k, |
| const at::Tensor& block_table, |
| const double softmax_scale, |
| bool is_causal, |
| const at::Tensor& tile_scheduler_metadata, |
| const at::Tensor& num_splits, |
| const bool& is_fp8, |
| const std::optional<at::Tensor>& indices |
| ); |
|
|
| void FMHACutlassSM100FwdRun( |
| at::Tensor workspace_buffer, |
| at::Tensor q, |
| at::Tensor k, |
| at::Tensor v, |
| at::Tensor cumulative_seqlen_q, |
| at::Tensor cumulative_seqlen_kv, |
| at::Tensor o, |
| at::Tensor lse, |
| int64_t mask_mode_code, |
| double softmax_scale, |
| int64_t max_seqlen_q, |
| int64_t max_seqlen_kv, |
| bool is_varlen); |
|
|
| std::vector<at::Tensor> |
| sparse_prefill_fwd(const at::Tensor& q, const at::Tensor& kv, const at::Tensor& indices, double sm_scale, int64_t d_v); |
|
|
| std::vector<at::Tensor> fwd_kvcache_mla_fp8( |
| at::Tensor& q, |
| const at::Tensor& kcache, |
| |
| const int64_t head_size_v, |
| const at::Tensor& seqlens_k, |
| const at::Tensor& block_table, |
| const double softmax_scale, |
| bool is_causal, |
| const at::Tensor& tile_scheduler_metadata, |
| const at::Tensor& num_splits, |
| const std::optional<at::Tensor>& descale_q, |
| const std::optional<at::Tensor>& descale_k |
| ); |
|
|
| std::vector<at::Tensor> get_mla_decoding_metadata_dense_fp8( |
| at::Tensor& seqlens_k, const int64_t num_heads_per_head_k, const int64_t num_heads_k); |
|
|