File size: 8,530 Bytes
cd5ed10
c837b79
576d564
cd5ed10
576d564
 
cd5ed10
576d564
c837b79
576d564
 
 
c837b79
576d564
cd5ed10
c837b79
cd5ed10
 
c837b79
 
576d564
cd5ed10
 
c837b79
cd5ed10
 
 
 
c837b79
cd5ed10
576d564
 
 
 
 
cd5ed10
c837b79
cd5ed10
c837b79
cd5ed10
 
c837b79
 
 
 
cd5ed10
c837b79
cd5ed10
 
 
c837b79
cd5ed10
 
 
 
a459477
cd5ed10
c837b79
 
cd5ed10
 
 
 
a459477
cd5ed10
 
 
 
c837b79
 
cd5ed10
a459477
cd5ed10
 
c837b79
cd5ed10
 
 
 
 
c837b79
cd5ed10
 
 
 
 
c837b79
576d564
cd5ed10
a459477
cd5ed10
 
576d564
cd5ed10
 
 
c837b79
 
cd5ed10
 
 
 
 
 
c837b79
cd5ed10
 
c837b79
cd5ed10
576d564
 
 
 
 
 
 
 
 
 
 
 
 
c837b79
a459477
c837b79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
576d564
cd5ed10
 
a459477
 
 
 
 
cd5ed10
a459477
7316338
a459477
 
 
 
576d564
8faec96
576d564
 
 
 
a459477
 
 
 
 
576d564
a459477
 
576d564
a459477
cd5ed10
a459477
 
cd5ed10
a459477
 
cd5ed10
a459477
cd5ed10
a459477
 
 
cd5ed10
a459477
 
 
576d564
a459477
 
 
 
576d564
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cd5ed10
576d564
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
# %%
import argparse
import os.path
import pickle
import unicodedata

import torch
from tqdm import tqdm

import NER_medNLP as ner
import utils
from EntityNormalizer import EntityNormalizer, EntityDictionary, DefaultDiseaseDict, DefaultDrugDict

device = torch.device("mps" if torch.backends.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu')

# %% global変数として使う
dict_key = {}


# %%
def to_xml(data, id_to_tags):
    with open("key_attr.pkl", "rb") as tf:
        key_attr = pickle.load(tf)

    text = data['text']
    count = 0
    for i, entities in enumerate(data['entities_predicted']):
        if entities == "":
            return
        span = entities['span']
        try:
            type_id = id_to_tags[entities['type_id']].split('_')
        except:
            print("out of rage type_id", entities)
            continue
        tag = type_id[0]

        if not type_id[1] == "":
            attr = ' ' + value_to_key(type_id[1], key_attr) + '=' + '"' + type_id[1] + '"'
        else:
            attr = ""

        if 'norm' in entities:
            attr = attr + ' norm="' + str(entities['norm']) + '"'

        add_tag = "<" + str(tag) + str(attr) + ">"
        text = text[:span[0] + count] + add_tag + text[span[0] + count:]
        count += len(add_tag)

        add_tag = "</" + str(tag) + ">"
        text = text[:span[1] + count] + add_tag + text[span[1] + count:]
        count += len(add_tag)
    return text


def predict_entities(model, tokenizer, sentences_list):

    # entities_list = [] # 正解の固有表現を追加していく
    entities_predicted_list = []  # 抽出された固有表現を追加していく

    text_entities_set = []
    for dataset in sentences_list:
        text_entities = []
        for sample in tqdm(dataset, desc='Predict', leave=False):
            text = sample
            encoding, spans = tokenizer.encode_plus_untagged(
                text, return_tensors='pt'
            )
            encoding = {k: v.to(device) for k, v in encoding.items()}

            with torch.no_grad():
                output = model(**encoding)
                scores = output.logits
                scores = scores[0].cpu().numpy().tolist()

            # 分類スコアを固有表現に変換する
            entities_predicted = tokenizer.convert_bert_output_to_entities(
                text, scores, spans
            )

            # entities_list.append(sample['entities'])
            entities_predicted_list.append(entities_predicted)
            text_entities.append({'text': text, 'entities_predicted': entities_predicted})
        text_entities_set.append(text_entities)
    return text_entities_set


def combine_sentences(text_entities_set, id_to_tags, insert: str):
    documents = []
    for text_entities in text_entities_set:
        document = []
        for t in text_entities:
            document.append(to_xml(t, id_to_tags))
        documents.append('\n'.join(document))
    return documents


def value_to_key(value, key_attr):  # attributeから属性名を取得
    global dict_key
    if dict_key.get(value) != None:
        return dict_key[value]
    for k in key_attr.keys():
        for v in key_attr[k]:
            if value == v:
                dict_key[v] = k
                return k


# %%
def normalize_entities(text_entities_set, id_to_tags, disease_dict=None, disease_candidate_col=None, disease_normalization_col=None, disease_matching_threshold=None, drug_dict=None,
                       drug_candidate_col=None, drug_normalization_col=None, drug_matching_threshold=None):
    if disease_dict:
        disease_dict = EntityDictionary(disease_dict, disease_candidate_col, disease_normalization_col)
    else:
        disease_dict = DefaultDiseaseDict()
    disease_normalizer = EntityNormalizer(disease_dict, matching_threshold=disease_matching_threshold)

    if drug_dict:
        drug_dict = EntityDictionary(drug_dict, drug_candidate_col, drug_normalization_col)
    else:
        drug_dict = DefaultDrugDict()
    drug_normalizer = EntityNormalizer(drug_dict, matching_threshold=drug_matching_threshold)

    for entry in tqdm(text_entities_set, desc='Normalization', leave=False):
        for text_entities in entry:
            entities = text_entities['entities_predicted']
            for entity in entities:
                tag = id_to_tags[entity['type_id']].split('_')[0]

                normalizer = drug_normalizer if tag == 'm-key' \
                    else disease_normalizer if tag == 'd' \
                    else None

                if normalizer is None:
                    continue

                normalization, score = normalizer.normalize(entity['name'])
                entity['norm'] = str(normalization)


def run(model, input, output=None, normalize=False, **kwargs):
    with open("id_to_tags.pkl", "rb") as tf:
        id_to_tags = pickle.load(tf)
    len_num_entity_type = len(id_to_tags)

    # Load the model and tokenizer
    classification_model = ner.BertForTokenClassification_pl.from_pretrained_bin(model_path=model, num_labels=2 * len_num_entity_type + 1)
    bert_tc = classification_model.bert_tc.to(device)

    tokenizer = ner.NER_tokenizer_BIO.from_pretrained(
        'tohoku-nlp/bert-base-japanese-whole-word-masking',
        num_entity_type=len_num_entity_type  # Entityの数を変え忘れないように!
    )

    # Load input files
    if (os.path.isdir(input)):
        files = [os.path.join(input, f) for f in os.listdir(input) if os.path.isfile(os.path.join(input, f))]
    else:
        files = [input]

    for file in tqdm(files, desc="Input file"):
        try:
            with open(file) as f:
                articles_raw = f.read()

            article_norm = unicodedata.normalize('NFKC', articles_raw)

            sentences_raw = utils.split_sentences(articles_raw)
            sentences_norm = utils.split_sentences(article_norm)

            text_entities_set = predict_entities(bert_tc, tokenizer, [sentences_norm])

            for i, texts_ent in enumerate(text_entities_set[0]):
                texts_ent['text'] = sentences_raw[i]

            if normalize:
                normalize_entities(text_entities_set, id_to_tags, **kwargs)

            documents = combine_sentences(text_entities_set, id_to_tags, '\n')

            tqdm.write(f"File: {file}")
            tqdm.write(documents[0])
            tqdm.write("")

            if output:
                with open(file.replace(input, output), 'w') as f:
                    f.write(documents[0])

        except Exception as e:
            tqdm.write("Error while processing file: {}".format(file))
            tqdm.write(str(e))
            tqdm.write("")


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Predict entities from text')
    parser.add_argument('-m', '--model', type=str, default='pytorch_model.bin', help='Path to model checkpoint')
    parser.add_argument('-i', '--input', type=str, default='text.txt', help='Path to text file or directory')
    parser.add_argument('-o', '--output', type=str, default=None, help='Path to output file or directory')
    parser.add_argument('-n', '--normalize', action=argparse.BooleanOptionalAction, help='Enable entity normalization', default=False)

    # Dictionary override arguments
    parser.add_argument("--drug-dict", help="File path for overriding the default drug dictionary")
    parser.add_argument("--drug-candidate-col", type=int, help="Column name for drug candidates in the CSV file (required if --drug-dict is specified)")
    parser.add_argument("--drug-normalization-col", type=int, help="Column name for drug normalization in the CSV file (required if --drug-dict is specified")
    parser.add_argument('--disease-matching-threshold', type=int, default=50, help='Matching threshold for disease dictionary')

    parser.add_argument("--disease-dict", help="File path for overriding the default disease dictionary")
    parser.add_argument("--disease-candidate-col", type=int, help="Column name for disease candidates in the CSV file (required if --disease-dict is specified)")
    parser.add_argument("--disease-normalization-col", type=int, help="Column name for disease normalization in the CSV file (required if --disease-dict is specified)")
    parser.add_argument('--drug-matching-threshold', type=int, default=50, help='Matching threshold for drug dictionary')
    args = parser.parse_args()

    argument_dict = vars(args)
    run(**argument_dict)