Satandon1999 commited on
Commit
351b4fd
1 Parent(s): ed7de9a

Update triton_flash_blocksparse_attn.py

Browse files

Adding with ```torch.cuda.device(q.device.index)``` at all applicable sections to support multi gpu.

Files changed (1) hide show
  1. triton_flash_blocksparse_attn.py +64 -62
triton_flash_blocksparse_attn.py CHANGED
@@ -992,37 +992,38 @@ def blocksparse_flash_attn_padded_fwd(
992
 
993
  grid = (len(q_start_sids), n_heads)
994
 
995
- _fwd_kernel_batch_inference[grid](
996
- q, k, v, out,
997
- sm_scale,
998
- q_batch_starts,
999
- q_batch_ends,
1000
- k_batch_starts,
1001
- k_batch_ends,
1002
- q_batch_ids,
1003
- q_start_sids,
1004
-
1005
- *q.stride(),
1006
- *k.stride(),
1007
- *v.stride(),
1008
- *out.stride(),
1009
-
1010
- layout_crow_indices,
1011
- layout_col_indices,
1012
- *layout_crow_indices.stride(),
1013
- *layout_col_indices.stride(),
1014
-
1015
- q_k_ratio,
1016
- HAS_BATCH_DIM = True,
1017
- D_HEAD = head_size,
1018
- BLOCK_M = block_size,
1019
- BLOCK_N = block_size,
1020
- BLOCK_D = block_d,
1021
- BLOCK_M_LOADING = 16 if q_len == 1 else block_size, # smaller for decoding
1022
- EVEN_D = block_d == head_size,
1023
- num_warps = 1 if q_len == 1 else 4,
1024
- num_stages = 3
1025
- )
 
1026
 
1027
  return out
1028
 
@@ -1094,37 +1095,38 @@ def blocksparse_flash_attn_varlen_fwd(
1094
 
1095
  grid = (len(q_start_sids), n_heads)
1096
 
1097
- _fwd_kernel_batch_inference[grid](
1098
- q, k, v, out,
1099
- sm_scale,
1100
- cu_seqlens_q[:-1],
1101
- cu_seqlens_q[1:],
1102
- cu_seqlens_k[:-1],
1103
- cu_seqlens_k[1:],
1104
- q_batch_ids,
1105
- q_start_sids,
1106
-
1107
- 0, *q.stride(),
1108
- 0, *k.stride(),
1109
- 0, *v.stride(),
1110
- 0, *out.stride(),
1111
-
1112
- layout_crow_indices,
1113
- layout_col_indices,
1114
- *layout_crow_indices.stride(),
1115
- *layout_col_indices.stride(),
1116
-
1117
- q_k_ratio,
1118
- HAS_BATCH_DIM = False,
1119
- D_HEAD = head_size,
1120
- BLOCK_M = block_size,
1121
- BLOCK_N = block_size,
1122
- BLOCK_D = block_d,
1123
- BLOCK_M_LOADING = 16 if decoding_only else block_size, # smaller for decoding
1124
- EVEN_D = block_d == head_size,
1125
- num_warps = 1 if decoding_only else 4,
1126
- num_stages = 3
1127
- )
 
1128
 
1129
  return out
1130
 
 
992
 
993
  grid = (len(q_start_sids), n_heads)
994
 
995
+ with torch.cuda.device(q.device.index):
996
+ _fwd_kernel_batch_inference[grid](
997
+ q, k, v, out,
998
+ sm_scale,
999
+ q_batch_starts,
1000
+ q_batch_ends,
1001
+ k_batch_starts,
1002
+ k_batch_ends,
1003
+ q_batch_ids,
1004
+ q_start_sids,
1005
+
1006
+ *q.stride(),
1007
+ *k.stride(),
1008
+ *v.stride(),
1009
+ *out.stride(),
1010
+
1011
+ layout_crow_indices,
1012
+ layout_col_indices,
1013
+ *layout_crow_indices.stride(),
1014
+ *layout_col_indices.stride(),
1015
+
1016
+ q_k_ratio,
1017
+ HAS_BATCH_DIM = True,
1018
+ D_HEAD = head_size,
1019
+ BLOCK_M = block_size,
1020
+ BLOCK_N = block_size,
1021
+ BLOCK_D = block_d,
1022
+ BLOCK_M_LOADING = 16 if q_len == 1 else block_size, # smaller for decoding
1023
+ EVEN_D = block_d == head_size,
1024
+ num_warps = 1 if q_len == 1 else 4,
1025
+ num_stages = 3
1026
+ )
1027
 
1028
  return out
1029
 
 
1095
 
1096
  grid = (len(q_start_sids), n_heads)
1097
 
1098
+ with torch.cuda.device(q.device.index):
1099
+ _fwd_kernel_batch_inference[grid](
1100
+ q, k, v, out,
1101
+ sm_scale,
1102
+ cu_seqlens_q[:-1],
1103
+ cu_seqlens_q[1:],
1104
+ cu_seqlens_k[:-1],
1105
+ cu_seqlens_k[1:],
1106
+ q_batch_ids,
1107
+ q_start_sids,
1108
+
1109
+ 0, *q.stride(),
1110
+ 0, *k.stride(),
1111
+ 0, *v.stride(),
1112
+ 0, *out.stride(),
1113
+
1114
+ layout_crow_indices,
1115
+ layout_col_indices,
1116
+ *layout_crow_indices.stride(),
1117
+ *layout_col_indices.stride(),
1118
+
1119
+ q_k_ratio,
1120
+ HAS_BATCH_DIM = False,
1121
+ D_HEAD = head_size,
1122
+ BLOCK_M = block_size,
1123
+ BLOCK_N = block_size,
1124
+ BLOCK_D = block_d,
1125
+ BLOCK_M_LOADING = 16 if decoding_only else block_size, # smaller for decoding
1126
+ EVEN_D = block_d == head_size,
1127
+ num_warps = 1 if decoding_only else 4,
1128
+ num_stages = 3
1129
+ )
1130
 
1131
  return out
1132