Spaces:
Runtime error
Runtime error
Update modeling_llama.py
Browse files- modeling_llama.py +1 -1
modeling_llama.py
CHANGED
@@ -294,7 +294,7 @@ class LlamaAttention(nn.Module):
|
|
294 |
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
295 |
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
296 |
|
297 |
-
past_key_value = (
|
298 |
|
299 |
use_flashattn = self.config.use_flashattn and is_flash_attn_available()
|
300 |
|
|
|
294 |
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
295 |
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
296 |
|
297 |
+
past_key_value = (key_states, value_states) if use_cache else None
|
298 |
|
299 |
use_flashattn = self.config.use_flashattn and is_flash_attn_available()
|
300 |
|