babylm commited on
Commit
4d7c080
1 Parent(s): c6f1442

add support for sequence classification

Browse files
Files changed (3) hide show
  1. config.json +44 -42
  2. configuration_flamingo.py +1 -1
  3. modeling_flamingo.py +128 -26
config.json CHANGED
@@ -1,43 +1,45 @@
1
  {
2
- "_name_or_path": "facebook/opt-125m",
3
- "_remove_final_layer_norm": false,
4
- "activation_dropout": 0.0,
5
- "activation_function": "relu",
6
- "architectures": [
7
- "FlamingoForCausalLM"
8
- ],
9
- "auto_map": {
10
- "AutoConfig": "configuration_flamingo.FlamingoConfig",
11
- "AutoModelForCausalLM": "modeling_flamingo.FlamingoForCausalLM"
12
- },
13
- "attention_dropout": 0.0,
14
- "bos_token_id": 2,
15
- "cross_attn_every": 2,
16
- "do_layer_norm_before": true,
17
- "dropout": 0.1,
18
- "enable_bias": true,
19
- "eos_token_id": 2,
20
- "ffn_dim": 3072,
21
- "finetune_LM": true,
22
- "hidden_size": 768,
23
- "id_perceiver": false,
24
- "init_std": 0.02,
25
- "inp_dim": 768,
26
- "layer_norm_elementwise_affine": true,
27
- "layerdrop": 0.0,
28
- "max_position_embeddings": 2048,
29
- "media_token_id": 32768,
30
- "model_type": "opt",
31
- "num_attention_heads": 12,
32
- "num_hidden_layers": 12,
33
- "only_attend_immediate_media": true,
34
- "pad_token_id": 1,
35
- "perceiver_depth": 2,
36
- "perceiver_num_latents": 64,
37
- "prefix": "</s>",
38
- "torch_dtype": "float32",
39
- "transformers_version": "4.29.0",
40
- "use_cache": true,
41
- "vocab_size": 32778,
42
- "word_embed_proj_dim": 768
43
- }
 
 
 
1
  {
2
+ "_name_or_path": "facebook/opt-125m",
3
+ "_remove_final_layer_norm": false,
4
+ "activation_dropout": 0.0,
5
+ "activation_function": "relu",
6
+ "architectures": [
7
+ "FlamingoForCausalLM"
8
+ ],
9
+ "auto_map": {
10
+ "AutoConfig": "configuration_flamingo.FlamingoConfig",
11
+ "AutoModelForCausalLM": "modeling_flamingo.FlamingoForCausalLM",
12
+ "AutoModelForSequenceClassification": "modeling_flamingo.FlamingoForSequenceClassification"
13
+ },
14
+ "attention_dropout": 0.0,
15
+ "bos_token_id": 2,
16
+ "cross_attn_every": 2,
17
+ "do_layer_norm_before": true,
18
+ "dropout": 0.1,
19
+ "enable_bias": true,
20
+ "eos_token_id": 2,
21
+ "ffn_dim": 3072,
22
+ "finetune_LM": true,
23
+ "hidden_size": 768,
24
+ "id_perceiver": false,
25
+ "init_std": 0.02,
26
+ "inp_dim": 768,
27
+ "layer_norm_elementwise_affine": true,
28
+ "layerdrop": 0.0,
29
+ "max_position_embeddings": 2048,
30
+ "media_token_id": 32768,
31
+ "model_type": "opt",
32
+ "num_attention_heads": 12,
33
+ "num_hidden_layers": 12,
34
+ "only_attend_immediate_media": true,
35
+ "pad_token_id": 1,
36
+ "perceiver_depth": 2,
37
+ "perceiver_num_latents": 64,
38
+ "prefix": "</s>",
39
+ "torch_dtype": "float32",
40
+ "transformers_version": "4.29.0",
41
+ "use_cache": true,
42
+ "vocab_size": 32778,
43
+ "word_embed_proj_dim": 768
44
+ }
45
+
configuration_flamingo.py CHANGED
@@ -32,4 +32,4 @@ class FlamingoConfig(configuration_opt.OPTConfig, dict):
32
  self, vocab_size=vocab_size, **kwargs)
33
  self.media_token_id = media_token_id
34
  self.cross_attn_every = cross_attn_every
35
- dict.__init__(self, **self.__dict__)
 
32
  self, vocab_size=vocab_size, **kwargs)
33
  self.media_token_id = media_token_id
34
  self.cross_attn_every = cross_attn_every
35
+ dict.__init__(self, **self.__dict__)
modeling_flamingo.py CHANGED
@@ -7,9 +7,9 @@ import os
7
  import torch
8
  import torch.utils.checkpoint
9
  from torch import nn
10
- from torch.nn import CrossEntropyLoss
11
 
12
- from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
13
  import transformers.models.opt.modeling_opt as modeling_opt
14
  from transformers.models.opt.modeling_opt\
15
  import OPTDecoderLayer, OPTPreTrainedModel, OPTConfig
@@ -46,7 +46,6 @@ class OPTLearnedPositionalEmbedding(nn.Embedding):
46
  class OPTDecoder(modeling_opt.OPTDecoder):
47
  """
48
  Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`OPTDecoderLayer`]
49
-
50
  Args:
51
  config: OPTConfig
52
  embed_tokens (nn.Embedding): output embedding
@@ -136,35 +135,26 @@ class OPTDecoder(modeling_opt.OPTDecoder):
136
  input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
137
  Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
138
  provide it.
139
-
140
  Indices can be obtained using [`OPTTokenizer`]. See [`PreTrainedTokenizer.encode`] and
141
  [`PreTrainedTokenizer.__call__`] for details.
142
-
143
  [What are input IDs?](../glossary#input-ids)
144
  attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
145
  Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
146
-
147
  - 1 for tokens that are **not masked**,
148
  - 0 for tokens that are **masked**.
149
-
150
  [What are attention masks?](../glossary#attention-mask)
151
  head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*):
152
  Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
153
-
154
  - 1 indicates the head is **not masked**,
155
  - 0 indicates the head is **masked**.
156
-
157
  past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
158
  Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
159
  shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
160
-
161
  Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
162
  cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
163
-
164
  If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
165
  that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
166
  all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
167
-
168
  inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
169
  Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
170
  This is useful if you want more control over how to convert `input_ids` indices into associated vectors
@@ -405,33 +395,25 @@ class FlamingoForCausalLM(modeling_opt.OPTForCausalLM):
405
  input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
406
  Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
407
  provide it.
408
-
409
  Indices can be obtained using [`OPTTokenizer`]. See [`PreTrainedTokenizer.encode`] and
410
  [`PreTrainedTokenizer.__call__`] for details.
411
-
412
  [What are input IDs?](../glossary#input-ids)
413
  attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
414
  Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
415
-
416
  - 1 for tokens that are **not masked**,
417
  - 0 for tokens that are **masked**.
418
-
419
  [What are attention masks?](../glossary#attention-mask)
420
  head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*):
421
  Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
422
-
423
  - 1 indicates the head is **not masked**,
424
  - 0 indicates the head is **masked**.
425
-
426
  past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
427
  Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
428
  shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
429
  shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional
430
  tensors are only required when the model is used as a decoder in a Sequence to Sequence model.
431
-
432
  Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
433
  cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
434
-
435
  If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
436
  that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
437
  all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
@@ -454,20 +436,14 @@ class FlamingoForCausalLM(modeling_opt.OPTForCausalLM):
454
  for more detail.
455
  return_dict (`bool`, *optional*):
456
  Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
457
-
458
  Returns:
459
-
460
  Example:
461
-
462
  ```python
463
  >>> from transformers import GPT2Tokenizer, OPTForCausalLM
464
-
465
  >>> model = OPTForCausalLM.from_pretrained("facebook/opt-350m")
466
  >>> tokenizer = GPT2Tokenizer.from_pretrained("facebook/opt-350m")
467
-
468
  >>> prompt = "Hey, are you consciours? Can you talk to me?"
469
  >>> inputs = tokenizer(prompt, return_tensors="pt")
470
-
471
  >>> # Generate
472
  >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
473
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
@@ -514,3 +490,129 @@ class FlamingoForCausalLM(modeling_opt.OPTForCausalLM):
514
  hidden_states=outputs.hidden_states,
515
  attentions=outputs.attentions,
516
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  import torch
8
  import torch.utils.checkpoint
9
  from torch import nn
10
+ from torch.nn import CrossEntropyLoss, BCEWithLogitsLoss, MSELoss
11
 
12
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
13
  import transformers.models.opt.modeling_opt as modeling_opt
14
  from transformers.models.opt.modeling_opt\
15
  import OPTDecoderLayer, OPTPreTrainedModel, OPTConfig
 
46
  class OPTDecoder(modeling_opt.OPTDecoder):
47
  """
48
  Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`OPTDecoderLayer`]
 
49
  Args:
50
  config: OPTConfig
51
  embed_tokens (nn.Embedding): output embedding
 
135
  input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
136
  Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
137
  provide it.
 
138
  Indices can be obtained using [`OPTTokenizer`]. See [`PreTrainedTokenizer.encode`] and
139
  [`PreTrainedTokenizer.__call__`] for details.
 
140
  [What are input IDs?](../glossary#input-ids)
141
  attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
142
  Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
 
143
  - 1 for tokens that are **not masked**,
144
  - 0 for tokens that are **masked**.
 
145
  [What are attention masks?](../glossary#attention-mask)
146
  head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*):
147
  Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
 
148
  - 1 indicates the head is **not masked**,
149
  - 0 indicates the head is **masked**.
 
150
  past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
151
  Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
152
  shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
 
153
  Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
154
  cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
 
155
  If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
156
  that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
157
  all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
 
158
  inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
159
  Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
160
  This is useful if you want more control over how to convert `input_ids` indices into associated vectors
 
395
  input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
396
  Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
397
  provide it.
 
398
  Indices can be obtained using [`OPTTokenizer`]. See [`PreTrainedTokenizer.encode`] and
399
  [`PreTrainedTokenizer.__call__`] for details.
 
400
  [What are input IDs?](../glossary#input-ids)
401
  attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
402
  Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
 
403
  - 1 for tokens that are **not masked**,
404
  - 0 for tokens that are **masked**.
 
405
  [What are attention masks?](../glossary#attention-mask)
406
  head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*):
407
  Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
 
408
  - 1 indicates the head is **not masked**,
409
  - 0 indicates the head is **masked**.
 
410
  past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
411
  Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
412
  shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
413
  shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional
414
  tensors are only required when the model is used as a decoder in a Sequence to Sequence model.
 
415
  Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
416
  cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
 
417
  If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
418
  that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
419
  all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
 
436
  for more detail.
437
  return_dict (`bool`, *optional*):
438
  Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
 
439
  Returns:
 
440
  Example:
 
441
  ```python
442
  >>> from transformers import GPT2Tokenizer, OPTForCausalLM
 
443
  >>> model = OPTForCausalLM.from_pretrained("facebook/opt-350m")
444
  >>> tokenizer = GPT2Tokenizer.from_pretrained("facebook/opt-350m")
 
445
  >>> prompt = "Hey, are you consciours? Can you talk to me?"
446
  >>> inputs = tokenizer(prompt, return_tensors="pt")
 
447
  >>> # Generate
448
  >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
449
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
 
490
  hidden_states=outputs.hidden_states,
491
  attentions=outputs.attentions,
492
  )
493
+
494
+
495
+ class FlamingoForSequenceClassification(OPTPreTrainedModel):
496
+ _keys_to_ignore_on_load_missing = [
497
+ r"score.weight",
498
+ ]
499
+
500
+ def __init__(self, config: OPTConfig):
501
+ OPTPreTrainedModel.__init__(self, config)
502
+ config = setup_default_flamingo_configs(config)
503
+ self.num_labels = config.num_labels
504
+ self.model = OPTModel(config)
505
+
506
+ # the lm_head weight is automatically tied to the embed tokens weight
507
+ self.score = nn.Linear(config.word_embed_proj_dim, self.num_labels, bias=False)
508
+
509
+ # Initialize weights and apply final processing
510
+ self.post_init()
511
+ self.model.decoder.img_encoder = None
512
+ self.loss_fct = CrossEntropyLoss()
513
+ dino_model = ViTModel.from_pretrained("facebook/dino-vitb16")
514
+ self.setup_vis_encoder(dino_model)
515
+
516
+ def setup_vis_encoder(self, img_encoder):
517
+ self.model.decoder.img_encoder = img_encoder
518
+ freeze_all_layers_(img_encoder)
519
+
520
+ def forward(
521
+ self,
522
+ input_ids: Optional[torch.LongTensor] = None,
523
+ attention_mask: Optional[torch.FloatTensor] = None,
524
+ head_mask: Optional[torch.FloatTensor] = None,
525
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
526
+ inputs_embeds: Optional[torch.FloatTensor] = None,
527
+ labels: Optional[torch.LongTensor] = None,
528
+ use_cache: Optional[bool] = None,
529
+ output_attentions: Optional[bool] = None,
530
+ output_hidden_states: Optional[bool] = None,
531
+ return_dict: Optional[bool] = None,
532
+ *args, **kwargs) -> Union[Tuple, SequenceClassifierOutputWithPast]:
533
+ r"""
534
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
535
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
536
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
537
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
538
+ """
539
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
540
+
541
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
542
+ outputs = self.model.decoder(
543
+ input_ids=input_ids,
544
+ attention_mask=attention_mask,
545
+ head_mask=head_mask,
546
+ past_key_values=past_key_values,
547
+ inputs_embeds=inputs_embeds,
548
+ use_cache=use_cache,
549
+ output_attentions=output_attentions,
550
+ output_hidden_states=output_hidden_states,
551
+ return_dict=return_dict,
552
+ *args, **kwargs)
553
+
554
+ hidden_states = outputs[0]
555
+ logits = self.score(hidden_states)
556
+
557
+ if input_ids is not None:
558
+ batch_size, sequence_length = input_ids.shape[:2]
559
+ else:
560
+ batch_size, sequence_length = inputs_embeds.shape[:2]
561
+
562
+ if self.config.pad_token_id is None:
563
+ sequence_lengths = -1
564
+ else:
565
+ if input_ids is not None:
566
+ # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
567
+ sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
568
+ sequence_lengths = sequence_lengths % input_ids.shape[-1]
569
+ sequence_lengths = sequence_lengths.to(logits.device)
570
+ else:
571
+ sequence_lengths = -1
572
+ # logger.warning(
573
+ # f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
574
+ # "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
575
+ # )
576
+
577
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
578
+
579
+ loss = None
580
+ if labels is not None:
581
+ if self.config.problem_type is None:
582
+ if self.num_labels == 1:
583
+ self.config.problem_type = "regression"
584
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
585
+ self.config.problem_type = "single_label_classification"
586
+ else:
587
+ self.config.problem_type = "multi_label_classification"
588
+
589
+ if self.config.problem_type == "regression":
590
+ loss_fct = MSELoss()
591
+ if self.num_labels == 1:
592
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
593
+ else:
594
+ loss = loss_fct(pooled_logits, labels)
595
+ elif self.config.problem_type == "single_label_classification":
596
+ loss_fct = CrossEntropyLoss()
597
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
598
+ elif self.config.problem_type == "multi_label_classification":
599
+ loss_fct = BCEWithLogitsLoss()
600
+ loss = loss_fct(pooled_logits, labels)
601
+
602
+ if not return_dict:
603
+ output = (pooled_logits,) + outputs[1:]
604
+ return ((loss,) + output) if loss is not None else output
605
+
606
+ return SequenceClassifierOutputWithPast(
607
+ loss=loss,
608
+ logits=pooled_logits,
609
+ past_key_values=outputs.past_key_values,
610
+ hidden_states=outputs.hidden_states,
611
+ attentions=outputs.attentions,
612
+ )
613
+
614
+ def get_input_embeddings(self):
615
+ return self.model.decoder.embed_tokens
616
+
617
+ def set_input_embeddings(self, value):
618
+ self.model.decoder.embed_tokens = value