Guanzheng commited on
Commit
9a3865f
1 Parent(s): e910345

Update modeling_llama.py

Browse files
Files changed (1) hide show
  1. modeling_llama.py +1 -1
modeling_llama.py CHANGED
@@ -294,7 +294,7 @@ class LlamaAttention(nn.Module):
294
  key_states = torch.cat([past_key_value[0], key_states], dim=2)
295
  value_states = torch.cat([past_key_value[1], value_states], dim=2)
296
 
297
- past_key_value = (cache_key_states, value_states) if use_cache else None
298
 
299
  use_flashattn = self.config.use_flashattn and is_flash_attn_available()
300
 
 
294
  key_states = torch.cat([past_key_value[0], key_states], dim=2)
295
  value_states = torch.cat([past_key_value[1], value_states], dim=2)
296
 
297
+ past_key_value = (key_states, value_states) if use_cache else None
298
 
299
  use_flashattn = self.config.use_flashattn and is_flash_attn_available()
300