sjrhuschlee commited on
Commit
8544700
1 Parent(s): 33c155a

Upload modeling_t5qa.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_t5qa.py +201 -0
modeling_t5qa.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import warnings
3
+ from typing import Optional, Tuple, Union
4
+
5
+ import torch
6
+ from torch import nn
7
+ from torch.nn import CrossEntropyLoss
8
+
9
+ from transformers.modeling_outputs import (
10
+ BaseModelOutput,
11
+ Seq2SeqQuestionAnsweringModelOutput,
12
+ )
13
+ from transformers.models.t5.configuration_t5 import T5Config
14
+ from transformers.models.t5.modeling_t5 import T5PreTrainedModel, T5Stack
15
+
16
+
17
+ class T5ForQuestionAnswering(T5PreTrainedModel):
18
+ _keys_to_ignore_on_load_missing = [
19
+ r"encoder.embed_tokens.weight",
20
+ r"decoder.embed_tokens.weight",
21
+ ]
22
+ _keys_to_ignore_on_load_unexpected = [
23
+ r"decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight",
24
+ ]
25
+
26
+ def __init__(self, config: T5Config):
27
+ super().__init__(config)
28
+ self.model_dim = config.d_model
29
+
30
+ self.shared = nn.Embedding(config.vocab_size, config.d_model)
31
+
32
+ encoder_config = copy.deepcopy(config)
33
+ encoder_config.is_decoder = False
34
+ encoder_config.use_cache = False
35
+ encoder_config.is_encoder_decoder = False
36
+ self.encoder = T5Stack(encoder_config, self.shared)
37
+
38
+ decoder_config = copy.deepcopy(config)
39
+ decoder_config.is_decoder = True
40
+ decoder_config.is_encoder_decoder = False
41
+ decoder_config.num_layers = config.num_decoder_layers
42
+ self.decoder = T5Stack(decoder_config, self.shared)
43
+
44
+ self.num_labels = config.num_labels
45
+ self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
46
+
47
+ # Initialize weights and apply final processing
48
+ self.post_init()
49
+
50
+ # Model parallel
51
+ self.model_parallel = False
52
+ self.device_map = None
53
+
54
+ def get_input_embeddings(self):
55
+ return self.shared
56
+
57
+ def set_input_embeddings(self, new_embeddings):
58
+ self.shared = new_embeddings
59
+ self.encoder.set_input_embeddings(new_embeddings)
60
+ self.decoder.set_input_embeddings(new_embeddings)
61
+
62
+ def get_encoder(self):
63
+ return self.encoder
64
+
65
+ def get_decoder(self):
66
+ return self.decoder
67
+
68
+ def forward(
69
+ self,
70
+ input_ids: Optional[torch.LongTensor] = None,
71
+ attention_mask: Optional[torch.FloatTensor] = None,
72
+ decoder_input_ids: Optional[torch.LongTensor] = None,
73
+ decoder_attention_mask: Optional[torch.BoolTensor] = None,
74
+ head_mask: Optional[torch.FloatTensor] = None,
75
+ decoder_head_mask: Optional[torch.FloatTensor] = None,
76
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
77
+ encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,
78
+ start_positions: Optional[torch.LongTensor] = None,
79
+ end_positions: Optional[torch.LongTensor] = None,
80
+ inputs_embeds: Optional[torch.FloatTensor] = None,
81
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
82
+ use_cache: Optional[bool] = None,
83
+ output_attentions: Optional[bool] = None,
84
+ output_hidden_states: Optional[bool] = None,
85
+ return_dict: Optional[bool] = None,
86
+ ) -> Union[Tuple[torch.FloatTensor], Seq2SeqQuestionAnsweringModelOutput]:
87
+ r"""
88
+ start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
89
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
90
+ Positions are clamped to the length of the sequence (*sequence_length*). Position outside of the sequence
91
+ are not taken into account for computing the loss.
92
+ end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
93
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
94
+ Positions are clamped to the length of the sequence (*sequence_length*). Position outside of the sequence
95
+ are not taken into account for computing the loss.
96
+
97
+ Returns:
98
+ """
99
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
100
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
101
+ if start_positions is not None and end_positions is not None:
102
+ use_cache = False
103
+
104
+ # Copied from models.bart.modeling_bart.BartModel.forward
105
+ # different to other models, T5 automatically creates decoder_input_ids from
106
+ # input_ids if no decoder_input_ids are provided
107
+ if decoder_input_ids is None and decoder_inputs_embeds is None:
108
+ if input_ids is None:
109
+ raise ValueError(
110
+ "If no `decoder_input_ids` or `decoder_inputs_embeds` are "
111
+ "passed, `input_ids` cannot be `None`. Please pass either "
112
+ "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`."
113
+ )
114
+ decoder_input_ids = self._shift_right(input_ids)
115
+
116
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
117
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
118
+
119
+ # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
120
+ if head_mask is not None and decoder_head_mask is None:
121
+ if self.config.num_layers == self.config.num_decoder_layers:
122
+ warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
123
+ decoder_head_mask = head_mask
124
+
125
+ # Encode if needed (training, first prediction pass)
126
+ if encoder_outputs is None:
127
+ encoder_outputs = self.encoder(
128
+ input_ids=input_ids,
129
+ attention_mask=attention_mask,
130
+ inputs_embeds=inputs_embeds,
131
+ head_mask=head_mask,
132
+ output_attentions=output_attentions,
133
+ output_hidden_states=output_hidden_states,
134
+ return_dict=return_dict,
135
+ )
136
+ elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
137
+ encoder_outputs = BaseModelOutput(
138
+ last_hidden_state=encoder_outputs[0],
139
+ hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
140
+ attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
141
+ )
142
+
143
+ hidden_states = encoder_outputs[0]
144
+
145
+ # Decode
146
+ decoder_outputs = self.decoder(
147
+ input_ids=decoder_input_ids,
148
+ attention_mask=decoder_attention_mask,
149
+ inputs_embeds=decoder_inputs_embeds,
150
+ past_key_values=None,
151
+ encoder_hidden_states=hidden_states,
152
+ encoder_attention_mask=attention_mask,
153
+ head_mask=decoder_head_mask,
154
+ cross_attn_head_mask=cross_attn_head_mask,
155
+ use_cache=use_cache,
156
+ output_attentions=output_attentions,
157
+ output_hidden_states=output_hidden_states,
158
+ return_dict=return_dict,
159
+ )
160
+
161
+ sequence_output = decoder_outputs[0]
162
+
163
+ logits = self.qa_outputs(sequence_output)
164
+ start_logits, end_logits = logits.split(1, dim=-1)
165
+ start_logits = start_logits.squeeze(-1).contiguous()
166
+ end_logits = end_logits.squeeze(-1).contiguous()
167
+
168
+ total_loss = None
169
+ if start_positions is not None and end_positions is not None:
170
+ # If we are on multi-GPU, split add a dimension
171
+ if len(start_positions.size()) > 1:
172
+ start_positions = start_positions.squeeze(-1).to(start_logits.device)
173
+ if len(end_positions.size()) > 1:
174
+ end_positions = end_positions.squeeze(-1).to(end_logits.device)
175
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
176
+ ignored_index = start_logits.size(1)
177
+ start_positions = start_positions.clamp(0, ignored_index)
178
+ end_positions = end_positions.clamp(0, ignored_index)
179
+
180
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
181
+ start_loss = loss_fct(start_logits, start_positions)
182
+ end_loss = loss_fct(end_logits, end_positions)
183
+ total_loss = (start_loss + end_loss) / 2
184
+
185
+ if not return_dict:
186
+ output = (start_logits, end_logits) + decoder_outputs[1:] + encoder_outputs
187
+ return ((total_loss,) + output) if total_loss is not None else output
188
+
189
+ return Seq2SeqQuestionAnsweringModelOutput(
190
+ loss=total_loss,
191
+ start_logits=start_logits,
192
+ end_logits=end_logits,
193
+ past_key_values=decoder_outputs.past_key_values,
194
+ decoder_hidden_states=decoder_outputs.hidden_states,
195
+ decoder_attentions=decoder_outputs.attentions,
196
+ cross_attentions=decoder_outputs.cross_attentions,
197
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
198
+ encoder_hidden_states=encoder_outputs.hidden_states,
199
+ encoder_attentions=encoder_outputs.attentions,
200
+ )
201
+