drbh commited on
Commit
59bdff8
·
1 Parent(s): f475609

fix: adjust sig types

Browse files
flash_mla/flash_mla_api.cu CHANGED
@@ -53,40 +53,26 @@ get_mla_metadata(
53
  return {tile_scheduler_metadata, num_splits};
54
  }
55
 
 
 
56
  std::vector<at::Tensor>
57
  mha_fwd_kvcache_mla(
58
  at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
59
  const at::Tensor &kcache, // num_blocks x page_block_size x num_heads_k x head_size
60
-
61
- // TODO: fix for optional
62
- // std::optional<const at::Tensor> &vcache_, // num_blocks x page_block_size x num_heads_k x head_size_v
63
- const at::Tensor &vcache_, // num_blocks x page_block_size x num_heads_k x head_size_v
64
-
65
  const int64_t head_size_v,
66
  const at::Tensor &seqlens_k, // batch_size
67
  const at::Tensor &block_table, // batch_size x max_num_blocks_per_seq
68
-
69
- // TODO: should be float
70
  const double softmax_scale,
71
- const bool is_causal_,
72
  const at::Tensor &tile_scheduler_metadata, // num_sm_parts x TileSchedulerMetaDataSize
73
  const at::Tensor &num_splits // batch_size + 1
74
-
75
- // TODO: remove this once determined why build is adding this parameter
76
- // const int64_t unknown_param
77
  ) {
78
  auto dprops = at::cuda::getCurrentDeviceProperties();
79
  bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
80
  TORCH_CHECK(is_sm90);
81
 
82
- // TODO: fix for mutable bool
83
- bool is_causal = is_causal_;
84
-
85
-
86
- // TODO: fix for optional
87
- // at::Tensor vcache = vcache_.has_value() ? vcache_.value() : kcache;
88
- at::Tensor vcache = vcache_;
89
-
90
  auto q_dtype = q.dtype();
91
  TORCH_CHECK(kcache.dtype() == q_dtype, "query and key must have the same dtype");
92
 
 
53
  return {tile_scheduler_metadata, num_splits};
54
  }
55
 
56
+ // note doubles and longs are used in place of floats and ints
57
+ // https://github.com/pytorch/pytorch/blob/338ed67a1e7aa98dd849f297533c5a71bea4b661/aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h#L211
58
  std::vector<at::Tensor>
59
  mha_fwd_kvcache_mla(
60
  at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
61
  const at::Tensor &kcache, // num_blocks x page_block_size x num_heads_k x head_size
62
+ const c10::optional<torch::Tensor> &vcache_, // num_blocks x page_block_size x num_heads_k x head_size_v
 
 
 
 
63
  const int64_t head_size_v,
64
  const at::Tensor &seqlens_k, // batch_size
65
  const at::Tensor &block_table, // batch_size x max_num_blocks_per_seq
 
 
66
  const double softmax_scale,
67
+ bool is_causal,
68
  const at::Tensor &tile_scheduler_metadata, // num_sm_parts x TileSchedulerMetaDataSize
69
  const at::Tensor &num_splits // batch_size + 1
 
 
 
70
  ) {
71
  auto dprops = at::cuda::getCurrentDeviceProperties();
72
  bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
73
  TORCH_CHECK(is_sm90);
74
 
75
+ at::Tensor vcache = vcache_.has_value() ? vcache_.value() : kcache;
 
 
 
 
 
 
 
76
  auto q_dtype = q.dtype();
77
  TORCH_CHECK(kcache.dtype() == q_dtype, "query and key must have the same dtype");
78
 
torch-ext/torch_binding.cpp CHANGED
@@ -8,7 +8,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
8
  ops.impl("get_mla_metadata", torch::kCUDA, &get_mla_metadata);
9
 
10
  // TOOD: remove last unknown_param when resolved
11
- ops.def("mha_fwd_kvcache_mla(Tensor! q, Tensor! kcache, Tensor! vcache_, int head_size_v, Tensor! seqlens_k, Tensor! block_table, float softmax_scale, bool is_causal_, Tensor! tile_scheduler_metadata, Tensor! num_splits) -> Tensor[]");
12
  ops.impl("mha_fwd_kvcache_mla", torch::kCUDA, &mha_fwd_kvcache_mla);
13
  }
14
 
 
8
  ops.impl("get_mla_metadata", torch::kCUDA, &get_mla_metadata);
9
 
10
  // TOOD: remove last unknown_param when resolved
11
+ ops.def("mha_fwd_kvcache_mla(Tensor! q, Tensor! kcache, Tensor? vcache_, int head_size_v, Tensor! seqlens_k, Tensor! block_table, float softmax_scale, bool is_causal_, Tensor! tile_scheduler_metadata, Tensor! num_splits) -> Tensor[]");
12
  ops.impl("mha_fwd_kvcache_mla", torch::kCUDA, &mha_fwd_kvcache_mla);
13
  }
14
 
torch-ext/torch_binding.h CHANGED
@@ -13,21 +13,13 @@ std::vector<torch::Tensor>
13
  mha_fwd_kvcache_mla(
14
  torch::Tensor &q,
15
  const torch::Tensor &kcache,
16
-
17
- // TODO: fix for optional
18
- // std::optional<torch::Tensor> &vcache_,
19
-
20
- const torch::Tensor &vcache_,
21
  const int64_t head_size_v,
22
  const torch::Tensor &seqlens_k,
23
  const torch::Tensor &block_table,
24
-
25
  // TODO:should be float
26
- const double softmax_scale,
27
-
28
- // TODO: fix for mutable bool
29
- const bool is_causal_,
30
-
31
  const torch::Tensor &tile_scheduler_metadata,
32
  const torch::Tensor &num_splits
33
  );
 
13
  mha_fwd_kvcache_mla(
14
  torch::Tensor &q,
15
  const torch::Tensor &kcache,
16
+ const c10::optional<torch::Tensor> &vcache_,
 
 
 
 
17
  const int64_t head_size_v,
18
  const torch::Tensor &seqlens_k,
19
  const torch::Tensor &block_table,
 
20
  // TODO:should be float
21
+ const torch::kFloat softmax_scale,
22
+ bool is_causal,
 
 
 
23
  const torch::Tensor &tile_scheduler_metadata,
24
  const torch::Tensor &num_splits
25
  );