File size: 9,343 Bytes
1ca6f42 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 |
# %%
import itertools
from tqdm import tqdm
import numpy as np
import torch
from transformers import BertJapaneseTokenizer, BertForTokenClassification
import pytorch_lightning as pl
# from torch.utils.data import DataLoader
# import from_XML_to_json as XtC
# import random
# import json
# import unicodedata
# import pandas as pd
# %%
# 8-16
# PyTorch Lightningのモデル
class BertForTokenClassification_pl(pl.LightningModule):
def __init__(self, model_name, num_labels, lr):
super().__init__()
self.save_hyperparameters()
self.bert_tc = BertForTokenClassification.from_pretrained(
model_name,
num_labels=num_labels
)
def training_step(self, batch, batch_idx):
output = self.bert_tc(**batch)
loss = output.loss
self.log('train_loss', loss)
return loss
def validation_step(self, batch, batch_idx):
output = self.bert_tc(**batch)
val_loss = output.loss
self.log('val_loss', val_loss)
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
# %%
class NER_tokenizer_BIO(BertJapaneseTokenizer):
# 初期化時に固有表現のカテゴリーの数`num_entity_type`を
# 受け入れるようにする。
def __init__(self, *args, **kwargs):
self.num_entity_type = kwargs.pop('num_entity_type')
super().__init__(*args, **kwargs)
def encode_plus_tagged(self, text, entities, max_length):
"""
文章とそれに含まれる固有表現が与えられた時に、
符号化とラベル列の作成を行う。
"""
# 固有表現の前後でtextを分割し、それぞれのラベルをつけておく。
splitted = [] # 分割後の文字列を追加していく
position = 0
for entity in entities:
start = entity['span'][0]
end = entity['span'][1]
label = entity['type_id']
splitted.append({'text':text[position:start], 'label':0})
splitted.append({'text':text[start:end], 'label':label})
position = end
splitted.append({'text': text[position:], 'label':0})
splitted = [ s for s in splitted if s['text'] ]
# 分割されたそれぞれの文章をトークン化し、ラベルをつける。
tokens = [] # トークンを追加していく
labels = [] # ラベルを追加していく
for s in splitted:
tokens_splitted = self.tokenize(s['text'])
label = s['label']
if label > 0: # 固有表現
# まずトークン全てにI-タグを付与
# 番号順O-tag:0, B-tag:1~タグの数,I-tag:タグの数〜
labels_splitted = \
[ label + self.num_entity_type ] * len(tokens_splitted)
# 先頭のトークンをB-タグにする
labels_splitted[0] = label
else: # それ以外
labels_splitted = [0] * len(tokens_splitted)
tokens.extend(tokens_splitted)
labels.extend(labels_splitted)
# 符号化を行いBERTに入力できる形式にする。
input_ids = self.convert_tokens_to_ids(tokens)
encoding = self.prepare_for_model(
input_ids,
max_length=max_length,
padding='max_length',
truncation=True
)
# ラベルに特殊トークンを追加
# max_lengthで切り取って,その前後に[CLS]と[SEP]を追加するためのラベルを入れる
labels = [0] + labels[:max_length-2] + [0]
# max_lengthに満たない場合は,満たない分を後ろ側に追加する
labels = labels + [0]*( max_length - len(labels) )
encoding['labels'] = labels
return encoding
def encode_plus_untagged(
self, text, max_length=None, return_tensors=None
):
"""
文章をトークン化し、それぞれのトークンの文章中の位置も特定しておく。
IO法のトークナイザのencode_plus_untaggedと同じ
"""
# 文章のトークン化を行い、
# それぞれのトークンと文章中の文字列を対応づける。
tokens = [] # トークンを追加していく。
tokens_original = [] # トークンに対応する文章中の文字列を追加していく。
words = self.word_tokenizer.tokenize(text) # MeCabで単語に分割
for word in words:
# 単語をサブワードに分割
tokens_word = self.subword_tokenizer.tokenize(word)
tokens.extend(tokens_word)
if tokens_word[0] == '[UNK]': # 未知語への対応
tokens_original.append(word)
else:
tokens_original.extend([
token.replace('##','') for token in tokens_word
])
# 各トークンの文章中での位置を調べる。(空白の位置を考慮する)
position = 0
spans = [] # トークンの位置を追加していく。
for token in tokens_original:
l = len(token)
while 1:
if token != text[position:position+l]:
position += 1
else:
spans.append([position, position+l])
position += l
break
# 符号化を行いBERTに入力できる形式にする。
input_ids = self.convert_tokens_to_ids(tokens)
encoding = self.prepare_for_model(
input_ids,
max_length=max_length,
padding='max_length' if max_length else False,
truncation=True if max_length else False
)
sequence_length = len(encoding['input_ids'])
# 特殊トークン[CLS]に対するダミーのspanを追加。
spans = [[-1, -1]] + spans[:sequence_length-2]
# 特殊トークン[SEP]、[PAD]に対するダミーのspanを追加。
spans = spans + [[-1, -1]] * ( sequence_length - len(spans) )
# 必要に応じてtorch.Tensorにする。
if return_tensors == 'pt':
encoding = { k: torch.tensor([v]) for k, v in encoding.items() }
return encoding, spans
@staticmethod
def Viterbi(scores_bert, num_entity_type, penalty=10000):
"""
Viterbiアルゴリズムで最適解を求める。
"""
m = 2*num_entity_type + 1
penalty_matrix = np.zeros([m, m])
for i in range(m):
for j in range(1+num_entity_type, m):
if not ( (i == j) or (i+num_entity_type == j) ):
penalty_matrix[i,j] = penalty
path = [ [i] for i in range(m) ]
scores_path = scores_bert[0] - penalty_matrix[0,:]
scores_bert = scores_bert[1:]
for scores in scores_bert:
assert len(scores) == 2*num_entity_type + 1
score_matrix = np.array(scores_path).reshape(-1,1) \
+ np.array(scores).reshape(1,-1) \
- penalty_matrix
scores_path = score_matrix.max(axis=0)
argmax = score_matrix.argmax(axis=0)
path_new = []
for i, idx in enumerate(argmax):
path_new.append( path[idx] + [i] )
path = path_new
labels_optimal = path[np.argmax(scores_path)]
return labels_optimal
def convert_bert_output_to_entities(self, text, scores, spans):
"""
文章、分類スコア、各トークンの位置から固有表現を得る。
分類スコアはサイズが(系列長、ラベル数)の2次元配列
"""
assert len(spans) == len(scores)
num_entity_type = self.num_entity_type
# 特殊トークンに対応する部分を取り除く
scores = [score for score, span in zip(scores, spans) if span[0]!=-1]
spans = [span for span in spans if span[0]!=-1]
# Viterbiアルゴリズムでラベルの予測値を決める。
labels = self.Viterbi(scores, num_entity_type)
# 同じラベルが連続するトークンをまとめて、固有表現を抽出する。
entities = []
for label, group \
in itertools.groupby(enumerate(labels), key=lambda x: x[1]):
group = list(group)
start = spans[group[0][0]][0]
end = spans[group[-1][0]][1]
if label != 0: # 固有表現であれば
if 1 <= label <= num_entity_type:
# ラベルが`B-`ならば、新しいentityを追加
entity = {
"name": text[start:end],
"span": [start, end],
"type_id": label
}
entities.append(entity)
else:
# ラベルが`I-`ならば、直近のentityを更新
entity['span'][1] = end
entity['name'] = text[entity['span'][0]:entity['span'][1]]
return entities
|