drbh
commited on
Commit
·
59bdff8
1
Parent(s):
f475609
fix: adjust sig types
Browse files- flash_mla/flash_mla_api.cu +5 -19
- torch-ext/torch_binding.cpp +1 -1
- torch-ext/torch_binding.h +3 -11
flash_mla/flash_mla_api.cu
CHANGED
@@ -53,40 +53,26 @@ get_mla_metadata(
|
|
53 |
return {tile_scheduler_metadata, num_splits};
|
54 |
}
|
55 |
|
|
|
|
|
56 |
std::vector<at::Tensor>
|
57 |
mha_fwd_kvcache_mla(
|
58 |
at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
|
59 |
const at::Tensor &kcache, // num_blocks x page_block_size x num_heads_k x head_size
|
60 |
-
|
61 |
-
// TODO: fix for optional
|
62 |
-
// std::optional<const at::Tensor> &vcache_, // num_blocks x page_block_size x num_heads_k x head_size_v
|
63 |
-
const at::Tensor &vcache_, // num_blocks x page_block_size x num_heads_k x head_size_v
|
64 |
-
|
65 |
const int64_t head_size_v,
|
66 |
const at::Tensor &seqlens_k, // batch_size
|
67 |
const at::Tensor &block_table, // batch_size x max_num_blocks_per_seq
|
68 |
-
|
69 |
-
// TODO: should be float
|
70 |
const double softmax_scale,
|
71 |
-
|
72 |
const at::Tensor &tile_scheduler_metadata, // num_sm_parts x TileSchedulerMetaDataSize
|
73 |
const at::Tensor &num_splits // batch_size + 1
|
74 |
-
|
75 |
-
// TODO: remove this once determined why build is adding this parameter
|
76 |
-
// const int64_t unknown_param
|
77 |
) {
|
78 |
auto dprops = at::cuda::getCurrentDeviceProperties();
|
79 |
bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
|
80 |
TORCH_CHECK(is_sm90);
|
81 |
|
82 |
-
|
83 |
-
bool is_causal = is_causal_;
|
84 |
-
|
85 |
-
|
86 |
-
// TODO: fix for optional
|
87 |
-
// at::Tensor vcache = vcache_.has_value() ? vcache_.value() : kcache;
|
88 |
-
at::Tensor vcache = vcache_;
|
89 |
-
|
90 |
auto q_dtype = q.dtype();
|
91 |
TORCH_CHECK(kcache.dtype() == q_dtype, "query and key must have the same dtype");
|
92 |
|
|
|
53 |
return {tile_scheduler_metadata, num_splits};
|
54 |
}
|
55 |
|
56 |
+
// note doubles and longs are used in place of floats and ints
|
57 |
+
// https://github.com/pytorch/pytorch/blob/338ed67a1e7aa98dd849f297533c5a71bea4b661/aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h#L211
|
58 |
std::vector<at::Tensor>
|
59 |
mha_fwd_kvcache_mla(
|
60 |
at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
|
61 |
const at::Tensor &kcache, // num_blocks x page_block_size x num_heads_k x head_size
|
62 |
+
const c10::optional<torch::Tensor> &vcache_, // num_blocks x page_block_size x num_heads_k x head_size_v
|
|
|
|
|
|
|
|
|
63 |
const int64_t head_size_v,
|
64 |
const at::Tensor &seqlens_k, // batch_size
|
65 |
const at::Tensor &block_table, // batch_size x max_num_blocks_per_seq
|
|
|
|
|
66 |
const double softmax_scale,
|
67 |
+
bool is_causal,
|
68 |
const at::Tensor &tile_scheduler_metadata, // num_sm_parts x TileSchedulerMetaDataSize
|
69 |
const at::Tensor &num_splits // batch_size + 1
|
|
|
|
|
|
|
70 |
) {
|
71 |
auto dprops = at::cuda::getCurrentDeviceProperties();
|
72 |
bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
|
73 |
TORCH_CHECK(is_sm90);
|
74 |
|
75 |
+
at::Tensor vcache = vcache_.has_value() ? vcache_.value() : kcache;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
auto q_dtype = q.dtype();
|
77 |
TORCH_CHECK(kcache.dtype() == q_dtype, "query and key must have the same dtype");
|
78 |
|
torch-ext/torch_binding.cpp
CHANGED
@@ -8,7 +8,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|
8 |
ops.impl("get_mla_metadata", torch::kCUDA, &get_mla_metadata);
|
9 |
|
10 |
// TOOD: remove last unknown_param when resolved
|
11 |
-
ops.def("mha_fwd_kvcache_mla(Tensor! q, Tensor! kcache, Tensor
|
12 |
ops.impl("mha_fwd_kvcache_mla", torch::kCUDA, &mha_fwd_kvcache_mla);
|
13 |
}
|
14 |
|
|
|
8 |
ops.impl("get_mla_metadata", torch::kCUDA, &get_mla_metadata);
|
9 |
|
10 |
// TOOD: remove last unknown_param when resolved
|
11 |
+
ops.def("mha_fwd_kvcache_mla(Tensor! q, Tensor! kcache, Tensor? vcache_, int head_size_v, Tensor! seqlens_k, Tensor! block_table, float softmax_scale, bool is_causal_, Tensor! tile_scheduler_metadata, Tensor! num_splits) -> Tensor[]");
|
12 |
ops.impl("mha_fwd_kvcache_mla", torch::kCUDA, &mha_fwd_kvcache_mla);
|
13 |
}
|
14 |
|
torch-ext/torch_binding.h
CHANGED
@@ -13,21 +13,13 @@ std::vector<torch::Tensor>
|
|
13 |
mha_fwd_kvcache_mla(
|
14 |
torch::Tensor &q,
|
15 |
const torch::Tensor &kcache,
|
16 |
-
|
17 |
-
// TODO: fix for optional
|
18 |
-
// std::optional<torch::Tensor> &vcache_,
|
19 |
-
|
20 |
-
const torch::Tensor &vcache_,
|
21 |
const int64_t head_size_v,
|
22 |
const torch::Tensor &seqlens_k,
|
23 |
const torch::Tensor &block_table,
|
24 |
-
|
25 |
// TODO:should be float
|
26 |
-
const
|
27 |
-
|
28 |
-
// TODO: fix for mutable bool
|
29 |
-
const bool is_causal_,
|
30 |
-
|
31 |
const torch::Tensor &tile_scheduler_metadata,
|
32 |
const torch::Tensor &num_splits
|
33 |
);
|
|
|
13 |
mha_fwd_kvcache_mla(
|
14 |
torch::Tensor &q,
|
15 |
const torch::Tensor &kcache,
|
16 |
+
const c10::optional<torch::Tensor> &vcache_,
|
|
|
|
|
|
|
|
|
17 |
const int64_t head_size_v,
|
18 |
const torch::Tensor &seqlens_k,
|
19 |
const torch::Tensor &block_table,
|
|
|
20 |
// TODO:should be float
|
21 |
+
const torch::kFloat softmax_scale,
|
22 |
+
bool is_causal,
|
|
|
|
|
|
|
23 |
const torch::Tensor &tile_scheduler_metadata,
|
24 |
const torch::Tensor &num_splits
|
25 |
);
|