File size: 32,107 Bytes
d643072
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
   718	
   719	@add_start_docstrings(
   720	    "The bare Gemma2 Model outputting raw hidden-states without any specific head on top.",
   721	    GEMMA2_START_DOCSTRING,
   722	)
   723	class Gemma2Model(Gemma2PreTrainedModel):
   724	    """
   725	    Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Gemma2DecoderLayer`]
   726	
   727	    Args:
   728	        config: Gemma2Config
   729	    """
   730	
   731	    def __init__(self, config: Gemma2Config):
   732	        super().__init__(config)
   733	        self.padding_idx = config.pad_token_id
   734	        self.vocab_size = config.vocab_size
   735	
   736	        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
   737	        self.layers = nn.ModuleList(
   738	            [Gemma2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
   739	        )
   740	        self.norm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
   741	        self.gradient_checkpointing = False
   742	
   743	        # Initialize weights and apply final processing
   744	        self.post_init()
   745	
   746	    def get_input_embeddings(self):
   747	        return self.embed_tokens
   748	
   749	    def set_input_embeddings(self, value):
   750	        self.embed_tokens = value
   751	
   752	    @add_start_docstrings_to_model_forward(GEMMA2_INPUTS_DOCSTRING)
   753	    def forward(
   754	        self,
   755	        input_ids: torch.LongTensor = None,
   756	        attention_mask: Optional[torch.Tensor] = None,
   757	        position_ids: Optional[torch.LongTensor] = None,
   758	        past_key_values: Optional[HybridCache] = None,
   759	        inputs_embeds: Optional[torch.FloatTensor] = None,
   760	        use_cache: Optional[bool] = None,
   761	        output_attentions: Optional[bool] = None,
   762	        output_hidden_states: Optional[bool] = None,
   763	        return_dict: Optional[bool] = None,
   764	        cache_position: Optional[torch.LongTensor] = None,
   765	    ) -> Union[Tuple, BaseModelOutputWithPast]:
   766	        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
   767	        output_hidden_states = (
   768	            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
   769	        )
   770	        use_cache = use_cache if use_cache is not None else self.config.use_cache
   771	        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
   772	
   773	        if (input_ids is None) ^ (inputs_embeds is not None):
   774	            raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
   775	
   776	        if self.gradient_checkpointing and self.training and use_cache:
   777	            logger.warning_once(
   778	                "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
   779	            )
   780	            use_cache = False
   781	
   782	        if inputs_embeds is None:
   783	            inputs_embeds = self.embed_tokens(input_ids)
   784	
   785	        if use_cache and past_key_values is None and not self.training:
   786	            batch_size, seq_len, _ = inputs_embeds.shape
   787	            past_key_values = HybridCache(
   788	                self.config,
   789	                batch_size=batch_size,
   790	                max_cache_len=seq_len,
   791	                device=self.device,
   792	                dtype=inputs_embeds.dtype,
   793	            )
   794	
   795	        if cache_position is None:
   796	            past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
   797	            cache_position = torch.arange(
   798	                past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
   799	            )
   800	
   801	        if position_ids is None:
   802	            position_ids = cache_position.unsqueeze(0)
   803	
   804	        causal_mask = self._update_causal_mask(
   805	            attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
   806	        )
   807	
   808	        # embed positions
   809	        hidden_states = inputs_embeds
   810	
   811	        # normalized
   812	        # Gemma2 downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5
   813	        # See https://github.com/huggingface/transformers/pull/29402
   814	        normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype)
   815	        hidden_states = hidden_states * normalizer
   816	
   817	        # decoder layers
   818	        all_hidden_states = () if output_hidden_states else None
   819	        all_self_attns = () if output_attentions else None
   820	
   821	        for decoder_layer in self.layers:
   822	            if output_hidden_states:
   823	                all_hidden_states += (hidden_states,)
   824	
   825	            if self.gradient_checkpointing and self.training:
   826	                layer_outputs = self._gradient_checkpointing_func(
   827	                    decoder_layer.__call__,
   828	                    hidden_states,
   829	                    causal_mask,
   830	                    position_ids,
   831	                    past_key_values,
   832	                    output_attentions,
   833	                    use_cache,
   834	                    cache_position,
   835	                )
   836	            else:
   837	                layer_outputs = decoder_layer(
   838	                    hidden_states,
   839	                    attention_mask=causal_mask,
   840	                    position_ids=position_ids,
   841	                    past_key_value=past_key_values,
   842	                    output_attentions=output_attentions,
   843	                    use_cache=use_cache,
   844	                    cache_position=cache_position,
   845	                )
   846	
   847	            hidden_states = layer_outputs[0]
   848	
   849	            if output_attentions:
   850	                all_self_attns += (layer_outputs[1],)
   851	
   852	        hidden_states = self.norm(hidden_states)
   853	
   854	        if output_hidden_states:
   855	            all_hidden_states += (hidden_states,)
   856	
   857	        next_cache = past_key_values if use_cache else None
   858	
   859	        if not return_dict:
   860	            return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
   861	        return BaseModelOutputWithPast(
   862	            last_hidden_state=hidden_states,
   863	            past_key_values=next_cache,
   864	            hidden_states=all_hidden_states,
   865	            attentions=all_self_attns,
   866	        )
   867	
   868	    def _update_causal_mask(
   869	        self,
   870	        attention_mask: torch.Tensor,
   871	        input_tensor: torch.Tensor,
   872	        cache_position: torch.Tensor,
   873	        past_key_values: HybridCache,
   874	        output_attentions: bool,
   875	    ):
   876	        # Flash Attention currently doesn't support static cache but Gemma2 work only with static cache.
   877	        # So we will pass in attention mask as is in any case, not only when ther's padding. Then we'll use its shape
   878	        # to cut out keys/values trailing 0 used in static cache. This workaround should be compile compatible
   879	        # as it doesn't cause dynamic control issues.
   880	        if self.config._attn_implementation == "flash_attention_2":
   881	            return attention_mask
   882	
   883	        dtype, device = input_tensor.dtype, input_tensor.device
   884	        sequence_length = input_tensor.shape[1]
   885	        if isinstance(past_key_values, HybridCache):
   886	            target_length = past_key_values.get_max_cache_shape()
   887	        else:
   888	            target_length = attention_mask.shape[-1] if attention_mask is not None else input_tensor.shape[1]
   889	
   890	        # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
   891	        causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
   892	            attention_mask,
   893	            sequence_length=sequence_length,
   894	            target_length=target_length,
   895	            dtype=dtype,
   896	            device=device,
   897	            cache_position=cache_position,
   898	            batch_size=input_tensor.shape[0],
   899	        )
   900	        return causal_mask
   901	
   902	    @staticmethod
   903	    def _prepare_4d_causal_attention_mask_with_cache_position(
   904	        attention_mask: torch.Tensor,
   905	        sequence_length: int,
   906	        target_length: int,
   907	        dtype: torch.dtype,
   908	        device: torch.device,
   909	        cache_position: torch.Tensor,
   910	        batch_size: int,
   911	        **kwargs,
   912	    ):
   913	        """
   914	        Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
   915	        `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
   916	
   917	        Args:
   918	            attention_mask (`torch.Tensor`):
   919	                A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
   920	                `(batch_size, 1, query_length, key_value_length)`.
   921	            sequence_length (`int`):
   922	                The sequence length being processed.
   923	            target_length (`int`):
   924	                The target length: when generating with static cache, the mask should be as long as the static cache,
   925	                to account for the 0 padding, the part of the cache that is not filled yet.
   926	            dtype (`torch.dtype`):
   927	                The dtype to use for the 4D attention mask.
   928	            device (`torch.device`):
   929	                The device to plcae the 4D attention mask on.
   930	            cache_position (`torch.Tensor`):
   931	                Indices depicting the position of the input sequence tokens in the sequence.
   932	            batch_size (`torch.Tensor`):
   933	                Batch size.
   934	        """
   935	        if attention_mask is not None and attention_mask.dim() == 4:
   936	            # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
   937	            causal_mask = attention_mask
   938	        else:
   939	            min_dtype = torch.finfo(dtype).min
   940	            causal_mask = torch.full(
   941	                (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
   942	            )
   943	            if sequence_length != 1:
   944	                causal_mask = torch.triu(causal_mask, diagonal=1)
   945	            causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
   946	            causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
   947	            if attention_mask is not None:
   948	                causal_mask = causal_mask.clone()  # copy to contiguous memory for in-place edit
   949	                mask_length = attention_mask.shape[-1]
   950	                padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
   951	                padding_mask = padding_mask == 0
   952	                causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
   953	                    padding_mask, min_dtype
   954	                )
   955	
   956	        return causal_mask
   957	
   958	
   959	class Gemma2ForCausalLM(Gemma2PreTrainedModel, GenerationMixin):
   960	    _tied_weights_keys = ["lm_head.weight"]
   961	
   962	    def __init__(self, config):
   963	        super().__init__(config)
   964	        self.model = Gemma2Model(config)
   965	        self.vocab_size = config.vocab_size
   966	        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
   967	
   968	        # Initialize weights and apply final processing
   969	        self.post_init()
   970	
   971	    def get_input_embeddings(self):
   972	        return self.model.embed_tokens
   973	
   974	    def set_input_embeddings(self, value):
   975	        self.model.embed_tokens = value
   976	
   977	    def get_output_embeddings(self):
   978	        return self.lm_head
   979	
   980	    def set_output_embeddings(self, new_embeddings):
   981	        self.lm_head = new_embeddings
   982	
   983	    def set_decoder(self, decoder):
   984	        self.model = decoder
   985	
   986	    def get_decoder(self):
   987	        return self.model
   988	
   989	    @add_start_docstrings_to_model_forward(GEMMA2_INPUTS_DOCSTRING)
   990	    @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
   991	    def forward(
   992	        self,
   993	        input_ids: torch.LongTensor = None,
   994	        attention_mask: Optional[torch.Tensor] = None,
   995	        position_ids: Optional[torch.LongTensor] = None,
   996	        past_key_values: Optional[HybridCache] = None,
   997	        inputs_embeds: Optional[torch.FloatTensor] = None,
   998	        labels: Optional[torch.LongTensor] = None,
   999	        use_cache: Optional[bool] = None,
  1000	        output_attentions: Optional[bool] = None,
  1001	        output_hidden_states: Optional[bool] = None,
  1002	        return_dict: Optional[bool] = None,
  1003	        cache_position: Optional[torch.LongTensor] = None,
  1004	        num_logits_to_keep: int = 0,
  1005	        **loss_kwargs,
  1006	    ) -> Union[Tuple, CausalLMOutputWithPast]:
  1007	        r"""
  1008	        Args:
  1009	            labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1010	                Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  1011	                config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  1012	                (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  1013	
  1014	            num_logits_to_keep (`int`, *optional*):
  1015	                Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
  1016	                `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
  1017	                token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
  1018	
  1019	        Returns:
  1020	
  1021	        Example:
  1022	
  1023	        ```python
  1024	        >>> from transformers import AutoTokenizer, GemmaForCausalLM
  1025	
  1026	        >>> model = GemmaForCausalLM.from_pretrained("google/gemma-2-9b")
  1027	        >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b")
  1028	
  1029	        >>> prompt = "What is your favorite condiment?"
  1030	        >>> inputs = tokenizer(prompt, return_tensors="pt")
  1031	
  1032	        >>> # Generate
  1033	        >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
  1034	        >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  1035	        "What is your favorite condiment?"
  1036	        ```"""
  1037	
  1038	        if self.training and self.config._attn_implementation != "eager":
  1039	            logger.warning_once(
  1040	                "It is strongly recommended to train Gemma2 models with the `eager` attention implementation "
  1041	                f"instead of `{self.config._attn_implementation}`. Use `eager` with `AutoModelForCausalLM.from_pretrained('<path-to-checkpoint>', attn_implementation='eager')`."
  1042	            )
  1043	        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  1044	        output_hidden_states = (
  1045	            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1046	        )
  1047	        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1048	        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
  1049	        outputs = self.model(
  1050	            input_ids=input_ids,
  1051	            attention_mask=attention_mask,
  1052	            position_ids=position_ids,
  1053	            past_key_values=past_key_values,
  1054	            inputs_embeds=inputs_embeds,
  1055	            use_cache=use_cache,
  1056	            output_attentions=output_attentions,
  1057	            output_hidden_states=output_hidden_states,
  1058	            return_dict=return_dict,
  1059	            cache_position=cache_position,
  1060	        )
  1061	
  1062	        hidden_states = outputs[0]
  1063	        # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  1064	        logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
  1065	        if self.config.final_logit_softcapping is not None:
  1066	            logits = logits / self.config.final_logit_softcapping
  1067	            logits = torch.tanh(logits)
  1068	            logits = logits * self.config.final_logit_softcapping
  1069	
  1070	        loss = None
  1071	        if labels is not None:
  1072	            loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
  1073	
  1074	        if not return_dict:
  1075	            output = (logits,) + outputs[1:]
  1076	            return (loss,) + output if loss is not None else output
  1077	
  1078	        return CausalLMOutputWithPast(
  1079	            loss=loss,
  1080	            logits=logits,
  1081	            past_key_values=outputs.past_key_values,
  1082	            hidden_states=outputs.hidden_states,
  1083	            attentions=outputs.attentions,
  1084	        )
  1085	
  1086	    def prepare_inputs_for_generation(
  1087	        self,
  1088	        input_ids,
  1089	        past_key_values=None,
  1090	        attention_mask=None,
  1091	        inputs_embeds=None,
  1092	        cache_position=None,
  1093	        position_ids=None,
  1094	        use_cache=True,
  1095	        num_logits_to_keep=None,
  1096	        **kwargs,
  1097	    ):
  1098	        # Overwritten: has a special cache type, `HybridCache`
  1099	
  1100	        # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
  1101	        # Exception 1: when passing input_embeds, input_ids may be missing entries
  1102	        # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
  1103	        if past_key_values is not None:
  1104	            if inputs_embeds is not None:  # Exception 1
  1105	                input_ids = input_ids[:, -cache_position.shape[0] :]
  1106	            elif input_ids.shape[1] != cache_position.shape[0]:  # Default case (the "else", a no op, is Exception 2)
  1107	                input_ids = input_ids[:, cache_position]
  1108	        if attention_mask is not None and position_ids is None:
  1109	            # create position_ids on the fly for batch generation
  1110	            position_ids = attention_mask.long().cumsum(-1) - 1
  1111	            position_ids.masked_fill_(attention_mask == 0, 1)
  1112	            if past_key_values:
  1113	                position_ids = position_ids[:, -input_ids.shape[1] :]
  1114	                # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s
  1115	                # `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride
  1116	                # during the decoding. Here, simply using `.contiguous()` is not sufficient as in the
  1117	                # batch size = 1 case, `position_ids` is already contiguous but with varying stride
  1118	                # which retriggers a capture.
  1119	                position_ids = position_ids.clone(memory_format=torch.contiguous_format)
  1120	
  1121	        # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
  1122	        if inputs_embeds is not None and cache_position[0] == 0:
  1123	            model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
  1124	        else:
  1125	            # The clone here is for the same reason as for `position_ids`.
  1126	            model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
  1127	
  1128	        if (
  1129	            isinstance(past_key_values, HybridCache)
  1130	            and attention_mask.ndim == 2
  1131	            and not self.config._attn_implementation == "flash_attention_2"
  1132	        ):
  1133	            if model_inputs["inputs_embeds"] is not None:
  1134	                batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
  1135	                device = model_inputs["inputs_embeds"].device
  1136	            else:
  1137	                batch_size, sequence_length = model_inputs["input_ids"].shape
  1138	                device = model_inputs["input_ids"].device
  1139	
  1140	            attention_mask = self.model._prepare_4d_causal_attention_mask_with_cache_position(
  1141	                attention_mask,
  1142	                sequence_length=sequence_length,
  1143	                target_length=past_key_values.get_max_cache_shape(),
  1144	                dtype=self.lm_head.weight.dtype,
  1145	                device=device,
  1146	                cache_position=cache_position,
  1147	                batch_size=batch_size,
  1148	            )
  1149	
  1150	        if num_logits_to_keep is not None:
  1151	            model_inputs["num_logits_to_keep"] = num_logits_to_keep
  1152	
  1153	        model_inputs.update(
  1154	            {
  1155	                "position_ids": position_ids,
  1156	                "cache_position": cache_position,
  1157	                "past_key_values": past_key_values,
  1158	                "use_cache": use_cache,
  1159	                "attention_mask": attention_mask,
  1160	            }
  1161	        )
  1162	        return model_inputs
  1163	
  1164	
  1165	@add_start_docstrings(
  1166	    """
  1167	    The Gemma2 Model transformer with a sequence classification head on top (linear layer).
  1168	
  1169	    [`Gemma2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models
  1170	    (e.g. GPT-2) do.
  1171	
  1172	    Since it does classification on the last token, it requires to know the position of the last token. If a
  1173	    `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
  1174	    no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
  1175	    padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
  1176	    each row of the batch).
  1177	    """,
  1178	    GEMMA2_START_DOCSTRING,
  1179	)
  1180	class Gemma2ForSequenceClassification(Gemma2PreTrainedModel):
  1181	    def __init__(self, config):
  1182	        super().__init__(config)
  1183	        self.num_labels = config.num_labels
  1184	        self.model = Gemma2Model(config)
  1185	        self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
  1186	
  1187	        # Initialize weights and apply final processing
  1188	        self.post_init()
  1189	
  1190	    def get_input_embeddings(self):
  1191	        return self.model.embed_tokens
  1192	
  1193	    def set_input_embeddings(self, value):
  1194	        self.model.embed_tokens = value
  1195	
  1196	    @add_start_docstrings_to_model_forward(GEMMA2_INPUTS_DOCSTRING)
  1197	    def forward(
  1198	        self,
  1199	        input_ids: Optional[torch.LongTensor] = None,
  1200	        attention_mask: Optional[torch.Tensor] = None,
  1201	        position_ids: Optional[torch.LongTensor] = None,
  1202	        past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
  1203	        inputs_embeds: Optional[torch.FloatTensor] = None,
  1204	        labels: Optional[torch.LongTensor] = None,
  1205	        use_cache: Optional[bool] = None,
  1206	        output_attentions: Optional[bool] = None,
  1207	        output_hidden_states: Optional[bool] = None,
  1208	        return_dict: Optional[bool] = None,
  1209	    ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
  1210	        r"""
  1211	        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1212	            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  1213	            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  1214	            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  1215	        """
  1216	        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1217	
  1218	        transformer_outputs = self.model(
  1219	            input_ids,
  1220	            attention_mask=attention_mask,
  1221	            position_ids=position_ids,
  1222	            past_key_values=past_key_values,
  1223	            inputs_embeds=inputs_embeds,
  1224	            use_cache=use_cache,
  1225	            output_attentions=output_attentions,
  1226	            output_hidden_states=output_hidden_states,
  1227	            return_dict=return_dict,
  1228	        )
  1229	        hidden_states = transformer_outputs[0]
  1230	        logits = self.score(hidden_states)
  1231	
  1232	        if input_ids is not None:
  1233	            batch_size = input_ids.shape[0]
  1234	        else:
  1235	            batch_size = inputs_embeds.shape[0]
  1236	
  1237	        if self.config.pad_token_id is None and batch_size != 1:
  1238	            raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
  1239	        if self.config.pad_token_id is None:
  1240	            sequence_lengths = -1
  1241	        else:
  1242	            if input_ids is not None:
  1243	                # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
  1244	                sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
  1245	                sequence_lengths = sequence_lengths % input_ids.shape[-1]
  1246	                sequence_lengths = sequence_lengths.to(logits.device)
  1247	            else:
  1248	                sequence_lengths = -1
  1249	
  1250	        pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
  1251	
  1252	        loss = None
  1253	        if labels is not None:
  1254	            loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)
  1255	
  1256	        if not return_dict:
  1257	            output = (pooled_logits,) + transformer_outputs[1:]
  1258	            return ((loss,) + output) if loss is not None else output
  1259	
  1260	        return SequenceClassifierOutputWithPast(
  1261	            loss=loss,
  1262	            logits=pooled_logits,
  1263	            past_key_values=transformer_outputs.past_key_values,
  1264	            hidden_states=transformer_outputs.hidden_states,
  1265	            attentions=transformer_outputs.attentions,
  1266	        )
  1267	
  1268	
  1269	@add_start_docstrings(
  1270	    """
  1271	    The Gemma2 Model transformer with a token classification head on top (a linear layer on top of the hidden-states
  1272	    output) e.g. for Named-Entity-Recognition (NER) tasks.
  1273	    """,
  1274	    GEMMA2_START_DOCSTRING,
  1275	)
  1276	class Gemma2ForTokenClassification(Gemma2PreTrainedModel):
  1277	    def __init__(self, config):
  1278	        super().__init__(config)
  1279	        self.num_labels = config.num_labels
  1280	        self.model = Gemma2Model(config)
  1281	        if getattr(config, "classifier_dropout", None) is not None:
  1282	            classifier_dropout = config.classifier_dropout
  1283	        elif getattr(config, "hidden_dropout", None) is not None:
  1284	            classifier_dropout = config.hidden_dropout
  1285	        else:
  1286	            classifier_dropout = 0.1
  1287	        self.dropout = nn.Dropout(classifier_dropout)
  1288	        self.score = nn.Linear(config.hidden_size, config.num_labels)
  1289	
  1290	        # Initialize weights and apply final processing
  1291	        self.post_init()
  1292	
  1293	    def get_input_embeddings(self):
  1294	        return self.model.embed_tokens
  1295	
  1296	    def set_input_embeddings(self, value):
  1297	        self.model.embed_tokens = value
  1298	
  1299	    @add_start_docstrings_to_model_forward(GEMMA2_INPUTS_DOCSTRING)
  1300	    @add_code_sample_docstrings(
  1301	        checkpoint=_CHECKPOINT_FOR_DOC,
  1302	        output_type=TokenClassifierOutput,
  1303	        config_class=_CONFIG_FOR_DOC,
  1304	    )
  1305	    def forward(
  1306	        self,
  1307	        input_ids: Optional[torch.LongTensor] = None,
  1308	        attention_mask: Optional[torch.Tensor] = None,
  1309	        position_ids: Optional[torch.LongTensor] = None,
  1310	        past_key_values: Optional[List[torch.FloatTensor]] = None,
  1311	        inputs_embeds: Optional[torch.FloatTensor] = None,
  1312	        labels: Optional[torch.LongTensor] = None,
  1313	        use_cache: Optional[bool] = None,
  1314	        output_attentions: Optional[bool] = None,
  1315	        output_hidden_states: Optional[bool] = None,
  1316	        return_dict: Optional[bool] = None,
  1317	    ) -> Union[Tuple, TokenClassifierOutput]:
  1318	        r"""
  1319	        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1320	            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  1321	            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  1322	            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  1323	        """
  1324	        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1325	
  1326	        outputs = self.model(
  1327	            input_ids,
  1328	            attention_mask=attention_mask,
  1329	            position_ids=position_ids,
  1330	            past_key_values=past_key_values,
  1331	            inputs_embeds=inputs_embeds,
  1332	            use_cache=use_cache,
  1333	            output_attentions=output_attentions,
  1334	            output_hidden_states=output_hidden_states,
  1335	            return_dict=return_dict,
  1336	        )
  1337	        sequence_output = outputs[0]
  1338	        sequence_output = self.dropout(sequence_output)
  1339	        logits = self.score(sequence_output)
  1340	
  1341	        loss = None
  1342	        if labels is not None:
  1343	            loss = self.loss_function(logits, labels, self.config)
  1344	
  1345	        if not return_dict:
  1346	            output = (logits,) + outputs[2:]
  1347	            return ((loss,) + output) if loss is not None else output
  1348	
  1349	        return TokenClassifierOutput(
  1350	            loss=loss,
  1351	            logits=logits,
  1352	            hidden_states=outputs.hidden_states,
  1353	            attentions=outputs.attentions,
  1354	        )