yangwang825 commited on
Commit
88cefe9
·
1 Parent(s): 7b4021e

Create modeling_bert.py

Browse files
Files changed (1) hide show
  1. modeling_bert.py +289 -0
modeling_bert.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from typing import Optional, List, Union, Tuple
4
+ from transformers import (
5
+ PretrainedConfig,
6
+ PreTrainedModel,
7
+ AutoTokenizer,
8
+ AutoConfig,
9
+ AutoModel,
10
+ AutoModelForSequenceClassification
11
+ )
12
+ from transformers.models.bert.modeling_bert import (
13
+ BertEmbeddings,
14
+ BertEncoder,
15
+ load_tf_weights_in_bert
16
+ )
17
+ from transformers.modeling_outputs import (
18
+ BaseModelOutputWithPoolingAndCrossAttentions,
19
+ SequenceClassifierOutput
20
+ )
21
+
22
+ from .configuration_bert import BertClsConfig
23
+
24
+
25
+ class BertPreTrainedModel(PreTrainedModel):
26
+ """
27
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
28
+ models.
29
+ """
30
+
31
+ config_class = BertClsConfig
32
+ load_tf_weights = load_tf_weights_in_bert
33
+ base_model_prefix = "bert"
34
+ supports_gradient_checkpointing = True
35
+
36
+ def _init_weights(self, module):
37
+ """Initialize the weights"""
38
+ if isinstance(module, nn.Linear):
39
+ # Slightly different from the TF version which uses truncated_normal for initialization
40
+ # cf https://github.com/pytorch/pytorch/pull/5617
41
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
42
+ if module.bias is not None:
43
+ module.bias.data.zero_()
44
+ elif isinstance(module, nn.Embedding):
45
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
46
+ if module.padding_idx is not None:
47
+ module.weight.data[module.padding_idx].zero_()
48
+ elif isinstance(module, nn.LayerNorm):
49
+ module.bias.data.zero_()
50
+ module.weight.data.fill_(1.0)
51
+
52
+
53
+ class BertClsPooler(nn.Module):
54
+
55
+ def __init__(self, config):
56
+ super().__init__()
57
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
58
+ self.activation = nn.Tanh()
59
+
60
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
61
+ # We "pool" the model by simply taking the hidden state corresponding
62
+ # to the first token.
63
+ first_token_tensor = hidden_states[:, 0]
64
+ pooled_output = self.dense(first_token_tensor)
65
+ pooled_output = self.activation(pooled_output)
66
+ return pooled_output
67
+
68
+
69
+ class BertModel(BertPreTrainedModel):
70
+
71
+ def __init__(self, config, add_pooling_layer=True):
72
+ super().__init__(config)
73
+ self.config = config
74
+
75
+ self.embeddings = BertEmbeddings(config)
76
+ self.encoder = BertEncoder(config)
77
+
78
+ self.pooler = BertClsPooler(config) if add_pooling_layer else None
79
+
80
+ # Initialize weights and apply final processing
81
+ self.post_init()
82
+
83
+ def get_input_embeddings(self):
84
+ return self.embeddings.word_embeddings
85
+
86
+ def set_input_embeddings(self, value):
87
+ self.embeddings.word_embeddings = value
88
+
89
+ def forward(
90
+ self,
91
+ input_ids: Optional[torch.Tensor] = None,
92
+ attention_mask: Optional[torch.Tensor] = None,
93
+ token_type_ids: Optional[torch.Tensor] = None,
94
+ position_ids: Optional[torch.Tensor] = None,
95
+ head_mask: Optional[torch.Tensor] = None,
96
+ inputs_embeds: Optional[torch.Tensor] = None,
97
+ encoder_hidden_states: Optional[torch.Tensor] = None,
98
+ encoder_attention_mask: Optional[torch.Tensor] = None,
99
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
100
+ use_cache: Optional[bool] = None,
101
+ output_attentions: Optional[bool] = None,
102
+ output_hidden_states: Optional[bool] = None,
103
+ return_dict: Optional[bool] = None,
104
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
105
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
106
+ output_hidden_states = (
107
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
108
+ )
109
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
110
+
111
+ if self.config.is_decoder:
112
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
113
+ else:
114
+ use_cache = False
115
+
116
+ if input_ids is not None and inputs_embeds is not None:
117
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
118
+ elif input_ids is not None:
119
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
120
+ input_shape = input_ids.size()
121
+ elif inputs_embeds is not None:
122
+ input_shape = inputs_embeds.size()[:-1]
123
+ else:
124
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
125
+
126
+ batch_size, seq_length = input_shape
127
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
128
+
129
+ # past_key_values_length
130
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
131
+
132
+ if attention_mask is None:
133
+ attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
134
+
135
+ if token_type_ids is None:
136
+ if hasattr(self.embeddings, "token_type_ids"):
137
+ buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
138
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
139
+ token_type_ids = buffered_token_type_ids_expanded
140
+ else:
141
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
142
+
143
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
144
+ # ourselves in which case we just need to make it broadcastable to all heads.
145
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
146
+
147
+ # If a 2D or 3D attention mask is provided for the cross-attention
148
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
149
+ if self.config.is_decoder and encoder_hidden_states is not None:
150
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
151
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
152
+ if encoder_attention_mask is None:
153
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
154
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
155
+ else:
156
+ encoder_extended_attention_mask = None
157
+
158
+ # Prepare head mask if needed
159
+ # 1.0 in head_mask indicate we keep the head
160
+ # attention_probs has shape bsz x n_heads x N x N
161
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
162
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
163
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
164
+
165
+ embedding_output = self.embeddings(
166
+ input_ids=input_ids,
167
+ position_ids=position_ids,
168
+ token_type_ids=token_type_ids,
169
+ inputs_embeds=inputs_embeds,
170
+ past_key_values_length=past_key_values_length,
171
+ )
172
+ encoder_outputs = self.encoder(
173
+ embedding_output,
174
+ attention_mask=extended_attention_mask,
175
+ head_mask=head_mask,
176
+ encoder_hidden_states=encoder_hidden_states,
177
+ encoder_attention_mask=encoder_extended_attention_mask,
178
+ past_key_values=past_key_values,
179
+ use_cache=use_cache,
180
+ output_attentions=output_attentions,
181
+ output_hidden_states=output_hidden_states,
182
+ return_dict=return_dict,
183
+ )
184
+ sequence_output = encoder_outputs[0]
185
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
186
+
187
+ if not return_dict:
188
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
189
+
190
+ return BaseModelOutputWithPoolingAndCrossAttentions(
191
+ last_hidden_state=sequence_output,
192
+ pooler_output=pooled_output,
193
+ past_key_values=encoder_outputs.past_key_values,
194
+ hidden_states=encoder_outputs.hidden_states,
195
+ attentions=encoder_outputs.attentions,
196
+ cross_attentions=encoder_outputs.cross_attentions,
197
+ )
198
+
199
+
200
+ class BertForSequenceClassification(BertPreTrainedModel):
201
+
202
+ def __init__(self, config):
203
+ super().__init__(config)
204
+ self.num_labels = config.num_labels
205
+ self.config = config
206
+
207
+ self.bert = BertModel(config)
208
+ classifier_dropout = (
209
+ config.classifier_dropout
210
+ if config.classifier_dropout is not None
211
+ else config.hidden_dropout_prob
212
+ )
213
+ self.dropout = nn.Dropout(classifier_dropout)
214
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
215
+
216
+ # Initialize weights and apply final processing
217
+ self.post_init()
218
+
219
+ def forward(
220
+ self,
221
+ input_ids: Optional[torch.Tensor] = None,
222
+ attention_mask: Optional[torch.Tensor] = None,
223
+ token_type_ids: Optional[torch.Tensor] = None,
224
+ position_ids: Optional[torch.Tensor] = None,
225
+ head_mask: Optional[torch.Tensor] = None,
226
+ inputs_embeds: Optional[torch.Tensor] = None,
227
+ labels: Optional[torch.Tensor] = None,
228
+ output_attentions: Optional[bool] = None,
229
+ output_hidden_states: Optional[bool] = None,
230
+ return_dict: Optional[bool] = None,
231
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
232
+ r"""
233
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
234
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
235
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
236
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
237
+ """
238
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
239
+
240
+ outputs = self.bert(
241
+ input_ids,
242
+ attention_mask=attention_mask,
243
+ token_type_ids=token_type_ids,
244
+ position_ids=position_ids,
245
+ head_mask=head_mask,
246
+ inputs_embeds=inputs_embeds,
247
+ output_attentions=output_attentions,
248
+ output_hidden_states=output_hidden_states,
249
+ return_dict=return_dict,
250
+ )
251
+
252
+ pooled_output = outputs[1]
253
+
254
+ pooled_output = self.dropout(pooled_output)
255
+ logits = self.classifier(pooled_output)
256
+
257
+ loss = None
258
+ if labels is not None:
259
+ if self.config.problem_type is None:
260
+ if self.num_labels == 1:
261
+ self.config.problem_type = "regression"
262
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
263
+ self.config.problem_type = "single_label_classification"
264
+ else:
265
+ self.config.problem_type = "multi_label_classification"
266
+
267
+ if self.config.problem_type == "regression":
268
+ loss_fct = nn.MSELoss()
269
+ if self.num_labels == 1:
270
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
271
+ else:
272
+ loss = loss_fct(logits, labels)
273
+ elif self.config.problem_type == "single_label_classification":
274
+ loss_fct = nn.CrossEntropyLoss()
275
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
276
+ elif self.config.problem_type == "multi_label_classification":
277
+ loss_fct = nn.BCEWithLogitsLoss()
278
+ loss = loss_fct(logits, labels)
279
+ if not return_dict:
280
+ output = (logits,) + outputs[2:]
281
+ return ((loss,) + output) if loss is not None else output
282
+
283
+ return SequenceClassifierOutput(
284
+ loss=loss,
285
+ logits=logits,
286
+ hidden_states=outputs.hidden_states,
287
+ attentions=outputs.attentions,
288
+ )
289
+