Jiqing commited on
Commit
9c07885
·
verified ·
1 Parent(s): ded0c00

Update modeling_kangaroo.py

Browse files
Files changed (1) hide show
  1. modeling_kangaroo.py +3 -3
modeling_kangaroo.py CHANGED
@@ -1020,7 +1020,7 @@ class LlamaModel(LlamaPreTrainedModel):
1020
  min_dtype = torch.finfo(dtype).min
1021
  sequence_length = input_tensor.shape[1]
1022
  if using_static_cache:
1023
- target_length = past_key_values.get_max_length()
1024
  else:
1025
  target_length = (
1026
  attention_mask.shape[-1]
@@ -1308,8 +1308,8 @@ class KangarooForCausalLM(LlamaPreTrainedModel):
1308
  if isinstance(past_key_values, Cache):
1309
  past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
1310
  max_cache_length = (
1311
- torch.tensor(past_key_values.get_max_length(), device=input_ids.device)
1312
- if past_key_values.get_max_length() is not None
1313
  else None
1314
  )
1315
  cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length)
 
1020
  min_dtype = torch.finfo(dtype).min
1021
  sequence_length = input_tensor.shape[1]
1022
  if using_static_cache:
1023
+ target_length = past_key_values.get_seq_length()
1024
  else:
1025
  target_length = (
1026
  attention_mask.shape[-1]
 
1308
  if isinstance(past_key_values, Cache):
1309
  past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
1310
  max_cache_length = (
1311
+ torch.tensor(past_key_values.get_seq_length(), device=input_ids.device)
1312
+ if past_key_values.get_seq_length() is not None
1313
  else None
1314
  )
1315
  cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length)