gabrielandrade2 commited on
Commit
c837b79
1 Parent(s): cd5ed10

Add normalization methods

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
  model.safetensors filter=lfs diff=lfs merge=lfs -text
 
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
  model.safetensors filter=lfs diff=lfs merge=lfs -text
36
+ *.csv filter=lfs diff=lfs merge=lfs -text
EntityNormalizer.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import mojimoji
2
+ import pandas as pd
3
+ from rapidfuzz import fuzz, process
4
+
5
+
6
+ class EntityDictionary:
7
+
8
+ def __init__(self, path):
9
+ self.df = pd.read_csv(path)
10
+
11
+ def get_candidates_list(self):
12
+ return self.df.iloc[:, 0].to_list()
13
+
14
+ def get_normalization_list(self):
15
+ return self.df.iloc[:, 2].to_list()
16
+
17
+ def get_normalized_term(self, term):
18
+ return self.df[self.df.iloc[:, 0] == term].iloc[:, 2].item()
19
+
20
+
21
+ class DiseaseDict(EntityDictionary):
22
+
23
+ def __init__(self):
24
+ super().__init__('dictionaries/disease_dict.csv')
25
+
26
+
27
+ class DrugDict(EntityDictionary):
28
+
29
+ def __init__(self):
30
+ super().__init__('dictionaries/drug_dict.csv')
31
+
32
+
33
+ class EntityNormalizer:
34
+
35
+ def __init__(self, database: EntityDictionary, matching_method=fuzz.ratio, matching_threshold=0):
36
+ self.database = database
37
+ self.matching_method = matching_method
38
+ self.matching_threshold = matching_threshold
39
+ self.candidates = [mojimoji.han_to_zen(x) for x in self.database.get_candidates_list()]
40
+
41
+ def normalize(self, term):
42
+ term = mojimoji.han_to_zen(term)
43
+ preferred_candidate = process.extractOne(term, self.candidates, scorer=self.matching_method)
44
+ score = preferred_candidate[1]
45
+
46
+ if score > self.matching_threshold:
47
+ ret = self.database.get_normalized_term(preferred_candidate[0])
48
+ return ('' if pd.isna(ret) else ret), score
49
+ else:
50
+ return '', score
51
+
README.md CHANGED
@@ -1,15 +1,15 @@
1
  ---
2
- language:
3
- - ja
4
  license:
5
- - cc-by-4.0
6
  tags:
7
- - NER
8
- - medical documents
9
  datasets:
10
- - MedTxt-CR-JA-training-v2.xml
11
  metrics:
12
- - NTCIR-16 Real-MedNLP subtask 1
13
  ---
14
 
15
 
@@ -18,18 +18,45 @@ This is a model for named entity recognition of Japanese medical documents.
18
  ### How to use
19
 
20
  Download the following five files and put into the same folder.
 
21
  - id_to_tags.pkl
22
  - key_attr.pkl
23
  - NER_medNLP.py
24
  - predict.py
25
  - text.txt (This is an input file which should be predicted, which could be changed.)
26
 
27
- You can use this model by running predict.py.
28
 
29
  ```
30
  python3 predict.py
31
  ```
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  ### Input Example
34
 
35
  ```
@@ -40,10 +67,9 @@ python3 predict.py
40
  ### Output Example
41
 
42
  ```
43
- <d certainty="positive">肥大型心筋症、心房細動</d>に対して<m-key state="executed">WF</m-key>投与が開始となった。
44
- <timex3 type="med">治療経過中</timex3>に<d certainty="positive">非持続性心室頻拍</d>が認められたため<m-key state="executed">アミオダロン</m-key>が併用となった。
45
  ```
46
 
47
  ### Publication
48
 
49
- Tomohiro Nishiyama, Aki Ando, Mihiro Nishidani, Shuntaro Yada, Shoko Wakamiya, Eiji Aramaki: NAISTSOC at the NTCIR-16 Real-MedNLP Task, In Proceedings of the 16th NTCIR Conference on Evaluation of Information Access Technologies (NTCIR-16), pp. 330-333, 2022
1
  ---
2
+ language:
3
+ - ja
4
  license:
5
+ - cc-by-4.0
6
  tags:
7
+ - NER
8
+ - medical documents
9
  datasets:
10
+ - MedTxt-CR-JA-training-v2.xml
11
  metrics:
12
+ - NTCIR-16 Real-MedNLP subtask 1
13
  ---
14
 
15
 
18
  ### How to use
19
 
20
  Download the following five files and put into the same folder.
21
+
22
  - id_to_tags.pkl
23
  - key_attr.pkl
24
  - NER_medNLP.py
25
  - predict.py
26
  - text.txt (This is an input file which should be predicted, which could be changed.)
27
 
28
+ You can use this model by running `predict.py`.
29
 
30
  ```
31
  python3 predict.py
32
  ```
33
 
34
+ #### Entity normalization
35
+
36
+ This model supports entity normalization via dictionary matching. The dictionary is a list of medical terms or
37
+ drugs and their standard forms.
38
+
39
+ Two different dictionaries are used for drug and disease normalization, stored in the `dictionaries` folder as
40
+ `drug_dict.csv` and `disease_dict.csv`, respectively.
41
+
42
+ To enable normalization you can add the `--normalize` flag to the `predict.py` command.
43
+
44
+ ```
45
+ python3 predict.py --normalize
46
+ ```
47
+
48
+ Normalization will add the `norm` attribute to the output XML tags. This attribute can be empty if a normalized form of
49
+ the term is not found.
50
+
51
+ The provided disease normalization dictionary (`dictionaties/disease_dict.csv`) is based on the [Manbyo Dictionary](https://sociocom.naist.jp/manbyo-dic-en/) and provides normalization to the standard ICD code for the diseases.
52
+
53
+ The default drug dictionary (`dictionaties/drug_dict.csv`) is based on the [Hyakuyaku Dictionary](https://sociocom.naist.jp/hyakuyaku-dic-en/).
54
+
55
+ The dictionary is a CSV file with three columns: the first column is the surface form term and the third column contain
56
+ its standard form. The second column is not used.
57
+
58
+ User can freely change the dictionary to fit their needs, as long as the format and filename are kept.
59
+
60
  ### Input Example
61
 
62
  ```
67
  ### Output Example
68
 
69
  ```
70
+ <d certainty="positive" norm="I422">肥大型心筋症、心房細動</d>に対して<m-key state="executed" norm="ワルファリンカリウム">WF</m-key>投与が開始となった。
71
+ <timex3 type="med">治療経過中</timex3>に<d certainty="positive" norm="I472">非持続性心室頻拍</d>が認められたため<m-key state="executed" norm="アミオダロン塩酸塩">アミオダロン</m-key>が併用となった。
72
  ```
73
 
74
  ### Publication
75
 
 
dictionaries/disease_dict.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6e104832f7bc912497936c11c7196f7a7949a5e69d9414f47a6a6a3bf7caec6b
3
+ size 20832536
dictionaries/drug_dict.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ab7b62baab746dc53ef6909840dfc2f8d2a80e591c930fbb0a035907082a4bec
3
+ size 5442740
predict.py CHANGED
@@ -1,54 +1,49 @@
1
-
2
  # %%
 
 
3
  from tqdm import tqdm
4
  import unicodedata
5
  import re
6
  import pickle
7
  import torch
8
  import NER_medNLP as ner
9
- from bs4 import BeautifulSoup
10
-
11
-
12
- # import from_XML_to_json as XtC
13
- # import itertools
14
- # import random
15
- # import json
16
- # from torch.utils.data import DataLoader
17
- # from transformers import BertJapaneseTokenizer, BertForTokenClassification
18
- # import pytorch_lightning as pl
19
- # import pandas as pd
20
- # import numpy as np
21
- # import codecs
22
  device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
23
 
24
- #%% global変数として使う
25
  dict_key = {}
26
 
27
- #%%
 
28
  def to_xml(data):
29
  with open("key_attr.pkl", "rb") as tf:
30
  key_attr = pickle.load(tf)
31
-
32
  text = data['text']
33
  count = 0
34
  for i, entities in enumerate(data['entities_predicted']):
35
  if entities == "":
36
- return
37
  span = entities['span']
38
  type_id = id_to_tags[entities['type_id']].split('_')
39
  tag = type_id[0]
40
-
41
  if not type_id[1] == "":
42
- attr = ' ' + value_to_key(type_id[1], key_attr) + '=' + '"' + type_id[1] + '"'
43
  else:
44
  attr = ""
45
-
 
 
 
46
  add_tag = "<" + str(tag) + str(attr) + ">"
47
- text = text[:span[0]+count] + add_tag + text[span[0]+count:]
48
  count += len(add_tag)
49
 
50
  add_tag = "</" + str(tag) + ">"
51
- text = text[:span[1]+count] + add_tag + text[span[1]+count:]
52
  count += len(add_tag)
53
  return text
54
 
@@ -58,18 +53,18 @@ def predict_entities(modelpath, sentences_list, len_num_entity_type):
58
  # checkpoint_path = modelpath + ".ckpt"
59
  # )
60
  # bert_tc = model.bert_tc.cuda()
61
-
62
- model = ner.BertForTokenClassification_pl(modelpath, num_labels=81, lr=1e-5)
63
  bert_tc = model.bert_tc.to(device)
64
 
65
  MODEL_NAME = 'cl-tohoku/bert-base-japanese-whole-word-masking'
66
  tokenizer = ner.NER_tokenizer_BIO.from_pretrained(
67
  MODEL_NAME,
68
- num_entity_type = len_num_entity_type#Entityの数を変え忘れないように!
69
  )
70
 
71
- #entities_list = [] # 正解の固有表現を追加していく
72
- entities_predicted_list = [] # 抽出された固有表現を追加していく
73
 
74
  text_entities_set = []
75
  for dataset in sentences_list:
@@ -79,24 +74,25 @@ def predict_entities(modelpath, sentences_list, len_num_entity_type):
79
  encoding, spans = tokenizer.encode_plus_untagged(
80
  text, return_tensors='pt'
81
  )
82
- encoding = { k: v.to(device) for k, v in encoding.items() }
83
-
84
  with torch.no_grad():
85
  output = bert_tc(**encoding)
86
  scores = output.logits
87
  scores = scores[0].cpu().numpy().tolist()
88
-
89
  # 分類スコアを固有表現に変換する
90
  entities_predicted = tokenizer.convert_bert_output_to_entities(
91
  text, scores, spans
92
  )
93
 
94
- #entities_list.append(sample['entities'])
95
  entities_predicted_list.append(entities_predicted)
96
  text_entities.append({'text': text, 'entities_predicted': entities_predicted})
97
  text_entities_set.append(text_entities)
98
  return text_entities_set
99
 
 
100
  def combine_sentences(text_entities_set, insert: str):
101
  documents = []
102
  for text_entities in tqdm(text_entities_set):
@@ -106,25 +102,51 @@ def combine_sentences(text_entities_set, insert: str):
106
  documents.append('\n'.join(document))
107
  return documents
108
 
109
- def value_to_key(value, key_attr):#attributeから属性名を取得
 
110
  global dict_key
111
  if dict_key.get(value) != None:
112
  return dict_key[value]
113
  for k in key_attr.keys():
114
  for v in key_attr[k]:
115
  if value == v:
116
- dict_key[v]=k
117
  return k
118
 
 
119
  # %%
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
  if __name__ == '__main__':
 
 
 
 
121
  with open("id_to_tags.pkl", "rb") as tf:
122
  id_to_tags = pickle.load(tf)
123
  with open("key_attr.pkl", "rb") as tf:
124
  key_attr = pickle.load(tf)
125
  with open('text.txt') as f:
126
  articles_raw = f.read()
127
-
128
 
129
  article_norm = unicodedata.normalize('NFKC', articles_raw)
130
 
@@ -133,10 +155,11 @@ if __name__ == '__main__':
133
 
134
  text_entities_set = predict_entities("sociocom/RealMedNLP_CR_JA", [sentences_norm], len(id_to_tags))
135
 
136
-
137
  for i, texts_ent in enumerate(text_entities_set[0]):
138
  texts_ent['text'] = sentences_raw[i]
139
 
 
 
140
 
141
  documents = combine_sentences(text_entities_set, '\n')
142
 
 
1
  # %%
2
+ import argparse
3
+
4
  from tqdm import tqdm
5
  import unicodedata
6
  import re
7
  import pickle
8
  import torch
9
  import NER_medNLP as ner
10
+
11
+ from EntityNormalizer import EntityNormalizer, DiseaseDict, DrugDict
12
+
 
 
 
 
 
 
 
 
 
 
13
  device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
14
 
15
+ # %% global変数として使う
16
  dict_key = {}
17
 
18
+
19
+ # %%
20
  def to_xml(data):
21
  with open("key_attr.pkl", "rb") as tf:
22
  key_attr = pickle.load(tf)
23
+
24
  text = data['text']
25
  count = 0
26
  for i, entities in enumerate(data['entities_predicted']):
27
  if entities == "":
28
+ return
29
  span = entities['span']
30
  type_id = id_to_tags[entities['type_id']].split('_')
31
  tag = type_id[0]
32
+
33
  if not type_id[1] == "":
34
+ attr = ' ' + value_to_key(type_id[1], key_attr) + '=' + '"' + type_id[1] + '"'
35
  else:
36
  attr = ""
37
+
38
+ if 'norm' in entities:
39
+ attr = attr + ' norm="' + str(entities['norm']) + '"'
40
+
41
  add_tag = "<" + str(tag) + str(attr) + ">"
42
+ text = text[:span[0] + count] + add_tag + text[span[0] + count:]
43
  count += len(add_tag)
44
 
45
  add_tag = "</" + str(tag) + ">"
46
+ text = text[:span[1] + count] + add_tag + text[span[1] + count:]
47
  count += len(add_tag)
48
  return text
49
 
53
  # checkpoint_path = modelpath + ".ckpt"
54
  # )
55
  # bert_tc = model.bert_tc.cuda()
56
+
57
+ model = ner.BertForTokenClassification_pl(modelpath, num_labels=81, lr=1e-5)
58
  bert_tc = model.bert_tc.to(device)
59
 
60
  MODEL_NAME = 'cl-tohoku/bert-base-japanese-whole-word-masking'
61
  tokenizer = ner.NER_tokenizer_BIO.from_pretrained(
62
  MODEL_NAME,
63
+ num_entity_type=len_num_entity_type # Entityの数を変え忘れないように!
64
  )
65
 
66
+ # entities_list = [] # 正解の固有表現を追加していく
67
+ entities_predicted_list = [] # 抽出された固有表現を追加していく
68
 
69
  text_entities_set = []
70
  for dataset in sentences_list:
74
  encoding, spans = tokenizer.encode_plus_untagged(
75
  text, return_tensors='pt'
76
  )
77
+ encoding = {k: v.to(device) for k, v in encoding.items()}
78
+
79
  with torch.no_grad():
80
  output = bert_tc(**encoding)
81
  scores = output.logits
82
  scores = scores[0].cpu().numpy().tolist()
83
+
84
  # 分類スコアを固有表現に変換する
85
  entities_predicted = tokenizer.convert_bert_output_to_entities(
86
  text, scores, spans
87
  )
88
 
89
+ # entities_list.append(sample['entities'])
90
  entities_predicted_list.append(entities_predicted)
91
  text_entities.append({'text': text, 'entities_predicted': entities_predicted})
92
  text_entities_set.append(text_entities)
93
  return text_entities_set
94
 
95
+
96
  def combine_sentences(text_entities_set, insert: str):
97
  documents = []
98
  for text_entities in tqdm(text_entities_set):
102
  documents.append('\n'.join(document))
103
  return documents
104
 
105
+
106
+ def value_to_key(value, key_attr): # attributeから属性名を取得
107
  global dict_key
108
  if dict_key.get(value) != None:
109
  return dict_key[value]
110
  for k in key_attr.keys():
111
  for v in key_attr[k]:
112
  if value == v:
113
+ dict_key[v] = k
114
  return k
115
 
116
+
117
  # %%
118
+ def normalize_entities(text_entities_set):
119
+ disease_normalizer = EntityNormalizer(DiseaseDict(), matching_threshold=50)
120
+ drug_normalizer = EntityNormalizer(DrugDict(), matching_threshold=50)
121
+
122
+ for entry in text_entities_set:
123
+ for text_entities in entry:
124
+ entities = text_entities['entities_predicted']
125
+ for entity in entities:
126
+ tag = id_to_tags[entity['type_id']].split('_')[0]
127
+
128
+ normalizer = drug_normalizer if tag == 'm-key' \
129
+ else disease_normalizer if tag == 'd' \
130
+ else None
131
+
132
+ if normalizer is None:
133
+ continue
134
+
135
+ normalization, score = normalizer.normalize(entity['name'])
136
+ entity['norm'] = str(normalization)
137
+
138
+
139
  if __name__ == '__main__':
140
+ parser = argparse.ArgumentParser(description='Predict entities from text')
141
+ parser.add_argument('--normalize', action=argparse.BooleanOptionalAction, help='Enable entity normalization')
142
+ args = parser.parse_args()
143
+
144
  with open("id_to_tags.pkl", "rb") as tf:
145
  id_to_tags = pickle.load(tf)
146
  with open("key_attr.pkl", "rb") as tf:
147
  key_attr = pickle.load(tf)
148
  with open('text.txt') as f:
149
  articles_raw = f.read()
 
150
 
151
  article_norm = unicodedata.normalize('NFKC', articles_raw)
152
 
155
 
156
  text_entities_set = predict_entities("sociocom/RealMedNLP_CR_JA", [sentences_norm], len(id_to_tags))
157
 
 
158
  for i, texts_ent in enumerate(text_entities_set[0]):
159
  texts_ent['text'] = sentences_raw[i]
160
 
161
+ if args.normalize:
162
+ normalize_entities(text_entities_set)
163
 
164
  documents = combine_sentences(text_entities_set, '\n')
165
 
requirements.txt ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiohttp==3.8.4
2
+ aiosignal==1.3.1
3
+ async-timeout==4.0.2
4
+ attrs==22.2.0
5
+ certifi==2022.12.7
6
+ charset-normalizer==3.1.0
7
+ et-xmlfile==1.1.0
8
+ filelock==3.11.0
9
+ frozenlist==1.3.3
10
+ fsspec==2023.4.0
11
+ fugashi==1.2.1
12
+ huggingface-hub==0.13.4
13
+ idna==3.4
14
+ ipadic==1.0.0
15
+ Jinja2==3.1.2
16
+ Levenshtein==0.20.9
17
+ lightning-utilities==0.8.0
18
+ MarkupSafe==2.1.2
19
+ mojimoji==0.0.12
20
+ mpmath==1.3.0
21
+ multidict==6.0.4
22
+ networkx==3.1
23
+ numpy==1.24.2
24
+ openpyxl==3.1.2
25
+ packaging==23.0
26
+ pandas==2.0.0
27
+ python-dateutil==2.8.2
28
+ pytorch-lightning==2.0.1.post0
29
+ pytz==2023.3
30
+ PyYAML==6.0
31
+ rapidfuzz==2.15.1
32
+ regex==2023.3.23
33
+ requests==2.28.2
34
+ six==1.16.0
35
+ soupsieve==2.4
36
+ sympy==1.11.1
37
+ tokenizers==0.13.3
38
+ torch==2.0.0
39
+ torchmetrics==0.11.4
40
+ tqdm==4.65.0
41
+ transformers==4.27.4
42
+ typing_extensions==4.5.0
43
+ tzdata==2023.3
44
+ urllib3==1.26.15
45
+ yarl==1.8.2