add support for sequence classification
Browse files- config.json +44 -42
- configuration_flamingo.py +1 -1
- modeling_flamingo.py +128 -26
config.json
CHANGED
@@ -1,43 +1,45 @@
|
|
1 |
{
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
|
|
|
|
|
1 |
{
|
2 |
+
"_name_or_path": "facebook/opt-125m",
|
3 |
+
"_remove_final_layer_norm": false,
|
4 |
+
"activation_dropout": 0.0,
|
5 |
+
"activation_function": "relu",
|
6 |
+
"architectures": [
|
7 |
+
"FlamingoForCausalLM"
|
8 |
+
],
|
9 |
+
"auto_map": {
|
10 |
+
"AutoConfig": "configuration_flamingo.FlamingoConfig",
|
11 |
+
"AutoModelForCausalLM": "modeling_flamingo.FlamingoForCausalLM",
|
12 |
+
"AutoModelForSequenceClassification": "modeling_flamingo.FlamingoForSequenceClassification"
|
13 |
+
},
|
14 |
+
"attention_dropout": 0.0,
|
15 |
+
"bos_token_id": 2,
|
16 |
+
"cross_attn_every": 2,
|
17 |
+
"do_layer_norm_before": true,
|
18 |
+
"dropout": 0.1,
|
19 |
+
"enable_bias": true,
|
20 |
+
"eos_token_id": 2,
|
21 |
+
"ffn_dim": 3072,
|
22 |
+
"finetune_LM": true,
|
23 |
+
"hidden_size": 768,
|
24 |
+
"id_perceiver": false,
|
25 |
+
"init_std": 0.02,
|
26 |
+
"inp_dim": 768,
|
27 |
+
"layer_norm_elementwise_affine": true,
|
28 |
+
"layerdrop": 0.0,
|
29 |
+
"max_position_embeddings": 2048,
|
30 |
+
"media_token_id": 32768,
|
31 |
+
"model_type": "opt",
|
32 |
+
"num_attention_heads": 12,
|
33 |
+
"num_hidden_layers": 12,
|
34 |
+
"only_attend_immediate_media": true,
|
35 |
+
"pad_token_id": 1,
|
36 |
+
"perceiver_depth": 2,
|
37 |
+
"perceiver_num_latents": 64,
|
38 |
+
"prefix": "</s>",
|
39 |
+
"torch_dtype": "float32",
|
40 |
+
"transformers_version": "4.29.0",
|
41 |
+
"use_cache": true,
|
42 |
+
"vocab_size": 32778,
|
43 |
+
"word_embed_proj_dim": 768
|
44 |
+
}
|
45 |
+
|
configuration_flamingo.py
CHANGED
@@ -32,4 +32,4 @@ class FlamingoConfig(configuration_opt.OPTConfig, dict):
|
|
32 |
self, vocab_size=vocab_size, **kwargs)
|
33 |
self.media_token_id = media_token_id
|
34 |
self.cross_attn_every = cross_attn_every
|
35 |
-
dict.__init__(self, **self.__dict__)
|
|
|
32 |
self, vocab_size=vocab_size, **kwargs)
|
33 |
self.media_token_id = media_token_id
|
34 |
self.cross_attn_every = cross_attn_every
|
35 |
+
dict.__init__(self, **self.__dict__)
|
modeling_flamingo.py
CHANGED
@@ -7,9 +7,9 @@ import os
|
|
7 |
import torch
|
8 |
import torch.utils.checkpoint
|
9 |
from torch import nn
|
10 |
-
from torch.nn import CrossEntropyLoss
|
11 |
|
12 |
-
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
13 |
import transformers.models.opt.modeling_opt as modeling_opt
|
14 |
from transformers.models.opt.modeling_opt\
|
15 |
import OPTDecoderLayer, OPTPreTrainedModel, OPTConfig
|
@@ -46,7 +46,6 @@ class OPTLearnedPositionalEmbedding(nn.Embedding):
|
|
46 |
class OPTDecoder(modeling_opt.OPTDecoder):
|
47 |
"""
|
48 |
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`OPTDecoderLayer`]
|
49 |
-
|
50 |
Args:
|
51 |
config: OPTConfig
|
52 |
embed_tokens (nn.Embedding): output embedding
|
@@ -136,35 +135,26 @@ class OPTDecoder(modeling_opt.OPTDecoder):
|
|
136 |
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
137 |
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
|
138 |
provide it.
|
139 |
-
|
140 |
Indices can be obtained using [`OPTTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
141 |
[`PreTrainedTokenizer.__call__`] for details.
|
142 |
-
|
143 |
[What are input IDs?](../glossary#input-ids)
|
144 |
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
145 |
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
146 |
-
|
147 |
- 1 for tokens that are **not masked**,
|
148 |
- 0 for tokens that are **masked**.
|
149 |
-
|
150 |
[What are attention masks?](../glossary#attention-mask)
|
151 |
head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*):
|
152 |
Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
|
153 |
-
|
154 |
- 1 indicates the head is **not masked**,
|
155 |
- 0 indicates the head is **masked**.
|
156 |
-
|
157 |
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
158 |
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
|
159 |
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
|
160 |
-
|
161 |
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
|
162 |
cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
|
163 |
-
|
164 |
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
|
165 |
that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
|
166 |
all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
167 |
-
|
168 |
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
169 |
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
|
170 |
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
@@ -405,33 +395,25 @@ class FlamingoForCausalLM(modeling_opt.OPTForCausalLM):
|
|
405 |
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
406 |
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
|
407 |
provide it.
|
408 |
-
|
409 |
Indices can be obtained using [`OPTTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
410 |
[`PreTrainedTokenizer.__call__`] for details.
|
411 |
-
|
412 |
[What are input IDs?](../glossary#input-ids)
|
413 |
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
414 |
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
415 |
-
|
416 |
- 1 for tokens that are **not masked**,
|
417 |
- 0 for tokens that are **masked**.
|
418 |
-
|
419 |
[What are attention masks?](../glossary#attention-mask)
|
420 |
head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*):
|
421 |
Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
|
422 |
-
|
423 |
- 1 indicates the head is **not masked**,
|
424 |
- 0 indicates the head is **masked**.
|
425 |
-
|
426 |
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
427 |
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
|
428 |
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
|
429 |
shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional
|
430 |
tensors are only required when the model is used as a decoder in a Sequence to Sequence model.
|
431 |
-
|
432 |
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
|
433 |
cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
|
434 |
-
|
435 |
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
|
436 |
that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
|
437 |
all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
@@ -454,20 +436,14 @@ class FlamingoForCausalLM(modeling_opt.OPTForCausalLM):
|
|
454 |
for more detail.
|
455 |
return_dict (`bool`, *optional*):
|
456 |
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
457 |
-
|
458 |
Returns:
|
459 |
-
|
460 |
Example:
|
461 |
-
|
462 |
```python
|
463 |
>>> from transformers import GPT2Tokenizer, OPTForCausalLM
|
464 |
-
|
465 |
>>> model = OPTForCausalLM.from_pretrained("facebook/opt-350m")
|
466 |
>>> tokenizer = GPT2Tokenizer.from_pretrained("facebook/opt-350m")
|
467 |
-
|
468 |
>>> prompt = "Hey, are you consciours? Can you talk to me?"
|
469 |
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
470 |
-
|
471 |
>>> # Generate
|
472 |
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
473 |
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
@@ -514,3 +490,129 @@ class FlamingoForCausalLM(modeling_opt.OPTForCausalLM):
|
|
514 |
hidden_states=outputs.hidden_states,
|
515 |
attentions=outputs.attentions,
|
516 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
import torch
|
8 |
import torch.utils.checkpoint
|
9 |
from torch import nn
|
10 |
+
from torch.nn import CrossEntropyLoss, BCEWithLogitsLoss, MSELoss
|
11 |
|
12 |
+
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
|
13 |
import transformers.models.opt.modeling_opt as modeling_opt
|
14 |
from transformers.models.opt.modeling_opt\
|
15 |
import OPTDecoderLayer, OPTPreTrainedModel, OPTConfig
|
|
|
46 |
class OPTDecoder(modeling_opt.OPTDecoder):
|
47 |
"""
|
48 |
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`OPTDecoderLayer`]
|
|
|
49 |
Args:
|
50 |
config: OPTConfig
|
51 |
embed_tokens (nn.Embedding): output embedding
|
|
|
135 |
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
136 |
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
|
137 |
provide it.
|
|
|
138 |
Indices can be obtained using [`OPTTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
139 |
[`PreTrainedTokenizer.__call__`] for details.
|
|
|
140 |
[What are input IDs?](../glossary#input-ids)
|
141 |
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
142 |
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
|
|
143 |
- 1 for tokens that are **not masked**,
|
144 |
- 0 for tokens that are **masked**.
|
|
|
145 |
[What are attention masks?](../glossary#attention-mask)
|
146 |
head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*):
|
147 |
Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
|
|
|
148 |
- 1 indicates the head is **not masked**,
|
149 |
- 0 indicates the head is **masked**.
|
|
|
150 |
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
151 |
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
|
152 |
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
|
|
|
153 |
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
|
154 |
cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
|
|
|
155 |
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
|
156 |
that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
|
157 |
all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
|
|
158 |
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
159 |
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
|
160 |
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
|
|
395 |
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
396 |
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
|
397 |
provide it.
|
|
|
398 |
Indices can be obtained using [`OPTTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
399 |
[`PreTrainedTokenizer.__call__`] for details.
|
|
|
400 |
[What are input IDs?](../glossary#input-ids)
|
401 |
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
402 |
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
|
|
403 |
- 1 for tokens that are **not masked**,
|
404 |
- 0 for tokens that are **masked**.
|
|
|
405 |
[What are attention masks?](../glossary#attention-mask)
|
406 |
head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*):
|
407 |
Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
|
|
|
408 |
- 1 indicates the head is **not masked**,
|
409 |
- 0 indicates the head is **masked**.
|
|
|
410 |
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
411 |
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
|
412 |
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
|
413 |
shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional
|
414 |
tensors are only required when the model is used as a decoder in a Sequence to Sequence model.
|
|
|
415 |
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
|
416 |
cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
|
|
|
417 |
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
|
418 |
that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
|
419 |
all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
|
|
436 |
for more detail.
|
437 |
return_dict (`bool`, *optional*):
|
438 |
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
|
|
439 |
Returns:
|
|
|
440 |
Example:
|
|
|
441 |
```python
|
442 |
>>> from transformers import GPT2Tokenizer, OPTForCausalLM
|
|
|
443 |
>>> model = OPTForCausalLM.from_pretrained("facebook/opt-350m")
|
444 |
>>> tokenizer = GPT2Tokenizer.from_pretrained("facebook/opt-350m")
|
|
|
445 |
>>> prompt = "Hey, are you consciours? Can you talk to me?"
|
446 |
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
|
|
447 |
>>> # Generate
|
448 |
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
449 |
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
|
490 |
hidden_states=outputs.hidden_states,
|
491 |
attentions=outputs.attentions,
|
492 |
)
|
493 |
+
|
494 |
+
|
495 |
+
class FlamingoForSequenceClassification(OPTPreTrainedModel):
|
496 |
+
_keys_to_ignore_on_load_missing = [
|
497 |
+
r"score.weight",
|
498 |
+
]
|
499 |
+
|
500 |
+
def __init__(self, config: OPTConfig):
|
501 |
+
OPTPreTrainedModel.__init__(self, config)
|
502 |
+
config = setup_default_flamingo_configs(config)
|
503 |
+
self.num_labels = config.num_labels
|
504 |
+
self.model = OPTModel(config)
|
505 |
+
|
506 |
+
# the lm_head weight is automatically tied to the embed tokens weight
|
507 |
+
self.score = nn.Linear(config.word_embed_proj_dim, self.num_labels, bias=False)
|
508 |
+
|
509 |
+
# Initialize weights and apply final processing
|
510 |
+
self.post_init()
|
511 |
+
self.model.decoder.img_encoder = None
|
512 |
+
self.loss_fct = CrossEntropyLoss()
|
513 |
+
dino_model = ViTModel.from_pretrained("facebook/dino-vitb16")
|
514 |
+
self.setup_vis_encoder(dino_model)
|
515 |
+
|
516 |
+
def setup_vis_encoder(self, img_encoder):
|
517 |
+
self.model.decoder.img_encoder = img_encoder
|
518 |
+
freeze_all_layers_(img_encoder)
|
519 |
+
|
520 |
+
def forward(
|
521 |
+
self,
|
522 |
+
input_ids: Optional[torch.LongTensor] = None,
|
523 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
524 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
525 |
+
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
526 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
527 |
+
labels: Optional[torch.LongTensor] = None,
|
528 |
+
use_cache: Optional[bool] = None,
|
529 |
+
output_attentions: Optional[bool] = None,
|
530 |
+
output_hidden_states: Optional[bool] = None,
|
531 |
+
return_dict: Optional[bool] = None,
|
532 |
+
*args, **kwargs) -> Union[Tuple, SequenceClassifierOutputWithPast]:
|
533 |
+
r"""
|
534 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
535 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
536 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
537 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
538 |
+
"""
|
539 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
540 |
+
|
541 |
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
542 |
+
outputs = self.model.decoder(
|
543 |
+
input_ids=input_ids,
|
544 |
+
attention_mask=attention_mask,
|
545 |
+
head_mask=head_mask,
|
546 |
+
past_key_values=past_key_values,
|
547 |
+
inputs_embeds=inputs_embeds,
|
548 |
+
use_cache=use_cache,
|
549 |
+
output_attentions=output_attentions,
|
550 |
+
output_hidden_states=output_hidden_states,
|
551 |
+
return_dict=return_dict,
|
552 |
+
*args, **kwargs)
|
553 |
+
|
554 |
+
hidden_states = outputs[0]
|
555 |
+
logits = self.score(hidden_states)
|
556 |
+
|
557 |
+
if input_ids is not None:
|
558 |
+
batch_size, sequence_length = input_ids.shape[:2]
|
559 |
+
else:
|
560 |
+
batch_size, sequence_length = inputs_embeds.shape[:2]
|
561 |
+
|
562 |
+
if self.config.pad_token_id is None:
|
563 |
+
sequence_lengths = -1
|
564 |
+
else:
|
565 |
+
if input_ids is not None:
|
566 |
+
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
|
567 |
+
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
|
568 |
+
sequence_lengths = sequence_lengths % input_ids.shape[-1]
|
569 |
+
sequence_lengths = sequence_lengths.to(logits.device)
|
570 |
+
else:
|
571 |
+
sequence_lengths = -1
|
572 |
+
# logger.warning(
|
573 |
+
# f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
|
574 |
+
# "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
|
575 |
+
# )
|
576 |
+
|
577 |
+
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
|
578 |
+
|
579 |
+
loss = None
|
580 |
+
if labels is not None:
|
581 |
+
if self.config.problem_type is None:
|
582 |
+
if self.num_labels == 1:
|
583 |
+
self.config.problem_type = "regression"
|
584 |
+
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
585 |
+
self.config.problem_type = "single_label_classification"
|
586 |
+
else:
|
587 |
+
self.config.problem_type = "multi_label_classification"
|
588 |
+
|
589 |
+
if self.config.problem_type == "regression":
|
590 |
+
loss_fct = MSELoss()
|
591 |
+
if self.num_labels == 1:
|
592 |
+
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
|
593 |
+
else:
|
594 |
+
loss = loss_fct(pooled_logits, labels)
|
595 |
+
elif self.config.problem_type == "single_label_classification":
|
596 |
+
loss_fct = CrossEntropyLoss()
|
597 |
+
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
|
598 |
+
elif self.config.problem_type == "multi_label_classification":
|
599 |
+
loss_fct = BCEWithLogitsLoss()
|
600 |
+
loss = loss_fct(pooled_logits, labels)
|
601 |
+
|
602 |
+
if not return_dict:
|
603 |
+
output = (pooled_logits,) + outputs[1:]
|
604 |
+
return ((loss,) + output) if loss is not None else output
|
605 |
+
|
606 |
+
return SequenceClassifierOutputWithPast(
|
607 |
+
loss=loss,
|
608 |
+
logits=pooled_logits,
|
609 |
+
past_key_values=outputs.past_key_values,
|
610 |
+
hidden_states=outputs.hidden_states,
|
611 |
+
attentions=outputs.attentions,
|
612 |
+
)
|
613 |
+
|
614 |
+
def get_input_embeddings(self):
|
615 |
+
return self.model.decoder.embed_tokens
|
616 |
+
|
617 |
+
def set_input_embeddings(self, value):
|
618 |
+
self.model.decoder.embed_tokens = value
|