iofu728 commited on
Commit
b215053
1 Parent(s): 82b1ec4

Feature(MInference): fix the func name

Browse files
minference/ops/block_sparse_flash_attention.py CHANGED
@@ -444,7 +444,7 @@ def test_flash_attention(
444
  print('========================================\n')
445
 
446
 
447
- def block_sparse_flash_attention_forward(
448
  query: torch.Tensor, # [BATCH, N_HEADS, N_CTX, D_HEAD]
449
  key: torch.Tensor, # [BATCH, N_HEADS, N_CTX, D_HEAD]
450
  value: torch.Tensor, # [BATCH, N_HEADS, N_CTX, D_HEAD]
 
444
  print('========================================\n')
445
 
446
 
447
+ def block_sparse_attention(
448
  query: torch.Tensor, # [BATCH, N_HEADS, N_CTX, D_HEAD]
449
  key: torch.Tensor, # [BATCH, N_HEADS, N_CTX, D_HEAD]
450
  value: torch.Tensor, # [BATCH, N_HEADS, N_CTX, D_HEAD]
minference/ops/pit_sparse_flash_attention_v2.py CHANGED
@@ -693,7 +693,7 @@ def test_flash_attention(
693
  torch.testing.assert_close(output_flash, output_triton_sparse, atol=1e-2, rtol=0)
694
 
695
 
696
- def pit_sparse_flash_attention_forward(
697
  query: torch.Tensor, # [BATCH, N_HEADS, N_CTX, D_HEAD]
698
  key: torch.Tensor, # [BATCH, N_HEADS, N_CTX, D_HEAD]
699
  value: torch.Tensor, # [BATCH, N_HEADS, N_CTX, D_HEAD]
 
693
  torch.testing.assert_close(output_flash, output_triton_sparse, atol=1e-2, rtol=0)
694
 
695
 
696
+ def vertical_slash_sparse_attention(
697
  query: torch.Tensor, # [BATCH, N_HEADS, N_CTX, D_HEAD]
698
  key: torch.Tensor, # [BATCH, N_HEADS, N_CTX, D_HEAD]
699
  value: torch.Tensor, # [BATCH, N_HEADS, N_CTX, D_HEAD]