Text Generation
Transformers
PyTorch
English
llama
Inference Endpoints
text-generation-inference

Fix RuntimeError: pad attn scores back to original query sequence length, instead of unpadded sequence length (i.e. no change).

#17
by Birchlabs - opened

Prevents RuntimeError on line 382's pad_input(…).reshape()
shape '[1, 4096, 4096]' is invalid for input of size 9400320

before this change, pad_input() was basically just doing a .unsqueeze(0):
attn_output.shape
torch.Size([2295, 32, 128])
pad_input(attn_output, indices_q, bsz, max_seqlen_q).shape
torch.Size([1, 2295, 32, 128])

after this change: pad_input actually pads the input back to the original query sequence length:
pad_input(attn_output, indices_q, bsz, q_len).shape
torch.Size([1, 4096, 32, 128])
and the reshape succeeds:
pad_input(attn_output, indices_q, bsz, q_len).reshape(bsz, q_len, h_size).shape
torch.Size([1, 4096, 4096])

I was getting a similar error. Thanks for the change... when will this get merged?

Ready to merge
This branch is ready to get merged automatically.

Sign up or log in to comment