Update modeling_sliding_llama.py
Browse files
modeling_sliding_llama.py
CHANGED
@@ -438,9 +438,7 @@ class LlamaFlashAttention2(LlamaAttention):
|
|
438 |
# key_seq_len += cache_position[0]
|
439 |
key_seq_len += past_key_value.get_usable_length(key_seq_len, self.layer_idx)
|
440 |
|
441 |
-
|
442 |
-
# cos, sin = self.rotary_emb(value_states, position_ids)
|
443 |
-
cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)
|
444 |
|
445 |
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
446 |
|
|
|
438 |
# key_seq_len += cache_position[0]
|
439 |
key_seq_len += past_key_value.get_usable_length(key_seq_len, self.layer_idx)
|
440 |
|
441 |
+
cos, sin = self.rotary_emb(value_states, position_ids)
|
|
|
|
|
442 |
|
443 |
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
444 |
|