aslawliet commited on
Commit
9cbbe5a
1 Parent(s): c8c482a

Create modeling_tcss.py

Browse files
Files changed (1) hide show
  1. modeling_tcss.py +536 -0
modeling_tcss.py ADDED
@@ -0,0 +1,536 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from copy import deepcopy
4
+
5
+ from transformers.models.llama.modeling_llama import *
6
+ from transformers.modeling_outputs import TokenClassifierOutput
7
+
8
+
9
+ _CONFIG_FOR_DOC = "LlamaConfig"
10
+
11
+
12
+ # Copied from transformers.models.bart.modeling_bart._make_causal_mask
13
+ def _make_causal_mask(
14
+ input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
15
+ ):
16
+ """
17
+ Make causal mask used for bi-directional self-attention.
18
+ """
19
+ bsz, tgt_len = input_ids_shape
20
+ mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
21
+ mask_cond = torch.arange(mask.size(-1), device=device)
22
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
23
+ mask = mask.to(dtype)
24
+
25
+ if past_key_values_length > 0:
26
+ mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
27
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
28
+
29
+
30
+ # Copied from transformers.models.bart.modeling_bart._expand_mask
31
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
32
+ """
33
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
34
+ """
35
+ bsz, src_len = mask.size()
36
+ tgt_len = tgt_len if tgt_len is not None else src_len
37
+
38
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
39
+
40
+ inverted_mask = 1.0 - expanded_mask
41
+
42
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
43
+
44
+
45
+ class UnmaskingLlamaModel(LlamaPreTrainedModel):
46
+ """
47
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
48
+
49
+ Args:
50
+ config: LlamaConfig
51
+ """
52
+
53
+ def __init__(self, config: LlamaConfig):
54
+ super().__init__(config)
55
+ self.padding_idx = config.pad_token_id
56
+ self.vocab_size = config.vocab_size
57
+
58
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
59
+ self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])
60
+ self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
61
+
62
+ self.gradient_checkpointing = False
63
+ # Initialize weights and apply final processing
64
+ self.post_init()
65
+
66
+ def get_input_embeddings(self):
67
+ return self.embed_tokens
68
+
69
+ def set_input_embeddings(self, value):
70
+ self.embed_tokens = value
71
+
72
+ # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
73
+ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
74
+ # create causal mask
75
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
76
+ combined_attention_mask = None
77
+ if input_shape[-1] > 1:
78
+ combined_attention_mask = _make_causal_mask(
79
+ input_shape,
80
+ inputs_embeds.dtype,
81
+ device=inputs_embeds.device,
82
+ past_key_values_length=past_key_values_length,
83
+ )
84
+ if attention_mask is not None:
85
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
86
+ expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
87
+ inputs_embeds.device
88
+ )
89
+ combined_attention_mask = (
90
+ expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
91
+ )
92
+
93
+ return combined_attention_mask
94
+
95
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
96
+ def forward(
97
+ self,
98
+ input_ids: torch.LongTensor = None,
99
+ attention_mask: Optional[torch.Tensor] = None,
100
+ position_ids: Optional[torch.LongTensor] = None,
101
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
102
+ inputs_embeds: Optional[torch.FloatTensor] = None,
103
+ use_cache: Optional[bool] = None,
104
+ output_attentions: Optional[bool] = None,
105
+ output_hidden_states: Optional[bool] = None,
106
+ return_dict: Optional[bool] = None,
107
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
108
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
109
+ output_hidden_states = (
110
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
111
+ )
112
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
113
+
114
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
115
+
116
+ # retrieve input_ids and inputs_embeds
117
+ if input_ids is not None and inputs_embeds is not None:
118
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
119
+ elif input_ids is not None:
120
+ batch_size, seq_length = input_ids.shape
121
+ elif inputs_embeds is not None:
122
+ batch_size, seq_length, _ = inputs_embeds.shape
123
+ else:
124
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
125
+
126
+ seq_length_with_past = seq_length
127
+ past_key_values_length = 0
128
+
129
+ if past_key_values is not None:
130
+ past_key_values_length = past_key_values[0][0].shape[2]
131
+ seq_length_with_past = seq_length_with_past + past_key_values_length
132
+
133
+ if position_ids is None:
134
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
135
+ position_ids = torch.arange(
136
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
137
+ )
138
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
139
+ else:
140
+ position_ids = position_ids.view(-1, seq_length).long()
141
+
142
+ if inputs_embeds is None:
143
+ inputs_embeds = self.embed_tokens(input_ids)
144
+ # embed positions
145
+ if attention_mask is None:
146
+ attention_mask = torch.ones(
147
+ (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
148
+ )
149
+ # causal mask
150
+ '''
151
+ attention_mask = self._prepare_decoder_attention_mask(
152
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
153
+ )
154
+ print('unmasking attention mask:')
155
+ print(attention_mask)
156
+ '''
157
+ # remove causal mask
158
+ attention_mask = torch.zeros(
159
+ (batch_size, 1, seq_length, seq_length), device=inputs_embeds.device
160
+ )
161
+
162
+ hidden_states = inputs_embeds
163
+
164
+ if self.gradient_checkpointing and self.training:
165
+ if use_cache:
166
+ logger.warning_once(
167
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
168
+ )
169
+ use_cache = False
170
+
171
+ # decoder layers
172
+ all_hidden_states = () if output_hidden_states else None
173
+ all_self_attns = () if output_attentions else None
174
+ next_decoder_cache = () if use_cache else None
175
+
176
+ for idx, decoder_layer in enumerate(self.layers):
177
+ if output_hidden_states:
178
+ all_hidden_states += (hidden_states,)
179
+
180
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
181
+
182
+ if self.gradient_checkpointing and self.training:
183
+
184
+ def create_custom_forward(module):
185
+ def custom_forward(*inputs):
186
+ # None for past_key_value
187
+ return module(*inputs, past_key_value, output_attentions)
188
+
189
+ return custom_forward
190
+
191
+ layer_outputs = torch.utils.checkpoint.checkpoint(
192
+ create_custom_forward(decoder_layer),
193
+ hidden_states,
194
+ attention_mask,
195
+ position_ids,
196
+ )
197
+ else:
198
+ layer_outputs = decoder_layer(
199
+ hidden_states,
200
+ attention_mask=attention_mask,
201
+ position_ids=position_ids,
202
+ past_key_value=past_key_value,
203
+ output_attentions=output_attentions,
204
+ use_cache=use_cache,
205
+ )
206
+
207
+ hidden_states = layer_outputs[0]
208
+
209
+ if use_cache:
210
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
211
+
212
+ if output_attentions:
213
+ all_self_attns += (layer_outputs[1],)
214
+
215
+ hidden_states = self.norm(hidden_states)
216
+
217
+ # add hidden states from the last decoder layer
218
+ if output_hidden_states:
219
+ all_hidden_states += (hidden_states,)
220
+
221
+ next_cache = next_decoder_cache if use_cache else None
222
+ if not return_dict:
223
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
224
+ return BaseModelOutputWithPast(
225
+ last_hidden_state=hidden_states,
226
+ past_key_values=next_cache,
227
+ hidden_states=all_hidden_states,
228
+ attentions=all_self_attns,
229
+ )
230
+
231
+
232
+ @add_start_docstrings(
233
+ """
234
+ The LLaMa Model transformer with a sequence classification head on top (linear layer).
235
+
236
+ [`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models
237
+ (e.g. GPT-2) do.
238
+
239
+ Since it does classification on the last token, it requires to know the position of the last token. If a
240
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
241
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
242
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
243
+ each row of the batch).
244
+ """,
245
+ LLAMA_START_DOCSTRING,
246
+ )
247
+ class UnmaskingLlamaForSequenceClassification(LlamaPreTrainedModel):
248
+ def __init__(self, config):
249
+ super().__init__(config)
250
+ self.num_labels = config.num_labels
251
+ self.model = UnmaskingLlamaModel(config)
252
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
253
+
254
+ self.pooling = 'mean'
255
+ # Initialize weights and apply final processing
256
+ self.post_init()
257
+
258
+ def get_input_embeddings(self):
259
+ return self.model.embed_tokens
260
+
261
+ def set_input_embeddings(self, value):
262
+ self.model.embed_tokens = value
263
+
264
+ def set_pooling(self, pooling):
265
+ self.pooling = pooling
266
+
267
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
268
+ def forward(
269
+ self,
270
+ input_ids: torch.LongTensor = None,
271
+ attention_mask: Optional[torch.Tensor] = None,
272
+ position_ids: Optional[torch.LongTensor] = None,
273
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
274
+ inputs_embeds: Optional[torch.FloatTensor] = None,
275
+ labels: Optional[torch.LongTensor] = None,
276
+ use_cache: Optional[bool] = None,
277
+ output_attentions: Optional[bool] = None,
278
+ output_hidden_states: Optional[bool] = None,
279
+ return_dict: Optional[bool] = None,
280
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
281
+ r"""
282
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
283
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
284
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
285
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
286
+ """
287
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
288
+
289
+ transformer_outputs = self.model(
290
+ input_ids,
291
+ attention_mask=attention_mask,
292
+ position_ids=position_ids,
293
+ past_key_values=past_key_values,
294
+ inputs_embeds=inputs_embeds,
295
+ use_cache=use_cache,
296
+ output_attentions=output_attentions,
297
+ output_hidden_states=output_hidden_states,
298
+ return_dict=return_dict,
299
+ )
300
+ hidden_states = transformer_outputs[0]
301
+ logits = self.score(hidden_states)
302
+
303
+ if input_ids is not None:
304
+ batch_size = input_ids.shape[0]
305
+ else:
306
+ batch_size = inputs_embeds.shape[0]
307
+
308
+ if self.config.pad_token_id is None and batch_size != 1:
309
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
310
+ if self.config.pad_token_id is None:
311
+ sequence_lengths = -1
312
+ else:
313
+ if input_ids is not None:
314
+ sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1).to(
315
+ logits.device
316
+ )
317
+ else:
318
+ sequence_lengths = -1
319
+
320
+ if self.pooling == 'last':
321
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
322
+ elif self.pooling == 'max':
323
+ pooled_logits, _ = torch.max(logits, dim=1)
324
+ elif self.pooling == 'mean':
325
+ pooled_logits = torch.mean(logits, dim=1)
326
+ else:
327
+ raise NotImplementedError
328
+
329
+ loss = None
330
+ if labels is not None:
331
+ labels = labels.to(logits.device)
332
+ if self.config.problem_type is None:
333
+ if self.num_labels == 1:
334
+ self.config.problem_type = "regression"
335
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
336
+ self.config.problem_type = "single_label_classification"
337
+ else:
338
+ self.config.problem_type = "multi_label_classification"
339
+
340
+ if self.config.problem_type == "regression":
341
+ loss_fct = MSELoss()
342
+ if self.num_labels == 1:
343
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
344
+ else:
345
+ loss = loss_fct(pooled_logits, labels)
346
+ elif self.config.problem_type == "single_label_classification":
347
+ loss_fct = CrossEntropyLoss()
348
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
349
+ elif self.config.problem_type == "multi_label_classification":
350
+ loss_fct = BCEWithLogitsLoss()
351
+ loss = loss_fct(pooled_logits, labels)
352
+ if not return_dict:
353
+ output = (pooled_logits,) + transformer_outputs[1:]
354
+ return ((loss,) + output) if loss is not None else output
355
+
356
+ return SequenceClassifierOutputWithPast(
357
+ loss=loss,
358
+ logits=pooled_logits,
359
+ past_key_values=transformer_outputs.past_key_values,
360
+ hidden_states=transformer_outputs.hidden_states,
361
+ attentions=transformer_outputs.attentions,
362
+ )
363
+
364
+
365
+ @add_start_docstrings(
366
+ """
367
+ The LLaMa Model transformer with a token classification head on top (linear layer).
368
+
369
+ [`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models
370
+ (e.g. GPT-2) do.
371
+
372
+ Since it does classification on the last token, it requires to know the position of the last token. If a
373
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
374
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
375
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
376
+ each row of the batch).
377
+ """,
378
+ LLAMA_START_DOCSTRING,
379
+ )
380
+ class LlamaForTokenClassification(LlamaPreTrainedModel):
381
+ def __init__(self, config):
382
+ super().__init__(config)
383
+ self.num_labels = config.num_labels
384
+ self.model = LlamaModel(config)
385
+ self.dropout = nn.Dropout(0.1)
386
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
387
+
388
+ # Initialize weights and apply final processing
389
+ self.post_init()
390
+
391
+ def get_input_embeddings(self):
392
+ return self.model.embed_tokens
393
+
394
+ def set_input_embeddings(self, value):
395
+ self.model.embed_tokens = value
396
+
397
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
398
+ def forward(
399
+ self,
400
+ input_ids: torch.LongTensor = None,
401
+ attention_mask: Optional[torch.Tensor] = None,
402
+ position_ids: Optional[torch.LongTensor] = None,
403
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
404
+ inputs_embeds: Optional[torch.FloatTensor] = None,
405
+ labels: Optional[torch.LongTensor] = None,
406
+ use_cache: Optional[bool] = None,
407
+ output_attentions: Optional[bool] = None,
408
+ output_hidden_states: Optional[bool] = None,
409
+ return_dict: Optional[bool] = None,
410
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
411
+ r"""
412
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
413
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
414
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
415
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
416
+ """
417
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
418
+
419
+ outputs = self.model(
420
+ input_ids,
421
+ attention_mask=attention_mask,
422
+ position_ids=position_ids,
423
+ past_key_values=past_key_values,
424
+ inputs_embeds=inputs_embeds,
425
+ use_cache=use_cache,
426
+ output_attentions=output_attentions,
427
+ output_hidden_states=output_hidden_states,
428
+ return_dict=return_dict,
429
+ )
430
+ sequence_output = outputs[0]
431
+
432
+ sequence_output = self.dropout(sequence_output)
433
+ logits = self.classifier(sequence_output)
434
+
435
+ loss = None
436
+ if labels is not None:
437
+ loss_fct = CrossEntropyLoss()
438
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
439
+
440
+ if not return_dict:
441
+ output = (logits,) + outputs[2:]
442
+ return ((loss,) + output) if loss is not None else output
443
+
444
+ return TokenClassifierOutput(
445
+ loss=loss,
446
+ logits=logits,
447
+ hidden_states=outputs.hidden_states,
448
+ attentions=outputs.attentions,
449
+ )
450
+
451
+
452
+ @add_start_docstrings(
453
+ """
454
+ The LLaMa Model transformer with a token classification head on top (linear layer).
455
+
456
+ [`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models
457
+ (e.g. GPT-2) do.
458
+
459
+ Since it does classification on the last token, it requires to know the position of the last token. If a
460
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
461
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
462
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
463
+ each row of the batch).
464
+ """,
465
+ LLAMA_START_DOCSTRING,
466
+ )
467
+ class UnmaskingLlamaForTokenClassification(LlamaPreTrainedModel):
468
+ def __init__(self, config):
469
+ super().__init__(config)
470
+ self.num_labels = config.num_labels
471
+ self.model = UnmaskingLlamaModel(config)
472
+ self.dropout = nn.Dropout(0.1)
473
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
474
+
475
+ # Initialize weights and apply final processing
476
+ self.post_init()
477
+
478
+ def get_input_embeddings(self):
479
+ return self.model.embed_tokens
480
+
481
+ def set_input_embeddings(self, value):
482
+ self.model.embed_tokens = value
483
+
484
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
485
+ def forward(
486
+ self,
487
+ input_ids: torch.LongTensor = None,
488
+ attention_mask: Optional[torch.Tensor] = None,
489
+ position_ids: Optional[torch.LongTensor] = None,
490
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
491
+ inputs_embeds: Optional[torch.FloatTensor] = None,
492
+ labels: Optional[torch.LongTensor] = None,
493
+ use_cache: Optional[bool] = None,
494
+ output_attentions: Optional[bool] = None,
495
+ output_hidden_states: Optional[bool] = None,
496
+ return_dict: Optional[bool] = None,
497
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
498
+ r"""
499
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
500
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
501
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
502
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
503
+ """
504
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
505
+
506
+ outputs = self.model(
507
+ input_ids,
508
+ attention_mask=attention_mask,
509
+ position_ids=position_ids,
510
+ past_key_values=past_key_values,
511
+ inputs_embeds=inputs_embeds,
512
+ use_cache=use_cache,
513
+ output_attentions=output_attentions,
514
+ output_hidden_states=output_hidden_states,
515
+ return_dict=return_dict,
516
+ )
517
+ sequence_output = outputs[0]
518
+
519
+ sequence_output = self.dropout(sequence_output)
520
+ logits = self.classifier(sequence_output)
521
+
522
+ loss = None
523
+ if labels is not None:
524
+ loss_fct = CrossEntropyLoss()
525
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
526
+
527
+ if not return_dict:
528
+ output = (logits,) + outputs[2:]
529
+ return ((loss,) + output) if loss is not None else output
530
+
531
+ return TokenClassifierOutput(
532
+ loss=loss,
533
+ logits=logits,
534
+ hidden_states=outputs.hidden_states,
535
+ attentions=outputs.attentions,
536
+ )