Crystalcareai commited on
Commit
292a484
·
verified ·
1 Parent(s): c28e2ee

Update modeling_gemmoe.py

Browse files
Files changed (1) hide show
  1. modeling_gemmoe.py +583 -605
modeling_gemmoe.py CHANGED
@@ -28,6 +28,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
28
  from transformers.activations import ACT2FN
29
  from transformers.cache_utils import Cache, DynamicCache, StaticCache
30
  from transformers.modeling_attn_mask_utils import (
 
31
  _prepare_4d_causal_attention_mask,
32
  )
33
  from transformers.modeling_outputs import SequenceClassifierOutputWithPast, MoeModelOutputWithPast, MoeCausalLMOutputWithPast
@@ -176,52 +177,71 @@ ALL_LAYERNORM_LAYERS.append(GemmoeRMSNorm)
176
  class GemmoeRotaryEmbedding(nn.Module):
177
  def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
178
  super().__init__()
 
179
  self.dim = dim
180
  self.max_position_embeddings = max_position_embeddings
181
  self.base = base
182
- self._set_cos_sin_cache(seq_len=max_position_embeddings, device=device, dtype=torch.get_default_dtype())
183
-
184
- def _set_cos_sin_cache(self, seq_len, device, dtype):
185
- self.max_seq_len_cached = seq_len
186
- freq_exponents = (2.0 / self.dim) * (
187
- torch.arange(self.dim // 2, dtype=torch.int64, device="cpu").float()
188
- )
189
- timescale = self.base ** freq_exponents
190
- positions = torch.arange(self.max_seq_len_cached, device="cpu", dtype=torch.int64).float()
191
- radians_new = positions[..., None] / timescale[None, None, :]
192
- radians_new = radians_new.squeeze(0)
193
- emb = torch.cat((radians_new, radians_new), dim=-1)
194
- cos = emb.cos().to(device=device, non_blocking=True)
195
- sin = emb.sin().to(device=device, non_blocking=True)
196
- self.register_buffer("cos_cached", cos, persistent=False)
197
- self.register_buffer("sin_cached", sin, persistent=False)
198
-
199
- def forward(self, x, position_ids=None, seq_len=None):
200
- if seq_len is None:
201
- seq_len = x.size(2)
202
- if seq_len > self.max_seq_len_cached:
203
- self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
204
- return (
205
- self.cos_cached[:seq_len],
206
- self.sin_cached[:seq_len],
207
- )
208
-
209
  def rotate_half(x):
210
  """Rotates half the hidden dims of the input."""
211
  x1 = x[..., : x.shape[-1] // 2]
212
  x2 = x[..., x.shape[-1] // 2 :]
213
  return torch.cat((-x2, x1), dim=-1)
214
 
215
- def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None):
216
- """Applies Rotary Position Embedding to the query and key tensors."""
217
- seq_len, dim = q.shape[-2], q.shape[-1]
218
- cos = cos[:seq_len].view(1, 1, seq_len, dim)
219
- sin = sin[:seq_len].view(1, 1, seq_len, dim)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220
  q_embed = (q * cos) + (rotate_half(q) * sin)
221
  k_embed = (k * cos) + (rotate_half(k) * sin)
222
  return q_embed, k_embed
223
 
224
- def repeat_kv(self, hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
 
 
 
225
  """
226
  This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
227
  num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
@@ -233,14 +253,9 @@ def repeat_kv(self, hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
233
  return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
234
 
235
  class GemmoeAttention(nn.Module):
236
- """
237
- Multi-headed attention module for Gemmoe model.
238
-
239
- Args:
240
- config (GemmoeConfig): The configuration object for the Gemmoe model.
241
- layer_idx (Optional[int]): The index of the layer. Default is None.
242
- """
243
 
 
244
  def __init__(self, config: GemmoeConfig, layer_idx: Optional[int] = None):
245
  super().__init__()
246
  self.config = config
@@ -251,6 +266,7 @@ class GemmoeAttention(nn.Module):
251
  "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
252
  "when creating this class."
253
  )
 
254
  self.attention_dropout = config.attention_dropout
255
  self.hidden_size = config.hidden_size
256
  self.num_heads = config.num_attention_heads
@@ -266,15 +282,16 @@ class GemmoeAttention(nn.Module):
266
  f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
267
  f" and `num_heads`: {self.num_heads})."
268
  )
 
269
  self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
270
  self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
271
  self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
272
- self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
273
  self.rotary_emb = GemmoeRotaryEmbedding(
274
- self.head_dim,
275
- max_position_embeddings=self.max_position_embeddings,
276
- base=self.rope_theta,
277
- )
278
 
279
  def forward(
280
  self,
@@ -287,25 +304,6 @@ class GemmoeAttention(nn.Module):
287
  cache_position: Optional[torch.LongTensor] = None,
288
  **kwargs,
289
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
290
- """
291
- Forward pass of the attention module.
292
-
293
- Args:
294
- hidden_states (torch.Tensor): The input hidden states.
295
- attention_mask (Optional[torch.Tensor]): The attention mask. Default is None.
296
- position_ids (Optional[torch.LongTensor]): The position IDs. Default is None.
297
- past_key_value (Optional[Cache]): The past key-value cache. Default is None.
298
- output_attentions (bool): Whether to output the attention weights. Default is False.
299
- use_cache (bool): Whether to use caching. Default is False.
300
- cache_position (Optional[torch.LongTensor]): The cache position. Default is None.
301
- **kwargs: Additional keyword arguments.
302
-
303
- Returns:
304
- Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
305
- - The output hidden states.
306
- - The attention weights (if `output_attentions=True`).
307
- - The past key-value cache (if `use_cache=True`).
308
- """
309
  bsz, q_len, _ = hidden_states.size()
310
 
311
  query_states = self.q_proj(hidden_states)
@@ -317,17 +315,16 @@ class GemmoeAttention(nn.Module):
317
  value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
318
 
319
  past_key_value = getattr(self, "past_key_value", past_key_value)
320
-
321
  cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
322
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None)
323
 
324
  if past_key_value is not None:
325
- # sin and cos are specific to RoPE models; position_ids needed for the static cache
326
  cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
327
  key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
328
 
329
- key_states = self.repeat_kv(key_states, self.num_key_value_groups)
330
- value_states = self.repeat_kv(value_states, self.num_key_value_groups)
331
 
332
  attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
333
 
@@ -341,7 +338,6 @@ class GemmoeAttention(nn.Module):
341
  # upcast attention to fp32
342
  attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
343
  attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
344
-
345
  attn_output = torch.matmul(attn_weights, value_states)
346
 
347
  if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
@@ -351,8 +347,8 @@ class GemmoeAttention(nn.Module):
351
  )
352
 
353
  attn_output = attn_output.transpose(1, 2).contiguous()
354
- attn_output = attn_output.view(bsz, q_len, -1)
355
 
 
356
  attn_output = self.o_proj(attn_output)
357
 
358
  if not output_attentions:
@@ -360,17 +356,24 @@ class GemmoeAttention(nn.Module):
360
 
361
  return attn_output, attn_weights, past_key_value
362
 
 
 
363
  class GemmoeFlashAttention2(GemmoeAttention):
364
  """
365
  Gemmoe flash attention module. This module inherits from `GemmoeAttention` as the weights of the module stays
366
  untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
367
  flash attention and deal with padding tokens in case the input contains any of them.
368
  """
 
369
  def __init__(self, *args, **kwargs):
370
  super().__init__(*args, **kwargs)
371
- # TODO: Remove this attribute once Flash Attention for RoCm is bumped to 2.1.
 
 
 
372
  self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
373
 
 
374
  def forward(
375
  self,
376
  hidden_states: torch.Tensor,
@@ -401,8 +404,9 @@ class GemmoeFlashAttention2(GemmoeAttention):
401
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None)
402
 
403
  past_key_value = getattr(self, "past_key_value", past_key_value)
 
404
  if past_key_value is not None:
405
- # sin and cos are specific to RoPE models; position_ids needed for the static cache
406
  cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
407
  key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
408
 
@@ -419,6 +423,7 @@ class GemmoeFlashAttention2(GemmoeAttention):
419
  # cast them back in the correct dtype just to be sure everything works as expected.
420
  # This might slowdown training & inference so it is recommended to not cast the LayerNorms
421
  # in fp32. (GemmoeRMSNorm handles it correctly)
 
422
  input_dtype = query_states.dtype
423
  if input_dtype == torch.float32:
424
  if torch.is_autocast_enabled():
@@ -434,6 +439,7 @@ class GemmoeFlashAttention2(GemmoeAttention):
434
  f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
435
  f" {target_dtype}."
436
  )
 
437
  query_states = query_states.to(target_dtype)
438
  key_states = key_states.to(target_dtype)
439
  value_states = value_states.to(target_dtype)
@@ -467,7 +473,7 @@ class GemmoeFlashAttention2(GemmoeAttention):
467
  attention_mask (`torch.Tensor`):
468
  The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
469
  position of padding tokens and 1 for the position of non-padding tokens.
470
- dropout (`int`, *optional*):
471
  Attention dropout
472
  softmax_scale (`float`, *optional*):
473
  The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
@@ -484,6 +490,7 @@ class GemmoeFlashAttention2(GemmoeAttention):
484
  query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
485
  query_states, key_states, value_states, attention_mask, query_length
486
  )
 
487
  cu_seqlens_q, cu_seqlens_k = cu_seq_lens
488
  max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
489
 
@@ -499,6 +506,7 @@ class GemmoeFlashAttention2(GemmoeAttention):
499
  softmax_scale=softmax_scale,
500
  causal=causal,
501
  )
 
502
  attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
503
  else:
504
  attn_output = flash_attn_func(
@@ -509,15 +517,14 @@ class GemmoeFlashAttention2(GemmoeAttention):
509
 
510
  def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
511
  indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
512
-
513
  batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
 
514
  key_layer = index_first_axis(
515
  key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
516
  )
517
  value_layer = index_first_axis(
518
  value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
519
  )
520
-
521
  if query_length == kv_seq_len:
522
  query_layer = index_first_axis(
523
  query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
@@ -546,24 +553,16 @@ class GemmoeFlashAttention2(GemmoeAttention):
546
  (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
547
  )
548
 
 
 
549
  class GemmoeSdpaAttention(GemmoeAttention):
550
  """
551
  Gemmoe attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
552
- GemmoeAttention as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
553
  SDPA API.
554
  """
555
 
556
- def repeat_kv(self, x, n_rep):
557
- """
558
- This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
559
- num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
560
- """
561
- batch, num_key_value_heads, slen, head_dim = x.shape
562
- if n_rep == 1:
563
- return x
564
- x = x[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
565
- return x.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
566
-
567
  def forward(
568
  self,
569
  hidden_states: torch.Tensor,
@@ -576,11 +575,10 @@ class GemmoeSdpaAttention(GemmoeAttention):
576
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
577
  if output_attentions:
578
  # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
579
- # logger.warning_once(
580
- "GemmoeModel is using GemmoeSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
581
- 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
582
- # )
583
-
584
  return super().forward(
585
  hidden_states=hidden_states,
586
  attention_mask=attention_mask,
@@ -590,7 +588,7 @@ class GemmoeSdpaAttention(GemmoeAttention):
590
  use_cache=use_cache,
591
  cache_position=cache_position,
592
  )
593
-
594
  bsz, q_len, _ = hidden_states.size()
595
 
596
  query_states = self.q_proj(hidden_states)
@@ -605,23 +603,19 @@ class GemmoeSdpaAttention(GemmoeAttention):
605
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None)
606
 
607
  past_key_value = getattr(self, "past_key_value", past_key_value)
 
608
  if past_key_value is not None:
609
- # sin and cos are specific to RoPE models; position_ids needed for the static cache
610
  cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
611
  key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
612
 
613
- key_states = self.repeat_kv(key_states, self.num_key_value_groups)
614
- value_states = self.repeat_kv(value_states, self.num_key_value_groups)
615
 
616
  causal_mask = attention_mask
617
  if attention_mask is not None and cache_position is not None:
618
  causal_mask = causal_mask[:, :, cache_position, : key_states.shape[-2]]
619
 
620
- # Ensure query, key, and value states have the same dtype
621
- common_dtype = query_states.dtype
622
- key_states = key_states.to(dtype=common_dtype)
623
- value_states = value_states.to(dtype=common_dtype)
624
-
625
  # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
626
  # Reference: https://github.com/pytorch/pytorch/issues/112577.
627
  if query_states.device.type == "cuda" and causal_mask is not None:
@@ -629,10 +623,6 @@ class GemmoeSdpaAttention(GemmoeAttention):
629
  key_states = key_states.contiguous()
630
  value_states = value_states.contiguous()
631
 
632
- # Cast causal_mask to the same dtype as query_states
633
- if causal_mask is not None:
634
- causal_mask = causal_mask.to(dtype=query_states.dtype)
635
-
636
  attn_output = torch.nn.functional.scaled_dot_product_attention(
637
  query_states,
638
  key_states,
@@ -643,167 +633,113 @@ class GemmoeSdpaAttention(GemmoeAttention):
643
 
644
  attn_output = attn_output.transpose(1, 2).contiguous()
645
  attn_output = attn_output.view(bsz, q_len, -1)
 
646
  attn_output = self.o_proj(attn_output)
647
 
648
  return attn_output, None, past_key_value
649
 
650
- GEMMOE_ATTENTION_CLASSES = {
651
- "eager": GemmoeAttention,
652
- "flash_attention_2": GemmoeFlashAttention2,
653
- "sdpa": GemmoeSdpaAttention,
654
- }
655
-
656
- class GemmoeMLP(nn.Module):
657
- def __init__(self, config, hidden_size=None, intermediate_size=None):
658
- super().__init__()
659
- self.config = config
660
- self.hidden_size = config.hidden_size if hidden_size is None else hidden_size
661
- self.intermediate_size = config.intermediate_size if intermediate_size is None else intermediate_size
662
 
663
- self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
664
- self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
665
- self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
666
- self.act_fn = ACT2FN[config.hidden_act]
667
-
668
- def forward(self, x):
669
- if self.config.pretraining_tp > 1:
670
- slice = self.intermediate_size // self.config.pretraining_tp
671
- gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
672
- up_proj_slices = self.up_proj.weight.split(slice, dim=0)
673
- down_proj_slices = self.down_proj.weight.split(slice, dim=1)
674
-
675
- gate_proj = torch.cat(
676
- [F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1
677
- )
678
- up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)
679
-
680
- intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
681
- down_proj = [
682
- F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp)
683
- ]
684
- down_proj = sum(down_proj)
685
- else:
686
- down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
687
 
688
- return down_proj
689
-
690
- class MoEGate(nn.Module):
691
- def __init__(self, config):
692
  super().__init__()
693
- self.config = config
694
- self.top_k = config.num_experts_per_tok
695
- self.n_routed_experts = config.n_routed_experts
696
-
697
- self.scoring_func = config.scoring_func
698
- self.alpha = config.aux_loss_alpha
699
- self.seq_aux = config.seq_aux
700
 
701
- # topk selection algorithm
702
- self.norm_topk_prob = config.norm_topk_prob
703
- self.gating_dim = config.hidden_size
704
- self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim)))
705
- self.reset_parameters()
706
 
707
- def reset_parameters(self) -> None:
708
- import torch.nn.init as init
709
- init.kaiming_uniform_(self.weight, a=math.sqrt(5))
710
 
711
  def forward(self, hidden_states):
712
- bsz, seq_len, h = hidden_states.shape
713
- ### compute gating score
714
- hidden_states = hidden_states.view(-1, h)
715
- logits = F.linear(hidden_states, self.weight, None)
716
- if self.scoring_func == 'softmax':
717
- scores = logits.softmax(dim=-1)
718
- else:
719
- raise NotImplementedError(f'insupportable scoring function for MoE gating: {self.scoring_func}')
720
-
721
- ### select top-k experts
722
- topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
723
-
724
- ### norm gate to sum 1
725
- if self.top_k > 1 and self.norm_topk_prob:
726
- denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
727
- topk_weight = topk_weight / denominator
728
-
729
- ### expert-level computation auxiliary loss
730
- if self.training and self.alpha > 0.0:
731
- scores_for_aux = scores
732
- aux_topk = self.top_k
733
- # always compute aux loss based on the naive greedy topk method
734
- topk_idx_for_aux_loss = topk_idx.view(bsz, -1)
735
- if self.seq_aux:
736
- scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1)
737
- ce = torch.zeros(bsz, self.n_routed_experts, device=hidden_states.device)
738
- ce.scatter_add_(1, topk_idx_for_aux_loss, torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device)).div_(seq_len * aux_topk / self.n_routed_experts)
739
- aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum(dim=1).mean() * self.alpha
740
- else:
741
- mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts)
742
- ce = mask_ce.float().mean(0)
743
- Pi = scores_for_aux.mean(0)
744
- fi = ce * self.n_routed_experts
745
- aux_loss = (Pi * fi).sum() * self.alpha
746
- else:
747
- aux_loss = None
748
- return topk_idx, topk_weight, aux_loss
749
 
 
 
 
 
 
 
750
 
751
  class GemmoeSparseMoeBlock(nn.Module):
 
 
 
 
 
 
 
 
 
 
 
752
  def __init__(self, config):
753
  super().__init__()
754
  self.hidden_dim = config.hidden_size
755
  self.ffn_dim = config.intermediate_size
756
  self.num_experts = config.num_local_experts
757
- self.top_k = 2
758
 
759
- self.gate = MoEGate(config)
 
760
 
761
- self.experts = nn.ModuleList([GemmoeMLP(config) for _ in range(self.num_experts)])
762
 
763
- def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
 
764
  batch_size, sequence_length, hidden_dim = hidden_states.shape
765
  hidden_states = hidden_states.view(-1, hidden_dim)
 
 
766
 
767
- topk_idx, topk_weight, aux_loss = self.gate(hidden_states)
768
-
 
769
  # we cast back to the input dtype
770
- topk_weight = topk_weight.to(hidden_states.dtype)
771
 
772
- hidden_states = hidden_states.repeat_interleave(self.top_k, dim=0)
 
 
773
 
774
- y = torch.empty_like(hidden_states)
 
 
775
 
776
- flat_topk_idx = topk_idx.view(-1)
777
- for i in range(self.num_experts):
778
- expert = self.experts[i]
779
- expert_output = expert(hidden_states[flat_topk_idx == i])
780
- y[flat_topk_idx == i] = expert_output.to(y.dtype) # Cast expert_output to the same dtype as y
781
 
782
- y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
 
783
 
784
- final_hidden_states = y.reshape(batch_size, sequence_length, hidden_dim)
785
- return final_hidden_states, aux_loss
786
-
787
- class AddAuxiliaryLoss(torch.autograd.Function):
788
- """
789
- The trick function of adding auxiliary (aux) loss,
790
- which includes the gradient of the aux loss during backpropagation.
791
- """
792
- @staticmethod
793
- def forward(ctx, x, loss):
794
- assert loss.numel() == 1
795
- ctx.dtype = loss.dtype
796
- ctx.required_aux_loss = loss.requires_grad
797
- return x
798
 
799
- @staticmethod
800
- def backward(ctx, grad_output):
801
- grad_loss = None
802
- if ctx.required_aux_loss:
803
- grad_loss = torch.ones(1, dtype=ctx.dtype, device=grad_output.device)
804
- return grad_output, grad_loss
 
 
 
 
 
805
 
806
 
 
807
  class GemmoeDecoderLayer(nn.Module):
808
  def __init__(self, config: GemmoeConfig, layer_idx: int):
809
  super().__init__()
@@ -824,10 +760,31 @@ class GemmoeDecoderLayer(nn.Module):
824
  output_attentions: Optional[bool] = False,
825
  output_router_logits: Optional[bool] = False,
826
  use_cache: Optional[bool] = False,
827
- cache_position: Optional[torch.LongTensor] = None,
828
  **kwargs,
829
  ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
830
  residual = hidden_states
 
831
  hidden_states = self.input_layernorm(hidden_states)
832
 
833
  # Self Attention
@@ -838,20 +795,15 @@ class GemmoeDecoderLayer(nn.Module):
838
  past_key_value=past_key_value,
839
  output_attentions=output_attentions,
840
  use_cache=use_cache,
841
- cache_position=cache_position,
842
- **kwargs,
843
  )
844
  hidden_states = residual + hidden_states
845
 
846
  # Fully Connected
847
  residual = hidden_states
848
  hidden_states = self.post_attention_layernorm(hidden_states)
849
- hidden_states, aux_loss = self.block_sparse_moe(hidden_states)
850
  hidden_states = residual + hidden_states
851
 
852
- if aux_loss is not None:
853
- hidden_states = AddAuxiliaryLoss.apply(hidden_states, aux_loss)
854
-
855
  outputs = (hidden_states,)
856
 
857
  if output_attentions:
@@ -860,331 +812,364 @@ class GemmoeDecoderLayer(nn.Module):
860
  if use_cache:
861
  outputs += (present_key_value,)
862
 
 
 
 
863
  return outputs
864
 
 
865
  GEMMOE_START_DOCSTRING = r"""
866
- This model inherits from [PreTrainedModel]. Check the superclass documentation for the generic methods the
867
- library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
868
- etc.)
 
 
 
 
 
 
 
 
 
 
869
  """
870
 
 
871
  @add_start_docstrings(
872
- "The bare Gemmoe Model outputting raw hidden-states without any specific head on top.",
873
- GEMMOE_START_DOCSTRING,
874
  )
875
 
876
  class GemmoePreTrainedModel(PreTrainedModel):
877
- config_class = GemmoeConfig
878
- base_model_prefix = "model"
879
- supports_gradient_checkpointing = True
880
- _keep_in_fp32_modules = ["inv_freq", "rotary_emb", "cos_cached", "sin_cached"]
881
- _no_split_modules = ["GemmoeDecoderLayer"]
882
- _skip_keys_device_placement = ["past_key_values", "causal_mask"]
883
- _supports_flash_attn_2 = True
884
- _supports_sdpa = True
885
- _supports_cache_class = True
886
-
887
- def _init_weights(self, module):
888
- std = self.config.initializer_range
889
- if isinstance(module, nn.Linear):
890
- module.weight.data.normal_(mean=0.0, std=std)
891
- if module.bias is not None:
892
- module.bias.data.zero_()
893
- elif isinstance(module, nn.Embedding):
894
- module.weight.data.normal_(mean=0.0, std=std)
895
- if module.padding_idx is not None:
896
- module.weight.data[module.padding_idx].zero_()
897
-
898
- def _setup_cache(self, cache_cls, max_batch_size, max_cache_len: Optional[int] = None):
899
- if self.config._attn_implementation == "flash_attention_2" and cache_cls == StaticCache:
900
- raise ValueError(
901
- "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
902
- "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
903
- )
904
- if max_cache_len > self.model.causal_mask.shape[-1] or self.device != self.model.causal_mask.device:
905
- causal_mask = torch.full((max_cache_len, max_cache_len), fill_value=1, device=self.device)
906
- self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False)
907
-
908
- for layer in self.model.layers:
909
- weights = layer.self_attn.o_proj.weight
910
- layer.self_attn.past_key_value = cache_cls(
911
- self.config, max_batch_size, max_cache_len, device=weights.device, dtype=weights.dtype
912
- )
913
-
914
- def _reset_cache(self):
915
- for layer in self.model.layers:
916
- layer.self_attn.past_key_value = None
 
 
917
 
918
  GEMMOE_INPUTS_DOCSTRING = r"""
919
- Args:
920
- input_ids (torch.LongTensor of shape (batch_size, sequence_length)):
921
- Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
922
- it.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
923
  """
924
 
 
925
  @add_start_docstrings(
926
- "The bare Gemmoe Model outputting raw hidden-states without any specific head on top.",
927
- GEMMOE_START_DOCSTRING,
928
  )
929
 
930
  class GemmoeModel(GemmoePreTrainedModel):
931
- """
932
- Transformer decoder consisting of config.num_hidden_layers layers. Each layer is a [GemmoeDecoderLayer]Args:
933
- config: GemmoeConfig
934
- """
935
-
936
-
937
- def __init__(self, config: GemmoeConfig):
938
- super().__init__(config)
939
- self.padding_idx = config.pad_token_id
940
- self.vocab_size = config.vocab_size
941
-
942
- self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
943
- self.layers = nn.ModuleList(
944
- [GemmoeDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
945
- )
946
-
947
- self.norm = GemmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
948
-
949
- self.gradient_checkpointing = False
950
-
951
- # Register a causal mask to separate causal and padding mask creation. Merging happens in the attention class.
952
- # NOTE: This is not friendly with TorchScript, ONNX, ExportedProgram serialization for very large `max_position_embeddings`.
953
- causal_mask = torch.full(
954
- (config.max_position_embeddings, config.max_position_embeddings), fill_value=True, dtype=torch.bool
955
- )
956
- self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False)
957
-
958
- # Initialize weights and apply final processing
959
- self.post_init()
960
-
961
- def get_input_embeddings(self):
962
- return self.embed_tokens
963
-
964
- def set_input_embeddings(self, value):
965
- self.embed_tokens = value
966
-
967
- @add_start_docstrings_to_model_forward(GEMMOE_INPUTS_DOCSTRING)
968
- @replace_return_docstrings(output_type=MoeModelOutputWithPast, config_class=_CONFIG_FOR_DOC)
969
- def forward(
970
- self,
971
- input_ids: torch.LongTensor = None,
972
- attention_mask: Optional[torch.Tensor] = None,
973
- position_ids: Optional[torch.LongTensor] = None,
974
- past_key_values: Optional[List[torch.FloatTensor]] = None,
975
- inputs_embeds: Optional[torch.FloatTensor] = None,
976
- use_cache: Optional[bool] = None,
977
- output_attentions: Optional[bool] = None,
978
- output_hidden_states: Optional[bool] = None,
979
- output_router_logits: Optional[bool] = None,
980
- return_dict: Optional[bool] = None,
981
- cache_position: Optional[torch.LongTensor] = None,
982
- ) -> Union[Tuple, MoeModelOutputWithPast]:
983
- """
984
- Forward pass of the sequence classification model.
985
-
986
- Args:
987
- input_ids: Input token IDs.
988
- attention_mask: Attention mask.
989
- position_ids: Position IDs.
990
- past_key_values: Past key-value pairs.
991
- inputs_embeds: Input embeddings.
992
- labels: Labels for sequence classification.
993
- use_cache: Whether to use cache.
994
- output_attentions: Whether to output attentions.
995
- output_hidden_states: Whether to output hidden states.
996
- return_dict: Whether to return a dictionary or tuple.
997
-
998
- Returns:
999
- Output of the sequence classification model.
1000
- """
1001
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1002
- output_hidden_states = (
1003
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1004
- )
1005
- output_router_logits = (
1006
- output_router_logits if output_router_logits is not None else self.config.output_router_logits
1007
- )
1008
- use_cache = use_cache if use_cache is not None else self.config.use_cache
1009
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1010
-
1011
- if (input_ids is None) ^ (inputs_embeds is not None):
1012
- raise ValueError(
1013
- "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
1014
- )
1015
-
1016
- if self.gradient_checkpointing and self.training and use_cache:
1017
- logger.warning_once(
1018
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
1019
- )
1020
- use_cache = False
1021
-
1022
- if inputs_embeds is None:
1023
- inputs_embeds = self.embed_tokens(input_ids)
1024
-
1025
- past_seen_tokens = 0
1026
- if use_cache: # kept for BC (cache positions)
1027
- if not isinstance(past_key_values, StaticCache):
1028
- past_key_values = DynamicCache.from_legacy_cache(past_key_values)
1029
- past_seen_tokens = past_key_values.get_seq_length()
1030
-
1031
- if cache_position is None:
1032
- cache_position = torch.arange(
1033
- past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
1034
- )
1035
-
1036
- if position_ids is None:
1037
- position_ids = cache_position.unsqueeze(0)
1038
-
1039
- causal_mask = self._update_causal_mask(attention_mask, inputs_embeds)
1040
-
1041
- hidden_states = inputs_embeds
1042
-
1043
- # Normalize
1044
- scale_factor = torch.tensor(math_sqrt(self.config.hidden_size), dtype=hidden_states.dtype)
1045
- hidden_states = hidden_states * scale_factor
1046
- # Decoder layers
1047
- all_hidden_states = () if output_hidden_states else None
1048
- all_self_attns = () if output_attentions else None
1049
- all_router_logits = () if output_router_logits else None
1050
- next_decoder_cache = None
1051
-
1052
- for decoder_layer in self.layers:
1053
- if output_hidden_states:
1054
- all_hidden_states += (hidden_states,)
1055
-
1056
- if self.gradient_checkpointing and self.training:
1057
- layer_outputs = self._gradient_checkpointing_func(
1058
- decoder_layer.__call__,
1059
- hidden_states,
1060
- causal_mask,
1061
- position_ids,
1062
- past_key_values,
1063
- output_attentions,
1064
- output_router_logits,
1065
- use_cache,
1066
- cache_position,
1067
- )
1068
- else:
1069
- layer_outputs = decoder_layer(
1070
- hidden_states,
1071
- attention_mask=causal_mask,
1072
- position_ids=position_ids,
1073
- past_key_value=past_key_values,
1074
- output_attentions=output_attentions,
1075
- output_router_logits=output_router_logits,
1076
- use_cache=use_cache,
1077
- cache_position=cache_position,
1078
- )
1079
-
1080
- hidden_states = layer_outputs[0]
1081
- if use_cache:
1082
- next_decoder_cache = layer_outputs[2 if output_attentions else 1]
1083
- if output_attentions:
1084
- all_self_attns += (layer_outputs[1],)
1085
- if output_router_logits:
1086
- all_router_logits += (layer_outputs[-1],)
1087
-
1088
- hidden_states = self.norm(hidden_states)
1089
-
1090
- # Add hidden states from the last decoder layer
1091
- if output_hidden_states:
1092
- all_hidden_states += (hidden_states,)
1093
-
1094
- next_cache = None
1095
- if use_cache:
1096
- next_cache = (
1097
- next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache
1098
- )
1099
-
1100
- if not return_dict:
1101
- return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits] if v is not None)
1102
-
1103
- return MoeModelOutputWithPast(
1104
- last_hidden_state=hidden_states,
1105
- past_key_values=next_cache,
1106
- hidden_states=all_hidden_states,
1107
- attentions=all_self_attns,
1108
- router_logits=all_router_logits
1109
- )
1110
-
1111
- def _update_causal_mask(self, attention_mask, input_tensor):
1112
- """
1113
- Update the causal mask based on the attention mask and input tensor.
1114
-
1115
- Args:
1116
- attention_mask (torch.Tensor): The attention mask.
1117
- input_tensor (torch.Tensor): The input tensor.
1118
-
1119
- Returns:
1120
- torch.Tensor: The updated causal mask.
1121
- """
1122
-
1123
- if self.config._attn_implementation == "flash_attention_2":
1124
- if attention_mask is not None and 0.0 in attention_mask:
1125
- return attention_mask
1126
- return None
1127
-
1128
- batch_size, seq_length = input_tensor.shape[:2]
1129
- dtype = input_tensor.dtype
1130
- device = input_tensor.device
1131
-
1132
- # support going beyond cached `max_position_embedding`
1133
- if seq_length > self.causal_mask.shape[-1]:
1134
- logger.info(f"Resizing causal mask buffer from {self.causal_mask.shape[-1]} to {2 * self.causal_mask.shape[-1]}")
1135
- causal_mask = torch.full((2 * self.causal_mask.shape[-1], 2 * self.causal_mask.shape[-1]), fill_value=1)
1136
- self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False)
1137
-
1138
- # We use the current dtype to avoid any overflows
1139
- min_dtype = torch.finfo(dtype).min
1140
- causal_mask = self.causal_mask[None, None, :, :].repeat(batch_size, 1, 1, 1).to(dtype) * min_dtype
1141
- causal_mask = causal_mask.to(dtype=dtype, device=device)
1142
-
1143
- if attention_mask is not None and attention_mask.dim() == 2:
1144
- mask_length = attention_mask.shape[-1]
1145
- padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
1146
- causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)
1147
-
1148
- if self.config._attn_implementation == "sdpa" and attention_mask is not None:
1149
- # TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400).
1150
- is_tracing = (
1151
- torch.jit.is_tracing()
1152
- or isinstance(input_tensor, torch.fx.Proxy)
1153
- or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
1154
- )
1155
-
1156
- if not is_tracing and torch.any(attention_mask != 1):
1157
- # Attend to all tokens in masked rows from the causal_mask, for example the relevant first rows when
1158
- # using left padding. This is required by
1159
- # F.scaled_dot_product_attention memory-efficient attention path.
1160
- # Details: https://github.com/pytorch/pytorch/issues/110213
1161
- causal_mask = causal_mask.mul(~torch.all(causal_mask == min_dtype, dim=-1, keepdim=True)).to(dtype)
1162
-
1163
- return causal_mask
1164
-
1165
- class GemmoeForCausalLM(GemmoePreTrainedModel):
1166
- r"""
1167
- The Gemmoe Model transformer with a language modeling head on top for causal language modeling (CLM).
1168
 
1169
  Args:
1170
- config (GemmoeConfig): The configuration object for the Gemmoe model.
 
 
 
 
 
 
1171
 
1172
- Example usage:
1173
- ```python
1174
- >>> from transformers import AutoTokenizer, GemmoeForCausalLM
 
 
 
1175
 
1176
- >>> model = GemmoeForCausalLM.from_pretrained("google/gemmoe-7b")
1177
- >>> tokenizer = AutoTokenizer.from_pretrained("google/gemmoe-7b")
 
 
 
 
 
 
1178
 
1179
- >>> prompt = "What is your favorite condiment?"
1180
- >>> inputs = tokenizer(prompt, return_tensors="pt")
1181
 
1182
- >>> # Generate
1183
- >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1184
- >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1185
- "What is your favorite condiment?"
1186
- ```
1187
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1188
  _tied_weights_keys = ["lm_head.weight"]
1189
 
1190
  def __init__(self, config):
@@ -1193,9 +1178,8 @@ class GemmoeForCausalLM(GemmoePreTrainedModel):
1193
  self.vocab_size = config.vocab_size
1194
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1195
  self.router_aux_loss_coef = config.router_aux_loss_coef
1196
- self.num_experts = 8
1197
  self.num_experts_per_tok = config.num_experts_per_tok
1198
-
1199
  # Initialize weights and apply final processing
1200
  self.post_init()
1201
 
@@ -1219,6 +1203,7 @@ class GemmoeForCausalLM(GemmoePreTrainedModel):
1219
 
1220
  @add_start_docstrings_to_model_forward(GEMMOE_INPUTS_DOCSTRING)
1221
  @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
 
1222
  def forward(
1223
  self,
1224
  input_ids: torch.LongTensor = None,
@@ -1232,7 +1217,6 @@ class GemmoeForCausalLM(GemmoePreTrainedModel):
1232
  output_hidden_states: Optional[bool] = None,
1233
  output_router_logits: Optional[bool] = None,
1234
  return_dict: Optional[bool] = None,
1235
- cache_position: Optional[torch.LongTensor] = None,
1236
  ) -> Union[Tuple, MoeCausalLMOutputWithPast]:
1237
  r"""
1238
  Args:
@@ -1248,26 +1232,29 @@ class GemmoeForCausalLM(GemmoePreTrainedModel):
1248
  ```python
1249
  >>> from transformers import AutoTokenizer, GemmoeForCausalLM
1250
 
1251
- >>> model = GemmoeForCausalLM.from_pretrained("google/gemmoe-7b")
1252
- >>> tokenizer = AutoTokenizer.from_pretrained("google/gemmoe-7b")
1253
 
1254
- >>> prompt = "What is your favorite condiment?"
1255
  >>> inputs = tokenizer(prompt, return_tensors="pt")
1256
 
1257
  >>> # Generate
1258
  >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1259
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1260
- "What is your favorite condiment?"
1261
  ```"""
 
1262
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1263
  output_router_logits = (
1264
- output_router_logits if output_router_logits is not None else getattr(self.config, "output_router_logits", False)
1265
  )
 
1266
  output_hidden_states = (
1267
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1268
  )
1269
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1270
 
 
1271
  outputs = self.model(
1272
  input_ids=input_ids,
1273
  attention_mask=attention_mask,
@@ -1279,42 +1266,39 @@ class GemmoeForCausalLM(GemmoePreTrainedModel):
1279
  output_hidden_states=output_hidden_states,
1280
  output_router_logits=output_router_logits,
1281
  return_dict=return_dict,
1282
- cache_position=cache_position,
1283
  )
1284
 
1285
  hidden_states = outputs[0]
1286
-
1287
- # Ensure hidden_states and lm_head have compatible dtypes
1288
- hidden_states = hidden_states.to(dtype=self.lm_head.weight.dtype)
1289
-
1290
  logits = self.lm_head(hidden_states)
 
1291
 
1292
  loss = None
1293
  if labels is not None:
 
1294
  shift_logits = logits[..., :-1, :].contiguous()
1295
  shift_labels = labels[..., 1:].contiguous()
 
1296
  loss_fct = CrossEntropyLoss()
1297
  shift_logits = shift_logits.view(-1, self.config.vocab_size)
1298
  shift_labels = shift_labels.view(-1)
 
1299
  shift_labels = shift_labels.to(shift_logits.device)
1300
  loss = loss_fct(shift_logits, shift_labels)
1301
 
1302
  aux_loss = None
1303
  if output_router_logits:
1304
- router_logits = outputs.router_logits if return_dict else outputs[-1]
1305
- if router_logits is not None:
1306
- aux_loss = load_balancing_loss_func(
1307
- router_logits,
1308
- self.num_experts,
1309
- self.num_experts_per_tok,
1310
- attention_mask,
1311
- )
1312
- if labels is not None:
1313
- loss += self.router_aux_loss_coef * aux_loss.to(loss.device)
1314
 
1315
  if not return_dict:
1316
  output = (logits,) + outputs[1:]
1317
- if aux_loss is not None:
1318
  output = (aux_loss,) + output
1319
  return (loss,) + output if loss is not None else output
1320
 
@@ -1329,9 +1313,15 @@ class GemmoeForCausalLM(GemmoePreTrainedModel):
1329
  )
1330
 
1331
  def prepare_inputs_for_generation(
1332
- self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
 
 
 
 
 
 
1333
  ):
1334
- past_length = 0
1335
  if past_key_values is not None:
1336
  if isinstance(past_key_values, Cache):
1337
  cache_length = past_key_values.get_seq_length()
@@ -1341,11 +1331,19 @@ class GemmoeForCausalLM(GemmoePreTrainedModel):
1341
  cache_length = past_length = past_key_values[0][0].shape[2]
1342
  max_cache_length = None
1343
 
 
 
 
 
1344
  if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
1345
  input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
 
 
1346
  elif past_length < input_ids.shape[1]:
1347
  input_ids = input_ids[:, past_length:]
1348
-
 
 
1349
  if (
1350
  max_cache_length is not None
1351
  and attention_mask is not None
@@ -1355,37 +1353,27 @@ class GemmoeForCausalLM(GemmoePreTrainedModel):
1355
 
1356
  position_ids = kwargs.get("position_ids", None)
1357
  if attention_mask is not None and position_ids is None:
 
1358
  position_ids = attention_mask.long().cumsum(-1) - 1
1359
  position_ids.masked_fill_(attention_mask == 0, 1)
1360
  if past_key_values:
1361
  position_ids = position_ids[:, -input_ids.shape[1] :]
1362
 
1363
- if self.generation_config.cache_implementation == "static":
1364
- cache_position = kwargs.get("cache_position", None)
1365
- if cache_position is None:
1366
- past_length = 0
1367
- else:
1368
- past_length = cache_position[-1] + 1
1369
- input_ids = input_ids[:, -1].unsqueeze(-1)
1370
- position_ids = position_ids[:, -1].unsqueeze(-1)
1371
-
1372
- cache_position = torch.arange(past_length, past_length + position_ids.shape[-1], device=position_ids.device)
1373
-
1374
  if inputs_embeds is not None and past_key_values is None:
1375
  model_inputs = {"inputs_embeds": inputs_embeds}
1376
  else:
1377
- model_inputs = {"input_ids": input_ids.contiguous()}
1378
 
1379
  model_inputs.update(
1380
  {
1381
- "position_ids": position_ids.contiguous(),
1382
- "cache_position": cache_position,
1383
  "past_key_values": past_key_values,
1384
  "use_cache": kwargs.get("use_cache"),
1385
  "attention_mask": attention_mask,
 
1386
  }
1387
  )
1388
-
1389
  return model_inputs
1390
 
1391
  @staticmethod
@@ -1418,6 +1406,7 @@ class GemmoeForSequenceClassification(GemmoePreTrainedModel):
1418
  self.num_labels = config.num_labels
1419
  self.model = GemmoeModel(config)
1420
  self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
 
1421
  # Initialize weights and apply final processing
1422
  self.post_init()
1423
 
@@ -1428,7 +1417,6 @@ class GemmoeForSequenceClassification(GemmoePreTrainedModel):
1428
  self.model.embed_tokens = value
1429
 
1430
  @add_start_docstrings_to_model_forward(GEMMOE_INPUTS_DOCSTRING)
1431
- @replace_return_docstrings(output_type=SequenceClassifierOutputWithPast, config_class=_CONFIG_FOR_DOC)
1432
  def forward(
1433
  self,
1434
  input_ids: torch.LongTensor = None,
@@ -1442,25 +1430,14 @@ class GemmoeForSequenceClassification(GemmoePreTrainedModel):
1442
  output_hidden_states: Optional[bool] = None,
1443
  return_dict: Optional[bool] = None,
1444
  ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1445
- """
1446
- Forward pass of the sequence classification model.
1447
-
1448
- Args:
1449
- input_ids (torch.LongTensor, optional): Input token IDs.
1450
- attention_mask (torch.Tensor, optional): Attention mask.
1451
- position_ids (torch.LongTensor, optional): Position IDs.
1452
- past_key_values (List[torch.FloatTensor], optional): Past key-value pairs.
1453
- inputs_embeds (torch.FloatTensor, optional): Input embeddings.
1454
- labels (torch.LongTensor, optional): Labels for sequence classification.
1455
- use_cache (bool, optional): Whether to use cache.
1456
- output_attentions (bool, optional): Whether to output attentions.
1457
- output_hidden_states (bool, optional): Whether to output hidden states.
1458
- return_dict (bool, optional): Whether to return a dictionary or tuple.
1459
-
1460
- Returns:
1461
- Union[Tuple, SequenceClassifierOutputWithPast]: Output of the sequence classification model.
1462
  """
1463
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 
1464
  transformer_outputs = self.model(
1465
  input_ids,
1466
  attention_mask=attention_mask,
@@ -1486,8 +1463,10 @@ class GemmoeForSequenceClassification(GemmoePreTrainedModel):
1486
  sequence_lengths = -1
1487
  else:
1488
  if input_ids is not None:
1489
- sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1
1490
- sequence_lengths = sequence_lengths.clamp(min=0).to(logits.device)
 
 
1491
  else:
1492
  sequence_lengths = -1
1493
 
@@ -1516,7 +1495,6 @@ class GemmoeForSequenceClassification(GemmoePreTrainedModel):
1516
  elif self.config.problem_type == "multi_label_classification":
1517
  loss_fct = BCEWithLogitsLoss()
1518
  loss = loss_fct(pooled_logits, labels)
1519
-
1520
  if not return_dict:
1521
  output = (pooled_logits,) + transformer_outputs[1:]
1522
  return ((loss,) + output) if loss is not None else output
@@ -1527,4 +1505,4 @@ class GemmoeForSequenceClassification(GemmoePreTrainedModel):
1527
  past_key_values=transformer_outputs.past_key_values,
1528
  hidden_states=transformer_outputs.hidden_states,
1529
  attentions=transformer_outputs.attentions,
1530
- )
 
28
  from transformers.activations import ACT2FN
29
  from transformers.cache_utils import Cache, DynamicCache, StaticCache
30
  from transformers.modeling_attn_mask_utils import (
31
+ AttentionMaskConverter,
32
  _prepare_4d_causal_attention_mask,
33
  )
34
  from transformers.modeling_outputs import SequenceClassifierOutputWithPast, MoeModelOutputWithPast, MoeCausalLMOutputWithPast
 
177
  class GemmoeRotaryEmbedding(nn.Module):
178
  def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
179
  super().__init__()
180
+
181
  self.dim = dim
182
  self.max_position_embeddings = max_position_embeddings
183
  self.base = base
184
+ self.register_buffer("inv_freq", None, persistent=False)
185
+
186
+ @torch.no_grad()
187
+ def forward(self, x, position_ids, seq_len=None):
188
+ # x: [bs, num_attention_heads, seq_len, head_size]
189
+ if self.inv_freq is None:
190
+ self.inv_freq = 1.0 / (
191
+ self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim)
192
+ )
193
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
194
+ position_ids_expanded = position_ids[:, None, :].float()
195
+ # Force float32 since bfloat16 loses precision on long contexts
196
+ # See https://github.com/huggingface/transformers/pull/29285
197
+ device_type = x.device.type
198
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
199
+ with torch.autocast(device_type=device_type, enabled=False):
200
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
201
+ emb = torch.cat((freqs, freqs), dim=-1)
202
+ cos = emb.cos()
203
+ sin = emb.sin()
204
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
205
+
206
+ # Copied from transformers.models.llama.modeling_llama.rotate_half
 
 
 
 
207
  def rotate_half(x):
208
  """Rotates half the hidden dims of the input."""
209
  x1 = x[..., : x.shape[-1] // 2]
210
  x2 = x[..., x.shape[-1] // 2 :]
211
  return torch.cat((-x2, x1), dim=-1)
212
 
213
+
214
+ # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
215
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
216
+ """Applies Rotary Position Embedding to the query and key tensors.
217
+
218
+ Args:
219
+ q (`torch.Tensor`): The query tensor.
220
+ k (`torch.Tensor`): The key tensor.
221
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
222
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
223
+ position_ids (`torch.Tensor`, *optional*):
224
+ Deprecated and unused.
225
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
226
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
227
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
228
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
229
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
230
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
231
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
232
+ Returns:
233
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
234
+ """
235
+ cos = cos.unsqueeze(unsqueeze_dim)
236
+ sin = sin.unsqueeze(unsqueeze_dim)
237
  q_embed = (q * cos) + (rotate_half(q) * sin)
238
  k_embed = (k * cos) + (rotate_half(k) * sin)
239
  return q_embed, k_embed
240
 
241
+
242
+
243
+ # Copied from transformers.models.llama.modeling_llama.repeat_kv
244
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
245
  """
246
  This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
247
  num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
 
253
  return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
254
 
255
  class GemmoeAttention(nn.Module):
256
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
 
 
 
 
 
 
257
 
258
+ # Ignore copy
259
  def __init__(self, config: GemmoeConfig, layer_idx: Optional[int] = None):
260
  super().__init__()
261
  self.config = config
 
266
  "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
267
  "when creating this class."
268
  )
269
+
270
  self.attention_dropout = config.attention_dropout
271
  self.hidden_size = config.hidden_size
272
  self.num_heads = config.num_attention_heads
 
282
  f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
283
  f" and `num_heads`: {self.num_heads})."
284
  )
285
+
286
  self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
287
  self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
288
  self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
289
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
290
  self.rotary_emb = GemmoeRotaryEmbedding(
291
+ self.head_dim,
292
+ max_position_embeddings=self.max_position_embeddings,
293
+ base=self.rope_theta,
294
+ )
295
 
296
  def forward(
297
  self,
 
304
  cache_position: Optional[torch.LongTensor] = None,
305
  **kwargs,
306
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
307
  bsz, q_len, _ = hidden_states.size()
308
 
309
  query_states = self.q_proj(hidden_states)
 
315
  value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
316
 
317
  past_key_value = getattr(self, "past_key_value", past_key_value)
 
318
  cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
319
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None)
320
 
321
  if past_key_value is not None:
322
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
323
  cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
324
  key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
325
 
326
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
327
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
328
 
329
  attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
330
 
 
338
  # upcast attention to fp32
339
  attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
340
  attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
 
341
  attn_output = torch.matmul(attn_weights, value_states)
342
 
343
  if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
 
347
  )
348
 
349
  attn_output = attn_output.transpose(1, 2).contiguous()
 
350
 
351
+ attn_output = attn_output.view(bsz, q_len, -1)
352
  attn_output = self.o_proj(attn_output)
353
 
354
  if not output_attentions:
 
356
 
357
  return attn_output, attn_weights, past_key_value
358
 
359
+
360
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with Llama->Gemmoe
361
  class GemmoeFlashAttention2(GemmoeAttention):
362
  """
363
  Gemmoe flash attention module. This module inherits from `GemmoeAttention` as the weights of the module stays
364
  untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
365
  flash attention and deal with padding tokens in case the input contains any of them.
366
  """
367
+
368
  def __init__(self, *args, **kwargs):
369
  super().__init__(*args, **kwargs)
370
+
371
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
372
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
373
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
374
  self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
375
 
376
+ # Ignore copy
377
  def forward(
378
  self,
379
  hidden_states: torch.Tensor,
 
404
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None)
405
 
406
  past_key_value = getattr(self, "past_key_value", past_key_value)
407
+
408
  if past_key_value is not None:
409
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
410
  cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
411
  key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
412
 
 
423
  # cast them back in the correct dtype just to be sure everything works as expected.
424
  # This might slowdown training & inference so it is recommended to not cast the LayerNorms
425
  # in fp32. (GemmoeRMSNorm handles it correctly)
426
+
427
  input_dtype = query_states.dtype
428
  if input_dtype == torch.float32:
429
  if torch.is_autocast_enabled():
 
439
  f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
440
  f" {target_dtype}."
441
  )
442
+
443
  query_states = query_states.to(target_dtype)
444
  key_states = key_states.to(target_dtype)
445
  value_states = value_states.to(target_dtype)
 
473
  attention_mask (`torch.Tensor`):
474
  The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
475
  position of padding tokens and 1 for the position of non-padding tokens.
476
+ dropout (`float`):
477
  Attention dropout
478
  softmax_scale (`float`, *optional*):
479
  The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
 
490
  query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
491
  query_states, key_states, value_states, attention_mask, query_length
492
  )
493
+
494
  cu_seqlens_q, cu_seqlens_k = cu_seq_lens
495
  max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
496
 
 
506
  softmax_scale=softmax_scale,
507
  causal=causal,
508
  )
509
+
510
  attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
511
  else:
512
  attn_output = flash_attn_func(
 
517
 
518
  def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
519
  indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
 
520
  batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
521
+
522
  key_layer = index_first_axis(
523
  key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
524
  )
525
  value_layer = index_first_axis(
526
  value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
527
  )
 
528
  if query_length == kv_seq_len:
529
  query_layer = index_first_axis(
530
  query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
 
553
  (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
554
  )
555
 
556
+
557
+ # Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Gemmoe
558
  class GemmoeSdpaAttention(GemmoeAttention):
559
  """
560
  Gemmoe attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
561
+ `GemmoeAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
562
  SDPA API.
563
  """
564
 
565
+ # Ignore copy
 
 
 
 
 
 
 
 
 
 
566
  def forward(
567
  self,
568
  hidden_states: torch.Tensor,
 
575
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
576
  if output_attentions:
577
  # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
578
+ logger.warning_once(
579
+ "GemmoeModel is using GemmoeSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
580
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
581
+ )
 
582
  return super().forward(
583
  hidden_states=hidden_states,
584
  attention_mask=attention_mask,
 
588
  use_cache=use_cache,
589
  cache_position=cache_position,
590
  )
591
+
592
  bsz, q_len, _ = hidden_states.size()
593
 
594
  query_states = self.q_proj(hidden_states)
 
603
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None)
604
 
605
  past_key_value = getattr(self, "past_key_value", past_key_value)
606
+
607
  if past_key_value is not None:
608
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
609
  cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
610
  key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
611
 
612
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
613
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
614
 
615
  causal_mask = attention_mask
616
  if attention_mask is not None and cache_position is not None:
617
  causal_mask = causal_mask[:, :, cache_position, : key_states.shape[-2]]
618
 
 
 
 
 
 
619
  # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
620
  # Reference: https://github.com/pytorch/pytorch/issues/112577.
621
  if query_states.device.type == "cuda" and causal_mask is not None:
 
623
  key_states = key_states.contiguous()
624
  value_states = value_states.contiguous()
625
 
 
 
 
 
626
  attn_output = torch.nn.functional.scaled_dot_product_attention(
627
  query_states,
628
  key_states,
 
633
 
634
  attn_output = attn_output.transpose(1, 2).contiguous()
635
  attn_output = attn_output.view(bsz, q_len, -1)
636
+
637
  attn_output = self.o_proj(attn_output)
638
 
639
  return attn_output, None, past_key_value
640
 
 
 
 
 
 
 
 
 
 
 
 
 
641
 
642
+ GEMMOE_ATTENTION_CLASSES = {
643
+ "eager": GemmoeAttention,
644
+ "flash_attention_2": GemmoeFlashAttention2,
645
+ "sdpa": GemmoeSdpaAttention,
646
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
647
 
648
+ class GemmoeBlockSparseTop2MLP(nn.Module):
649
+ def __init__(self, config: GemmoeConfig):
 
 
650
  super().__init__()
651
+ self.ffn_dim = config.intermediate_size
652
+ self.hidden_dim = config.hidden_size
 
 
 
 
 
653
 
654
+ self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
655
+ self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
656
+ self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
 
 
657
 
658
+ self.act_fn = approx_gelu
 
 
659
 
660
  def forward(self, hidden_states):
661
+ current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
662
+ current_hidden_states = self.w2(current_hidden_states)
663
+ return current_hidden_states
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
664
 
665
+ class GemmoeBlockSparseTop2MLP(GemmoeBlockSparseTop2MLP):
666
+ def __init__(self, *args, **kwargs):
667
+ logger.warning_once(
668
+ "GemmoeBLockSparseTop2MLP is deprecated by GemmoeBlockSparseTop2MLP and will be removed in v4.40."
669
+ )
670
+ super().__init__(*args, **kwargs)
671
 
672
  class GemmoeSparseMoeBlock(nn.Module):
673
+ """
674
+ This implementation is
675
+ strictly equivalent to standard MoE with full capacity (no
676
+ dropped tokens). It's faster since it formulates MoE operations
677
+ in terms of block-sparse operations to accomodate imbalanced
678
+ assignments of tokens to experts, whereas standard MoE either
679
+ (1) drop tokens at the cost of reduced performance or (2) set
680
+ capacity factor to number of experts and thus waste computation
681
+ and memory on padding.
682
+ """
683
+
684
  def __init__(self, config):
685
  super().__init__()
686
  self.hidden_dim = config.hidden_size
687
  self.ffn_dim = config.intermediate_size
688
  self.num_experts = config.num_local_experts
689
+ self.top_k = config.num_experts_per_tok
690
 
691
+ # gating
692
+ self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
693
 
694
+ self.experts = nn.ModuleList([GemmoeBlockSparseTop2MLP(config) for _ in range(self.num_experts)])
695
 
696
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
697
+ """ """
698
  batch_size, sequence_length, hidden_dim = hidden_states.shape
699
  hidden_states = hidden_states.view(-1, hidden_dim)
700
+ # router_logits: (batch * sequence_length, n_experts)
701
+ router_logits = self.gate(hidden_states)
702
 
703
+ routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
704
+ routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
705
+ routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
706
  # we cast back to the input dtype
707
+ routing_weights = routing_weights.to(hidden_states.dtype)
708
 
709
+ final_hidden_states = torch.zeros(
710
+ (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
711
+ )
712
 
713
+ # One hot encode the selected experts to create an expert mask
714
+ # this will be used to easily index which expert is going to be sollicitated
715
+ expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
716
 
717
+ # Loop over all available experts in the model and perform the computation on each expert
718
+ for expert_idx in range(self.num_experts):
719
+ expert_layer = self.experts[expert_idx]
720
+ idx, top_x = torch.where(expert_mask[expert_idx])
 
721
 
722
+ if top_x.shape[0] == 0:
723
+ continue
724
 
725
+ # in torch it is faster to index using lists than torch tensors
726
+ top_x_list = top_x.tolist()
727
+ idx_list = idx.tolist()
 
 
 
 
 
 
 
 
 
 
 
728
 
729
+ # Index the correct hidden states and compute the expert hidden state for
730
+ # the current expert. We need to make sure to multiply the output hidden
731
+ # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
732
+ current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim)
733
+ current_hidden_states = expert_layer(current_state) * routing_weights[top_x_list, idx_list, None]
734
+
735
+ # However `index_add_` only support torch tensors for indexing so we'll use
736
+ # the `top_x` tensor here.
737
+ final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
738
+ final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
739
+ return final_hidden_states, router_logits
740
 
741
 
742
+ # Copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with LLAMA->GEMMOE,Llama->Gemmoe
743
  class GemmoeDecoderLayer(nn.Module):
744
  def __init__(self, config: GemmoeConfig, layer_idx: int):
745
  super().__init__()
 
760
  output_attentions: Optional[bool] = False,
761
  output_router_logits: Optional[bool] = False,
762
  use_cache: Optional[bool] = False,
 
763
  **kwargs,
764
  ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
765
+ if "padding_mask" in kwargs:
766
+ warnings.warn(
767
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
768
+ )
769
+ """
770
+ Args:
771
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
772
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
773
+ `(batch, sequence_length)` where padding elements are indicated by 0.
774
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
775
+ output_attentions (`bool`, *optional*):
776
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
777
+ returned tensors for more detail.
778
+ output_router_logits (`bool`, *optional*):
779
+ Whether or not to return the logits of all the routers. They are useful for computing the router loss, and
780
+ should not be returned during inference.
781
+ use_cache (`bool`, *optional*):
782
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
783
+ (see `past_key_values`).
784
+ """
785
+
786
  residual = hidden_states
787
+
788
  hidden_states = self.input_layernorm(hidden_states)
789
 
790
  # Self Attention
 
795
  past_key_value=past_key_value,
796
  output_attentions=output_attentions,
797
  use_cache=use_cache,
 
 
798
  )
799
  hidden_states = residual + hidden_states
800
 
801
  # Fully Connected
802
  residual = hidden_states
803
  hidden_states = self.post_attention_layernorm(hidden_states)
804
+ hidden_states, router_logits = self.block_sparse_moe(hidden_states)
805
  hidden_states = residual + hidden_states
806
 
 
 
 
807
  outputs = (hidden_states,)
808
 
809
  if output_attentions:
 
812
  if use_cache:
813
  outputs += (present_key_value,)
814
 
815
+ if output_router_logits:
816
+ outputs += (router_logits,)
817
+
818
  return outputs
819
 
820
+
821
  GEMMOE_START_DOCSTRING = r"""
822
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
823
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
824
+ etc.)
825
+
826
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
827
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
828
+ and behavior.
829
+
830
+ Parameters:
831
+ config ([`GemmoeConfig`]):
832
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
833
+ load the weights associated with the model, only the configuration. Check out the
834
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
835
  """
836
 
837
+
838
  @add_start_docstrings(
839
+ "The bare Gemmoe Model outputting raw hidden-states without any specific head on top.",
840
+ GEMMOE_START_DOCSTRING,
841
  )
842
 
843
  class GemmoePreTrainedModel(PreTrainedModel):
844
+ config_class = GemmoeConfig
845
+ base_model_prefix = "model"
846
+ supports_gradient_checkpointing = True
847
+ _keep_in_fp32_modules = ["inv_freq", "rotary_emb", "cos_cached", "sin_cached"]
848
+ _no_split_modules = ["GemmoeDecoderLayer"]
849
+ _skip_keys_device_placement = ["past_key_values", "causal_mask"]
850
+ _supports_flash_attn_2 = True
851
+ _supports_sdpa = True
852
+ _supports_cache_class = True
853
+
854
+ def _init_weights(self, module):
855
+ std = self.config.initializer_range
856
+ if isinstance(module, nn.Linear):
857
+ module.weight.data.normal_(mean=0.0, std=std)
858
+ if module.bias is not None:
859
+ module.bias.data.zero_()
860
+ elif isinstance(module, nn.Embedding):
861
+ module.weight.data.normal_(mean=0.0, std=std)
862
+ if module.padding_idx is not None:
863
+ module.weight.data[module.padding_idx].zero_()
864
+
865
+ def _setup_cache(self, cache_cls, max_batch_size, max_cache_len: Optional[int] = None):
866
+ if self.config._attn_implementation == "flash_attention_2" and cache_cls == StaticCache:
867
+ raise ValueError(
868
+ "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
869
+ "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
870
+ )
871
+
872
+ if max_cache_len > self.model.causal_mask.shape[-1] or self.device != self.model.causal_mask.device:
873
+ causal_mask = torch.full((max_cache_len, max_cache_len), fill_value=1, device=self.device)
874
+ self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False)
875
+
876
+ for layer in self.model.layers:
877
+ weights = layer.self_attn.o_proj.weight
878
+ layer.self_attn.past_key_value = cache_cls(
879
+ self.config, max_batch_size, max_cache_len, device=weights.device, dtype=weights.dtype
880
+ )
881
+
882
+ def _reset_cache(self):
883
+ for layer in self.model.layers:
884
+ layer.self_attn.past_key_value = None
885
+
886
 
887
  GEMMOE_INPUTS_DOCSTRING = r"""
888
+ Args:
889
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
890
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
891
+ it.
892
+
893
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
894
+ [`PreTrainedTokenizer.__call__`] for details.
895
+
896
+ [What are input IDs?](../glossary#input-ids)
897
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
898
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
899
+
900
+ - 1 for tokens that are **not masked**,
901
+ - 0 for tokens that are **masked**.
902
+
903
+ [What are attention masks?](../glossary#attention-mask)
904
+
905
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
906
+ [`PreTrainedTokenizer.__call__`] for details.
907
+
908
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
909
+ `past_key_values`).
910
+
911
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
912
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
913
+ information on the default strategy.
914
+
915
+ - 1 indicates the head is **not masked**,
916
+ - 0 indicates the head is **masked**.
917
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
918
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
919
+ config.n_positions - 1]`.
920
+
921
+ [What are position IDs?](../glossary#position-ids)
922
+ past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
923
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
924
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
925
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
926
+
927
+ Two formats are allowed:
928
+ - a [`~cache_utils.Cache`] instance;
929
+ - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
930
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
931
+ cache format.
932
+
933
+ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
934
+ legacy cache format will be returned.
935
+
936
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
937
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
938
+ of shape `(batch_size, sequence_length)`.
939
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
940
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
941
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
942
+ model's internal embedding lookup matrix.
943
+ use_cache (`bool`, *optional*):
944
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
945
+ `past_key_values`).
946
+ output_attentions (`bool`, *optional*):
947
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
948
+ tensors for more detail.
949
+ output_hidden_states (`bool`, *optional*):
950
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
951
+ more detail.
952
+ return_dict (`bool`, *optional*):
953
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
954
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
955
+ Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
956
+ this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
957
+ the complete sequence length.
958
  """
959
 
960
+
961
  @add_start_docstrings(
962
+ "The bare Gemmoe Model outputting raw hidden-states without any specific head on top.",
963
+ GEMMOE_START_DOCSTRING,
964
  )
965
 
966
  class GemmoeModel(GemmoePreTrainedModel):
967
+ """
968
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`GemmoeDecoderLayer`]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
969
 
970
  Args:
971
+ config: GemmoeConfig
972
+ """
973
+
974
+ def __init__(self, config: GemmoeConfig):
975
+ super().__init__(config)
976
+ self.padding_idx = config.pad_token_id
977
+ self.vocab_size = config.vocab_size
978
 
979
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
980
+ self.layers = nn.ModuleList(
981
+ [GemmoeDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
982
+ )
983
+ self.norm = GemmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
984
+ self.gradient_checkpointing = False
985
 
986
+ # Register a causal mask to separate causal and padding mask creation. Merging happens in the attention class.
987
+ # NOTE: This is not friendly with TorchScript, ONNX, ExportedProgram serialization for very large `max_position_embeddings`.
988
+ causal_mask = torch.full(
989
+ (config.max_position_embeddings, config.max_position_embeddings), fill_value=True, dtype=torch.bool
990
+ )
991
+ self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False)
992
+ # Initialize weights and apply final processing
993
+ self.post_init()
994
 
995
+ def get_input_embeddings(self):
996
+ return self.embed_tokens
997
 
998
+ def set_input_embeddings(self, value):
999
+ self.embed_tokens = value
1000
+
1001
+ @add_start_docstrings_to_model_forward(GEMMOE_INPUTS_DOCSTRING)
1002
+ # Ignore copy
1003
+ def forward(
1004
+ self,
1005
+ input_ids: torch.LongTensor = None,
1006
+ attention_mask: Optional[torch.Tensor] = None,
1007
+ position_ids: Optional[torch.LongTensor] = None,
1008
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1009
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1010
+ use_cache: Optional[bool] = None,
1011
+ output_attentions: Optional[bool] = None,
1012
+ output_hidden_states: Optional[bool] = None,
1013
+ return_dict: Optional[bool] = None,
1014
+ cache_position: Optional[torch.LongTensor] = None,
1015
+ ) -> Union[Tuple, MoeModelOutputWithPast]:
1016
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1017
+ output_hidden_states = (
1018
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1019
+ )
1020
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1021
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1022
+
1023
+ if (input_ids is None) ^ (inputs_embeds is not None):
1024
+ raise ValueError(
1025
+ "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
1026
+ )
1027
+
1028
+ if self.gradient_checkpointing and self.training and use_cache:
1029
+ logger.warning_once(
1030
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
1031
+ )
1032
+ use_cache = False
1033
+
1034
+ if inputs_embeds is None:
1035
+ inputs_embeds = self.embed_tokens(input_ids)
1036
+
1037
+ past_seen_tokens = 0
1038
+ if use_cache: # kept for BC (cache positions)
1039
+ if not isinstance(past_key_values, StaticCache):
1040
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
1041
+ past_seen_tokens = past_key_values.get_seq_length()
1042
+
1043
+ if cache_position is None:
1044
+ cache_position = torch.arange(
1045
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
1046
+ )
1047
+
1048
+ if position_ids is None:
1049
+ position_ids = cache_position.unsqueeze(0)
1050
+
1051
+ causal_mask = self._update_causal_mask(attention_mask, inputs_embeds)
1052
+
1053
+ # embed positions
1054
+ hidden_states = inputs_embeds
1055
+
1056
+ # normalized
1057
+ hidden_states = hidden_states * (self.config.hidden_size**0.5)
1058
+
1059
+ # decoder layers
1060
+ all_hidden_states = () if output_hidden_states else None
1061
+ all_self_attns = () if output_attentions else None
1062
+ next_decoder_cache = None
1063
+
1064
+ for decoder_layer in self.layers:
1065
+ if output_hidden_states:
1066
+ all_hidden_states += (hidden_states,)
1067
+
1068
+ if self.gradient_checkpointing and self.training:
1069
+ layer_outputs = self._gradient_checkpointing_func(
1070
+ decoder_layer.__call__,
1071
+ hidden_states,
1072
+ causal_mask,
1073
+ position_ids,
1074
+ past_key_values,
1075
+ output_attentions,
1076
+ use_cache,
1077
+ cache_position,
1078
+ )
1079
+ else:
1080
+ layer_outputs = decoder_layer(
1081
+ hidden_states,
1082
+ attention_mask=causal_mask,
1083
+ position_ids=position_ids,
1084
+ past_key_value=past_key_values,
1085
+ output_attentions=output_attentions,
1086
+ use_cache=use_cache,
1087
+ cache_position=cache_position,
1088
+ )
1089
+
1090
+ hidden_states = layer_outputs[0]
1091
+
1092
+ if use_cache:
1093
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
1094
+
1095
+ if output_attentions:
1096
+ all_self_attns += (layer_outputs[1],)
1097
+
1098
+ hidden_states = self.norm(hidden_states)
1099
+
1100
+ # add hidden states from the last decoder layer
1101
+ if output_hidden_states:
1102
+ all_hidden_states += (hidden_states,)
1103
+
1104
+ next_cache = None
1105
+ if use_cache:
1106
+ next_cache = (
1107
+ next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache
1108
+ )
1109
+ if not return_dict:
1110
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
1111
+ return MoeModelOutputWithPast(
1112
+ last_hidden_state=hidden_states,
1113
+ past_key_values=next_cache,
1114
+ hidden_states=all_hidden_states,
1115
+ attentions=all_self_attns,
1116
+ )
1117
+
1118
+ # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
1119
+ # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
1120
+ # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
1121
+ # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
1122
+ def _update_causal_mask(self, attention_mask, input_tensor):
1123
+ if self.config._attn_implementation == "flash_attention_2":
1124
+ if attention_mask is not None and 0.0 in attention_mask:
1125
+ return attention_mask
1126
+ return None
1127
+
1128
+ batch_size, seq_length = input_tensor.shape[:2]
1129
+ dtype = input_tensor.dtype
1130
+ device = input_tensor.device
1131
+
1132
+ # support going beyond cached `max_position_embedding`
1133
+ if seq_length > self.causal_mask.shape[-1]:
1134
+ causal_mask = torch.full((2 * self.causal_mask.shape[-1], 2 * self.causal_mask.shape[-1]), fill_value=1)
1135
+ self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False)
1136
+
1137
+ # We use the current dtype to avoid any overflows
1138
+ min_dtype = torch.finfo(dtype).min
1139
+
1140
+ causal_mask = self.causal_mask[None, None, :, :].to(dtype=dtype, device=device) * min_dtype
1141
+ causal_mask = causal_mask.expand(batch_size, 1, -1, -1)
1142
+ if attention_mask is not None:
1143
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
1144
+ if attention_mask.dim() == 2:
1145
+ mask_length = attention_mask.shape[-1]
1146
+ padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
1147
+ causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)
1148
+ elif attention_mask.dim() == 4:
1149
+ mask_shape = attention_mask.shape
1150
+ mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype
1151
+ causal_mask[: mask_shape[0], : mask_shape[1], : mask_shape[2], : mask_shape[3]] = mask_slice
1152
+
1153
+ if (
1154
+ self.config._attn_implementation == "sdpa"
1155
+ and attention_mask is not None
1156
+ and attention_mask.device.type == "cuda"
1157
+ ):
1158
+ # TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400).
1159
+ is_tracing = (
1160
+ torch.jit.is_tracing()
1161
+ or isinstance(input_tensor, torch.fx.Proxy)
1162
+ or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
1163
+ )
1164
+ if not is_tracing and torch.any(attention_mask != 1):
1165
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
1166
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
1167
+ # Details: https://github.com/pytorch/pytorch/issues/110213
1168
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
1169
+
1170
+ return causal_mask
1171
+
1172
+ class GemmoeForCausalLM(GemmoePreTrainedModel):
1173
  _tied_weights_keys = ["lm_head.weight"]
1174
 
1175
  def __init__(self, config):
 
1178
  self.vocab_size = config.vocab_size
1179
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1180
  self.router_aux_loss_coef = config.router_aux_loss_coef
1181
+ self.num_experts = config.num_local_experts
1182
  self.num_experts_per_tok = config.num_experts_per_tok
 
1183
  # Initialize weights and apply final processing
1184
  self.post_init()
1185
 
 
1203
 
1204
  @add_start_docstrings_to_model_forward(GEMMOE_INPUTS_DOCSTRING)
1205
  @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1206
+ # Ignore copy
1207
  def forward(
1208
  self,
1209
  input_ids: torch.LongTensor = None,
 
1217
  output_hidden_states: Optional[bool] = None,
1218
  output_router_logits: Optional[bool] = None,
1219
  return_dict: Optional[bool] = None,
 
1220
  ) -> Union[Tuple, MoeCausalLMOutputWithPast]:
1221
  r"""
1222
  Args:
 
1232
  ```python
1233
  >>> from transformers import AutoTokenizer, GemmoeForCausalLM
1234
 
1235
+ >>> model = GemmoeForCausalLM.from_pretrained("mistralai/Gemmoe-8x7B-v0.1")
1236
+ >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Gemmoe-8x7B-v0.1")
1237
 
1238
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
1239
  >>> inputs = tokenizer(prompt, return_tensors="pt")
1240
 
1241
  >>> # Generate
1242
  >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1243
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1244
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1245
  ```"""
1246
+
1247
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1248
  output_router_logits = (
1249
+ output_router_logits if output_router_logits is not None else self.config.output_router_logits
1250
  )
1251
+
1252
  output_hidden_states = (
1253
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1254
  )
1255
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1256
 
1257
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1258
  outputs = self.model(
1259
  input_ids=input_ids,
1260
  attention_mask=attention_mask,
 
1266
  output_hidden_states=output_hidden_states,
1267
  output_router_logits=output_router_logits,
1268
  return_dict=return_dict,
 
1269
  )
1270
 
1271
  hidden_states = outputs[0]
 
 
 
 
1272
  logits = self.lm_head(hidden_states)
1273
+ logits = logits.float()
1274
 
1275
  loss = None
1276
  if labels is not None:
1277
+ # Shift so that tokens < n predict n
1278
  shift_logits = logits[..., :-1, :].contiguous()
1279
  shift_labels = labels[..., 1:].contiguous()
1280
+ # Flatten the tokens
1281
  loss_fct = CrossEntropyLoss()
1282
  shift_logits = shift_logits.view(-1, self.config.vocab_size)
1283
  shift_labels = shift_labels.view(-1)
1284
+ # Enable model parallelism
1285
  shift_labels = shift_labels.to(shift_logits.device)
1286
  loss = loss_fct(shift_logits, shift_labels)
1287
 
1288
  aux_loss = None
1289
  if output_router_logits:
1290
+ aux_loss = load_balancing_loss_func(
1291
+ outputs.router_logits if return_dict else outputs[-1],
1292
+ self.num_experts,
1293
+ self.num_experts_per_tok,
1294
+ attention_mask,
1295
+ )
1296
+ if labels is not None:
1297
+ loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
 
 
1298
 
1299
  if not return_dict:
1300
  output = (logits,) + outputs[1:]
1301
+ if output_router_logits:
1302
  output = (aux_loss,) + output
1303
  return (loss,) + output if loss is not None else output
1304
 
 
1313
  )
1314
 
1315
  def prepare_inputs_for_generation(
1316
+ self,
1317
+ input_ids,
1318
+ past_key_values=None,
1319
+ attention_mask=None,
1320
+ inputs_embeds=None,
1321
+ output_router_logits=False,
1322
+ **kwargs,
1323
  ):
1324
+ # Omit tokens covered by past_key_values
1325
  if past_key_values is not None:
1326
  if isinstance(past_key_values, Cache):
1327
  cache_length = past_key_values.get_seq_length()
 
1331
  cache_length = past_length = past_key_values[0][0].shape[2]
1332
  max_cache_length = None
1333
 
1334
+ # Keep only the unprocessed tokens:
1335
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1336
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
1337
+ # input)
1338
  if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
1339
  input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
1340
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1341
+ # input_ids based on the past_length.
1342
  elif past_length < input_ids.shape[1]:
1343
  input_ids = input_ids[:, past_length:]
1344
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
1345
+
1346
+ # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
1347
  if (
1348
  max_cache_length is not None
1349
  and attention_mask is not None
 
1353
 
1354
  position_ids = kwargs.get("position_ids", None)
1355
  if attention_mask is not None and position_ids is None:
1356
+ # create position_ids on the fly for batch generation
1357
  position_ids = attention_mask.long().cumsum(-1) - 1
1358
  position_ids.masked_fill_(attention_mask == 0, 1)
1359
  if past_key_values:
1360
  position_ids = position_ids[:, -input_ids.shape[1] :]
1361
 
1362
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
 
 
 
 
 
 
 
 
 
 
1363
  if inputs_embeds is not None and past_key_values is None:
1364
  model_inputs = {"inputs_embeds": inputs_embeds}
1365
  else:
1366
+ model_inputs = {"input_ids": input_ids}
1367
 
1368
  model_inputs.update(
1369
  {
1370
+ "position_ids": position_ids,
 
1371
  "past_key_values": past_key_values,
1372
  "use_cache": kwargs.get("use_cache"),
1373
  "attention_mask": attention_mask,
1374
+ "output_router_logits": output_router_logits,
1375
  }
1376
  )
 
1377
  return model_inputs
1378
 
1379
  @staticmethod
 
1406
  self.num_labels = config.num_labels
1407
  self.model = GemmoeModel(config)
1408
  self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1409
+
1410
  # Initialize weights and apply final processing
1411
  self.post_init()
1412
 
 
1417
  self.model.embed_tokens = value
1418
 
1419
  @add_start_docstrings_to_model_forward(GEMMOE_INPUTS_DOCSTRING)
 
1420
  def forward(
1421
  self,
1422
  input_ids: torch.LongTensor = None,
 
1430
  output_hidden_states: Optional[bool] = None,
1431
  return_dict: Optional[bool] = None,
1432
  ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1433
+ r"""
1434
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1435
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1436
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1437
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
 
 
 
 
 
 
 
 
 
 
 
 
1438
  """
1439
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1440
+
1441
  transformer_outputs = self.model(
1442
  input_ids,
1443
  attention_mask=attention_mask,
 
1463
  sequence_lengths = -1
1464
  else:
1465
  if input_ids is not None:
1466
+ # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
1467
+ sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
1468
+ sequence_lengths = sequence_lengths % input_ids.shape[-1]
1469
+ sequence_lengths = sequence_lengths.to(logits.device)
1470
  else:
1471
  sequence_lengths = -1
1472
 
 
1495
  elif self.config.problem_type == "multi_label_classification":
1496
  loss_fct = BCEWithLogitsLoss()
1497
  loss = loss_fct(pooled_logits, labels)
 
1498
  if not return_dict:
1499
  output = (pooled_logits,) + transformer_outputs[1:]
1500
  return ((loss,) + output) if loss is not None else output
 
1505
  past_key_values=transformer_outputs.past_key_values,
1506
  hidden_states=transformer_outputs.hidden_states,
1507
  attentions=transformer_outputs.attentions,
1508
+ )