wissamantoun commited on
Commit
3fed02b
1 Parent(s): 0266e41

Update backend/modeling_gpt2.py

Browse files
Files changed (1) hide show
  1. backend/modeling_gpt2.py +650 -436
backend/modeling_gpt2.py CHANGED
@@ -13,51 +13,54 @@
13
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
  # See the License for the specific language governing permissions and
15
  # limitations under the License.
 
16
 
17
- """
18
- PyTorch OpenAI GPT-2 model.
19
- Adapted from https://github.com/huggingface/transformers/blob/v4.0.1/src/transformers/models/gpt2/modeling_gpt2.py
20
- and https://github.com/ghosthamlet/gpt2-ml-torch/blob/master/gpt2_ml_torch/modeling_gpt2.py
21
- """
22
-
23
-
24
- import logging
25
  import os
26
  from dataclasses import dataclass
27
- from typing import List, Optional, Tuple
28
 
29
  import torch
30
- import torch.nn as nn
31
- from torch.nn import CrossEntropyLoss, MSELoss
32
- from transformers import CONFIG_NAME, WEIGHTS_NAME, GPT2Config, GPT2Model
 
 
 
 
 
 
 
 
 
33
  from transformers.activations import ACT2FN
34
- from transformers.file_utils import (
35
- ModelOutput,
36
- add_code_sample_docstrings,
37
- add_start_docstrings,
38
- add_start_docstrings_to_model_forward,
39
- replace_return_docstrings,
40
- )
41
  from transformers.modeling_outputs import (
42
  BaseModelOutputWithPastAndCrossAttentions,
43
  CausalLMOutputWithCrossAttentions,
44
  SequenceClassifierOutputWithPast,
45
  TokenClassifierOutput,
46
  )
47
- from transformers.modeling_utils import (
 
48
  Conv1D,
49
- PreTrainedModel,
50
- SequenceSummary,
51
  find_pruneable_heads_and_indices,
52
  prune_conv1d_layer,
53
  )
 
 
 
 
 
 
 
 
54
  from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
 
55
 
56
- # THe Difference from Transformers is code under _USE_GROVER
57
- _USE_GROVER = True
58
 
59
- logger = logging.getLogger(__name__)
60
 
 
61
  _CONFIG_FOR_DOC = "GPT2Config"
62
  _TOKENIZER_FOR_DOC = "GPT2Tokenizer"
63
 
@@ -70,11 +73,6 @@ GPT2_PRETRAINED_MODEL_ARCHIVE_LIST = [
70
  # See all GPT-2 models at https://huggingface.co/models?filter=gpt2
71
  ]
72
 
73
- logger.setLevel(logging.INFO)
74
- console = logging.StreamHandler()
75
- console.setLevel(logging.INFO)
76
- logger.addHandler(console)
77
-
78
  _GPT2_ML_TF_TO_TORCH = {
79
  "LayerNorm_embed_norm": "emb_norm",
80
  "pos_embed": "wpe.weight",
@@ -126,7 +124,6 @@ def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
126
  """Load tf checkpoints in a pytorch model"""
127
  try:
128
  import re
129
-
130
  import tensorflow as tf
131
  except ImportError:
132
  logger.error(
@@ -206,11 +203,7 @@ def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
206
  d = torch.from_numpy(array)
207
  is_bias = len(shape) == 1
208
  end = int(shape[0 if is_bias else 1] / 3)
209
- m = dict(
210
- query_layer=0,
211
- key_layer=end,
212
- value_layer=end * 2,
213
- )
214
  start = m[attn_layer]
215
  end = start + end
216
  if is_bias:
@@ -232,39 +225,54 @@ def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
232
  return model
233
 
234
 
235
- class Attention(nn.Module):
236
- def __init__(self, nx, n_ctx, config, scale=False, is_cross_attention=False):
237
  super().__init__()
238
 
239
- n_state = nx # in Attention: n_state=768 (nx=n_embd)
240
- # [switch nx => n_state from Block to Attention to keep identical to TF implem]
241
- assert n_state % config.n_head == 0
242
  self.register_buffer(
243
  "bias",
244
- torch.tril(torch.ones((n_ctx, n_ctx), dtype=torch.uint8)).view(
245
- 1, 1, n_ctx, n_ctx
246
- ),
247
  )
248
  self.register_buffer("masked_bias", torch.tensor(-1e4))
249
- self.n_head = config.n_head
250
- self.split_size = n_state
251
- self.scale = scale
 
 
 
 
 
 
 
 
 
252
  self.is_cross_attention = is_cross_attention
 
 
 
 
 
 
253
  if self.is_cross_attention:
254
- self.c_attn = Conv1D(2 * n_state, nx)
255
- self.q_attn = Conv1D(n_state, nx)
256
  else:
257
- self.c_attn = Conv1D(3 * n_state, nx)
258
- self.c_proj = Conv1D(n_state, nx)
 
259
  self.attn_dropout = nn.Dropout(config.attn_pdrop)
260
  self.resid_dropout = nn.Dropout(config.resid_pdrop)
 
261
  self.pruned_heads = set()
262
 
263
  def prune_heads(self, heads):
264
  if len(heads) == 0:
265
  return
266
  heads, index = find_pruneable_heads_and_indices(
267
- heads, self.n_head, self.split_size // self.n_head, self.pruned_heads
268
  )
269
  index_attn = torch.cat(
270
  [index, index + self.split_size, index + (2 * self.split_size)]
@@ -275,67 +283,163 @@ class Attention(nn.Module):
275
  self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0)
276
 
277
  # Update hyper params
278
- self.split_size = (self.split_size // self.n_head) * (self.n_head - len(heads))
279
- self.n_head = self.n_head - len(heads)
 
 
280
  self.pruned_heads = self.pruned_heads.union(heads)
281
 
282
- def _attn(
283
- self, q, k, v, attention_mask=None, head_mask=None, output_attentions=False
284
- ):
285
- w = torch.matmul(q, k)
286
- if self.scale:
287
- w = w / (float(v.size(-1)) ** 0.5)
288
- nd, ns = w.size(-2), w.size(-1)
 
 
289
 
290
  if not self.is_cross_attention:
291
  # if only "normal" attention layer implements causal mask
292
- mask = self.bias[:, :, ns - nd : ns, :ns]
293
- w = torch.where(mask.bool(), w, self.masked_bias.to(w.dtype))
 
 
 
 
 
294
 
295
  if attention_mask is not None:
296
  # Apply the attention mask
297
- w = w + attention_mask
298
 
299
- w = nn.Softmax(dim=-1)(w)
300
- w = self.attn_dropout(w)
 
 
 
301
 
302
  # Mask heads if we want to
303
  if head_mask is not None:
304
- w = w * head_mask
305
 
306
- outputs = [torch.matmul(w, v)]
307
- if output_attentions:
308
- outputs.append(w)
309
- return outputs
310
-
311
- def merge_heads(self, x):
312
- x = x.permute(0, 2, 1, 3).contiguous()
313
- new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),)
314
- return x.view(*new_x_shape) # in Tensorflow implem: fct merge_states
315
-
316
- def split_heads(self, x, k=False):
317
- new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head)
318
- x = x.view(*new_x_shape) # in Tensorflow implem: fct split_states
319
- if k:
320
- return x.permute(0, 2, 3, 1) # (batch, head, head_features, seq_length)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
321
  else:
322
- return x.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
323
 
324
  def forward(
325
  self,
326
- hidden_states,
327
- layer_past=None,
328
- attention_mask=None,
329
- head_mask=None,
330
- encoder_hidden_states=None,
331
- encoder_attention_mask=None,
332
- use_cache=False,
333
- output_attentions=False,
334
- ):
335
  if encoder_hidden_states is not None:
336
- assert hasattr(
337
- self, "q_attn"
338
- ), "If class is used as cross attention, the weights `q_attn` have to be defined. Please make sure to instantiate class with `Attention(..., is_cross_attention=True)`."
 
 
 
339
  query = self.q_attn(hidden_states)
340
  key, value = self.c_attn(encoder_hidden_states).split(
341
  self.split_size, dim=2
@@ -344,80 +448,97 @@ class Attention(nn.Module):
344
  else:
345
  query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
346
 
347
- query = self.split_heads(query)
348
- key = self.split_heads(key, k=True)
349
- value = self.split_heads(value)
 
350
  if layer_past is not None:
351
- past_key, past_value = (
352
- layer_past[0].transpose(-2, -1),
353
- layer_past[1],
354
- ) # transpose back cf below
355
- key = torch.cat((past_key, key), dim=-1)
356
  value = torch.cat((past_value, value), dim=-2)
357
 
358
  if use_cache is True:
359
- present = torch.stack(
360
- (key.transpose(-2, -1), value)
361
- ) # transpose to have same shapes for stacking
362
  else:
363
- present = (None,)
364
 
365
- attn_outputs = self._attn(
366
- query, key, value, attention_mask, head_mask, output_attentions
367
- )
368
- a = attn_outputs[0]
 
 
 
 
 
 
 
 
369
 
370
- a = self.merge_heads(a)
371
- a = self.c_proj(a)
372
- a = self.resid_dropout(a)
373
 
374
- outputs = [a, present] + attn_outputs[1:]
375
  return outputs # a, present, (attentions)
376
 
377
 
378
- class MLP(nn.Module):
379
- def __init__(self, n_state, config): # in MLP: n_state=3072 (4 * n_embd)
380
  super().__init__()
381
- nx = config.n_embd
382
- self.c_fc = Conv1D(n_state, nx)
383
- self.c_proj = Conv1D(nx, n_state)
384
  self.act = ACT2FN[config.activation_function]
385
  self.dropout = nn.Dropout(config.resid_pdrop)
386
 
387
- def forward(self, x):
388
- h = self.act(self.c_fc(x))
389
- h2 = self.c_proj(h)
390
- return self.dropout(h2)
 
 
 
 
391
 
392
 
393
- class Block(nn.Module):
394
- def __init__(self, n_ctx, config, scale=False):
395
  super().__init__()
396
- hidden_size = config.n_embd
397
  inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
 
398
  self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
399
- self.attn = Attention(hidden_size, n_ctx, config, scale)
400
  self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
 
401
  if config.add_cross_attention:
402
- self.crossattention = Attention(
403
- hidden_size, n_ctx, config, scale, is_cross_attention=True
404
  )
405
  self.ln_cross_attn = nn.LayerNorm(
406
  hidden_size, eps=config.layer_norm_epsilon
407
  )
408
- self.mlp = MLP(inner_dim, config)
 
409
 
410
  def forward(
411
  self,
412
- hidden_states,
413
- layer_past=None,
414
- attention_mask=None,
415
- head_mask=None,
416
- encoder_hidden_states=None,
417
- encoder_attention_mask=None,
418
- use_cache=False,
419
- output_attentions=False,
420
- ):
 
 
 
 
 
 
 
421
  attn_outputs = self.attn(
422
  hidden_states,
423
  layer_past=layer_past,
@@ -433,11 +554,16 @@ class Block(nn.Module):
433
 
434
  if encoder_hidden_states is not None:
435
  # add one self-attention block for cross-attention
436
- assert hasattr(
437
- self, "crossattention"
438
- ), f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`"
 
 
 
 
 
439
  cross_attn_outputs = self.crossattention(
440
- self.ln_cross_attn(hidden_states),
441
  attention_mask=attention_mask,
442
  head_mask=head_mask,
443
  encoder_hidden_states=encoder_hidden_states,
@@ -446,18 +572,24 @@ class Block(nn.Module):
446
  )
447
  attn_output = cross_attn_outputs[0]
448
  # residual connection
449
- hidden_states = hidden_states + attn_output
450
  outputs = (
451
  outputs + cross_attn_outputs[2:]
452
  ) # add cross attentions if we output attention weights
453
 
454
- feed_forward_hidden_states = self.mlp(self.ln_1(hidden_states))
 
 
455
  # residual connection
456
- hidden_states = hidden_states + feed_forward_hidden_states
457
 
458
- hidden_states = self.ln_2(hidden_states)
 
 
 
 
 
459
 
460
- outputs = [hidden_states] + outputs
461
  return outputs # hidden_states, present, (attentions, cross_attentions)
462
 
463
 
@@ -471,22 +603,48 @@ class GPT2PreTrainedModel(PreTrainedModel):
471
  load_tf_weights = load_tf_weights_in_gpt2
472
  base_model_prefix = "transformer"
473
  is_parallelizable = True
 
474
 
475
  def __init__(self, *inputs, **kwargs):
476
  super().__init__(*inputs, **kwargs)
477
 
478
  def _init_weights(self, module):
479
  """Initialize the weights."""
480
- if isinstance(module, (nn.Linear, nn.Embedding, Conv1D)):
481
  # Slightly different from the TF version which uses truncated_normal for initialization
482
  # cf https://github.com/pytorch/pytorch/pull/5617
483
  module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
484
- if isinstance(module, (nn.Linear, Conv1D)) and module.bias is not None:
485
  module.bias.data.zero_()
 
 
 
 
486
  elif isinstance(module, nn.LayerNorm):
487
  module.bias.data.zero_()
488
  module.weight.data.fill_(1.0)
489
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
490
 
491
  @dataclass
492
  class GPT2DoubleHeadsModelOutput(ModelOutput):
@@ -494,125 +652,125 @@ class GPT2DoubleHeadsModelOutput(ModelOutput):
494
  Base class for outputs of models predicting if two sentences are consecutive or not.
495
 
496
  Args:
497
- loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when ``labels`` is provided):
498
  Language modeling loss.
499
- mc_loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`mc_labels` is provided):
500
  Multiple choice classification loss.
501
- logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices, sequence_length, config.vocab_size)`):
502
  Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
503
- mc_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices)`):
504
  Prediction scores of the multiple choice classification head (scores for each choice before SoftMax).
505
- past_key_values (:obj:`List[torch.FloatTensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
506
- List of :obj:`torch.FloatTensor` of length :obj:`config.n_layers`, with each tensor of shape :obj:`(2,
507
- batch_size, num_heads, sequence_length, embed_size_per_head)`).
508
 
509
  Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
510
- :obj:`past_key_values` input) to speed up sequential decoding.
511
- hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
512
- Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
513
- of shape :obj:`(batch_size, sequence_length, hidden_size)`.
514
 
515
  Hidden-states of the model at the output of each layer plus the initial embedding outputs.
516
- attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
517
- Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
518
- sequence_length, sequence_length)`.
519
 
520
- Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
521
- heads.
522
  """
523
 
524
  loss: Optional[torch.FloatTensor] = None
525
  mc_loss: Optional[torch.FloatTensor] = None
526
  logits: torch.FloatTensor = None
527
  mc_logits: torch.FloatTensor = None
528
- past_key_values: Optional[List[torch.FloatTensor]] = None
529
  hidden_states: Optional[Tuple[torch.FloatTensor]] = None
530
  attentions: Optional[Tuple[torch.FloatTensor]] = None
531
 
532
 
533
  GPT2_START_DOCSTRING = r"""
534
 
535
- This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic
536
- methods the library implements for all its model (such as downloading or saving, resizing the input embeddings,
537
- pruning heads etc.)
538
 
539
- This model is also a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`__
540
- subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to
541
- general usage and behavior.
542
 
543
  Parameters:
544
- config (:class:`~transformers.GPT2Config`): Model configuration class with all the parameters of the model.
545
  Initializing with a config file does not load the weights associated with the model, only the
546
- configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model
547
- weights.
548
  """
549
 
550
  GPT2_INPUTS_DOCSTRING = r"""
551
  Args:
552
- input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, input_ids_length)`):
553
- :obj:`input_ids_length` = ``sequence_length`` if :obj:`past_key_values` is ``None`` else
554
- ``past_key_values[0].shape[-2]`` (``sequence_length`` of input past key value states). Indices of input
555
  sequence tokens in the vocabulary.
556
 
557
- If :obj:`past_key_values` is used, only ``input_ids`` that do not have their past calculated should be
558
- passed as ``input_ids``.
559
 
560
- Indices can be obtained using :class:`~transformers.GPT2Tokenizer`. See
561
- :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
562
- details.
563
 
564
- `What are input IDs? <../glossary.html#input-ids>`__
565
- past_key_values (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
566
  Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see
567
- :obj:`past_key_values` output below). Can be used to speed up sequential decoding. The ``input_ids`` which
568
- have their past given to this model should not be passed as ``input_ids`` as they have already been
569
- computed.
570
- attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
571
- Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
572
 
573
  - 1 for tokens that are **not masked**,
574
  - 0 for tokens that are **masked**.
575
 
576
- `What are attention masks? <../glossary.html#attention-mask>`__
577
- token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, input_ids_length)`, `optional`):
578
- Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0,
579
- 1]``:
580
 
581
- - 0 corresponds to a `sentence A` token,
582
- - 1 corresponds to a `sentence B` token.
 
 
583
 
584
- `What are token type IDs? <../glossary.html#token-type-ids>`_
585
- position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
586
- Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0,
587
- config.max_position_embeddings - 1]``.
588
 
589
- `What are position IDs? <../glossary.html#position-ids>`_
590
- head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
591
- Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``:
 
 
 
 
 
592
 
593
  - 1 indicates the head is **not masked**,
594
  - 0 indicates the head is **masked**.
595
 
596
- inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
597
- Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
598
- This is useful if you want more control over how to convert :obj:`input_ids` indices into associated
599
- vectors than the model's internal embedding lookup matrix.
600
-
601
- If :obj:`past_key_values` is used, optionally only the last :obj:`inputs_embeds` have to be input (see
602
- :obj:`past_key_values`).
603
- use_cache (:obj:`bool`, `optional`):
604
- If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
605
- decoding (see :obj:`past_key_values`).
606
- output_attentions (:obj:`bool`, `optional`):
607
- Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
608
  tensors for more detail.
609
- output_hidden_states (:obj:`bool`, `optional`):
610
- Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
611
  more detail.
612
- return_dict (:obj:`bool`, `optional`):
613
- Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
614
  """
615
-
616
  PARALLELIZE_DOCSTRING = r"""
617
  This is an experimental feature and is a subject to change at a moment's notice.
618
 
@@ -620,7 +778,7 @@ PARALLELIZE_DOCSTRING = r"""
620
  it will evenly distribute blocks across all devices.
621
 
622
  Args:
623
- device_map (:obj:`Dict[int, list]`, optional, defaults to None):
624
  A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always
625
  automatically mapped to the first device (for esoteric reasons). That means that the first device should
626
  have fewer attention modules mapped to it than other devices. For reference, the gpt2 models have the
@@ -631,31 +789,37 @@ PARALLELIZE_DOCSTRING = r"""
631
  - gpt2-large: 36
632
  - gpt2-xl: 48
633
 
634
- Example::
635
-
636
- # Here is an example of a device map on a machine with 4 GPUs using gpt2-xl, which has a total of 48 attention modules:
637
- model = GPT2LMHeadModel.from_pretrained('gpt2-xl')
638
- device_map = {0: [0, 1, 2, 3, 4, 5, 6, 7, 8],
639
-
640
- 1: [9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21],
641
- 2: [22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34],
642
- 3: [35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47]}
643
- model.parallelize(device_map)
 
 
 
644
  """
645
  DEPARALLELIZE_DOCSTRING = r"""
646
  Moves the model to cpu from a model parallel state.
647
 
648
- Example::
649
-
650
- # On a 4 GPU machine with gpt2-large:
651
- model = GPT2LMHeadModel.from_pretrained('gpt2-large')
652
- device_map = {0: [0, 1, 2, 3, 4, 5, 6, 7],
653
-
654
- 1: [8, 9, 10, 11, 12, 13, 14, 15],
655
- 2: [16, 17, 18, 19, 20, 21, 22, 23],
656
- 3: [24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35]}
657
- model.parallelize(device_map) # Splits the model across several devices
658
- model.deparallelize() # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache()
 
 
 
659
  """
660
 
661
 
@@ -664,26 +828,32 @@ DEPARALLELIZE_DOCSTRING = r"""
664
  GPT2_START_DOCSTRING,
665
  )
666
  class GPT2Model(GPT2PreTrainedModel):
 
 
667
  def __init__(self, config):
668
  super().__init__(config)
669
 
670
- self.wte = nn.Embedding(config.vocab_size, config.n_embd)
671
- self.wpe = nn.Embedding(config.n_positions, config.n_embd)
672
- if _USE_GROVER:
673
- self.emb_norm = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
674
 
 
 
 
 
 
675
  self.drop = nn.Dropout(config.embd_pdrop)
676
  self.h = nn.ModuleList(
677
- [Block(config.n_ctx, config, scale=True) for _ in range(config.n_layer)]
678
  )
679
- if not _USE_GROVER:
680
- self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
681
-
682
- self.init_weights()
683
 
684
  # Model parallel
685
  self.model_parallel = False
686
  self.device_map = None
 
 
 
 
687
 
688
  @add_start_docstrings(PARALLELIZE_DOCSTRING)
689
  def parallelize(self, device_map=None):
@@ -703,13 +873,22 @@ class GPT2Model(GPT2PreTrainedModel):
703
  self.last_device = "cuda:" + str(max(self.device_map.keys()))
704
  self.wte = self.wte.to(self.first_device)
705
  self.wpe = self.wpe.to(self.first_device)
 
 
 
 
 
 
 
 
706
  # Load onto devices
707
  for k, v in self.device_map.items():
708
  for block in v:
709
  cuda_device = "cuda:" + str(k)
710
  self.h[block] = self.h[block].to(cuda_device)
711
  # ln_f to last
712
- self.ln_f = self.ln_f.to(self.last_device)
 
713
 
714
  @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
715
  def deparallelize(self):
@@ -719,9 +898,12 @@ class GPT2Model(GPT2PreTrainedModel):
719
  self.last_device = "cpu"
720
  self.wte = self.wte.to("cpu")
721
  self.wpe = self.wpe.to("cpu")
 
 
722
  for index in range(len(self.h)):
723
  self.h[index] = self.h[index].to("cpu")
724
- self.ln_f = self.ln_f.to("cpu")
 
725
  torch.cuda.empty_cache()
726
 
727
  def get_input_embeddings(self):
@@ -739,27 +921,27 @@ class GPT2Model(GPT2PreTrainedModel):
739
 
740
  @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
741
  @add_code_sample_docstrings(
742
- tokenizer_class=_TOKENIZER_FOR_DOC,
743
- checkpoint="gpt2",
744
  output_type=BaseModelOutputWithPastAndCrossAttentions,
745
  config_class=_CONFIG_FOR_DOC,
746
  )
747
  def forward(
748
  self,
749
- input_ids=None,
750
- past_key_values=None,
751
- attention_mask=None,
752
- token_type_ids=None,
753
- position_ids=None,
754
- head_mask=None,
755
- inputs_embeds=None,
756
- encoder_hidden_states=None,
757
- encoder_attention_mask=None,
758
- use_cache=None,
759
- output_attentions=None,
760
- output_hidden_states=None,
761
- return_dict=None,
762
- ):
763
  output_attentions = (
764
  output_attentions
765
  if output_attentions is not None
@@ -789,6 +971,8 @@ class GPT2Model(GPT2PreTrainedModel):
789
  else:
790
  raise ValueError("You have to specify either input_ids or inputs_embeds")
791
 
 
 
792
  if token_type_ids is not None:
793
  token_type_ids = token_type_ids.view(-1, input_shape[-1])
794
  if position_ids is not None:
@@ -796,11 +980,10 @@ class GPT2Model(GPT2PreTrainedModel):
796
 
797
  if past_key_values is None:
798
  past_length = 0
799
- past_key_values = [None] * len(self.h)
800
  else:
801
  past_length = past_key_values[0][0].size(-2)
802
  if position_ids is None:
803
- device = input_ids.device if input_ids is not None else inputs_embeds.device
804
  position_ids = torch.arange(
805
  past_length,
806
  input_shape[-1] + past_length,
@@ -809,7 +992,7 @@ class GPT2Model(GPT2PreTrainedModel):
809
  )
810
  position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
811
 
812
- # Attention mask.
813
  if attention_mask is not None:
814
  if batch_size <= 0:
815
  raise ValueError("batch_size has to be defined and > 0")
@@ -829,7 +1012,7 @@ class GPT2Model(GPT2PreTrainedModel):
829
  attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
830
  attention_mask = (1.0 - attention_mask) * -10000.0
831
 
832
- # If a 2D ou 3D attention mask is provided for the cross-attention
833
  # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
834
  if self.config.add_cross_attention and encoder_hidden_states is not None:
835
  (
@@ -860,8 +1043,9 @@ class GPT2Model(GPT2PreTrainedModel):
860
  hidden_states = hidden_states + token_type_embeds
861
 
862
  hidden_states = self.drop(hidden_states)
863
- if _USE_GROVER:
864
- hidden_states = self.emb_norm(hidden_states)
 
865
  output_shape = input_shape + (hidden_states.size(-1),)
866
 
867
  presents = () if use_cache else None
@@ -885,28 +1069,28 @@ class GPT2Model(GPT2PreTrainedModel):
885
  attention_mask = attention_mask.to(hidden_states.device)
886
  if isinstance(head_mask, torch.Tensor):
887
  head_mask = head_mask.to(hidden_states.device)
888
-
889
  if output_hidden_states:
890
- all_hidden_states = all_hidden_states + (
891
- hidden_states.view(*output_shape),
892
- )
893
 
894
- if getattr(self.config, "gradient_checkpointing", False):
 
 
 
 
 
 
895
 
896
  def create_custom_forward(module):
897
  def custom_forward(*inputs):
898
- # checkpointing only works with tuple returns, not with lists
899
- return tuple(
900
- output
901
- for output in module(*inputs, use_cache, output_attentions)
902
- )
903
 
904
  return custom_forward
905
 
906
  outputs = torch.utils.checkpoint.checkpoint(
907
  create_custom_forward(block),
908
  hidden_states,
909
- layer_past,
910
  attention_mask,
911
  head_mask[i],
912
  encoder_hidden_states,
@@ -924,9 +1108,9 @@ class GPT2Model(GPT2PreTrainedModel):
924
  output_attentions=output_attentions,
925
  )
926
 
927
- hidden_states, present = outputs[:2]
928
  if use_cache is True:
929
- presents = presents + (present,)
930
 
931
  if output_attentions:
932
  all_self_attentions = all_self_attentions + (
@@ -943,10 +1127,10 @@ class GPT2Model(GPT2PreTrainedModel):
943
  if i == v[-1] and "cuda:" + str(k) != self.last_device:
944
  hidden_states = hidden_states.to("cuda:" + str(k + 1))
945
 
946
- if not _USE_GROVER:
947
- hidden_states = self.ln_f(hidden_states)
948
 
949
- hidden_states = hidden_states.view(*output_shape)
950
  # Add last hidden state
951
  if output_hidden_states:
952
  all_hidden_states = all_hidden_states + (hidden_states,)
@@ -981,19 +1165,24 @@ class GPT2Model(GPT2PreTrainedModel):
981
  GPT2_START_DOCSTRING,
982
  )
983
  class GPT2LMHeadModel(GPT2PreTrainedModel):
984
- _keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"lm_head\.weight"]
 
 
 
 
985
 
986
  def __init__(self, config):
987
  super().__init__(config)
988
  self.transformer = GPT2Model(config)
989
  self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
990
 
991
- self.init_weights()
992
-
993
  # Model parallel
994
  self.model_parallel = False
995
  self.device_map = None
996
 
 
 
 
997
  @add_start_docstrings(PARALLELIZE_DOCSTRING)
998
  def parallelize(self, device_map=None):
999
  self.device_map = (
@@ -1017,6 +1206,9 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
1017
  def get_output_embeddings(self):
1018
  return self.lm_head
1019
 
 
 
 
1020
  def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
1021
  token_type_ids = kwargs.get("token_type_ids", None)
1022
  # only last token for inputs_ids if past is defined in kwargs
@@ -1047,33 +1239,33 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
1047
 
1048
  @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
1049
  @add_code_sample_docstrings(
1050
- tokenizer_class=_TOKENIZER_FOR_DOC,
1051
- checkpoint="gpt2",
1052
  output_type=CausalLMOutputWithCrossAttentions,
1053
  config_class=_CONFIG_FOR_DOC,
1054
  )
1055
  def forward(
1056
  self,
1057
- input_ids=None,
1058
- past_key_values=None,
1059
- attention_mask=None,
1060
- token_type_ids=None,
1061
- position_ids=None,
1062
- head_mask=None,
1063
- inputs_embeds=None,
1064
- encoder_hidden_states=None,
1065
- encoder_attention_mask=None,
1066
- labels=None,
1067
- use_cache=None,
1068
- output_attentions=None,
1069
- output_hidden_states=None,
1070
- return_dict=None,
1071
- ):
1072
  r"""
1073
- labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1074
  Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
1075
- ``labels = input_ids`` Indices are selected in ``[-100, 0, ..., config.vocab_size]`` All labels set to
1076
- ``-100`` are ignored (masked), the loss is only computed for labels in ``[0, ..., config.vocab_size]``
1077
  """
1078
  return_dict = (
1079
  return_dict if return_dict is not None else self.config.use_return_dict
@@ -1132,9 +1324,9 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
1132
  past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
1133
  ) -> Tuple[Tuple[torch.Tensor]]:
1134
  """
1135
- This function is used to re-order the :obj:`past_key_values` cache if
1136
- :meth:`~transformers.PreTrainedModel.beam_search` or :meth:`~transformers.PreTrainedModel.beam_sample` is
1137
- called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step.
1138
  """
1139
  return tuple(
1140
  tuple(
@@ -1155,6 +1347,12 @@ input sequence).
1155
  GPT2_START_DOCSTRING,
1156
  )
1157
  class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
 
 
 
 
 
 
1158
  def __init__(self, config):
1159
  super().__init__(config)
1160
  config.num_labels = 1
@@ -1162,12 +1360,13 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
1162
  self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
1163
  self.multiple_choice_head = SequenceSummary(config)
1164
 
1165
- self.init_weights()
1166
-
1167
  # Model parallel
1168
  self.model_parallel = False
1169
  self.device_map = None
1170
 
 
 
 
1171
  @add_start_docstrings(PARALLELIZE_DOCSTRING)
1172
  def parallelize(self, device_map=None):
1173
  self.device_map = (
@@ -1195,6 +1394,9 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
1195
  def get_output_embeddings(self):
1196
  return self.lm_head
1197
 
 
 
 
1198
  def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
1199
  token_type_ids = kwargs.get("token_type_ids", None)
1200
  # only last token for inputs_ids if past is defined in kwargs
@@ -1230,62 +1432,61 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
1230
  )
1231
  def forward(
1232
  self,
1233
- input_ids=None,
1234
- past_key_values=None,
1235
- attention_mask=None,
1236
- token_type_ids=None,
1237
- position_ids=None,
1238
- head_mask=None,
1239
- inputs_embeds=None,
1240
- mc_token_ids=None,
1241
- labels=None,
1242
- mc_labels=None,
1243
- use_cache=None,
1244
- output_attentions=None,
1245
- output_hidden_states=None,
1246
- return_dict=None,
1247
  **kwargs,
1248
- ):
1249
  r"""
1250
- mc_token_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, num_choices)`, `optional`, default to index of the last token of the input):
1251
- Index of the classification token in each input sequence. Selected in the range ``[0, input_ids.size(-1) -
1252
- 1[``.
1253
- labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1254
  Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
1255
- ``labels = input_ids`` Indices are selected in ``[-1, 0, ..., config.vocab_size]`` All labels set to
1256
- ``-100`` are ignored (masked), the loss is only computed for labels in ``[0, ..., config.vocab_size]``
1257
- mc_labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size)`, `optional`):
1258
- Labels for computing the multiple choice classification loss. Indices should be in ``[0, ...,
1259
- num_choices]`` where `num_choices` is the size of the second dimension of the input tensors. (see
1260
- `input_ids` above)
1261
 
1262
  Return:
1263
 
1264
- Example::
1265
-
1266
- >>> import torch
1267
- >>> from transformers import GPT2Tokenizer, GPT2DoubleHeadsModel
1268
-
1269
- >>> tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
1270
- >>> model = GPT2DoubleHeadsModel.from_pretrained('gpt2')
1271
 
1272
- >>> # Add a [CLS] to the vocabulary (we should train it also!)
1273
- >>> num_added_tokens = tokenizer.add_special_tokens({'cls_token': '[CLS]'})
 
1274
 
1275
- >>> embedding_layer = model.resize_token_embeddings(len(tokenizer)) # Update the model embeddings with the new vocabulary size
 
1276
 
1277
- >>> choices = ["Hello, my dog is cute [CLS]", "Hello, my cat is cute [CLS]"]
1278
- >>> encoded_choices = [tokenizer.encode(s) for s in choices]
1279
- >>> cls_token_location = [tokens.index(tokenizer.cls_token_id) for tokens in encoded_choices]
 
1280
 
1281
- >>> input_ids = torch.tensor(encoded_choices).unsqueeze(0) # Batch size: 1, number of choices: 2
1282
- >>> mc_token_ids = torch.tensor([cls_token_location]) # Batch size: 1
 
1283
 
1284
- >>> outputs = model(input_ids, mc_token_ids=mc_token_ids)
1285
- >>> lm_logits = outputs.lm_logits
1286
- >>> mc_logits = outputs.mc_logits
1287
 
1288
- """
 
 
 
1289
  return_dict = (
1290
  return_dict if return_dict is not None else self.config.use_return_dict
1291
  )
@@ -1350,9 +1551,9 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
1350
  past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
1351
  ) -> Tuple[Tuple[torch.Tensor]]:
1352
  """
1353
- This function is used to re-order the :obj:`past_key_values` cache if
1354
- :meth:`~transformers.PreTrainedModel.beam_search` or :meth:`~transformers.PreTrainedModel.beam_sample` is
1355
- called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step.
1356
  """
1357
  return tuple(
1358
  tuple(
@@ -1367,14 +1568,14 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
1367
  """
1368
  The GPT2 Model transformer with a sequence classification head on top (linear layer).
1369
 
1370
- :class:`~transformers.GPT2ForSequenceClassification` uses the last token in order to do the classification, as
1371
- other causal models (e.g. GPT-1) do.
1372
 
1373
  Since it does classification on the last token, it requires to know the position of the last token. If a
1374
- :obj:`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each
1375
- row. If no :obj:`pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot
1376
- guess the padding tokens when :obj:`inputs_embeds` are passed instead of :obj:`input_ids`, it does the same (take
1377
- the last value in each row of the batch).
1378
  """,
1379
  GPT2_START_DOCSTRING,
1380
  )
@@ -1387,39 +1588,42 @@ class GPT2ForSequenceClassification(GPT2PreTrainedModel):
1387
  self.transformer = GPT2Model(config)
1388
  self.score = nn.Linear(config.n_embd, self.num_labels, bias=False)
1389
 
1390
- self.init_weights()
1391
-
1392
  # Model parallel
1393
  self.model_parallel = False
1394
  self.device_map = None
1395
 
 
 
 
1396
  @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
1397
  @add_code_sample_docstrings(
1398
- tokenizer_class=_TOKENIZER_FOR_DOC,
1399
- checkpoint="microsoft/dialogrpt",
1400
  output_type=SequenceClassifierOutputWithPast,
1401
  config_class=_CONFIG_FOR_DOC,
 
 
1402
  )
1403
  def forward(
1404
  self,
1405
- input_ids=None,
1406
- past_key_values=None,
1407
- attention_mask=None,
1408
- token_type_ids=None,
1409
- position_ids=None,
1410
- head_mask=None,
1411
- inputs_embeds=None,
1412
- labels=None,
1413
- use_cache=None,
1414
- output_attentions=None,
1415
- output_hidden_states=None,
1416
- return_dict=None,
1417
- ):
1418
  r"""
1419
- labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
1420
- Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
1421
- config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
1422
- If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1423
  """
1424
  return_dict = (
1425
  return_dict if return_dict is not None else self.config.use_return_dict
@@ -1460,23 +1664,39 @@ class GPT2ForSequenceClassification(GPT2PreTrainedModel):
1460
  sequence_lengths = -1
1461
  logger.warning(
1462
  f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
1463
- f"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
1464
  )
1465
 
1466
- pooled_logits = logits[range(batch_size), sequence_lengths]
 
 
1467
 
1468
  loss = None
1469
  if labels is not None:
1470
- if self.num_labels == 1:
1471
- # We are doing regression
 
 
 
 
 
 
 
 
 
1472
  loss_fct = MSELoss()
1473
- loss = loss_fct(pooled_logits.view(-1), labels.to(self.dtype).view(-1))
1474
- else:
 
 
 
1475
  loss_fct = CrossEntropyLoss()
1476
  loss = loss_fct(
1477
  pooled_logits.view(-1, self.num_labels), labels.view(-1)
1478
  )
1479
-
 
 
1480
  if not return_dict:
1481
  output = (pooled_logits,) + transformer_outputs[1:]
1482
  return ((loss,) + output) if loss is not None else output
@@ -1515,39 +1735,44 @@ class GPT2ForTokenClassification(GPT2PreTrainedModel):
1515
  self.dropout = nn.Dropout(classifier_dropout)
1516
  self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1517
 
1518
- self.init_weights()
1519
-
1520
  # Model parallel
1521
  self.model_parallel = False
1522
  self.device_map = None
1523
 
 
 
 
1524
  @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
 
1525
  @add_code_sample_docstrings(
1526
- tokenizer_class=_TOKENIZER_FOR_DOC,
1527
- checkpoint="microsoft/DialogRPT-updown",
1528
  output_type=TokenClassifierOutput,
1529
  config_class=_CONFIG_FOR_DOC,
 
 
1530
  )
 
1531
  def forward(
1532
  self,
1533
- input_ids=None,
1534
- past_key_values=None,
1535
- attention_mask=None,
1536
- token_type_ids=None,
1537
- position_ids=None,
1538
- head_mask=None,
1539
- inputs_embeds=None,
1540
- labels=None,
1541
- use_cache=None,
1542
- output_attentions=None,
1543
- output_hidden_states=None,
1544
- return_dict=None,
1545
- ):
1546
  r"""
1547
- labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
1548
- Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
1549
- config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
1550
- If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1551
  """
1552
  return_dict = (
1553
  return_dict if return_dict is not None else self.config.use_return_dict
@@ -1574,18 +1799,7 @@ class GPT2ForTokenClassification(GPT2PreTrainedModel):
1574
  loss = None
1575
  if labels is not None:
1576
  loss_fct = CrossEntropyLoss()
1577
- # Only keep active parts of the loss
1578
- if attention_mask is not None:
1579
- active_loss = attention_mask.view(-1) == 1
1580
- active_logits = logits.view(-1, self.num_labels)
1581
- active_labels = torch.where(
1582
- active_loss,
1583
- labels.view(-1),
1584
- torch.tensor(loss_fct.ignore_index).type_as(labels),
1585
- )
1586
- loss = loss_fct(active_logits, active_labels)
1587
- else:
1588
- loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1589
 
1590
  if not return_dict:
1591
  output = (logits,) + transformer_outputs[2:]
@@ -1596,4 +1810,4 @@ class GPT2ForTokenClassification(GPT2PreTrainedModel):
1596
  logits=logits,
1597
  hidden_states=transformer_outputs.hidden_states,
1598
  attentions=transformer_outputs.attentions,
1599
- )
13
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
  # See the License for the specific language governing permissions and
15
  # limitations under the License.
16
+ """PyTorch GROVER model."""
17
 
18
+ import math
 
 
 
 
 
 
 
19
  import os
20
  from dataclasses import dataclass
21
+ from typing import Optional, Tuple, Union
22
 
23
  import torch
24
+ import torch.utils.checkpoint
25
+ from packaging import version
26
+ from torch import nn
27
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
28
+
29
+
30
+ if version.parse(torch.__version__) >= version.parse("1.6"):
31
+ is_amp_available = True
32
+ from torch.cuda.amp import autocast
33
+ else:
34
+ is_amp_available = False
35
+
36
  from transformers.activations import ACT2FN
 
 
 
 
 
 
 
37
  from transformers.modeling_outputs import (
38
  BaseModelOutputWithPastAndCrossAttentions,
39
  CausalLMOutputWithCrossAttentions,
40
  SequenceClassifierOutputWithPast,
41
  TokenClassifierOutput,
42
  )
43
+ from transformers.modeling_utils import PreTrainedModel, SequenceSummary
44
+ from transformers.pytorch_utils import (
45
  Conv1D,
 
 
46
  find_pruneable_heads_and_indices,
47
  prune_conv1d_layer,
48
  )
49
+ from transformers.utils import (
50
+ ModelOutput,
51
+ add_code_sample_docstrings,
52
+ add_start_docstrings,
53
+ add_start_docstrings_to_model_forward,
54
+ logging,
55
+ replace_return_docstrings,
56
+ )
57
  from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
58
+ from transformers import GPT2Config
59
 
 
 
60
 
61
+ logger = logging.get_logger(__name__)
62
 
63
+ _CHECKPOINT_FOR_DOC = "gpt2"
64
  _CONFIG_FOR_DOC = "GPT2Config"
65
  _TOKENIZER_FOR_DOC = "GPT2Tokenizer"
66
 
73
  # See all GPT-2 models at https://huggingface.co/models?filter=gpt2
74
  ]
75
 
 
 
 
 
 
76
  _GPT2_ML_TF_TO_TORCH = {
77
  "LayerNorm_embed_norm": "emb_norm",
78
  "pos_embed": "wpe.weight",
124
  """Load tf checkpoints in a pytorch model"""
125
  try:
126
  import re
 
127
  import tensorflow as tf
128
  except ImportError:
129
  logger.error(
203
  d = torch.from_numpy(array)
204
  is_bias = len(shape) == 1
205
  end = int(shape[0 if is_bias else 1] / 3)
206
+ m = dict(query_layer=0, key_layer=end, value_layer=end * 2,)
 
 
 
 
207
  start = m[attn_layer]
208
  end = start + end
209
  if is_bias:
225
  return model
226
 
227
 
228
+ class GPT2Attention(nn.Module):
229
+ def __init__(self, config, is_cross_attention=False, layer_idx=None):
230
  super().__init__()
231
 
232
+ max_positions = config.max_position_embeddings
 
 
233
  self.register_buffer(
234
  "bias",
235
+ torch.tril(
236
+ torch.ones((max_positions, max_positions), dtype=torch.uint8)
237
+ ).view(1, 1, max_positions, max_positions),
238
  )
239
  self.register_buffer("masked_bias", torch.tensor(-1e4))
240
+
241
+ self.embed_dim = config.hidden_size
242
+ self.num_heads = config.num_attention_heads
243
+ self.head_dim = self.embed_dim // self.num_heads
244
+ self.split_size = self.embed_dim
245
+ if self.head_dim * self.num_heads != self.embed_dim:
246
+ raise ValueError(
247
+ f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
248
+ f" {self.num_heads})."
249
+ )
250
+
251
+ self.scale_attn_weights = config.scale_attn_weights
252
  self.is_cross_attention = is_cross_attention
253
+
254
+ # Layer-wise attention scaling, reordering, and upcasting
255
+ self.scale_attn_by_inverse_layer_idx = config.scale_attn_by_inverse_layer_idx
256
+ self.layer_idx = layer_idx
257
+ self.reorder_and_upcast_attn = config.reorder_and_upcast_attn
258
+
259
  if self.is_cross_attention:
260
+ self.c_attn = Conv1D(2 * self.embed_dim, self.embed_dim)
261
+ self.q_attn = Conv1D(self.embed_dim, self.embed_dim)
262
  else:
263
+ self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim)
264
+ self.c_proj = Conv1D(self.embed_dim, self.embed_dim)
265
+
266
  self.attn_dropout = nn.Dropout(config.attn_pdrop)
267
  self.resid_dropout = nn.Dropout(config.resid_pdrop)
268
+
269
  self.pruned_heads = set()
270
 
271
  def prune_heads(self, heads):
272
  if len(heads) == 0:
273
  return
274
  heads, index = find_pruneable_heads_and_indices(
275
+ heads, self.num_heads, self.head_dim, self.pruned_heads
276
  )
277
  index_attn = torch.cat(
278
  [index, index + self.split_size, index + (2 * self.split_size)]
283
  self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0)
284
 
285
  # Update hyper params
286
+ self.split_size = (self.split_size // self.num_heads) * (
287
+ self.num_heads - len(heads)
288
+ )
289
+ self.num_heads = self.num_heads - len(heads)
290
  self.pruned_heads = self.pruned_heads.union(heads)
291
 
292
+ def _attn(self, query, key, value, attention_mask=None, head_mask=None):
293
+ attn_weights = torch.matmul(query, key.transpose(-1, -2))
294
+
295
+ if self.scale_attn_weights:
296
+ attn_weights = attn_weights / (value.size(-1) ** 0.5)
297
+
298
+ # Layer-wise attention scaling
299
+ if self.scale_attn_by_inverse_layer_idx:
300
+ attn_weights = attn_weights / float(self.layer_idx + 1)
301
 
302
  if not self.is_cross_attention:
303
  # if only "normal" attention layer implements causal mask
304
+ query_length, key_length = query.size(-2), key.size(-2)
305
+ causal_mask = self.bias[
306
+ :, :, key_length - query_length : key_length, :key_length
307
+ ].bool()
308
+ attn_weights = torch.where(
309
+ causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype)
310
+ )
311
 
312
  if attention_mask is not None:
313
  # Apply the attention mask
314
+ attn_weights = attn_weights + attention_mask
315
 
316
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
317
+
318
+ # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise
319
+ attn_weights = attn_weights.type(value.dtype)
320
+ attn_weights = self.attn_dropout(attn_weights)
321
 
322
  # Mask heads if we want to
323
  if head_mask is not None:
324
+ attn_weights = attn_weights * head_mask
325
 
326
+ attn_output = torch.matmul(attn_weights, value)
327
+
328
+ return attn_output, attn_weights
329
+
330
+ def _upcast_and_reordered_attn(
331
+ self, query, key, value, attention_mask=None, head_mask=None
332
+ ):
333
+ # Use `torch.baddbmm` (a bit more efficient w/ alpha param for scaling -- from Megatron-LM)
334
+ bsz, num_heads, q_seq_len, dk = query.size()
335
+ _, _, k_seq_len, _ = key.size()
336
+
337
+ # Preallocate attn_weights for `baddbmm`
338
+ attn_weights = torch.empty(
339
+ bsz * num_heads,
340
+ q_seq_len,
341
+ k_seq_len,
342
+ dtype=torch.float32,
343
+ device=query.device,
344
+ )
345
+
346
+ # Compute Scale Factor
347
+ scale_factor = 1.0
348
+ if self.scale_attn_weights:
349
+ scale_factor /= float(value.size(-1)) ** 0.5
350
+
351
+ if self.scale_attn_by_inverse_layer_idx:
352
+ scale_factor /= float(self.layer_idx + 1)
353
+
354
+ # Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk))
355
+ if is_amp_available:
356
+ with autocast(enabled=False):
357
+ q, k = (
358
+ query.reshape(-1, q_seq_len, dk),
359
+ key.transpose(-1, -2).reshape(-1, dk, k_seq_len),
360
+ )
361
+ attn_weights = torch.baddbmm(
362
+ attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor
363
+ )
364
+ attn_weights = attn_weights.reshape(
365
+ bsz, num_heads, q_seq_len, k_seq_len
366
+ )
367
  else:
368
+ q, k = (
369
+ query.reshape(-1, q_seq_len, dk),
370
+ key.transpose(-1, -2).reshape(-1, dk, k_seq_len),
371
+ )
372
+ attn_weights = torch.baddbmm(
373
+ attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor
374
+ )
375
+ attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
376
+
377
+ if not self.is_cross_attention:
378
+ # if only "normal" attention layer implements causal mask
379
+ query_length, key_length = query.size(-2), key.size(-2)
380
+ causal_mask = self.bias[
381
+ :, :, key_length - query_length : key_length, :key_length
382
+ ].bool()
383
+ attn_weights = torch.where(
384
+ causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype)
385
+ )
386
+
387
+ if attention_mask is not None:
388
+ # Apply the attention mask
389
+ attn_weights = attn_weights + attention_mask
390
+
391
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
392
+
393
+ # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op if otherwise
394
+ if attn_weights.dtype != torch.float32:
395
+ raise RuntimeError(
396
+ "Error with upcasting, attn_weights does not have dtype torch.float32"
397
+ )
398
+ attn_weights = attn_weights.type(value.dtype)
399
+ attn_weights = self.attn_dropout(attn_weights)
400
+
401
+ # Mask heads if we want to
402
+ if head_mask is not None:
403
+ attn_weights = attn_weights * head_mask
404
+
405
+ attn_output = torch.matmul(attn_weights, value)
406
+
407
+ return attn_output, attn_weights
408
+
409
+ def _split_heads(self, tensor, num_heads, attn_head_size):
410
+ """
411
+ Splits hidden_size dim into attn_head_size and num_heads
412
+ """
413
+ new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
414
+ tensor = tensor.view(new_shape)
415
+ return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
416
+
417
+ def _merge_heads(self, tensor, num_heads, attn_head_size):
418
+ """
419
+ Merges attn_head_size dim and num_attn_heads dim into hidden_size
420
+ """
421
+ tensor = tensor.permute(0, 2, 1, 3).contiguous()
422
+ new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,)
423
+ return tensor.view(new_shape)
424
 
425
  def forward(
426
  self,
427
+ hidden_states: Optional[Tuple[torch.FloatTensor]],
428
+ layer_past: Optional[Tuple[torch.Tensor]] = None,
429
+ attention_mask: Optional[torch.FloatTensor] = None,
430
+ head_mask: Optional[torch.FloatTensor] = None,
431
+ encoder_hidden_states: Optional[torch.Tensor] = None,
432
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
433
+ use_cache: Optional[bool] = False,
434
+ output_attentions: Optional[bool] = False,
435
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
436
  if encoder_hidden_states is not None:
437
+ if not hasattr(self, "q_attn"):
438
+ raise ValueError(
439
+ "If class is used as cross attention, the weights `q_attn` have to be defined. "
440
+ "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`."
441
+ )
442
+
443
  query = self.q_attn(hidden_states)
444
  key, value = self.c_attn(encoder_hidden_states).split(
445
  self.split_size, dim=2
448
  else:
449
  query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
450
 
451
+ query = self._split_heads(query, self.num_heads, self.head_dim)
452
+ key = self._split_heads(key, self.num_heads, self.head_dim)
453
+ value = self._split_heads(value, self.num_heads, self.head_dim)
454
+
455
  if layer_past is not None:
456
+ past_key, past_value = layer_past
457
+ key = torch.cat((past_key, key), dim=-2)
 
 
 
458
  value = torch.cat((past_value, value), dim=-2)
459
 
460
  if use_cache is True:
461
+ present = (key, value)
 
 
462
  else:
463
+ present = None
464
 
465
+ if self.reorder_and_upcast_attn:
466
+ attn_output, attn_weights = self._upcast_and_reordered_attn(
467
+ query, key, value, attention_mask, head_mask
468
+ )
469
+ else:
470
+ attn_output, attn_weights = self._attn(
471
+ query, key, value, attention_mask, head_mask
472
+ )
473
+
474
+ attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
475
+ attn_output = self.c_proj(attn_output)
476
+ attn_output = self.resid_dropout(attn_output)
477
 
478
+ outputs = (attn_output, present)
479
+ if output_attentions:
480
+ outputs += (attn_weights,)
481
 
 
482
  return outputs # a, present, (attentions)
483
 
484
 
485
+ class GPT2MLP(nn.Module):
486
+ def __init__(self, intermediate_size, config):
487
  super().__init__()
488
+ embed_dim = config.hidden_size
489
+ self.c_fc = Conv1D(intermediate_size, embed_dim)
490
+ self.c_proj = Conv1D(embed_dim, intermediate_size)
491
  self.act = ACT2FN[config.activation_function]
492
  self.dropout = nn.Dropout(config.resid_pdrop)
493
 
494
+ def forward(
495
+ self, hidden_states: Optional[Tuple[torch.FloatTensor]]
496
+ ) -> torch.FloatTensor:
497
+ hidden_states = self.c_fc(hidden_states)
498
+ hidden_states = self.act(hidden_states)
499
+ hidden_states = self.c_proj(hidden_states)
500
+ hidden_states = self.dropout(hidden_states)
501
+ return hidden_states
502
 
503
 
504
+ class GPT2Block(nn.Module):
505
+ def __init__(self, config, layer_idx=None):
506
  super().__init__()
507
+ hidden_size = config.hidden_size
508
  inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
509
+
510
  self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
511
+ self.attn = GPT2Attention(config, layer_idx=layer_idx)
512
  self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
513
+
514
  if config.add_cross_attention:
515
+ self.crossattention = GPT2Attention(
516
+ config, is_cross_attention=True, layer_idx=layer_idx
517
  )
518
  self.ln_cross_attn = nn.LayerNorm(
519
  hidden_size, eps=config.layer_norm_epsilon
520
  )
521
+
522
+ self.mlp = GPT2MLP(inner_dim, config)
523
 
524
  def forward(
525
  self,
526
+ hidden_states: Optional[Tuple[torch.FloatTensor]],
527
+ layer_past: Optional[Tuple[torch.Tensor]] = None,
528
+ attention_mask: Optional[torch.FloatTensor] = None,
529
+ head_mask: Optional[torch.FloatTensor] = None,
530
+ encoder_hidden_states: Optional[torch.Tensor] = None,
531
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
532
+ use_cache: Optional[bool] = False,
533
+ output_attentions: Optional[bool] = False,
534
+ ) -> Union[
535
+ Tuple[torch.Tensor],
536
+ Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]],
537
+ ]:
538
+
539
+ # removed in GROVER
540
+ # residual = hidden_states
541
+ # hidden_states = self.ln_1(hidden_states)
542
  attn_outputs = self.attn(
543
  hidden_states,
544
  layer_past=layer_past,
554
 
555
  if encoder_hidden_states is not None:
556
  # add one self-attention block for cross-attention
557
+ if not hasattr(self, "crossattention"):
558
+ raise ValueError(
559
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with "
560
+ "cross-attention layers by setting `config.add_cross_attention=True`"
561
+ )
562
+ # removed in GROVER
563
+ # residual = hidden_states
564
+ # hidden_states = self.ln_cross_attn(hidden_states)
565
  cross_attn_outputs = self.crossattention(
566
+ hidden_states,
567
  attention_mask=attention_mask,
568
  head_mask=head_mask,
569
  encoder_hidden_states=encoder_hidden_states,
572
  )
573
  attn_output = cross_attn_outputs[0]
574
  # residual connection
575
+ hidden_states = attn_output + hidden_states
576
  outputs = (
577
  outputs + cross_attn_outputs[2:]
578
  ) # add cross attentions if we output attention weights
579
 
580
+ residual = hidden_states
581
+ hidden_states = self.ln_1(hidden_states)
582
+ feed_forward_hidden_states = self.mlp(hidden_states)
583
  # residual connection
584
+ hidden_states = residual + feed_forward_hidden_states
585
 
586
+ hidden_states = self.ln_2(hidden_states) # Added in GROVER
587
+
588
+ if use_cache:
589
+ outputs = (hidden_states,) + outputs
590
+ else:
591
+ outputs = (hidden_states,) + outputs[1:]
592
 
 
593
  return outputs # hidden_states, present, (attentions, cross_attentions)
594
 
595
 
603
  load_tf_weights = load_tf_weights_in_gpt2
604
  base_model_prefix = "transformer"
605
  is_parallelizable = True
606
+ supports_gradient_checkpointing = True
607
 
608
  def __init__(self, *inputs, **kwargs):
609
  super().__init__(*inputs, **kwargs)
610
 
611
  def _init_weights(self, module):
612
  """Initialize the weights."""
613
+ if isinstance(module, (nn.Linear, Conv1D)):
614
  # Slightly different from the TF version which uses truncated_normal for initialization
615
  # cf https://github.com/pytorch/pytorch/pull/5617
616
  module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
617
+ if module.bias is not None:
618
  module.bias.data.zero_()
619
+ elif isinstance(module, nn.Embedding):
620
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
621
+ if module.padding_idx is not None:
622
+ module.weight.data[module.padding_idx].zero_()
623
  elif isinstance(module, nn.LayerNorm):
624
  module.bias.data.zero_()
625
  module.weight.data.fill_(1.0)
626
 
627
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
628
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
629
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
630
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
631
+ #
632
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
633
+ for name, p in module.named_parameters():
634
+ if "c_proj" in name and "weight" in name:
635
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
636
+ p.data.normal_(
637
+ mean=0.0,
638
+ std=(
639
+ self.config.initializer_range
640
+ / math.sqrt(2 * self.config.n_layer)
641
+ ),
642
+ )
643
+
644
+ def _set_gradient_checkpointing(self, module, value=False):
645
+ if isinstance(module, GPT2Model):
646
+ module.gradient_checkpointing = value
647
+
648
 
649
  @dataclass
650
  class GPT2DoubleHeadsModelOutput(ModelOutput):
652
  Base class for outputs of models predicting if two sentences are consecutive or not.
653
 
654
  Args:
655
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
656
  Language modeling loss.
657
+ mc_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `mc_labels` is provided):
658
  Multiple choice classification loss.
659
+ logits (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, config.vocab_size)`):
660
  Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
661
+ mc_logits (`torch.FloatTensor` of shape `(batch_size, num_choices)`):
662
  Prediction scores of the multiple choice classification head (scores for each choice before SoftMax).
663
+ past_key_values (`Tuple[Tuple[torch.Tensor]]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
664
+ Tuple of length `config.n_layers`, containing tuples of tensors of shape `(batch_size, num_heads,
665
+ sequence_length, embed_size_per_head)`).
666
 
667
  Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
668
+ `past_key_values` input) to speed up sequential decoding.
669
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
670
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
671
+ shape `(batch_size, sequence_length, hidden_size)`.
672
 
673
  Hidden-states of the model at the output of each layer plus the initial embedding outputs.
674
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
675
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
676
+ sequence_length)`.
677
 
678
+ GPT2Attentions weights after the attention softmax, used to compute the weighted average in the
679
+ self-attention heads.
680
  """
681
 
682
  loss: Optional[torch.FloatTensor] = None
683
  mc_loss: Optional[torch.FloatTensor] = None
684
  logits: torch.FloatTensor = None
685
  mc_logits: torch.FloatTensor = None
686
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
687
  hidden_states: Optional[Tuple[torch.FloatTensor]] = None
688
  attentions: Optional[Tuple[torch.FloatTensor]] = None
689
 
690
 
691
  GPT2_START_DOCSTRING = r"""
692
 
693
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
694
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
695
+ etc.)
696
 
697
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
698
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
699
+ and behavior.
700
 
701
  Parameters:
702
+ config ([`GPT2Config`]): Model configuration class with all the parameters of the model.
703
  Initializing with a config file does not load the weights associated with the model, only the
704
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
 
705
  """
706
 
707
  GPT2_INPUTS_DOCSTRING = r"""
708
  Args:
709
+ input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
710
+ `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
711
+ `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input
712
  sequence tokens in the vocabulary.
713
 
714
+ If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
715
+ `input_ids`.
716
 
717
+ Indices can be obtained using [`GPT2Tokenizer`]. See [`PreTrainedTokenizer.encode`] and
718
+ [`PreTrainedTokenizer.__call__`] for details.
 
719
 
720
+ [What are input IDs?](../glossary#input-ids)
721
+ past_key_values (`Tuple[Tuple[torch.Tensor]]` of length `config.n_layers`):
722
  Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see
723
+ `past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have
724
+ their past given to this model should not be passed as `input_ids` as they have already been computed.
725
+ attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
726
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
 
727
 
728
  - 1 for tokens that are **not masked**,
729
  - 0 for tokens that are **masked**.
730
 
731
+ If `past_key_values` is used, `attention_mask` needs to contain the masking strategy that was used for
732
+ `past_key_values`. In other words, the `attention_mask` always has to have the length:
733
+ `len(past_key_values) + len(input_ids)`
 
734
 
735
+ [What are attention masks?](../glossary#attention-mask)
736
+ token_type_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`, *optional*):
737
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
738
+ 1]`:
739
 
740
+ - 0 corresponds to a *sentence A* token,
741
+ - 1 corresponds to a *sentence B* token.
 
 
742
 
743
+ [What are token type IDs?](../glossary#token-type-ids)
744
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
745
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
746
+ config.max_position_embeddings - 1]`.
747
+
748
+ [What are position IDs?](../glossary#position-ids)
749
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
750
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
751
 
752
  - 1 indicates the head is **not masked**,
753
  - 0 indicates the head is **masked**.
754
 
755
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
756
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
757
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
758
+ model's internal embedding lookup matrix.
759
+
760
+ If `past_key_values` is used, optionally only the last `inputs_embeds` have to be input (see
761
+ `past_key_values`).
762
+ use_cache (`bool`, *optional*):
763
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
764
+ `past_key_values`).
765
+ output_attentions (`bool`, *optional*):
766
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
767
  tensors for more detail.
768
+ output_hidden_states (`bool`, *optional*):
769
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
770
  more detail.
771
+ return_dict (`bool`, *optional*):
772
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
773
  """
 
774
  PARALLELIZE_DOCSTRING = r"""
775
  This is an experimental feature and is a subject to change at a moment's notice.
776
 
778
  it will evenly distribute blocks across all devices.
779
 
780
  Args:
781
+ device_map (`Dict[int, list]`, optional, defaults to None):
782
  A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always
783
  automatically mapped to the first device (for esoteric reasons). That means that the first device should
784
  have fewer attention modules mapped to it than other devices. For reference, the gpt2 models have the
789
  - gpt2-large: 36
790
  - gpt2-xl: 48
791
 
792
+ Example:
793
+
794
+ ```python
795
+ # Here is an example of a device map on a machine with 4 GPUs using gpt2-xl, which has a total of 48 attention modules:
796
+ model = GPT2LMHeadModel.from_pretrained("gpt2-xl")
797
+ device_map = {
798
+ 0: [0, 1, 2, 3, 4, 5, 6, 7, 8],
799
+ 1: [9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21],
800
+ 2: [22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34],
801
+ 3: [35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47],
802
+ }
803
+ model.parallelize(device_map)
804
+ ```
805
  """
806
  DEPARALLELIZE_DOCSTRING = r"""
807
  Moves the model to cpu from a model parallel state.
808
 
809
+ Example:
810
+
811
+ ```python
812
+ # On a 4 GPU machine with gpt2-large:
813
+ model = GPT2LMHeadModel.from_pretrained("gpt2-large")
814
+ device_map = {
815
+ 0: [0, 1, 2, 3, 4, 5, 6, 7],
816
+ 1: [8, 9, 10, 11, 12, 13, 14, 15],
817
+ 2: [16, 17, 18, 19, 20, 21, 22, 23],
818
+ 3: [24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35],
819
+ }
820
+ model.parallelize(device_map) # Splits the model across several devices
821
+ model.deparallelize() # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache()
822
+ ```
823
  """
824
 
825
 
828
  GPT2_START_DOCSTRING,
829
  )
830
  class GPT2Model(GPT2PreTrainedModel):
831
+ _keys_to_ignore_on_load_missing = ["attn.masked_bias"]
832
+
833
  def __init__(self, config):
834
  super().__init__(config)
835
 
836
+ self.embed_dim = config.hidden_size
 
 
 
837
 
838
+ self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
839
+ self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
840
+ self.emb_norm = nn.LayerNorm(
841
+ config.n_embd, eps=config.layer_norm_epsilon
842
+ ) # Added in GROVER
843
  self.drop = nn.Dropout(config.embd_pdrop)
844
  self.h = nn.ModuleList(
845
+ [GPT2Block(config, layer_idx=i) for i in range(config.num_hidden_layers)]
846
  )
847
+ # Removed in GROVER
848
+ # self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
 
 
849
 
850
  # Model parallel
851
  self.model_parallel = False
852
  self.device_map = None
853
+ self.gradient_checkpointing = False
854
+
855
+ # Initialize weights and apply final processing
856
+ self.post_init()
857
 
858
  @add_start_docstrings(PARALLELIZE_DOCSTRING)
859
  def parallelize(self, device_map=None):
873
  self.last_device = "cuda:" + str(max(self.device_map.keys()))
874
  self.wte = self.wte.to(self.first_device)
875
  self.wpe = self.wpe.to(self.first_device)
876
+
877
+ # Added in GROVER
878
+ # Wissam: not sure if is fine being on cpu or Better on GPU
879
+ self.emb_norm = self.emb_norm.to(
880
+ "cuda:" + str(min(self.device_map.keys()))
881
+ ) # GPU
882
+ # self.emb_norm = self.emb_norm.to(self.first_device) # CPU
883
+
884
  # Load onto devices
885
  for k, v in self.device_map.items():
886
  for block in v:
887
  cuda_device = "cuda:" + str(k)
888
  self.h[block] = self.h[block].to(cuda_device)
889
  # ln_f to last
890
+ # Removed in GROVER
891
+ # self.ln_f = self.ln_f.to(self.last_device)
892
 
893
  @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
894
  def deparallelize(self):
898
  self.last_device = "cpu"
899
  self.wte = self.wte.to("cpu")
900
  self.wpe = self.wpe.to("cpu")
901
+ # Added in GROVER
902
+ self.emb_norm = self.emb_norm.to("cpu")
903
  for index in range(len(self.h)):
904
  self.h[index] = self.h[index].to("cpu")
905
+ # Removed in GROVER
906
+ # self.ln_f = self.ln_f.to("cpu")
907
  torch.cuda.empty_cache()
908
 
909
  def get_input_embeddings(self):
921
 
922
  @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
923
  @add_code_sample_docstrings(
924
+ processor_class=_TOKENIZER_FOR_DOC,
925
+ checkpoint=_CHECKPOINT_FOR_DOC,
926
  output_type=BaseModelOutputWithPastAndCrossAttentions,
927
  config_class=_CONFIG_FOR_DOC,
928
  )
929
  def forward(
930
  self,
931
+ input_ids: Optional[torch.LongTensor] = None,
932
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
933
+ attention_mask: Optional[torch.FloatTensor] = None,
934
+ token_type_ids: Optional[torch.LongTensor] = None,
935
+ position_ids: Optional[torch.LongTensor] = None,
936
+ head_mask: Optional[torch.FloatTensor] = None,
937
+ inputs_embeds: Optional[torch.FloatTensor] = None,
938
+ encoder_hidden_states: Optional[torch.Tensor] = None,
939
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
940
+ use_cache: Optional[bool] = None,
941
+ output_attentions: Optional[bool] = None,
942
+ output_hidden_states: Optional[bool] = None,
943
+ return_dict: Optional[bool] = None,
944
+ ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
945
  output_attentions = (
946
  output_attentions
947
  if output_attentions is not None
971
  else:
972
  raise ValueError("You have to specify either input_ids or inputs_embeds")
973
 
974
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
975
+
976
  if token_type_ids is not None:
977
  token_type_ids = token_type_ids.view(-1, input_shape[-1])
978
  if position_ids is not None:
980
 
981
  if past_key_values is None:
982
  past_length = 0
983
+ past_key_values = tuple([None] * len(self.h))
984
  else:
985
  past_length = past_key_values[0][0].size(-2)
986
  if position_ids is None:
 
987
  position_ids = torch.arange(
988
  past_length,
989
  input_shape[-1] + past_length,
992
  )
993
  position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
994
 
995
+ # GPT2Attention mask.
996
  if attention_mask is not None:
997
  if batch_size <= 0:
998
  raise ValueError("batch_size has to be defined and > 0")
1012
  attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
1013
  attention_mask = (1.0 - attention_mask) * -10000.0
1014
 
1015
+ # If a 2D or 3D attention mask is provided for the cross-attention
1016
  # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
1017
  if self.config.add_cross_attention and encoder_hidden_states is not None:
1018
  (
1043
  hidden_states = hidden_states + token_type_embeds
1044
 
1045
  hidden_states = self.drop(hidden_states)
1046
+ # Added in Grover
1047
+ hidden_states = self.emb_norm(hidden_states)
1048
+
1049
  output_shape = input_shape + (hidden_states.size(-1),)
1050
 
1051
  presents = () if use_cache else None
1069
  attention_mask = attention_mask.to(hidden_states.device)
1070
  if isinstance(head_mask, torch.Tensor):
1071
  head_mask = head_mask.to(hidden_states.device)
 
1072
  if output_hidden_states:
1073
+ all_hidden_states = all_hidden_states + (hidden_states,)
 
 
1074
 
1075
+ if self.gradient_checkpointing and self.training:
1076
+
1077
+ if use_cache:
1078
+ logger.warning(
1079
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
1080
+ )
1081
+ use_cache = False
1082
 
1083
  def create_custom_forward(module):
1084
  def custom_forward(*inputs):
1085
+ # None for past_key_value
1086
+ return module(*inputs, use_cache, output_attentions)
 
 
 
1087
 
1088
  return custom_forward
1089
 
1090
  outputs = torch.utils.checkpoint.checkpoint(
1091
  create_custom_forward(block),
1092
  hidden_states,
1093
+ None,
1094
  attention_mask,
1095
  head_mask[i],
1096
  encoder_hidden_states,
1108
  output_attentions=output_attentions,
1109
  )
1110
 
1111
+ hidden_states = outputs[0]
1112
  if use_cache is True:
1113
+ presents = presents + (outputs[1],)
1114
 
1115
  if output_attentions:
1116
  all_self_attentions = all_self_attentions + (
1127
  if i == v[-1] and "cuda:" + str(k) != self.last_device:
1128
  hidden_states = hidden_states.to("cuda:" + str(k + 1))
1129
 
1130
+ # Removed in Grover
1131
+ # hidden_states = self.ln_f(hidden_states)
1132
 
1133
+ hidden_states = hidden_states.view(output_shape)
1134
  # Add last hidden state
1135
  if output_hidden_states:
1136
  all_hidden_states = all_hidden_states + (hidden_states,)
1165
  GPT2_START_DOCSTRING,
1166
  )
1167
  class GPT2LMHeadModel(GPT2PreTrainedModel):
1168
+ _keys_to_ignore_on_load_missing = [
1169
+ r"attn.masked_bias",
1170
+ r"attn.bias",
1171
+ r"lm_head.weight",
1172
+ ]
1173
 
1174
  def __init__(self, config):
1175
  super().__init__(config)
1176
  self.transformer = GPT2Model(config)
1177
  self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
1178
 
 
 
1179
  # Model parallel
1180
  self.model_parallel = False
1181
  self.device_map = None
1182
 
1183
+ # Initialize weights and apply final processing
1184
+ self.post_init()
1185
+
1186
  @add_start_docstrings(PARALLELIZE_DOCSTRING)
1187
  def parallelize(self, device_map=None):
1188
  self.device_map = (
1206
  def get_output_embeddings(self):
1207
  return self.lm_head
1208
 
1209
+ def set_output_embeddings(self, new_embeddings):
1210
+ self.lm_head = new_embeddings
1211
+
1212
  def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
1213
  token_type_ids = kwargs.get("token_type_ids", None)
1214
  # only last token for inputs_ids if past is defined in kwargs
1239
 
1240
  @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
1241
  @add_code_sample_docstrings(
1242
+ processor_class=_TOKENIZER_FOR_DOC,
1243
+ checkpoint=_CHECKPOINT_FOR_DOC,
1244
  output_type=CausalLMOutputWithCrossAttentions,
1245
  config_class=_CONFIG_FOR_DOC,
1246
  )
1247
  def forward(
1248
  self,
1249
+ input_ids: Optional[torch.LongTensor] = None,
1250
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
1251
+ attention_mask: Optional[torch.FloatTensor] = None,
1252
+ token_type_ids: Optional[torch.LongTensor] = None,
1253
+ position_ids: Optional[torch.LongTensor] = None,
1254
+ head_mask: Optional[torch.FloatTensor] = None,
1255
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1256
+ encoder_hidden_states: Optional[torch.Tensor] = None,
1257
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
1258
+ labels: Optional[torch.LongTensor] = None,
1259
+ use_cache: Optional[bool] = None,
1260
+ output_attentions: Optional[bool] = None,
1261
+ output_hidden_states: Optional[bool] = None,
1262
+ return_dict: Optional[bool] = None,
1263
+ ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
1264
  r"""
1265
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1266
  Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
1267
+ `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
1268
+ are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
1269
  """
1270
  return_dict = (
1271
  return_dict if return_dict is not None else self.config.use_return_dict
1324
  past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
1325
  ) -> Tuple[Tuple[torch.Tensor]]:
1326
  """
1327
+ This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
1328
+ [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
1329
+ beam_idx at every generation step.
1330
  """
1331
  return tuple(
1332
  tuple(
1347
  GPT2_START_DOCSTRING,
1348
  )
1349
  class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
1350
+ _keys_to_ignore_on_load_missing = [
1351
+ r"attn.masked_bias",
1352
+ r"attn.bias",
1353
+ r"lm_head.weight",
1354
+ ]
1355
+
1356
  def __init__(self, config):
1357
  super().__init__(config)
1358
  config.num_labels = 1
1360
  self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
1361
  self.multiple_choice_head = SequenceSummary(config)
1362
 
 
 
1363
  # Model parallel
1364
  self.model_parallel = False
1365
  self.device_map = None
1366
 
1367
+ # Initialize weights and apply final processing
1368
+ self.post_init()
1369
+
1370
  @add_start_docstrings(PARALLELIZE_DOCSTRING)
1371
  def parallelize(self, device_map=None):
1372
  self.device_map = (
1394
  def get_output_embeddings(self):
1395
  return self.lm_head
1396
 
1397
+ def set_output_embeddings(self, new_embeddings):
1398
+ self.lm_head = new_embeddings
1399
+
1400
  def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
1401
  token_type_ids = kwargs.get("token_type_ids", None)
1402
  # only last token for inputs_ids if past is defined in kwargs
1432
  )
1433
  def forward(
1434
  self,
1435
+ input_ids: Optional[torch.LongTensor] = None,
1436
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
1437
+ attention_mask: Optional[torch.FloatTensor] = None,
1438
+ token_type_ids: Optional[torch.LongTensor] = None,
1439
+ position_ids: Optional[torch.LongTensor] = None,
1440
+ head_mask: Optional[torch.FloatTensor] = None,
1441
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1442
+ mc_token_ids: Optional[torch.LongTensor] = None,
1443
+ labels: Optional[torch.LongTensor] = None,
1444
+ mc_labels: Optional[torch.LongTensor] = None,
1445
+ use_cache: Optional[bool] = None,
1446
+ output_attentions: Optional[bool] = None,
1447
+ output_hidden_states: Optional[bool] = None,
1448
+ return_dict: Optional[bool] = None,
1449
  **kwargs,
1450
+ ) -> Union[Tuple, GPT2DoubleHeadsModelOutput]:
1451
  r"""
1452
+ mc_token_ids (`torch.LongTensor` of shape `(batch_size, num_choices)`, *optional*, default to index of the last token of the input):
1453
+ Index of the classification token in each input sequence. Selected in the range `[0, input_ids.size(-1) -
1454
+ 1[`.
1455
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1456
  Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
1457
+ `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size - 1]` All labels set to
1458
+ `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size - 1]`
1459
+ mc_labels (`torch.LongTensor` of shape `(batch_size)`, *optional*):
1460
+ Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]`
1461
+ where *num_choices* is the size of the second dimension of the input tensors. (see *input_ids* above)
 
1462
 
1463
  Return:
1464
 
1465
+ Example:
 
 
 
 
 
 
1466
 
1467
+ ```python
1468
+ >>> import torch
1469
+ >>> from transformers import GPT2Tokenizer, GPT2DoubleHeadsModel
1470
 
1471
+ >>> tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
1472
+ >>> model = GPT2DoubleHeadsModel.from_pretrained("gpt2")
1473
 
1474
+ >>> # Add a [CLS] to the vocabulary (we should train it also!)
1475
+ >>> num_added_tokens = tokenizer.add_special_tokens({"cls_token": "[CLS]"})
1476
+ >>> # Update the model embeddings with the new vocabulary size
1477
+ >>> embedding_layer = model.resize_token_embeddings(len(tokenizer))
1478
 
1479
+ >>> choices = ["Hello, my dog is cute [CLS]", "Hello, my cat is cute [CLS]"]
1480
+ >>> encoded_choices = [tokenizer.encode(s) for s in choices]
1481
+ >>> cls_token_location = [tokens.index(tokenizer.cls_token_id) for tokens in encoded_choices]
1482
 
1483
+ >>> input_ids = torch.tensor(encoded_choices).unsqueeze(0) # Batch size: 1, number of choices: 2
1484
+ >>> mc_token_ids = torch.tensor([cls_token_location]) # Batch size: 1
 
1485
 
1486
+ >>> outputs = model(input_ids, mc_token_ids=mc_token_ids)
1487
+ >>> lm_logits = outputs.logits
1488
+ >>> mc_logits = outputs.mc_logits
1489
+ ```"""
1490
  return_dict = (
1491
  return_dict if return_dict is not None else self.config.use_return_dict
1492
  )
1551
  past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
1552
  ) -> Tuple[Tuple[torch.Tensor]]:
1553
  """
1554
+ This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
1555
+ [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
1556
+ beam_idx at every generation step.
1557
  """
1558
  return tuple(
1559
  tuple(
1568
  """
1569
  The GPT2 Model transformer with a sequence classification head on top (linear layer).
1570
 
1571
+ [`GPT2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1572
+ (e.g. GPT-1) do.
1573
 
1574
  Since it does classification on the last token, it requires to know the position of the last token. If a
1575
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1576
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1577
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1578
+ each row of the batch).
1579
  """,
1580
  GPT2_START_DOCSTRING,
1581
  )
1588
  self.transformer = GPT2Model(config)
1589
  self.score = nn.Linear(config.n_embd, self.num_labels, bias=False)
1590
 
 
 
1591
  # Model parallel
1592
  self.model_parallel = False
1593
  self.device_map = None
1594
 
1595
+ # Initialize weights and apply final processing
1596
+ self.post_init()
1597
+
1598
  @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
1599
  @add_code_sample_docstrings(
1600
+ processor_class=_TOKENIZER_FOR_DOC,
1601
+ checkpoint="microsoft/DialogRPT-updown",
1602
  output_type=SequenceClassifierOutputWithPast,
1603
  config_class=_CONFIG_FOR_DOC,
1604
+ expected_output="'LABEL_0'",
1605
+ expected_loss=5.28,
1606
  )
1607
  def forward(
1608
  self,
1609
+ input_ids: Optional[torch.LongTensor] = None,
1610
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
1611
+ attention_mask: Optional[torch.FloatTensor] = None,
1612
+ token_type_ids: Optional[torch.LongTensor] = None,
1613
+ position_ids: Optional[torch.LongTensor] = None,
1614
+ head_mask: Optional[torch.FloatTensor] = None,
1615
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1616
+ labels: Optional[torch.LongTensor] = None,
1617
+ use_cache: Optional[bool] = None,
1618
+ output_attentions: Optional[bool] = None,
1619
+ output_hidden_states: Optional[bool] = None,
1620
+ return_dict: Optional[bool] = None,
1621
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1622
  r"""
1623
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1624
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1625
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1626
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1627
  """
1628
  return_dict = (
1629
  return_dict if return_dict is not None else self.config.use_return_dict
1664
  sequence_lengths = -1
1665
  logger.warning(
1666
  f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
1667
+ "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
1668
  )
1669
 
1670
+ pooled_logits = logits[
1671
+ torch.arange(batch_size, device=self.device), sequence_lengths
1672
+ ]
1673
 
1674
  loss = None
1675
  if labels is not None:
1676
+ if self.config.problem_type is None:
1677
+ if self.num_labels == 1:
1678
+ self.config.problem_type = "regression"
1679
+ elif self.num_labels > 1 and (
1680
+ labels.dtype == torch.long or labels.dtype == torch.int
1681
+ ):
1682
+ self.config.problem_type = "single_label_classification"
1683
+ else:
1684
+ self.config.problem_type = "multi_label_classification"
1685
+
1686
+ if self.config.problem_type == "regression":
1687
  loss_fct = MSELoss()
1688
+ if self.num_labels == 1:
1689
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1690
+ else:
1691
+ loss = loss_fct(pooled_logits, labels)
1692
+ elif self.config.problem_type == "single_label_classification":
1693
  loss_fct = CrossEntropyLoss()
1694
  loss = loss_fct(
1695
  pooled_logits.view(-1, self.num_labels), labels.view(-1)
1696
  )
1697
+ elif self.config.problem_type == "multi_label_classification":
1698
+ loss_fct = BCEWithLogitsLoss()
1699
+ loss = loss_fct(pooled_logits, labels)
1700
  if not return_dict:
1701
  output = (pooled_logits,) + transformer_outputs[1:]
1702
  return ((loss,) + output) if loss is not None else output
1735
  self.dropout = nn.Dropout(classifier_dropout)
1736
  self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1737
 
 
 
1738
  # Model parallel
1739
  self.model_parallel = False
1740
  self.device_map = None
1741
 
1742
+ # Initialize weights and apply final processing
1743
+ self.post_init()
1744
+
1745
  @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
1746
+ # fmt: off
1747
  @add_code_sample_docstrings(
1748
+ processor_class=_TOKENIZER_FOR_DOC,
1749
+ checkpoint="brad1141/gpt2-finetuned-comp2",
1750
  output_type=TokenClassifierOutput,
1751
  config_class=_CONFIG_FOR_DOC,
1752
+ expected_loss=0.25,
1753
+ expected_output=["Lead", "Lead", "Lead", "Position", "Lead", "Lead", "Lead", "Lead", "Lead", "Lead", "Lead", "Lead"],
1754
  )
1755
+ # fmt: on
1756
  def forward(
1757
  self,
1758
+ input_ids: Optional[torch.LongTensor] = None,
1759
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
1760
+ attention_mask: Optional[torch.FloatTensor] = None,
1761
+ token_type_ids: Optional[torch.LongTensor] = None,
1762
+ position_ids: Optional[torch.LongTensor] = None,
1763
+ head_mask: Optional[torch.FloatTensor] = None,
1764
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1765
+ labels: Optional[torch.LongTensor] = None,
1766
+ use_cache: Optional[bool] = None,
1767
+ output_attentions: Optional[bool] = None,
1768
+ output_hidden_states: Optional[bool] = None,
1769
+ return_dict: Optional[bool] = None,
1770
+ ) -> Union[Tuple, TokenClassifierOutput]:
1771
  r"""
1772
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1773
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1774
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1775
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1776
  """
1777
  return_dict = (
1778
  return_dict if return_dict is not None else self.config.use_return_dict
1799
  loss = None
1800
  if labels is not None:
1801
  loss_fct = CrossEntropyLoss()
1802
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
 
 
 
 
 
 
 
 
 
 
 
1803
 
1804
  if not return_dict:
1805
  output = (logits,) + transformer_outputs[2:]
1810
  logits=logits,
1811
  hidden_states=transformer_outputs.hidden_states,
1812
  attentions=transformer_outputs.attentions,
1813
+ )