Fix RuntimeError: pad attn scores back to original query sequence length, instead of unpadded sequence length (i.e. no change).
#17
by
Birchlabs
- opened
- modeling_flash_llama.py +1 -1
modeling_flash_llama.py
CHANGED
@@ -378,7 +378,7 @@ class LlamaAttention(nn.Module):
|
|
378 |
|
379 |
attn_output = attn_outputs[0] if output_attentions else attn_outputs
|
380 |
attn_output = pad_input(
|
381 |
-
attn_output, indices_q, bsz,
|
382 |
).reshape(bsz, q_len, h_size)
|
383 |
attn_weights = attn_outputs[2] if output_attentions else None
|
384 |
|
|
|
378 |
|
379 |
attn_output = attn_outputs[0] if output_attentions else attn_outputs
|
380 |
attn_output = pad_input(
|
381 |
+
attn_output, indices_q, bsz, q_len
|
382 |
).reshape(bsz, q_len, h_size)
|
383 |
attn_weights = attn_outputs[2] if output_attentions else None
|
384 |
|