bourdoiscatie commited on
Commit
c98d82a
1 Parent(s): bbd070b

Upload 13 files

Browse files
config.json CHANGED
@@ -4,18 +4,13 @@
4
  "FlashT5ForConditionalGeneration"
5
  ],
6
  "attention_dropout_rate": 0.0,
7
- "attention_scale": 1.0,
8
- "attention_type": "ref",
9
  "auto_map": {
10
  "AutoConfig": "configuration_flash_t5.FlashT5Config",
11
  "AutoModel": "modeling_flash_t5.FlashT5ForConditionalGeneration",
12
- "AutoModelForQuestionAnswering": "custom_heads_flash_t5.FlashT5ForQuestionAnswering",
13
- "AutoModelForSeq2SeqLM": "modeling_flash_t5.FlashT5ForConditionalGeneration",
14
- "AutoModelForSequenceClassification": "custom_heads_flash_t5.FlashT5ForSequenceClassification",
15
- "AutoModelForTokenClassification": "custom_heads_flash_t5.FlashT5ForTokenClassification"
16
  },
17
  "classifier_dropout": 0.0,
18
- "d_ff": 1024,
19
  "d_kv": 64,
20
  "d_model": 512,
21
  "decoder_start_token_id": 0,
@@ -23,17 +18,16 @@
23
  "dropout_rate": 0.0,
24
  "eos_token_id": 1,
25
  "feed_forward_proj": "relu",
26
- "fire_mlp_width": 32,
27
  "initializer_factor": 1.0,
28
  "is_encoder_decoder": false,
29
  "is_gated_act": false,
30
- "label_smoothing": 0.1,
31
  "layer_norm_epsilon": 1e-06,
32
  "max_sequence_length": 1024,
33
  "model_type": "flash_t5",
34
- "num_decoder_layers": 8,
35
- "num_heads": 6,
36
- "num_layers": 8,
37
  "pad_token_id": 0,
38
  "position_encoding_type": "t5",
39
  "relative_attention_max_distance": 128,
@@ -44,17 +38,16 @@
44
  "rotary_scale_base": null,
45
  "tie_word_embeddings": false,
46
  "torch_dtype": "float32",
47
- "transformers_version": "4.39.3",
48
  "use_cache": true,
49
- "use_flash_attention": "triton",
50
  "use_full_bias_size": false,
51
  "use_gelu_act": true,
52
  "use_glu_mlp": true,
53
- "use_masking": false,
54
  "use_randomized_position_encoding": false,
55
- "use_triton_crossentropy": false,
56
  "use_triton_gated_mlp": false,
57
  "use_triton_layernorm": false,
58
- "vocab_size": 32128,
59
  "z_loss": 0.0001
60
  }
 
4
  "FlashT5ForConditionalGeneration"
5
  ],
6
  "attention_dropout_rate": 0.0,
 
 
7
  "auto_map": {
8
  "AutoConfig": "configuration_flash_t5.FlashT5Config",
9
  "AutoModel": "modeling_flash_t5.FlashT5ForConditionalGeneration",
10
+ "AutoModelForSeq2SeqLM": "modeling_flash_t5.FlashT5ForConditionalGeneration"
 
 
 
11
  },
12
  "classifier_dropout": 0.0,
13
+ "d_ff": 2048,
14
  "d_kv": 64,
15
  "d_model": 512,
16
  "decoder_start_token_id": 0,
 
18
  "dropout_rate": 0.0,
19
  "eos_token_id": 1,
20
  "feed_forward_proj": "relu",
 
21
  "initializer_factor": 1.0,
22
  "is_encoder_decoder": false,
23
  "is_gated_act": false,
24
+ "label_smoothing": 0.0,
25
  "layer_norm_epsilon": 1e-06,
26
  "max_sequence_length": 1024,
27
  "model_type": "flash_t5",
28
+ "num_decoder_layers": 12,
29
+ "num_heads": 8,
30
+ "num_layers": 12,
31
  "pad_token_id": 0,
32
  "position_encoding_type": "t5",
33
  "relative_attention_max_distance": 128,
 
38
  "rotary_scale_base": null,
39
  "tie_word_embeddings": false,
40
  "torch_dtype": "float32",
41
+ "transformers_version": "4.37.2",
42
  "use_cache": true,
43
+ "use_flash_attention": "ref",
44
  "use_full_bias_size": false,
45
  "use_gelu_act": true,
46
  "use_glu_mlp": true,
 
47
  "use_randomized_position_encoding": false,
48
+ "use_triton_crossentropy": true,
49
  "use_triton_gated_mlp": false,
50
  "use_triton_layernorm": false,
51
+ "vocab_size": 32768,
52
  "z_loss": 0.0001
53
  }
custom_heads_flash_t5(1).py ADDED
@@ -0,0 +1,404 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
4
+ import copy
5
+ from typing import Optional, Union, Tuple, List
6
+ from transformers.modeling_outputs import (
7
+ Seq2SeqQuestionAnsweringModelOutput,
8
+ QuestionAnsweringModelOutput,
9
+ TokenClassifierOutput,
10
+ BaseModelOutput,
11
+ Seq2SeqSequenceClassifierOutput,
12
+ SequenceClassifierOutput
13
+ )
14
+
15
+ from .modeling_flash_t5 import FlashT5PreTrainedModel, FlashT5Stack, FlashT5Model, FlashT5EncoderModel
16
+ from .configuration_flash_t5 import FlashT5Config
17
+
18
+
19
+ ################## Encoder only head ##################
20
+ class FlashT5ForTokenClassification(FlashT5PreTrainedModel):
21
+
22
+ def __init__(self, config: FlashT5Config):
23
+ super().__init__(config)
24
+ self.num_labels = config.num_labels
25
+ self.shared = nn.Embedding(config.vocab_size, config.d_model)
26
+
27
+ self.encoder = FlashT5Stack(config, self.shared)
28
+ self.dropout = nn.Dropout(config.classifier_dropout)
29
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
30
+
31
+ # Initialize weights and apply final processing
32
+ self.post_init()
33
+
34
+ # Initialize classifier
35
+ self.classifier.weight.data.normal_(mean=0.0, std=config.initializer_factor * 1.0)
36
+ self.classifier.bias.data.zero_()
37
+
38
+ self.model_parallel = False
39
+
40
+ def forward(
41
+ self,
42
+ input_ids: Optional[torch.Tensor] = None,
43
+ attention_mask: Optional[torch.Tensor] = None,
44
+ head_mask: Optional[torch.Tensor] = None,
45
+ inputs_embeds: Optional[torch.Tensor] = None,
46
+ labels: Optional[torch.Tensor] = None,
47
+ output_attentions: Optional[bool] = None,
48
+ output_hidden_states: Optional[bool] = None,
49
+ return_dict: Optional[bool] = None,
50
+ ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
51
+ r"""
52
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
53
+ Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
54
+ Returns:
55
+ """
56
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
57
+
58
+ outputs = self.encoder(
59
+ input_ids=input_ids,
60
+ attention_mask=attention_mask,
61
+ inputs_embeds=inputs_embeds,
62
+ head_mask=head_mask,
63
+ output_attentions=output_attentions,
64
+ output_hidden_states=output_hidden_states,
65
+ return_dict=return_dict,
66
+ )
67
+
68
+ hidden_states = outputs[0]
69
+ hidden_states = self.dropout(hidden_states)
70
+ logits = self.classifier(hidden_states)
71
+
72
+ loss = None
73
+ if labels is not None:
74
+ loss_fct = nn.CrossEntropyLoss()
75
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
76
+
77
+ if not return_dict:
78
+ output = (logits, outputs[2:-1])
79
+ return ((loss,) + output) if loss is not None else output
80
+
81
+ return TokenClassifierOutput(
82
+ loss=loss,
83
+ logits=logits,
84
+ hidden_states=outputs.hidden_states,
85
+ attentions=outputs.attentions,
86
+ )
87
+
88
+
89
+ class FlashT5ClassificationHead(nn.Module):
90
+ """Head for sentence-level classification tasks."""
91
+
92
+ def __init__(self, config: FlashT5Config):
93
+ super().__init__()
94
+ self.dense = nn.Linear(config.d_model, config.d_model)
95
+ self.dropout = nn.Dropout(p=config.classifier_dropout)
96
+ self.out_proj = nn.Linear(config.d_model, config.num_labels)
97
+
98
+ # initialize weights
99
+ factor = config.initializer_factor
100
+ self.dense.weight.data.normal_(mean=0.0, std=factor * ((config.d_model) ** -0.5))
101
+ if hasattr(self.dense, "bias") and self.dense.bias is not None:
102
+ self.dense.bias.data.zero_()
103
+ self.out_proj.weight.data.normal_(mean=0.0, std=factor * ((config.d_model) ** -0.5))
104
+ if hasattr(self.out_proj, "bias") and self.out_proj.bias is not None:
105
+ self.out_proj.bias.data.zero_()
106
+
107
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
108
+ hidden_states = self.dropout(hidden_states)
109
+ hidden_states = self.dense(hidden_states)
110
+ hidden_states = torch.tanh(hidden_states)
111
+ hidden_states = self.dropout(hidden_states)
112
+ hidden_states = self.out_proj(hidden_states)
113
+ return hidden_states
114
+
115
+
116
+ class FlashT5ForSequenceClassification(FlashT5PreTrainedModel):
117
+ _keys_to_ignore_on_load_missing = [r"encoder.embed_tokens.weight"]
118
+
119
+ def __init__(self, config: FlashT5Config):
120
+ super().__init__(config)
121
+ self.model_dim = config.d_model
122
+ self.config.problem_type = None
123
+ self.config.is_encoder_decoder = False
124
+
125
+ self.shared = nn.Embedding(config.vocab_size, config.d_model)
126
+
127
+ encoder_config = copy.deepcopy(config)
128
+ encoder_config.is_decoder = False
129
+ encoder_config.is_encoder_decoder = False
130
+ encoder_config.use_cache = False
131
+ self.encoder = FlashT5Stack(encoder_config, self.shared)
132
+ self.classification_head = FlashT5ClassificationHead(config)
133
+
134
+ # Initialize weights and apply final processing
135
+ self.post_init()
136
+
137
+ self.model_parallel = False
138
+
139
+ def forward(
140
+ self,
141
+ input_ids: torch.LongTensor = None,
142
+ attention_mask: Optional[torch.Tensor] = None,
143
+ head_mask: Optional[torch.Tensor] = None,
144
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
145
+ encoder_outputs: Optional[List[torch.FloatTensor]] = None,
146
+ inputs_embeds: Optional[torch.FloatTensor] = None,
147
+ labels: Optional[torch.LongTensor] = None,
148
+ use_cache: Optional[bool] = None,
149
+ output_attentions: Optional[bool] = None,
150
+ output_hidden_states: Optional[bool] = None,
151
+ return_dict: Optional[bool] = None,
152
+ ) -> Union[Tuple, Seq2SeqSequenceClassifierOutput]:
153
+ r"""
154
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
155
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
156
+ config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
157
+ Returns:
158
+ """
159
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
160
+ if labels is not None:
161
+ use_cache = False
162
+
163
+ if input_ids is None and inputs_embeds is not None:
164
+ raise NotImplementedError(
165
+ f"Passing input embeddings is currently not supported for {self.__class__.__name__}"
166
+ )
167
+
168
+
169
+ outputs = self.encoder(
170
+ input_ids=input_ids,
171
+ attention_mask=attention_mask,
172
+ inputs_embeds=inputs_embeds,
173
+ head_mask=head_mask,
174
+ output_attentions=output_attentions,
175
+ output_hidden_states=output_hidden_states,
176
+ return_dict=return_dict,
177
+ )
178
+ sequence_output = outputs[0]
179
+
180
+ eos_mask = input_ids.eq(self.config.eos_token_id).to(sequence_output.device)
181
+
182
+ if len(torch.unique_consecutive(eos_mask.sum(1))) > 1:
183
+ raise ValueError("All examples must have the same number of <eos> tokens.")
184
+ batch_size, _, hidden_size = sequence_output.shape
185
+ sentence_representation = sequence_output[eos_mask, :].view(batch_size, -1, hidden_size)[:, -1, :]
186
+ logits = self.classification_head(sentence_representation)
187
+
188
+ loss = None
189
+ if labels is not None:
190
+ labels = labels.to(logits.device)
191
+ if self.config.problem_type is None:
192
+ if self.config.num_labels == 1:
193
+ self.config.problem_type = "regression"
194
+ elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
195
+ self.config.problem_type = "single_label_classification"
196
+ else:
197
+ self.config.problem_type = "multi_label_classification"
198
+
199
+ if self.config.problem_type == "regression":
200
+ loss_fct = nn.MSELoss()
201
+ if self.config.num_labels == 1:
202
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
203
+ else:
204
+ loss = loss_fct(logits, labels)
205
+ elif self.config.problem_type == "single_label_classification":
206
+ loss_fct = nn.CrossEntropyLoss()
207
+ loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
208
+ elif self.config.problem_type == "multi_label_classification":
209
+ loss_fct = nn.BCEWithLogitsLoss()
210
+ loss = loss_fct(logits, labels)
211
+ if not return_dict:
212
+ output = (logits,) + outputs[1:]
213
+ return ((loss,) + output) if loss is not None else output
214
+
215
+ return SequenceClassifierOutput(
216
+ loss=loss,
217
+ logits=logits,
218
+ hidden_states=outputs.hidden_states,
219
+ attentions=outputs.attentions
220
+ )
221
+
222
+
223
+ class FlashT5ForQuestionAnswering(FlashT5PreTrainedModel):
224
+ _keys_to_ignore_on_load_missing = [r"encoder.embed_tokens.weight"]
225
+
226
+ def __init__(self, config: FlashT5Config):
227
+ super().__init__(config)
228
+ self.transformer = FlashT5EncoderModel(config)
229
+
230
+ self.num_labels = config.num_labels
231
+ self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
232
+
233
+ # Initialize weights and apply final processing
234
+ self.post_init()
235
+
236
+ # Model parallel
237
+ self.model_parallel = False
238
+
239
+ def forward(
240
+ self,
241
+ input_ids: Optional[torch.LongTensor] = None,
242
+ attention_mask: Optional[torch.FloatTensor] = None,
243
+ head_mask: Optional[torch.FloatTensor] = None,
244
+ inputs_embeds: Optional[torch.FloatTensor] = None,
245
+ start_positions: Optional[torch.Tensor] = None,
246
+ end_positions: Optional[torch.Tensor] = None,
247
+ output_attentions: Optional[bool] = None,
248
+ output_hidden_states: Optional[bool] = None,
249
+ return_dict: Optional[bool] = None,
250
+ ) -> Union[Tuple[torch.FloatTensor], QuestionAnsweringModelOutput]:
251
+ r"""
252
+ start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
253
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
254
+ Positions are clamped to the length of the sequence (*sequence_length*). Position outside of the sequence
255
+ are not taken into account for computing the loss.
256
+ end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
257
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
258
+ Positions are clamped to the length of the sequence (*sequence_length*). Position outside of the sequence
259
+ are not taken into account for computing the loss.
260
+
261
+ Returns:
262
+ """
263
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
264
+
265
+ encoder_outputs = self.transformer(
266
+ input_ids=input_ids,
267
+ attention_mask=attention_mask,
268
+ inputs_embeds=inputs_embeds,
269
+ head_mask=head_mask,
270
+ output_attentions=output_attentions,
271
+ output_hidden_states=output_hidden_states,
272
+ return_dict=return_dict,
273
+ )
274
+
275
+ sequence_output = encoder_outputs[0]
276
+
277
+ logits = self.qa_outputs(sequence_output)
278
+ start_logits, end_logits = logits.split(1, dim=-1)
279
+ start_logits = start_logits.squeeze(-1).contiguous()
280
+ end_logits = end_logits.squeeze(-1).contiguous()
281
+
282
+ total_loss = None
283
+ if start_positions is not None and end_positions is not None:
284
+ # If we are on multi-GPU, split add a dimension
285
+ if len(start_positions.size()) > 1:
286
+ start_positions = start_positions.squeeze(-1).to(start_logits.device)
287
+ if len(end_positions.size()) > 1:
288
+ end_positions = end_positions.squeeze(-1).to(end_logits.device)
289
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
290
+ ignored_index = start_logits.size(1)
291
+ start_positions = start_positions.clamp(0, ignored_index)
292
+ end_positions = end_positions.clamp(0, ignored_index)
293
+
294
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
295
+ start_loss = loss_fct(start_logits, start_positions)
296
+ end_loss = loss_fct(end_logits, end_positions)
297
+ total_loss = (start_loss + end_loss) / 2
298
+
299
+ if not return_dict:
300
+ output = (start_logits, end_logits) + encoder_outputs[1:]
301
+ return ((total_loss,) + output) if total_loss is not None else output
302
+
303
+ return QuestionAnsweringModelOutput(
304
+ loss=total_loss,
305
+ start_logits=start_logits,
306
+ end_logits=end_logits,
307
+ hidden_states=encoder_outputs.hidden_states,
308
+ attentions=encoder_outputs.attentions,
309
+ )
310
+
311
+
312
+
313
+ class FlashT5ForQuestionAnswering(FlashT5PreTrainedModel):
314
+ _keys_to_ignore_on_load_missing = [r"encoder.embed_tokens.weight"]
315
+
316
+ def __init__(self, config: FlashT5Config):
317
+ super().__init__(config)
318
+ self.shared = nn.Embedding(config.vocab_size, config.d_model)
319
+
320
+ encoder_config = copy.deepcopy(config)
321
+ encoder_config.is_decoder = False
322
+ encoder_config.is_encoder_decoder = False
323
+ self.encoder = FlashT5Stack(encoder_config, self.shared)
324
+ self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
325
+
326
+ # Initialize weights and apply final processing
327
+ self.post_init()
328
+
329
+ self.qa_outputs.weight.data.normal_(mean=0.0, std=config.initializer_factor * 1.0)
330
+ self.qa_outputs.bias.data.zero_()
331
+
332
+ self.model_parallel = False
333
+
334
+ def forward(
335
+ self,
336
+ input_ids: Optional[torch.LongTensor] = None,
337
+ attention_mask: Optional[torch.FloatTensor] = None,
338
+ head_mask: Optional[torch.FloatTensor] = None,
339
+ inputs_embeds: Optional[torch.FloatTensor] = None,
340
+ start_positions: Optional[torch.LongTensor] = None,
341
+ end_positions: Optional[torch.LongTensor] = None,
342
+ output_attentions: Optional[bool] = None,
343
+ output_hidden_states: Optional[bool] = None,
344
+ return_dict: Optional[bool] = None,
345
+ ) -> Union[Tuple, QuestionAnsweringModelOutput]:
346
+ r"""
347
+ Returns:
348
+
349
+ Example:
350
+
351
+ ```python
352
+ >>> from transformers import AutoTokenizer, MTxEncoderForQuestionAnswering
353
+
354
+ >>> tokenizer = AutoTokenizer.from_pretrained("MTx-small")
355
+ >>> model = MTxEncoderForQuestionAnswering.from_pretrained("MTx-small")
356
+ >>> input_ids = tokenizer(
357
+ ... "Studies have been shown that owning a dog is good for you", return_tensors="pt"
358
+ ... ).input_ids # Batch size 1
359
+ >>> outputs = model(input_ids=input_ids)
360
+ >>> start_logits = outputs.start_logits
361
+ >>> end_logits = outputs.end_logits
362
+ ```"""
363
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
364
+
365
+ outputs = self.encoder(
366
+ input_ids,
367
+ attention_mask=attention_mask,
368
+ inputs_embeds=inputs_embeds,
369
+ )
370
+ sequence_output = outputs[0]
371
+
372
+ logits = self.qa_outputs(sequence_output)
373
+ start_logits, end_logits = logits.split(1, dim=-1)
374
+ start_logits = start_logits.squeeze(-1).contiguous()
375
+ end_logits = end_logits.squeeze(-1).contiguous()
376
+
377
+ total_loss = None
378
+ if start_positions is not None and end_positions is not None:
379
+ # If we are on multi-GPU, split add a dimension
380
+ if len(start_positions.size()) > 1:
381
+ start_positions = start_positions.squeeze(-1).to(start_logits.device)
382
+ if len(end_positions.size()) > 1:
383
+ end_positions = end_positions.squeeze(-1).to(end_logits.device)
384
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
385
+ ignored_index = start_logits.size(1)
386
+ start_positions = start_positions.clamp(0, ignored_index)
387
+ end_positions = end_positions.clamp(0, ignored_index)
388
+
389
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
390
+ start_loss = loss_fct(start_logits, start_positions)
391
+ end_loss = loss_fct(end_logits, end_positions)
392
+ total_loss = (start_loss + end_loss) / 2
393
+
394
+ if not return_dict:
395
+ output = (start_logits, end_logits) + outputs[1:]
396
+ return ((total_loss,) + output) if total_loss is not None else output
397
+
398
+ return QuestionAnsweringModelOutput(
399
+ loss=total_loss,
400
+ start_logits=start_logits,
401
+ end_logits=end_logits,
402
+ hidden_states=outputs.hidden_states,
403
+ attentions=outputs.attentions,
404
+ )
generation_config.json CHANGED
@@ -3,5 +3,5 @@
3
  "decoder_start_token_id": 0,
4
  "eos_token_id": 1,
5
  "pad_token_id": 0,
6
- "transformers_version": "4.39.3"
7
  }
 
3
  "decoder_start_token_id": 0,
4
  "eos_token_id": 1,
5
  "pad_token_id": 0,
6
+ "transformers_version": "4.37.2"
7
  }
modeling_flash_t5(1).py ADDED
@@ -0,0 +1,840 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # From: https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py
2
+
3
+ from dataclasses import dataclass
4
+
5
+ import copy
6
+ import math
7
+ from typing import Optional, Tuple, Union
8
+
9
+ import torch
10
+ from torch import nn
11
+ from torch.nn import CrossEntropyLoss
12
+ import torch.nn.functional as F
13
+
14
+ from transformers.modeling_utils import ModuleUtilsMixin
15
+ from transformers.modeling_outputs import ModelOutput, Seq2SeqModelOutput, BaseModelOutput
16
+ from transformers import PreTrainedModel
17
+
18
+ try:
19
+ from .rms_norm import fast_rms_layernorm
20
+ except ImportError:
21
+ fast_rms_layernorm = None
22
+
23
+ try:
24
+ from .cross_entropy_loss import fast_cross_entropy_loss
25
+ except ImportError:
26
+ fast_cross_entropy_loss = None
27
+
28
+ try:
29
+ from .flash_attention_v2_bias import attention as flash_attention_triton
30
+ except ImportError:
31
+ fast_cross_entropy_loss = None
32
+
33
+ try:
34
+ from .gated_mlp import gated_mlp
35
+ except ImportError:
36
+ gated_mlp = None
37
+
38
+ try:
39
+ #from flash_attn import flash_attn_kvpacked_func, flash_attn_func
40
+ from .fa2_compilable import flash_attn_kvpacked_func, flash_attn_func
41
+ except ImportError:
42
+ flash_attn_kvpacked_func, flash_attn_func = None, None
43
+
44
+ from .attn_ref import attn_ref
45
+
46
+ from .configuration_flash_t5 import FlashT5Config
47
+ from .positional_encoding import ALiBiPositionalEncoding, RelativePositionalEncoding, RotaryPositionalEncoding
48
+
49
+ @dataclass
50
+ class EncoderOutput(ModelOutput):
51
+ hidden_states: torch.FloatTensor = None
52
+ attention_mask: torch.FloatTensor = None
53
+
54
+ @dataclass
55
+ class Seq2SeqLMOutput(ModelOutput):
56
+ loss: torch.FloatTensor = None
57
+ logits: torch.FloatTensor = None
58
+ encoder_outputs: EncoderOutput = None
59
+
60
+
61
+ class FlashT5CrossEntropyLoss(nn.Module):
62
+ def __init__(self, z_loss_factor=0.0, label_smoothing=0.0, use_triton_crossentropy=False):
63
+
64
+ super().__init__()
65
+
66
+ if use_triton_crossentropy and fast_cross_entropy_loss is None:
67
+ raise ImportError("fast_cross_entropy_loss is not available")
68
+
69
+ self.use_triton_crossentropy = use_triton_crossentropy
70
+ self.z_loss_factor = z_loss_factor
71
+
72
+ self.cross_entropy_loss = nn.CrossEntropyLoss(label_smoothing=label_smoothing)
73
+
74
+ def compute_zloss(self, logits: torch.Tensor, z_loss: float):
75
+ logits_sum = torch.logsumexp(logits, dim=-1, keepdim=True)
76
+ log_z = torch.squeeze(logits_sum, axis=-1)
77
+ total_z_loss = z_loss * torch.square(log_z)
78
+ return total_z_loss.mean()
79
+
80
+ def forward(self, logits, labels):
81
+
82
+ if self.use_triton_crossentropy:
83
+ return fast_cross_entropy_loss(logits, labels, z_loss_factor=self.z_loss_factor)
84
+
85
+ # use standard method
86
+ batch, seq_len, d = logits.shape
87
+ logits_flatten = logits.float().view(batch*seq_len, d) # Must cast to float32 for numerical stability
88
+ labels_flatten = labels.view(-1)
89
+ loss = self.cross_entropy_loss(logits_flatten, labels_flatten)
90
+ z_loss = 0.0
91
+ if self.z_loss_factor != 0.0:
92
+ z_loss = self.compute_zloss(logits_flatten[labels_flatten != -100],
93
+ z_loss=self.z_loss_factor)
94
+ return loss, z_loss
95
+
96
+ class FlashT5LayerNorm(nn.Module):
97
+ def __init__(self, hidden_size, eps=1e-6, use_triton_layernorm=False):
98
+ """
99
+ Construct a layernorm module in the T5 style. No bias and no subtraction of mean.
100
+ """
101
+ super().__init__()
102
+
103
+ if use_triton_layernorm and fast_rms_layernorm is None:
104
+ raise ImportError("fast_rms_layernorm is not available")
105
+
106
+ self.use_triton_layernorm = use_triton_layernorm
107
+ self.weight = nn.Parameter(torch.ones(hidden_size))
108
+ self.variance_epsilon = eps
109
+
110
+ def forward(self, hidden_states):
111
+
112
+ if self.use_triton_layernorm:
113
+ return fast_rms_layernorm(hidden_states, self.weight, self.variance_epsilon)
114
+
115
+ # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
116
+ # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated
117
+ # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
118
+ # half-precision inputs is done in fp32
119
+
120
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
121
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
122
+
123
+ # convert into half-precision if necessary
124
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
125
+ hidden_states = hidden_states.to(self.weight.dtype)
126
+
127
+ return self.weight * hidden_states
128
+
129
+ class FlashT5DenseAct(nn.Module):
130
+ def __init__(self, config: FlashT5Config):
131
+ super().__init__()
132
+ self.wi = nn.Linear(config.d_model, config.d_ff, bias=False)
133
+ self.dropout = nn.Dropout(config.dropout_rate)
134
+ self.act = torch.nn.GELU(approximate='tanh') if config.use_gelu_act else torch.nn.ReLU()
135
+
136
+ def forward(self, hidden_states):
137
+ hidden_states = self.wi(hidden_states)
138
+ hidden_states = self.act(hidden_states)
139
+ hidden_states = self.dropout(hidden_states)
140
+ if (
141
+ isinstance(self.wo.weight, torch.Tensor)
142
+ and hidden_states.dtype != self.wo.weight.dtype
143
+ and self.wo.weight.dtype != torch.int8
144
+ ):
145
+ hidden_states = hidden_states.to(self.wo.weight.dtype)
146
+
147
+ return hidden_states
148
+
149
+ class FlashT5DenseGatedAct(nn.Module):
150
+ def __init__(self, config: FlashT5Config):
151
+ super().__init__()
152
+ self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False)
153
+ self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False)
154
+ self.dropout = nn.Dropout(config.dropout_rate)
155
+ self.act = torch.nn.GELU(approximate='tanh') if config.use_gelu_act else torch.nn.ReLU()
156
+
157
+ self.use_triton_gated_mlp = config.use_triton_gated_mlp
158
+ if self.use_triton_gated_mlp and gated_mlp is None:
159
+ raise ImportError("gated_mlp is not available")
160
+ self.use_gelu_act = config.use_gelu_act
161
+
162
+ def forward(self, hidden_states):
163
+
164
+ if self.use_triton_gated_mlp:
165
+ return gated_mlp(hidden_states, self.wi_0.weight, self.wi_1.weight, self.use_gelu_act)
166
+
167
+ hidden_act = self.act(self.wi_0(hidden_states))
168
+ hidden_linear = self.wi_1(hidden_states)
169
+ hidden_states = hidden_act * hidden_linear
170
+ hidden_states = self.dropout(hidden_states)
171
+
172
+ return hidden_states
173
+
174
+ class FlashT5LayerFF(nn.Module):
175
+ def __init__(self, config: FlashT5Config):
176
+ super().__init__()
177
+ if config.use_glu_mlp:
178
+ self.act = FlashT5DenseGatedAct(config)
179
+ else:
180
+ self.act = FlashT5DenseAct(config)
181
+
182
+ self.layer_norm = FlashT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon, use_triton_layernorm=config.use_triton_layernorm)
183
+ self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
184
+ self.dropout = nn.Dropout(config.dropout_rate)
185
+
186
+ def forward(self, hidden_states):
187
+ forwarded_states = self.layer_norm(hidden_states).type_as(hidden_states)
188
+ forwarded_states = self.act(forwarded_states)
189
+ forwarded_states = self.wo(forwarded_states)
190
+ hidden_states = hidden_states + self.dropout(forwarded_states)
191
+ return hidden_states
192
+
193
+
194
+ class FlashT5Attention(nn.Module, ModuleUtilsMixin):
195
+ def __init__(self, config: FlashT5Config, has_positional_encoding=False, is_causal=False):
196
+ super().__init__()
197
+ self.is_decoder = config.is_decoder
198
+ self.has_positional_encoding = has_positional_encoding
199
+ self.is_causal = is_causal
200
+ self.relative_attention_num_buckets = config.relative_attention_num_buckets
201
+ self.relative_attention_max_distance = config.relative_attention_max_distance
202
+ self.d_model = config.d_model
203
+ self.key_value_proj_dim = config.d_kv
204
+ self.n_heads = config.num_heads
205
+ self.p_dropout = config.attention_dropout_rate
206
+ self.inner_dim = self.n_heads * self.key_value_proj_dim
207
+ self.use_flash_attention = config.use_flash_attention
208
+ self.position_encoding_type = config.position_encoding_type
209
+ self.max_sequence_length = config.max_sequence_length
210
+ self.softmax_scale = 1.0/math.sqrt(self.n_heads)
211
+ self.use_full_bias_size = config.use_full_bias_size
212
+
213
+ if self.use_flash_attention == "triton" and flash_attention_triton is None:
214
+ raise ImportError("flash_attention_triton is not available")
215
+ elif self.use_flash_attention == "fa2" and flash_attn_func is None:
216
+ raise ImportError("Flash Attention 2 is not available")
217
+
218
+ assert (self.p_dropout == 0.0) or (self.use_flash_attention != "triton"), "Triton attention does not support dropout"
219
+
220
+ self.pe_encoding = None
221
+ if self.position_encoding_type == "ALiBi" and has_positional_encoding:
222
+ # build alibi matrix with an upper bound on seq length
223
+ self.pe_encoding = ALiBiPositionalEncoding(self.max_sequence_length, self.n_heads, config.alibi_mode, config.use_randomized_position_encoding)
224
+ elif self.position_encoding_type == "t5" and has_positional_encoding:
225
+ self.pe_encoding = RelativePositionalEncoding(self.relative_attention_num_buckets, self.relative_attention_max_distance, self.n_heads, self.max_sequence_length, config.use_randomized_position_encoding)
226
+ elif self.position_encoding_type == "RoPE":
227
+ self.pe_encoding = RotaryPositionalEncoding(int(self.key_value_proj_dim * config.rotary_emb_fraction), self.max_sequence_length, config.rotary_base, config.rotary_interleaved, config.rotary_scale_base, config.use_randomized_position_encoding)
228
+
229
+ self.Wq = nn.Linear(self.d_model, self.inner_dim, bias=False)
230
+ self.Wk = nn.Linear(self.d_model, self.inner_dim, bias=False)
231
+ self.Wv = nn.Linear(self.d_model, self.inner_dim, bias=False)
232
+ self.o = nn.Linear(self.inner_dim, self.d_model, bias=False)
233
+
234
+ def forward(
235
+ self,
236
+ hidden_states,
237
+ mask=None,
238
+ key_value_states=None,
239
+ position_bias=None,
240
+ ):
241
+ """
242
+ Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
243
+ """
244
+ # Input is (batch_size, seq_length, dim)
245
+ # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length)
246
+ batch_size, seq_length = hidden_states.shape[:2]
247
+ key_length = seq_length if key_value_states is None else key_value_states.shape[1]
248
+ q = self.Wq(hidden_states)
249
+ if key_value_states is None:
250
+ k = self.Wk(hidden_states)
251
+ v = self.Wv(hidden_states)
252
+ else:
253
+ k = self.Wk(key_value_states)
254
+ v = self.Wv(key_value_states)
255
+
256
+ q = q.view(batch_size, seq_length, self.n_heads, self.key_value_proj_dim)
257
+ k = k.view(batch_size, key_length, self.n_heads, self.key_value_proj_dim)
258
+ v = v.view(batch_size, key_length, self.n_heads, self.key_value_proj_dim)
259
+
260
+ if position_bias is None and self.pe_encoding is not None:
261
+ q, k, v, position_bias = self.pe_encoding(q, k, v)
262
+
263
+ if position_bias is not None and self.use_full_bias_size and (self.use_flash_attention == "fa2" or self.use_flash_attention == "triton"):
264
+ position_bias = position_bias.expand(q.shape[0], q.shape[2], q.shape[1], k.shape[1]).contiguous()
265
+
266
+ if self.use_flash_attention == "fa2":
267
+ output = flash_attn_func(q, k, v, dropout_p=self.p_dropout, softmax_scale=self.softmax_scale, attn_bias=position_bias, causal=self.is_causal)
268
+ elif self.use_flash_attention == "triton":
269
+ q = q.permute(0, 2, 1, 3)
270
+ k = k.permute(0, 2, 1, 3)
271
+ v = v.permute(0, 2, 1, 3)
272
+ output = flash_attention_triton(q, k, v, position_bias, self.is_causal, self.softmax_scale)
273
+ output = output.permute(0, 2, 1, 3)
274
+ else: # use flash attention
275
+ q = q.permute(0, 2, 1, 3)
276
+ k = k.permute(0, 2, 1, 3)
277
+ v = v.permute(0, 2, 1, 3)
278
+ output = attn_ref(q, k, v, position_bias, dropout_p=self.p_dropout, sm_scale=self.softmax_scale, causal=self.is_causal)
279
+ output = output.permute(0, 2, 1, 3)
280
+
281
+ output = self.o(output.reshape(output.shape[0], output.shape[1], self.inner_dim))
282
+ return (output, position_bias)
283
+
284
+
285
+ class FlashT5LayerSelfAttention(nn.Module):
286
+ def __init__(self, config, has_positional_encoding=False):
287
+ super().__init__()
288
+ self.self_attention = FlashT5Attention(config, has_positional_encoding=has_positional_encoding, is_causal=config.is_decoder)
289
+ self.layer_norm = FlashT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon, use_triton_layernorm=config.use_triton_layernorm)
290
+ self.dropout = nn.Dropout(config.dropout_rate)
291
+
292
+ def forward(
293
+ self,
294
+ hidden_states,
295
+ attention_mask=None,
296
+ position_bias=None,
297
+ ):
298
+ normed_hidden_states = self.layer_norm(hidden_states).type_as(hidden_states)
299
+ attention_output = self.self_attention(
300
+ normed_hidden_states,
301
+ mask=attention_mask,
302
+ position_bias=position_bias,
303
+ )
304
+ hidden_states = hidden_states + self.dropout(attention_output[0])
305
+ outputs = (hidden_states,) + attention_output[1:]
306
+ return outputs
307
+
308
+
309
+ class FlashT5LayerCrossAttention(nn.Module):
310
+ def __init__(self, config):
311
+ super().__init__()
312
+ self.cross_attention = FlashT5Attention(config, has_positional_encoding=False)
313
+ self.layer_norm = FlashT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon, use_triton_layernorm=config.use_triton_layernorm)
314
+ self.dropout = nn.Dropout(config.dropout_rate)
315
+
316
+ def forward(
317
+ self,
318
+ hidden_states,
319
+ key_value_states,
320
+ attention_mask=None,
321
+ position_bias=None,
322
+ ):
323
+ normed_hidden_states = self.layer_norm(hidden_states)
324
+ attention_output = self.cross_attention(
325
+ normed_hidden_states,
326
+ mask=attention_mask,
327
+ key_value_states=key_value_states,
328
+ position_bias=position_bias,
329
+ )
330
+ layer_output = hidden_states + self.dropout(attention_output[0])
331
+ outputs = (layer_output,) + attention_output[1:]
332
+ return outputs
333
+
334
+
335
+ class FlashT5Block(nn.Module):
336
+ def __init__(self, config, has_positional_encoding=False):
337
+ super().__init__()
338
+ self.is_decoder = config.is_decoder
339
+
340
+ self.self_attention_layer = FlashT5LayerSelfAttention(config, has_positional_encoding=has_positional_encoding)
341
+
342
+ if self.is_decoder:
343
+ self.cross_attention_layer = FlashT5LayerCrossAttention(config)
344
+
345
+ self.ff_layer = FlashT5LayerFF(config)
346
+
347
+ def forward(
348
+ self,
349
+ hidden_states,
350
+ attention_mask=None,
351
+ position_bias=None,
352
+ encoder_hidden_states=None,
353
+ encoder_attention_mask=None,
354
+ encoder_decoder_position_bias=None,
355
+ ):
356
+ self_attention_outputs = self.self_attention_layer(
357
+ hidden_states,
358
+ attention_mask=attention_mask,
359
+ position_bias=position_bias,
360
+ )
361
+ hidden_states = self_attention_outputs[0]
362
+ attention_outputs = self_attention_outputs[1:] # Relative position weights
363
+
364
+ if self.is_decoder and encoder_hidden_states is not None:
365
+ cross_attention_outputs = self.cross_attention_layer(
366
+ hidden_states,
367
+ key_value_states=encoder_hidden_states,
368
+ attention_mask=encoder_attention_mask,
369
+ position_bias=encoder_decoder_position_bias,
370
+ )
371
+ hidden_states = cross_attention_outputs[0]
372
+
373
+ # Keep relative position weights
374
+ attention_outputs = attention_outputs + cross_attention_outputs[1:]
375
+
376
+ # Apply Feed Forward layer
377
+ hidden_states = self.ff_layer(hidden_states)
378
+
379
+ outputs = (hidden_states,) + attention_outputs
380
+ return outputs # hidden-states, (self-attention position bias), (cross-attention position bias)
381
+
382
+ class FlashT5Stack(nn.Module, ModuleUtilsMixin):
383
+ def __init__(self, config, embed_tokens):
384
+ super().__init__()
385
+ assert embed_tokens is not None
386
+
387
+ self.config = config
388
+ self.embed_tokens = embed_tokens
389
+ self.is_decoder = config.is_decoder
390
+ self.use_flash_attention = config.use_flash_attention
391
+
392
+ self.block = nn.ModuleList(
393
+ [FlashT5Block(config, has_positional_encoding=bool(i == 0)) for i in range(config.num_layers)]
394
+ )
395
+
396
+ self.final_layer_norm = FlashT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon, use_triton_layernorm=config.use_triton_layernorm)
397
+ self.dropout = nn.Dropout(config.dropout_rate)
398
+
399
+ def forward(
400
+ self,
401
+ input_ids=None,
402
+ attention_mask=None,
403
+ encoder_hidden_states=None,
404
+ encoder_attention_mask=None,
405
+ inputs_embeds=None,
406
+ head_mask=None,
407
+ cross_attn_head_mask=None,
408
+ past_key_values=None,
409
+ use_cache=None,
410
+ output_attentions=None,
411
+ output_hidden_states=None,
412
+ return_dict=None) -> BaseModelOutput:
413
+ input_shape = input_ids.size()
414
+ batch_size, seq_length = input_shape
415
+
416
+ if inputs_embeds is None:
417
+ inputs_embeds = self.embed_tokens(input_ids)
418
+
419
+ if torch.is_autocast_enabled() and input_ids.device.type == 'cuda':
420
+ inputs_embeds = inputs_embeds.to(torch.get_autocast_gpu_dtype())
421
+
422
+ # Masking
423
+ if attention_mask is None:
424
+ attention_mask = torch.ones(batch_size, seq_length, device=inputs_embeds.device, dtype=torch.bool)
425
+
426
+ if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None:
427
+ encoder_seq_length = encoder_hidden_states.shape[1]
428
+ encoder_attention_mask = torch.ones(
429
+ batch_size, encoder_seq_length, device=inputs_embeds.device, dtype=torch.bool
430
+ )
431
+
432
+ position_bias = None
433
+ encoder_decoder_position_bias = None
434
+
435
+ hidden_states = self.dropout(inputs_embeds)
436
+
437
+ for _, layer_module in enumerate(self.block):
438
+ layer_outputs = layer_module(
439
+ hidden_states,
440
+ attention_mask=attention_mask,
441
+ position_bias=position_bias,
442
+ encoder_hidden_states=encoder_hidden_states,
443
+ encoder_attention_mask=encoder_attention_mask,
444
+ encoder_decoder_position_bias=encoder_decoder_position_bias,
445
+ )
446
+
447
+ # We share the position biases between the layers - the first layer store them
448
+ position_bias = layer_outputs[1]
449
+ if self.is_decoder and encoder_hidden_states is not None:
450
+ encoder_decoder_position_bias = layer_outputs[2]
451
+
452
+ hidden_states = layer_outputs[0]
453
+
454
+ hidden_states = self.final_layer_norm(hidden_states).type_as(hidden_states)
455
+ hidden_states = self.dropout(hidden_states)
456
+
457
+ return BaseModelOutput(
458
+ last_hidden_state=hidden_states
459
+ )
460
+
461
+
462
+ class FlashT5PreTrainedModel(PreTrainedModel):
463
+ """
464
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
465
+ models.
466
+ """
467
+
468
+ config_class = FlashT5Config
469
+ base_model_prefix = "transformer"
470
+ is_parallelizable = False
471
+ supports_gradient_checkpointing = True
472
+ _no_split_modules = ["FlashT5Block"]
473
+ _keep_in_fp32_modules = []
474
+
475
+ def _init_weights(self, module):
476
+ factor = self.config.initializer_factor # Used for testing weights initialization
477
+ if isinstance(module, FlashT5LayerNorm):
478
+ module.weight.data.fill_(factor * 1.0)
479
+ elif isinstance(module, (FlashT5ForConditionalGeneration)):
480
+ module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0)
481
+ if hasattr(module, "lm_head") and not self.config.tie_word_embeddings:
482
+ module.lm_head.weight.data.normal_(mean=0.0, std=factor * self.config.d_model ** -0.5)
483
+ elif isinstance(module, FlashT5DenseGatedAct):
484
+ d_ff, d_model = module.wi_0.weight.data.size()
485
+ module.wi_0.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5))
486
+ module.wi_1.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5))
487
+ elif isinstance(module, FlashT5LayerFF):
488
+ d_ff, d_model = module.wo.weight.data.size()
489
+ module.wo.weight.data.normal_(mean=0.0, std=factor * ((d_ff) ** -0.5))
490
+ elif isinstance(module, FlashT5Attention):
491
+ d_model = self.config.d_model
492
+ key_value_proj_dim = self.config.d_kv
493
+ n_heads = self.config.num_heads
494
+ module.Wq.weight.data.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5))
495
+ module.Wk.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5))
496
+ module.Wv.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5))
497
+ module.o.weight.data.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5))
498
+ if module.has_positional_encoding:
499
+ if hasattr(module.pe_encoding, "relative_attention_bias"):
500
+ module.pe_encoding.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5))
501
+
502
+ def _shift_right(self, input_ids):
503
+ decoder_start_token_id = self.config.decoder_start_token_id
504
+ pad_token_id = self.config.pad_token_id
505
+
506
+ shifted_input_ids = input_ids.new_zeros(input_ids.shape)
507
+ shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
508
+ shifted_input_ids[..., 0] = decoder_start_token_id
509
+
510
+ # replace possible -100 values in labels by `pad_token_id`
511
+ shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
512
+
513
+ return shifted_input_ids
514
+
515
+
516
+ class FlashT5Model(FlashT5PreTrainedModel):
517
+
518
+ def __init__(self, config: FlashT5Config):
519
+ super().__init__(config)
520
+ self.shared = nn.Embedding(config.vocab_size, config.d_model)
521
+
522
+ encoder_config = copy.deepcopy(config)
523
+ encoder_config.is_decoder = False
524
+ encoder_config.use_cache = False
525
+ encoder_config.is_encoder_decoder = False
526
+ self.encoder = FlashT5Stack(encoder_config, self.shared)
527
+
528
+ decoder_config = copy.deepcopy(config)
529
+ decoder_config.is_decoder = True
530
+ decoder_config.is_encoder_decoder = False
531
+ decoder_config.num_layers = config.num_decoder_layers
532
+ self.decoder = FlashT5Stack(decoder_config, self.shared)
533
+
534
+ # Initialize weights and apply final processing
535
+ self.post_init()
536
+
537
+ # Model parallel
538
+ self.model_parallel = False
539
+ self.device_map = None
540
+
541
+ def get_input_embeddings(self):
542
+ return self.shared
543
+
544
+ def set_input_embeddings(self, new_embeddings):
545
+ self.shared = new_embeddings
546
+ self.encoder.set_input_embeddings(new_embeddings)
547
+ self.decoder.set_input_embeddings(new_embeddings)
548
+
549
+ def get_encoder(self):
550
+ return self.encoder
551
+
552
+ def get_decoder(self):
553
+ return self.decoder
554
+
555
+ def forward(
556
+ self,
557
+ input_ids: Optional[torch.LongTensor] = None,
558
+ attention_mask: Optional[torch.FloatTensor] = None,
559
+ decoder_input_ids: Optional[torch.LongTensor] = None,
560
+ decoder_attention_mask: Optional[torch.BoolTensor] = None,
561
+ head_mask: Optional[torch.FloatTensor] = None,
562
+ decoder_head_mask: Optional[torch.FloatTensor] = None,
563
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
564
+ encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
565
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
566
+ inputs_embeds: Optional[torch.Tensor] = None,
567
+ decoder_inputs_embeds: Optional[torch.Tensor] = None,
568
+ use_cache: Optional[bool] = None,
569
+ output_attentions: Optional[bool] = None,
570
+ output_hidden_states: Optional[bool] = None,
571
+ return_dict: Optional[bool] = None,
572
+ ) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]:
573
+
574
+ # Encode if needed (training, first prediction pass)
575
+ if encoder_outputs is None:
576
+ encoder_outputs = self.encoder(
577
+ input_ids=input_ids,
578
+ attention_mask=attention_mask,
579
+ inputs_embeds=inputs_embeds
580
+ )
581
+
582
+ hidden_states = encoder_outputs[0]
583
+
584
+ # Decode
585
+ decoder_outputs = self.decoder(
586
+ input_ids=decoder_input_ids,
587
+ attention_mask=decoder_attention_mask,
588
+ inputs_embeds=decoder_inputs_embeds,
589
+ encoder_hidden_states=hidden_states,
590
+ encoder_attention_mask=attention_mask
591
+ )
592
+
593
+ return Seq2SeqModelOutput(
594
+ last_hidden_state=decoder_outputs.last_hidden_state,
595
+ decoder_hidden_states=decoder_outputs.hidden_states,
596
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
597
+ encoder_hidden_states=encoder_outputs.hidden_states,
598
+ )
599
+
600
+ class FlashT5ForConditionalGeneration(FlashT5PreTrainedModel):
601
+
602
+ def __init__(self, config: FlashT5Config):
603
+ super().__init__(config)
604
+ config.is_encoder_decoder = False
605
+ assert not config.tie_word_embeddings
606
+
607
+ self.config = config
608
+ self.model_dim = config.d_model
609
+ self.shared = nn.Embedding(config.vocab_size, config.d_model)
610
+
611
+ encoder_config = copy.deepcopy(config)
612
+ encoder_config.is_decoder = False
613
+ self.encoder = FlashT5Stack(encoder_config, self.shared)
614
+
615
+ decoder_config = copy.deepcopy(config)
616
+ decoder_config.is_decoder = True
617
+ decoder_config.num_layers = config.num_decoder_layers
618
+ self.decoder = FlashT5Stack(decoder_config, self.shared)
619
+
620
+ self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
621
+
622
+ self.loss_fct = FlashT5CrossEntropyLoss(z_loss_factor=config.z_loss,
623
+ label_smoothing=config.label_smoothing,
624
+ use_triton_crossentropy=config.use_triton_crossentropy)
625
+
626
+ # Initialize weights and apply final processing
627
+ self.post_init()
628
+
629
+ def prepare_inputs_for_generation(
630
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
631
+ ):
632
+ # do nothing
633
+ model_inputs = {"input_ids": input_ids, "attention_mask": attention_mask}
634
+
635
+ return model_inputs
636
+
637
+ def get_input_embeddings(self):
638
+ return self.shared
639
+
640
+ def set_input_embeddings(self, value):
641
+ self.shared = value
642
+
643
+ def generate(
644
+ self,
645
+ input_ids: Optional[torch.LongTensor] = None,
646
+ attention_mask: Optional[torch.FloatTensor] = None,
647
+ max_length = 32,
648
+ **kwargs,
649
+ ) -> torch.LongTensor:
650
+ """
651
+ input_ids: B x L_encoder, int64
652
+ attention_mask: B x L_encoder, int64
653
+ 1 for tokens to attend to, 0 for tokens to ignore
654
+
655
+ Generation:
656
+ Starts with 0, ends with 1, padding is 0
657
+
658
+ # For 20 input/outputs, the diff between my implementation and HF is 9.8s vs 11.4s
659
+ """
660
+ B, _ = input_ids.size()
661
+ labels = torch.zeros(B, 1, dtype=torch.long, device=input_ids.device)
662
+ encoder_outputs = None
663
+
664
+ for _ in range(max_length):
665
+ out = self.forward(
666
+ input_ids=input_ids,
667
+ attention_mask=attention_mask,
668
+ decoder_input_ids=labels,
669
+ encoder_outputs=encoder_outputs,
670
+ )
671
+ encoder_outputs = out.encoder_outputs
672
+ top_labels = out.logits[:, -1].argmax(-1).unsqueeze(-1)
673
+ labels = torch.cat([labels, top_labels], dim=-1)
674
+
675
+ if (labels == 1).sum(-1).clamp(min=0, max=1).sum().item() == B:
676
+ break
677
+
678
+ labels[:, -1] = 1
679
+
680
+ # Mask out the padding, i.e., all positions after the first 1 with 0
681
+ B, L = labels.size()
682
+ mask = torch.arange(L, device=labels.device).unsqueeze(0) <= (labels == 1).long().argmax(-1).unsqueeze(-1)
683
+ labels = labels.masked_fill(~mask, 0)
684
+
685
+ return labels
686
+
687
+ def forward(
688
+ self,
689
+ input_ids: Optional[torch.LongTensor] = None,
690
+ attention_mask: Optional[torch.FloatTensor] = None,
691
+ decoder_input_ids: Optional[torch.LongTensor] = None,
692
+ decoder_attention_mask: Optional[torch.BoolTensor] = None,
693
+ labels: Optional[torch.LongTensor] = None,
694
+ encoder_outputs = None,
695
+ ) -> Seq2SeqLMOutput:
696
+ """
697
+ input_ids: B x L_encoder, int64
698
+ attention_mask: B x L_encoder, int64
699
+ 1 for tokens to attend to, 0 for tokens to ignore
700
+ labels: B x L_decoder, int64
701
+ """
702
+ if encoder_outputs is None:
703
+ encoder_outputs = self.encoder(
704
+ input_ids=input_ids,
705
+ attention_mask=attention_mask,
706
+ )
707
+
708
+ hidden_states = encoder_outputs.hidden_states
709
+
710
+ if labels is not None and decoder_input_ids is None:
711
+ decoder_input_ids = self._shift_right(labels)
712
+
713
+ decoder_outputs = self.decoder(
714
+ input_ids=decoder_input_ids,
715
+ attention_mask=decoder_attention_mask,
716
+ encoder_hidden_states=hidden_states,
717
+ encoder_attention_mask=attention_mask,
718
+ )
719
+
720
+ sequence_output = decoder_outputs[0]
721
+ lm_logits = self.lm_head(sequence_output)
722
+
723
+ loss = None
724
+ if labels is not None:
725
+ loss, z_loss = self.loss_fct(lm_logits, labels)
726
+ loss += z_loss
727
+
728
+ return Seq2SeqLMOutput(
729
+ loss=loss,
730
+ logits=lm_logits,
731
+ encoder_outputs=encoder_outputs,
732
+ )
733
+
734
+
735
+
736
+ class FlashT5EncoderModel(FlashT5PreTrainedModel):
737
+ _tied_weights_keys = ["encoder.embed_tokens.weight"]
738
+
739
+ def __init__(self, config: FlashT5Config):
740
+ super().__init__(config)
741
+ self.shared = nn.Embedding(config.vocab_size, config.d_model)
742
+
743
+ encoder_config = copy.deepcopy(config)
744
+ encoder_config.use_cache = False
745
+ encoder_config.is_encoder_decoder = False
746
+ self.encoder = FlashT5Stack(encoder_config, self.shared)
747
+
748
+ # Initialize weights and apply final processing
749
+ self.post_init()
750
+
751
+ # Model parallel
752
+ self.model_parallel = False
753
+ self.device_map = None
754
+
755
+
756
+ def parallelize(self, device_map=None):
757
+ warnings.warn(
758
+ "`T5EncoderModel.parallelize` is deprecated and will be removed in v5 of Transformers, you should load"
759
+ " your model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own"
760
+ " `device_map` but it needs to be a dictionary module_name to device, so for instance {'block.0': 0,"
761
+ " 'block.1': 1, ...}",
762
+ FutureWarning,
763
+ )
764
+ self.device_map = (
765
+ get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
766
+ if device_map is None
767
+ else device_map
768
+ )
769
+ assert_device_map(self.device_map, len(self.encoder.block))
770
+ self.encoder.parallelize(self.device_map)
771
+ self.model_parallel = True
772
+
773
+ def deparallelize(self):
774
+ warnings.warn(
775
+ "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
776
+ FutureWarning,
777
+ )
778
+ self.encoder.deparallelize()
779
+ self.encoder = self.encoder.to("cpu")
780
+ self.model_parallel = False
781
+ self.device_map = None
782
+ torch.cuda.empty_cache()
783
+
784
+ def get_input_embeddings(self):
785
+ return self.shared
786
+
787
+ def set_input_embeddings(self, new_embeddings):
788
+ self.shared = new_embeddings
789
+ self.encoder.set_input_embeddings(new_embeddings)
790
+
791
+ def get_encoder(self):
792
+ return self.encoder
793
+
794
+ def _prune_heads(self, heads_to_prune):
795
+ """
796
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
797
+ class PreTrainedModel
798
+ """
799
+ for layer, heads in heads_to_prune.items():
800
+ self.encoder.block[layer].layer[0].SelfAttention.prune_heads(heads)
801
+
802
+ def forward(
803
+ self,
804
+ input_ids: Optional[torch.LongTensor] = None,
805
+ attention_mask: Optional[torch.FloatTensor] = None,
806
+ head_mask: Optional[torch.FloatTensor] = None,
807
+ inputs_embeds: Optional[torch.FloatTensor] = None,
808
+ output_attentions: Optional[bool] = None,
809
+ output_hidden_states: Optional[bool] = None,
810
+ return_dict: Optional[bool] = None,
811
+ ) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]:
812
+ r"""
813
+ Returns:
814
+
815
+ Example:
816
+
817
+ ```python
818
+ >>> from transformers import AutoTokenizer, T5EncoderModel
819
+
820
+ >>> tokenizer = AutoTokenizer.from_pretrained("t5-small")
821
+ >>> model = T5EncoderModel.from_pretrained("t5-small")
822
+ >>> input_ids = tokenizer(
823
+ ... "Studies have been shown that owning a dog is good for you", return_tensors="pt"
824
+ ... ).input_ids # Batch size 1
825
+ >>> outputs = model(input_ids=input_ids)
826
+ >>> last_hidden_states = outputs.last_hidden_state
827
+ ```"""
828
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
829
+
830
+ encoder_outputs = self.encoder(
831
+ input_ids=input_ids,
832
+ attention_mask=attention_mask,
833
+ inputs_embeds=inputs_embeds,
834
+ head_mask=head_mask,
835
+ output_attentions=output_attentions,
836
+ output_hidden_states=output_hidden_states,
837
+ return_dict=return_dict,
838
+ )
839
+
840
+ return encoder_outputs
special_tokens_map.json ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<extra_id_0>",
4
+ "<extra_id_1>",
5
+ "<extra_id_2>",
6
+ "<extra_id_3>",
7
+ "<extra_id_4>",
8
+ "<extra_id_5>",
9
+ "<extra_id_6>",
10
+ "<extra_id_7>",
11
+ "<extra_id_8>",
12
+ "<extra_id_9>",
13
+ "<extra_id_10>",
14
+ "<extra_id_11>",
15
+ "<extra_id_12>",
16
+ "<extra_id_13>",
17
+ "<extra_id_14>",
18
+ "<extra_id_15>",
19
+ "<extra_id_16>",
20
+ "<extra_id_17>",
21
+ "<extra_id_18>",
22
+ "<extra_id_19>",
23
+ "<extra_id_20>",
24
+ "<extra_id_21>",
25
+ "<extra_id_22>",
26
+ "<extra_id_23>",
27
+ "<extra_id_24>",
28
+ "<extra_id_25>",
29
+ "<extra_id_26>",
30
+ "<extra_id_27>",
31
+ "<extra_id_28>",
32
+ "<extra_id_29>",
33
+ "<extra_id_30>",
34
+ "<extra_id_31>",
35
+ "<extra_id_32>",
36
+ "<extra_id_33>",
37
+ "<extra_id_34>",
38
+ "<extra_id_35>",
39
+ "<extra_id_36>",
40
+ "<extra_id_37>",
41
+ "<extra_id_38>",
42
+ "<extra_id_39>",
43
+ "<extra_id_40>",
44
+ "<extra_id_41>",
45
+ "<extra_id_42>",
46
+ "<extra_id_43>",
47
+ "<extra_id_44>",
48
+ "<extra_id_45>",
49
+ "<extra_id_46>",
50
+ "<extra_id_47>",
51
+ "<extra_id_48>",
52
+ "<extra_id_49>",
53
+ "<extra_id_50>",
54
+ "<extra_id_51>",
55
+ "<extra_id_52>",
56
+ "<extra_id_53>",
57
+ "<extra_id_54>",
58
+ "<extra_id_55>",
59
+ "<extra_id_56>",
60
+ "<extra_id_57>",
61
+ "<extra_id_58>",
62
+ "<extra_id_59>",
63
+ "<extra_id_60>",
64
+ "<extra_id_61>",
65
+ "<extra_id_62>",
66
+ "<extra_id_63>",
67
+ "<extra_id_64>",
68
+ "<extra_id_65>",
69
+ "<extra_id_66>",
70
+ "<extra_id_67>",
71
+ "<extra_id_68>",
72
+ "<extra_id_69>",
73
+ "<extra_id_70>",
74
+ "<extra_id_71>",
75
+ "<extra_id_72>",
76
+ "<extra_id_73>",
77
+ "<extra_id_74>",
78
+ "<extra_id_75>",
79
+ "<extra_id_76>",
80
+ "<extra_id_77>",
81
+ "<extra_id_78>",
82
+ "<extra_id_79>",
83
+ "<extra_id_80>",
84
+ "<extra_id_81>",
85
+ "<extra_id_82>",
86
+ "<extra_id_83>",
87
+ "<extra_id_84>",
88
+ "<extra_id_85>",
89
+ "<extra_id_86>",
90
+ "<extra_id_87>",
91
+ "<extra_id_88>",
92
+ "<extra_id_89>",
93
+ "<extra_id_90>",
94
+ "<extra_id_91>",
95
+ "<extra_id_92>",
96
+ "<extra_id_93>",
97
+ "<extra_id_94>",
98
+ "<extra_id_95>",
99
+ "<extra_id_96>",
100
+ "<extra_id_97>",
101
+ "<extra_id_98>",
102
+ "<extra_id_99>"
103
+ ],
104
+ "cls_token": "<cls>",
105
+ "eos_token": "</s>",
106
+ "mask_token": "<mask>",
107
+ "pad_token": "<pad>",
108
+ "sep_token": "<sep>",
109
+ "unk_token": "<unk>"
110
+ }