kz919 commited on
Commit
df270b8
1 Parent(s): da84cc6

Update modeling_sliding_llama.py

Browse files
Files changed (1) hide show
  1. modeling_sliding_llama.py +1 -3
modeling_sliding_llama.py CHANGED
@@ -438,9 +438,7 @@ class LlamaFlashAttention2(LlamaAttention):
438
  # key_seq_len += cache_position[0]
439
  key_seq_len += past_key_value.get_usable_length(key_seq_len, self.layer_idx)
440
 
441
- rotary_seq_len = max(key_seq_len, position_ids[:, -1].max().item()) + 1
442
- # cos, sin = self.rotary_emb(value_states, position_ids)
443
- cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)
444
 
445
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
446
 
 
438
  # key_seq_len += cache_position[0]
439
  key_seq_len += past_key_value.get_usable_length(key_seq_len, self.layer_idx)
440
 
441
+ cos, sin = self.rotary_emb(value_states, position_ids)
 
 
442
 
443
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
444