Upload modeling_mistral.py
Browse files- modeling_mistral.py +41 -100
modeling_mistral.py
CHANGED
@@ -414,7 +414,7 @@ class MistralStarConfig(PretrainedConfig):
|
|
414 |
>>> configuration = model.config
|
415 |
```"""
|
416 |
|
417 |
-
model_type = "
|
418 |
keys_to_ignore_at_inference = ["past_key_values"]
|
419 |
|
420 |
def __init__(
|
@@ -661,6 +661,7 @@ class MistralPreTrainedModel(PreTrainedModel):
|
|
661 |
if module.padding_idx is not None:
|
662 |
module.weight.data[module.padding_idx].zero_()
|
663 |
|
|
|
664 |
@add_start_docstrings(
|
665 |
"The bare Mistral Model outputting raw hidden-states without any specific head on top.",
|
666 |
MISTRAL_START_DOCSTRING,
|
@@ -673,7 +674,7 @@ class MistralModel(MistralPreTrainedModel):
|
|
673 |
config: MistralConfig
|
674 |
"""
|
675 |
|
676 |
-
def __init__(self, config):
|
677 |
super().__init__(config)
|
678 |
self.padding_idx = config.pad_token_id
|
679 |
self.vocab_size = config.vocab_size
|
@@ -694,8 +695,6 @@ class MistralModel(MistralPreTrainedModel):
|
|
694 |
|
695 |
def set_input_embeddings(self, value):
|
696 |
self.embed_tokens = value
|
697 |
-
|
698 |
-
|
699 |
|
700 |
@add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING)
|
701 |
def forward(
|
@@ -703,12 +702,13 @@ class MistralModel(MistralPreTrainedModel):
|
|
703 |
input_ids: torch.LongTensor = None,
|
704 |
attention_mask: Optional[torch.Tensor] = None,
|
705 |
position_ids: Optional[torch.LongTensor] = None,
|
706 |
-
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
707 |
inputs_embeds: Optional[torch.FloatTensor] = None,
|
708 |
use_cache: Optional[bool] = None,
|
709 |
output_attentions: Optional[bool] = None,
|
710 |
output_hidden_states: Optional[bool] = None,
|
711 |
return_dict: Optional[bool] = None,
|
|
|
712 |
) -> Union[Tuple, BaseModelOutputWithPast]:
|
713 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
714 |
output_hidden_states = (
|
@@ -719,73 +719,42 @@ class MistralModel(MistralPreTrainedModel):
|
|
719 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
720 |
|
721 |
# retrieve input_ids and inputs_embeds
|
722 |
-
if input_ids is
|
723 |
-
raise ValueError(
|
724 |
-
|
725 |
-
|
726 |
-
elif inputs_embeds is not None:
|
727 |
-
batch_size, seq_length, _ = inputs_embeds.shape
|
728 |
-
else:
|
729 |
-
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
730 |
-
|
731 |
-
if self.gradient_checkpointing and self.training:
|
732 |
-
if use_cache:
|
733 |
-
logger.warning_once(
|
734 |
-
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
735 |
-
)
|
736 |
-
use_cache = False
|
737 |
-
|
738 |
-
past_key_values_length = 0
|
739 |
-
|
740 |
-
if use_cache:
|
741 |
-
use_legacy_cache = not isinstance(past_key_values, Cache)
|
742 |
-
if use_legacy_cache:
|
743 |
-
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
744 |
-
past_key_values_length = past_key_values.get_usable_length(seq_length)
|
745 |
|
746 |
-
if
|
747 |
-
|
748 |
-
|
749 |
-
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
|
750 |
)
|
751 |
-
|
752 |
-
else:
|
753 |
-
position_ids = position_ids.view(-1, seq_length).long()
|
754 |
|
755 |
if inputs_embeds is None:
|
756 |
inputs_embeds = self.embed_tokens(input_ids)
|
757 |
|
758 |
-
|
759 |
-
|
760 |
-
|
761 |
-
|
762 |
-
|
763 |
-
|
764 |
-
|
765 |
-
)
|
766 |
-
|
767 |
-
if self._attn_implementation == "flash_attention_2":
|
768 |
-
# 2d mask is passed through the layers
|
769 |
-
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
770 |
-
elif self._attn_implementation == "sdpa" and not output_attentions and attention_mask.dim() == 2 and False:
|
771 |
-
# output_attentions=True can not be supported when using SDPA, and we fall back on
|
772 |
-
# the manual implementation that requires a 4D causal mask in all cases.
|
773 |
-
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
774 |
-
attention_mask,
|
775 |
-
(batch_size, seq_length),
|
776 |
-
inputs_embeds,
|
777 |
-
past_key_values_length,
|
778 |
)
|
779 |
-
|
780 |
-
|
781 |
-
|
782 |
-
|
783 |
-
|
784 |
-
inputs_embeds,
|
785 |
-
past_key_values_length,
|
786 |
-
sliding_window=self.config.sliding_window,
|
787 |
)
|
788 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
789 |
hidden_states = inputs_embeds
|
790 |
|
791 |
# decoder layers
|
@@ -801,20 +770,22 @@ class MistralModel(MistralPreTrainedModel):
|
|
801 |
layer_outputs = self._gradient_checkpointing_func(
|
802 |
decoder_layer.__call__,
|
803 |
hidden_states,
|
804 |
-
|
805 |
position_ids,
|
806 |
past_key_values,
|
807 |
output_attentions,
|
808 |
use_cache,
|
|
|
809 |
)
|
810 |
else:
|
811 |
layer_outputs = decoder_layer(
|
812 |
hidden_states,
|
813 |
-
attention_mask=
|
814 |
position_ids=position_ids,
|
815 |
past_key_value=past_key_values,
|
816 |
output_attentions=output_attentions,
|
817 |
use_cache=use_cache,
|
|
|
818 |
)
|
819 |
|
820 |
hidden_states = layer_outputs[0]
|
@@ -831,9 +802,9 @@ class MistralModel(MistralPreTrainedModel):
|
|
831 |
if output_hidden_states:
|
832 |
all_hidden_states += (hidden_states,)
|
833 |
|
834 |
-
next_cache = None
|
835 |
-
if
|
836 |
-
next_cache =
|
837 |
|
838 |
if not return_dict:
|
839 |
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
@@ -853,7 +824,7 @@ class MistralModel(MistralPreTrainedModel):
|
|
853 |
use_cache: bool,
|
854 |
output_attentions: bool,
|
855 |
):
|
856 |
-
|
857 |
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
|
858 |
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
|
859 |
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
|
@@ -952,43 +923,13 @@ class MistralModel(MistralPreTrainedModel):
|
|
952 |
|
953 |
return causal_mask
|
954 |
|
955 |
-
############################## LM Heads #################################
|
956 |
-
|
957 |
|
958 |
-
|
959 |
-
if not is_torch_available():
|
960 |
-
raise OptionalDependencyNotAvailable()
|
961 |
-
except OptionalDependencyNotAvailable:
|
962 |
-
pass
|
963 |
-
else:
|
964 |
-
_import_structure["_SpydazWebAI_Mistral_Transformer_"] = [
|
965 |
-
"MistralForCausalLM",
|
966 |
-
"MistralModel",
|
967 |
-
"MistralPreTrainedModel",
|
968 |
-
"MistralForSequenceClassification",
|
969 |
-
"MistralForTokenClassification",
|
970 |
-
]
|
971 |
|
972 |
-
pass
|
973 |
|
974 |
|
975 |
|
976 |
-
if TYPE_CHECKING:
|
977 |
-
from .configuration_mistral import MistralConfig
|
978 |
|
979 |
-
try:
|
980 |
-
if not is_torch_available():
|
981 |
-
raise OptionalDependencyNotAvailable()
|
982 |
-
except OptionalDependencyNotAvailable:
|
983 |
-
pass
|
984 |
-
else:
|
985 |
-
from .modeling_mistral import (
|
986 |
-
MistralForCausalLM,
|
987 |
-
MistralForSequenceClassification,
|
988 |
-
MistralForTokenClassification,
|
989 |
-
MistralModel,
|
990 |
-
MistralPreTrainedModel,
|
991 |
-
)
|
992 |
|
993 |
################################ Tokenizer ##############################
|
994 |
class MistralTokenizer(PreTrainedTokenizer):
|
|
|
414 |
>>> configuration = model.config
|
415 |
```"""
|
416 |
|
417 |
+
model_type = "mistralstar"
|
418 |
keys_to_ignore_at_inference = ["past_key_values"]
|
419 |
|
420 |
def __init__(
|
|
|
661 |
if module.padding_idx is not None:
|
662 |
module.weight.data[module.padding_idx].zero_()
|
663 |
|
664 |
+
|
665 |
@add_start_docstrings(
|
666 |
"The bare Mistral Model outputting raw hidden-states without any specific head on top.",
|
667 |
MISTRAL_START_DOCSTRING,
|
|
|
674 |
config: MistralConfig
|
675 |
"""
|
676 |
|
677 |
+
def __init__(self, config: MistralConfig):
|
678 |
super().__init__(config)
|
679 |
self.padding_idx = config.pad_token_id
|
680 |
self.vocab_size = config.vocab_size
|
|
|
695 |
|
696 |
def set_input_embeddings(self, value):
|
697 |
self.embed_tokens = value
|
|
|
|
|
698 |
|
699 |
@add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING)
|
700 |
def forward(
|
|
|
702 |
input_ids: torch.LongTensor = None,
|
703 |
attention_mask: Optional[torch.Tensor] = None,
|
704 |
position_ids: Optional[torch.LongTensor] = None,
|
705 |
+
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
706 |
inputs_embeds: Optional[torch.FloatTensor] = None,
|
707 |
use_cache: Optional[bool] = None,
|
708 |
output_attentions: Optional[bool] = None,
|
709 |
output_hidden_states: Optional[bool] = None,
|
710 |
return_dict: Optional[bool] = None,
|
711 |
+
cache_position: Optional[torch.LongTensor] = None,
|
712 |
) -> Union[Tuple, BaseModelOutputWithPast]:
|
713 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
714 |
output_hidden_states = (
|
|
|
719 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
720 |
|
721 |
# retrieve input_ids and inputs_embeds
|
722 |
+
if (input_ids is None) ^ (inputs_embeds is not None):
|
723 |
+
raise ValueError(
|
724 |
+
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
|
725 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
726 |
|
727 |
+
if self.gradient_checkpointing and self.training and use_cache:
|
728 |
+
logger.warning_once(
|
729 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
|
|
730 |
)
|
731 |
+
use_cache = False
|
|
|
|
|
732 |
|
733 |
if inputs_embeds is None:
|
734 |
inputs_embeds = self.embed_tokens(input_ids)
|
735 |
|
736 |
+
return_legacy_cache = False
|
737 |
+
if use_cache and not isinstance(past_key_values, Cache):
|
738 |
+
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
739 |
+
return_legacy_cache = True
|
740 |
+
logger.warning_once(
|
741 |
+
"We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. "
|
742 |
+
"Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
743 |
)
|
744 |
+
|
745 |
+
if cache_position is None:
|
746 |
+
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
747 |
+
cache_position = torch.arange(
|
748 |
+
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
|
|
|
|
|
|
749 |
)
|
750 |
|
751 |
+
if position_ids is None:
|
752 |
+
position_ids = cache_position.unsqueeze(0)
|
753 |
+
|
754 |
+
causal_mask = self._update_causal_mask(
|
755 |
+
attention_mask, inputs_embeds, cache_position, past_key_values, use_cache, output_attentions
|
756 |
+
)
|
757 |
+
|
758 |
hidden_states = inputs_embeds
|
759 |
|
760 |
# decoder layers
|
|
|
770 |
layer_outputs = self._gradient_checkpointing_func(
|
771 |
decoder_layer.__call__,
|
772 |
hidden_states,
|
773 |
+
causal_mask,
|
774 |
position_ids,
|
775 |
past_key_values,
|
776 |
output_attentions,
|
777 |
use_cache,
|
778 |
+
cache_position,
|
779 |
)
|
780 |
else:
|
781 |
layer_outputs = decoder_layer(
|
782 |
hidden_states,
|
783 |
+
attention_mask=causal_mask,
|
784 |
position_ids=position_ids,
|
785 |
past_key_value=past_key_values,
|
786 |
output_attentions=output_attentions,
|
787 |
use_cache=use_cache,
|
788 |
+
cache_position=cache_position,
|
789 |
)
|
790 |
|
791 |
hidden_states = layer_outputs[0]
|
|
|
802 |
if output_hidden_states:
|
803 |
all_hidden_states += (hidden_states,)
|
804 |
|
805 |
+
next_cache = next_decoder_cache if use_cache else None
|
806 |
+
if return_legacy_cache:
|
807 |
+
next_cache = next_cache.to_legacy_cache()
|
808 |
|
809 |
if not return_dict:
|
810 |
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
|
|
824 |
use_cache: bool,
|
825 |
output_attentions: bool,
|
826 |
):
|
827 |
+
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
|
828 |
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
|
829 |
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
|
830 |
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
|
|
|
923 |
|
924 |
return causal_mask
|
925 |
|
|
|
|
|
926 |
|
927 |
+
############################## LM Heads #################################
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
928 |
|
|
|
929 |
|
930 |
|
931 |
|
|
|
|
|
932 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
933 |
|
934 |
################################ Tokenizer ##############################
|
935 |
class MistralTokenizer(PreTrainedTokenizer):
|