Update modeling_phi2_clex.py
Browse files- modeling_phi2_clex.py +4 -5
modeling_phi2_clex.py
CHANGED
@@ -59,7 +59,10 @@ logger = logging.get_logger(__name__)
|
|
59 |
_CHECKPOINT_FOR_DOC = "microsoft/phi-2"
|
60 |
_CONFIG_FOR_DOC = "CLEXPhiConfig"
|
61 |
|
62 |
-
|
|
|
|
|
|
|
63 |
|
64 |
|
65 |
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
|
@@ -373,10 +376,6 @@ class PhiAttention(nn.Module):
|
|
373 |
# [batch_size, seq_length, num_heads, head_dim]
|
374 |
query_states = torch.cat((query_rot, query_pass), dim=-1)
|
375 |
key_states = torch.cat((key_rot, key_pass), dim=-1)
|
376 |
-
rotary_dim = int(self.partial_rotary_factor * self.head_dim)
|
377 |
-
if past_key_value is not None:
|
378 |
-
cache_kwargs = {"sin": sin, "cos": cos, "partial_rotation_size": rotary_dim}
|
379 |
-
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
380 |
|
381 |
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
382 |
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
|
|
59 |
_CHECKPOINT_FOR_DOC = "microsoft/phi-2"
|
60 |
_CONFIG_FOR_DOC = "CLEXPhiConfig"
|
61 |
|
62 |
+
PHI_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
63 |
+
"microsoft/phi-2",
|
64 |
+
# See all Phi models at https://huggingface.co/models?filter=phi
|
65 |
+
]
|
66 |
|
67 |
|
68 |
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
|
|
|
376 |
# [batch_size, seq_length, num_heads, head_dim]
|
377 |
query_states = torch.cat((query_rot, query_pass), dim=-1)
|
378 |
key_states = torch.cat((key_rot, key_pass), dim=-1)
|
|
|
|
|
|
|
|
|
379 |
|
380 |
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
381 |
value_states = repeat_kv(value_states, self.num_key_value_groups)
|