LeroyDyer commited on
Commit
0422a54
1 Parent(s): 13fb130

Upload SpydazWebAI_Mistral_Transformer.py

Browse files
Files changed (1) hide show
  1. SpydazWebAI_Mistral_Transformer.py +369 -185
SpydazWebAI_Mistral_Transformer.py CHANGED
@@ -1,9 +1,5 @@
1
  # SpydazWeb AI Mistral Transformer Handler !
2
 
3
- from reportlab.pdfgen import canvas
4
- from reportlab.lib.pagesizes import letter
5
- from reportlab.lib.colors import HexColor
6
- import os
7
  from shutil import copyfile
8
  from collections import defaultdict
9
  from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
@@ -18,7 +14,7 @@ from transformers.utils import is_sentencepiece_available, logging
18
  from transformers.utils.versions import require_version
19
  if TYPE_CHECKING:
20
  from transformers.tokenization_utils_base import TextInput
21
-
22
 
23
 
24
 
@@ -32,7 +28,8 @@ B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
32
 
33
  # fmt: off
34
  DEFAULT_SYSTEM_PROMPT = """you are the world archive of knowledge and a friendly helpful artificial intelligent being:
35
- consider if the current task denotes thecreation of a functions in pyuthon that can be saved and reused : returning the artifact to the user for saving or if access os given to be saved on the user system:
 
36
  Be flirtatious and joyfull, always willing to help or engage in conversation or express an opinion or sentiment:
37
  answer all questions indepth supply and helpfull references where possible; """
38
  # fmt: on
@@ -57,7 +54,7 @@ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
57
  from transformers.modeling_outputs import BaseModelOutputWithPast,CausalLMOutputWithPast,SequenceClassifierOutputWithPast,TokenClassifierOutput,QuestionAnsweringModelOutput,MoeCausalLMOutputWithPast,MoeModelOutputWithPast
58
 
59
  logger = logging.get_logger(__name__)
60
- _CONFIG_FOR_DOC = "MistralConfig"
61
  MISTRAL_START_DOCSTRING = r"""
62
  This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
63
  library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
@@ -82,8 +79,6 @@ MISTRAL_PRETRAINED_CONFIG_ARCHIVE_MAP = {
82
  "mistralai/Mistral-7B-v0.1": "https://huggingface.co/mistralai/Mistral-7B-v0.1/resolve/main/config.json",
83
  "mistralai/Mistral-7B-Instruct-v0.1": "https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1/resolve/main/config.json",
84
  }
85
-
86
-
87
  class MistralStarConfig(PretrainedConfig):
88
  r"""
89
  This is the configuration class to store the configuration of a [`MistralModel`]. It is used to instantiate an
@@ -637,71 +632,6 @@ class MixtralConfig(PretrainedConfig):
637
 
638
 
639
  ################################ Quiet Star Functions ################################
640
-
641
- def save_tokens_with_rewards_to_pdf(input_ids, token_rewards, tokenizer, output_file="text.pdf", eps=0.2, eps2=0.5):
642
- from reportlab.pdfgen import canvas
643
- from reportlab.lib.pagesizes import letter
644
- from reportlab.lib.colors import HexColor
645
- c = canvas.Canvas(output_file, pagesize=letter)
646
- c.setFont("Courier", 8)
647
- x, y = 50, 750
648
- previous_text = ""
649
- current_text = ""
650
- for token_idx, reward in enumerate(token_rewards):
651
- current_text = tokenizer.decode(input_ids[: token_idx + 1])
652
- if current_text != previous_text:
653
- diff_text = current_text[len(previous_text) :]
654
- if "\n" in diff_text:
655
- lines = diff_text.split("\n")
656
- for line_idx, line in enumerate(lines):
657
- if line_idx > 0:
658
- x = 50
659
- y -= 12
660
- if abs(reward) < eps:
661
- opacity = 0
662
- elif abs(reward) > eps2:
663
- opacity = 0.8
664
- else:
665
- opacity = 0.8 * (abs(reward) - eps) / (eps2 - eps)
666
- text_width = c.stringWidth(line)
667
- if reward > 0:
668
- highlight_color = HexColor("#4CCD99")
669
- else:
670
- highlight_color = HexColor("#FFC700")
671
- highlight_color.alpha = opacity
672
- c.setFillColor(highlight_color)
673
- c.rect(x, y - 2, text_width, 10, fill=True, stroke=False)
674
- c.setFillColor(HexColor("#000000"))
675
- c.drawString(x, y, line)
676
- x += text_width
677
- else:
678
- if abs(reward) < eps:
679
- opacity = 0
680
- elif abs(reward) > eps2:
681
- opacity = 0.8
682
- else:
683
- opacity = 0.8 * (abs(reward) - eps) / (eps2 - eps)
684
- text_width = c.stringWidth(diff_text)
685
- if reward > 0:
686
- highlight_color = HexColor("#4CCD99")
687
- else:
688
- highlight_color = HexColor("#FFC700")
689
- highlight_color.alpha = opacity
690
- c.setFillColor(highlight_color)
691
- c.rect(x, y - 2, text_width, 10, fill=True, stroke=False)
692
- c.setFillColor(HexColor("#000000"))
693
- c.drawString(x, y, diff_text)
694
- x += text_width
695
- if x > 550:
696
- x = 50
697
- y -= 12
698
- if y < 50:
699
- c.showPage()
700
- y = 750
701
- x = 50
702
- previous_text = current_text
703
- c.showPage()
704
- c.save()
705
  def nonzero_mean(x, axis=None):
706
  if axis is not None:
707
  return x.sum(axis) / (x != 0).sum(axis)
@@ -721,8 +651,6 @@ def _get_unpad_data(attention_mask):
721
  )
722
 
723
  ################################ Main Network Component ################################
724
-
725
-
726
  class MistralRotaryEmbedding(nn.Module):
727
  def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
728
  super().__init__()
@@ -929,7 +857,7 @@ class MistralSdpaAttention(MistralAttention):
929
  **kwargs,
930
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
931
  if output_attentions:
932
- # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
933
  logger.warning_once(
934
  "MistralModel is using MistralSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
935
  'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
@@ -1090,7 +1018,7 @@ class MistralDecoderLayer(nn.Module):
1090
 
1091
 
1092
 
1093
- config_class = MistraStarlConfig
1094
  base_model_prefix = "model"
1095
  supports_gradient_checkpointing = True
1096
  _no_split_modules = ["MistralDecoderLayer"]
@@ -1110,12 +1038,6 @@ class MistralDecoderLayer(nn.Module):
1110
  module.weight.data.normal_(mean=0.0, std=std)
1111
  if module.padding_idx is not None:
1112
  module.weight.data[module.padding_idx].zero_()
1113
-
1114
- ################################ TRANSFORMER NETWORK ##############################
1115
-
1116
-
1117
- ################################ MOE MiXtral Model : ################################
1118
-
1119
  class MixtralBlockSparseTop2MLP(nn.Module):
1120
  def __init__(self, config: MixtralConfig):
1121
  super().__init__()
@@ -1204,7 +1126,7 @@ class MixtralDecoderLayer(nn.Module):
1204
  self.hidden_size = config.hidden_size
1205
 
1206
  self.self_attn = MISTRAL_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
1207
-
1208
  self.block_sparse_moe = MixtralSparseMoeBlock(config)
1209
  self.input_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1210
  self.post_attention_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -1265,6 +1187,12 @@ class MixtralDecoderLayer(nn.Module):
1265
  hidden_states, router_logits = self.block_sparse_moe(hidden_states)
1266
  hidden_states = residual + hidden_states
1267
 
 
 
 
 
 
 
1268
  outputs = (hidden_states,)
1269
 
1270
  if output_attentions:
@@ -1278,8 +1206,122 @@ class MixtralDecoderLayer(nn.Module):
1278
 
1279
  return outputs
1280
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1281
  ################################ Pretrained Mistral MODEL ##############################
 
 
 
 
1282
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1283
  @add_start_docstrings(
1284
  "The bare Mistral Model outputting raw hidden-states without any specific head on top.",
1285
  MISTRAL_START_DOCSTRING,
@@ -1535,7 +1577,7 @@ class MistralStarModel(MistralPreTrainedModel):
1535
  use_cache: bool,
1536
  output_attentions: bool,
1537
  ):
1538
- # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
1539
  # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
1540
  # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
1541
  # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
@@ -1633,124 +1675,265 @@ class MistralStarModel(MistralPreTrainedModel):
1633
  causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
1634
 
1635
  return causal_mask
 
 
 
 
 
 
 
1636
 
1637
- ################################ PreTrained Mixtral MODEL ##############################
 
 
1638
 
1639
- def load_balancing_loss_func(
1640
- gate_logits: torch.Tensor, num_experts: torch.Tensor = None, top_k=2, attention_mask: Optional[torch.Tensor] = None
1641
- ) -> float:
1642
- r"""
1643
- Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
1644
 
1645
- See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss
1646
- function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
1647
- experts is too unbalanced.
 
 
 
1648
 
1649
- Args:
1650
- gate_logits (Union[`torch.Tensor`, Tuple[torch.Tensor]):
1651
- Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of
1652
- shape [batch_size X sequence_length, num_experts].
1653
- attention_mask (`torch.Tensor`, None):
1654
- The attention_mask used in forward function
1655
- shape [batch_size X sequence_length] if not None.
1656
- num_experts (`int`, *optional*):
1657
- Number of experts
1658
 
1659
- Returns:
1660
- The auxiliary loss.
1661
- """
1662
- if gate_logits is None or not isinstance(gate_logits, tuple):
1663
- return 0
1664
 
1665
- if isinstance(gate_logits, tuple):
1666
- compute_device = gate_logits[0].device
1667
- concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)
1668
 
1669
- routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1670
 
1671
- _, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
1672
 
1673
- expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
 
 
 
 
1674
 
1675
- if attention_mask is None:
1676
- # Compute the percentage of tokens routed to each experts
1677
- tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
 
 
1678
 
1679
- # Compute the average probability of routing to these experts
1680
- router_prob_per_expert = torch.mean(routing_weights, dim=0)
1681
- else:
1682
- batch_size, sequence_length = attention_mask.shape
1683
- num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)
1684
 
1685
- # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
1686
- expert_attention_mask = (
1687
- attention_mask[None, :, :, None, None]
1688
- .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts))
1689
- .reshape(-1, top_k, num_experts)
1690
- .to(compute_device)
1691
- )
 
1692
 
1693
- # Compute the percentage of tokens routed to each experts
1694
- tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(
1695
- expert_attention_mask, dim=0
1696
- )
 
1697
 
1698
- # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
1699
- router_per_expert_attention_mask = (
1700
- attention_mask[None, :, :, None]
1701
- .expand((num_hidden_layers, batch_size, sequence_length, num_experts))
1702
- .reshape(-1, num_experts)
1703
- .to(compute_device)
1704
  )
1705
 
1706
- # Compute the average probability of routing to these experts
1707
- router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(
1708
- router_per_expert_attention_mask, dim=0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1709
  )
1710
 
1711
- overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
1712
- return overall_loss * num_experts
 
 
 
 
 
 
 
 
 
 
1713
 
1714
- MIXTRAL_START_DOCSTRING = r"""
1715
- This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
1716
- library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
1717
- etc.)
 
 
 
 
 
 
 
 
1718
 
1719
- This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
1720
- Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
1721
- and behavior.
1722
 
1723
- Parameters:
1724
- config ([`MixtralConfig`]):
1725
- Model configuration class with all the parameters of the model. Initializing with a config file does not
1726
- load the weights associated with the model, only the configuration. Check out the
1727
- [`~PreTrainedModel.from_pretrained`] method to load the model weights.
1728
- """
1729
- @add_start_docstrings(
1730
- "The bare Mixtral Model outputting raw hidden-states without any specific head on top.",
1731
- MIXTRAL_START_DOCSTRING,
1732
- )
1733
- # Copied from transformers.models.qwen2.modeling_qwen2.Qwen2PreTrainedModel with Qwen2->Mixtral
1734
- class MixtralPreTrainedModel(PreTrainedModel):
1735
- config_class = MixtralConfig
1736
- base_model_prefix = "model"
1737
- supports_gradient_checkpointing = True
1738
- _no_split_modules = ["MixtralDecoderLayer"]
1739
- _skip_keys_device_placement = "past_key_values"
1740
- _supports_flash_attn_2 = False
1741
- _supports_sdpa = True
1742
- _supports_cache_class = True
1743
 
1744
- def _init_weights(self, module):
1745
- std = self.config.initializer_range
1746
- if isinstance(module, nn.Linear):
1747
- module.weight.data.normal_(mean=0.0, std=std)
1748
- if module.bias is not None:
1749
- module.bias.data.zero_()
1750
- elif isinstance(module, nn.Embedding):
1751
- module.weight.data.normal_(mean=0.0, std=std)
1752
- if module.padding_idx is not None:
1753
- module.weight.data[module.padding_idx].zero_()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1754
  MIXTRAL_INPUTS_DOCSTRING = r"""
1755
  Args:
1756
  input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
@@ -1824,7 +2007,6 @@ MIXTRAL_INPUTS_DOCSTRING = r"""
1824
  "The bare Mixtral Model outputting raw hidden-states without any specific head on top.",
1825
  MIXTRAL_START_DOCSTRING,
1826
  )
1827
- # copied from transformers.models.mistral.modeling_mistral.MistralModel with MISTRAL->MIXTRAL,Mistral->Mixtral
1828
  class MixtralModel(MixtralPreTrainedModel):
1829
  """
1830
  Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MixtralDecoderLayer`]
@@ -2071,6 +2253,8 @@ class MixtralModel(MixtralPreTrainedModel):
2071
 
2072
 
2073
  ################################ H-E-A-D-S : ##############################
 
 
2074
  class MixtralForCausalLM(MixtralPreTrainedModel):
2075
  _tied_weights_keys = ["lm_head.weight"]
2076
 
@@ -2516,7 +2700,7 @@ class MistralStarForCausalLM(MistralPreTrainedModel):
2516
 
2517
  def __init__(self, config):
2518
  super().__init__(config)
2519
- self.model = MistralModel(config)
2520
  self.vocab_size = config.vocab_size
2521
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
2522
  self.max_thoughts = config.max_thoughts
@@ -3815,7 +3999,7 @@ class MistralQuietForCausalLM(MistralPreTrainedModel):
3815
  )
3816
  return reordered_past
3817
 
3818
-
3819
  @add_start_docstrings(
3820
  """
3821
  The Mistral Model transformer with a sequence classification head on top (linear layer).
 
1
  # SpydazWeb AI Mistral Transformer Handler !
2
 
 
 
 
 
3
  from shutil import copyfile
4
  from collections import defaultdict
5
  from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
 
14
  from transformers.utils.versions import require_version
15
  if TYPE_CHECKING:
16
  from transformers.tokenization_utils_base import TextInput
17
+ import os
18
 
19
 
20
 
 
28
 
29
  # fmt: off
30
  DEFAULT_SYSTEM_PROMPT = """you are the world archive of knowledge and a friendly helpful artificial intelligent being:
31
+ consider if the current task denotes the creation of a function in python that can be saved and re-used for a future task,
32
+ if created, return the artifact to the user, in the response, or if access is given to be save as a file on the user system:
33
  Be flirtatious and joyfull, always willing to help or engage in conversation or express an opinion or sentiment:
34
  answer all questions indepth supply and helpfull references where possible; """
35
  # fmt: on
 
54
  from transformers.modeling_outputs import BaseModelOutputWithPast,CausalLMOutputWithPast,SequenceClassifierOutputWithPast,TokenClassifierOutput,QuestionAnsweringModelOutput,MoeCausalLMOutputWithPast,MoeModelOutputWithPast
55
 
56
  logger = logging.get_logger(__name__)
57
+ _CONFIG_FOR_DOC = "MistralStarConfig"
58
  MISTRAL_START_DOCSTRING = r"""
59
  This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
60
  library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
 
79
  "mistralai/Mistral-7B-v0.1": "https://huggingface.co/mistralai/Mistral-7B-v0.1/resolve/main/config.json",
80
  "mistralai/Mistral-7B-Instruct-v0.1": "https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1/resolve/main/config.json",
81
  }
 
 
82
  class MistralStarConfig(PretrainedConfig):
83
  r"""
84
  This is the configuration class to store the configuration of a [`MistralModel`]. It is used to instantiate an
 
632
 
633
 
634
  ################################ Quiet Star Functions ################################
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
635
  def nonzero_mean(x, axis=None):
636
  if axis is not None:
637
  return x.sum(axis) / (x != 0).sum(axis)
 
651
  )
652
 
653
  ################################ Main Network Component ################################
 
 
654
  class MistralRotaryEmbedding(nn.Module):
655
  def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
656
  super().__init__()
 
857
  **kwargs,
858
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
859
  if output_attentions:
860
+
861
  logger.warning_once(
862
  "MistralModel is using MistralSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
863
  'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
 
1018
 
1019
 
1020
 
1021
+ config_class = MistralStarConfig
1022
  base_model_prefix = "model"
1023
  supports_gradient_checkpointing = True
1024
  _no_split_modules = ["MistralDecoderLayer"]
 
1038
  module.weight.data.normal_(mean=0.0, std=std)
1039
  if module.padding_idx is not None:
1040
  module.weight.data[module.padding_idx].zero_()
 
 
 
 
 
 
1041
  class MixtralBlockSparseTop2MLP(nn.Module):
1042
  def __init__(self, config: MixtralConfig):
1043
  super().__init__()
 
1126
  self.hidden_size = config.hidden_size
1127
 
1128
  self.self_attn = MISTRAL_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
1129
+ self.mlp = MistralMLP(config)
1130
  self.block_sparse_moe = MixtralSparseMoeBlock(config)
1131
  self.input_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1132
  self.post_attention_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
 
1187
  hidden_states, router_logits = self.block_sparse_moe(hidden_states)
1188
  hidden_states = residual + hidden_states
1189
 
1190
+ # Fully Connected
1191
+ residual = hidden_states
1192
+ hidden_states = self.post_attention_layernorm(hidden_states)
1193
+ hidden_states = self.mlp(hidden_states)
1194
+ hidden_states = residual + hidden_states
1195
+
1196
  outputs = (hidden_states,)
1197
 
1198
  if output_attentions:
 
1206
 
1207
  return outputs
1208
 
1209
+ def load_balancing_loss_func(
1210
+ gate_logits: torch.Tensor, num_experts: torch.Tensor = None, top_k=2, attention_mask: Optional[torch.Tensor] = None
1211
+ ) -> float:
1212
+ r"""
1213
+ Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
1214
+
1215
+ See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss
1216
+ function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
1217
+ experts is too unbalanced.
1218
+
1219
+ Args:
1220
+ gate_logits (Union[`torch.Tensor`, Tuple[torch.Tensor]):
1221
+ Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of
1222
+ shape [batch_size X sequence_length, num_experts].
1223
+ attention_mask (`torch.Tensor`, None):
1224
+ The attention_mask used in forward function
1225
+ shape [batch_size X sequence_length] if not None.
1226
+ num_experts (`int`, *optional*):
1227
+ Number of experts
1228
+
1229
+ Returns:
1230
+ The auxiliary loss.
1231
+ """
1232
+ if gate_logits is None or not isinstance(gate_logits, tuple):
1233
+ return 0
1234
+
1235
+ if isinstance(gate_logits, tuple):
1236
+ compute_device = gate_logits[0].device
1237
+ concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)
1238
+
1239
+ routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)
1240
+
1241
+ _, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
1242
+
1243
+ expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
1244
+
1245
+ if attention_mask is None:
1246
+ # Compute the percentage of tokens routed to each experts
1247
+ tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
1248
+
1249
+ # Compute the average probability of routing to these experts
1250
+ router_prob_per_expert = torch.mean(routing_weights, dim=0)
1251
+ else:
1252
+ batch_size, sequence_length = attention_mask.shape
1253
+ num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)
1254
+
1255
+ # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
1256
+ expert_attention_mask = (
1257
+ attention_mask[None, :, :, None, None]
1258
+ .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts))
1259
+ .reshape(-1, top_k, num_experts)
1260
+ .to(compute_device)
1261
+ )
1262
+
1263
+ # Compute the percentage of tokens routed to each experts
1264
+ tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(
1265
+ expert_attention_mask, dim=0
1266
+ )
1267
+
1268
+ # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
1269
+ router_per_expert_attention_mask = (
1270
+ attention_mask[None, :, :, None]
1271
+ .expand((num_hidden_layers, batch_size, sequence_length, num_experts))
1272
+ .reshape(-1, num_experts)
1273
+ .to(compute_device)
1274
+ )
1275
+
1276
+ # Compute the average probability of routing to these experts
1277
+ router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(
1278
+ router_per_expert_attention_mask, dim=0
1279
+ )
1280
+
1281
+ overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
1282
+ return overall_loss * num_experts
1283
+
1284
  ################################ Pretrained Mistral MODEL ##############################
1285
+ MIXTRAL_START_DOCSTRING = r"""
1286
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
1287
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
1288
+ etc.)
1289
 
1290
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
1291
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
1292
+ and behavior.
1293
+
1294
+ Parameters:
1295
+ config ([`MixtralConfig`]):
1296
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
1297
+ load the weights associated with the model, only the configuration. Check out the
1298
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
1299
+ """
1300
+ @add_start_docstrings(
1301
+ "The bare Mixtral Model outputting raw hidden-states without any specific head on top.",
1302
+ MIXTRAL_START_DOCSTRING,
1303
+ )
1304
+ # Copied from transformers.models.qwen2.modeling_qwen2.Qwen2PreTrainedModel with Qwen2->Mixtral
1305
+ class MixtralPreTrainedModel(PreTrainedModel):
1306
+ config_class = MixtralConfig
1307
+ base_model_prefix = "model"
1308
+ supports_gradient_checkpointing = True
1309
+ _no_split_modules = ["MixtralDecoderLayer"]
1310
+ _skip_keys_device_placement = "past_key_values"
1311
+ _supports_flash_attn_2 = False
1312
+ _supports_sdpa = True
1313
+ _supports_cache_class = True
1314
+
1315
+ def _init_weights(self, module):
1316
+ std = self.config.initializer_range
1317
+ if isinstance(module, nn.Linear):
1318
+ module.weight.data.normal_(mean=0.0, std=std)
1319
+ if module.bias is not None:
1320
+ module.bias.data.zero_()
1321
+ elif isinstance(module, nn.Embedding):
1322
+ module.weight.data.normal_(mean=0.0, std=std)
1323
+ if module.padding_idx is not None:
1324
+ module.weight.data[module.padding_idx].zero_()
1325
  @add_start_docstrings(
1326
  "The bare Mistral Model outputting raw hidden-states without any specific head on top.",
1327
  MISTRAL_START_DOCSTRING,
 
1577
  use_cache: bool,
1578
  output_attentions: bool,
1579
  ):
1580
+
1581
  # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
1582
  # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
1583
  # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
 
1675
  causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
1676
 
1677
  return causal_mask
1678
+ @add_start_docstrings(
1679
+ "The bare Mistral Model outputting raw hidden-states without any specific head on top.",
1680
+ MISTRAL_START_DOCSTRING,
1681
+ )
1682
+ class MistralModel(MistralPreTrainedModel):
1683
+ """
1684
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MistralDecoderLayer`]
1685
 
1686
+ Args:
1687
+ config: MistralConfig
1688
+ """
1689
 
1690
+ def __init__(self, config: MistralConfig):
1691
+ super().__init__(config)
1692
+ self.padding_idx = config.pad_token_id
1693
+ self.vocab_size = config.vocab_size
 
1694
 
1695
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
1696
+ self.layers = nn.ModuleList(
1697
+ [MistralDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
1698
+ )
1699
+ self._attn_implementation = config._attn_implementation
1700
+ self.norm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1701
 
1702
+ self.gradient_checkpointing = False
1703
+ # Initialize weights and apply final processing
1704
+ self.post_init()
 
 
 
 
 
 
1705
 
1706
+ def get_input_embeddings(self):
1707
+ return self.embed_tokens
 
 
 
1708
 
1709
+ def set_input_embeddings(self, value):
1710
+ self.embed_tokens = value
 
1711
 
1712
+ @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING)
1713
+ def forward(
1714
+ self,
1715
+ input_ids: torch.LongTensor = None,
1716
+ attention_mask: Optional[torch.Tensor] = None,
1717
+ position_ids: Optional[torch.LongTensor] = None,
1718
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
1719
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1720
+ use_cache: Optional[bool] = None,
1721
+ output_attentions: Optional[bool] = None,
1722
+ output_hidden_states: Optional[bool] = None,
1723
+ return_dict: Optional[bool] = None,
1724
+ cache_position: Optional[torch.LongTensor] = None,
1725
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
1726
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1727
+ output_hidden_states = (
1728
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1729
+ )
1730
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1731
 
1732
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1733
 
1734
+ # retrieve input_ids and inputs_embeds
1735
+ if (input_ids is None) ^ (inputs_embeds is not None):
1736
+ raise ValueError(
1737
+ "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
1738
+ )
1739
 
1740
+ if self.gradient_checkpointing and self.training and use_cache:
1741
+ logger.warning_once(
1742
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
1743
+ )
1744
+ use_cache = False
1745
 
1746
+ if inputs_embeds is None:
1747
+ inputs_embeds = self.embed_tokens(input_ids)
 
 
 
1748
 
1749
+ return_legacy_cache = False
1750
+ if use_cache and not isinstance(past_key_values, Cache):
1751
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
1752
+ return_legacy_cache = True
1753
+ logger.warning_once(
1754
+ "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. "
1755
+ "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)"
1756
+ )
1757
 
1758
+ if cache_position is None:
1759
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
1760
+ cache_position = torch.arange(
1761
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
1762
+ )
1763
 
1764
+ if position_ids is None:
1765
+ position_ids = cache_position.unsqueeze(0)
1766
+
1767
+ causal_mask = self._update_causal_mask(
1768
+ attention_mask, inputs_embeds, cache_position, past_key_values, use_cache, output_attentions
 
1769
  )
1770
 
1771
+ hidden_states = inputs_embeds
1772
+
1773
+ # decoder layers
1774
+ all_hidden_states = () if output_hidden_states else None
1775
+ all_self_attns = () if output_attentions else None
1776
+ next_decoder_cache = None
1777
+
1778
+ for decoder_layer in self.layers:
1779
+ if output_hidden_states:
1780
+ all_hidden_states += (hidden_states,)
1781
+
1782
+ if self.gradient_checkpointing and self.training:
1783
+ layer_outputs = self._gradient_checkpointing_func(
1784
+ decoder_layer.__call__,
1785
+ hidden_states,
1786
+ causal_mask,
1787
+ position_ids,
1788
+ past_key_values,
1789
+ output_attentions,
1790
+ use_cache,
1791
+ cache_position,
1792
+ )
1793
+ else:
1794
+ layer_outputs = decoder_layer(
1795
+ hidden_states,
1796
+ attention_mask=causal_mask,
1797
+ position_ids=position_ids,
1798
+ past_key_value=past_key_values,
1799
+ output_attentions=output_attentions,
1800
+ use_cache=use_cache,
1801
+ cache_position=cache_position,
1802
+ )
1803
+
1804
+ hidden_states = layer_outputs[0]
1805
+
1806
+ if use_cache:
1807
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
1808
+
1809
+ if output_attentions:
1810
+ all_self_attns += (layer_outputs[1],)
1811
+
1812
+ hidden_states = self.norm(hidden_states)
1813
+
1814
+ # add hidden states from the last decoder layer
1815
+ if output_hidden_states:
1816
+ all_hidden_states += (hidden_states,)
1817
+
1818
+ next_cache = next_decoder_cache if use_cache else None
1819
+ if return_legacy_cache:
1820
+ next_cache = next_cache.to_legacy_cache()
1821
+
1822
+ if not return_dict:
1823
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
1824
+ return BaseModelOutputWithPast(
1825
+ last_hidden_state=hidden_states,
1826
+ past_key_values=next_cache,
1827
+ hidden_states=all_hidden_states,
1828
+ attentions=all_self_attns,
1829
  )
1830
 
1831
+ def _update_causal_mask(
1832
+ self,
1833
+ attention_mask: torch.Tensor,
1834
+ input_tensor: torch.Tensor,
1835
+ cache_position: torch.Tensor,
1836
+ past_key_values: Cache,
1837
+ use_cache: bool,
1838
+ output_attentions: bool,
1839
+ ):
1840
+ # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
1841
+ # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
1842
+ # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
1843
 
1844
+ if self._attn_implementation == "flash_attention_2":
1845
+ if attention_mask is not None and use_cache:
1846
+ is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0]
1847
+ if is_padding_right:
1848
+ raise ValueError(
1849
+ "You are attempting to perform batched generation with padding_side='right'"
1850
+ " this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to "
1851
+ " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
1852
+ )
1853
+ if attention_mask is not None and 0.0 in attention_mask:
1854
+ return attention_mask
1855
+ return None
1856
 
1857
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
1858
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
1859
+ # to infer the attention mask.
1860
 
1861
+ # cache_position must be valid here no matter which cache we use
1862
+ past_seen_tokens = cache_position[0] if past_key_values is not None else 0
1863
+ using_static_cache = isinstance(past_key_values, StaticCache)
1864
+ using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1865
 
1866
+ if (
1867
+ self.config._attn_implementation == "sdpa"
1868
+ and not (using_static_cache or using_sliding_window_cache)
1869
+ and not output_attentions
1870
+ ):
1871
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
1872
+ attention_mask,
1873
+ inputs_embeds=input_tensor,
1874
+ past_key_values_length=past_seen_tokens,
1875
+ sliding_window=self.config.sliding_window,
1876
+ is_training=self.training,
1877
+ ):
1878
+ return None
1879
+
1880
+ dtype, device = input_tensor.dtype, input_tensor.device
1881
+ min_dtype = torch.finfo(dtype).min
1882
+ sequence_length = input_tensor.shape[1]
1883
+ # SlidingWindowCache
1884
+ if using_sliding_window_cache:
1885
+ target_length = max(sequence_length, self.config.sliding_window)
1886
+ # StaticCache
1887
+ elif using_static_cache:
1888
+ target_length = past_key_values.get_max_length()
1889
+ # DynamicCache or no cache
1890
+ else:
1891
+ target_length = (
1892
+ attention_mask.shape[-1]
1893
+ if isinstance(attention_mask, torch.Tensor)
1894
+ else past_seen_tokens + sequence_length + 1
1895
+ )
1896
+
1897
+ if attention_mask is not None and attention_mask.dim() == 4:
1898
+ # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing
1899
+ if attention_mask.max() != 0:
1900
+ raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`")
1901
+ causal_mask = attention_mask
1902
+ else:
1903
+ causal_mask = torch.full(
1904
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
1905
+ )
1906
+ exclude_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
1907
+ if self.config.sliding_window is not None:
1908
+ if not using_sliding_window_cache or sequence_length > self.config.sliding_window:
1909
+ exclude_mask.bitwise_or_(
1910
+ torch.arange(target_length, device=device)
1911
+ <= (cache_position.reshape(-1, 1) - self.config.sliding_window)
1912
+ )
1913
+ causal_mask *= exclude_mask
1914
+ causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
1915
+ if attention_mask is not None:
1916
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
1917
+ if attention_mask.dim() == 2:
1918
+ mask_length = attention_mask.shape[-1]
1919
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
1920
+ padding_mask = padding_mask == 0
1921
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
1922
+ padding_mask, min_dtype
1923
+ )
1924
+
1925
+ if (
1926
+ self.config._attn_implementation == "sdpa"
1927
+ and attention_mask is not None
1928
+ and attention_mask.device.type == "cuda"
1929
+ and not output_attentions
1930
+ ):
1931
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
1932
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
1933
+ # Details: https://github.com/pytorch/pytorch/issues/110213
1934
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
1935
+
1936
+ return causal_mask
1937
  MIXTRAL_INPUTS_DOCSTRING = r"""
1938
  Args:
1939
  input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
 
2007
  "The bare Mixtral Model outputting raw hidden-states without any specific head on top.",
2008
  MIXTRAL_START_DOCSTRING,
2009
  )
 
2010
  class MixtralModel(MixtralPreTrainedModel):
2011
  """
2012
  Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MixtralDecoderLayer`]
 
2253
 
2254
 
2255
  ################################ H-E-A-D-S : ##############################
2256
+
2257
+ ################################ CausalLM ##############################
2258
  class MixtralForCausalLM(MixtralPreTrainedModel):
2259
  _tied_weights_keys = ["lm_head.weight"]
2260
 
 
2700
 
2701
  def __init__(self, config):
2702
  super().__init__(config)
2703
+ self.model = MistralStarModel(config)
2704
  self.vocab_size = config.vocab_size
2705
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
2706
  self.max_thoughts = config.max_thoughts
 
3999
  )
4000
  return reordered_past
4001
 
4002
+ ############################## Extra Heads #################################
4003
  @add_start_docstrings(
4004
  """
4005
  The Mistral Model transformer with a sequence classification head on top (linear layer).