sjrhuschlee commited on
Commit
9ac08c2
1 Parent(s): 1e85266

Upload modeling_t5qa.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_t5qa.py +205 -0
modeling_t5qa.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import warnings
3
+ from typing import Optional, Tuple, Union
4
+
5
+ import torch
6
+ from torch import nn
7
+ from torch.nn import CrossEntropyLoss
8
+
9
+ from transformers import AutoModelForQuestionAnswering
10
+ from transformers.modeling_outputs import (
11
+ BaseModelOutput,
12
+ Seq2SeqQuestionAnsweringModelOutput,
13
+ )
14
+ from transformers.models.t5.configuration_t5 import T5Config
15
+ from transformers.models.t5.modeling_t5 import T5PreTrainedModel, T5Stack
16
+
17
+
18
+ class T5ForQuestionAnswering(T5PreTrainedModel):
19
+ _keys_to_ignore_on_load_missing = [
20
+ r"encoder.embed_tokens.weight",
21
+ r"decoder.embed_tokens.weight",
22
+ ]
23
+ _keys_to_ignore_on_load_unexpected = [
24
+ r"decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight",
25
+ ]
26
+
27
+ def __init__(self, config: T5Config):
28
+ super().__init__(config)
29
+ self.model_dim = config.d_model
30
+
31
+ self.shared = nn.Embedding(config.vocab_size, config.d_model)
32
+
33
+ encoder_config = copy.deepcopy(config)
34
+ encoder_config.is_decoder = False
35
+ encoder_config.use_cache = False
36
+ encoder_config.is_encoder_decoder = False
37
+ self.encoder = T5Stack(encoder_config, self.shared)
38
+
39
+ decoder_config = copy.deepcopy(config)
40
+ decoder_config.is_decoder = True
41
+ decoder_config.is_encoder_decoder = False
42
+ decoder_config.num_layers = config.num_decoder_layers
43
+ self.decoder = T5Stack(decoder_config, self.shared)
44
+
45
+ self.num_labels = config.num_labels
46
+ self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
47
+
48
+ # Initialize weights and apply final processing
49
+ self.post_init()
50
+
51
+ # Model parallel
52
+ self.model_parallel = False
53
+ self.device_map = None
54
+
55
+ def get_input_embeddings(self):
56
+ return self.shared
57
+
58
+ def set_input_embeddings(self, new_embeddings):
59
+ self.shared = new_embeddings
60
+ self.encoder.set_input_embeddings(new_embeddings)
61
+ self.decoder.set_input_embeddings(new_embeddings)
62
+
63
+ def get_encoder(self):
64
+ return self.encoder
65
+
66
+ def get_decoder(self):
67
+ return self.decoder
68
+
69
+ def forward(
70
+ self,
71
+ input_ids: Optional[torch.LongTensor] = None,
72
+ attention_mask: Optional[torch.FloatTensor] = None,
73
+ decoder_input_ids: Optional[torch.LongTensor] = None,
74
+ decoder_attention_mask: Optional[torch.BoolTensor] = None,
75
+ head_mask: Optional[torch.FloatTensor] = None,
76
+ decoder_head_mask: Optional[torch.FloatTensor] = None,
77
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
78
+ encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,
79
+ start_positions: Optional[torch.LongTensor] = None,
80
+ end_positions: Optional[torch.LongTensor] = None,
81
+ inputs_embeds: Optional[torch.FloatTensor] = None,
82
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
83
+ use_cache: Optional[bool] = None,
84
+ output_attentions: Optional[bool] = None,
85
+ output_hidden_states: Optional[bool] = None,
86
+ return_dict: Optional[bool] = None,
87
+ ) -> Union[Tuple[torch.FloatTensor], Seq2SeqQuestionAnsweringModelOutput]:
88
+ r"""
89
+ start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
90
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
91
+ Positions are clamped to the length of the sequence (*sequence_length*). Position outside of the sequence
92
+ are not taken into account for computing the loss.
93
+ end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
94
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
95
+ Positions are clamped to the length of the sequence (*sequence_length*). Position outside of the sequence
96
+ are not taken into account for computing the loss.
97
+
98
+ Returns:
99
+ """
100
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
101
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
102
+ if start_positions is not None and end_positions is not None:
103
+ use_cache = False
104
+
105
+ # Copied from models.bart.modeling_bart.BartModel.forward
106
+ # different to other models, T5 automatically creates decoder_input_ids from
107
+ # input_ids if no decoder_input_ids are provided
108
+ if decoder_input_ids is None and decoder_inputs_embeds is None:
109
+ if input_ids is None:
110
+ raise ValueError(
111
+ "If no `decoder_input_ids` or `decoder_inputs_embeds` are "
112
+ "passed, `input_ids` cannot be `None`. Please pass either "
113
+ "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`."
114
+ )
115
+ decoder_input_ids = self._shift_right(input_ids)
116
+
117
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
118
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
119
+
120
+ # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
121
+ if head_mask is not None and decoder_head_mask is None:
122
+ if self.config.num_layers == self.config.num_decoder_layers:
123
+ warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
124
+ decoder_head_mask = head_mask
125
+
126
+ # Encode if needed (training, first prediction pass)
127
+ if encoder_outputs is None:
128
+ encoder_outputs = self.encoder(
129
+ input_ids=input_ids,
130
+ attention_mask=attention_mask,
131
+ inputs_embeds=inputs_embeds,
132
+ head_mask=head_mask,
133
+ output_attentions=output_attentions,
134
+ output_hidden_states=output_hidden_states,
135
+ return_dict=return_dict,
136
+ )
137
+ elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
138
+ encoder_outputs = BaseModelOutput(
139
+ last_hidden_state=encoder_outputs[0],
140
+ hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
141
+ attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
142
+ )
143
+
144
+ hidden_states = encoder_outputs[0]
145
+
146
+ # Decode
147
+ decoder_outputs = self.decoder(
148
+ input_ids=decoder_input_ids,
149
+ attention_mask=decoder_attention_mask,
150
+ inputs_embeds=decoder_inputs_embeds,
151
+ past_key_values=None,
152
+ encoder_hidden_states=hidden_states,
153
+ encoder_attention_mask=attention_mask,
154
+ head_mask=decoder_head_mask,
155
+ cross_attn_head_mask=cross_attn_head_mask,
156
+ use_cache=use_cache,
157
+ output_attentions=output_attentions,
158
+ output_hidden_states=output_hidden_states,
159
+ return_dict=return_dict,
160
+ )
161
+
162
+ sequence_output = decoder_outputs[0]
163
+
164
+ logits = self.qa_outputs(sequence_output)
165
+ start_logits, end_logits = logits.split(1, dim=-1)
166
+ start_logits = start_logits.squeeze(-1).contiguous()
167
+ end_logits = end_logits.squeeze(-1).contiguous()
168
+
169
+ total_loss = None
170
+ if start_positions is not None and end_positions is not None:
171
+ # If we are on multi-GPU, split add a dimension
172
+ if len(start_positions.size()) > 1:
173
+ start_positions = start_positions.squeeze(-1).to(start_logits.device)
174
+ if len(end_positions.size()) > 1:
175
+ end_positions = end_positions.squeeze(-1).to(end_logits.device)
176
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
177
+ ignored_index = start_logits.size(1)
178
+ start_positions = start_positions.clamp(0, ignored_index)
179
+ end_positions = end_positions.clamp(0, ignored_index)
180
+
181
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
182
+ start_loss = loss_fct(start_logits, start_positions)
183
+ end_loss = loss_fct(end_logits, end_positions)
184
+ total_loss = (start_loss + end_loss) / 2
185
+
186
+ if not return_dict:
187
+ output = (start_logits, end_logits) + decoder_outputs[1:] + encoder_outputs
188
+ return ((total_loss,) + output) if total_loss is not None else output
189
+
190
+ return Seq2SeqQuestionAnsweringModelOutput(
191
+ loss=total_loss,
192
+ start_logits=start_logits,
193
+ end_logits=end_logits,
194
+ past_key_values=decoder_outputs.past_key_values,
195
+ decoder_hidden_states=decoder_outputs.hidden_states,
196
+ decoder_attentions=decoder_outputs.attentions,
197
+ cross_attentions=decoder_outputs.cross_attentions,
198
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
199
+ encoder_hidden_states=encoder_outputs.hidden_states,
200
+ encoder_attentions=encoder_outputs.attentions,
201
+ )
202
+
203
+
204
+ AutoModelForQuestionAnswering.register(T5Config, T5ForQuestionAnswering)
205
+