Text Generation
Transformers
PyTorch
TensorBoard
Arabic
aragpt2
custom_code
wissamantoun commited on
Commit
ea715ce
1 Parent(s): b584e54

Update config.json with AraGPT2 model and auto mapping

Browse files
Files changed (3) hide show
  1. config.json +9 -4
  2. configuration_aragpt2.py +275 -0
  3. modeling_aragpt2.py +1989 -0
config.json CHANGED
@@ -1,8 +1,13 @@
1
  {
2
  "activation_function": "gelu_new",
3
  "architectures": [
4
- "GPT2LMHeadModel"
5
  ],
 
 
 
 
 
6
  "attention_probs_dropout_prob": 0.1,
7
  "attn_pdrop": 0.1,
8
  "bos_token_id": 0,
@@ -13,7 +18,7 @@
13
  "initializer_range": 0.014142135623731,
14
  "intermediate_size": 6144,
15
  "layer_norm_epsilon": 1e-05,
16
- "model_type": "gpt2",
17
  "n_ctx": 1024,
18
  "n_embd": 1536,
19
  "n_head": 24,
@@ -32,9 +37,9 @@
32
  "max_length": 50,
33
  "num_beams": 5,
34
  "top_p": 0.95,
35
- "repetition_penalty": 3.0,
36
  "no_repeat_ngram_size": 3
37
  }
38
  },
39
  "vocab_size": 64000
40
- }
 
1
  {
2
  "activation_function": "gelu_new",
3
  "architectures": [
4
+ "AraGPT2LMHeadModel"
5
  ],
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_aragpt2.AraGPT2Config",
8
+ "AutoForCausalLM": "modeling_aragpt2.AraGPT2ForCausalLM",
9
+ "AutoModel": "modeling_aragpt2.AraGPT2Model"
10
+ },
11
  "attention_probs_dropout_prob": 0.1,
12
  "attn_pdrop": 0.1,
13
  "bos_token_id": 0,
 
18
  "initializer_range": 0.014142135623731,
19
  "intermediate_size": 6144,
20
  "layer_norm_epsilon": 1e-05,
21
+ "model_type": "aragpt2",
22
  "n_ctx": 1024,
23
  "n_embd": 1536,
24
  "n_head": 24,
 
37
  "max_length": 50,
38
  "num_beams": 5,
39
  "top_p": 0.95,
40
+ "repetition_penalty": 3.0,
41
  "no_repeat_ngram_size": 3
42
  }
43
  },
44
  "vocab_size": 64000
45
+ }
configuration_aragpt2.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ """ AraAraGPT2 configuration"""
3
+ from collections import OrderedDict
4
+ from typing import Any, List, Mapping, Optional
5
+
6
+ from transformers import PreTrainedTokenizer, TensorType, is_torch_available
7
+ from transformers.configuration_utils import PretrainedConfig
8
+ from transformers.onnx import OnnxConfigWithPast, PatchingSpec
9
+ from transformers.utils import logging
10
+
11
+
12
+ logger = logging.get_logger(__name__)
13
+
14
+ AraGPT2_PRETRAINED_CONFIG_ARCHIVE_MAP = {
15
+ "aubmindlab/aragpt2-mega": "https://huggingface.co/aubmindlab/aragpt2-mega/resolve/main/config.json",
16
+ }
17
+
18
+
19
+ class AraGPT2Config(PretrainedConfig):
20
+ """
21
+ This is the configuration class to store the configuration of a [`AraAraGPT2Model`] or a [`TFAraAraGPT2Model`]. It is used to
22
+ instantiate a AraGPT2 model according to the specified arguments, defining the model architecture. Instantiating a
23
+ configuration with the defaults will yield a similar configuration to that of the AraGPT2
24
+ [aubmindlab/aragpt2-mega](https://huggingface.co/aubmindlab/aragpt2-mega) architecture.
25
+
26
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
27
+ documentation from [`PretrainedConfig`] for more information.
28
+
29
+
30
+ Args:
31
+ vocab_size (`int`, *optional*, defaults to 64000):
32
+ Vocabulary size of the AraGPT2 model. Defines the number of different tokens that can be represented by the
33
+ `inputs_ids` passed when calling [`AraGPT2Model`] or [`TFAraGPT2Model`].
34
+ n_positions (`int`, *optional*, defaults to 1024):
35
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
36
+ just in case (e.g., 512 or 1024 or 2048).
37
+ n_embd (`int`, *optional*, defaults to 768):
38
+ Dimensionality of the embeddings and hidden states.
39
+ n_layer (`int`, *optional*, defaults to 12):
40
+ Number of hidden layers in the Transformer encoder.
41
+ n_head (`int`, *optional*, defaults to 12):
42
+ Number of attention heads for each attention layer in the Transformer encoder.
43
+ n_inner (`int`, *optional*):
44
+ Dimensionality of the inner feed-forward layers. `None` will set it to 4 times n_embd
45
+ activation_function (`str`, *optional*, defaults to `"gelu_new"`):
46
+ Activation function, to be selected in the list `["relu", "silu", "gelu", "tanh", "gelu_new"]`.
47
+ resid_pdrop (`float`, *optional*, defaults to 0.1):
48
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
49
+ embd_pdrop (`float`, *optional*, defaults to 0.1):
50
+ The dropout ratio for the embeddings.
51
+ attn_pdrop (`float`, *optional*, defaults to 0.1):
52
+ The dropout ratio for the attention.
53
+ layer_norm_epsilon (`float`, *optional*, defaults to 1e-05):
54
+ The epsilon to use in the layer normalization layers.
55
+ initializer_range (`float`, *optional*, defaults to 0.02):
56
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
57
+ summary_type (`string`, *optional*, defaults to `"cls_index"`):
58
+ Argument used when doing sequence summary, used in the models [`GPT2DoubleHeadsModel`] and
59
+ [`TFGPT2DoubleHeadsModel`].
60
+
61
+ Has to be one of the following options:
62
+
63
+ - `"last"`: Take the last token hidden state (like XLNet).
64
+ - `"first"`: Take the first token hidden state (like BERT).
65
+ - `"mean"`: Take the mean of all tokens hidden states.
66
+ - `"cls_index"`: Supply a Tensor of classification token position (like GPT/AraGPT2).
67
+ - `"attn"`: Not implemented now, use multi-head attention.
68
+ summary_use_proj (`bool`, *optional*, defaults to `True`):
69
+ Argument used when doing sequence summary, used in the models [`GPT2DoubleHeadsModel`] and
70
+ [`TFGPT2DoubleHeadsModel`].
71
+
72
+ Whether or not to add a projection after the vector extraction.
73
+ summary_activation (`str`, *optional*):
74
+ Argument used when doing sequence summary. Used in for the multiple choice head in
75
+ [`GPT2DoubleHeadsModel`].
76
+
77
+ Pass `"tanh"` for a tanh activation to the output, any other value will result in no activation.
78
+ summary_proj_to_labels (`bool`, *optional*, defaults to `True`):
79
+ Argument used when doing sequence summary, used in the models [`GPT2DoubleHeadsModel`] and
80
+ [`TFGPT2DoubleHeadsModel`].
81
+
82
+ Whether the projection outputs should have `config.num_labels` or `config.hidden_size` classes.
83
+ summary_first_dropout (`float`, *optional*, defaults to 0.1):
84
+ Argument used when doing sequence summary, used in the models [`GPT2DoubleHeadsModel`] and
85
+ [`TFGPT2DoubleHeadsModel`].
86
+
87
+ The dropout ratio to be used after the projection and activation.
88
+ scale_attn_weights (`bool`, *optional*, defaults to `True`):
89
+ Scale attention weights by dividing by sqrt(hidden_size)..
90
+ use_cache (`bool`, *optional*, defaults to `True`):
91
+ Whether or not the model should return the last key/values attentions (not used by all models).
92
+ bos_token_id (`int`, *optional*, defaults to 50256):
93
+ Id of the beginning of sentence token in the vocabulary.
94
+ eos_token_id (`int`, *optional*, defaults to 50256):
95
+ Id of the end of sentence token in the vocabulary.
96
+ scale_attn_by_inverse_layer_idx (`bool`, *optional*, defaults to `False`):
97
+ Whether to additionally scale attention weights by `1 / layer_idx + 1`.
98
+ reorder_and_upcast_attn (`bool`, *optional*, defaults to `False`):
99
+ Whether to scale keys (K) prior to computing attention (dot-product) and upcast attention
100
+ dot-product/softmax to float() when training with mixed precision.
101
+
102
+ Example:
103
+
104
+ ```python
105
+ >>> from transformers import AraGPT2Config, AraGPT2Model
106
+
107
+ >>> # Initializing a AraGPT2 configuration
108
+ >>> configuration = AraGPT2Config()
109
+
110
+ >>> # Initializing a model (with random weights) from the configuration
111
+ >>> model = AraGPT2Model(configuration)
112
+
113
+ >>> # Accessing the model configuration
114
+ >>> configuration = model.config
115
+ ```"""
116
+
117
+ model_type = "aragpt2"
118
+ keys_to_ignore_at_inference = ["past_key_values"]
119
+ attribute_map = {
120
+ "hidden_size": "n_embd",
121
+ "max_position_embeddings": "n_positions",
122
+ "num_attention_heads": "n_head",
123
+ "num_hidden_layers": "n_layer",
124
+ }
125
+
126
+ def __init__(
127
+ self,
128
+ vocab_size=64000,
129
+ n_positions=1024,
130
+ n_embd=768,
131
+ n_layer=12,
132
+ n_head=12,
133
+ n_inner=None,
134
+ activation_function="gelu_new",
135
+ resid_pdrop=0.1,
136
+ embd_pdrop=0.1,
137
+ attn_pdrop=0.1,
138
+ layer_norm_epsilon=1e-5,
139
+ initializer_range=0.02,
140
+ summary_type="cls_index",
141
+ summary_use_proj=True,
142
+ summary_activation=None,
143
+ summary_proj_to_labels=True,
144
+ summary_first_dropout=0.1,
145
+ scale_attn_weights=True,
146
+ use_cache=True,
147
+ bos_token_id=50256,
148
+ eos_token_id=50256,
149
+ scale_attn_by_inverse_layer_idx=False,
150
+ reorder_and_upcast_attn=False,
151
+ **kwargs,
152
+ ):
153
+ self.vocab_size = vocab_size
154
+ self.n_positions = n_positions
155
+ self.n_embd = n_embd
156
+ self.n_layer = n_layer
157
+ self.n_head = n_head
158
+ self.n_inner = n_inner
159
+ self.activation_function = activation_function
160
+ self.resid_pdrop = resid_pdrop
161
+ self.embd_pdrop = embd_pdrop
162
+ self.attn_pdrop = attn_pdrop
163
+ self.layer_norm_epsilon = layer_norm_epsilon
164
+ self.initializer_range = initializer_range
165
+ self.summary_type = summary_type
166
+ self.summary_use_proj = summary_use_proj
167
+ self.summary_activation = summary_activation
168
+ self.summary_first_dropout = summary_first_dropout
169
+ self.summary_proj_to_labels = summary_proj_to_labels
170
+ self.scale_attn_weights = scale_attn_weights
171
+ self.use_cache = use_cache
172
+ self.scale_attn_by_inverse_layer_idx = scale_attn_by_inverse_layer_idx
173
+ self.reorder_and_upcast_attn = reorder_and_upcast_attn
174
+
175
+ self.bos_token_id = bos_token_id
176
+ self.eos_token_id = eos_token_id
177
+
178
+ super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
179
+
180
+
181
+ class AraGPT2OnnxConfig(OnnxConfigWithPast):
182
+ def __init__(
183
+ self,
184
+ config: PretrainedConfig,
185
+ task: str = "default",
186
+ patching_specs: List[PatchingSpec] = None,
187
+ use_past: bool = False,
188
+ ):
189
+ super().__init__(
190
+ config, task=task, patching_specs=patching_specs, use_past=use_past
191
+ )
192
+ if not getattr(self._config, "pad_token_id", None):
193
+ # TODO: how to do that better?
194
+ self._config.pad_token_id = 0
195
+
196
+ @property
197
+ def inputs(self) -> Mapping[str, Mapping[int, str]]:
198
+ common_inputs = OrderedDict({"input_ids": {0: "batch", 1: "sequence"}})
199
+ if self.use_past:
200
+ self.fill_with_past_key_values_(common_inputs, direction="inputs")
201
+ common_inputs["attention_mask"] = {
202
+ 0: "batch",
203
+ 1: "past_sequence + sequence",
204
+ }
205
+ else:
206
+ common_inputs["attention_mask"] = {0: "batch", 1: "sequence"}
207
+
208
+ return common_inputs
209
+
210
+ @property
211
+ def num_layers(self) -> int:
212
+ return self._config.n_layer
213
+
214
+ @property
215
+ def num_attention_heads(self) -> int:
216
+ return self._config.n_head
217
+
218
+ def generate_dummy_inputs(
219
+ self,
220
+ tokenizer: PreTrainedTokenizer,
221
+ batch_size: int = -1,
222
+ seq_length: int = -1,
223
+ is_pair: bool = False,
224
+ framework: Optional[TensorType] = None,
225
+ ) -> Mapping[str, Any]:
226
+ common_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs(
227
+ tokenizer,
228
+ batch_size=batch_size,
229
+ seq_length=seq_length,
230
+ is_pair=is_pair,
231
+ framework=framework,
232
+ )
233
+
234
+ # We need to order the input in the way they appears in the forward()
235
+ ordered_inputs = OrderedDict({"input_ids": common_inputs["input_ids"]})
236
+
237
+ # Need to add the past_keys
238
+ if self.use_past:
239
+ if not is_torch_available():
240
+ raise ValueError(
241
+ "Cannot generate dummy past_keys inputs without PyTorch installed."
242
+ )
243
+ else:
244
+ import torch
245
+
246
+ batch, seqlen = common_inputs["input_ids"].shape
247
+ # Not using the same length for past_key_values
248
+ past_key_values_length = seqlen + 2
249
+ past_shape = (
250
+ batch,
251
+ self.num_attention_heads,
252
+ past_key_values_length,
253
+ self._config.hidden_size // self.num_attention_heads,
254
+ )
255
+ ordered_inputs["past_key_values"] = [
256
+ (torch.zeros(past_shape), torch.zeros(past_shape))
257
+ for _ in range(self.num_layers)
258
+ ]
259
+
260
+ ordered_inputs["attention_mask"] = common_inputs["attention_mask"]
261
+ if self.use_past:
262
+ mask_dtype = ordered_inputs["attention_mask"].dtype
263
+ ordered_inputs["attention_mask"] = torch.cat(
264
+ [
265
+ ordered_inputs["attention_mask"],
266
+ torch.ones(batch, past_key_values_length, dtype=mask_dtype),
267
+ ],
268
+ dim=1,
269
+ )
270
+
271
+ return ordered_inputs
272
+
273
+ @property
274
+ def default_onnx_opset(self) -> int:
275
+ return 13
modeling_aragpt2.py ADDED
@@ -0,0 +1,1989 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """PyTorch OpenAI GPT-2 model."""
17
+
18
+ import math
19
+ import os
20
+ import warnings
21
+ from dataclasses import dataclass
22
+ from typing import Optional, Tuple, Union
23
+
24
+ import torch
25
+ import torch.utils.checkpoint
26
+ from torch import nn
27
+ from torch.cuda.amp import autocast
28
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
29
+
30
+ from transformers.activations import ACT2FN
31
+ from transformers.modeling_outputs import (
32
+ BaseModelOutputWithPastAndCrossAttentions,
33
+ CausalLMOutputWithCrossAttentions,
34
+ QuestionAnsweringModelOutput,
35
+ SequenceClassifierOutputWithPast,
36
+ TokenClassifierOutput,
37
+ )
38
+ from transformers.modeling_utils import PreTrainedModel, SequenceSummary
39
+ from transformers.pytorch_utils import (
40
+ Conv1D,
41
+ find_pruneable_heads_and_indices,
42
+ prune_conv1d_layer,
43
+ )
44
+ from transformers.utils import (
45
+ ModelOutput,
46
+ add_code_sample_docstrings,
47
+ add_start_docstrings,
48
+ add_start_docstrings_to_model_forward,
49
+ logging,
50
+ replace_return_docstrings,
51
+ )
52
+ from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
53
+ from .configuration_aragpt2 import AraGPT2Config
54
+
55
+
56
+ logger = logging.get_logger(__name__)
57
+
58
+ _CHECKPOINT_FOR_DOC = "aubmindlab/aragpt2-mega"
59
+ _CONFIG_FOR_DOC = "AraGPT2Config"
60
+ _TOKENIZER_FOR_DOC = "GPT2Tokenizer"
61
+
62
+ GPT2_PRETRAINED_MODEL_ARCHIVE_LIST = [
63
+ "aubmindlab/aragpt2-mega",
64
+ "gpt2-medium",
65
+ "aubmindlab/aragpt2-mega",
66
+ "aubmindlab/aragpt2-mega",
67
+ "distilgpt2",
68
+ # See all GPT-2 models at https://huggingface.co/models?filter=gpt2
69
+ ]
70
+
71
+ _GPT2_ML_TF_TO_TORCH = {
72
+ "LayerNorm_embed_norm": "emb_norm",
73
+ "pos_embed": "wpe.weight",
74
+ "word_embed": "wte.weight",
75
+ "layer": "h",
76
+ # Most importently This two layer norm must be put on the same position as gpt2-ml
77
+ # or generated data is bad, just repeat the last token
78
+ "LayerNorm_mlp_ln0": "ln_1",
79
+ "LayerNorm_mlp_ln1": "ln_2",
80
+ "intermediate": "mlp.c_fc",
81
+ "output": "mlp.c_proj",
82
+ "query_layer": "attn.c_attn",
83
+ "key_layer": "attn.c_attn",
84
+ "value_layer": "attn.c_attn",
85
+ "context_projection_layer": "attn.c_proj",
86
+ "gamma": "weight",
87
+ "kernel": "weight",
88
+ "beta": "bias",
89
+ "bias": "bias",
90
+ }
91
+
92
+
93
+ def convert_gpt2_checkpoint_to_pytorch(
94
+ gpt2_checkpoint_path, gpt2_config_file, pytorch_dump_folder_path
95
+ ):
96
+ # Construct model
97
+ if gpt2_config_file == "":
98
+ config = AraGPT2Config()
99
+ else:
100
+ config = AraGPT2Config.from_json_file(gpt2_config_file)
101
+ model = AraGPT2Model(config)
102
+
103
+ # Load weights from numpy
104
+ load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path)
105
+
106
+ # Save pytorch-model
107
+ pytorch_weights_dump_path = pytorch_dump_folder_path + "/" + WEIGHTS_NAME
108
+ pytorch_config_dump_path = pytorch_dump_folder_path + "/" + CONFIG_NAME
109
+ print("Save PyTorch model to {}".format(pytorch_weights_dump_path))
110
+ torch.save(model.state_dict(), pytorch_weights_dump_path)
111
+ print("Save configuration file to {}".format(pytorch_config_dump_path))
112
+ with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
113
+ f.write(config.to_json_string())
114
+
115
+
116
+ # XXX: MUST do like: convert_gpt2_checkpoint_to_pytorch('./model.ckpt-100000', './mega.json', './')
117
+ # https://github.com/tensorflow/models/issues/2675#issuecomment-516595597
118
+ def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
119
+ """Load tf checkpoints in a pytorch model"""
120
+ try:
121
+ import re
122
+ import tensorflow as tf
123
+ except ImportError:
124
+ logger.error(
125
+ "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
126
+ "https://www.tensorflow.org/install/ for installation instructions."
127
+ )
128
+ raise
129
+ tf_path = os.path.abspath(gpt2_checkpoint_path)
130
+ logger.info("Converting TensorFlow checkpoint from {}".format(tf_path))
131
+ # Load weights from TF model
132
+ init_vars = tf.train.list_variables(tf_path)
133
+ names = []
134
+ arrays = []
135
+ for name, shape in init_vars:
136
+ logger.info("Loading TF weight {} with shape {}".format(name, shape))
137
+ array = tf.train.load_variable(tf_path, name)
138
+ names.append(name)
139
+ arrays.append(array.squeeze())
140
+
141
+ import copy
142
+
143
+ orig_model = copy.deepcopy(model)
144
+
145
+ for name, array in zip(names, arrays):
146
+ name = name[6:] # skip "model/"
147
+ name = name.split("/")
148
+ pointer = model
149
+
150
+ attn_layer = ""
151
+ for m_name in name:
152
+ if re.fullmatch(r"[A-Za-z]+\d+", m_name):
153
+ scope_names = re.split(r"(\d+)", m_name)
154
+ else:
155
+ scope_names = [m_name]
156
+ sname = scope_names[0]
157
+
158
+ if sname == "" or sname == "embeddings":
159
+ continue
160
+ elif sname not in _GPT2_ML_TF_TO_TORCH:
161
+ print("=========================================================")
162
+ logger.info("Skip var name {}".format(scope_names))
163
+ pointer = None
164
+ break
165
+ else:
166
+ tname = _GPT2_ML_TF_TO_TORCH[sname]
167
+ if "." in tname:
168
+ parent, child = tname.split(".")
169
+ pointer = getattr(pointer, parent)
170
+ pointer = getattr(pointer, child)
171
+ else:
172
+ pointer = getattr(pointer, tname)
173
+
174
+ if tname == "attn.c_attn":
175
+ attn_layer = sname
176
+
177
+ if len(scope_names) >= 2:
178
+ num = int(scope_names[1])
179
+ pointer = pointer[num]
180
+
181
+ if pointer is None:
182
+ continue
183
+ if attn_layer == "":
184
+ try:
185
+ assert pointer.shape == array.shape
186
+ except AssertionError as e:
187
+ e.args += (pointer.shape, array.shape)
188
+ raise
189
+ logger.info(
190
+ "Initialize PyTorch weight {}, {}, {}".format(
191
+ name, array.mean(), pointer.mean()
192
+ )
193
+ )
194
+ if attn_layer == "":
195
+ pointer.data = torch.from_numpy(array)
196
+ else:
197
+ shape = pointer.shape
198
+ d = torch.from_numpy(array)
199
+ is_bias = len(shape) == 1
200
+ end = int(shape[0 if is_bias else 1] / 3)
201
+ m = dict(
202
+ query_layer=0,
203
+ key_layer=end,
204
+ value_layer=end * 2,
205
+ )
206
+ start = m[attn_layer]
207
+ end = start + end
208
+ if is_bias:
209
+ pointer.data[start:end] = d
210
+ else:
211
+ pointer.data[:, start:end] = d
212
+ logger.info(
213
+ "Initialize PyTorch weight {}, {}, {}".format(
214
+ name, array.mean(), pointer.mean()
215
+ )
216
+ )
217
+
218
+ for name, params in orig_model.named_parameters():
219
+ for n, p in model.named_parameters():
220
+ if name == n:
221
+ if params.equal(p):
222
+ print("--------------------------")
223
+ print(" %s not changed!" % n)
224
+ return model
225
+
226
+
227
+ class AraGPT2Attention(nn.Module):
228
+ def __init__(self, config, is_cross_attention=False, layer_idx=None):
229
+ super().__init__()
230
+
231
+ max_positions = config.max_position_embeddings
232
+ self.register_buffer(
233
+ "bias",
234
+ torch.tril(
235
+ torch.ones((max_positions, max_positions), dtype=torch.bool)
236
+ ).view(1, 1, max_positions, max_positions),
237
+ persistent=False,
238
+ )
239
+ self.register_buffer("masked_bias", torch.tensor(-1e4), persistent=False)
240
+
241
+ self.embed_dim = config.hidden_size
242
+ self.num_heads = config.num_attention_heads
243
+ self.head_dim = self.embed_dim // self.num_heads
244
+ self.split_size = self.embed_dim
245
+ if self.head_dim * self.num_heads != self.embed_dim:
246
+ raise ValueError(
247
+ f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
248
+ f" {self.num_heads})."
249
+ )
250
+
251
+ self.scale_attn_weights = config.scale_attn_weights
252
+ self.is_cross_attention = is_cross_attention
253
+
254
+ # Layer-wise attention scaling, reordering, and upcasting
255
+ self.scale_attn_by_inverse_layer_idx = config.scale_attn_by_inverse_layer_idx
256
+ self.layer_idx = layer_idx
257
+ self.reorder_and_upcast_attn = config.reorder_and_upcast_attn
258
+
259
+ if self.is_cross_attention:
260
+ self.c_attn = Conv1D(2 * self.embed_dim, self.embed_dim)
261
+ self.q_attn = Conv1D(self.embed_dim, self.embed_dim)
262
+ else:
263
+ self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim)
264
+ self.c_proj = Conv1D(self.embed_dim, self.embed_dim)
265
+
266
+ self.attn_dropout = nn.Dropout(config.attn_pdrop)
267
+ self.resid_dropout = nn.Dropout(config.resid_pdrop)
268
+
269
+ self.pruned_heads = set()
270
+
271
+ def prune_heads(self, heads):
272
+ if len(heads) == 0:
273
+ return
274
+ heads, index = find_pruneable_heads_and_indices(
275
+ heads, self.num_heads, self.head_dim, self.pruned_heads
276
+ )
277
+ index_attn = torch.cat(
278
+ [index, index + self.split_size, index + (2 * self.split_size)]
279
+ )
280
+
281
+ # Prune conv1d layers
282
+ self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1)
283
+ self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0)
284
+
285
+ # Update hyper params
286
+ self.split_size = (self.split_size // self.num_heads) * (
287
+ self.num_heads - len(heads)
288
+ )
289
+ self.num_heads = self.num_heads - len(heads)
290
+ self.pruned_heads = self.pruned_heads.union(heads)
291
+
292
+ def _attn(self, query, key, value, attention_mask=None, head_mask=None):
293
+ attn_weights = torch.matmul(query, key.transpose(-1, -2))
294
+
295
+ if self.scale_attn_weights:
296
+ attn_weights = attn_weights / torch.full(
297
+ [],
298
+ value.size(-1) ** 0.5,
299
+ dtype=attn_weights.dtype,
300
+ device=attn_weights.device,
301
+ )
302
+
303
+ # Layer-wise attention scaling
304
+ if self.scale_attn_by_inverse_layer_idx:
305
+ attn_weights = attn_weights / float(self.layer_idx + 1)
306
+
307
+ if not self.is_cross_attention:
308
+ # if only "normal" attention layer implements causal mask
309
+ query_length, key_length = query.size(-2), key.size(-2)
310
+ causal_mask = self.bias[
311
+ :, :, key_length - query_length : key_length, :key_length
312
+ ]
313
+ mask_value = torch.finfo(attn_weights.dtype).min
314
+ # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
315
+ # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
316
+ mask_value = torch.full(
317
+ [], mask_value, dtype=attn_weights.dtype, device=attn_weights.device
318
+ )
319
+ attn_weights = torch.where(
320
+ causal_mask, attn_weights.to(attn_weights.dtype), mask_value
321
+ )
322
+
323
+ if attention_mask is not None:
324
+ # Apply the attention mask
325
+ attn_weights = attn_weights + attention_mask
326
+
327
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
328
+
329
+ # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise
330
+ attn_weights = attn_weights.type(value.dtype)
331
+ attn_weights = self.attn_dropout(attn_weights)
332
+
333
+ # Mask heads if we want to
334
+ if head_mask is not None:
335
+ attn_weights = attn_weights * head_mask
336
+
337
+ attn_output = torch.matmul(attn_weights, value)
338
+
339
+ return attn_output, attn_weights
340
+
341
+ def _upcast_and_reordered_attn(
342
+ self, query, key, value, attention_mask=None, head_mask=None
343
+ ):
344
+ # Use `torch.baddbmm` (a bit more efficient w/ alpha param for scaling -- from Megatron-LM)
345
+ bsz, num_heads, q_seq_len, dk = query.size()
346
+ _, _, k_seq_len, _ = key.size()
347
+
348
+ # Preallocate attn_weights for `baddbmm`
349
+ attn_weights = torch.empty(
350
+ bsz * num_heads,
351
+ q_seq_len,
352
+ k_seq_len,
353
+ dtype=torch.float32,
354
+ device=query.device,
355
+ )
356
+
357
+ # Compute Scale Factor
358
+ scale_factor = 1.0
359
+ if self.scale_attn_weights:
360
+ scale_factor /= float(value.size(-1)) ** 0.5
361
+
362
+ if self.scale_attn_by_inverse_layer_idx:
363
+ scale_factor /= float(self.layer_idx + 1)
364
+
365
+ # Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk))
366
+ with autocast(enabled=False):
367
+ q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(
368
+ -1, dk, k_seq_len
369
+ )
370
+ attn_weights = torch.baddbmm(
371
+ attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor
372
+ )
373
+ attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
374
+
375
+ if not self.is_cross_attention:
376
+ # if only "normal" attention layer implements causal mask
377
+ query_length, key_length = query.size(-2), key.size(-2)
378
+ causal_mask = self.bias[
379
+ :, :, key_length - query_length : key_length, :key_length
380
+ ]
381
+ mask_value = torch.finfo(attn_weights.dtype).min
382
+ # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
383
+ # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
384
+ mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(
385
+ attn_weights.device
386
+ )
387
+ attn_weights = torch.where(causal_mask, attn_weights, mask_value)
388
+
389
+ if attention_mask is not None:
390
+ # Apply the attention mask
391
+ attn_weights = attn_weights + attention_mask
392
+
393
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
394
+
395
+ # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op if otherwise
396
+ if attn_weights.dtype != torch.float32:
397
+ raise RuntimeError(
398
+ "Error with upcasting, attn_weights does not have dtype torch.float32"
399
+ )
400
+ attn_weights = attn_weights.type(value.dtype)
401
+ attn_weights = self.attn_dropout(attn_weights)
402
+
403
+ # Mask heads if we want to
404
+ if head_mask is not None:
405
+ attn_weights = attn_weights * head_mask
406
+
407
+ attn_output = torch.matmul(attn_weights, value)
408
+
409
+ return attn_output, attn_weights
410
+
411
+ def _split_heads(self, tensor, num_heads, attn_head_size):
412
+ """
413
+ Splits hidden_size dim into attn_head_size and num_heads
414
+ """
415
+ new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
416
+ tensor = tensor.view(new_shape)
417
+ return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
418
+
419
+ def _merge_heads(self, tensor, num_heads, attn_head_size):
420
+ """
421
+ Merges attn_head_size dim and num_attn_heads dim into hidden_size
422
+ """
423
+ tensor = tensor.permute(0, 2, 1, 3).contiguous()
424
+ new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,)
425
+ return tensor.view(new_shape)
426
+
427
+ def forward(
428
+ self,
429
+ hidden_states: Optional[Tuple[torch.FloatTensor]],
430
+ layer_past: Optional[Tuple[torch.Tensor]] = None,
431
+ attention_mask: Optional[torch.FloatTensor] = None,
432
+ head_mask: Optional[torch.FloatTensor] = None,
433
+ encoder_hidden_states: Optional[torch.Tensor] = None,
434
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
435
+ use_cache: Optional[bool] = False,
436
+ output_attentions: Optional[bool] = False,
437
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
438
+ if encoder_hidden_states is not None:
439
+ if not hasattr(self, "q_attn"):
440
+ raise ValueError(
441
+ "If class is used as cross attention, the weights `q_attn` have to be defined. "
442
+ "Please make sure to instantiate class with `AraGPT2Attention(..., is_cross_attention=True)`."
443
+ )
444
+
445
+ query = self.q_attn(hidden_states)
446
+ key, value = self.c_attn(encoder_hidden_states).split(
447
+ self.split_size, dim=2
448
+ )
449
+ attention_mask = encoder_attention_mask
450
+ else:
451
+ query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
452
+
453
+ query = self._split_heads(query, self.num_heads, self.head_dim)
454
+ key = self._split_heads(key, self.num_heads, self.head_dim)
455
+ value = self._split_heads(value, self.num_heads, self.head_dim)
456
+
457
+ if layer_past is not None:
458
+ past_key, past_value = layer_past
459
+ key = torch.cat((past_key, key), dim=-2)
460
+ value = torch.cat((past_value, value), dim=-2)
461
+
462
+ if use_cache is True:
463
+ present = (key, value)
464
+ else:
465
+ present = None
466
+
467
+ if self.reorder_and_upcast_attn:
468
+ attn_output, attn_weights = self._upcast_and_reordered_attn(
469
+ query, key, value, attention_mask, head_mask
470
+ )
471
+ else:
472
+ attn_output, attn_weights = self._attn(
473
+ query, key, value, attention_mask, head_mask
474
+ )
475
+
476
+ attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
477
+ attn_output = self.c_proj(attn_output)
478
+ attn_output = self.resid_dropout(attn_output)
479
+
480
+ outputs = (attn_output, present)
481
+ if output_attentions:
482
+ outputs += (attn_weights,)
483
+
484
+ return outputs # a, present, (attentions)
485
+
486
+
487
+ class AraGPT2MLP(nn.Module):
488
+ def __init__(self, intermediate_size, config):
489
+ super().__init__()
490
+ embed_dim = config.hidden_size
491
+ self.c_fc = Conv1D(intermediate_size, embed_dim)
492
+ self.c_proj = Conv1D(embed_dim, intermediate_size)
493
+ self.act = ACT2FN[config.activation_function]
494
+ self.dropout = nn.Dropout(config.resid_pdrop)
495
+
496
+ def forward(
497
+ self, hidden_states: Optional[Tuple[torch.FloatTensor]]
498
+ ) -> torch.FloatTensor:
499
+ hidden_states = self.c_fc(hidden_states)
500
+ hidden_states = self.act(hidden_states)
501
+ hidden_states = self.c_proj(hidden_states)
502
+ hidden_states = self.dropout(hidden_states)
503
+ return hidden_states
504
+
505
+
506
+ class AraGPT2Block(nn.Module):
507
+ def __init__(self, config, layer_idx=None):
508
+ super().__init__()
509
+ hidden_size = config.hidden_size
510
+ inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
511
+
512
+ self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
513
+ self.attn = AraGPT2Attention(config, layer_idx=layer_idx)
514
+ self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
515
+
516
+ if config.add_cross_attention:
517
+ self.crossattention = AraGPT2Attention(
518
+ config, is_cross_attention=True, layer_idx=layer_idx
519
+ )
520
+ self.ln_cross_attn = nn.LayerNorm(
521
+ hidden_size, eps=config.layer_norm_epsilon
522
+ )
523
+
524
+ self.mlp = AraGPT2MLP(inner_dim, config)
525
+
526
+ def forward(
527
+ self,
528
+ hidden_states: Optional[Tuple[torch.FloatTensor]],
529
+ layer_past: Optional[Tuple[torch.Tensor]] = None,
530
+ attention_mask: Optional[torch.FloatTensor] = None,
531
+ head_mask: Optional[torch.FloatTensor] = None,
532
+ encoder_hidden_states: Optional[torch.Tensor] = None,
533
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
534
+ use_cache: Optional[bool] = False,
535
+ output_attentions: Optional[bool] = False,
536
+ ) -> Union[
537
+ Tuple[torch.Tensor],
538
+ Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]],
539
+ ]:
540
+
541
+ # removed in GROVER
542
+ # residual = hidden_states
543
+ # hidden_states = self.ln_1(hidden_states)
544
+ attn_outputs = self.attn(
545
+ hidden_states,
546
+ layer_past=layer_past,
547
+ attention_mask=attention_mask,
548
+ head_mask=head_mask,
549
+ use_cache=use_cache,
550
+ output_attentions=output_attentions,
551
+ )
552
+ attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
553
+ outputs = attn_outputs[1:]
554
+ # residual connection
555
+ hidden_states = attn_output + hidden_states
556
+
557
+ if encoder_hidden_states is not None:
558
+ # add one self-attention block for cross-attention
559
+ if not hasattr(self, "crossattention"):
560
+ raise ValueError(
561
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with "
562
+ "cross-attention layers by setting `config.add_cross_attention=True`"
563
+ )
564
+ # removed in GROVER
565
+ # residual = hidden_states
566
+ # hidden_states = self.ln_cross_attn(hidden_states)
567
+ cross_attn_outputs = self.crossattention(
568
+ hidden_states,
569
+ attention_mask=attention_mask,
570
+ head_mask=head_mask,
571
+ encoder_hidden_states=encoder_hidden_states,
572
+ encoder_attention_mask=encoder_attention_mask,
573
+ output_attentions=output_attentions,
574
+ )
575
+ attn_output = cross_attn_outputs[0]
576
+ # residual connection
577
+ hidden_states = attn_output + hidden_states
578
+ outputs = (
579
+ outputs + cross_attn_outputs[2:]
580
+ ) # add cross attentions if we output attention weights
581
+
582
+ residual = hidden_states
583
+ hidden_states = self.ln_1(hidden_states)
584
+ feed_forward_hidden_states = self.mlp(hidden_states)
585
+ # residual connection
586
+ hidden_states = residual + feed_forward_hidden_states
587
+
588
+ hidden_states = self.ln_2(hidden_states) # Added in GROVER
589
+
590
+ if use_cache:
591
+ outputs = (hidden_states,) + outputs
592
+ else:
593
+ outputs = (hidden_states,) + outputs[1:]
594
+
595
+ return outputs # hidden_states, present, (attentions, cross_attentions)
596
+
597
+
598
+ class AraGPT2PreTrainedModel(PreTrainedModel):
599
+ """
600
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
601
+ models.
602
+ """
603
+
604
+ config_class = AraGPT2Config
605
+ load_tf_weights = load_tf_weights_in_gpt2
606
+ base_model_prefix = "transformer"
607
+ is_parallelizable = True
608
+ supports_gradient_checkpointing = True
609
+ _no_split_modules = ["AraGPT2Block"]
610
+ _skip_keys_device_placement = "past_key_values"
611
+
612
+ def __init__(self, *inputs, **kwargs):
613
+ super().__init__(*inputs, **kwargs)
614
+
615
+ def _init_weights(self, module):
616
+ """Initialize the weights."""
617
+ if isinstance(module, (nn.Linear, Conv1D)):
618
+ # Slightly different from the TF version which uses truncated_normal for initialization
619
+ # cf https://github.com/pytorch/pytorch/pull/5617
620
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
621
+ if module.bias is not None:
622
+ module.bias.data.zero_()
623
+ elif isinstance(module, nn.Embedding):
624
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
625
+ if module.padding_idx is not None:
626
+ module.weight.data[module.padding_idx].zero_()
627
+ elif isinstance(module, nn.LayerNorm):
628
+ module.bias.data.zero_()
629
+ module.weight.data.fill_(1.0)
630
+
631
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
632
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
633
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
634
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
635
+ #
636
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
637
+ for name, p in module.named_parameters():
638
+ if "c_proj" in name and "weight" in name:
639
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
640
+ p.data.normal_(
641
+ mean=0.0,
642
+ std=(
643
+ self.config.initializer_range
644
+ / math.sqrt(2 * self.config.n_layer)
645
+ ),
646
+ )
647
+
648
+
649
+ @dataclass
650
+ class AraGPT2DoubleHeadsModelOutput(ModelOutput):
651
+ """
652
+ Base class for outputs of models predicting if two sentences are consecutive or not.
653
+
654
+ Args:
655
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
656
+ Language modeling loss.
657
+ mc_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `mc_labels` is provided):
658
+ Multiple choice classification loss.
659
+ logits (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, config.vocab_size)`):
660
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
661
+ mc_logits (`torch.FloatTensor` of shape `(batch_size, num_choices)`):
662
+ Prediction scores of the multiple choice classification head (scores for each choice before SoftMax).
663
+ past_key_values (`Tuple[Tuple[torch.Tensor]]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
664
+ Tuple of length `config.n_layers`, containing tuples of tensors of shape `(batch_size, num_heads,
665
+ sequence_length, embed_size_per_head)`).
666
+
667
+ Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
668
+ `past_key_values` input) to speed up sequential decoding.
669
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
670
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
671
+ shape `(batch_size, sequence_length, hidden_size)`.
672
+
673
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
674
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
675
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
676
+ sequence_length)`.
677
+
678
+ GPT2Attentions weights after the attention softmax, used to compute the weighted average in the
679
+ self-attention heads.
680
+ """
681
+
682
+ loss: Optional[torch.FloatTensor] = None
683
+ mc_loss: Optional[torch.FloatTensor] = None
684
+ logits: torch.FloatTensor = None
685
+ mc_logits: torch.FloatTensor = None
686
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
687
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
688
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
689
+
690
+
691
+ AraGPT2_START_DOCSTRING = r"""
692
+
693
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
694
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
695
+ etc.)
696
+
697
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
698
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
699
+ and behavior.
700
+
701
+ Parameters:
702
+ config ([`AraGPT2Config`]): Model configuration class with all the parameters of the model.
703
+ Initializing with a config file does not load the weights associated with the model, only the
704
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
705
+ """
706
+
707
+ GPT2_INPUTS_DOCSTRING = r"""
708
+ Args:
709
+ input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
710
+ `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
711
+ `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input
712
+ sequence tokens in the vocabulary.
713
+
714
+ If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
715
+ `input_ids`.
716
+
717
+ Indices can be obtained using [`GPT2Tokenizer`]. See [`PreTrainedTokenizer.encode`] and
718
+ [`PreTrainedTokenizer.__call__`] for details.
719
+
720
+ [What are input IDs?](../glossary#input-ids)
721
+ past_key_values (`Tuple[Tuple[torch.Tensor]]` of length `config.n_layers`):
722
+ Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see
723
+ `past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have
724
+ their past given to this model should not be passed as `input_ids` as they have already been computed.
725
+ attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
726
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
727
+
728
+ - 1 for tokens that are **not masked**,
729
+ - 0 for tokens that are **masked**.
730
+
731
+ If `past_key_values` is used, `attention_mask` needs to contain the masking strategy that was used for
732
+ `past_key_values`. In other words, the `attention_mask` always has to have the length:
733
+ `len(past_key_values) + len(input_ids)`
734
+
735
+ [What are attention masks?](../glossary#attention-mask)
736
+ token_type_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`, *optional*):
737
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
738
+ 1]`:
739
+
740
+ - 0 corresponds to a *sentence A* token,
741
+ - 1 corresponds to a *sentence B* token.
742
+
743
+ [What are token type IDs?](../glossary#token-type-ids)
744
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
745
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
746
+ config.max_position_embeddings - 1]`.
747
+
748
+ [What are position IDs?](../glossary#position-ids)
749
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
750
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
751
+
752
+ - 1 indicates the head is **not masked**,
753
+ - 0 indicates the head is **masked**.
754
+
755
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
756
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
757
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
758
+ model's internal embedding lookup matrix.
759
+
760
+ If `past_key_values` is used, optionally only the last `inputs_embeds` have to be input (see
761
+ `past_key_values`).
762
+ use_cache (`bool`, *optional*):
763
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
764
+ `past_key_values`).
765
+ output_attentions (`bool`, *optional*):
766
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
767
+ tensors for more detail.
768
+ output_hidden_states (`bool`, *optional*):
769
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
770
+ more detail.
771
+ return_dict (`bool`, *optional*):
772
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
773
+ """
774
+ PARALLELIZE_DOCSTRING = r"""
775
+ This is an experimental feature and is a subject to change at a moment's notice.
776
+
777
+ Uses a device map to distribute attention modules of the model across several devices. If no device map is given,
778
+ it will evenly distribute blocks across all devices.
779
+
780
+ Args:
781
+ device_map (`Dict[int, list]`, optional, defaults to None):
782
+ A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always
783
+ automatically mapped to the first device (for esoteric reasons). That means that the first device should
784
+ have fewer attention modules mapped to it than other devices. For reference, the gpt2 models have the
785
+ following number of attention modules:
786
+
787
+ - aubmindlab/aragpt2-mega: 48
788
+
789
+ Example:
790
+
791
+ ```python
792
+ # Here is an example of a device map on a machine with 4 GPUs using aubmindlab/aragpt2-mega, which has a total of 48 attention modules:
793
+ model = AraGPT2LMHeadModel.from_pretrained("aubmindlab/aragpt2-mega")
794
+ device_map = {
795
+ 0: [0, 1, 2, 3, 4, 5, 6, 7, 8],
796
+ 1: [9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21],
797
+ 2: [22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34],
798
+ 3: [35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47],
799
+ }
800
+ model.parallelize(device_map)
801
+ ```
802
+ """
803
+ DEPARALLELIZE_DOCSTRING = r"""
804
+ Moves the model to cpu from a model parallel state.
805
+
806
+ Example:
807
+
808
+ ```python
809
+ # On a 4 GPU machine with aubmindlab/aragpt2-mega:
810
+ model = AraGPT2LMHeadModel.from_pretrained("aubmindlab/aragpt2-mega")
811
+ device_map = {
812
+ 0: [0, 1, 2, 3, 4, 5, 6, 7, 8],
813
+ 1: [9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21],
814
+ 2: [22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34],
815
+ 3: [35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47],
816
+ }
817
+ model.parallelize(device_map) # Splits the model across several devices
818
+ model.deparallelize() # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache()
819
+ ```
820
+ """
821
+
822
+
823
+ @add_start_docstrings(
824
+ "The bare GPT2 Model transformer outputting raw hidden-states without any specific head on top.",
825
+ AraGPT2_START_DOCSTRING,
826
+ )
827
+ class AraGPT2Model(AraGPT2PreTrainedModel):
828
+ _keys_to_ignore_on_load_unexpected = ["attn.masked_bias"]
829
+ _keys_to_ignore_on_load_missing = ["attn.masked_bias"]
830
+
831
+ def __init__(self, config):
832
+ super().__init__(config)
833
+
834
+ self.embed_dim = config.hidden_size
835
+
836
+ self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
837
+ self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
838
+ self.emb_norm = nn.LayerNorm(
839
+ config.n_embd, eps=config.layer_norm_epsilon
840
+ ) # Added in GROVER
841
+ self.drop = nn.Dropout(config.embd_pdrop)
842
+ self.h = nn.ModuleList(
843
+ [AraGPT2Block(config, layer_idx=i) for i in range(config.num_hidden_layers)]
844
+ )
845
+ # Removed in GROVER
846
+ # self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
847
+
848
+ # Model parallel
849
+ self.model_parallel = False
850
+ self.device_map = None
851
+ self.gradient_checkpointing = False
852
+
853
+ # Initialize weights and apply final processing
854
+ self.post_init()
855
+
856
+ @add_start_docstrings(PARALLELIZE_DOCSTRING)
857
+ def parallelize(self, device_map=None):
858
+ # Check validity of device_map
859
+ warnings.warn(
860
+ "`AraGPT2Model.parallelize` is deprecated and will be removed in v5 of Transformers, you should load your"
861
+ " model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own"
862
+ " `device_map` but it needs to be a dictionary module_name to device, so for instance {'h.0': 0, 'h.1': 1,"
863
+ " ...}",
864
+ FutureWarning,
865
+ )
866
+ self.device_map = (
867
+ get_device_map(len(self.h), range(torch.cuda.device_count()))
868
+ if device_map is None
869
+ else device_map
870
+ )
871
+ assert_device_map(self.device_map, len(self.h))
872
+ self.model_parallel = True
873
+ self.first_device = (
874
+ "cpu"
875
+ if "cpu" in self.device_map.keys()
876
+ else "cuda:" + str(min(self.device_map.keys()))
877
+ )
878
+ self.last_device = "cuda:" + str(max(self.device_map.keys()))
879
+ self.wte = self.wte.to(self.first_device)
880
+ self.wpe = self.wpe.to(self.first_device)
881
+
882
+ # Added in GROVER
883
+ # Wissam: not sure if it is fine being on cpu or Better on GPU
884
+ self.emb_norm = self.emb_norm.to(
885
+ "cuda:" + str(min(self.device_map.keys()))
886
+ ) # GPU
887
+ # self.emb_norm = self.emb_norm.to(self.first_device) # CPU
888
+
889
+ # Load onto devices
890
+ for k, v in self.device_map.items():
891
+ for block in v:
892
+ cuda_device = "cuda:" + str(k)
893
+ self.h[block] = self.h[block].to(cuda_device)
894
+ # ln_f to last
895
+ # Removed in GROVER
896
+ # self.ln_f = self.ln_f.to(self.last_device)
897
+
898
+ @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
899
+ def deparallelize(self):
900
+ warnings.warn(
901
+ "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
902
+ FutureWarning,
903
+ )
904
+ self.model_parallel = False
905
+ self.device_map = None
906
+ self.first_device = "cpu"
907
+ self.last_device = "cpu"
908
+ self.wte = self.wte.to("cpu")
909
+ self.wpe = self.wpe.to("cpu")
910
+ # Added in GROVER
911
+ self.emb_norm = self.emb_norm.to("cpu")
912
+ for index in range(len(self.h)):
913
+ self.h[index] = self.h[index].to("cpu")
914
+ # Removed in GROVER
915
+ # self.ln_f = self.ln_f.to("cpu")
916
+ torch.cuda.empty_cache()
917
+
918
+ def get_input_embeddings(self):
919
+ return self.wte
920
+
921
+ def set_input_embeddings(self, new_embeddings):
922
+ self.wte = new_embeddings
923
+
924
+ def _prune_heads(self, heads_to_prune):
925
+ """
926
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
927
+ """
928
+ for layer, heads in heads_to_prune.items():
929
+ self.h[layer].attn.prune_heads(heads)
930
+
931
+ @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
932
+ @add_code_sample_docstrings(
933
+ processor_class=_TOKENIZER_FOR_DOC,
934
+ checkpoint=_CHECKPOINT_FOR_DOC,
935
+ output_type=BaseModelOutputWithPastAndCrossAttentions,
936
+ config_class=_CONFIG_FOR_DOC,
937
+ )
938
+ def forward(
939
+ self,
940
+ input_ids: Optional[torch.LongTensor] = None,
941
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
942
+ attention_mask: Optional[torch.FloatTensor] = None,
943
+ token_type_ids: Optional[torch.LongTensor] = None,
944
+ position_ids: Optional[torch.LongTensor] = None,
945
+ head_mask: Optional[torch.FloatTensor] = None,
946
+ inputs_embeds: Optional[torch.FloatTensor] = None,
947
+ encoder_hidden_states: Optional[torch.Tensor] = None,
948
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
949
+ use_cache: Optional[bool] = None,
950
+ output_attentions: Optional[bool] = None,
951
+ output_hidden_states: Optional[bool] = None,
952
+ return_dict: Optional[bool] = None,
953
+ ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
954
+ output_attentions = (
955
+ output_attentions
956
+ if output_attentions is not None
957
+ else self.config.output_attentions
958
+ )
959
+ output_hidden_states = (
960
+ output_hidden_states
961
+ if output_hidden_states is not None
962
+ else self.config.output_hidden_states
963
+ )
964
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
965
+ return_dict = (
966
+ return_dict if return_dict is not None else self.config.use_return_dict
967
+ )
968
+
969
+ if input_ids is not None and inputs_embeds is not None:
970
+ raise ValueError(
971
+ "You cannot specify both input_ids and inputs_embeds at the same time"
972
+ )
973
+ elif input_ids is not None:
974
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
975
+ input_shape = input_ids.size()
976
+ input_ids = input_ids.view(-1, input_shape[-1])
977
+ batch_size = input_ids.shape[0]
978
+ elif inputs_embeds is not None:
979
+ input_shape = inputs_embeds.size()[:-1]
980
+ batch_size = inputs_embeds.shape[0]
981
+ else:
982
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
983
+
984
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
985
+
986
+ if token_type_ids is not None:
987
+ token_type_ids = token_type_ids.view(-1, input_shape[-1])
988
+
989
+ if past_key_values is None:
990
+ past_length = 0
991
+ past_key_values = tuple([None] * len(self.h))
992
+ else:
993
+ past_length = past_key_values[0][0].size(-2)
994
+ if position_ids is None:
995
+ position_ids = torch.arange(
996
+ past_length,
997
+ input_shape[-1] + past_length,
998
+ dtype=torch.long,
999
+ device=device,
1000
+ )
1001
+ position_ids = position_ids.unsqueeze(0)
1002
+
1003
+ # AraGPT2Attention mask.
1004
+ if attention_mask is not None:
1005
+ if batch_size <= 0:
1006
+ raise ValueError("batch_size has to be defined and > 0")
1007
+ attention_mask = attention_mask.view(batch_size, -1)
1008
+ # We create a 3D attention mask from a 2D tensor mask.
1009
+ # Sizes are [batch_size, 1, 1, to_seq_length]
1010
+ # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
1011
+ # this attention mask is more simple than the triangular masking of causal attention
1012
+ # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
1013
+ attention_mask = attention_mask[:, None, None, :]
1014
+
1015
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
1016
+ # masked positions, this operation will create a tensor which is 0.0 for
1017
+ # positions we want to attend and the dtype's smallest value for masked positions.
1018
+ # Since we are adding it to the raw scores before the softmax, this is
1019
+ # effectively the same as removing these entirely.
1020
+ attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
1021
+ attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
1022
+
1023
+ # If a 2D or 3D attention mask is provided for the cross-attention
1024
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
1025
+ if self.config.add_cross_attention and encoder_hidden_states is not None:
1026
+ encoder_batch_size, encoder_sequence_length, _ = (
1027
+ encoder_hidden_states.size()
1028
+ )
1029
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
1030
+ if encoder_attention_mask is None:
1031
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
1032
+ encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
1033
+ else:
1034
+ encoder_attention_mask = None
1035
+
1036
+ # Prepare head mask if needed
1037
+ # 1.0 in head_mask indicate we keep the head
1038
+ # attention_probs has shape bsz x n_heads x N x N
1039
+ # head_mask has shape n_layer x batch x n_heads x N x N
1040
+ head_mask = self.get_head_mask(head_mask, self.config.n_layer)
1041
+
1042
+ if inputs_embeds is None:
1043
+ inputs_embeds = self.wte(input_ids)
1044
+ position_embeds = self.wpe(position_ids)
1045
+ hidden_states = inputs_embeds + position_embeds
1046
+
1047
+ if token_type_ids is not None:
1048
+ token_type_embeds = self.wte(token_type_ids)
1049
+ hidden_states = hidden_states + token_type_embeds
1050
+
1051
+ hidden_states = self.drop(hidden_states)
1052
+ # Added in Grover
1053
+ hidden_states = self.emb_norm(hidden_states)
1054
+
1055
+ output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),)
1056
+
1057
+ if self.gradient_checkpointing and self.training:
1058
+ if use_cache:
1059
+ logger.warning_once(
1060
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
1061
+ )
1062
+ use_cache = False
1063
+
1064
+ presents = () if use_cache else None
1065
+ all_self_attentions = () if output_attentions else None
1066
+ all_cross_attentions = (
1067
+ () if output_attentions and self.config.add_cross_attention else None
1068
+ )
1069
+ all_hidden_states = () if output_hidden_states else None
1070
+ for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
1071
+ # Model parallel
1072
+ if self.model_parallel:
1073
+ torch.cuda.set_device(hidden_states.device)
1074
+ # Ensure layer_past is on same device as hidden_states (might not be correct)
1075
+ if layer_past is not None:
1076
+ layer_past = tuple(
1077
+ past_state.to(hidden_states.device) for past_state in layer_past
1078
+ )
1079
+ # Ensure that attention_mask is always on the same device as hidden_states
1080
+ if attention_mask is not None:
1081
+ attention_mask = attention_mask.to(hidden_states.device)
1082
+ if isinstance(head_mask, torch.Tensor):
1083
+ head_mask = head_mask.to(hidden_states.device)
1084
+ if output_hidden_states:
1085
+ all_hidden_states = all_hidden_states + (hidden_states,)
1086
+
1087
+ if self.gradient_checkpointing and self.training:
1088
+ outputs = self._gradient_checkpointing_func(
1089
+ block.__call__,
1090
+ hidden_states,
1091
+ None,
1092
+ attention_mask,
1093
+ head_mask[i],
1094
+ encoder_hidden_states,
1095
+ encoder_attention_mask,
1096
+ use_cache,
1097
+ output_attentions,
1098
+ )
1099
+ else:
1100
+ outputs = block(
1101
+ hidden_states,
1102
+ layer_past=layer_past,
1103
+ attention_mask=attention_mask,
1104
+ head_mask=head_mask[i],
1105
+ encoder_hidden_states=encoder_hidden_states,
1106
+ encoder_attention_mask=encoder_attention_mask,
1107
+ use_cache=use_cache,
1108
+ output_attentions=output_attentions,
1109
+ )
1110
+
1111
+ hidden_states = outputs[0]
1112
+ if use_cache is True:
1113
+ presents = presents + (outputs[1],)
1114
+
1115
+ if output_attentions:
1116
+ all_self_attentions = all_self_attentions + (
1117
+ outputs[2 if use_cache else 1],
1118
+ )
1119
+ if self.config.add_cross_attention:
1120
+ all_cross_attentions = all_cross_attentions + (
1121
+ outputs[3 if use_cache else 2],
1122
+ )
1123
+
1124
+ # Model Parallel: If it's the last layer for that device, put things on the next device
1125
+ if self.model_parallel:
1126
+ for k, v in self.device_map.items():
1127
+ if i == v[-1] and "cuda:" + str(k) != self.last_device:
1128
+ hidden_states = hidden_states.to("cuda:" + str(k + 1))
1129
+
1130
+ # Removed in Grover
1131
+ # hidden_states = self.ln_f(hidden_states)
1132
+
1133
+ hidden_states = hidden_states.view(output_shape)
1134
+ # Add last hidden state
1135
+ if output_hidden_states:
1136
+ all_hidden_states = all_hidden_states + (hidden_states,)
1137
+
1138
+ if not return_dict:
1139
+ return tuple(
1140
+ v
1141
+ for v in [
1142
+ hidden_states,
1143
+ presents,
1144
+ all_hidden_states,
1145
+ all_self_attentions,
1146
+ all_cross_attentions,
1147
+ ]
1148
+ if v is not None
1149
+ )
1150
+
1151
+ return BaseModelOutputWithPastAndCrossAttentions(
1152
+ last_hidden_state=hidden_states,
1153
+ past_key_values=presents,
1154
+ hidden_states=all_hidden_states,
1155
+ attentions=all_self_attentions,
1156
+ cross_attentions=all_cross_attentions,
1157
+ )
1158
+
1159
+
1160
+ @add_start_docstrings(
1161
+ """
1162
+ The GPT2 Model transformer with a language modeling head on top (linear layer with weights tied to the input
1163
+ embeddings).
1164
+ """,
1165
+ AraGPT2_START_DOCSTRING,
1166
+ )
1167
+ class AraGPT2LMHeadModel(AraGPT2PreTrainedModel):
1168
+ _keys_to_ignore_on_load_unexpected = [
1169
+ r"attn.masked_bias",
1170
+ r"attn.bias",
1171
+ r"lm_head.weight",
1172
+ ]
1173
+ _keys_to_ignore_on_load_missing = [
1174
+ r"attn.masked_bias",
1175
+ r"attn.bias",
1176
+ r"lm_head.weight",
1177
+ ]
1178
+ _tied_weights_keys = ["lm_head.weight"]
1179
+
1180
+ def __init__(self, config):
1181
+ super().__init__(config)
1182
+ self.transformer = AraGPT2Model(config)
1183
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
1184
+
1185
+ # Model parallel
1186
+ self.model_parallel = False
1187
+ self.device_map = None
1188
+
1189
+ # Initialize weights and apply final processing
1190
+ self.post_init()
1191
+
1192
+ @add_start_docstrings(PARALLELIZE_DOCSTRING)
1193
+ def parallelize(self, device_map=None):
1194
+ warnings.warn(
1195
+ "`GPT2LMHeadModel.parallelize` is deprecated and will be removed in v5 of Transformers, you should load"
1196
+ " your model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own"
1197
+ " `device_map` but it needs to be a dictionary module_name to device, so for instance {'transformer.h.0':"
1198
+ " 0, 'transformer.h.1': 1, ...}",
1199
+ FutureWarning,
1200
+ )
1201
+ self.device_map = (
1202
+ get_device_map(len(self.transformer.h), range(torch.cuda.device_count()))
1203
+ if device_map is None
1204
+ else device_map
1205
+ )
1206
+ assert_device_map(self.device_map, len(self.transformer.h))
1207
+ self.transformer.parallelize(self.device_map)
1208
+ self.lm_head = self.lm_head.to(self.transformer.first_device)
1209
+ self.model_parallel = True
1210
+
1211
+ @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
1212
+ def deparallelize(self):
1213
+ warnings.warn(
1214
+ "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
1215
+ FutureWarning,
1216
+ )
1217
+ self.transformer.deparallelize()
1218
+ self.transformer = self.transformer.to("cpu")
1219
+ self.lm_head = self.lm_head.to("cpu")
1220
+ self.model_parallel = False
1221
+ torch.cuda.empty_cache()
1222
+
1223
+ def get_output_embeddings(self):
1224
+ return self.lm_head
1225
+
1226
+ def set_output_embeddings(self, new_embeddings):
1227
+ self.lm_head = new_embeddings
1228
+
1229
+ def prepare_inputs_for_generation(
1230
+ self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs
1231
+ ):
1232
+ token_type_ids = kwargs.get("token_type_ids", None)
1233
+ # Omit tokens covered by past_key_values
1234
+ if past_key_values:
1235
+ past_length = past_key_values[0][0].shape[2]
1236
+
1237
+ # Some generation methods already pass only the last input ID
1238
+ if input_ids.shape[1] > past_length:
1239
+ remove_prefix_length = past_length
1240
+ else:
1241
+ # Default to old behavior: keep only final ID
1242
+ remove_prefix_length = input_ids.shape[1] - 1
1243
+
1244
+ input_ids = input_ids[:, remove_prefix_length:]
1245
+ if token_type_ids is not None:
1246
+ token_type_ids = token_type_ids[:, -input_ids.shape[1] :]
1247
+
1248
+ attention_mask = kwargs.get("attention_mask", None)
1249
+ position_ids = kwargs.get("position_ids", None)
1250
+
1251
+ if attention_mask is not None and position_ids is None:
1252
+ # create position_ids on the fly for batch generation
1253
+ position_ids = attention_mask.long().cumsum(-1) - 1
1254
+ position_ids.masked_fill_(attention_mask == 0, 1)
1255
+ if past_key_values:
1256
+ position_ids = position_ids[:, -input_ids.shape[1] :]
1257
+ else:
1258
+ position_ids = None
1259
+
1260
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1261
+ if inputs_embeds is not None and past_key_values is None:
1262
+ model_inputs = {"inputs_embeds": inputs_embeds}
1263
+ else:
1264
+ model_inputs = {"input_ids": input_ids}
1265
+
1266
+ model_inputs.update(
1267
+ {
1268
+ "past_key_values": past_key_values,
1269
+ "use_cache": kwargs.get("use_cache"),
1270
+ "position_ids": position_ids,
1271
+ "attention_mask": attention_mask,
1272
+ "token_type_ids": token_type_ids,
1273
+ }
1274
+ )
1275
+
1276
+ return model_inputs
1277
+
1278
+ @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
1279
+ @add_code_sample_docstrings(
1280
+ processor_class=_TOKENIZER_FOR_DOC,
1281
+ checkpoint=_CHECKPOINT_FOR_DOC,
1282
+ output_type=CausalLMOutputWithCrossAttentions,
1283
+ config_class=_CONFIG_FOR_DOC,
1284
+ )
1285
+ def forward(
1286
+ self,
1287
+ input_ids: Optional[torch.LongTensor] = None,
1288
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
1289
+ attention_mask: Optional[torch.FloatTensor] = None,
1290
+ token_type_ids: Optional[torch.LongTensor] = None,
1291
+ position_ids: Optional[torch.LongTensor] = None,
1292
+ head_mask: Optional[torch.FloatTensor] = None,
1293
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1294
+ encoder_hidden_states: Optional[torch.Tensor] = None,
1295
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
1296
+ labels: Optional[torch.LongTensor] = None,
1297
+ use_cache: Optional[bool] = None,
1298
+ output_attentions: Optional[bool] = None,
1299
+ output_hidden_states: Optional[bool] = None,
1300
+ return_dict: Optional[bool] = None,
1301
+ ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
1302
+ r"""
1303
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1304
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
1305
+ `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
1306
+ are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
1307
+ """
1308
+ return_dict = (
1309
+ return_dict if return_dict is not None else self.config.use_return_dict
1310
+ )
1311
+
1312
+ transformer_outputs = self.transformer(
1313
+ input_ids,
1314
+ past_key_values=past_key_values,
1315
+ attention_mask=attention_mask,
1316
+ token_type_ids=token_type_ids,
1317
+ position_ids=position_ids,
1318
+ head_mask=head_mask,
1319
+ inputs_embeds=inputs_embeds,
1320
+ encoder_hidden_states=encoder_hidden_states,
1321
+ encoder_attention_mask=encoder_attention_mask,
1322
+ use_cache=use_cache,
1323
+ output_attentions=output_attentions,
1324
+ output_hidden_states=output_hidden_states,
1325
+ return_dict=return_dict,
1326
+ )
1327
+ hidden_states = transformer_outputs[0]
1328
+
1329
+ # Set device for model parallelism
1330
+ if self.model_parallel:
1331
+ torch.cuda.set_device(self.transformer.first_device)
1332
+ hidden_states = hidden_states.to(self.lm_head.weight.device)
1333
+
1334
+ lm_logits = self.lm_head(hidden_states)
1335
+
1336
+ loss = None
1337
+ if labels is not None:
1338
+ # move labels to correct device to enable model parallelism
1339
+ labels = labels.to(lm_logits.device)
1340
+ # Shift so that tokens < n predict n
1341
+ shift_logits = lm_logits[..., :-1, :].contiguous()
1342
+ shift_labels = labels[..., 1:].contiguous()
1343
+ # Flatten the tokens
1344
+ loss_fct = CrossEntropyLoss()
1345
+ loss = loss_fct(
1346
+ shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
1347
+ )
1348
+
1349
+ if not return_dict:
1350
+ output = (lm_logits,) + transformer_outputs[1:]
1351
+ return ((loss,) + output) if loss is not None else output
1352
+
1353
+ return CausalLMOutputWithCrossAttentions(
1354
+ loss=loss,
1355
+ logits=lm_logits,
1356
+ past_key_values=transformer_outputs.past_key_values,
1357
+ hidden_states=transformer_outputs.hidden_states,
1358
+ attentions=transformer_outputs.attentions,
1359
+ cross_attentions=transformer_outputs.cross_attentions,
1360
+ )
1361
+
1362
+ @staticmethod
1363
+ def _reorder_cache(
1364
+ past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
1365
+ ) -> Tuple[Tuple[torch.Tensor]]:
1366
+ """
1367
+ This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
1368
+ [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
1369
+ beam_idx at every generation step.
1370
+ """
1371
+ return tuple(
1372
+ tuple(
1373
+ past_state.index_select(0, beam_idx.to(past_state.device))
1374
+ for past_state in layer_past
1375
+ )
1376
+ for layer_past in past_key_values
1377
+ )
1378
+
1379
+
1380
+ @add_start_docstrings(
1381
+ """
1382
+ The GPT2 Model transformer with a language modeling and a multiple-choice classification head on top e.g. for
1383
+ RocStories/SWAG tasks. The two heads are two linear layers. The language modeling head has its weights tied to the
1384
+ input embeddings, the classification head takes as input the input of a specified classification token index in the
1385
+ input sequence).
1386
+ """,
1387
+ AraGPT2_START_DOCSTRING,
1388
+ )
1389
+ class AraGPT2DoubleHeadsModel(AraGPT2PreTrainedModel):
1390
+ _keys_to_ignore_on_load_unexpected = [
1391
+ r"attn.masked_bias",
1392
+ r"attn.bias",
1393
+ r"lm_head.weight",
1394
+ ]
1395
+ _keys_to_ignore_on_load_missing = [
1396
+ r"attn.masked_bias",
1397
+ r"attn.bias",
1398
+ r"lm_head.weight",
1399
+ ]
1400
+ _tied_weights_keys = ["lm_head.weight"]
1401
+
1402
+ def __init__(self, config):
1403
+ super().__init__(config)
1404
+ config.num_labels = 1
1405
+ self.transformer = AraGPT2Model(config)
1406
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
1407
+ self.multiple_choice_head = SequenceSummary(config)
1408
+
1409
+ # Model parallel
1410
+ self.model_parallel = False
1411
+ self.device_map = None
1412
+
1413
+ # Initialize weights and apply final processing
1414
+ self.post_init()
1415
+
1416
+ @add_start_docstrings(PARALLELIZE_DOCSTRING)
1417
+ def parallelize(self, device_map=None):
1418
+ warnings.warn(
1419
+ "`GPT2DoubleHeadsModel.parallelize` is deprecated and will be removed in v5 of Transformers, you should"
1420
+ " load your model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your"
1421
+ " own `device_map` but it needs to be a dictionary module_name to device, so for instance"
1422
+ " {'transformer.h.0': 0, 'transformer.h.1': 1, ...}",
1423
+ FutureWarning,
1424
+ )
1425
+ self.device_map = (
1426
+ get_device_map(len(self.transformer.h), range(torch.cuda.device_count()))
1427
+ if device_map is None
1428
+ else device_map
1429
+ )
1430
+ assert_device_map(self.device_map, len(self.transformer.h))
1431
+ self.transformer.parallelize(self.device_map)
1432
+ self.lm_head = self.lm_head.to(self.transformer.first_device)
1433
+ self.multiple_choice_head = self.multiple_choice_head.to(
1434
+ self.transformer.first_device
1435
+ )
1436
+ self.model_parallel = True
1437
+
1438
+ @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
1439
+ def deparallelize(self):
1440
+ warnings.warn(
1441
+ "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
1442
+ FutureWarning,
1443
+ )
1444
+ self.transformer.deparallelize()
1445
+ self.transformer = self.transformer.to("cpu")
1446
+ self.lm_head = self.lm_head.to("cpu")
1447
+ self.multiple_choice_head = self.multiple_choice_head.to("cpu")
1448
+ self.model_parallel = False
1449
+ torch.cuda.empty_cache()
1450
+
1451
+ def get_output_embeddings(self):
1452
+ return self.lm_head
1453
+
1454
+ def set_output_embeddings(self, new_embeddings):
1455
+ self.lm_head = new_embeddings
1456
+
1457
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
1458
+ token_type_ids = kwargs.get("token_type_ids", None)
1459
+ # Omit tokens covered by past_key_values
1460
+ if past_key_values:
1461
+ past_length = past_key_values[0][0].shape[2]
1462
+
1463
+ # Some generation methods already pass only the last input ID
1464
+ if input_ids.shape[1] > past_length:
1465
+ remove_prefix_length = past_length
1466
+ else:
1467
+ # Default to old behavior: keep only final ID
1468
+ remove_prefix_length = input_ids.shape[1] - 1
1469
+
1470
+ input_ids = input_ids[:, remove_prefix_length:]
1471
+ if token_type_ids is not None:
1472
+ token_type_ids = token_type_ids[:, -input_ids.shape[1] :]
1473
+
1474
+ attention_mask = kwargs.get("attention_mask", None)
1475
+ position_ids = kwargs.get("position_ids", None)
1476
+
1477
+ if attention_mask is not None and position_ids is None:
1478
+ # create position_ids on the fly for batch generation
1479
+ position_ids = attention_mask.long().cumsum(-1) - 1
1480
+ position_ids.masked_fill_(attention_mask == 0, 1)
1481
+ if past_key_values:
1482
+ position_ids = position_ids[:, -input_ids.shape[1] :]
1483
+ else:
1484
+ position_ids = None
1485
+
1486
+ return {
1487
+ "input_ids": input_ids,
1488
+ "past_key_values": past_key_values,
1489
+ "use_cache": kwargs.get("use_cache"),
1490
+ "position_ids": position_ids,
1491
+ "attention_mask": attention_mask,
1492
+ "token_type_ids": token_type_ids,
1493
+ }
1494
+
1495
+ @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
1496
+ @replace_return_docstrings(
1497
+ output_type=AraGPT2DoubleHeadsModelOutput, config_class=_CONFIG_FOR_DOC
1498
+ )
1499
+ def forward(
1500
+ self,
1501
+ input_ids: Optional[torch.LongTensor] = None,
1502
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
1503
+ attention_mask: Optional[torch.FloatTensor] = None,
1504
+ token_type_ids: Optional[torch.LongTensor] = None,
1505
+ position_ids: Optional[torch.LongTensor] = None,
1506
+ head_mask: Optional[torch.FloatTensor] = None,
1507
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1508
+ mc_token_ids: Optional[torch.LongTensor] = None,
1509
+ labels: Optional[torch.LongTensor] = None,
1510
+ mc_labels: Optional[torch.LongTensor] = None,
1511
+ use_cache: Optional[bool] = None,
1512
+ output_attentions: Optional[bool] = None,
1513
+ output_hidden_states: Optional[bool] = None,
1514
+ return_dict: Optional[bool] = None,
1515
+ **kwargs,
1516
+ ) -> Union[Tuple, AraGPT2DoubleHeadsModelOutput]:
1517
+ r"""
1518
+ mc_token_ids (`torch.LongTensor` of shape `(batch_size, num_choices)`, *optional*, default to index of the last token of the input):
1519
+ Index of the classification token in each input sequence. Selected in the range `[0, input_ids.size(-1) -
1520
+ 1]`.
1521
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1522
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
1523
+ `labels = input_ids`. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`. All labels set to
1524
+ `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size - 1]`
1525
+ mc_labels (`torch.LongTensor` of shape `(batch_size)`, *optional*):
1526
+ Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]`
1527
+ where *num_choices* is the size of the second dimension of the input tensors. (see *input_ids* above)
1528
+
1529
+ Return:
1530
+
1531
+ Example:
1532
+
1533
+ ```python
1534
+ >>> import torch
1535
+ >>> from transformers import GPT2Tokenizer, GPT2DoubleHeadsModel
1536
+
1537
+ >>> tokenizer = GPT2Tokenizer.from_pretrained("aubmindlab/aragpt2-mega")
1538
+ >>> model = GPT2DoubleHeadsModel.from_pretrained("aubmindlab/aragpt2-mega")
1539
+
1540
+ >>> # Add a [CLS] to the vocabulary (we should train it also!)
1541
+ >>> num_added_tokens = tokenizer.add_special_tokens({"cls_token": "[CLS]"})
1542
+ >>> # Update the model embeddings with the new vocabulary size
1543
+ >>> embedding_layer = model.resize_token_embeddings(len(tokenizer))
1544
+
1545
+ >>> choices = ["Hello, my dog is cute [CLS]", "Hello, my cat is cute [CLS]"]
1546
+ >>> encoded_choices = [tokenizer.encode(s) for s in choices]
1547
+ >>> cls_token_location = [tokens.index(tokenizer.cls_token_id) for tokens in encoded_choices]
1548
+
1549
+ >>> input_ids = torch.tensor(encoded_choices).unsqueeze(0) # Batch size: 1, number of choices: 2
1550
+ >>> mc_token_ids = torch.tensor([cls_token_location]) # Batch size: 1
1551
+
1552
+ >>> outputs = model(input_ids, mc_token_ids=mc_token_ids)
1553
+ >>> lm_logits = outputs.logits
1554
+ >>> mc_logits = outputs.mc_logits
1555
+ ```"""
1556
+ return_dict = (
1557
+ return_dict if return_dict is not None else self.config.use_return_dict
1558
+ )
1559
+
1560
+ transformer_outputs = self.transformer(
1561
+ input_ids,
1562
+ past_key_values=past_key_values,
1563
+ attention_mask=attention_mask,
1564
+ token_type_ids=token_type_ids,
1565
+ position_ids=position_ids,
1566
+ head_mask=head_mask,
1567
+ inputs_embeds=inputs_embeds,
1568
+ use_cache=use_cache,
1569
+ output_attentions=output_attentions,
1570
+ output_hidden_states=output_hidden_states,
1571
+ return_dict=return_dict,
1572
+ )
1573
+
1574
+ hidden_states = transformer_outputs[0]
1575
+
1576
+ # Set device for model parallelism
1577
+ if self.model_parallel:
1578
+ torch.cuda.set_device(self.transformer.first_device)
1579
+ hidden_states = hidden_states.to(self.lm_head.weight.device)
1580
+
1581
+ lm_logits = self.lm_head(hidden_states)
1582
+ mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids).squeeze(-1)
1583
+
1584
+ mc_loss = None
1585
+ if mc_labels is not None:
1586
+ loss_fct = CrossEntropyLoss()
1587
+ mc_loss = loss_fct(
1588
+ mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1)
1589
+ )
1590
+ lm_loss = None
1591
+ if labels is not None:
1592
+ labels = labels.to(lm_logits.device)
1593
+ shift_logits = lm_logits[..., :-1, :].contiguous()
1594
+ shift_labels = labels[..., 1:].contiguous()
1595
+ loss_fct = CrossEntropyLoss()
1596
+ lm_loss = loss_fct(
1597
+ shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
1598
+ )
1599
+
1600
+ if not return_dict:
1601
+ output = (lm_logits, mc_logits) + transformer_outputs[1:]
1602
+ if mc_loss is not None:
1603
+ output = (mc_loss,) + output
1604
+ return ((lm_loss,) + output) if lm_loss is not None else output
1605
+
1606
+ return AraGPT2DoubleHeadsModelOutput(
1607
+ loss=lm_loss,
1608
+ mc_loss=mc_loss,
1609
+ logits=lm_logits,
1610
+ mc_logits=mc_logits,
1611
+ past_key_values=transformer_outputs.past_key_values,
1612
+ hidden_states=transformer_outputs.hidden_states,
1613
+ attentions=transformer_outputs.attentions,
1614
+ )
1615
+
1616
+ @staticmethod
1617
+ def _reorder_cache(
1618
+ past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
1619
+ ) -> Tuple[Tuple[torch.Tensor]]:
1620
+ """
1621
+ This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
1622
+ [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
1623
+ beam_idx at every generation step.
1624
+ """
1625
+ return tuple(
1626
+ tuple(
1627
+ past_state.index_select(0, beam_idx.to(past_state.device))
1628
+ for past_state in layer_past
1629
+ )
1630
+ for layer_past in past_key_values
1631
+ )
1632
+
1633
+
1634
+ @add_start_docstrings(
1635
+ """
1636
+ The GPT2 Model transformer with a sequence classification head on top (linear layer).
1637
+
1638
+ [`GPT2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1639
+ (e.g. GPT-1) do.
1640
+
1641
+ Since it does classification on the last token, it requires to know the position of the last token. If a
1642
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1643
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1644
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1645
+ each row of the batch).
1646
+ """,
1647
+ AraGPT2_START_DOCSTRING,
1648
+ )
1649
+ class AraGPT2ForSequenceClassification(AraGPT2PreTrainedModel):
1650
+ _keys_to_ignore_on_load_unexpected = [
1651
+ r"h\.\d+\.attn\.masked_bias",
1652
+ r"lm_head.weight",
1653
+ ]
1654
+ _keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"lm_head.weight"]
1655
+
1656
+ def __init__(self, config):
1657
+ super().__init__(config)
1658
+ self.num_labels = config.num_labels
1659
+ self.transformer = AraGPT2Model(config)
1660
+ self.score = nn.Linear(config.n_embd, self.num_labels, bias=False)
1661
+
1662
+ # Model parallel
1663
+ self.model_parallel = False
1664
+ self.device_map = None
1665
+
1666
+ # Initialize weights and apply final processing
1667
+ self.post_init()
1668
+
1669
+ @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
1670
+ @add_code_sample_docstrings(
1671
+ processor_class=_TOKENIZER_FOR_DOC,
1672
+ output_type=SequenceClassifierOutputWithPast,
1673
+ config_class=_CONFIG_FOR_DOC,
1674
+ )
1675
+ def forward(
1676
+ self,
1677
+ input_ids: Optional[torch.LongTensor] = None,
1678
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
1679
+ attention_mask: Optional[torch.FloatTensor] = None,
1680
+ token_type_ids: Optional[torch.LongTensor] = None,
1681
+ position_ids: Optional[torch.LongTensor] = None,
1682
+ head_mask: Optional[torch.FloatTensor] = None,
1683
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1684
+ labels: Optional[torch.LongTensor] = None,
1685
+ use_cache: Optional[bool] = None,
1686
+ output_attentions: Optional[bool] = None,
1687
+ output_hidden_states: Optional[bool] = None,
1688
+ return_dict: Optional[bool] = None,
1689
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1690
+ r"""
1691
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1692
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1693
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1694
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1695
+ """
1696
+ return_dict = (
1697
+ return_dict if return_dict is not None else self.config.use_return_dict
1698
+ )
1699
+
1700
+ transformer_outputs = self.transformer(
1701
+ input_ids,
1702
+ past_key_values=past_key_values,
1703
+ attention_mask=attention_mask,
1704
+ token_type_ids=token_type_ids,
1705
+ position_ids=position_ids,
1706
+ head_mask=head_mask,
1707
+ inputs_embeds=inputs_embeds,
1708
+ use_cache=use_cache,
1709
+ output_attentions=output_attentions,
1710
+ output_hidden_states=output_hidden_states,
1711
+ return_dict=return_dict,
1712
+ )
1713
+ hidden_states = transformer_outputs[0]
1714
+ logits = self.score(hidden_states)
1715
+
1716
+ if input_ids is not None:
1717
+ batch_size, sequence_length = input_ids.shape[:2]
1718
+ else:
1719
+ batch_size, sequence_length = inputs_embeds.shape[:2]
1720
+
1721
+ assert (
1722
+ self.config.pad_token_id is not None or batch_size == 1
1723
+ ), "Cannot handle batch sizes > 1 if no padding token is defined."
1724
+ if self.config.pad_token_id is None:
1725
+ sequence_lengths = -1
1726
+ else:
1727
+ if input_ids is not None:
1728
+ # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
1729
+ sequence_lengths = (
1730
+ torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
1731
+ )
1732
+ sequence_lengths = sequence_lengths % input_ids.shape[-1]
1733
+ sequence_lengths = sequence_lengths.to(logits.device)
1734
+ else:
1735
+ sequence_lengths = -1
1736
+ logger.warning(
1737
+ f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
1738
+ "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
1739
+ )
1740
+
1741
+ pooled_logits = logits[
1742
+ torch.arange(batch_size, device=logits.device), sequence_lengths
1743
+ ]
1744
+
1745
+ loss = None
1746
+ if labels is not None:
1747
+ if self.config.problem_type is None:
1748
+ if self.num_labels == 1:
1749
+ self.config.problem_type = "regression"
1750
+ elif self.num_labels > 1 and (
1751
+ labels.dtype == torch.long or labels.dtype == torch.int
1752
+ ):
1753
+ self.config.problem_type = "single_label_classification"
1754
+ else:
1755
+ self.config.problem_type = "multi_label_classification"
1756
+
1757
+ if self.config.problem_type == "regression":
1758
+ loss_fct = MSELoss()
1759
+ if self.num_labels == 1:
1760
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1761
+ else:
1762
+ loss = loss_fct(pooled_logits, labels)
1763
+ elif self.config.problem_type == "single_label_classification":
1764
+ loss_fct = CrossEntropyLoss()
1765
+ loss = loss_fct(
1766
+ pooled_logits.view(-1, self.num_labels), labels.view(-1)
1767
+ )
1768
+ elif self.config.problem_type == "multi_label_classification":
1769
+ loss_fct = BCEWithLogitsLoss()
1770
+ loss = loss_fct(pooled_logits, labels)
1771
+ if not return_dict:
1772
+ output = (pooled_logits,) + transformer_outputs[1:]
1773
+ return ((loss,) + output) if loss is not None else output
1774
+
1775
+ return SequenceClassifierOutputWithPast(
1776
+ loss=loss,
1777
+ logits=pooled_logits,
1778
+ past_key_values=transformer_outputs.past_key_values,
1779
+ hidden_states=transformer_outputs.hidden_states,
1780
+ attentions=transformer_outputs.attentions,
1781
+ )
1782
+
1783
+
1784
+ @add_start_docstrings(
1785
+ """
1786
+ GPT2 Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
1787
+ Named-Entity-Recognition (NER) tasks.
1788
+ """,
1789
+ AraGPT2_START_DOCSTRING,
1790
+ )
1791
+ class AraGPT2ForTokenClassification(AraGPT2PreTrainedModel):
1792
+ def __init__(self, config):
1793
+ super().__init__(config)
1794
+ self.num_labels = config.num_labels
1795
+
1796
+ self.transformer = AraGPT2Model(config)
1797
+ if (
1798
+ hasattr(config, "classifier_dropout")
1799
+ and config.classifier_dropout is not None
1800
+ ):
1801
+ classifier_dropout = config.classifier_dropout
1802
+ elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None:
1803
+ classifier_dropout = config.hidden_dropout
1804
+ else:
1805
+ classifier_dropout = 0.1
1806
+ self.dropout = nn.Dropout(classifier_dropout)
1807
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1808
+
1809
+ # Model parallel
1810
+ self.model_parallel = False
1811
+ self.device_map = None
1812
+
1813
+ # Initialize weights and apply final processing
1814
+ self.post_init()
1815
+
1816
+ @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
1817
+ # fmt: off
1818
+ @add_code_sample_docstrings(
1819
+ processor_class=_TOKENIZER_FOR_DOC,
1820
+ output_type=TokenClassifierOutput,
1821
+ config_class=_CONFIG_FOR_DOC,
1822
+ )
1823
+ # fmt: on
1824
+ def forward(
1825
+ self,
1826
+ input_ids: Optional[torch.LongTensor] = None,
1827
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
1828
+ attention_mask: Optional[torch.FloatTensor] = None,
1829
+ token_type_ids: Optional[torch.LongTensor] = None,
1830
+ position_ids: Optional[torch.LongTensor] = None,
1831
+ head_mask: Optional[torch.FloatTensor] = None,
1832
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1833
+ labels: Optional[torch.LongTensor] = None,
1834
+ use_cache: Optional[bool] = None,
1835
+ output_attentions: Optional[bool] = None,
1836
+ output_hidden_states: Optional[bool] = None,
1837
+ return_dict: Optional[bool] = None,
1838
+ ) -> Union[Tuple, TokenClassifierOutput]:
1839
+ r"""
1840
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1841
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1842
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1843
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1844
+ """
1845
+ return_dict = (
1846
+ return_dict if return_dict is not None else self.config.use_return_dict
1847
+ )
1848
+
1849
+ transformer_outputs = self.transformer(
1850
+ input_ids,
1851
+ past_key_values=past_key_values,
1852
+ attention_mask=attention_mask,
1853
+ token_type_ids=token_type_ids,
1854
+ position_ids=position_ids,
1855
+ head_mask=head_mask,
1856
+ inputs_embeds=inputs_embeds,
1857
+ use_cache=use_cache,
1858
+ output_attentions=output_attentions,
1859
+ output_hidden_states=output_hidden_states,
1860
+ return_dict=return_dict,
1861
+ )
1862
+
1863
+ hidden_states = transformer_outputs[0]
1864
+ hidden_states = self.dropout(hidden_states)
1865
+ logits = self.classifier(hidden_states)
1866
+
1867
+ loss = None
1868
+ if labels is not None:
1869
+ labels = labels.to(logits.device)
1870
+ loss_fct = CrossEntropyLoss()
1871
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1872
+
1873
+ if not return_dict:
1874
+ output = (logits,) + transformer_outputs[2:]
1875
+ return ((loss,) + output) if loss is not None else output
1876
+
1877
+ return TokenClassifierOutput(
1878
+ loss=loss,
1879
+ logits=logits,
1880
+ hidden_states=transformer_outputs.hidden_states,
1881
+ attentions=transformer_outputs.attentions,
1882
+ )
1883
+
1884
+
1885
+ @add_start_docstrings(
1886
+ """
1887
+ The AraGPT2 Model transformer with a span classification head on top for extractive question-answering tasks like
1888
+ SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
1889
+ """,
1890
+ AraGPT2_START_DOCSTRING,
1891
+ )
1892
+ class AraGPT2ForQuestionAnswering(AraGPT2PreTrainedModel):
1893
+ def __init__(self, config):
1894
+ super().__init__(config)
1895
+ self.num_labels = config.num_labels
1896
+ self.transformer = AraGPT2Model(config)
1897
+ self.qa_outputs = nn.Linear(config.hidden_size, 2)
1898
+
1899
+ # Model parallel
1900
+ self.model_parallel = False
1901
+ self.device_map = None
1902
+
1903
+ # Initialize weights and apply final processing
1904
+ self.post_init()
1905
+
1906
+ @add_start_docstrings_to_model_forward(
1907
+ GPT2_INPUTS_DOCSTRING.format("batch_size, sequence_length")
1908
+ )
1909
+ @add_code_sample_docstrings(
1910
+ checkpoint=_CHECKPOINT_FOR_DOC,
1911
+ output_type=QuestionAnsweringModelOutput,
1912
+ config_class=_CONFIG_FOR_DOC,
1913
+ real_checkpoint=_CHECKPOINT_FOR_DOC,
1914
+ )
1915
+ def forward(
1916
+ self,
1917
+ input_ids: Optional[torch.LongTensor] = None,
1918
+ attention_mask: Optional[torch.FloatTensor] = None,
1919
+ token_type_ids: Optional[torch.LongTensor] = None,
1920
+ position_ids: Optional[torch.LongTensor] = None,
1921
+ head_mask: Optional[torch.FloatTensor] = None,
1922
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1923
+ start_positions: Optional[torch.LongTensor] = None,
1924
+ end_positions: Optional[torch.LongTensor] = None,
1925
+ output_attentions: Optional[bool] = None,
1926
+ output_hidden_states: Optional[bool] = None,
1927
+ return_dict: Optional[bool] = None,
1928
+ ) -> Union[Tuple, QuestionAnsweringModelOutput]:
1929
+ r"""
1930
+ start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1931
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
1932
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1933
+ are not taken into account for computing the loss.
1934
+ end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1935
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
1936
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1937
+ are not taken into account for computing the loss.
1938
+ """
1939
+ return_dict = (
1940
+ return_dict if return_dict is not None else self.config.use_return_dict
1941
+ )
1942
+
1943
+ outputs = self.transformer(
1944
+ input_ids,
1945
+ attention_mask=attention_mask,
1946
+ token_type_ids=token_type_ids,
1947
+ position_ids=position_ids,
1948
+ head_mask=head_mask,
1949
+ inputs_embeds=inputs_embeds,
1950
+ output_attentions=output_attentions,
1951
+ output_hidden_states=output_hidden_states,
1952
+ return_dict=return_dict,
1953
+ )
1954
+
1955
+ sequence_output = outputs[0]
1956
+
1957
+ logits = self.qa_outputs(sequence_output)
1958
+ start_logits, end_logits = logits.split(1, dim=-1)
1959
+ start_logits = start_logits.squeeze(-1).contiguous()
1960
+ end_logits = end_logits.squeeze(-1).contiguous()
1961
+
1962
+ total_loss = None
1963
+ if start_positions is not None and end_positions is not None:
1964
+ # If we are on multi-GPU, split add a dimension
1965
+ if len(start_positions.size()) > 1:
1966
+ start_positions = start_positions.squeeze(-1).to(start_logits.device)
1967
+ if len(end_positions.size()) > 1:
1968
+ end_positions = end_positions.squeeze(-1).to(end_logits.device)
1969
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
1970
+ ignored_index = start_logits.size(1)
1971
+ start_positions = start_positions.clamp(0, ignored_index)
1972
+ end_positions = end_positions.clamp(0, ignored_index)
1973
+
1974
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
1975
+ start_loss = loss_fct(start_logits, start_positions)
1976
+ end_loss = loss_fct(end_logits, end_positions)
1977
+ total_loss = (start_loss + end_loss) / 2
1978
+
1979
+ if not return_dict:
1980
+ output = (start_logits, end_logits) + outputs[2:]
1981
+ return ((total_loss,) + output) if total_loss is not None else output
1982
+
1983
+ return QuestionAnsweringModelOutput(
1984
+ loss=total_loss,
1985
+ start_logits=start_logits,
1986
+ end_logits=end_logits,
1987
+ hidden_states=outputs.hidden_states,
1988
+ attentions=outputs.attentions,
1989
+ )