Update modeling_llama.py
Browse filesupdating LlamaRotaryEmbedding based on transformers==4.49
- modeling_llama.py +1 -21
modeling_llama.py
CHANGED
|
@@ -189,27 +189,7 @@ class LlamaAttention(nn.Module):
|
|
| 189 |
self._init_rope()
|
| 190 |
|
| 191 |
def _init_rope(self):
|
| 192 |
-
|
| 193 |
-
self.rotary_emb = LlamaRotaryEmbedding(self.config)
|
| 194 |
-
else:
|
| 195 |
-
scaling_type = self.config.rope_scaling["type"]
|
| 196 |
-
scaling_factor = self.config.rope_scaling["factor"]
|
| 197 |
-
if scaling_type == "linear":
|
| 198 |
-
self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
|
| 199 |
-
self.head_dim,
|
| 200 |
-
max_position_embeddings=self.max_position_embeddings,
|
| 201 |
-
scaling_factor=scaling_factor,
|
| 202 |
-
base=self.rope_theta,
|
| 203 |
-
)
|
| 204 |
-
elif scaling_type == "dynamic":
|
| 205 |
-
self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
|
| 206 |
-
self.head_dim,
|
| 207 |
-
max_position_embeddings=self.max_position_embeddings,
|
| 208 |
-
scaling_factor=scaling_factor,
|
| 209 |
-
base=self.rope_theta,
|
| 210 |
-
)
|
| 211 |
-
else:
|
| 212 |
-
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
| 213 |
|
| 214 |
def forward(
|
| 215 |
self,
|
|
|
|
| 189 |
self._init_rope()
|
| 190 |
|
| 191 |
def _init_rope(self):
|
| 192 |
+
self.rotary_emb = LlamaRotaryEmbedding(self.config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 193 |
|
| 194 |
def forward(
|
| 195 |
self,
|