|
718 |
|
719 @add_start_docstrings( |
|
720 "The bare Gemma2 Model outputting raw hidden-states without any specific head on top.", |
|
721 GEMMA2_START_DOCSTRING, |
|
722 ) |
|
723 class Gemma2Model(Gemma2PreTrainedModel): |
|
724 """ |
|
725 Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Gemma2DecoderLayer`] |
|
726 |
|
727 Args: |
|
728 config: Gemma2Config |
|
729 """ |
|
730 |
|
731 def __init__(self, config: Gemma2Config): |
|
732 super().__init__(config) |
|
733 self.padding_idx = config.pad_token_id |
|
734 self.vocab_size = config.vocab_size |
|
735 |
|
736 self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) |
|
737 self.layers = nn.ModuleList( |
|
738 [Gemma2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] |
|
739 ) |
|
740 self.norm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
741 self.gradient_checkpointing = False |
|
742 |
|
743 # Initialize weights and apply final processing |
|
744 self.post_init() |
|
745 |
|
746 def get_input_embeddings(self): |
|
747 return self.embed_tokens |
|
748 |
|
749 def set_input_embeddings(self, value): |
|
750 self.embed_tokens = value |
|
751 |
|
752 @add_start_docstrings_to_model_forward(GEMMA2_INPUTS_DOCSTRING) |
|
753 def forward( |
|
754 self, |
|
755 input_ids: torch.LongTensor = None, |
|
756 attention_mask: Optional[torch.Tensor] = None, |
|
757 position_ids: Optional[torch.LongTensor] = None, |
|
758 past_key_values: Optional[HybridCache] = None, |
|
759 inputs_embeds: Optional[torch.FloatTensor] = None, |
|
760 use_cache: Optional[bool] = None, |
|
761 output_attentions: Optional[bool] = None, |
|
762 output_hidden_states: Optional[bool] = None, |
|
763 return_dict: Optional[bool] = None, |
|
764 cache_position: Optional[torch.LongTensor] = None, |
|
765 ) -> Union[Tuple, BaseModelOutputWithPast]: |
|
766 output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
767 output_hidden_states = ( |
|
768 output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
769 ) |
|
770 use_cache = use_cache if use_cache is not None else self.config.use_cache |
|
771 return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
772 |
|
773 if (input_ids is None) ^ (inputs_embeds is not None): |
|
774 raise ValueError("You must specify exactly one of input_ids or inputs_embeds") |
|
775 |
|
776 if self.gradient_checkpointing and self.training and use_cache: |
|
777 logger.warning_once( |
|
778 "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." |
|
779 ) |
|
780 use_cache = False |
|
781 |
|
782 if inputs_embeds is None: |
|
783 inputs_embeds = self.embed_tokens(input_ids) |
|
784 |
|
785 if use_cache and past_key_values is None and not self.training: |
|
786 batch_size, seq_len, _ = inputs_embeds.shape |
|
787 past_key_values = HybridCache( |
|
788 self.config, |
|
789 batch_size=batch_size, |
|
790 max_cache_len=seq_len, |
|
791 device=self.device, |
|
792 dtype=inputs_embeds.dtype, |
|
793 ) |
|
794 |
|
795 if cache_position is None: |
|
796 past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 |
|
797 cache_position = torch.arange( |
|
798 past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device |
|
799 ) |
|
800 |
|
801 if position_ids is None: |
|
802 position_ids = cache_position.unsqueeze(0) |
|
803 |
|
804 causal_mask = self._update_causal_mask( |
|
805 attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions |
|
806 ) |
|
807 |
|
808 # embed positions |
|
809 hidden_states = inputs_embeds |
|
810 |
|
811 # normalized |
|
812 # Gemma2 downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5 |
|
813 # See https://github.com/huggingface/transformers/pull/29402 |
|
814 normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype) |
|
815 hidden_states = hidden_states * normalizer |
|
816 |
|
817 # decoder layers |
|
818 all_hidden_states = () if output_hidden_states else None |
|
819 all_self_attns = () if output_attentions else None |
|
820 |
|
821 for decoder_layer in self.layers: |
|
822 if output_hidden_states: |
|
823 all_hidden_states += (hidden_states,) |
|
824 |
|
825 if self.gradient_checkpointing and self.training: |
|
826 layer_outputs = self._gradient_checkpointing_func( |
|
827 decoder_layer.__call__, |
|
828 hidden_states, |
|
829 causal_mask, |
|
830 position_ids, |
|
831 past_key_values, |
|
832 output_attentions, |
|
833 use_cache, |
|
834 cache_position, |
|
835 ) |
|
836 else: |
|
837 layer_outputs = decoder_layer( |
|
838 hidden_states, |
|
839 attention_mask=causal_mask, |
|
840 position_ids=position_ids, |
|
841 past_key_value=past_key_values, |
|
842 output_attentions=output_attentions, |
|
843 use_cache=use_cache, |
|
844 cache_position=cache_position, |
|
845 ) |
|
846 |
|
847 hidden_states = layer_outputs[0] |
|
848 |
|
849 if output_attentions: |
|
850 all_self_attns += (layer_outputs[1],) |
|
851 |
|
852 hidden_states = self.norm(hidden_states) |
|
853 |
|
854 if output_hidden_states: |
|
855 all_hidden_states += (hidden_states,) |
|
856 |
|
857 next_cache = past_key_values if use_cache else None |
|
858 |
|
859 if not return_dict: |
|
860 return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) |
|
861 return BaseModelOutputWithPast( |
|
862 last_hidden_state=hidden_states, |
|
863 past_key_values=next_cache, |
|
864 hidden_states=all_hidden_states, |
|
865 attentions=all_self_attns, |
|
866 ) |
|
867 |
|
868 def _update_causal_mask( |
|
869 self, |
|
870 attention_mask: torch.Tensor, |
|
871 input_tensor: torch.Tensor, |
|
872 cache_position: torch.Tensor, |
|
873 past_key_values: HybridCache, |
|
874 output_attentions: bool, |
|
875 ): |
|
876 # Flash Attention currently doesn't support static cache but Gemma2 work only with static cache. |
|
877 # So we will pass in attention mask as is in any case, not only when ther's padding. Then we'll use its shape |
|
878 # to cut out keys/values trailing 0 used in static cache. This workaround should be compile compatible |
|
879 # as it doesn't cause dynamic control issues. |
|
880 if self.config._attn_implementation == "flash_attention_2": |
|
881 return attention_mask |
|
882 |
|
883 dtype, device = input_tensor.dtype, input_tensor.device |
|
884 sequence_length = input_tensor.shape[1] |
|
885 if isinstance(past_key_values, HybridCache): |
|
886 target_length = past_key_values.get_max_cache_shape() |
|
887 else: |
|
888 target_length = attention_mask.shape[-1] if attention_mask is not None else input_tensor.shape[1] |
|
889 |
|
890 # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). |
|
891 causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( |
|
892 attention_mask, |
|
893 sequence_length=sequence_length, |
|
894 target_length=target_length, |
|
895 dtype=dtype, |
|
896 device=device, |
|
897 cache_position=cache_position, |
|
898 batch_size=input_tensor.shape[0], |
|
899 ) |
|
900 return causal_mask |
|
901 |
|
902 @staticmethod |
|
903 def _prepare_4d_causal_attention_mask_with_cache_position( |
|
904 attention_mask: torch.Tensor, |
|
905 sequence_length: int, |
|
906 target_length: int, |
|
907 dtype: torch.dtype, |
|
908 device: torch.device, |
|
909 cache_position: torch.Tensor, |
|
910 batch_size: int, |
|
911 **kwargs, |
|
912 ): |
|
913 """ |
|
914 Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape |
|
915 `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. |
|
916 |
|
917 Args: |
|
918 attention_mask (`torch.Tensor`): |
|
919 A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape |
|
920 `(batch_size, 1, query_length, key_value_length)`. |
|
921 sequence_length (`int`): |
|
922 The sequence length being processed. |
|
923 target_length (`int`): |
|
924 The target length: when generating with static cache, the mask should be as long as the static cache, |
|
925 to account for the 0 padding, the part of the cache that is not filled yet. |
|
926 dtype (`torch.dtype`): |
|
927 The dtype to use for the 4D attention mask. |
|
928 device (`torch.device`): |
|
929 The device to plcae the 4D attention mask on. |
|
930 cache_position (`torch.Tensor`): |
|
931 Indices depicting the position of the input sequence tokens in the sequence. |
|
932 batch_size (`torch.Tensor`): |
|
933 Batch size. |
|
934 """ |
|
935 if attention_mask is not None and attention_mask.dim() == 4: |
|
936 # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. |
|
937 causal_mask = attention_mask |
|
938 else: |
|
939 min_dtype = torch.finfo(dtype).min |
|
940 causal_mask = torch.full( |
|
941 (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device |
|
942 ) |
|
943 if sequence_length != 1: |
|
944 causal_mask = torch.triu(causal_mask, diagonal=1) |
|
945 causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) |
|
946 causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) |
|
947 if attention_mask is not None: |
|
948 causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit |
|
949 mask_length = attention_mask.shape[-1] |
|
950 padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] |
|
951 padding_mask = padding_mask == 0 |
|
952 causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( |
|
953 padding_mask, min_dtype |
|
954 ) |
|
955 |
|
956 return causal_mask |
|
957 |
|
958 |
|
959 class Gemma2ForCausalLM(Gemma2PreTrainedModel, GenerationMixin): |
|
960 _tied_weights_keys = ["lm_head.weight"] |
|
961 |
|
962 def __init__(self, config): |
|
963 super().__init__(config) |
|
964 self.model = Gemma2Model(config) |
|
965 self.vocab_size = config.vocab_size |
|
966 self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
967 |
|
968 # Initialize weights and apply final processing |
|
969 self.post_init() |
|
970 |
|
971 def get_input_embeddings(self): |
|
972 return self.model.embed_tokens |
|
973 |
|
974 def set_input_embeddings(self, value): |
|
975 self.model.embed_tokens = value |
|
976 |
|
977 def get_output_embeddings(self): |
|
978 return self.lm_head |
|
979 |
|
980 def set_output_embeddings(self, new_embeddings): |
|
981 self.lm_head = new_embeddings |
|
982 |
|
983 def set_decoder(self, decoder): |
|
984 self.model = decoder |
|
985 |
|
986 def get_decoder(self): |
|
987 return self.model |
|
988 |
|
989 @add_start_docstrings_to_model_forward(GEMMA2_INPUTS_DOCSTRING) |
|
990 @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) |
|
991 def forward( |
|
992 self, |
|
993 input_ids: torch.LongTensor = None, |
|
994 attention_mask: Optional[torch.Tensor] = None, |
|
995 position_ids: Optional[torch.LongTensor] = None, |
|
996 past_key_values: Optional[HybridCache] = None, |
|
997 inputs_embeds: Optional[torch.FloatTensor] = None, |
|
998 labels: Optional[torch.LongTensor] = None, |
|
999 use_cache: Optional[bool] = None, |
|
1000 output_attentions: Optional[bool] = None, |
|
1001 output_hidden_states: Optional[bool] = None, |
|
1002 return_dict: Optional[bool] = None, |
|
1003 cache_position: Optional[torch.LongTensor] = None, |
|
1004 num_logits_to_keep: int = 0, |
|
1005 **loss_kwargs, |
|
1006 ) -> Union[Tuple, CausalLMOutputWithPast]: |
|
1007 r""" |
|
1008 Args: |
|
1009 labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
|
1010 Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., |
|
1011 config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored |
|
1012 (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. |
|
1013 |
|
1014 num_logits_to_keep (`int`, *optional*): |
|
1015 Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all |
|
1016 `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that |
|
1017 token can save memory, which becomes pretty significant for long sequences or large vocabulary size. |
|
1018 |
|
1019 Returns: |
|
1020 |
|
1021 Example: |
|
1022 |
|
1023 ```python |
|
1024 >>> from transformers import AutoTokenizer, GemmaForCausalLM |
|
1025 |
|
1026 >>> model = GemmaForCausalLM.from_pretrained("google/gemma-2-9b") |
|
1027 >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b") |
|
1028 |
|
1029 >>> prompt = "What is your favorite condiment?" |
|
1030 >>> inputs = tokenizer(prompt, return_tensors="pt") |
|
1031 |
|
1032 >>> # Generate |
|
1033 >>> generate_ids = model.generate(inputs.input_ids, max_length=30) |
|
1034 >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] |
|
1035 "What is your favorite condiment?" |
|
1036 ```""" |
|
1037 |
|
1038 if self.training and self.config._attn_implementation != "eager": |
|
1039 logger.warning_once( |
|
1040 "It is strongly recommended to train Gemma2 models with the `eager` attention implementation " |
|
1041 f"instead of `{self.config._attn_implementation}`. Use `eager` with `AutoModelForCausalLM.from_pretrained('<path-to-checkpoint>', attn_implementation='eager')`." |
|
1042 ) |
|
1043 output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
1044 output_hidden_states = ( |
|
1045 output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
1046 ) |
|
1047 return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
1048 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) |
|
1049 outputs = self.model( |
|
1050 input_ids=input_ids, |
|
1051 attention_mask=attention_mask, |
|
1052 position_ids=position_ids, |
|
1053 past_key_values=past_key_values, |
|
1054 inputs_embeds=inputs_embeds, |
|
1055 use_cache=use_cache, |
|
1056 output_attentions=output_attentions, |
|
1057 output_hidden_states=output_hidden_states, |
|
1058 return_dict=return_dict, |
|
1059 cache_position=cache_position, |
|
1060 ) |
|
1061 |
|
1062 hidden_states = outputs[0] |
|
1063 # Only compute necessary logits, and do not upcast them to float if we are not computing the loss |
|
1064 logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) |
|
1065 if self.config.final_logit_softcapping is not None: |
|
1066 logits = logits / self.config.final_logit_softcapping |
|
1067 logits = torch.tanh(logits) |
|
1068 logits = logits * self.config.final_logit_softcapping |
|
1069 |
|
1070 loss = None |
|
1071 if labels is not None: |
|
1072 loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs) |
|
1073 |
|
1074 if not return_dict: |
|
1075 output = (logits,) + outputs[1:] |
|
1076 return (loss,) + output if loss is not None else output |
|
1077 |
|
1078 return CausalLMOutputWithPast( |
|
1079 loss=loss, |
|
1080 logits=logits, |
|
1081 past_key_values=outputs.past_key_values, |
|
1082 hidden_states=outputs.hidden_states, |
|
1083 attentions=outputs.attentions, |
|
1084 ) |
|
1085 |
|
1086 def prepare_inputs_for_generation( |
|
1087 self, |
|
1088 input_ids, |
|
1089 past_key_values=None, |
|
1090 attention_mask=None, |
|
1091 inputs_embeds=None, |
|
1092 cache_position=None, |
|
1093 position_ids=None, |
|
1094 use_cache=True, |
|
1095 num_logits_to_keep=None, |
|
1096 **kwargs, |
|
1097 ): |
|
1098 # Overwritten: has a special cache type, `HybridCache` |
|
1099 |
|
1100 # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens |
|
1101 # Exception 1: when passing input_embeds, input_ids may be missing entries |
|
1102 # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here |
|
1103 if past_key_values is not None: |
|
1104 if inputs_embeds is not None: # Exception 1 |
|
1105 input_ids = input_ids[:, -cache_position.shape[0] :] |
|
1106 elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) |
|
1107 input_ids = input_ids[:, cache_position] |
|
1108 if attention_mask is not None and position_ids is None: |
|
1109 # create position_ids on the fly for batch generation |
|
1110 position_ids = attention_mask.long().cumsum(-1) - 1 |
|
1111 position_ids.masked_fill_(attention_mask == 0, 1) |
|
1112 if past_key_values: |
|
1113 position_ids = position_ids[:, -input_ids.shape[1] :] |
|
1114 # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s |
|
1115 # `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride |
|
1116 # during the decoding. Here, simply using `.contiguous()` is not sufficient as in the |
|
1117 # batch size = 1 case, `position_ids` is already contiguous but with varying stride |
|
1118 # which retriggers a capture. |
|
1119 position_ids = position_ids.clone(memory_format=torch.contiguous_format) |
|
1120 |
|
1121 # if `inputs_embeds` are passed, we only want to use them in the 1st generation step |
|
1122 if inputs_embeds is not None and cache_position[0] == 0: |
|
1123 model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} |
|
1124 else: |
|
1125 # The clone here is for the same reason as for `position_ids`. |
|
1126 model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} |
|
1127 |
|
1128 if ( |
|
1129 isinstance(past_key_values, HybridCache) |
|
1130 and attention_mask.ndim == 2 |
|
1131 and not self.config._attn_implementation == "flash_attention_2" |
|
1132 ): |
|
1133 if model_inputs["inputs_embeds"] is not None: |
|
1134 batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape |
|
1135 device = model_inputs["inputs_embeds"].device |
|
1136 else: |
|
1137 batch_size, sequence_length = model_inputs["input_ids"].shape |
|
1138 device = model_inputs["input_ids"].device |
|
1139 |
|
1140 attention_mask = self.model._prepare_4d_causal_attention_mask_with_cache_position( |
|
1141 attention_mask, |
|
1142 sequence_length=sequence_length, |
|
1143 target_length=past_key_values.get_max_cache_shape(), |
|
1144 dtype=self.lm_head.weight.dtype, |
|
1145 device=device, |
|
1146 cache_position=cache_position, |
|
1147 batch_size=batch_size, |
|
1148 ) |
|
1149 |
|
1150 if num_logits_to_keep is not None: |
|
1151 model_inputs["num_logits_to_keep"] = num_logits_to_keep |
|
1152 |
|
1153 model_inputs.update( |
|
1154 { |
|
1155 "position_ids": position_ids, |
|
1156 "cache_position": cache_position, |
|
1157 "past_key_values": past_key_values, |
|
1158 "use_cache": use_cache, |
|
1159 "attention_mask": attention_mask, |
|
1160 } |
|
1161 ) |
|
1162 return model_inputs |
|
1163 |
|
1164 |
|
1165 @add_start_docstrings( |
|
1166 """ |
|
1167 The Gemma2 Model transformer with a sequence classification head on top (linear layer). |
|
1168 |
|
1169 [`Gemma2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models |
|
1170 (e.g. GPT-2) do. |
|
1171 |
|
1172 Since it does classification on the last token, it requires to know the position of the last token. If a |
|
1173 `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If |
|
1174 no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the |
|
1175 padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in |
|
1176 each row of the batch). |
|
1177 """, |
|
1178 GEMMA2_START_DOCSTRING, |
|
1179 ) |
|
1180 class Gemma2ForSequenceClassification(Gemma2PreTrainedModel): |
|
1181 def __init__(self, config): |
|
1182 super().__init__(config) |
|
1183 self.num_labels = config.num_labels |
|
1184 self.model = Gemma2Model(config) |
|
1185 self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) |
|
1186 |
|
1187 # Initialize weights and apply final processing |
|
1188 self.post_init() |
|
1189 |
|
1190 def get_input_embeddings(self): |
|
1191 return self.model.embed_tokens |
|
1192 |
|
1193 def set_input_embeddings(self, value): |
|
1194 self.model.embed_tokens = value |
|
1195 |
|
1196 @add_start_docstrings_to_model_forward(GEMMA2_INPUTS_DOCSTRING) |
|
1197 def forward( |
|
1198 self, |
|
1199 input_ids: Optional[torch.LongTensor] = None, |
|
1200 attention_mask: Optional[torch.Tensor] = None, |
|
1201 position_ids: Optional[torch.LongTensor] = None, |
|
1202 past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, |
|
1203 inputs_embeds: Optional[torch.FloatTensor] = None, |
|
1204 labels: Optional[torch.LongTensor] = None, |
|
1205 use_cache: Optional[bool] = None, |
|
1206 output_attentions: Optional[bool] = None, |
|
1207 output_hidden_states: Optional[bool] = None, |
|
1208 return_dict: Optional[bool] = None, |
|
1209 ) -> Union[Tuple, SequenceClassifierOutputWithPast]: |
|
1210 r""" |
|
1211 labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): |
|
1212 Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., |
|
1213 config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If |
|
1214 `config.num_labels > 1` a classification loss is computed (Cross-Entropy). |
|
1215 """ |
|
1216 return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
1217 |
|
1218 transformer_outputs = self.model( |
|
1219 input_ids, |
|
1220 attention_mask=attention_mask, |
|
1221 position_ids=position_ids, |
|
1222 past_key_values=past_key_values, |
|
1223 inputs_embeds=inputs_embeds, |
|
1224 use_cache=use_cache, |
|
1225 output_attentions=output_attentions, |
|
1226 output_hidden_states=output_hidden_states, |
|
1227 return_dict=return_dict, |
|
1228 ) |
|
1229 hidden_states = transformer_outputs[0] |
|
1230 logits = self.score(hidden_states) |
|
1231 |
|
1232 if input_ids is not None: |
|
1233 batch_size = input_ids.shape[0] |
|
1234 else: |
|
1235 batch_size = inputs_embeds.shape[0] |
|
1236 |
|
1237 if self.config.pad_token_id is None and batch_size != 1: |
|
1238 raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") |
|
1239 if self.config.pad_token_id is None: |
|
1240 sequence_lengths = -1 |
|
1241 else: |
|
1242 if input_ids is not None: |
|
1243 # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility |
|
1244 sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 |
|
1245 sequence_lengths = sequence_lengths % input_ids.shape[-1] |
|
1246 sequence_lengths = sequence_lengths.to(logits.device) |
|
1247 else: |
|
1248 sequence_lengths = -1 |
|
1249 |
|
1250 pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] |
|
1251 |
|
1252 loss = None |
|
1253 if labels is not None: |
|
1254 loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config) |
|
1255 |
|
1256 if not return_dict: |
|
1257 output = (pooled_logits,) + transformer_outputs[1:] |
|
1258 return ((loss,) + output) if loss is not None else output |
|
1259 |
|
1260 return SequenceClassifierOutputWithPast( |
|
1261 loss=loss, |
|
1262 logits=pooled_logits, |
|
1263 past_key_values=transformer_outputs.past_key_values, |
|
1264 hidden_states=transformer_outputs.hidden_states, |
|
1265 attentions=transformer_outputs.attentions, |
|
1266 ) |
|
1267 |
|
1268 |
|
1269 @add_start_docstrings( |
|
1270 """ |
|
1271 The Gemma2 Model transformer with a token classification head on top (a linear layer on top of the hidden-states |
|
1272 output) e.g. for Named-Entity-Recognition (NER) tasks. |
|
1273 """, |
|
1274 GEMMA2_START_DOCSTRING, |
|
1275 ) |
|
1276 class Gemma2ForTokenClassification(Gemma2PreTrainedModel): |
|
1277 def __init__(self, config): |
|
1278 super().__init__(config) |
|
1279 self.num_labels = config.num_labels |
|
1280 self.model = Gemma2Model(config) |
|
1281 if getattr(config, "classifier_dropout", None) is not None: |
|
1282 classifier_dropout = config.classifier_dropout |
|
1283 elif getattr(config, "hidden_dropout", None) is not None: |
|
1284 classifier_dropout = config.hidden_dropout |
|
1285 else: |
|
1286 classifier_dropout = 0.1 |
|
1287 self.dropout = nn.Dropout(classifier_dropout) |
|
1288 self.score = nn.Linear(config.hidden_size, config.num_labels) |
|
1289 |
|
1290 # Initialize weights and apply final processing |
|
1291 self.post_init() |
|
1292 |
|
1293 def get_input_embeddings(self): |
|
1294 return self.model.embed_tokens |
|
1295 |
|
1296 def set_input_embeddings(self, value): |
|
1297 self.model.embed_tokens = value |
|
1298 |
|
1299 @add_start_docstrings_to_model_forward(GEMMA2_INPUTS_DOCSTRING) |
|
1300 @add_code_sample_docstrings( |
|
1301 checkpoint=_CHECKPOINT_FOR_DOC, |
|
1302 output_type=TokenClassifierOutput, |
|
1303 config_class=_CONFIG_FOR_DOC, |
|
1304 ) |
|
1305 def forward( |
|
1306 self, |
|
1307 input_ids: Optional[torch.LongTensor] = None, |
|
1308 attention_mask: Optional[torch.Tensor] = None, |
|
1309 position_ids: Optional[torch.LongTensor] = None, |
|
1310 past_key_values: Optional[List[torch.FloatTensor]] = None, |
|
1311 inputs_embeds: Optional[torch.FloatTensor] = None, |
|
1312 labels: Optional[torch.LongTensor] = None, |
|
1313 use_cache: Optional[bool] = None, |
|
1314 output_attentions: Optional[bool] = None, |
|
1315 output_hidden_states: Optional[bool] = None, |
|
1316 return_dict: Optional[bool] = None, |
|
1317 ) -> Union[Tuple, TokenClassifierOutput]: |
|
1318 r""" |
|
1319 labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): |
|
1320 Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., |
|
1321 config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If |
|
1322 `config.num_labels > 1` a classification loss is computed (Cross-Entropy). |
|
1323 """ |
|
1324 return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
1325 |
|
1326 outputs = self.model( |
|
1327 input_ids, |
|
1328 attention_mask=attention_mask, |
|
1329 position_ids=position_ids, |
|
1330 past_key_values=past_key_values, |
|
1331 inputs_embeds=inputs_embeds, |
|
1332 use_cache=use_cache, |
|
1333 output_attentions=output_attentions, |
|
1334 output_hidden_states=output_hidden_states, |
|
1335 return_dict=return_dict, |
|
1336 ) |
|
1337 sequence_output = outputs[0] |
|
1338 sequence_output = self.dropout(sequence_output) |
|
1339 logits = self.score(sequence_output) |
|
1340 |
|
1341 loss = None |
|
1342 if labels is not None: |
|
1343 loss = self.loss_function(logits, labels, self.config) |
|
1344 |
|
1345 if not return_dict: |
|
1346 output = (logits,) + outputs[2:] |
|
1347 return ((loss,) + output) if loss is not None else output |
|
1348 |
|
1349 return TokenClassifierOutput( |
|
1350 loss=loss, |
|
1351 logits=logits, |
|
1352 hidden_states=outputs.hidden_states, |
|
1353 attentions=outputs.attentions, |
|
1354 ) |
|
|