Update modeling_llama.py
Browse filesfor backward compatibility
- modeling_llama.py +2 -1
modeling_llama.py
CHANGED
|
@@ -966,6 +966,7 @@ class LlamaModel(LlamaPreTrainedModel):
|
|
| 966 |
past_key_values, StaticCache
|
| 967 |
):
|
| 968 |
if not isinstance(past_key_values, DynamicCache):
|
|
|
|
| 969 |
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
| 970 |
past_seen_tokens = past_key_values.get_seq_length()
|
| 971 |
|
|
@@ -1037,7 +1038,7 @@ class LlamaModel(LlamaPreTrainedModel):
|
|
| 1037 |
|
| 1038 |
next_cache = None
|
| 1039 |
if use_cache:
|
| 1040 |
-
next_cache = next_decoder_cache
|
| 1041 |
if not return_dict:
|
| 1042 |
return tuple(
|
| 1043 |
v
|
|
|
|
| 966 |
past_key_values, StaticCache
|
| 967 |
):
|
| 968 |
if not isinstance(past_key_values, DynamicCache):
|
| 969 |
+
used_legacy_cache=True
|
| 970 |
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
| 971 |
past_seen_tokens = past_key_values.get_seq_length()
|
| 972 |
|
|
|
|
| 1038 |
|
| 1039 |
next_cache = None
|
| 1040 |
if use_cache:
|
| 1041 |
+
next_cache = next_decoder_cache.to_legacy_cache() if used_legacy_cache else next_decoder_cache
|
| 1042 |
if not return_dict:
|
| 1043 |
return tuple(
|
| 1044 |
v
|