Markus28 commited on
Commit
ec37ae5
·
1 Parent(s): ba24fb1

feat: added further GLUE models

Browse files
Files changed (1) hide show
  1. modeling_for_glue.py +160 -1
modeling_for_glue.py CHANGED
@@ -3,7 +3,7 @@ from typing import Optional, Union, Tuple
3
  import torch
4
  from torch import nn
5
  from torch.nn import CrossEntropyLoss, MSELoss, BCEWithLogitsLoss
6
- from transformers.modeling_outputs import SequenceClassifierOutput
7
 
8
  from .modeling_bert import BertPreTrainedModel, BertModel
9
  from .configuration_bert import JinaBertConfig
@@ -102,3 +102,162 @@ class BertForSequenceClassification(BertPreTrainedModel):
102
  hidden_states=outputs.hidden_states,
103
  attentions=outputs.attentions,
104
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  import torch
4
  from torch import nn
5
  from torch.nn import CrossEntropyLoss, MSELoss, BCEWithLogitsLoss
6
+ from transformers.modeling_outputs import SequenceClassifierOutput, QuestionAnsweringModelOutput, TokenClassifierOutput
7
 
8
  from .modeling_bert import BertPreTrainedModel, BertModel
9
  from .configuration_bert import JinaBertConfig
 
102
  hidden_states=outputs.hidden_states,
103
  attentions=outputs.attentions,
104
  )
105
+
106
+ class BertForQuestionAnswering(BertPreTrainedModel):
107
+ def __init__(self, config: JinaBertConfig):
108
+ super().__init__(config)
109
+ self.num_labels = config.num_labels
110
+
111
+ self.bert = BertModel(config, add_pooling_layer=False)
112
+ self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
113
+
114
+ # Initialize weights and apply final processing
115
+ self.post_init()
116
+
117
+ def forward(
118
+ self,
119
+ input_ids: Optional[torch.Tensor] = None,
120
+ attention_mask: Optional[torch.Tensor] = None,
121
+ token_type_ids: Optional[torch.Tensor] = None,
122
+ position_ids: Optional[torch.Tensor] = None,
123
+ head_mask: Optional[torch.Tensor] = None,
124
+ inputs_embeds: Optional[torch.Tensor] = None,
125
+ start_positions: Optional[torch.Tensor] = None,
126
+ end_positions: Optional[torch.Tensor] = None,
127
+ output_attentions: Optional[bool] = None,
128
+ output_hidden_states: Optional[bool] = None,
129
+ return_dict: Optional[bool] = None,
130
+ ) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]:
131
+ r"""
132
+ start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
133
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
134
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
135
+ are not taken into account for computing the loss.
136
+ end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
137
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
138
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
139
+ are not taken into account for computing the loss.
140
+ """
141
+ return_dict = (
142
+ return_dict if return_dict is not None else self.config.use_return_dict
143
+ )
144
+
145
+ outputs = self.bert(
146
+ input_ids,
147
+ attention_mask=attention_mask,
148
+ token_type_ids=token_type_ids,
149
+ position_ids=position_ids,
150
+ head_mask=head_mask,
151
+ inputs_embeds=inputs_embeds,
152
+ output_attentions=output_attentions,
153
+ output_hidden_states=output_hidden_states,
154
+ return_dict=return_dict,
155
+ )
156
+
157
+ sequence_output = outputs[0]
158
+
159
+ logits = self.qa_outputs(sequence_output)
160
+ start_logits, end_logits = logits.split(1, dim=-1)
161
+ start_logits = start_logits.squeeze(-1).contiguous()
162
+ end_logits = end_logits.squeeze(-1).contiguous()
163
+
164
+ total_loss = None
165
+ if start_positions is not None and end_positions is not None:
166
+ # If we are on multi-GPU, split add a dimension
167
+ if len(start_positions.size()) > 1:
168
+ start_positions = start_positions.squeeze(-1)
169
+ if len(end_positions.size()) > 1:
170
+ end_positions = end_positions.squeeze(-1)
171
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
172
+ ignored_index = start_logits.size(1)
173
+ start_positions = start_positions.clamp(0, ignored_index)
174
+ end_positions = end_positions.clamp(0, ignored_index)
175
+
176
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
177
+ start_loss = loss_fct(start_logits, start_positions)
178
+ end_loss = loss_fct(end_logits, end_positions)
179
+ total_loss = (start_loss + end_loss) / 2
180
+
181
+ if not return_dict:
182
+ output = (start_logits, end_logits) + outputs[2:]
183
+ return ((total_loss,) + output) if total_loss is not None else output
184
+
185
+ return QuestionAnsweringModelOutput(
186
+ loss=total_loss,
187
+ start_logits=start_logits,
188
+ end_logits=end_logits,
189
+ hidden_states=outputs.hidden_states,
190
+ attentions=outputs.attentions,
191
+ )
192
+
193
+
194
+ class BertForTokenClassification(BertPreTrainedModel):
195
+ def __init__(self, config: JinaBertConfig):
196
+ super().__init__(config)
197
+ self.num_labels = config.num_labels
198
+
199
+ self.bert = BertModel(config, add_pooling_layer=False)
200
+ classifier_dropout = (
201
+ config.classifier_dropout
202
+ if config.classifier_dropout is not None
203
+ else config.hidden_dropout_prob
204
+ )
205
+ self.dropout = nn.Dropout(classifier_dropout)
206
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
207
+
208
+ # Initialize weights and apply final processing
209
+ self.post_init()
210
+
211
+ def forward(
212
+ self,
213
+ input_ids: Optional[torch.Tensor] = None,
214
+ attention_mask: Optional[torch.Tensor] = None,
215
+ token_type_ids: Optional[torch.Tensor] = None,
216
+ position_ids: Optional[torch.Tensor] = None,
217
+ head_mask: Optional[torch.Tensor] = None,
218
+ inputs_embeds: Optional[torch.Tensor] = None,
219
+ labels: Optional[torch.Tensor] = None,
220
+ output_attentions: Optional[bool] = None,
221
+ output_hidden_states: Optional[bool] = None,
222
+ return_dict: Optional[bool] = None,
223
+ ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
224
+ r"""
225
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
226
+ Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
227
+ """
228
+ return_dict = (
229
+ return_dict if return_dict is not None else self.config.use_return_dict
230
+ )
231
+
232
+ outputs = self.bert(
233
+ input_ids,
234
+ attention_mask=attention_mask,
235
+ token_type_ids=token_type_ids,
236
+ position_ids=position_ids,
237
+ head_mask=head_mask,
238
+ inputs_embeds=inputs_embeds,
239
+ output_attentions=output_attentions,
240
+ output_hidden_states=output_hidden_states,
241
+ return_dict=return_dict,
242
+ )
243
+
244
+ sequence_output = outputs[0]
245
+
246
+ sequence_output = self.dropout(sequence_output)
247
+ logits = self.classifier(sequence_output)
248
+
249
+ loss = None
250
+ if labels is not None:
251
+ loss_fct = CrossEntropyLoss()
252
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
253
+
254
+ if not return_dict:
255
+ output = (logits,) + outputs[2:]
256
+ return ((loss,) + output) if loss is not None else output
257
+
258
+ return TokenClassifierOutput(
259
+ loss=loss,
260
+ logits=logits,
261
+ hidden_states=outputs.hidden_states,
262
+ attentions=outputs.attentions,
263
+ )