satyaalmasian commited on
Commit
dab2163
1 Parent(s): 5c2d5ff

Upload BERTWithCRF.py

Browse files
Files changed (1) hide show
  1. BERTWithCRF.py +259 -0
BERTWithCRF.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #code adapted form https://github.com/Louis-udm/NER-BERT-CRF/blob/master/NER_BERT_CRF.py
2
+ import torch
3
+ from transformers import BertModel, BertConfig ##### import these guys -important otherwise config error and you spend an hour figuring out!
4
+ from transformers.models.bert.modeling_bert import BertPreTrainedModel
5
+ from torch import nn
6
+ from torch.nn import CrossEntropyLoss, BCELoss, LayerNorm
7
+ from transformers.modeling_outputs import TokenClassifierOutput
8
+
9
+ # Hack to guarantee backward-compatibility.
10
+ BertLayerNorm = LayerNorm
11
+
12
+
13
+ def log_sum_exp_batch(log_Tensor, axis=-1): # shape (batch_size,n,m)
14
+ return torch.max(log_Tensor, axis)[0]+torch.log(torch.exp(log_Tensor-torch.max(log_Tensor, axis)[0].view(log_Tensor.shape[0],-1,1)).sum(axis))
15
+
16
+
17
+ class BERT_CRF_NER(BertPreTrainedModel):
18
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
19
+
20
+ def __init__(self, config):
21
+ super().__init__(config)
22
+ self.hidden_size = 768
23
+ self.start_label_id = config.start_label_id
24
+ self.stop_label_id = config.stop_label_id
25
+ self.num_labels = config.num_classes
26
+ # self.max_seq_length = max_seq_length
27
+ self.batch_size = config.batch_size
28
+
29
+ # use pretrainded BertModel
30
+ self.bert = BertModel(config, add_pooling_layer=False)
31
+
32
+ self.dropout = torch.nn.Dropout(0.2)
33
+ # Maps the output of the bert into label space.
34
+ self.hidden2label = nn.Linear(self.hidden_size, self.num_labels)
35
+
36
+ # Matrix of transition parameters. Entry i,j is the score of transitioning *to* i *from* j.
37
+ self.transitions = nn.Parameter(
38
+ torch.randn(self.num_labels, self.num_labels))
39
+
40
+ # These two statements enforce the constraint that we never transfer *to* the start tag(or label),
41
+ # and we never transfer *from* the stop label (the model would probably learn this anyway,
42
+ # so this enforcement is likely unimportant)
43
+
44
+ self.transitions.data[self.start_label_id, :] = -10000
45
+ self.transitions.data[:, self.stop_label_id] = -10000
46
+
47
+ nn.init.xavier_uniform_(self.hidden2label.weight)
48
+ nn.init.constant_(self.hidden2label.bias, 0.0)
49
+ # self.apply(self.init_bert_weights)
50
+
51
+ def init_bert_weights(self, module):
52
+ """ Initialize the weights.
53
+ """
54
+ if isinstance(module, (nn.Linear, nn.Embedding)):
55
+ # Slightly different from the TF version which uses truncated_normal for initialization
56
+ # cf https://github.com/pytorch/pytorch/pull/5617
57
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
58
+ elif isinstance(module, BertLayerNorm):
59
+ module.bias.data.zero_()
60
+ module.weight.data.fill_(1.0)
61
+ if isinstance(module, nn.Linear) and module.bias is not None:
62
+ module.bias.data.zero_()
63
+
64
+ def _forward_alg(self, feats):
65
+ """
66
+ this also called alpha-recursion or forward recursion, to calculate log_prob of all barX
67
+ """
68
+
69
+ # T = self.max_seq_length
70
+ T = feats.shape[1]
71
+ batch_size = feats.shape[0]
72
+
73
+ # alpha_recursion,forward, alpha(zt)=p(zt,bar_x_1:t)
74
+ log_alpha = torch.Tensor(batch_size, 1, self.num_labels).fill_(-10000.).to(self.device)
75
+ # normal_alpha_0 : alpha[0]=Ot[0]*self.PIs
76
+ # self.start_label has all of the score. it is log,0 is p=1
77
+ log_alpha[:, 0, self.start_label_id] = 0
78
+
79
+ # feats: sentances -> word embedding -> lstm -> MLP -> feats
80
+ # feats is the probability of emission, feat.shape=(1,tag_size)
81
+ for t in range(1, T):
82
+ log_alpha = (log_sum_exp_batch(self.transitions + log_alpha, axis=-1) + feats[:, t]).unsqueeze(1)
83
+
84
+ # log_prob of all barX
85
+ log_prob_all_barX = log_sum_exp_batch(log_alpha)
86
+ return log_prob_all_barX
87
+
88
+ def _get_bert_features(self, input_ids,
89
+ attention_mask,
90
+ token_type_ids,
91
+ position_ids,
92
+ head_mask,
93
+ inputs_embeds,
94
+ output_attentions,
95
+ output_hidden_states,
96
+ return_dict):
97
+ """
98
+ sentences -> word embedding -> lstm -> MLP -> feats
99
+ """
100
+ bert_seq_out = self.bert(input_ids,
101
+ attention_mask=attention_mask,
102
+ token_type_ids=token_type_ids,
103
+ position_ids=position_ids,
104
+ head_mask=head_mask,
105
+ inputs_embeds=inputs_embeds,
106
+ output_attentions=output_attentions,
107
+ output_hidden_states=output_hidden_states,
108
+ return_dict=return_dict) # output_all_encoded_layers=False removed
109
+
110
+ bert_seq_out_last = bert_seq_out[0]
111
+ bert_seq_out_last = self.dropout(bert_seq_out_last)
112
+ bert_feats = self.hidden2label(bert_seq_out_last)
113
+ return bert_feats, bert_seq_out
114
+
115
+ def _score_sentence(self, feats, label_ids):
116
+ """
117
+ Gives the score of a provided label sequence
118
+ p(X=w1:t,Zt=tag1:t)=...p(Zt=tag_t|Zt-1=tag_t-1)p(xt|Zt=tag_t)...
119
+ """
120
+
121
+ # T = self.max_seq_length
122
+ T = feats.shape[1]
123
+ batch_size = feats.shape[0]
124
+
125
+ batch_transitions = self.transitions.expand(batch_size, self.num_labels, self.num_labels)
126
+ batch_transitions = batch_transitions.flatten(1)
127
+
128
+ score = torch.zeros((feats.shape[0], 1)).to(self.device)
129
+ # the 0th node is start_label->start_word, the probability of them=1. so t begins with 1.
130
+ for t in range(1, T):
131
+
132
+ score = score + \
133
+ batch_transitions.gather(-1, (label_ids[:, t] * self.num_labels + label_ids[:, t-1]).view(-1, 1)) + \
134
+ feats[:, t].gather(-1, label_ids[:, t].view(-1, 1)).view(-1, 1)
135
+ return score
136
+
137
+ def _viterbi_decode(self, feats):
138
+ """
139
+ Max-Product Algorithm or viterbi algorithm, argmax(p(z_0:t|x_0:t))
140
+ """
141
+
142
+ # T = self.max_seq_length
143
+ # feats=feats[0]#added
144
+ T = feats.shape[1]
145
+ batch_size = feats.shape[0]
146
+
147
+ # batch_transitions=self.transitions.expand(batch_size,self.num_labels,self.num_labels)
148
+
149
+ log_delta = torch.Tensor(batch_size, 1, self.num_labels).fill_(-10000.).to(self.device)
150
+ log_delta[:, 0, self.start_label_id] = 0
151
+
152
+ # psi is for the value of the last latent that make P(this_latent) maximum.
153
+ psi = torch.zeros((batch_size, T, self.num_labels), dtype=torch.long).to(self.device) # psi[0]=0000 useless
154
+ for t in range(1, T):
155
+ # delta[t][k]=max_z1:t-1( p(x1,x2,...,xt,z1,z2,...,zt-1,zt=k|theta) )
156
+ # delta[t] is the max prob of the path from z_t-1 to z_t[k]
157
+ log_delta, psi[:, t] = torch.max(self.transitions + log_delta, -1)
158
+ # psi[t][k]=argmax_z1:t-1( p(x1,x2,...,xt,z1,z2,...,zt-1,zt=k|theta) )
159
+ # psi[t][k] is the path chosen from z_t-1 to z_t[k],the value is the z_state(is k) index of z_t-1
160
+
161
+ log_delta = (log_delta + feats[:, t]).unsqueeze(1)
162
+
163
+ # trace back
164
+ path = torch.zeros((batch_size, T), dtype=torch.long).to(self.device)
165
+
166
+ # max p(z1:t,all_x|theta)
167
+ max_logLL_allz_allx, path[:, -1] = torch.max(log_delta.squeeze(), -1)
168
+
169
+ for t in range(T-2, -1, -1):
170
+ # choose the state of z_t according the state chosen of z_t+1.
171
+ path[:, t] = psi[:, t+1].gather(-1, path[:, t+1].view(-1, 1)).squeeze()
172
+
173
+ return max_logLL_allz_allx, path
174
+
175
+ def neg_log_likelihood(self, input_ids,
176
+ attention_mask,
177
+ token_type_ids,
178
+ position_ids,
179
+ head_mask,
180
+ inputs_embeds,
181
+ output_attentions,
182
+ output_hidden_states,
183
+ return_dict,
184
+ label_ids):
185
+
186
+ bert_feats, _ = self._get_bert_features(input_ids,
187
+ attention_mask,
188
+ token_type_ids,
189
+ position_ids,
190
+ head_mask,
191
+ inputs_embeds,
192
+ output_attentions,
193
+ output_hidden_states,
194
+ return_dict)
195
+
196
+ forward_score = self._forward_alg(bert_feats)
197
+ # p(X=w1:t,Zt=tag1:t)=...p(Zt=tag_t|Zt-1=tag_t-1)p(xt|Zt=tag_t)...
198
+ gold_score = self._score_sentence(bert_feats, label_ids)
199
+ # - log[ p(X=w1:t,Zt=tag1:t)/p(X=w1:t) ] = - log[ p(Zt=tag1:t|X=w1:t) ]
200
+ return torch.mean(forward_score - gold_score)
201
+
202
+ # this forward is just for predict, not for train
203
+ # dont confuse this with _forward_alg above.
204
+ def forward(
205
+ self,
206
+ input_ids=None,
207
+ attention_mask=None,
208
+ token_type_ids=None,
209
+ position_ids=None,
210
+ head_mask=None,
211
+ inputs_embeds=None,
212
+ labels=None,
213
+ output_attentions=None,
214
+ output_hidden_states=None,
215
+ return_dict=None,
216
+ inference_mode=False,
217
+ ):
218
+ # Get the emission scores from the BiLSTM
219
+ bert_feats, bert_out = self._get_bert_features(input_ids,
220
+ attention_mask,
221
+ token_type_ids,
222
+ position_ids,
223
+ head_mask,
224
+ inputs_embeds,
225
+ output_attentions,
226
+ output_hidden_states,
227
+ return_dict)
228
+
229
+ # Find the best path, given the features.
230
+ score, label_seq_ids = self._viterbi_decode(bert_feats)
231
+
232
+ if not inference_mode:
233
+ neg_log_likelihood = self.neg_log_likelihood(input_ids,
234
+ attention_mask,
235
+ token_type_ids,
236
+ position_ids,
237
+ head_mask,
238
+ inputs_embeds,
239
+ output_attentions,
240
+ output_hidden_states,
241
+ return_dict,
242
+ labels)
243
+
244
+ return TokenClassifierOutput(
245
+ loss=neg_log_likelihood,
246
+ logits=label_seq_ids,
247
+ hidden_states=bert_out.hidden_states,
248
+ attentions=bert_out.attentions,
249
+ )
250
+ else:
251
+ neg_log_likelihood = None
252
+ return TokenClassifierOutput(
253
+ loss=neg_log_likelihood,
254
+ logits=label_seq_ids,
255
+ hidden_states=bert_out.hidden_states,
256
+ attentions=bert_out.attentions,
257
+ )
258
+
259
+