bourdoiscatie
commited on
Update custom_heads_flash_t5.py
Browse files- custom_heads_flash_t5.py +34 -29
custom_heads_flash_t5.py
CHANGED
@@ -12,7 +12,7 @@ from transformers.modeling_outputs import (
|
|
12 |
SequenceClassifierOutput
|
13 |
)
|
14 |
|
15 |
-
from .modeling_flash_t5 import FlashT5PreTrainedModel, FlashT5Stack, FlashT5Model
|
16 |
from .configuration_flash_t5 import FlashT5Config
|
17 |
|
18 |
|
@@ -225,15 +225,20 @@ class FlashT5ForQuestionAnswering(FlashT5PreTrainedModel):
|
|
225 |
|
226 |
def __init__(self, config: FlashT5Config):
|
227 |
super().__init__(config)
|
228 |
-
self.
|
229 |
|
230 |
-
|
|
|
|
|
|
|
231 |
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
|
232 |
|
233 |
# Initialize weights and apply final processing
|
234 |
self.post_init()
|
235 |
|
236 |
-
|
|
|
|
|
237 |
self.model_parallel = False
|
238 |
|
239 |
def forward(
|
@@ -242,37 +247,37 @@ class FlashT5ForQuestionAnswering(FlashT5PreTrainedModel):
|
|
242 |
attention_mask: Optional[torch.FloatTensor] = None,
|
243 |
head_mask: Optional[torch.FloatTensor] = None,
|
244 |
inputs_embeds: Optional[torch.FloatTensor] = None,
|
245 |
-
start_positions: Optional[torch.
|
246 |
-
end_positions: Optional[torch.
|
247 |
output_attentions: Optional[bool] = None,
|
248 |
output_hidden_states: Optional[bool] = None,
|
249 |
return_dict: Optional[bool] = None,
|
250 |
-
) -> Union[Tuple
|
251 |
r"""
|
252 |
-
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
253 |
-
Labels for position (index) of the start of the labelled span for computing the token classification loss.
|
254 |
-
Positions are clamped to the length of the sequence (*sequence_length*). Position outside of the sequence
|
255 |
-
are not taken into account for computing the loss.
|
256 |
-
end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
257 |
-
Labels for position (index) of the end of the labelled span for computing the token classification loss.
|
258 |
-
Positions are clamped to the length of the sequence (*sequence_length*). Position outside of the sequence
|
259 |
-
are not taken into account for computing the loss.
|
260 |
-
|
261 |
Returns:
|
262 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
263 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
264 |
|
265 |
-
|
266 |
-
input_ids
|
267 |
attention_mask=attention_mask,
|
268 |
inputs_embeds=inputs_embeds,
|
269 |
-
head_mask=head_mask,
|
270 |
-
output_attentions=output_attentions,
|
271 |
-
output_hidden_states=output_hidden_states,
|
272 |
-
return_dict=return_dict,
|
273 |
)
|
274 |
-
|
275 |
-
sequence_output = encoder_outputs[0]
|
276 |
|
277 |
logits = self.qa_outputs(sequence_output)
|
278 |
start_logits, end_logits = logits.split(1, dim=-1)
|
@@ -297,13 +302,13 @@ class FlashT5ForQuestionAnswering(FlashT5PreTrainedModel):
|
|
297 |
total_loss = (start_loss + end_loss) / 2
|
298 |
|
299 |
if not return_dict:
|
300 |
-
output = (start_logits, end_logits) +
|
301 |
return ((total_loss,) + output) if total_loss is not None else output
|
302 |
|
303 |
return QuestionAnsweringModelOutput(
|
304 |
loss=total_loss,
|
305 |
start_logits=start_logits,
|
306 |
end_logits=end_logits,
|
307 |
-
hidden_states=
|
308 |
-
attentions=
|
309 |
-
)
|
|
|
12 |
SequenceClassifierOutput
|
13 |
)
|
14 |
|
15 |
+
from .modeling_flash_t5 import FlashT5PreTrainedModel, FlashT5Stack, FlashT5Model
|
16 |
from .configuration_flash_t5 import FlashT5Config
|
17 |
|
18 |
|
|
|
225 |
|
226 |
def __init__(self, config: FlashT5Config):
|
227 |
super().__init__(config)
|
228 |
+
self.shared = nn.Embedding(config.vocab_size, config.d_model)
|
229 |
|
230 |
+
encoder_config = copy.deepcopy(config)
|
231 |
+
encoder_config.is_decoder = False
|
232 |
+
encoder_config.is_encoder_decoder = False
|
233 |
+
self.encoder = FlashT5Stack(encoder_config, self.shared)
|
234 |
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
|
235 |
|
236 |
# Initialize weights and apply final processing
|
237 |
self.post_init()
|
238 |
|
239 |
+
self.qa_outputs.weight.data.normal_(mean=0.0, std=config.initializer_factor * 1.0)
|
240 |
+
self.qa_outputs.bias.data.zero_()
|
241 |
+
|
242 |
self.model_parallel = False
|
243 |
|
244 |
def forward(
|
|
|
247 |
attention_mask: Optional[torch.FloatTensor] = None,
|
248 |
head_mask: Optional[torch.FloatTensor] = None,
|
249 |
inputs_embeds: Optional[torch.FloatTensor] = None,
|
250 |
+
start_positions: Optional[torch.LongTensor] = None,
|
251 |
+
end_positions: Optional[torch.LongTensor] = None,
|
252 |
output_attentions: Optional[bool] = None,
|
253 |
output_hidden_states: Optional[bool] = None,
|
254 |
return_dict: Optional[bool] = None,
|
255 |
+
) -> Union[Tuple, QuestionAnsweringModelOutput]:
|
256 |
r"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
257 |
Returns:
|
258 |
+
|
259 |
+
Example:
|
260 |
+
|
261 |
+
```python
|
262 |
+
>>> from transformers import AutoTokenizer, MTxEncoderForQuestionAnswering
|
263 |
+
|
264 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("MTx-small")
|
265 |
+
>>> model = MTxEncoderForQuestionAnswering.from_pretrained("MTx-small")
|
266 |
+
>>> input_ids = tokenizer(
|
267 |
+
... "Studies have been shown that owning a dog is good for you", return_tensors="pt"
|
268 |
+
... ).input_ids # Batch size 1
|
269 |
+
>>> outputs = model(input_ids=input_ids)
|
270 |
+
>>> start_logits = outputs.start_logits
|
271 |
+
>>> end_logits = outputs.end_logits
|
272 |
+
```"""
|
273 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
274 |
|
275 |
+
outputs = self.encoder(
|
276 |
+
input_ids,
|
277 |
attention_mask=attention_mask,
|
278 |
inputs_embeds=inputs_embeds,
|
|
|
|
|
|
|
|
|
279 |
)
|
280 |
+
sequence_output = outputs[0]
|
|
|
281 |
|
282 |
logits = self.qa_outputs(sequence_output)
|
283 |
start_logits, end_logits = logits.split(1, dim=-1)
|
|
|
302 |
total_loss = (start_loss + end_loss) / 2
|
303 |
|
304 |
if not return_dict:
|
305 |
+
output = (start_logits, end_logits) + outputs[1:]
|
306 |
return ((total_loss,) + output) if total_loss is not None else output
|
307 |
|
308 |
return QuestionAnsweringModelOutput(
|
309 |
loss=total_loss,
|
310 |
start_logits=start_logits,
|
311 |
end_logits=end_logits,
|
312 |
+
hidden_states=outputs.hidden_states,
|
313 |
+
attentions=outputs.attentions,
|
314 |
+
)
|