File size: 4,547 Bytes
b6af7c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b26a83f
b6af7c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b26a83f
b6af7c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b26a83f
b6af7c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61b6f64
b6af7c5
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144

# %%
from tqdm import tqdm
import unicodedata
import re
import pickle
import torch
import NER_medNLP as ner
from bs4 import BeautifulSoup


# import from_XML_to_json as XtC
# import itertools
# import random
# import json
# from torch.utils.data import DataLoader
# from transformers import BertJapaneseTokenizer, BertForTokenClassification
# import pytorch_lightning as pl
# import pandas as pd
# import numpy as np
# import codecs
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

#%% global変数として使う
dict_key = {}

#%%
def to_xml(data):
    with open("key_attr.pkl", "rb") as tf:
        key_attr = pickle.load(tf)
    
    text = data['text']
    count = 0
    for i, entities in enumerate(data['entities_predicted']):
        if entities == "":
            return   
        span = entities['span']
        type_id = id_to_tags[entities['type_id']].split('_')
        tag = type_id[0]
        
        if not type_id[1] == "":
            attr = ' ' + value_to_key(type_id[1], key_attr) +  '=' + '"' + type_id[1] + '"'
        else:
            attr = ""
        
        add_tag = "<" + str(tag) + str(attr) + ">"
        text = text[:span[0]+count] + add_tag + text[span[0]+count:]
        count += len(add_tag)

        add_tag = "</" + str(tag) + ">"
        text = text[:span[1]+count] + add_tag + text[span[1]+count:]
        count += len(add_tag)
    return text


def predict_entities(modelpath, sentences_list, len_num_entity_type):
    # model = ner.BertForTokenClassification_pl.load_from_checkpoint(
    #     checkpoint_path = modelpath + ".ckpt"
    # ) 
    # bert_tc = model.bert_tc.cuda()
    
    model = ner.BertForTokenClassification_pl(modelpath, num_labels=81, lr=1e-5) 
    bert_tc = model.bert_tc.to(device)

    MODEL_NAME = 'cl-tohoku/bert-base-japanese-whole-word-masking'
    tokenizer = ner.NER_tokenizer_BIO.from_pretrained(
        MODEL_NAME,
        num_entity_type =  len_num_entity_type#Entityの数を変え忘れないように!
    )

    #entities_list = [] # 正解の固有表現を追加していく
    entities_predicted_list = [] # 抽出された固有表現を追加していく

    text_entities_set = []
    for dataset in sentences_list:
        text_entities = []
        for sample in tqdm(dataset):
            text = sample
            encoding, spans = tokenizer.encode_plus_untagged(
                text, return_tensors='pt'
            )
            encoding = { k: v.to(device) for k, v in encoding.items() } 
            
            with torch.no_grad():
                output = bert_tc(**encoding)
                scores = output.logits
                scores = scores[0].cpu().numpy().tolist()
                
            # 分類スコアを固有表現に変換する
            entities_predicted = tokenizer.convert_bert_output_to_entities(
                text, scores, spans
            )

            #entities_list.append(sample['entities'])
            entities_predicted_list.append(entities_predicted)
            text_entities.append({'text': text, 'entities_predicted': entities_predicted})
        text_entities_set.append(text_entities)
    return text_entities_set

def combine_sentences(text_entities_set, insert: str):
    documents = []
    for text_entities in tqdm(text_entities_set):
        document = []
        for t in text_entities:
            document.append(to_xml(t))
        documents.append('\n'.join(document))
    return documents

def value_to_key(value, key_attr):#attributeから属性名を取得
    global dict_key
    if dict_key.get(value) != None:
        return dict_key[value]
    for k in key_attr.keys():
        for v in key_attr[k]:
            if value == v:
                dict_key[v]=k
                return k

# %%
if __name__ == '__main__':
    with open("id_to_tags.pkl", "rb") as tf:
        id_to_tags = pickle.load(tf)
    with open("key_attr.pkl", "rb") as tf:
        key_attr = pickle.load(tf)
    with open('text.txt') as f:
        articles_raw = f.read()
        

    article_norm = unicodedata.normalize('NFKC', articles_raw)

    sentences_raw = [s for s in re.split(r'\n', articles_raw) if s != '']
    sentences_norm = [s for s in re.split(r'\n', article_norm) if s != '']

    text_entities_set = predict_entities("sociocom/MedNER-CR-JA", [sentences_norm], len(id_to_tags))


    for i, texts_ent in enumerate(text_entities_set[0]):
        texts_ent['text'] = sentences_raw[i]


    documents = combine_sentences(text_entities_set, '\n')

    print(documents[0])