Error with Inference with past_key_values

#17
by dimaischenko - opened

It seems that logic of using past_key_values in generation is either not implemented or implemented with error.

I tried to write my own generation loop using past_key_values I got errors in _convert_to_rw_cache(past) in modelling_RW.py with tensor dimensions or "nonsense" in generation if try to skip this method. More details:

In modelling_RW.py there is this method

    def prepare_inputs_for_generation(
        self,
        input_ids: torch.LongTensor,
        past: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        **kwargs,
    ) -> dict:
        # only last token for input_ids if past is not None
        if past:
            input_ids = input_ids[:, -1].unsqueeze(-1)

            # the cache may be in the stardard format (e.g. in contrastive search), convert to our's format if needed
            if past[0][0].shape[0] == input_ids.shape[0]:
                past = self._convert_to_rw_cache(past)

        return {
            "input_ids": input_ids,
            "past_key_values": past,
            "use_cache": kwargs.get("use_cache"),
            "attention_mask": attention_mask,
        }

Now, if you debug the default example with pipeline generation from the model card description https://huggingface.co/tiiuae/falcon-7b , this bit of code from prepare_inputs_for_generation method will never be called:

...

        if past:
            input_ids = input_ids[:, -1].unsqueeze(-1)
            
            if past[0][0].shape[0] == input_ids.shape[0]:
                past = self._convert_to_rw_cache(past)

...

Because past on each iteration of the generation loop is None.

I try to write my own loop, in which I set past argument values and set past_key_values. After that I always get an error in dimensions in method _convert_to_rw_cache(past)

There is no error in the generation loop if I manually edit prepare_inputs_for_generation and skip _convert_to_rw_cache method and leave the original dimensions of the tensors. But I get "nonsense" when decoding most probable tokens in result.

It seems that logic of using past_key_values is either not implemented or implemented with error.

I would be very happy to hear from you. Because using past_key_values speeds up the inference several times.

P.S. Again, when using original generate or pipeline methods out of the box of huggingface with Falcon model, everything works as it should, but debugging shows that in these cases past_key_values are not actually used.

P.P.S. I also tried changing the logic in the _convert_to_rw(past) method, there is clearly something wrong with the expected dimensions in code, but this also failed. At best I got "nonsense" when decoding the result tokens

To add more clarity. Here is my generation cycle

device = torch.device("cuda")
model_id = "tiiuae/falcon-7b"

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    trust_remote_code=True,
    device_map="auto",
).to(device)

tokenizer = AutoTokenizer.from_pretrained(model_id)

text = "Hello there. How are"

inputs = tokenizer(text, return_tensors="pt").to(device)
input_ids = inputs["input_ids"]

output = None
step = 0

# generation cycle with 10 steps
while step < 10:
    attention_mask = input_ids.new_ones(input_ids.shape)

    past_key_values = None    
    if output is not None:
        past_key_values = output["past_key_values"]
       
    ids = model.prepare_inputs_for_generation(input_ids,
                                              past=past_key_values,
                                              use_cache=True,
                                              attention_mask=attention_mask)
    output = model(**ids)
    
    # get random of 3 most probable tokens and add to input_ids
    top_k = 3
    next_token = random.choice(torch.topk(output.logits[:, -1, :], top_k, dim=-1).indices[0])
    
    input_ids = torch.cat([input_ids, torch.tensor([[next_token]]).to(device)], dim=-1)
    
    step += 1
print(tokenizer.decode(input_ids[0]))
Hello there. How are!
,
. I

.<|endoftext|>

P.S.

I commented out this check in modelling_RW.py in prepare_inputs_for_generation method

'''
# the cache may be in the stardard format (e.g. in contrastive search), convert to our's format if needed
if past[0][0].shape[0] == input_ids.shape[0]:
    past = self._convert_to_rw_cache(past)
'''

Otherwise an error in the tensor dimensions will fall out

Same problem!

@FalconLLM I would be very grateful if you could tell if past_key_values is supposed to be used in the generation, or if this logic is not implemented? Perhaps it can be added or there are some fundamental limitations? After all its use significantly speeds up the time of inference

@FalconLLM Or maybe you can suggest a specialist from your team who would help sort out this issue? I will be very grateful!

Same problem, appreciate some suggestions from @FalconLLM !

It appears to me that RotaryEmbeddings obtains "sequence_length" from q input, which will be 1 when using KV cache. This makes embeddings incorrect.

I resolved this by passing in the position_id of the current token I'm generating with the following. Although our embeddings now match what we see without KV cache, our results are still garbage.

    def cos_sin(
        self,
        seq_len: int,
        device="cuda",
        dtype=torch.bfloat16,
        position=None
    ) -> torch.Tensor:
        if seq_len != self.seq_len_cached:
            self.seq_len_cached = seq_len
            # t = torch.arange(seq_len, device=device).type_as(self.inv_freq)
            t = torch.arange(position, device=device).type_as(self.inv_freq)
            freqs = torch.einsum("i,j->ij", t, self.inv_freq)
            emb = torch.cat((freqs, freqs), dim=-1).to(device)

            if dtype in [torch.float16, torch.bfloat16]:
                emb = emb.float()

            self.cos_cached = emb.cos()[None, :, :]
            self.sin_cached = emb.sin()[None, :, :]

            self.cos_cached = self.cos_cached.type(dtype)
            self.sin_cached = self.sin_cached.type(dtype)

        return (self.cos_cached[:, -1:, :], self.sin_cached[:, -1:, :]) if position != seq_len else (self.cos_cached, self.sin_cached)

    def forward(self, q, k, position):
        # q: q_new, b*nh x q_len x d
        # k: k_new, b*nh x q_len x d
        # position: true position index of these tokens
        # These aren't the true position ids of the tokens
        batch, seq_len, head_dim = q.shape

        cos, sin = self.cos_sin(seq_len, q.device, q.dtype, position)
        return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)

@ColmanTT There was also this hypothesis https://huggingface.co/tiiuae/falcon-40b/discussions/48 (in 40b discussions) , we discussed and tested it with @cchudant . But the results are also garbage

Btw, the way I'm attempting to run with KV cache is like this:

result = model(input_ids=input_ids, past_key_values=past_key_values, attention_mask=None, position_ids=None, use_cache=True, return_dict=True)

First iteration, input_ids contains prompt and past_key_values is None. Subsequent iterations, input_ids contains only the new token, and past_key_values is piped back into the model.

@LevanKvirkvelia Did you success? Do you propose to replace the Attention class in falcon model by FlashRWAttention from HF?

@ColmanTT . I got pretty output after I changed code like this:

    def forward(self, q, k, seq_len):
        # batch, seq_len, head_dim = q.shape
        _,q_len,_ = q.shape
        cos, sin = self.cos_sin(seq_len, q.device, q.dtype)
        cos = cos[:,-q_len:]
        sin = sin[:,-q_len:]

        cos_np = cos.detach().cpu().float().numpy()
        sin_np = sin.detach().cpu().float().numpy()
        return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)

and also change the code as below:

 if layer_past is not None:
                L = query_layer_.shape[-2]
                S = key_layer_.shape[-2]
                attn_mask = torch.ones(L, S, dtype=torch.bool, device=query_layer_.device)
                attn_output = F.scaled_dot_product_attention(
                    query_layer_, key_layer_, value_layer_, attn_mask, 0.0, is_causal=False
                )
            else:
                attn_output = F.scaled_dot_product_attention(
                    query_layer_, key_layer_, value_layer_, None, 0.0, is_causal=True
                )

@Tron2060 Please can you explain how do you pass new arguments to RotaryEmbedding , forward(self, q, k, seq_len)

The old way:

query_layer, key_layer = self.maybe_rotary(query_layer, key_layer)

There is not seq_len in this context. I changed it to:

query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, fused_qkv.shape[1])

In fact, I get something more or less readable, but it still seems to be very far from normal model generation. Perhaps I misused RotaryEmbedding

@dimaischenko
I pass the arguments by this way:

         _, seq_len, _ = query_layer.shape
        if layer_past is not None:
            _,seq_len_past,_=layer_past[0].shape

            seq_len=seq_len+seq_len_past

        query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, seq_len)

This will all get fixed eventually in the transformers GitHub code
https://github.com/huggingface/transformers/issues/25151#issuecomment-1654062690

Sign up or log in to comment