socialcomp commited on
Commit
1ca6f42
1 Parent(s): b6af7c5

Upload NER_medNLP.py

Browse files
Files changed (1) hide show
  1. NER_medNLP.py +238 -0
NER_medNLP.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %%
2
+
3
+ import itertools
4
+ from tqdm import tqdm
5
+ import numpy as np
6
+ import torch
7
+ from transformers import BertJapaneseTokenizer, BertForTokenClassification
8
+ import pytorch_lightning as pl
9
+
10
+ # from torch.utils.data import DataLoader
11
+ # import from_XML_to_json as XtC
12
+ # import random
13
+ # import json
14
+ # import unicodedata
15
+ # import pandas as pd
16
+
17
+ # %%
18
+ # 8-16
19
+ # PyTorch Lightningのモデル
20
+ class BertForTokenClassification_pl(pl.LightningModule):
21
+
22
+ def __init__(self, model_name, num_labels, lr):
23
+ super().__init__()
24
+ self.save_hyperparameters()
25
+ self.bert_tc = BertForTokenClassification.from_pretrained(
26
+ model_name,
27
+ num_labels=num_labels
28
+ )
29
+
30
+ def training_step(self, batch, batch_idx):
31
+ output = self.bert_tc(**batch)
32
+ loss = output.loss
33
+ self.log('train_loss', loss)
34
+ return loss
35
+
36
+ def validation_step(self, batch, batch_idx):
37
+ output = self.bert_tc(**batch)
38
+ val_loss = output.loss
39
+ self.log('val_loss', val_loss)
40
+
41
+ def configure_optimizers(self):
42
+ return torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
43
+
44
+
45
+
46
+ # %%
47
+ class NER_tokenizer_BIO(BertJapaneseTokenizer):
48
+
49
+ # 初期化時に固有表現のカテゴリーの数`num_entity_type`を
50
+ # 受け入れるようにする。
51
+ def __init__(self, *args, **kwargs):
52
+ self.num_entity_type = kwargs.pop('num_entity_type')
53
+ super().__init__(*args, **kwargs)
54
+
55
+ def encode_plus_tagged(self, text, entities, max_length):
56
+ """
57
+ 文章とそれに含まれる固有表現が与えられた時に、
58
+ 符号化とラベル列の作成を行う。
59
+ """
60
+ # 固有表現の前後でtextを分割し、それぞれのラベルをつけておく。
61
+ splitted = [] # 分割後の文字列を追加していく
62
+ position = 0
63
+
64
+ for entity in entities:
65
+ start = entity['span'][0]
66
+ end = entity['span'][1]
67
+ label = entity['type_id']
68
+ splitted.append({'text':text[position:start], 'label':0})
69
+ splitted.append({'text':text[start:end], 'label':label})
70
+ position = end
71
+ splitted.append({'text': text[position:], 'label':0})
72
+ splitted = [ s for s in splitted if s['text'] ]
73
+
74
+ # 分割されたそれぞれの文章をトークン化し、ラベルをつける。
75
+ tokens = [] # トークンを追加していく
76
+ labels = [] # ラベルを追加していく
77
+ for s in splitted:
78
+ tokens_splitted = self.tokenize(s['text'])
79
+ label = s['label']
80
+ if label > 0: # 固有表現
81
+ # まずトークン全てにI-タグを付与
82
+ # 番号順O-tag:0, B-tag:1~タグの数,I-tag:タグの数〜
83
+ labels_splitted = \
84
+ [ label + self.num_entity_type ] * len(tokens_splitted)
85
+ # 先頭のトークンをB-タグにする
86
+ labels_splitted[0] = label
87
+ else: # それ以外
88
+ labels_splitted = [0] * len(tokens_splitted)
89
+
90
+ tokens.extend(tokens_splitted)
91
+ labels.extend(labels_splitted)
92
+
93
+ # 符号化を行いBERTに入力できる形式にする。
94
+ input_ids = self.convert_tokens_to_ids(tokens)
95
+ encoding = self.prepare_for_model(
96
+ input_ids,
97
+ max_length=max_length,
98
+ padding='max_length',
99
+ truncation=True
100
+ )
101
+
102
+ # ラベルに特殊トークンを追加
103
+ # max_lengthで切り取って,その前後に[CLS]と[SEP]を追加するためのラベルを入れる
104
+ labels = [0] + labels[:max_length-2] + [0]
105
+ # max_lengthに満たない場合は,満たない分を後ろ側に追加する
106
+ labels = labels + [0]*( max_length - len(labels) )
107
+ encoding['labels'] = labels
108
+
109
+ return encoding
110
+
111
+ def encode_plus_untagged(
112
+ self, text, max_length=None, return_tensors=None
113
+ ):
114
+ """
115
+ 文章をトークン化し、それぞれのトークンの文章中の位置も特定しておく。
116
+ IO法のトークナイザのencode_plus_untaggedと同じ
117
+ """
118
+ # 文章のトークン化を行い、
119
+ # それぞれのトークンと文章中の文字列を対応づける。
120
+ tokens = [] # トークンを追加していく。
121
+ tokens_original = [] # トークンに対応する文章中の文字列を追加していく。
122
+ words = self.word_tokenizer.tokenize(text) # MeCabで単語に分割
123
+ for word in words:
124
+ # 単語をサブワードに分割
125
+ tokens_word = self.subword_tokenizer.tokenize(word)
126
+ tokens.extend(tokens_word)
127
+ if tokens_word[0] == '[UNK]': # 未知語への対応
128
+ tokens_original.append(word)
129
+ else:
130
+ tokens_original.extend([
131
+ token.replace('##','') for token in tokens_word
132
+ ])
133
+
134
+ # 各トークンの文章中での位置を調べる。(空白の位置を考慮する)
135
+ position = 0
136
+ spans = [] # トークンの位置を追加していく。
137
+ for token in tokens_original:
138
+ l = len(token)
139
+ while 1:
140
+ if token != text[position:position+l]:
141
+ position += 1
142
+ else:
143
+ spans.append([position, position+l])
144
+ position += l
145
+ break
146
+
147
+ # 符号化を行いBERTに入力できる形式にする。
148
+ input_ids = self.convert_tokens_to_ids(tokens)
149
+ encoding = self.prepare_for_model(
150
+ input_ids,
151
+ max_length=max_length,
152
+ padding='max_length' if max_length else False,
153
+ truncation=True if max_length else False
154
+ )
155
+ sequence_length = len(encoding['input_ids'])
156
+ # 特殊トークン[CLS]に対するダミーのspanを追加。
157
+ spans = [[-1, -1]] + spans[:sequence_length-2]
158
+ # 特殊トークン[SEP]、[PAD]に対するダミーのspanを追加。
159
+ spans = spans + [[-1, -1]] * ( sequence_length - len(spans) )
160
+
161
+ # 必要に応じてtorch.Tensorにする。
162
+ if return_tensors == 'pt':
163
+ encoding = { k: torch.tensor([v]) for k, v in encoding.items() }
164
+
165
+ return encoding, spans
166
+
167
+ @staticmethod
168
+ def Viterbi(scores_bert, num_entity_type, penalty=10000):
169
+ """
170
+ Viterbiアルゴリズムで最適解を求める。
171
+ """
172
+ m = 2*num_entity_type + 1
173
+ penalty_matrix = np.zeros([m, m])
174
+ for i in range(m):
175
+ for j in range(1+num_entity_type, m):
176
+ if not ( (i == j) or (i+num_entity_type == j) ):
177
+ penalty_matrix[i,j] = penalty
178
+ path = [ [i] for i in range(m) ]
179
+ scores_path = scores_bert[0] - penalty_matrix[0,:]
180
+ scores_bert = scores_bert[1:]
181
+
182
+
183
+
184
+ for scores in scores_bert:
185
+ assert len(scores) == 2*num_entity_type + 1
186
+ score_matrix = np.array(scores_path).reshape(-1,1) \
187
+ + np.array(scores).reshape(1,-1) \
188
+ - penalty_matrix
189
+ scores_path = score_matrix.max(axis=0)
190
+ argmax = score_matrix.argmax(axis=0)
191
+ path_new = []
192
+ for i, idx in enumerate(argmax):
193
+ path_new.append( path[idx] + [i] )
194
+ path = path_new
195
+
196
+ labels_optimal = path[np.argmax(scores_path)]
197
+ return labels_optimal
198
+
199
+ def convert_bert_output_to_entities(self, text, scores, spans):
200
+ """
201
+ 文章、分類スコア、各トークンの位置から固有表現を得る。
202
+ 分類スコアはサイズが(系列長、ラベル数)の2次元配列
203
+ """
204
+ assert len(spans) == len(scores)
205
+ num_entity_type = self.num_entity_type
206
+
207
+ # 特殊トークンに対応する部分を取り除く
208
+ scores = [score for score, span in zip(scores, spans) if span[0]!=-1]
209
+ spans = [span for span in spans if span[0]!=-1]
210
+
211
+ # Viterbiアルゴリズムでラベルの予測値を決める。
212
+ labels = self.Viterbi(scores, num_entity_type)
213
+
214
+ # 同じラベルが連続するトークンをまとめて、固有表現を抽出する。
215
+ entities = []
216
+ for label, group \
217
+ in itertools.groupby(enumerate(labels), key=lambda x: x[1]):
218
+
219
+ group = list(group)
220
+ start = spans[group[0][0]][0]
221
+ end = spans[group[-1][0]][1]
222
+
223
+ if label != 0: # 固有表現であれば
224
+ if 1 <= label <= num_entity_type:
225
+ # ラベルが`B-`ならば、新しいentityを追加
226
+ entity = {
227
+ "name": text[start:end],
228
+ "span": [start, end],
229
+ "type_id": label
230
+ }
231
+ entities.append(entity)
232
+ else:
233
+ # ラベルが`I-`ならば、直近のentityを更新
234
+ entity['span'][1] = end
235
+ entity['name'] = text[entity['span'][0]:entity['span'][1]]
236
+
237
+ return entities
238
+