socialcomp commited on
Commit
b6af7c5
1 Parent(s): eeabe37

Upload predict.py

Browse files
Files changed (1) hide show
  1. predict.py +143 -0
predict.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # %%
3
+ from tqdm import tqdm
4
+ import unicodedata
5
+ import re
6
+ import pickle
7
+ import torch
8
+ import NER_medNLP as ner
9
+ from bs4 import BeautifulSoup
10
+
11
+
12
+ # import from_XML_to_json as XtC
13
+ # import itertools
14
+ # import random
15
+ # import json
16
+ # from torch.utils.data import DataLoader
17
+ # from transformers import BertJapaneseTokenizer, BertForTokenClassification
18
+ # import pytorch_lightning as pl
19
+ # import pandas as pd
20
+ # import numpy as np
21
+ # import codecs
22
+
23
+
24
+ #%% global変数として使う
25
+ dict_key = {}
26
+
27
+ #%%
28
+ def to_xml(data):
29
+ with open("key_attr.pkl", "rb") as tf:
30
+ key_attr = pickle.load(tf)
31
+
32
+ text = data['text']
33
+ count = 0
34
+ for i, entities in enumerate(data['entities_predicted']):
35
+ if entities == "":
36
+ return
37
+ span = entities['span']
38
+ type_id = id_to_tags[entities['type_id']].split('_')
39
+ tag = type_id[0]
40
+
41
+ if not type_id[1] == "":
42
+ attr = ' ' + value_to_key(type_id[1], key_attr) + '=' + '"' + type_id[1] + '"'
43
+ else:
44
+ attr = ""
45
+
46
+ add_tag = "<" + str(tag) + str(attr) + ">"
47
+ text = text[:span[0]+count] + add_tag + text[span[0]+count:]
48
+ count += len(add_tag)
49
+
50
+ add_tag = "</" + str(tag) + ">"
51
+ text = text[:span[1]+count] + add_tag + text[span[1]+count:]
52
+ count += len(add_tag)
53
+ return text
54
+
55
+
56
+ def predict_entities(modelpath, sentences_list, len_num_entity_type):
57
+ # model = ner.BertForTokenClassification_pl.load_from_checkpoint(
58
+ # checkpoint_path = modelpath + ".ckpt"
59
+ # )
60
+ # bert_tc = model.bert_tc.cuda()
61
+
62
+ model = ner.BertForTokenClassification_pl(modelpath, num_labels=81, lr=1e-5)
63
+ bert_tc = model.bert_tc.cuda()
64
+
65
+ MODEL_NAME = 'cl-tohoku/bert-base-japanese-whole-word-masking'
66
+ tokenizer = ner.NER_tokenizer_BIO.from_pretrained(
67
+ MODEL_NAME,
68
+ num_entity_type = len_num_entity_type#Entityの数を変え忘れないように!
69
+ )
70
+
71
+ #entities_list = [] # 正解の固有表現を追加していく
72
+ entities_predicted_list = [] # 抽出された固有表現を追加していく
73
+
74
+ text_entities_set = []
75
+ for dataset in sentences_list:
76
+ text_entities = []
77
+ for sample in tqdm(dataset):
78
+ text = sample
79
+ encoding, spans = tokenizer.encode_plus_untagged(
80
+ text, return_tensors='pt'
81
+ )
82
+ encoding = { k: v.cuda() for k, v in encoding.items() }
83
+
84
+ with torch.no_grad():
85
+ output = bert_tc(**encoding)
86
+ scores = output.logits
87
+ scores = scores[0].cpu().numpy().tolist()
88
+
89
+ # 分類スコアを固有表現に変換する
90
+ entities_predicted = tokenizer.convert_bert_output_to_entities(
91
+ text, scores, spans
92
+ )
93
+
94
+ #entities_list.append(sample['entities'])
95
+ entities_predicted_list.append(entities_predicted)
96
+ text_entities.append({'text': text, 'entities_predicted': entities_predicted})
97
+ text_entities_set.append(text_entities)
98
+ return text_entities_set
99
+
100
+ def combine_sentences(text_entities_set, insert: str):
101
+ documents = []
102
+ for text_entities in tqdm(text_entities_set):
103
+ document = []
104
+ for t in text_entities:
105
+ document.append(to_xml(t))
106
+ documents.append('\n'.join(document))
107
+ return documents
108
+
109
+ def value_to_key(value, key_attr):#attributeから属性名を取得
110
+ global dict_key
111
+ if dict_key.get(value) != None:
112
+ return dict_key[value]
113
+ for k in key_attr.keys():
114
+ for v in key_attr[k]:
115
+ if value == v:
116
+ dict_key[v]=k
117
+ return k
118
+
119
+ # %%
120
+ if __name__ == '__main__':
121
+ with open("id_to_tags.pkl", "rb") as tf:
122
+ id_to_tags = pickle.load(tf)
123
+ with open("key_attr.pkl", "rb") as tf:
124
+ key_attr = pickle.load(tf)
125
+ with open('text.txt') as f:
126
+ articles_raw = f.read()
127
+
128
+
129
+ article_norm = unicodedata.normalize('NFKC', articles_raw)
130
+
131
+ sentences_raw = [s for s in re.split(r'\n', articles_raw) if s != '']
132
+ sentences_norm = [s for s in re.split(r'\n', article_norm) if s != '']
133
+
134
+ text_entities_set = predict_entities("Tomohiro/RealMedNLP_CR_JA", [sentences_norm], len(id_to_tags))
135
+
136
+
137
+ for i, texts_ent in enumerate(text_entities_set[0]):
138
+ texts_ent['text'] = sentences_raw[i]
139
+
140
+
141
+ documents = combine_sentences(text_entities_set, '\n')
142
+
143
+ print(documents[0])