llama2 forward pass seemingly not working with padded inputs, unless one element in batch is not padded

#13
by joehakim - opened

From this discussion thread [https://github.com/huggingface/transformers/issues/26601], moved to here. Basically this seems to be an issue with padding, only when trust_remote_code=True, so maybe related to FlashAttention?

Here's a script to reproduce,

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaTokenizerFast


tokenizer = AutoTokenizer.from_pretrained("togethercomputer/Llama-2-7B-32K-Instruct")

tokenizer = LlamaTokenizerFast.from_pretrained(
    "togethercomputer/Llama-2-7B-32K-Instruct"
)

model = AutoModelForCausalLM.from_pretrained(
    "togethercomputer/Llama-2-7B-32K-Instruct",
    trust_remote_code=True, # this works when this is False
    torch_dtype=torch.float16,
).cuda()

""" THIS works in both cases
model = MT5ForConditionalGeneration.from_pretrained(
    'google/mt5-xl'
"""

encoded = tokenizer(
    [
        "[INST]\nWrite a poem about cats\n[/INST]\n\n",
        "[INST]\nWrite " + "a poem about" * 400 + " cats\n[/INST]\n\n",
    ],
    return_tensors="pt",
    padding="longest",
).to(model.device)

encoded_firstelem = {
    "input_ids": encoded["input_ids"][:1, :],
    "attention_mask": encoded["attention_mask"][:1, :],
}
breakpoint()

print(encoded_firstelem)
# {'input_ids': tensor([[    0,     0,     0,  ..., 29962,    13,    13]], device='cuda:0'), 'attention_mask': tensor([[0, 0, 0,  ..., 1, 1, 1]], device='cuda:0')}

# works
print(model(**encoded))

# breaks
print(model(**encoded_firstelem))
Together org

Hi @joehakim and thanks for reporting this!

I think the error you see when feeding only the first element comes from a mismatch between q_len and max_seqlen_q, because of the unnecessary padding of the first element.

For your specific example, this is caused by the following steps in `modelling_flash_llama.py:

  1. bsz, q_len, h_size = hidden_states.size() (L311) -- this reads the sequence length from the padded input which is 1215.
  2. unpadded_q, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, attention_mask[:, -q.size(1):]) (L371) -- here the padding gets removed and your max_seqlen_q becomes 18.
  3. attn_output = pad_input(attn_output, indices_q, bsz, max_seqlen_q).reshape(bsz, q_len, h_size) (L380-382) -- this is were the error happens due to the mismatch between q_len and max_seqlen_q

So that means that you can't process a batch where the actual (unpadded) sequence length is smaller than the longest (padded) sequence in your batch.

I am encountering the same error, ie a mismatch between q_len and max_seqlen_q gives
RuntimeError: shape '[4, 6400, 4096]' is invalid for input of size 14811136

Is there a solution to this issue?

Hi @mauriceweber - Is there support for batches containing different lengths of unpadded sequences?

Sign up or log in to comment