Fix RuntimeError: pad attn scores back to original query sequence length, instead of unpadded sequence length (i.e. no change).
#17
by
Birchlabs
- opened
Prevents RuntimeError on line 382's pad_input(β¦).reshape()
shape '[1, 4096, 4096]' is invalid for input of size 9400320
before this change, pad_input() was basically just doing a .unsqueeze(0):
attn_output.shape
torch.Size([2295, 32, 128])
pad_input(attn_output, indices_q, bsz, max_seqlen_q).shape
torch.Size([1, 2295, 32, 128])
after this change: pad_input actually pads the input back to the original query sequence length:
pad_input(attn_output, indices_q, bsz, q_len).shape
torch.Size([1, 4096, 32, 128])
and the reshape succeeds:
pad_input(attn_output, indices_q, bsz, q_len).reshape(bsz, q_len, h_size).shape
torch.Size([1, 4096, 4096])
I was getting a similar error. Thanks for the change... when will this get merged?