gugarosa commited on
Commit
39924aa
1 Parent(s): 534cce7

fix(modeling_phi3): Fixes inv_freq not being re-computed for extended RoPE.

Browse files
Files changed (1) hide show
  1. modeling_phi3.py +14 -14
modeling_phi3.py CHANGED
@@ -163,18 +163,18 @@ class Phi3SuScaledRotaryEmbedding(Phi3RotaryEmbedding):
163
 
164
  @torch.no_grad()
165
  def forward(self, x, position_ids, seq_len=None):
166
- position_ids_expanded = position_ids[:, None, :].float()
167
- if position_ids_expanded.shape[-1] > self.original_max_position_embeddings:
168
  ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device)
169
  else:
170
  ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=x.device)
171
 
172
- if self.inv_freq is None:
173
- self.inv_freq = 1.0 / (
174
- ext_factors
175
- * self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim)
176
- )
177
  inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
 
178
 
179
  # Force float32 since bfloat16 loses precision on long contexts
180
  # See https://github.com/huggingface/transformers/pull/29285
@@ -215,18 +215,18 @@ class Phi3YarnScaledRotaryEmbedding(Phi3RotaryEmbedding):
215
 
216
  @torch.no_grad()
217
  def forward(self, x, position_ids, seq_len=None):
218
- position_ids_expanded = position_ids[:, None, :].float()
219
- if position_ids_expanded.shape[-1] > self.original_max_position_embeddings:
220
  ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device)
221
  else:
222
  ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=x.device)
223
 
224
- if self.inv_freq is None:
225
- self.inv_freq = 1.0 / (
226
- ext_factors
227
- * self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim)
228
- )
229
  inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
 
230
 
231
  # Force float32 since bfloat16 loses precision on long contexts
232
  # See https://github.com/huggingface/transformers/pull/29285
 
163
 
164
  @torch.no_grad()
165
  def forward(self, x, position_ids, seq_len=None):
166
+ seq_len = torch.max(position_ids) + 1
167
+ if seq_len > self.original_max_position_embeddings:
168
  ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device)
169
  else:
170
  ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=x.device)
171
 
172
+ self.inv_freq = 1.0 / (
173
+ ext_factors
174
+ * self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim)
175
+ )
 
176
  inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
177
+ position_ids_expanded = position_ids[:, None, :].float()
178
 
179
  # Force float32 since bfloat16 loses precision on long contexts
180
  # See https://github.com/huggingface/transformers/pull/29285
 
215
 
216
  @torch.no_grad()
217
  def forward(self, x, position_ids, seq_len=None):
218
+ seq_len = torch.max(position_ids) + 1
219
+ if seq_len > self.original_max_position_embeddings:
220
  ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device)
221
  else:
222
  ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=x.device)
223
 
224
+ self.inv_freq = 1.0 / (
225
+ ext_factors
226
+ * self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim)
227
+ )
 
228
  inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
229
+ position_ids_expanded = position_ids[:, None, :].float()
230
 
231
  # Force float32 since bfloat16 loses precision on long contexts
232
  # See https://github.com/huggingface/transformers/pull/29285