bourdoiscatie commited on
Commit
d894af5
·
verified ·
1 Parent(s): 948f07b

Update custom_heads_flash_t5.py

Browse files
Files changed (1) hide show
  1. custom_heads_flash_t5.py +34 -29
custom_heads_flash_t5.py CHANGED
@@ -12,7 +12,7 @@ from transformers.modeling_outputs import (
12
  SequenceClassifierOutput
13
  )
14
 
15
- from .modeling_flash_t5 import FlashT5PreTrainedModel, FlashT5Stack, FlashT5Model, FlashT5EncoderModel
16
  from .configuration_flash_t5 import FlashT5Config
17
 
18
 
@@ -225,15 +225,20 @@ class FlashT5ForQuestionAnswering(FlashT5PreTrainedModel):
225
 
226
  def __init__(self, config: FlashT5Config):
227
  super().__init__(config)
228
- self.transformer = FlashT5EncoderModel(config)
229
 
230
- self.num_labels = config.num_labels
 
 
 
231
  self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
232
 
233
  # Initialize weights and apply final processing
234
  self.post_init()
235
 
236
- # Model parallel
 
 
237
  self.model_parallel = False
238
 
239
  def forward(
@@ -242,37 +247,37 @@ class FlashT5ForQuestionAnswering(FlashT5PreTrainedModel):
242
  attention_mask: Optional[torch.FloatTensor] = None,
243
  head_mask: Optional[torch.FloatTensor] = None,
244
  inputs_embeds: Optional[torch.FloatTensor] = None,
245
- start_positions: Optional[torch.Tensor] = None,
246
- end_positions: Optional[torch.Tensor] = None,
247
  output_attentions: Optional[bool] = None,
248
  output_hidden_states: Optional[bool] = None,
249
  return_dict: Optional[bool] = None,
250
- ) -> Union[Tuple[torch.FloatTensor], QuestionAnsweringModelOutput]:
251
  r"""
252
- start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
253
- Labels for position (index) of the start of the labelled span for computing the token classification loss.
254
- Positions are clamped to the length of the sequence (*sequence_length*). Position outside of the sequence
255
- are not taken into account for computing the loss.
256
- end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
257
- Labels for position (index) of the end of the labelled span for computing the token classification loss.
258
- Positions are clamped to the length of the sequence (*sequence_length*). Position outside of the sequence
259
- are not taken into account for computing the loss.
260
-
261
  Returns:
262
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
263
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
264
 
265
- encoder_outputs = self.transformer(
266
- input_ids=input_ids,
267
  attention_mask=attention_mask,
268
  inputs_embeds=inputs_embeds,
269
- head_mask=head_mask,
270
- output_attentions=output_attentions,
271
- output_hidden_states=output_hidden_states,
272
- return_dict=return_dict,
273
  )
274
-
275
- sequence_output = encoder_outputs[0]
276
 
277
  logits = self.qa_outputs(sequence_output)
278
  start_logits, end_logits = logits.split(1, dim=-1)
@@ -297,13 +302,13 @@ class FlashT5ForQuestionAnswering(FlashT5PreTrainedModel):
297
  total_loss = (start_loss + end_loss) / 2
298
 
299
  if not return_dict:
300
- output = (start_logits, end_logits) + encoder_outputs[1:]
301
  return ((total_loss,) + output) if total_loss is not None else output
302
 
303
  return QuestionAnsweringModelOutput(
304
  loss=total_loss,
305
  start_logits=start_logits,
306
  end_logits=end_logits,
307
- hidden_states=encoder_outputs.hidden_states,
308
- attentions=encoder_outputs.attentions,
309
- )
 
12
  SequenceClassifierOutput
13
  )
14
 
15
+ from .modeling_flash_t5 import FlashT5PreTrainedModel, FlashT5Stack, FlashT5Model
16
  from .configuration_flash_t5 import FlashT5Config
17
 
18
 
 
225
 
226
  def __init__(self, config: FlashT5Config):
227
  super().__init__(config)
228
+ self.shared = nn.Embedding(config.vocab_size, config.d_model)
229
 
230
+ encoder_config = copy.deepcopy(config)
231
+ encoder_config.is_decoder = False
232
+ encoder_config.is_encoder_decoder = False
233
+ self.encoder = FlashT5Stack(encoder_config, self.shared)
234
  self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
235
 
236
  # Initialize weights and apply final processing
237
  self.post_init()
238
 
239
+ self.qa_outputs.weight.data.normal_(mean=0.0, std=config.initializer_factor * 1.0)
240
+ self.qa_outputs.bias.data.zero_()
241
+
242
  self.model_parallel = False
243
 
244
  def forward(
 
247
  attention_mask: Optional[torch.FloatTensor] = None,
248
  head_mask: Optional[torch.FloatTensor] = None,
249
  inputs_embeds: Optional[torch.FloatTensor] = None,
250
+ start_positions: Optional[torch.LongTensor] = None,
251
+ end_positions: Optional[torch.LongTensor] = None,
252
  output_attentions: Optional[bool] = None,
253
  output_hidden_states: Optional[bool] = None,
254
  return_dict: Optional[bool] = None,
255
+ ) -> Union[Tuple, QuestionAnsweringModelOutput]:
256
  r"""
 
 
 
 
 
 
 
 
 
257
  Returns:
258
+
259
+ Example:
260
+
261
+ ```python
262
+ >>> from transformers import AutoTokenizer, MTxEncoderForQuestionAnswering
263
+
264
+ >>> tokenizer = AutoTokenizer.from_pretrained("MTx-small")
265
+ >>> model = MTxEncoderForQuestionAnswering.from_pretrained("MTx-small")
266
+ >>> input_ids = tokenizer(
267
+ ... "Studies have been shown that owning a dog is good for you", return_tensors="pt"
268
+ ... ).input_ids # Batch size 1
269
+ >>> outputs = model(input_ids=input_ids)
270
+ >>> start_logits = outputs.start_logits
271
+ >>> end_logits = outputs.end_logits
272
+ ```"""
273
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
274
 
275
+ outputs = self.encoder(
276
+ input_ids,
277
  attention_mask=attention_mask,
278
  inputs_embeds=inputs_embeds,
 
 
 
 
279
  )
280
+ sequence_output = outputs[0]
 
281
 
282
  logits = self.qa_outputs(sequence_output)
283
  start_logits, end_logits = logits.split(1, dim=-1)
 
302
  total_loss = (start_loss + end_loss) / 2
303
 
304
  if not return_dict:
305
+ output = (start_logits, end_logits) + outputs[1:]
306
  return ((total_loss,) + output) if total_loss is not None else output
307
 
308
  return QuestionAnsweringModelOutput(
309
  loss=total_loss,
310
  start_logits=start_logits,
311
  end_logits=end_logits,
312
+ hidden_states=outputs.hidden_states,
313
+ attentions=outputs.attentions,
314
+ )