Resolve - 196 [rank0]: triton.runtime.autotuner.OutOfResources: out of resource: shared memory, Required: 180224, Hardware limit: 101376. Reducing block sizes or `num_stages` may help.

#33
by moidhassan - opened
positional_embedding.py CHANGED
@@ -269,10 +269,10 @@ class RotaryEmbedding(torch.nn.Module):
269
  return (
270
  apply_rotary_pos_emb(
271
  q, cos_cached[seqlen_offset:seq_len], sin_cached[seqlen_offset:seq_len], seq_dimension=seq_dimension
272
- ),
273
  apply_rotary_pos_emb(
274
  k, cos_cached[seqlen_offset:seq_len], sin_cached[seqlen_offset:seq_len], seq_dimension=seq_dimension
275
- ),
276
  )
277
 
278
  @classmethod
 
269
  return (
270
  apply_rotary_pos_emb(
271
  q, cos_cached[seqlen_offset:seq_len], sin_cached[seqlen_offset:seq_len], seq_dimension=seq_dimension
272
+ ).to(q.dtype),
273
  apply_rotary_pos_emb(
274
  k, cos_cached[seqlen_offset:seq_len], sin_cached[seqlen_offset:seq_len], seq_dimension=seq_dimension
275
+ ).to(q.dtype),
276
  )
277
 
278
  @classmethod
triton_flash_blocksparse_attn.py CHANGED
@@ -1020,7 +1020,7 @@ def blocksparse_flash_attn_padded_fwd(
1020
  BLOCK_M_LOADING = 16 if q_len == 1 else block_size, # smaller for decoding
1021
  EVEN_D = block_d == head_size,
1022
  num_warps = 1 if q_len == 1 else 4,
1023
- num_stages = 3
1024
  )
1025
 
1026
  return out
 
1020
  BLOCK_M_LOADING = 16 if q_len == 1 else block_size, # smaller for decoding
1021
  EVEN_D = block_d == head_size,
1022
  num_warps = 1 if q_len == 1 else 4,
1023
+ num_stages = 1 # <---- instead of 3
1024
  )
1025
 
1026
  return out