gmastrapas commited on
Commit
71aae6d
1 Parent(s): 8771224

fix: handle window_size passed as list

Browse files
Files changed (1) hide show
  1. mha.py +4 -0
mha.py CHANGED
@@ -514,6 +514,10 @@ class MHA(nn.Module):
514
  alibi_slopes = torch.tensor(get_alibi_slopes(num_heads), device=device)
515
  else:
516
  alibi_slopes = None
 
 
 
 
517
  if window_size != (-1, -1):
518
  assert use_flash_attn, "Local (sliding window) attention code path requires flash_attn"
519
 
 
514
  alibi_slopes = torch.tensor(get_alibi_slopes(num_heads), device=device)
515
  else:
516
  alibi_slopes = None
517
+
518
+ if isinstance(window_size, list):
519
+ window_size = tuple(window_size)
520
+
521
  if window_size != (-1, -1):
522
  assert use_flash_attn, "Local (sliding window) attention code path requires flash_attn"
523