gabrielandrade2 commited on
Commit
a459477
1 Parent(s): 8faec96

Avoid reloading the model for each document to be processed, add more detailed progress bars

Browse files
Files changed (1) hide show
  1. predict.py +39 -27
predict.py CHANGED
@@ -53,14 +53,7 @@ def to_xml(data, id_to_tags):
53
  return text
54
 
55
 
56
- def predict_entities(modelpath, sentences_list, len_num_entity_type):
57
- model = ner.BertForTokenClassification_pl.from_pretrained_bin(model_path=modelpath, num_labels=2 * len_num_entity_type + 1)
58
- bert_tc = model.bert_tc.to(device)
59
-
60
- tokenizer = ner.NER_tokenizer_BIO.from_pretrained(
61
- 'cl-tohoku/bert-base-japanese-whole-word-masking',
62
- num_entity_type=len_num_entity_type # Entityの数を変え忘れないように!
63
- )
64
 
65
  # entities_list = [] # 正解の固有表現を追加していく
66
  entities_predicted_list = [] # 抽出された固有表現を追加していく
@@ -68,7 +61,7 @@ def predict_entities(modelpath, sentences_list, len_num_entity_type):
68
  text_entities_set = []
69
  for dataset in sentences_list:
70
  text_entities = []
71
- for sample in tqdm(dataset, desc='Predict'):
72
  text = sample
73
  encoding, spans = tokenizer.encode_plus_untagged(
74
  text, return_tensors='pt'
@@ -76,7 +69,7 @@ def predict_entities(modelpath, sentences_list, len_num_entity_type):
76
  encoding = {k: v.to(device) for k, v in encoding.items()}
77
 
78
  with torch.no_grad():
79
- output = bert_tc(**encoding)
80
  scores = output.logits
81
  scores = scores[0].cpu().numpy().tolist()
82
 
@@ -94,7 +87,7 @@ def predict_entities(modelpath, sentences_list, len_num_entity_type):
94
 
95
  def combine_sentences(text_entities_set, id_to_tags, insert: str):
96
  documents = []
97
- for text_entities in tqdm(text_entities_set):
98
  document = []
99
  for t in text_entities:
100
  document.append(to_xml(t, id_to_tags))
@@ -128,7 +121,7 @@ def normalize_entities(text_entities_set, id_to_tags, disease_dict=None, disease
128
  drug_dict = DefaultDrugDict()
129
  drug_normalizer = EntityNormalizer(drug_dict, matching_threshold=drug_matching_threshold)
130
 
131
- for entry in text_entities_set:
132
  for text_entities in entry:
133
  entities = text_entities['entities_predicted']
134
  for entity in entities:
@@ -148,36 +141,55 @@ def normalize_entities(text_entities_set, id_to_tags, disease_dict=None, disease
148
  def run(model, input, output=None, normalize=False, **kwargs):
149
  with open("id_to_tags.pkl", "rb") as tf:
150
  id_to_tags = pickle.load(tf)
 
 
 
 
 
151
 
 
 
 
 
 
 
152
  if (os.path.isdir(input)):
153
  files = [os.path.join(input, f) for f in os.listdir(input) if os.path.isfile(os.path.join(input, f))]
154
  else:
155
  files = [input]
156
 
157
  for file in tqdm(files, desc="Input file"):
158
- with open(file) as f:
159
- articles_raw = f.read()
 
 
 
160
 
161
- article_norm = unicodedata.normalize('NFKC', articles_raw)
 
162
 
163
- sentences_raw = utils.split_sentences(articles_raw)
164
- sentences_norm = utils.split_sentences(article_norm)
165
 
166
- text_entities_set = predict_entities(model, [sentences_norm], len(id_to_tags))
 
167
 
168
- for i, texts_ent in enumerate(text_entities_set[0]):
169
- texts_ent['text'] = sentences_raw[i]
170
 
171
- if normalize:
172
- normalize_entities(text_entities_set, id_to_tags, **kwargs)
173
 
174
- documents = combine_sentences(text_entities_set, id_to_tags, '\n')
 
 
175
 
176
- print(documents[0])
 
 
177
 
178
- if output:
179
- with open(file.replace(input, output), 'w') as f:
180
- f.write(documents[0])
 
181
 
182
 
183
  if __name__ == '__main__':
 
53
  return text
54
 
55
 
56
+ def predict_entities(model, tokenizer, sentences_list):
 
 
 
 
 
 
 
57
 
58
  # entities_list = [] # 正解の固有表現を追加していく
59
  entities_predicted_list = [] # 抽出された固有表現を追加していく
 
61
  text_entities_set = []
62
  for dataset in sentences_list:
63
  text_entities = []
64
+ for sample in tqdm(dataset, desc='Predict', leave=False):
65
  text = sample
66
  encoding, spans = tokenizer.encode_plus_untagged(
67
  text, return_tensors='pt'
 
69
  encoding = {k: v.to(device) for k, v in encoding.items()}
70
 
71
  with torch.no_grad():
72
+ output = model(**encoding)
73
  scores = output.logits
74
  scores = scores[0].cpu().numpy().tolist()
75
 
 
87
 
88
  def combine_sentences(text_entities_set, id_to_tags, insert: str):
89
  documents = []
90
+ for text_entities in text_entities_set:
91
  document = []
92
  for t in text_entities:
93
  document.append(to_xml(t, id_to_tags))
 
121
  drug_dict = DefaultDrugDict()
122
  drug_normalizer = EntityNormalizer(drug_dict, matching_threshold=drug_matching_threshold)
123
 
124
+ for entry in tqdm(text_entities_set, desc='Normalization', leave=False):
125
  for text_entities in entry:
126
  entities = text_entities['entities_predicted']
127
  for entity in entities:
 
141
  def run(model, input, output=None, normalize=False, **kwargs):
142
  with open("id_to_tags.pkl", "rb") as tf:
143
  id_to_tags = pickle.load(tf)
144
+ len_num_entity_type = len(id_to_tags)
145
+
146
+ # Load the model and tokenizer
147
+ classification_model = ner.BertForTokenClassification_pl.from_pretrained_bin(model_path=model, num_labels=2 * len_num_entity_type + 1)
148
+ bert_tc = classification_model.bert_tc.to(device)
149
 
150
+ tokenizer = ner.NER_tokenizer_BIO.from_pretrained(
151
+ 'cl-tohoku/bert-base-japanese-whole-word-masking',
152
+ num_entity_type=len_num_entity_type # Entityの数を変え忘れないように!
153
+ )
154
+
155
+ # Load input files
156
  if (os.path.isdir(input)):
157
  files = [os.path.join(input, f) for f in os.listdir(input) if os.path.isfile(os.path.join(input, f))]
158
  else:
159
  files = [input]
160
 
161
  for file in tqdm(files, desc="Input file"):
162
+ try:
163
+ with open(file) as f:
164
+ articles_raw = f.read()
165
+
166
+ article_norm = unicodedata.normalize('NFKC', articles_raw)
167
 
168
+ sentences_raw = utils.split_sentences(articles_raw)
169
+ sentences_norm = utils.split_sentences(article_norm)
170
 
171
+ text_entities_set = predict_entities(bert_tc, tokenizer, [sentences_norm])
 
172
 
173
+ for i, texts_ent in enumerate(text_entities_set[0]):
174
+ texts_ent['text'] = sentences_raw[i]
175
 
176
+ if normalize:
177
+ normalize_entities(text_entities_set, id_to_tags, **kwargs)
178
 
179
+ documents = combine_sentences(text_entities_set, id_to_tags, '\n')
 
180
 
181
+ tqdm.write(f"File: {file}")
182
+ tqdm.write(documents[0])
183
+ tqdm.write("")
184
 
185
+ if output:
186
+ with open(file.replace(input, output), 'w') as f:
187
+ f.write(documents[0])
188
 
189
+ except Exception as e:
190
+ tqdm.write("Error while processing file: {}".format(file))
191
+ tqdm.write(str(e))
192
+ tqdm.write("")
193
 
194
 
195
  if __name__ == '__main__':