LeroyDyer commited on
Commit
2384d95
1 Parent(s): cf7541a

Upload modeling_mistral.py

Browse files
Files changed (1) hide show
  1. modeling_mistral.py +41 -100
modeling_mistral.py CHANGED
@@ -414,7 +414,7 @@ class MistralStarConfig(PretrainedConfig):
414
  >>> configuration = model.config
415
  ```"""
416
 
417
- model_type = "mistral"
418
  keys_to_ignore_at_inference = ["past_key_values"]
419
 
420
  def __init__(
@@ -661,6 +661,7 @@ class MistralPreTrainedModel(PreTrainedModel):
661
  if module.padding_idx is not None:
662
  module.weight.data[module.padding_idx].zero_()
663
 
 
664
  @add_start_docstrings(
665
  "The bare Mistral Model outputting raw hidden-states without any specific head on top.",
666
  MISTRAL_START_DOCSTRING,
@@ -673,7 +674,7 @@ class MistralModel(MistralPreTrainedModel):
673
  config: MistralConfig
674
  """
675
 
676
- def __init__(self, config):
677
  super().__init__(config)
678
  self.padding_idx = config.pad_token_id
679
  self.vocab_size = config.vocab_size
@@ -694,8 +695,6 @@ class MistralModel(MistralPreTrainedModel):
694
 
695
  def set_input_embeddings(self, value):
696
  self.embed_tokens = value
697
-
698
-
699
 
700
  @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING)
701
  def forward(
@@ -703,12 +702,13 @@ class MistralModel(MistralPreTrainedModel):
703
  input_ids: torch.LongTensor = None,
704
  attention_mask: Optional[torch.Tensor] = None,
705
  position_ids: Optional[torch.LongTensor] = None,
706
- past_key_values: Optional[List[torch.FloatTensor]] = None,
707
  inputs_embeds: Optional[torch.FloatTensor] = None,
708
  use_cache: Optional[bool] = None,
709
  output_attentions: Optional[bool] = None,
710
  output_hidden_states: Optional[bool] = None,
711
  return_dict: Optional[bool] = None,
 
712
  ) -> Union[Tuple, BaseModelOutputWithPast]:
713
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
714
  output_hidden_states = (
@@ -719,73 +719,42 @@ class MistralModel(MistralPreTrainedModel):
719
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
720
 
721
  # retrieve input_ids and inputs_embeds
722
- if input_ids is not None and inputs_embeds is not None:
723
- raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
724
- elif input_ids is not None:
725
- batch_size, seq_length = input_ids.shape
726
- elif inputs_embeds is not None:
727
- batch_size, seq_length, _ = inputs_embeds.shape
728
- else:
729
- raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
730
-
731
- if self.gradient_checkpointing and self.training:
732
- if use_cache:
733
- logger.warning_once(
734
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
735
- )
736
- use_cache = False
737
-
738
- past_key_values_length = 0
739
-
740
- if use_cache:
741
- use_legacy_cache = not isinstance(past_key_values, Cache)
742
- if use_legacy_cache:
743
- past_key_values = DynamicCache.from_legacy_cache(past_key_values)
744
- past_key_values_length = past_key_values.get_usable_length(seq_length)
745
 
746
- if position_ids is None:
747
- device = input_ids.device if input_ids is not None else inputs_embeds.device
748
- position_ids = torch.arange(
749
- past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
750
  )
751
- position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
752
- else:
753
- position_ids = position_ids.view(-1, seq_length).long()
754
 
755
  if inputs_embeds is None:
756
  inputs_embeds = self.embed_tokens(input_ids)
757
 
758
- if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache:
759
- is_padding_right = attention_mask[:, -1].sum().item() != batch_size
760
- if is_padding_right:
761
- raise ValueError(
762
- "You are attempting to perform batched generation with padding_side='right'"
763
- " this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to "
764
- " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
765
- )
766
-
767
- if self._attn_implementation == "flash_attention_2":
768
- # 2d mask is passed through the layers
769
- attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
770
- elif self._attn_implementation == "sdpa" and not output_attentions and attention_mask.dim() == 2 and False:
771
- # output_attentions=True can not be supported when using SDPA, and we fall back on
772
- # the manual implementation that requires a 4D causal mask in all cases.
773
- attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
774
- attention_mask,
775
- (batch_size, seq_length),
776
- inputs_embeds,
777
- past_key_values_length,
778
  )
779
- elif attention_mask is None or attention_mask.dim() == 2:
780
- # 4d mask is passed through the layers
781
- attention_mask = _prepare_4d_causal_attention_mask(
782
- attention_mask,
783
- (batch_size, seq_length),
784
- inputs_embeds,
785
- past_key_values_length,
786
- sliding_window=self.config.sliding_window,
787
  )
788
 
 
 
 
 
 
 
 
789
  hidden_states = inputs_embeds
790
 
791
  # decoder layers
@@ -801,20 +770,22 @@ class MistralModel(MistralPreTrainedModel):
801
  layer_outputs = self._gradient_checkpointing_func(
802
  decoder_layer.__call__,
803
  hidden_states,
804
- attention_mask,
805
  position_ids,
806
  past_key_values,
807
  output_attentions,
808
  use_cache,
 
809
  )
810
  else:
811
  layer_outputs = decoder_layer(
812
  hidden_states,
813
- attention_mask=attention_mask,
814
  position_ids=position_ids,
815
  past_key_value=past_key_values,
816
  output_attentions=output_attentions,
817
  use_cache=use_cache,
 
818
  )
819
 
820
  hidden_states = layer_outputs[0]
@@ -831,9 +802,9 @@ class MistralModel(MistralPreTrainedModel):
831
  if output_hidden_states:
832
  all_hidden_states += (hidden_states,)
833
 
834
- next_cache = None
835
- if use_cache:
836
- next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
837
 
838
  if not return_dict:
839
  return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
@@ -853,7 +824,7 @@ class MistralModel(MistralPreTrainedModel):
853
  use_cache: bool,
854
  output_attentions: bool,
855
  ):
856
-
857
  # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
858
  # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
859
  # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
@@ -952,43 +923,13 @@ class MistralModel(MistralPreTrainedModel):
952
 
953
  return causal_mask
954
 
955
- ############################## LM Heads #################################
956
-
957
 
958
- try:
959
- if not is_torch_available():
960
- raise OptionalDependencyNotAvailable()
961
- except OptionalDependencyNotAvailable:
962
- pass
963
- else:
964
- _import_structure["_SpydazWebAI_Mistral_Transformer_"] = [
965
- "MistralForCausalLM",
966
- "MistralModel",
967
- "MistralPreTrainedModel",
968
- "MistralForSequenceClassification",
969
- "MistralForTokenClassification",
970
- ]
971
 
972
- pass
973
 
974
 
975
 
976
- if TYPE_CHECKING:
977
- from .configuration_mistral import MistralConfig
978
 
979
- try:
980
- if not is_torch_available():
981
- raise OptionalDependencyNotAvailable()
982
- except OptionalDependencyNotAvailable:
983
- pass
984
- else:
985
- from .modeling_mistral import (
986
- MistralForCausalLM,
987
- MistralForSequenceClassification,
988
- MistralForTokenClassification,
989
- MistralModel,
990
- MistralPreTrainedModel,
991
- )
992
 
993
  ################################ Tokenizer ##############################
994
  class MistralTokenizer(PreTrainedTokenizer):
 
414
  >>> configuration = model.config
415
  ```"""
416
 
417
+ model_type = "mistralstar"
418
  keys_to_ignore_at_inference = ["past_key_values"]
419
 
420
  def __init__(
 
661
  if module.padding_idx is not None:
662
  module.weight.data[module.padding_idx].zero_()
663
 
664
+
665
  @add_start_docstrings(
666
  "The bare Mistral Model outputting raw hidden-states without any specific head on top.",
667
  MISTRAL_START_DOCSTRING,
 
674
  config: MistralConfig
675
  """
676
 
677
+ def __init__(self, config: MistralConfig):
678
  super().__init__(config)
679
  self.padding_idx = config.pad_token_id
680
  self.vocab_size = config.vocab_size
 
695
 
696
  def set_input_embeddings(self, value):
697
  self.embed_tokens = value
 
 
698
 
699
  @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING)
700
  def forward(
 
702
  input_ids: torch.LongTensor = None,
703
  attention_mask: Optional[torch.Tensor] = None,
704
  position_ids: Optional[torch.LongTensor] = None,
705
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
706
  inputs_embeds: Optional[torch.FloatTensor] = None,
707
  use_cache: Optional[bool] = None,
708
  output_attentions: Optional[bool] = None,
709
  output_hidden_states: Optional[bool] = None,
710
  return_dict: Optional[bool] = None,
711
+ cache_position: Optional[torch.LongTensor] = None,
712
  ) -> Union[Tuple, BaseModelOutputWithPast]:
713
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
714
  output_hidden_states = (
 
719
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
720
 
721
  # retrieve input_ids and inputs_embeds
722
+ if (input_ids is None) ^ (inputs_embeds is not None):
723
+ raise ValueError(
724
+ "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
725
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
726
 
727
+ if self.gradient_checkpointing and self.training and use_cache:
728
+ logger.warning_once(
729
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
 
730
  )
731
+ use_cache = False
 
 
732
 
733
  if inputs_embeds is None:
734
  inputs_embeds = self.embed_tokens(input_ids)
735
 
736
+ return_legacy_cache = False
737
+ if use_cache and not isinstance(past_key_values, Cache):
738
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
739
+ return_legacy_cache = True
740
+ logger.warning_once(
741
+ "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. "
742
+ "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)"
 
 
 
 
 
 
 
 
 
 
 
 
 
743
  )
744
+
745
+ if cache_position is None:
746
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
747
+ cache_position = torch.arange(
748
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
 
 
 
749
  )
750
 
751
+ if position_ids is None:
752
+ position_ids = cache_position.unsqueeze(0)
753
+
754
+ causal_mask = self._update_causal_mask(
755
+ attention_mask, inputs_embeds, cache_position, past_key_values, use_cache, output_attentions
756
+ )
757
+
758
  hidden_states = inputs_embeds
759
 
760
  # decoder layers
 
770
  layer_outputs = self._gradient_checkpointing_func(
771
  decoder_layer.__call__,
772
  hidden_states,
773
+ causal_mask,
774
  position_ids,
775
  past_key_values,
776
  output_attentions,
777
  use_cache,
778
+ cache_position,
779
  )
780
  else:
781
  layer_outputs = decoder_layer(
782
  hidden_states,
783
+ attention_mask=causal_mask,
784
  position_ids=position_ids,
785
  past_key_value=past_key_values,
786
  output_attentions=output_attentions,
787
  use_cache=use_cache,
788
+ cache_position=cache_position,
789
  )
790
 
791
  hidden_states = layer_outputs[0]
 
802
  if output_hidden_states:
803
  all_hidden_states += (hidden_states,)
804
 
805
+ next_cache = next_decoder_cache if use_cache else None
806
+ if return_legacy_cache:
807
+ next_cache = next_cache.to_legacy_cache()
808
 
809
  if not return_dict:
810
  return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
 
824
  use_cache: bool,
825
  output_attentions: bool,
826
  ):
827
+ # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
828
  # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
829
  # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
830
  # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
 
923
 
924
  return causal_mask
925
 
 
 
926
 
927
+ ############################## LM Heads #################################
 
 
 
 
 
 
 
 
 
 
 
 
928
 
 
929
 
930
 
931
 
 
 
932
 
 
 
 
 
 
 
 
 
 
 
 
 
 
933
 
934
  ################################ Tokenizer ##############################
935
  class MistralTokenizer(PreTrainedTokenizer):