fix(modeling_phi3): Fixes inv_freq not being re-computed for extended RoPE.
Browse files- modeling_phi3.py +14 -14
modeling_phi3.py
CHANGED
@@ -163,18 +163,18 @@ class Phi3SuScaledRotaryEmbedding(Phi3RotaryEmbedding):
|
|
163 |
|
164 |
@torch.no_grad()
|
165 |
def forward(self, x, position_ids, seq_len=None):
|
166 |
-
|
167 |
-
if
|
168 |
ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device)
|
169 |
else:
|
170 |
ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=x.device)
|
171 |
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
)
|
177 |
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
|
|
178 |
|
179 |
# Force float32 since bfloat16 loses precision on long contexts
|
180 |
# See https://github.com/huggingface/transformers/pull/29285
|
@@ -215,18 +215,18 @@ class Phi3YarnScaledRotaryEmbedding(Phi3RotaryEmbedding):
|
|
215 |
|
216 |
@torch.no_grad()
|
217 |
def forward(self, x, position_ids, seq_len=None):
|
218 |
-
|
219 |
-
if
|
220 |
ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device)
|
221 |
else:
|
222 |
ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=x.device)
|
223 |
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
)
|
229 |
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
|
|
230 |
|
231 |
# Force float32 since bfloat16 loses precision on long contexts
|
232 |
# See https://github.com/huggingface/transformers/pull/29285
|
|
|
163 |
|
164 |
@torch.no_grad()
|
165 |
def forward(self, x, position_ids, seq_len=None):
|
166 |
+
seq_len = torch.max(position_ids) + 1
|
167 |
+
if seq_len > self.original_max_position_embeddings:
|
168 |
ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device)
|
169 |
else:
|
170 |
ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=x.device)
|
171 |
|
172 |
+
self.inv_freq = 1.0 / (
|
173 |
+
ext_factors
|
174 |
+
* self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim)
|
175 |
+
)
|
|
|
176 |
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
177 |
+
position_ids_expanded = position_ids[:, None, :].float()
|
178 |
|
179 |
# Force float32 since bfloat16 loses precision on long contexts
|
180 |
# See https://github.com/huggingface/transformers/pull/29285
|
|
|
215 |
|
216 |
@torch.no_grad()
|
217 |
def forward(self, x, position_ids, seq_len=None):
|
218 |
+
seq_len = torch.max(position_ids) + 1
|
219 |
+
if seq_len > self.original_max_position_embeddings:
|
220 |
ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device)
|
221 |
else:
|
222 |
ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=x.device)
|
223 |
|
224 |
+
self.inv_freq = 1.0 / (
|
225 |
+
ext_factors
|
226 |
+
* self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim)
|
227 |
+
)
|
|
|
228 |
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
229 |
+
position_ids_expanded = position_ids[:, None, :].float()
|
230 |
|
231 |
# Force float32 since bfloat16 loses precision on long contexts
|
232 |
# See https://github.com/huggingface/transformers/pull/29285
|