Fix RuntimeError: pad attn scores back to original query sequence length, instead of unpadded sequence length (i.e. no change).

#17
by Birchlabs - opened
Files changed (1) hide show
  1. modeling_flash_llama.py +1 -1
modeling_flash_llama.py CHANGED
@@ -378,7 +378,7 @@ class LlamaAttention(nn.Module):
378
 
379
  attn_output = attn_outputs[0] if output_attentions else attn_outputs
380
  attn_output = pad_input(
381
- attn_output, indices_q, bsz, max_seqlen_q
382
  ).reshape(bsz, q_len, h_size)
383
  attn_weights = attn_outputs[2] if output_attentions else None
384
 
 
378
 
379
  attn_output = attn_outputs[0] if output_attentions else attn_outputs
380
  attn_output = pad_input(
381
+ attn_output, indices_q, bsz, q_len
382
  ).reshape(bsz, q_len, h_size)
383
  attn_weights = attn_outputs[2] if output_attentions else None
384