wissamantoun commited on
Commit
ce026c5
1 Parent(s): 5be9630

updated gpt2 to transformer 4.10

Browse files

I hope it works ( i didnt test the parralelize method)

Files changed (1) hide show
  1. backend/modeling_gpt2.py +502 -99
backend/modeling_gpt2.py CHANGED
@@ -23,42 +23,35 @@ and https://github.com/ghosthamlet/gpt2-ml-torch/blob/master/gpt2_ml_torch/model
23
 
24
  import logging
25
  import os
26
-
27
  from dataclasses import dataclass
28
  from typing import List, Optional, Tuple
29
 
30
  import torch
31
  import torch.nn as nn
32
  from torch.nn import CrossEntropyLoss, MSELoss
33
-
34
-
35
-
36
  from transformers.activations import ACT2FN
37
- from transformers import GPT2Config
38
-
39
- from transformers.modeling_utils import (
40
- Conv1D,
41
- PreTrainedModel,
42
- SequenceSummary,
43
- prune_conv1d_layer,
44
- find_pruneable_heads_and_indices
45
  )
46
-
47
- from transformers import CONFIG_NAME, WEIGHTS_NAME, GPT2Config, GPT2Model
48
-
49
  from transformers.modeling_outputs import (
50
  BaseModelOutputWithPastAndCrossAttentions,
51
  CausalLMOutputWithCrossAttentions,
52
- SequenceClassifierOutputWithPast
 
53
  )
54
-
55
- from transformers.file_utils import (
56
- ModelOutput,
57
- add_start_docstrings,
58
- add_start_docstrings_to_model_forward,
59
- add_code_sample_docstrings,
60
- replace_return_docstrings
61
  )
 
62
 
63
  # THe Difference from Transformers is code under _USE_GROVER
64
  _USE_GROVER = True
@@ -83,30 +76,30 @@ console.setLevel(logging.INFO)
83
  logger.addHandler(console)
84
 
85
  _GPT2_ML_TF_TO_TORCH = {
86
- 'LayerNorm_embed_norm': 'emb_norm',
87
- 'pos_embed': 'wpe.weight',
88
- 'word_embed': 'wte.weight',
89
-
90
- 'layer': 'h',
91
- # Most importently This two layer norm must be put on the same position as gpt2-ml
92
- # or generated data is bad, just repeat the last token
93
- 'LayerNorm_mlp_ln0': 'ln_1',
94
- 'LayerNorm_mlp_ln1': 'ln_2',
95
- 'intermediate': 'mlp.c_fc',
96
- 'output': 'mlp.c_proj',
97
- 'query_layer': 'attn.c_attn',
98
- 'key_layer': 'attn.c_attn',
99
- 'value_layer': 'attn.c_attn',
100
- 'context_projection_layer': 'attn.c_proj',
101
-
102
- 'gamma': 'weight',
103
- 'kernel': 'weight',
104
- 'beta': 'bias',
105
- 'bias': 'bias',
106
  }
107
 
108
 
109
- def convert_gpt2_checkpoint_to_pytorch(gpt2_checkpoint_path, gpt2_config_file, pytorch_dump_folder_path):
 
 
110
  # Construct model
111
  if gpt2_config_file == "":
112
  config = GPT2Config()
@@ -130,10 +123,10 @@ def convert_gpt2_checkpoint_to_pytorch(gpt2_checkpoint_path, gpt2_config_file, p
130
  # XXX: MUST do like: convert_gpt2_checkpoint_to_pytorch('./model.ckpt-100000', './mega.json', './')
131
  # https://github.com/tensorflow/models/issues/2675#issuecomment-516595597
132
  def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
133
- """ Load tf checkpoints in a pytorch model
134
- """
135
  try:
136
  import re
 
137
  import tensorflow as tf
138
  except ImportError:
139
  logger.error(
@@ -154,6 +147,7 @@ def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
154
  arrays.append(array.squeeze())
155
 
156
  import copy
 
157
  orig_model = copy.deepcopy(model)
158
 
159
  for name, array in zip(names, arrays):
@@ -161,7 +155,7 @@ def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
161
  name = name.split("/")
162
  pointer = model
163
 
164
- attn_layer = ''
165
  for m_name in name:
166
  if re.fullmatch(r"[A-Za-z]+\d+", m_name):
167
  scope_names = re.split(r"(\d+)", m_name)
@@ -169,23 +163,23 @@ def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
169
  scope_names = [m_name]
170
  sname = scope_names[0]
171
 
172
- if sname == '' or sname == 'embeddings':
173
  continue
174
  elif sname not in _GPT2_ML_TF_TO_TORCH:
175
- print('=========================================================')
176
- logger.info('Skip var name {}'.format(scope_names))
177
  pointer = None
178
  break
179
  else:
180
  tname = _GPT2_ML_TF_TO_TORCH[sname]
181
- if '.' in tname:
182
- parent, child = tname.split('.')
183
  pointer = getattr(pointer, parent)
184
  pointer = getattr(pointer, child)
185
  else:
186
  pointer = getattr(pointer, tname)
187
 
188
- if tname == 'attn.c_attn':
189
  attn_layer = sname
190
 
191
  if len(scope_names) >= 2:
@@ -194,39 +188,47 @@ def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
194
 
195
  if pointer is None:
196
  continue
197
- if attn_layer == '':
198
  try:
199
  assert pointer.shape == array.shape
200
  except AssertionError as e:
201
  e.args += (pointer.shape, array.shape)
202
  raise
203
- logger.info("Initialize PyTorch weight {}, {}, {}".format(name, array.mean(), pointer.mean()))
204
- if attn_layer == '':
 
 
 
 
205
  pointer.data = torch.from_numpy(array)
206
  else:
207
  shape = pointer.shape
208
  d = torch.from_numpy(array)
209
  is_bias = len(shape) == 1
210
- end = int(shape[0 if is_bias else 1]/3)
211
  m = dict(
212
- query_layer=0,
213
- key_layer=end,
214
- value_layer=end*2,
215
- )
216
  start = m[attn_layer]
217
  end = start + end
218
  if is_bias:
219
  pointer.data[start:end] = d
220
  else:
221
  pointer.data[:, start:end] = d
222
- logger.info("Initialize PyTorch weight {}, {}, {}".format(name, array.mean(), pointer.mean()))
 
 
 
 
223
 
224
  for name, params in orig_model.named_parameters():
225
  for n, p in model.named_parameters():
226
  if name == n:
227
  if params.equal(p):
228
- print('--------------------------')
229
- print(' %s not changed!' % n)
230
  return model
231
 
232
 
@@ -238,7 +240,10 @@ class Attention(nn.Module):
238
  # [switch nx => n_state from Block to Attention to keep identical to TF implem]
239
  assert n_state % config.n_head == 0
240
  self.register_buffer(
241
- "bias", torch.tril(torch.ones((n_ctx, n_ctx), dtype=torch.uint8)).view(1, 1, n_ctx, n_ctx)
 
 
 
242
  )
243
  self.register_buffer("masked_bias", torch.tensor(-1e4))
244
  self.n_head = config.n_head
@@ -261,7 +266,9 @@ class Attention(nn.Module):
261
  heads, index = find_pruneable_heads_and_indices(
262
  heads, self.n_head, self.split_size // self.n_head, self.pruned_heads
263
  )
264
- index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)])
 
 
265
 
266
  # Prune conv1d layers
267
  self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1)
@@ -272,7 +279,9 @@ class Attention(nn.Module):
272
  self.n_head = self.n_head - len(heads)
273
  self.pruned_heads = self.pruned_heads.union(heads)
274
 
275
- def _attn(self, q, k, v, attention_mask=None, head_mask=None, output_attentions=False):
 
 
276
  w = torch.matmul(q, k)
277
  if self.scale:
278
  w = w / (float(v.size(-1)) ** 0.5)
@@ -328,7 +337,9 @@ class Attention(nn.Module):
328
  self, "q_attn"
329
  ), "If class is used as cross attention, the weights `q_attn` have to be defined. Please make sure to instantiate class with `Attention(..., is_cross_attention=True)`."
330
  query = self.q_attn(hidden_states)
331
- key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
 
 
332
  attention_mask = encoder_attention_mask
333
  else:
334
  query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
@@ -337,16 +348,23 @@ class Attention(nn.Module):
337
  key = self.split_heads(key, k=True)
338
  value = self.split_heads(value)
339
  if layer_past is not None:
340
- past_key, past_value = layer_past[0].transpose(-2, -1), layer_past[1] # transpose back cf below
 
 
 
341
  key = torch.cat((past_key, key), dim=-1)
342
  value = torch.cat((past_value, value), dim=-2)
343
 
344
  if use_cache is True:
345
- present = torch.stack((key.transpose(-2, -1), value)) # transpose to have same shapes for stacking
 
 
346
  else:
347
  present = (None,)
348
 
349
- attn_outputs = self._attn(query, key, value, attention_mask, head_mask, output_attentions)
 
 
350
  a = attn_outputs[0]
351
 
352
  a = self.merge_heads(a)
@@ -381,8 +399,12 @@ class Block(nn.Module):
381
  self.attn = Attention(hidden_size, n_ctx, config, scale)
382
  self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
383
  if config.add_cross_attention:
384
- self.crossattention = Attention(hidden_size, n_ctx, config, scale, is_cross_attention=True)
385
- self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
 
 
 
 
386
  self.mlp = MLP(inner_dim, config)
387
 
388
  def forward(
@@ -425,7 +447,9 @@ class Block(nn.Module):
425
  attn_output = cross_attn_outputs[0]
426
  # residual connection
427
  hidden_states = hidden_states + attn_output
428
- outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights
 
 
429
 
430
  feed_forward_hidden_states = self.mlp(self.ln_1(hidden_states))
431
  # residual connection
@@ -446,6 +470,7 @@ class GPT2PreTrainedModel(PreTrainedModel):
446
  config_class = GPT2Config
447
  load_tf_weights = load_tf_weights_in_gpt2
448
  base_model_prefix = "transformer"
 
449
 
450
  def __init__(self, *inputs, **kwargs):
451
  super().__init__(*inputs, **kwargs)
@@ -588,6 +613,51 @@ GPT2_INPUTS_DOCSTRING = r"""
588
  Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
589
  """
590
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
591
 
592
  @add_start_docstrings(
593
  "The bare GPT2 Model transformer outputting raw hidden-states without any specific head on top.",
@@ -603,12 +673,57 @@ class GPT2Model(GPT2PreTrainedModel):
603
  self.emb_norm = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
604
 
605
  self.drop = nn.Dropout(config.embd_pdrop)
606
- self.h = nn.ModuleList([Block(config.n_ctx, config, scale=True) for _ in range(config.n_layer)])
 
 
607
  if not _USE_GROVER:
608
  self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
609
 
610
  self.init_weights()
611
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
612
  def get_input_embeddings(self):
613
  return self.wte
614
 
@@ -645,15 +760,25 @@ class GPT2Model(GPT2PreTrainedModel):
645
  output_hidden_states=None,
646
  return_dict=None,
647
  ):
648
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
 
 
 
 
649
  output_hidden_states = (
650
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 
 
651
  )
652
  use_cache = use_cache if use_cache is not None else self.config.use_cache
653
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 
 
654
 
655
  if input_ids is not None and inputs_embeds is not None:
656
- raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
 
 
657
  elif input_ids is not None:
658
  input_shape = input_ids.size()
659
  input_ids = input_ids.view(-1, input_shape[-1])
@@ -676,12 +801,18 @@ class GPT2Model(GPT2PreTrainedModel):
676
  past_length = past_key_values[0][0].size(-2)
677
  if position_ids is None:
678
  device = input_ids.device if input_ids is not None else inputs_embeds.device
679
- position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
 
 
 
 
 
680
  position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
681
 
682
  # Attention mask.
683
  if attention_mask is not None:
684
- assert batch_size > 0, "batch_size has to be defined and > 0"
 
685
  attention_mask = attention_mask.view(batch_size, -1)
686
  # We create a 3D attention mask from a 2D tensor mask.
687
  # Sizes are [batch_size, 1, 1, to_seq_length]
@@ -701,7 +832,11 @@ class GPT2Model(GPT2PreTrainedModel):
701
  # If a 2D ou 3D attention mask is provided for the cross-attention
702
  # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
703
  if self.config.add_cross_attention and encoder_hidden_states is not None:
704
- encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
 
 
 
 
705
  encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
706
  if encoder_attention_mask is None:
707
  encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
@@ -731,18 +866,40 @@ class GPT2Model(GPT2PreTrainedModel):
731
 
732
  presents = () if use_cache else None
733
  all_self_attentions = () if output_attentions else None
734
- all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
 
 
735
  all_hidden_states = () if output_hidden_states else None
736
  for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
737
  if output_hidden_states:
738
- all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),)
 
 
739
 
740
  if getattr(self.config, "gradient_checkpointing", False):
741
 
742
  def create_custom_forward(module):
743
  def custom_forward(*inputs):
744
  # checkpointing only works with tuple returns, not with lists
745
- return tuple(output for output in module(*inputs, use_cache, output_attentions))
 
 
 
746
 
747
  return custom_forward
748
 
@@ -772,9 +929,19 @@ class GPT2Model(GPT2PreTrainedModel):
772
  presents = presents + (present,)
773
 
774
  if output_attentions:
775
- all_self_attentions = all_self_attentions + (outputs[2],)
 
 
776
  if self.config.add_cross_attention:
777
- all_cross_attentions = all_cross_attentions + (outputs[3],)
 
 
 
 
 
 
 
 
778
 
779
  if not _USE_GROVER:
780
  hidden_states = self.ln_f(hidden_states)
@@ -785,7 +952,17 @@ class GPT2Model(GPT2PreTrainedModel):
785
  all_hidden_states = all_hidden_states + (hidden_states,)
786
 
787
  if not return_dict:
788
- return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
 
 
 
 
 
 
 
 
 
 
789
 
790
  return BaseModelOutputWithPastAndCrossAttentions(
791
  last_hidden_state=hidden_states,
@@ -813,6 +990,30 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
813
 
814
  self.init_weights()
815
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
816
  def get_output_embeddings(self):
817
  return self.lm_head
818
 
@@ -848,7 +1049,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
848
  @add_code_sample_docstrings(
849
  tokenizer_class=_TOKENIZER_FOR_DOC,
850
  checkpoint="gpt2",
851
- output_type= CausalLMOutputWithCrossAttentions,
852
  config_class=_CONFIG_FOR_DOC,
853
  )
854
  def forward(
@@ -874,7 +1075,9 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
874
  ``labels = input_ids`` Indices are selected in ``[-100, 0, ..., config.vocab_size]`` All labels set to
875
  ``-100`` are ignored (masked), the loss is only computed for labels in ``[0, ..., config.vocab_size]``
876
  """
877
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 
 
878
 
879
  transformer_outputs = self.transformer(
880
  input_ids,
@@ -893,6 +1096,11 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
893
  )
894
  hidden_states = transformer_outputs[0]
895
 
 
 
 
 
 
896
  lm_logits = self.lm_head(hidden_states)
897
 
898
  loss = None
@@ -902,13 +1110,15 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
902
  shift_labels = labels[..., 1:].contiguous()
903
  # Flatten the tokens
904
  loss_fct = CrossEntropyLoss()
905
- loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
 
 
906
 
907
  if not return_dict:
908
  output = (lm_logits,) + transformer_outputs[1:]
909
  return ((loss,) + output) if loss is not None else output
910
 
911
- return CausalLMOutputWithCrossAttentions(
912
  loss=loss,
913
  logits=lm_logits,
914
  past_key_values=transformer_outputs.past_key_values,
@@ -917,6 +1127,23 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
917
  cross_attentions=transformer_outputs.cross_attentions,
918
  )
919
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
920
 
921
  @add_start_docstrings(
922
  """
@@ -937,6 +1164,34 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
937
 
938
  self.init_weights()
939
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
940
  def get_output_embeddings(self):
941
  return self.lm_head
942
 
@@ -970,7 +1225,9 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
970
  }
971
 
972
  @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
973
- @replace_return_docstrings(output_type=GPT2DoubleHeadsModelOutput, config_class=_CONFIG_FOR_DOC)
 
 
974
  def forward(
975
  self,
976
  input_ids=None,
@@ -1029,7 +1286,9 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
1029
  >>> mc_logits = outputs.mc_logits
1030
 
1031
  """
1032
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 
 
1033
 
1034
  transformer_outputs = self.transformer(
1035
  input_ids,
@@ -1047,19 +1306,28 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
1047
 
1048
  hidden_states = transformer_outputs[0]
1049
 
 
 
 
 
 
1050
  lm_logits = self.lm_head(hidden_states)
1051
  mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids).squeeze(-1)
1052
 
1053
  mc_loss = None
1054
  if mc_labels is not None:
1055
  loss_fct = CrossEntropyLoss()
1056
- mc_loss = loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1))
 
 
1057
  lm_loss = None
1058
  if labels is not None:
1059
  shift_logits = lm_logits[..., :-1, :].contiguous()
1060
  shift_labels = labels[..., 1:].contiguous()
1061
  loss_fct = CrossEntropyLoss()
1062
- lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
 
 
1063
 
1064
  if not return_dict:
1065
  output = (lm_logits, mc_logits) + transformer_outputs[1:]
@@ -1077,6 +1345,23 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
1077
  attentions=transformer_outputs.attentions,
1078
  )
1079
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1080
 
1081
  @add_start_docstrings(
1082
  """
@@ -1104,6 +1389,10 @@ class GPT2ForSequenceClassification(GPT2PreTrainedModel):
1104
 
1105
  self.init_weights()
1106
 
 
 
 
 
1107
  @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
1108
  @add_code_sample_docstrings(
1109
  tokenizer_class=_TOKENIZER_FOR_DOC,
@@ -1132,7 +1421,9 @@ class GPT2ForSequenceClassification(GPT2PreTrainedModel):
1132
  config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
1133
  If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1134
  """
1135
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 
 
1136
 
1137
  transformer_outputs = self.transformer(
1138
  input_ids,
@@ -1162,7 +1453,9 @@ class GPT2ForSequenceClassification(GPT2PreTrainedModel):
1162
  sequence_lengths = -1
1163
  else:
1164
  if input_ids is not None:
1165
- sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1
 
 
1166
  else:
1167
  sequence_lengths = -1
1168
  logger.warning(
@@ -1180,7 +1473,9 @@ class GPT2ForSequenceClassification(GPT2PreTrainedModel):
1180
  loss = loss_fct(pooled_logits.view(-1), labels.to(self.dtype).view(-1))
1181
  else:
1182
  loss_fct = CrossEntropyLoss()
1183
- loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
 
 
1184
 
1185
  if not return_dict:
1186
  output = (pooled_logits,) + transformer_outputs[1:]
@@ -1194,3 +1489,111 @@ class GPT2ForSequenceClassification(GPT2PreTrainedModel):
1194
  attentions=transformer_outputs.attentions,
1195
  )
1196
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
  import logging
25
  import os
 
26
  from dataclasses import dataclass
27
  from typing import List, Optional, Tuple
28
 
29
  import torch
30
  import torch.nn as nn
31
  from torch.nn import CrossEntropyLoss, MSELoss
32
+ from transformers import CONFIG_NAME, WEIGHTS_NAME, GPT2Config, GPT2Model
 
 
33
  from transformers.activations import ACT2FN
34
+ from transformers.file_utils import (
35
+ ModelOutput,
36
+ add_code_sample_docstrings,
37
+ add_start_docstrings,
38
+ add_start_docstrings_to_model_forward,
39
+ replace_return_docstrings,
 
 
40
  )
 
 
 
41
  from transformers.modeling_outputs import (
42
  BaseModelOutputWithPastAndCrossAttentions,
43
  CausalLMOutputWithCrossAttentions,
44
+ SequenceClassifierOutputWithPast,
45
+ TokenClassifierOutput,
46
  )
47
+ from transformers.modeling_utils import (
48
+ Conv1D,
49
+ PreTrainedModel,
50
+ SequenceSummary,
51
+ find_pruneable_heads_and_indices,
52
+ prune_conv1d_layer,
 
53
  )
54
+ from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
55
 
56
  # THe Difference from Transformers is code under _USE_GROVER
57
  _USE_GROVER = True
 
76
  logger.addHandler(console)
77
 
78
  _GPT2_ML_TF_TO_TORCH = {
79
+ "LayerNorm_embed_norm": "emb_norm",
80
+ "pos_embed": "wpe.weight",
81
+ "word_embed": "wte.weight",
82
+ "layer": "h",
83
+ # Most importently This two layer norm must be put on the same position as gpt2-ml
84
+ # or generated data is bad, just repeat the last token
85
+ "LayerNorm_mlp_ln0": "ln_1",
86
+ "LayerNorm_mlp_ln1": "ln_2",
87
+ "intermediate": "mlp.c_fc",
88
+ "output": "mlp.c_proj",
89
+ "query_layer": "attn.c_attn",
90
+ "key_layer": "attn.c_attn",
91
+ "value_layer": "attn.c_attn",
92
+ "context_projection_layer": "attn.c_proj",
93
+ "gamma": "weight",
94
+ "kernel": "weight",
95
+ "beta": "bias",
96
+ "bias": "bias",
 
 
97
  }
98
 
99
 
100
+ def convert_gpt2_checkpoint_to_pytorch(
101
+ gpt2_checkpoint_path, gpt2_config_file, pytorch_dump_folder_path
102
+ ):
103
  # Construct model
104
  if gpt2_config_file == "":
105
  config = GPT2Config()
 
123
  # XXX: MUST do like: convert_gpt2_checkpoint_to_pytorch('./model.ckpt-100000', './mega.json', './')
124
  # https://github.com/tensorflow/models/issues/2675#issuecomment-516595597
125
  def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
126
+ """Load tf checkpoints in a pytorch model"""
 
127
  try:
128
  import re
129
+
130
  import tensorflow as tf
131
  except ImportError:
132
  logger.error(
 
147
  arrays.append(array.squeeze())
148
 
149
  import copy
150
+
151
  orig_model = copy.deepcopy(model)
152
 
153
  for name, array in zip(names, arrays):
 
155
  name = name.split("/")
156
  pointer = model
157
 
158
+ attn_layer = ""
159
  for m_name in name:
160
  if re.fullmatch(r"[A-Za-z]+\d+", m_name):
161
  scope_names = re.split(r"(\d+)", m_name)
 
163
  scope_names = [m_name]
164
  sname = scope_names[0]
165
 
166
+ if sname == "" or sname == "embeddings":
167
  continue
168
  elif sname not in _GPT2_ML_TF_TO_TORCH:
169
+ print("=========================================================")
170
+ logger.info("Skip var name {}".format(scope_names))
171
  pointer = None
172
  break
173
  else:
174
  tname = _GPT2_ML_TF_TO_TORCH[sname]
175
+ if "." in tname:
176
+ parent, child = tname.split(".")
177
  pointer = getattr(pointer, parent)
178
  pointer = getattr(pointer, child)
179
  else:
180
  pointer = getattr(pointer, tname)
181
 
182
+ if tname == "attn.c_attn":
183
  attn_layer = sname
184
 
185
  if len(scope_names) >= 2:
 
188
 
189
  if pointer is None:
190
  continue
191
+ if attn_layer == "":
192
  try:
193
  assert pointer.shape == array.shape
194
  except AssertionError as e:
195
  e.args += (pointer.shape, array.shape)
196
  raise
197
+ logger.info(
198
+ "Initialize PyTorch weight {}, {}, {}".format(
199
+ name, array.mean(), pointer.mean()
200
+ )
201
+ )
202
+ if attn_layer == "":
203
  pointer.data = torch.from_numpy(array)
204
  else:
205
  shape = pointer.shape
206
  d = torch.from_numpy(array)
207
  is_bias = len(shape) == 1
208
+ end = int(shape[0 if is_bias else 1] / 3)
209
  m = dict(
210
+ query_layer=0,
211
+ key_layer=end,
212
+ value_layer=end * 2,
213
+ )
214
  start = m[attn_layer]
215
  end = start + end
216
  if is_bias:
217
  pointer.data[start:end] = d
218
  else:
219
  pointer.data[:, start:end] = d
220
+ logger.info(
221
+ "Initialize PyTorch weight {}, {}, {}".format(
222
+ name, array.mean(), pointer.mean()
223
+ )
224
+ )
225
 
226
  for name, params in orig_model.named_parameters():
227
  for n, p in model.named_parameters():
228
  if name == n:
229
  if params.equal(p):
230
+ print("--------------------------")
231
+ print(" %s not changed!" % n)
232
  return model
233
 
234
 
 
240
  # [switch nx => n_state from Block to Attention to keep identical to TF implem]
241
  assert n_state % config.n_head == 0
242
  self.register_buffer(
243
+ "bias",
244
+ torch.tril(torch.ones((n_ctx, n_ctx), dtype=torch.uint8)).view(
245
+ 1, 1, n_ctx, n_ctx
246
+ ),
247
  )
248
  self.register_buffer("masked_bias", torch.tensor(-1e4))
249
  self.n_head = config.n_head
 
266
  heads, index = find_pruneable_heads_and_indices(
267
  heads, self.n_head, self.split_size // self.n_head, self.pruned_heads
268
  )
269
+ index_attn = torch.cat(
270
+ [index, index + self.split_size, index + (2 * self.split_size)]
271
+ )
272
 
273
  # Prune conv1d layers
274
  self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1)
 
279
  self.n_head = self.n_head - len(heads)
280
  self.pruned_heads = self.pruned_heads.union(heads)
281
 
282
+ def _attn(
283
+ self, q, k, v, attention_mask=None, head_mask=None, output_attentions=False
284
+ ):
285
  w = torch.matmul(q, k)
286
  if self.scale:
287
  w = w / (float(v.size(-1)) ** 0.5)
 
337
  self, "q_attn"
338
  ), "If class is used as cross attention, the weights `q_attn` have to be defined. Please make sure to instantiate class with `Attention(..., is_cross_attention=True)`."
339
  query = self.q_attn(hidden_states)
340
+ key, value = self.c_attn(encoder_hidden_states).split(
341
+ self.split_size, dim=2
342
+ )
343
  attention_mask = encoder_attention_mask
344
  else:
345
  query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
 
348
  key = self.split_heads(key, k=True)
349
  value = self.split_heads(value)
350
  if layer_past is not None:
351
+ past_key, past_value = (
352
+ layer_past[0].transpose(-2, -1),
353
+ layer_past[1],
354
+ ) # transpose back cf below
355
  key = torch.cat((past_key, key), dim=-1)
356
  value = torch.cat((past_value, value), dim=-2)
357
 
358
  if use_cache is True:
359
+ present = torch.stack(
360
+ (key.transpose(-2, -1), value)
361
+ ) # transpose to have same shapes for stacking
362
  else:
363
  present = (None,)
364
 
365
+ attn_outputs = self._attn(
366
+ query, key, value, attention_mask, head_mask, output_attentions
367
+ )
368
  a = attn_outputs[0]
369
 
370
  a = self.merge_heads(a)
 
399
  self.attn = Attention(hidden_size, n_ctx, config, scale)
400
  self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
401
  if config.add_cross_attention:
402
+ self.crossattention = Attention(
403
+ hidden_size, n_ctx, config, scale, is_cross_attention=True
404
+ )
405
+ self.ln_cross_attn = nn.LayerNorm(
406
+ hidden_size, eps=config.layer_norm_epsilon
407
+ )
408
  self.mlp = MLP(inner_dim, config)
409
 
410
  def forward(
 
447
  attn_output = cross_attn_outputs[0]
448
  # residual connection
449
  hidden_states = hidden_states + attn_output
450
+ outputs = (
451
+ outputs + cross_attn_outputs[2:]
452
+ ) # add cross attentions if we output attention weights
453
 
454
  feed_forward_hidden_states = self.mlp(self.ln_1(hidden_states))
455
  # residual connection
 
470
  config_class = GPT2Config
471
  load_tf_weights = load_tf_weights_in_gpt2
472
  base_model_prefix = "transformer"
473
+ is_parallelizable = True
474
 
475
  def __init__(self, *inputs, **kwargs):
476
  super().__init__(*inputs, **kwargs)
 
613
  Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
614
  """
615
 
616
+ PARALLELIZE_DOCSTRING = r"""
617
+ This is an experimental feature and is a subject to change at a moment's notice.
618
+
619
+ Uses a device map to distribute attention modules of the model across several devices. If no device map is given,
620
+ it will evenly distribute blocks across all devices.
621
+
622
+ Args:
623
+ device_map (:obj:`Dict[int, list]`, optional, defaults to None):
624
+ A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always
625
+ automatically mapped to the first device (for esoteric reasons). That means that the first device should
626
+ have fewer attention modules mapped to it than other devices. For reference, the gpt2 models have the
627
+ following number of attention modules:
628
+
629
+ - gpt2: 12
630
+ - gpt2-medium: 24
631
+ - gpt2-large: 36
632
+ - gpt2-xl: 48
633
+
634
+ Example::
635
+
636
+ # Here is an example of a device map on a machine with 4 GPUs using gpt2-xl, which has a total of 48 attention modules:
637
+ model = GPT2LMHeadModel.from_pretrained('gpt2-xl')
638
+ device_map = {0: [0, 1, 2, 3, 4, 5, 6, 7, 8],
639
+
640
+ 1: [9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21],
641
+ 2: [22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34],
642
+ 3: [35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47]}
643
+ model.parallelize(device_map)
644
+ """
645
+ DEPARALLELIZE_DOCSTRING = r"""
646
+ Moves the model to cpu from a model parallel state.
647
+
648
+ Example::
649
+
650
+ # On a 4 GPU machine with gpt2-large:
651
+ model = GPT2LMHeadModel.from_pretrained('gpt2-large')
652
+ device_map = {0: [0, 1, 2, 3, 4, 5, 6, 7],
653
+
654
+ 1: [8, 9, 10, 11, 12, 13, 14, 15],
655
+ 2: [16, 17, 18, 19, 20, 21, 22, 23],
656
+ 3: [24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35]}
657
+ model.parallelize(device_map) # Splits the model across several devices
658
+ model.deparallelize() # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache()
659
+ """
660
+
661
 
662
  @add_start_docstrings(
663
  "The bare GPT2 Model transformer outputting raw hidden-states without any specific head on top.",
 
673
  self.emb_norm = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
674
 
675
  self.drop = nn.Dropout(config.embd_pdrop)
676
+ self.h = nn.ModuleList(
677
+ [Block(config.n_ctx, config, scale=True) for _ in range(config.n_layer)]
678
+ )
679
  if not _USE_GROVER:
680
  self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
681
 
682
  self.init_weights()
683
 
684
+ # Model parallel
685
+ self.model_parallel = False
686
+ self.device_map = None
687
+
688
+ @add_start_docstrings(PARALLELIZE_DOCSTRING)
689
+ def parallelize(self, device_map=None):
690
+ # Check validity of device_map
691
+ self.device_map = (
692
+ get_device_map(len(self.h), range(torch.cuda.device_count()))
693
+ if device_map is None
694
+ else device_map
695
+ )
696
+ assert_device_map(self.device_map, len(self.h))
697
+ self.model_parallel = True
698
+ self.first_device = (
699
+ "cpu"
700
+ if "cpu" in self.device_map.keys()
701
+ else "cuda:" + str(min(self.device_map.keys()))
702
+ )
703
+ self.last_device = "cuda:" + str(max(self.device_map.keys()))
704
+ self.wte = self.wte.to(self.first_device)
705
+ self.wpe = self.wpe.to(self.first_device)
706
+ # Load onto devices
707
+ for k, v in self.device_map.items():
708
+ for block in v:
709
+ cuda_device = "cuda:" + str(k)
710
+ self.h[block] = self.h[block].to(cuda_device)
711
+ # ln_f to last
712
+ self.ln_f = self.ln_f.to(self.last_device)
713
+
714
+ @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
715
+ def deparallelize(self):
716
+ self.model_parallel = False
717
+ self.device_map = None
718
+ self.first_device = "cpu"
719
+ self.last_device = "cpu"
720
+ self.wte = self.wte.to("cpu")
721
+ self.wpe = self.wpe.to("cpu")
722
+ for index in range(len(self.h)):
723
+ self.h[index] = self.h[index].to("cpu")
724
+ self.ln_f = self.ln_f.to("cpu")
725
+ torch.cuda.empty_cache()
726
+
727
  def get_input_embeddings(self):
728
  return self.wte
729
 
 
760
  output_hidden_states=None,
761
  return_dict=None,
762
  ):
763
+ output_attentions = (
764
+ output_attentions
765
+ if output_attentions is not None
766
+ else self.config.output_attentions
767
+ )
768
  output_hidden_states = (
769
+ output_hidden_states
770
+ if output_hidden_states is not None
771
+ else self.config.output_hidden_states
772
  )
773
  use_cache = use_cache if use_cache is not None else self.config.use_cache
774
+ return_dict = (
775
+ return_dict if return_dict is not None else self.config.use_return_dict
776
+ )
777
 
778
  if input_ids is not None and inputs_embeds is not None:
779
+ raise ValueError(
780
+ "You cannot specify both input_ids and inputs_embeds at the same time"
781
+ )
782
  elif input_ids is not None:
783
  input_shape = input_ids.size()
784
  input_ids = input_ids.view(-1, input_shape[-1])
 
801
  past_length = past_key_values[0][0].size(-2)
802
  if position_ids is None:
803
  device = input_ids.device if input_ids is not None else inputs_embeds.device
804
+ position_ids = torch.arange(
805
+ past_length,
806
+ input_shape[-1] + past_length,
807
+ dtype=torch.long,
808
+ device=device,
809
+ )
810
  position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
811
 
812
  # Attention mask.
813
  if attention_mask is not None:
814
+ if batch_size <= 0:
815
+ raise ValueError("batch_size has to be defined and > 0")
816
  attention_mask = attention_mask.view(batch_size, -1)
817
  # We create a 3D attention mask from a 2D tensor mask.
818
  # Sizes are [batch_size, 1, 1, to_seq_length]
 
832
  # If a 2D ou 3D attention mask is provided for the cross-attention
833
  # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
834
  if self.config.add_cross_attention and encoder_hidden_states is not None:
835
+ (
836
+ encoder_batch_size,
837
+ encoder_sequence_length,
838
+ _,
839
+ ) = encoder_hidden_states.size()
840
  encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
841
  if encoder_attention_mask is None:
842
  encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
 
866
 
867
  presents = () if use_cache else None
868
  all_self_attentions = () if output_attentions else None
869
+ all_cross_attentions = (
870
+ () if output_attentions and self.config.add_cross_attention else None
871
+ )
872
  all_hidden_states = () if output_hidden_states else None
873
  for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
874
+
875
+ # Model parallel
876
+ if self.model_parallel:
877
+ torch.cuda.set_device(hidden_states.device)
878
+ # Ensure layer_past is on same device as hidden_states (might not be correct)
879
+ if layer_past is not None:
880
+ layer_past = tuple(
881
+ past_state.to(hidden_states.device) for past_state in layer_past
882
+ )
883
+ # Ensure that attention_mask is always on the same device as hidden_states
884
+ if attention_mask is not None:
885
+ attention_mask = attention_mask.to(hidden_states.device)
886
+ if isinstance(head_mask, torch.Tensor):
887
+ head_mask = head_mask.to(hidden_states.device)
888
+
889
  if output_hidden_states:
890
+ all_hidden_states = all_hidden_states + (
891
+ hidden_states.view(*output_shape),
892
+ )
893
 
894
  if getattr(self.config, "gradient_checkpointing", False):
895
 
896
  def create_custom_forward(module):
897
  def custom_forward(*inputs):
898
  # checkpointing only works with tuple returns, not with lists
899
+ return tuple(
900
+ output
901
+ for output in module(*inputs, use_cache, output_attentions)
902
+ )
903
 
904
  return custom_forward
905
 
 
929
  presents = presents + (present,)
930
 
931
  if output_attentions:
932
+ all_self_attentions = all_self_attentions + (
933
+ outputs[2 if use_cache else 1],
934
+ )
935
  if self.config.add_cross_attention:
936
+ all_cross_attentions = all_cross_attentions + (
937
+ outputs[3 if use_cache else 2],
938
+ )
939
+
940
+ # Model Parallel: If it's the last layer for that device, put things on the next device
941
+ if self.model_parallel:
942
+ for k, v in self.device_map.items():
943
+ if i == v[-1] and "cuda:" + str(k) != self.last_device:
944
+ hidden_states = hidden_states.to("cuda:" + str(k + 1))
945
 
946
  if not _USE_GROVER:
947
  hidden_states = self.ln_f(hidden_states)
 
952
  all_hidden_states = all_hidden_states + (hidden_states,)
953
 
954
  if not return_dict:
955
+ return tuple(
956
+ v
957
+ for v in [
958
+ hidden_states,
959
+ presents,
960
+ all_hidden_states,
961
+ all_self_attentions,
962
+ all_cross_attentions,
963
+ ]
964
+ if v is not None
965
+ )
966
 
967
  return BaseModelOutputWithPastAndCrossAttentions(
968
  last_hidden_state=hidden_states,
 
990
 
991
  self.init_weights()
992
 
993
+ # Model parallel
994
+ self.model_parallel = False
995
+ self.device_map = None
996
+
997
+ @add_start_docstrings(PARALLELIZE_DOCSTRING)
998
+ def parallelize(self, device_map=None):
999
+ self.device_map = (
1000
+ get_device_map(len(self.transformer.h), range(torch.cuda.device_count()))
1001
+ if device_map is None
1002
+ else device_map
1003
+ )
1004
+ assert_device_map(self.device_map, len(self.transformer.h))
1005
+ self.transformer.parallelize(self.device_map)
1006
+ self.lm_head = self.lm_head.to(self.transformer.first_device)
1007
+ self.model_parallel = True
1008
+
1009
+ @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
1010
+ def deparallelize(self):
1011
+ self.transformer.deparallelize()
1012
+ self.transformer = self.transformer.to("cpu")
1013
+ self.lm_head = self.lm_head.to("cpu")
1014
+ self.model_parallel = False
1015
+ torch.cuda.empty_cache()
1016
+
1017
  def get_output_embeddings(self):
1018
  return self.lm_head
1019
 
 
1049
  @add_code_sample_docstrings(
1050
  tokenizer_class=_TOKENIZER_FOR_DOC,
1051
  checkpoint="gpt2",
1052
+ output_type=CausalLMOutputWithCrossAttentions,
1053
  config_class=_CONFIG_FOR_DOC,
1054
  )
1055
  def forward(
 
1075
  ``labels = input_ids`` Indices are selected in ``[-100, 0, ..., config.vocab_size]`` All labels set to
1076
  ``-100`` are ignored (masked), the loss is only computed for labels in ``[0, ..., config.vocab_size]``
1077
  """
1078
+ return_dict = (
1079
+ return_dict if return_dict is not None else self.config.use_return_dict
1080
+ )
1081
 
1082
  transformer_outputs = self.transformer(
1083
  input_ids,
 
1096
  )
1097
  hidden_states = transformer_outputs[0]
1098
 
1099
+ # Set device for model parallelism
1100
+ if self.model_parallel:
1101
+ torch.cuda.set_device(self.transformer.first_device)
1102
+ hidden_states = hidden_states.to(self.lm_head.weight.device)
1103
+
1104
  lm_logits = self.lm_head(hidden_states)
1105
 
1106
  loss = None
 
1110
  shift_labels = labels[..., 1:].contiguous()
1111
  # Flatten the tokens
1112
  loss_fct = CrossEntropyLoss()
1113
+ loss = loss_fct(
1114
+ shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
1115
+ )
1116
 
1117
  if not return_dict:
1118
  output = (lm_logits,) + transformer_outputs[1:]
1119
  return ((loss,) + output) if loss is not None else output
1120
 
1121
+ return CausalLMOutputWithCrossAttentions(
1122
  loss=loss,
1123
  logits=lm_logits,
1124
  past_key_values=transformer_outputs.past_key_values,
 
1127
  cross_attentions=transformer_outputs.cross_attentions,
1128
  )
1129
 
1130
+ @staticmethod
1131
+ def _reorder_cache(
1132
+ past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
1133
+ ) -> Tuple[Tuple[torch.Tensor]]:
1134
+ """
1135
+ This function is used to re-order the :obj:`past_key_values` cache if
1136
+ :meth:`~transformers.PreTrainedModel.beam_search` or :meth:`~transformers.PreTrainedModel.beam_sample` is
1137
+ called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step.
1138
+ """
1139
+ return tuple(
1140
+ tuple(
1141
+ past_state.index_select(0, beam_idx.to(past_state.device))
1142
+ for past_state in layer_past
1143
+ )
1144
+ for layer_past in past
1145
+ )
1146
+
1147
 
1148
  @add_start_docstrings(
1149
  """
 
1164
 
1165
  self.init_weights()
1166
 
1167
+ # Model parallel
1168
+ self.model_parallel = False
1169
+ self.device_map = None
1170
+
1171
+ @add_start_docstrings(PARALLELIZE_DOCSTRING)
1172
+ def parallelize(self, device_map=None):
1173
+ self.device_map = (
1174
+ get_device_map(len(self.transformer.h), range(torch.cuda.device_count()))
1175
+ if device_map is None
1176
+ else device_map
1177
+ )
1178
+ assert_device_map(self.device_map, len(self.transformer.h))
1179
+ self.transformer.parallelize(self.device_map)
1180
+ self.lm_head = self.lm_head.to(self.transformer.first_device)
1181
+ self.multiple_choice_head = self.multiple_choice_head.to(
1182
+ self.transformer.first_device
1183
+ )
1184
+ self.model_parallel = True
1185
+
1186
+ @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
1187
+ def deparallelize(self):
1188
+ self.transformer.deparallelize()
1189
+ self.transformer = self.transformer.to("cpu")
1190
+ self.lm_head = self.lm_head.to("cpu")
1191
+ self.multiple_choice_head = self.multiple_choice_head.to("cpu")
1192
+ self.model_parallel = False
1193
+ torch.cuda.empty_cache()
1194
+
1195
  def get_output_embeddings(self):
1196
  return self.lm_head
1197
 
 
1225
  }
1226
 
1227
  @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
1228
+ @replace_return_docstrings(
1229
+ output_type=GPT2DoubleHeadsModelOutput, config_class=_CONFIG_FOR_DOC
1230
+ )
1231
  def forward(
1232
  self,
1233
  input_ids=None,
 
1286
  >>> mc_logits = outputs.mc_logits
1287
 
1288
  """
1289
+ return_dict = (
1290
+ return_dict if return_dict is not None else self.config.use_return_dict
1291
+ )
1292
 
1293
  transformer_outputs = self.transformer(
1294
  input_ids,
 
1306
 
1307
  hidden_states = transformer_outputs[0]
1308
 
1309
+ # Set device for model parallelism
1310
+ if self.model_parallel:
1311
+ torch.cuda.set_device(self.transformer.first_device)
1312
+ hidden_states = hidden_states.to(self.lm_head.weight.device)
1313
+
1314
  lm_logits = self.lm_head(hidden_states)
1315
  mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids).squeeze(-1)
1316
 
1317
  mc_loss = None
1318
  if mc_labels is not None:
1319
  loss_fct = CrossEntropyLoss()
1320
+ mc_loss = loss_fct(
1321
+ mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1)
1322
+ )
1323
  lm_loss = None
1324
  if labels is not None:
1325
  shift_logits = lm_logits[..., :-1, :].contiguous()
1326
  shift_labels = labels[..., 1:].contiguous()
1327
  loss_fct = CrossEntropyLoss()
1328
+ lm_loss = loss_fct(
1329
+ shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
1330
+ )
1331
 
1332
  if not return_dict:
1333
  output = (lm_logits, mc_logits) + transformer_outputs[1:]
 
1345
  attentions=transformer_outputs.attentions,
1346
  )
1347
 
1348
+ @staticmethod
1349
+ def _reorder_cache(
1350
+ past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
1351
+ ) -> Tuple[Tuple[torch.Tensor]]:
1352
+ """
1353
+ This function is used to re-order the :obj:`past_key_values` cache if
1354
+ :meth:`~transformers.PreTrainedModel.beam_search` or :meth:`~transformers.PreTrainedModel.beam_sample` is
1355
+ called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step.
1356
+ """
1357
+ return tuple(
1358
+ tuple(
1359
+ past_state.index_select(0, beam_idx.to(past_state.device))
1360
+ for past_state in layer_past
1361
+ )
1362
+ for layer_past in past
1363
+ )
1364
+
1365
 
1366
  @add_start_docstrings(
1367
  """
 
1389
 
1390
  self.init_weights()
1391
 
1392
+ # Model parallel
1393
+ self.model_parallel = False
1394
+ self.device_map = None
1395
+
1396
  @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
1397
  @add_code_sample_docstrings(
1398
  tokenizer_class=_TOKENIZER_FOR_DOC,
 
1421
  config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
1422
  If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1423
  """
1424
+ return_dict = (
1425
+ return_dict if return_dict is not None else self.config.use_return_dict
1426
+ )
1427
 
1428
  transformer_outputs = self.transformer(
1429
  input_ids,
 
1453
  sequence_lengths = -1
1454
  else:
1455
  if input_ids is not None:
1456
+ sequence_lengths = (
1457
+ torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1
1458
+ )
1459
  else:
1460
  sequence_lengths = -1
1461
  logger.warning(
 
1473
  loss = loss_fct(pooled_logits.view(-1), labels.to(self.dtype).view(-1))
1474
  else:
1475
  loss_fct = CrossEntropyLoss()
1476
+ loss = loss_fct(
1477
+ pooled_logits.view(-1, self.num_labels), labels.view(-1)
1478
+ )
1479
 
1480
  if not return_dict:
1481
  output = (pooled_logits,) + transformer_outputs[1:]
 
1489
  attentions=transformer_outputs.attentions,
1490
  )
1491
 
1492
+
1493
+ @add_start_docstrings(
1494
+ """
1495
+ GPT2 Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
1496
+ Named-Entity-Recognition (NER) tasks.
1497
+ """,
1498
+ GPT2_START_DOCSTRING,
1499
+ )
1500
+ class GPT2ForTokenClassification(GPT2PreTrainedModel):
1501
+ def __init__(self, config):
1502
+ super().__init__(config)
1503
+ self.num_labels = config.num_labels
1504
+
1505
+ self.transformer = GPT2Model(config)
1506
+ if (
1507
+ hasattr(config, "classifier_dropout")
1508
+ and config.classifier_dropout is not None
1509
+ ):
1510
+ classifier_dropout = config.classifier_dropout
1511
+ elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None:
1512
+ classifier_dropout = config.hidden_dropout
1513
+ else:
1514
+ classifier_dropout = 0.1
1515
+ self.dropout = nn.Dropout(classifier_dropout)
1516
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1517
+
1518
+ self.init_weights()
1519
+
1520
+ # Model parallel
1521
+ self.model_parallel = False
1522
+ self.device_map = None
1523
+
1524
+ @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
1525
+ @add_code_sample_docstrings(
1526
+ tokenizer_class=_TOKENIZER_FOR_DOC,
1527
+ checkpoint="microsoft/DialogRPT-updown",
1528
+ output_type=TokenClassifierOutput,
1529
+ config_class=_CONFIG_FOR_DOC,
1530
+ )
1531
+ def forward(
1532
+ self,
1533
+ input_ids=None,
1534
+ past_key_values=None,
1535
+ attention_mask=None,
1536
+ token_type_ids=None,
1537
+ position_ids=None,
1538
+ head_mask=None,
1539
+ inputs_embeds=None,
1540
+ labels=None,
1541
+ use_cache=None,
1542
+ output_attentions=None,
1543
+ output_hidden_states=None,
1544
+ return_dict=None,
1545
+ ):
1546
+ r"""
1547
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
1548
+ Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
1549
+ config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
1550
+ If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1551
+ """
1552
+ return_dict = (
1553
+ return_dict if return_dict is not None else self.config.use_return_dict
1554
+ )
1555
+
1556
+ transformer_outputs = self.transformer(
1557
+ input_ids,
1558
+ past_key_values=past_key_values,
1559
+ attention_mask=attention_mask,
1560
+ token_type_ids=token_type_ids,
1561
+ position_ids=position_ids,
1562
+ head_mask=head_mask,
1563
+ inputs_embeds=inputs_embeds,
1564
+ use_cache=use_cache,
1565
+ output_attentions=output_attentions,
1566
+ output_hidden_states=output_hidden_states,
1567
+ return_dict=return_dict,
1568
+ )
1569
+
1570
+ hidden_states = transformer_outputs[0]
1571
+ hidden_states = self.dropout(hidden_states)
1572
+ logits = self.classifier(hidden_states)
1573
+
1574
+ loss = None
1575
+ if labels is not None:
1576
+ loss_fct = CrossEntropyLoss()
1577
+ # Only keep active parts of the loss
1578
+ if attention_mask is not None:
1579
+ active_loss = attention_mask.view(-1) == 1
1580
+ active_logits = logits.view(-1, self.num_labels)
1581
+ active_labels = torch.where(
1582
+ active_loss,
1583
+ labels.view(-1),
1584
+ torch.tensor(loss_fct.ignore_index).type_as(labels),
1585
+ )
1586
+ loss = loss_fct(active_logits, active_labels)
1587
+ else:
1588
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1589
+
1590
+ if not return_dict:
1591
+ output = (logits,) + transformer_outputs[2:]
1592
+ return ((loss,) + output) if loss is not None else output
1593
+
1594
+ return TokenClassifierOutput(
1595
+ loss=loss,
1596
+ logits=logits,
1597
+ hidden_states=transformer_outputs.hidden_states,
1598
+ attentions=transformer_outputs.attentions,
1599
+ )