Update modeling_Llamoe.py
Browse files- modeling_Llamoe.py +28 -71
modeling_Llamoe.py
CHANGED
@@ -162,60 +162,34 @@ ALL_LAYERNORM_LAYERS.append(LlamoeRMSNorm)
|
|
162 |
|
163 |
|
164 |
class LlamoeRotaryEmbedding(nn.Module):
|
165 |
-
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None
|
166 |
super().__init__()
|
167 |
-
self.scaling_factor = scaling_factor
|
168 |
self.dim = dim
|
169 |
self.max_position_embeddings = max_position_embeddings
|
170 |
self.base = base
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
self.max_seq_len_cached =
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
"The cos_cached attribute will be removed in 4.39. Bear in mind that its contents changed in v4.38. Use "
|
195 |
-
"the forward method of RoPE from now on instead. It is not used in the `LlamaAttention` class"
|
196 |
-
)
|
197 |
-
return self._cos_cached
|
198 |
-
|
199 |
-
@torch.no_grad()
|
200 |
-
def forward(self, x, position_ids):
|
201 |
-
# x: [bs, num_attention_heads, seq_len, head_size]
|
202 |
-
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
203 |
-
position_ids_expanded = position_ids[:, None, :].float()
|
204 |
-
# Force float32 since bfloat16 loses precision on long contexts
|
205 |
-
# See https://github.com/huggingface/transformers/pull/29285
|
206 |
-
device_type = x.device.type
|
207 |
-
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
208 |
-
with torch.autocast(device_type=device_type, enabled=False):
|
209 |
-
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
210 |
-
emb = torch.cat((freqs, freqs), dim=-1)
|
211 |
-
cos = emb.cos()
|
212 |
-
sin = emb.sin()
|
213 |
-
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
def rotate_half(x):
|
220 |
"""Rotates half the hidden dims of the input."""
|
221 |
x1 = x[..., : x.shape[-1] // 2]
|
@@ -224,32 +198,15 @@ def rotate_half(x):
|
|
224 |
|
225 |
|
226 |
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
q (`torch.Tensor`): The query tensor.
|
231 |
-
k (`torch.Tensor`): The key tensor.
|
232 |
-
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
233 |
-
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
234 |
-
position_ids (`torch.Tensor`, *optional*):
|
235 |
-
Deprecated and unused.
|
236 |
-
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
237 |
-
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
238 |
-
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
239 |
-
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
240 |
-
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
241 |
-
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
242 |
-
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
243 |
-
Returns:
|
244 |
-
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
245 |
-
"""
|
246 |
-
cos = cos.unsqueeze(unsqueeze_dim)
|
247 |
-
sin = sin.unsqueeze(unsqueeze_dim)
|
248 |
q_embed = (q * cos) + (rotate_half(q) * sin)
|
249 |
k_embed = (k * cos) + (rotate_half(k) * sin)
|
250 |
return q_embed, k_embed
|
251 |
|
252 |
|
|
|
253 |
class LlamoeBlockSparseTop2MLP(nn.Module):
|
254 |
def __init__(self, config: LlamoeConfig):
|
255 |
super().__init__()
|
|
|
162 |
|
163 |
|
164 |
class LlamoeRotaryEmbedding(nn.Module):
|
165 |
+
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
166 |
super().__init__()
|
|
|
167 |
self.dim = dim
|
168 |
self.max_position_embeddings = max_position_embeddings
|
169 |
self.base = base
|
170 |
+
self._set_cos_sin_cache(seq_len=max_position_embeddings, device=device, dtype=torch.get_default_dtype())
|
171 |
+
|
172 |
+
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
173 |
+
self.max_seq_len_cached = seq_len
|
174 |
+
freq_exponents = (2.0 / self.dim) * (torch.arange(self.dim // 2, dtype=torch.float32, device="cpu").float())
|
175 |
+
timescale = self.base ** freq_exponents
|
176 |
+
positions = torch.arange(self.max_seq_len_cached, device="cpu", dtype=torch.float32).float()
|
177 |
+
radians_new = positions[..., None] / timescale[None, None, :]
|
178 |
+
radians_new = radians_new.squeeze(0)
|
179 |
+
emb = torch.cat((radians_new, radians_new), dim=-1)
|
180 |
+
cos = emb.cos().to(device=device, dtype=dtype, non_blocking=True)
|
181 |
+
sin = emb.sin().to(device=device, dtype=dtype, non_blocking=True)
|
182 |
+
self.register_buffer("cos_cached", cos, persistent=False)
|
183 |
+
self.register_buffer("sin_cached", sin, persistent=False)
|
184 |
+
|
185 |
+
def forward(self, x, position_ids=None, seq_len=None):
|
186 |
+
if seq_len is None:
|
187 |
+
seq_len = x.size(2)
|
188 |
+
if seq_len > self.max_seq_len_cached:
|
189 |
+
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
|
190 |
+
return self.cos_cached[:seq_len], self.sin_cached[:seq_len]
|
191 |
+
|
192 |
+
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
193 |
def rotate_half(x):
|
194 |
"""Rotates half the hidden dims of the input."""
|
195 |
x1 = x[..., : x.shape[-1] // 2]
|
|
|
198 |
|
199 |
|
200 |
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
201 |
+
seq_len, dim = q.shape[-2], q.shape[-1]
|
202 |
+
cos = cos[:seq_len].view(1, 1, seq_len, dim)
|
203 |
+
sin = sin[:seq_len].view(1, 1, seq_len, dim)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
204 |
q_embed = (q * cos) + (rotate_half(q) * sin)
|
205 |
k_embed = (k * cos) + (rotate_half(k) * sin)
|
206 |
return q_embed, k_embed
|
207 |
|
208 |
|
209 |
+
|
210 |
class LlamoeBlockSparseTop2MLP(nn.Module):
|
211 |
def __init__(self, config: LlamoeConfig):
|
212 |
super().__init__()
|