bourdoiscatie commited on
Commit
18bc4bc
·
verified ·
1 Parent(s): c98d82a

Delete custom_heads_flash_t5(1).py

Browse files
Files changed (1) hide show
  1. custom_heads_flash_t5(1).py +0 -404
custom_heads_flash_t5(1).py DELETED
@@ -1,404 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
4
- import copy
5
- from typing import Optional, Union, Tuple, List
6
- from transformers.modeling_outputs import (
7
- Seq2SeqQuestionAnsweringModelOutput,
8
- QuestionAnsweringModelOutput,
9
- TokenClassifierOutput,
10
- BaseModelOutput,
11
- Seq2SeqSequenceClassifierOutput,
12
- SequenceClassifierOutput
13
- )
14
-
15
- from .modeling_flash_t5 import FlashT5PreTrainedModel, FlashT5Stack, FlashT5Model, FlashT5EncoderModel
16
- from .configuration_flash_t5 import FlashT5Config
17
-
18
-
19
- ################## Encoder only head ##################
20
- class FlashT5ForTokenClassification(FlashT5PreTrainedModel):
21
-
22
- def __init__(self, config: FlashT5Config):
23
- super().__init__(config)
24
- self.num_labels = config.num_labels
25
- self.shared = nn.Embedding(config.vocab_size, config.d_model)
26
-
27
- self.encoder = FlashT5Stack(config, self.shared)
28
- self.dropout = nn.Dropout(config.classifier_dropout)
29
- self.classifier = nn.Linear(config.hidden_size, config.num_labels)
30
-
31
- # Initialize weights and apply final processing
32
- self.post_init()
33
-
34
- # Initialize classifier
35
- self.classifier.weight.data.normal_(mean=0.0, std=config.initializer_factor * 1.0)
36
- self.classifier.bias.data.zero_()
37
-
38
- self.model_parallel = False
39
-
40
- def forward(
41
- self,
42
- input_ids: Optional[torch.Tensor] = None,
43
- attention_mask: Optional[torch.Tensor] = None,
44
- head_mask: Optional[torch.Tensor] = None,
45
- inputs_embeds: Optional[torch.Tensor] = None,
46
- labels: Optional[torch.Tensor] = None,
47
- output_attentions: Optional[bool] = None,
48
- output_hidden_states: Optional[bool] = None,
49
- return_dict: Optional[bool] = None,
50
- ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
51
- r"""
52
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
53
- Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
54
- Returns:
55
- """
56
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
57
-
58
- outputs = self.encoder(
59
- input_ids=input_ids,
60
- attention_mask=attention_mask,
61
- inputs_embeds=inputs_embeds,
62
- head_mask=head_mask,
63
- output_attentions=output_attentions,
64
- output_hidden_states=output_hidden_states,
65
- return_dict=return_dict,
66
- )
67
-
68
- hidden_states = outputs[0]
69
- hidden_states = self.dropout(hidden_states)
70
- logits = self.classifier(hidden_states)
71
-
72
- loss = None
73
- if labels is not None:
74
- loss_fct = nn.CrossEntropyLoss()
75
- loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
76
-
77
- if not return_dict:
78
- output = (logits, outputs[2:-1])
79
- return ((loss,) + output) if loss is not None else output
80
-
81
- return TokenClassifierOutput(
82
- loss=loss,
83
- logits=logits,
84
- hidden_states=outputs.hidden_states,
85
- attentions=outputs.attentions,
86
- )
87
-
88
-
89
- class FlashT5ClassificationHead(nn.Module):
90
- """Head for sentence-level classification tasks."""
91
-
92
- def __init__(self, config: FlashT5Config):
93
- super().__init__()
94
- self.dense = nn.Linear(config.d_model, config.d_model)
95
- self.dropout = nn.Dropout(p=config.classifier_dropout)
96
- self.out_proj = nn.Linear(config.d_model, config.num_labels)
97
-
98
- # initialize weights
99
- factor = config.initializer_factor
100
- self.dense.weight.data.normal_(mean=0.0, std=factor * ((config.d_model) ** -0.5))
101
- if hasattr(self.dense, "bias") and self.dense.bias is not None:
102
- self.dense.bias.data.zero_()
103
- self.out_proj.weight.data.normal_(mean=0.0, std=factor * ((config.d_model) ** -0.5))
104
- if hasattr(self.out_proj, "bias") and self.out_proj.bias is not None:
105
- self.out_proj.bias.data.zero_()
106
-
107
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
108
- hidden_states = self.dropout(hidden_states)
109
- hidden_states = self.dense(hidden_states)
110
- hidden_states = torch.tanh(hidden_states)
111
- hidden_states = self.dropout(hidden_states)
112
- hidden_states = self.out_proj(hidden_states)
113
- return hidden_states
114
-
115
-
116
- class FlashT5ForSequenceClassification(FlashT5PreTrainedModel):
117
- _keys_to_ignore_on_load_missing = [r"encoder.embed_tokens.weight"]
118
-
119
- def __init__(self, config: FlashT5Config):
120
- super().__init__(config)
121
- self.model_dim = config.d_model
122
- self.config.problem_type = None
123
- self.config.is_encoder_decoder = False
124
-
125
- self.shared = nn.Embedding(config.vocab_size, config.d_model)
126
-
127
- encoder_config = copy.deepcopy(config)
128
- encoder_config.is_decoder = False
129
- encoder_config.is_encoder_decoder = False
130
- encoder_config.use_cache = False
131
- self.encoder = FlashT5Stack(encoder_config, self.shared)
132
- self.classification_head = FlashT5ClassificationHead(config)
133
-
134
- # Initialize weights and apply final processing
135
- self.post_init()
136
-
137
- self.model_parallel = False
138
-
139
- def forward(
140
- self,
141
- input_ids: torch.LongTensor = None,
142
- attention_mask: Optional[torch.Tensor] = None,
143
- head_mask: Optional[torch.Tensor] = None,
144
- cross_attn_head_mask: Optional[torch.Tensor] = None,
145
- encoder_outputs: Optional[List[torch.FloatTensor]] = None,
146
- inputs_embeds: Optional[torch.FloatTensor] = None,
147
- labels: Optional[torch.LongTensor] = None,
148
- use_cache: Optional[bool] = None,
149
- output_attentions: Optional[bool] = None,
150
- output_hidden_states: Optional[bool] = None,
151
- return_dict: Optional[bool] = None,
152
- ) -> Union[Tuple, Seq2SeqSequenceClassifierOutput]:
153
- r"""
154
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
155
- Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
156
- config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
157
- Returns:
158
- """
159
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
160
- if labels is not None:
161
- use_cache = False
162
-
163
- if input_ids is None and inputs_embeds is not None:
164
- raise NotImplementedError(
165
- f"Passing input embeddings is currently not supported for {self.__class__.__name__}"
166
- )
167
-
168
-
169
- outputs = self.encoder(
170
- input_ids=input_ids,
171
- attention_mask=attention_mask,
172
- inputs_embeds=inputs_embeds,
173
- head_mask=head_mask,
174
- output_attentions=output_attentions,
175
- output_hidden_states=output_hidden_states,
176
- return_dict=return_dict,
177
- )
178
- sequence_output = outputs[0]
179
-
180
- eos_mask = input_ids.eq(self.config.eos_token_id).to(sequence_output.device)
181
-
182
- if len(torch.unique_consecutive(eos_mask.sum(1))) > 1:
183
- raise ValueError("All examples must have the same number of <eos> tokens.")
184
- batch_size, _, hidden_size = sequence_output.shape
185
- sentence_representation = sequence_output[eos_mask, :].view(batch_size, -1, hidden_size)[:, -1, :]
186
- logits = self.classification_head(sentence_representation)
187
-
188
- loss = None
189
- if labels is not None:
190
- labels = labels.to(logits.device)
191
- if self.config.problem_type is None:
192
- if self.config.num_labels == 1:
193
- self.config.problem_type = "regression"
194
- elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
195
- self.config.problem_type = "single_label_classification"
196
- else:
197
- self.config.problem_type = "multi_label_classification"
198
-
199
- if self.config.problem_type == "regression":
200
- loss_fct = nn.MSELoss()
201
- if self.config.num_labels == 1:
202
- loss = loss_fct(logits.squeeze(), labels.squeeze())
203
- else:
204
- loss = loss_fct(logits, labels)
205
- elif self.config.problem_type == "single_label_classification":
206
- loss_fct = nn.CrossEntropyLoss()
207
- loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
208
- elif self.config.problem_type == "multi_label_classification":
209
- loss_fct = nn.BCEWithLogitsLoss()
210
- loss = loss_fct(logits, labels)
211
- if not return_dict:
212
- output = (logits,) + outputs[1:]
213
- return ((loss,) + output) if loss is not None else output
214
-
215
- return SequenceClassifierOutput(
216
- loss=loss,
217
- logits=logits,
218
- hidden_states=outputs.hidden_states,
219
- attentions=outputs.attentions
220
- )
221
-
222
-
223
- class FlashT5ForQuestionAnswering(FlashT5PreTrainedModel):
224
- _keys_to_ignore_on_load_missing = [r"encoder.embed_tokens.weight"]
225
-
226
- def __init__(self, config: FlashT5Config):
227
- super().__init__(config)
228
- self.transformer = FlashT5EncoderModel(config)
229
-
230
- self.num_labels = config.num_labels
231
- self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
232
-
233
- # Initialize weights and apply final processing
234
- self.post_init()
235
-
236
- # Model parallel
237
- self.model_parallel = False
238
-
239
- def forward(
240
- self,
241
- input_ids: Optional[torch.LongTensor] = None,
242
- attention_mask: Optional[torch.FloatTensor] = None,
243
- head_mask: Optional[torch.FloatTensor] = None,
244
- inputs_embeds: Optional[torch.FloatTensor] = None,
245
- start_positions: Optional[torch.Tensor] = None,
246
- end_positions: Optional[torch.Tensor] = None,
247
- output_attentions: Optional[bool] = None,
248
- output_hidden_states: Optional[bool] = None,
249
- return_dict: Optional[bool] = None,
250
- ) -> Union[Tuple[torch.FloatTensor], QuestionAnsweringModelOutput]:
251
- r"""
252
- start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
253
- Labels for position (index) of the start of the labelled span for computing the token classification loss.
254
- Positions are clamped to the length of the sequence (*sequence_length*). Position outside of the sequence
255
- are not taken into account for computing the loss.
256
- end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
257
- Labels for position (index) of the end of the labelled span for computing the token classification loss.
258
- Positions are clamped to the length of the sequence (*sequence_length*). Position outside of the sequence
259
- are not taken into account for computing the loss.
260
-
261
- Returns:
262
- """
263
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
264
-
265
- encoder_outputs = self.transformer(
266
- input_ids=input_ids,
267
- attention_mask=attention_mask,
268
- inputs_embeds=inputs_embeds,
269
- head_mask=head_mask,
270
- output_attentions=output_attentions,
271
- output_hidden_states=output_hidden_states,
272
- return_dict=return_dict,
273
- )
274
-
275
- sequence_output = encoder_outputs[0]
276
-
277
- logits = self.qa_outputs(sequence_output)
278
- start_logits, end_logits = logits.split(1, dim=-1)
279
- start_logits = start_logits.squeeze(-1).contiguous()
280
- end_logits = end_logits.squeeze(-1).contiguous()
281
-
282
- total_loss = None
283
- if start_positions is not None and end_positions is not None:
284
- # If we are on multi-GPU, split add a dimension
285
- if len(start_positions.size()) > 1:
286
- start_positions = start_positions.squeeze(-1).to(start_logits.device)
287
- if len(end_positions.size()) > 1:
288
- end_positions = end_positions.squeeze(-1).to(end_logits.device)
289
- # sometimes the start/end positions are outside our model inputs, we ignore these terms
290
- ignored_index = start_logits.size(1)
291
- start_positions = start_positions.clamp(0, ignored_index)
292
- end_positions = end_positions.clamp(0, ignored_index)
293
-
294
- loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
295
- start_loss = loss_fct(start_logits, start_positions)
296
- end_loss = loss_fct(end_logits, end_positions)
297
- total_loss = (start_loss + end_loss) / 2
298
-
299
- if not return_dict:
300
- output = (start_logits, end_logits) + encoder_outputs[1:]
301
- return ((total_loss,) + output) if total_loss is not None else output
302
-
303
- return QuestionAnsweringModelOutput(
304
- loss=total_loss,
305
- start_logits=start_logits,
306
- end_logits=end_logits,
307
- hidden_states=encoder_outputs.hidden_states,
308
- attentions=encoder_outputs.attentions,
309
- )
310
-
311
-
312
-
313
- class FlashT5ForQuestionAnswering(FlashT5PreTrainedModel):
314
- _keys_to_ignore_on_load_missing = [r"encoder.embed_tokens.weight"]
315
-
316
- def __init__(self, config: FlashT5Config):
317
- super().__init__(config)
318
- self.shared = nn.Embedding(config.vocab_size, config.d_model)
319
-
320
- encoder_config = copy.deepcopy(config)
321
- encoder_config.is_decoder = False
322
- encoder_config.is_encoder_decoder = False
323
- self.encoder = FlashT5Stack(encoder_config, self.shared)
324
- self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
325
-
326
- # Initialize weights and apply final processing
327
- self.post_init()
328
-
329
- self.qa_outputs.weight.data.normal_(mean=0.0, std=config.initializer_factor * 1.0)
330
- self.qa_outputs.bias.data.zero_()
331
-
332
- self.model_parallel = False
333
-
334
- def forward(
335
- self,
336
- input_ids: Optional[torch.LongTensor] = None,
337
- attention_mask: Optional[torch.FloatTensor] = None,
338
- head_mask: Optional[torch.FloatTensor] = None,
339
- inputs_embeds: Optional[torch.FloatTensor] = None,
340
- start_positions: Optional[torch.LongTensor] = None,
341
- end_positions: Optional[torch.LongTensor] = None,
342
- output_attentions: Optional[bool] = None,
343
- output_hidden_states: Optional[bool] = None,
344
- return_dict: Optional[bool] = None,
345
- ) -> Union[Tuple, QuestionAnsweringModelOutput]:
346
- r"""
347
- Returns:
348
-
349
- Example:
350
-
351
- ```python
352
- >>> from transformers import AutoTokenizer, MTxEncoderForQuestionAnswering
353
-
354
- >>> tokenizer = AutoTokenizer.from_pretrained("MTx-small")
355
- >>> model = MTxEncoderForQuestionAnswering.from_pretrained("MTx-small")
356
- >>> input_ids = tokenizer(
357
- ... "Studies have been shown that owning a dog is good for you", return_tensors="pt"
358
- ... ).input_ids # Batch size 1
359
- >>> outputs = model(input_ids=input_ids)
360
- >>> start_logits = outputs.start_logits
361
- >>> end_logits = outputs.end_logits
362
- ```"""
363
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
364
-
365
- outputs = self.encoder(
366
- input_ids,
367
- attention_mask=attention_mask,
368
- inputs_embeds=inputs_embeds,
369
- )
370
- sequence_output = outputs[0]
371
-
372
- logits = self.qa_outputs(sequence_output)
373
- start_logits, end_logits = logits.split(1, dim=-1)
374
- start_logits = start_logits.squeeze(-1).contiguous()
375
- end_logits = end_logits.squeeze(-1).contiguous()
376
-
377
- total_loss = None
378
- if start_positions is not None and end_positions is not None:
379
- # If we are on multi-GPU, split add a dimension
380
- if len(start_positions.size()) > 1:
381
- start_positions = start_positions.squeeze(-1).to(start_logits.device)
382
- if len(end_positions.size()) > 1:
383
- end_positions = end_positions.squeeze(-1).to(end_logits.device)
384
- # sometimes the start/end positions are outside our model inputs, we ignore these terms
385
- ignored_index = start_logits.size(1)
386
- start_positions = start_positions.clamp(0, ignored_index)
387
- end_positions = end_positions.clamp(0, ignored_index)
388
-
389
- loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
390
- start_loss = loss_fct(start_logits, start_positions)
391
- end_loss = loss_fct(end_logits, end_positions)
392
- total_loss = (start_loss + end_loss) / 2
393
-
394
- if not return_dict:
395
- output = (start_logits, end_logits) + outputs[1:]
396
- return ((total_loss,) + output) if total_loss is not None else output
397
-
398
- return QuestionAnsweringModelOutput(
399
- loss=total_loss,
400
- start_logits=start_logits,
401
- end_logits=end_logits,
402
- hidden_states=outputs.hidden_states,
403
- attentions=outputs.attentions,
404
- )