zifei9 commited on
Commit
f2dcb86
·
verified ·
1 Parent(s): ef4e33e

Update modeling_llama.py

Browse files

updating LlamaRotaryEmbedding based on transformers==4.49

Files changed (1) hide show
  1. modeling_llama.py +1 -21
modeling_llama.py CHANGED
@@ -189,27 +189,7 @@ class LlamaAttention(nn.Module):
189
  self._init_rope()
190
 
191
  def _init_rope(self):
192
- if self.config.rope_scaling is None:
193
- self.rotary_emb = LlamaRotaryEmbedding(self.config)
194
- else:
195
- scaling_type = self.config.rope_scaling["type"]
196
- scaling_factor = self.config.rope_scaling["factor"]
197
- if scaling_type == "linear":
198
- self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
199
- self.head_dim,
200
- max_position_embeddings=self.max_position_embeddings,
201
- scaling_factor=scaling_factor,
202
- base=self.rope_theta,
203
- )
204
- elif scaling_type == "dynamic":
205
- self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
206
- self.head_dim,
207
- max_position_embeddings=self.max_position_embeddings,
208
- scaling_factor=scaling_factor,
209
- base=self.rope_theta,
210
- )
211
- else:
212
- raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
213
 
214
  def forward(
215
  self,
 
189
  self._init_rope()
190
 
191
  def _init_rope(self):
192
+ self.rotary_emb = LlamaRotaryEmbedding(self.config)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
 
194
  def forward(
195
  self,