Spaces:
Runtime error
Runtime error
ajit
commited on
Commit
•
d154d66
1
Parent(s):
d9023f8
Initial commit
Browse files- BatchInference.py +707 -0
- aggregate_server_json.py +541 -0
- app.py +271 -0
- batched_main_NER.py +905 -0
- bbc/bbc_labels.txt +0 -0
- bbc/desc_bbc_config.json +6 -0
- bbc/ner_bbc_config.json +8 -0
- bbc/vocab.txt +0 -0
- bio/a100_labels.txt +0 -0
- bio/desc_a100_config.json +6 -0
- bio/ner_a100_config.json +8 -0
- bio/vocab.txt +0 -0
- common.py +153 -0
- common_descs.txt +149 -0
- config_utils.py +19 -0
- ensemble_config.json +37 -0
- entity_types_consolidated.txt +18 -0
- logs/failed_queries_log.txt +0 -0
- logs/query_logs.txt +0 -0
- requirements.txt +3 -0
- untagged_terms.txt +0 -0
BatchInference.py
ADDED
@@ -0,0 +1,707 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import subprocess
|
3 |
+
#from pytorch_transformers import *
|
4 |
+
from transformers import *
|
5 |
+
import pdb
|
6 |
+
import operator
|
7 |
+
from collections import OrderedDict
|
8 |
+
import numpy as np
|
9 |
+
import argparse
|
10 |
+
import sys
|
11 |
+
import traceback
|
12 |
+
import string
|
13 |
+
import common as utils
|
14 |
+
import config_utils as cf
|
15 |
+
import requests
|
16 |
+
import json
|
17 |
+
import streamlit as st
|
18 |
+
|
19 |
+
# OPTIONAL: if you want to have more information on what's happening, activate the logger as follows
|
20 |
+
import logging
|
21 |
+
logging.basicConfig(level=logging.INFO)
|
22 |
+
|
23 |
+
|
24 |
+
DEFAULT_TOP_K = 20
|
25 |
+
DEFAULT_CONFIG = "./server_config.json"
|
26 |
+
DEFAULT_MODEL_PATH='./'
|
27 |
+
DEFAULT_LABELS_PATH='./labels.txt'
|
28 |
+
DEFAULT_TO_LOWER=False
|
29 |
+
DESC_FILE="./common_descs.txt"
|
30 |
+
SPECIFIC_TAG=":__entity__"
|
31 |
+
MAX_TOKENIZED_SENT_LENGTH = 500 #additional buffer for CLS SEP and entity term
|
32 |
+
|
33 |
+
try:
|
34 |
+
from subprocess import DEVNULL # Python 3.
|
35 |
+
except ImportError:
|
36 |
+
DEVNULL = open(os.devnull, 'wb')
|
37 |
+
|
38 |
+
|
39 |
+
@st.cache()
|
40 |
+
def load_bert_model(model_name,to_lower):
|
41 |
+
try:
|
42 |
+
bert_tokenizer = BertTokenizer.from_pretrained(model_name,do_lower_case=to_lower)
|
43 |
+
bert_model = BertForMaskedLM.from_pretrained(model_name)
|
44 |
+
return bert_tokenizer,bert_model
|
45 |
+
except Exception as e:
|
46 |
+
pass
|
47 |
+
|
48 |
+
def read_descs(file_name):
|
49 |
+
ret_dict = {}
|
50 |
+
with open(file_name) as fp:
|
51 |
+
line = fp.readline().rstrip("\n")
|
52 |
+
if (len(line) >= 1):
|
53 |
+
ret_dict[line] = 1
|
54 |
+
while line:
|
55 |
+
line = fp.readline().rstrip("\n")
|
56 |
+
if (len(line) >= 1):
|
57 |
+
ret_dict[line] = 1
|
58 |
+
return ret_dict
|
59 |
+
|
60 |
+
def read_vocab(file_name):
|
61 |
+
l_vocab_dict = {}
|
62 |
+
o_vocab_dict = {}
|
63 |
+
with open(file_name) as fp:
|
64 |
+
for line in fp:
|
65 |
+
line = line.rstrip('\n')
|
66 |
+
if (len(line) > 0):
|
67 |
+
l_vocab_dict[line.lower()] = line #If there are multiple cased versions they will be collapsed into one. which is okay since we have the original saved. This is only used
|
68 |
+
#when a word is not found in its pristine form in the original list.
|
69 |
+
o_vocab_dict[line] = line
|
70 |
+
print("Read vocab file:",len(o_vocab_dict))
|
71 |
+
return o_vocab_dict,l_vocab_dict
|
72 |
+
|
73 |
+
def consolidate_labels(existing_node,new_labels,new_counts):
|
74 |
+
"""Consolidates all the labels and counts for terms ignoring casing
|
75 |
+
|
76 |
+
For instance, egfr may not have an entity label associated with it
|
77 |
+
but eGFR and EGFR may have. So if input is egfr, then this function ensures
|
78 |
+
the combined entities set fo eGFR and EGFR is made so as to return that union
|
79 |
+
for egfr
|
80 |
+
"""
|
81 |
+
new_dict = {}
|
82 |
+
existing_labels_arr = existing_node["label"].split('/')
|
83 |
+
existing_counts_arr = existing_node["counts"].split('/')
|
84 |
+
new_labels_arr = new_labels.split('/')
|
85 |
+
new_counts_arr = new_counts.split('/')
|
86 |
+
assert(len(existing_labels_arr) == len(existing_counts_arr))
|
87 |
+
assert(len(new_labels_arr) == len(new_counts_arr))
|
88 |
+
for i in range(len(existing_labels_arr)):
|
89 |
+
new_dict[existing_labels_arr[i]] = int(existing_counts_arr[i])
|
90 |
+
for i in range(len(new_labels_arr)):
|
91 |
+
if (new_labels_arr[i] in new_dict):
|
92 |
+
new_dict[new_labels_arr[i]] += int(new_counts_arr[i])
|
93 |
+
else:
|
94 |
+
new_dict[new_labels_arr[i]] = int(new_counts_arr[i])
|
95 |
+
sorted_d = OrderedDict(sorted(new_dict.items(), key=lambda kv: kv[1], reverse=True))
|
96 |
+
ret_labels_str = ""
|
97 |
+
ret_counts_str = ""
|
98 |
+
count = 0
|
99 |
+
for key in sorted_d:
|
100 |
+
if (count == 0):
|
101 |
+
ret_labels_str = key
|
102 |
+
ret_counts_str = str(sorted_d[key])
|
103 |
+
else:
|
104 |
+
ret_labels_str += '/' + key
|
105 |
+
ret_counts_str += '/' + str(sorted_d[key])
|
106 |
+
count += 1
|
107 |
+
return {"label":ret_labels_str,"counts":ret_counts_str}
|
108 |
+
|
109 |
+
|
110 |
+
def read_labels(labels_file):
|
111 |
+
terms_dict = OrderedDict()
|
112 |
+
lc_terms_dict = OrderedDict()
|
113 |
+
with open(labels_file,encoding="utf-8") as fin:
|
114 |
+
count = 1
|
115 |
+
for term in fin:
|
116 |
+
term = term.strip("\n")
|
117 |
+
term = term.split()
|
118 |
+
if (len(term) == 3):
|
119 |
+
terms_dict[term[2]] = {"label":term[0],"counts":term[1]}
|
120 |
+
lc_term = term[2].lower()
|
121 |
+
if (lc_term in lc_terms_dict):
|
122 |
+
lc_terms_dict[lc_term] = consolidate_labels(lc_terms_dict[lc_term],term[0],term[1])
|
123 |
+
else:
|
124 |
+
lc_terms_dict[lc_term] = {"label":term[0],"counts":term[1]}
|
125 |
+
count += 1
|
126 |
+
else:
|
127 |
+
print("Invalid line:",term)
|
128 |
+
assert(0)
|
129 |
+
print("count of labels in " + labels_file + ":", len(terms_dict))
|
130 |
+
return terms_dict,lc_terms_dict
|
131 |
+
|
132 |
+
|
133 |
+
class BatchInference:
|
134 |
+
def __init__(self, config_file,path,to_lower,patched,topk,abbrev,tokmod,vocab_path,labels_file,delimsep):
|
135 |
+
print("Model path:",path,"lower casing set to:",to_lower," is patched ", patched)
|
136 |
+
self.path = path
|
137 |
+
base_path = cf.read_config(config_file)["BASE_PATH"] if ("BASE_PATH" in cf.read_config(config_file)) else "./"
|
138 |
+
desc_file_path = cf.read_config(config_file)["DESC_FILE"] if ("DESC_FILE" in cf.read_config(config_file)) else DESC_FILE
|
139 |
+
self.labels_dict,self.lc_labels_dict = read_labels(labels_file)
|
140 |
+
#self.tokenizer = BertTokenizer.from_pretrained(path,do_lower_case=to_lower) ### Set this to to True for uncased models
|
141 |
+
#self.model = BertForMaskedLM.from_pretrained(path)
|
142 |
+
self.tokenizer, self.model = load_bert_model(path,to_lower)
|
143 |
+
self.model.eval()
|
144 |
+
#st.info("model loaded")
|
145 |
+
self.descs = read_descs(desc_file_path)
|
146 |
+
#st.info("descs loaded")
|
147 |
+
self.top_k = topk
|
148 |
+
self.patched = patched
|
149 |
+
self.abbrev = abbrev
|
150 |
+
self.tokmod = tokmod
|
151 |
+
self.delimsep = delimsep
|
152 |
+
self.truncated_fp = open(base_path + "truncated_sentences.txt","a")
|
153 |
+
self.always_log_fp = open(base_path + "CI_LOGS.txt","a")
|
154 |
+
if (cf.read_config(config_file)["USE_CLS"] == "1"): #Models like Bert base cased return same prediction for CLS regardless of input. So ignore CLS
|
155 |
+
print("************** USE CLS: Turned ON for this model. ******* ")
|
156 |
+
self.use_cls = True
|
157 |
+
else:
|
158 |
+
print("************** USE CLS: Turned OFF for this model. ******* ")
|
159 |
+
self.use_cls = False
|
160 |
+
if (cf.read_config(config_file)["LOG_DESCS"] == "1"):
|
161 |
+
self.log_descs = True
|
162 |
+
self.ci_fp = open(base_path + "log_ci_predictions.txt","w")
|
163 |
+
self.cs_fp = open(base_path + "log_cs_predictions.txt","w")
|
164 |
+
else:
|
165 |
+
self.log_descs = False
|
166 |
+
self.pos_server_url = cf.read_config(config_file)["POS_SERVER_URL"]
|
167 |
+
#st.info("Attemting to load vocab file")
|
168 |
+
if (tokmod):
|
169 |
+
self.o_vocab_dict,self.l_vocab_dict = read_vocab(vocab_path + "/vocab.txt")
|
170 |
+
else:
|
171 |
+
self.o_vocab_dict = {}
|
172 |
+
self.l_vocab_dict = {}
|
173 |
+
# st.info("Constructor complete")
|
174 |
+
#pdb.set_trace()
|
175 |
+
|
176 |
+
def dispatch_request(self,url):
|
177 |
+
max_retries = 10
|
178 |
+
attempts = 0
|
179 |
+
while True:
|
180 |
+
try:
|
181 |
+
r = requests.get(url,timeout=1000)
|
182 |
+
if (r.status_code == 200):
|
183 |
+
return r
|
184 |
+
except:
|
185 |
+
print("Request:", url, " failed. Retrying...")
|
186 |
+
attempts += 1
|
187 |
+
if (attempts >= max_retries):
|
188 |
+
print("Request:", url, " failed")
|
189 |
+
break
|
190 |
+
|
191 |
+
def modify_text_to_match_vocab(self,text):
|
192 |
+
ret_arr = []
|
193 |
+
text = text.split()
|
194 |
+
for word in text:
|
195 |
+
if (word in self.o_vocab_dict):
|
196 |
+
ret_arr.append(word)
|
197 |
+
else:
|
198 |
+
if (word.lower() in self.l_vocab_dict):
|
199 |
+
ret_arr.append(self.l_vocab_dict[word.lower()])
|
200 |
+
else:
|
201 |
+
ret_arr.append(word)
|
202 |
+
return ' '.join(ret_arr)
|
203 |
+
|
204 |
+
#This is bad hack for prototyping - parsing from text output as opposed to json
|
205 |
+
def extract_POS(self,text):
|
206 |
+
arr = text.split('\n')
|
207 |
+
if (len(arr) > 0):
|
208 |
+
start_pos = 0
|
209 |
+
for i,line in enumerate(arr):
|
210 |
+
if (len(line) > 0):
|
211 |
+
start_pos += 1
|
212 |
+
continue
|
213 |
+
else:
|
214 |
+
break
|
215 |
+
#print(arr[start_pos:])
|
216 |
+
terms_arr = []
|
217 |
+
for i,line in enumerate(arr[start_pos:]):
|
218 |
+
terms = line.split('\t')
|
219 |
+
if (len(terms) == 5):
|
220 |
+
#print(terms)
|
221 |
+
terms_arr.append(terms)
|
222 |
+
return terms_arr
|
223 |
+
|
224 |
+
def masked_word_first_letter_capitalize(self,entity):
|
225 |
+
arr = entity.split()
|
226 |
+
ret_arr = []
|
227 |
+
for term in arr:
|
228 |
+
if (len(term) > 1 and term[0].islower() and term[1].islower()):
|
229 |
+
ret_arr.append(term[0].upper() + term[1:])
|
230 |
+
else:
|
231 |
+
ret_arr.append(term)
|
232 |
+
return ' '.join(ret_arr)
|
233 |
+
|
234 |
+
|
235 |
+
def gen_single_phrase_sentences(self,terms_arr,span_arr):
|
236 |
+
sentence_template = "%s is a entity"
|
237 |
+
#print(span_arr)
|
238 |
+
sentences = []
|
239 |
+
singleton_spans_arr = []
|
240 |
+
run_index = 0
|
241 |
+
entity = ""
|
242 |
+
singleton_span = []
|
243 |
+
while (run_index < len(span_arr)):
|
244 |
+
if (span_arr[run_index] == 1):
|
245 |
+
while (run_index < len(span_arr)):
|
246 |
+
if (span_arr[run_index] == 1):
|
247 |
+
#print(terms_arr[run_index][WORD_POS],end=' ')
|
248 |
+
if (len(entity) == 0):
|
249 |
+
entity = terms_arr[run_index][utils.WORD_POS]
|
250 |
+
else:
|
251 |
+
entity = entity + " " + terms_arr[run_index][utils.WORD_POS]
|
252 |
+
singleton_span.append(1)
|
253 |
+
run_index += 1
|
254 |
+
else:
|
255 |
+
break
|
256 |
+
#print()
|
257 |
+
for i in sentence_template.split():
|
258 |
+
if (i != "%s"):
|
259 |
+
singleton_span.append(0)
|
260 |
+
entity = self.masked_word_first_letter_capitalize(entity)
|
261 |
+
if (self.tokmod):
|
262 |
+
entity = self.modify_text_to_match_vocab(entity)
|
263 |
+
sentence = sentence_template % entity
|
264 |
+
sentences.append(sentence)
|
265 |
+
singleton_spans_arr.append(singleton_span)
|
266 |
+
#print(sentence)
|
267 |
+
#rint(singleton_span)
|
268 |
+
entity = ""
|
269 |
+
singleton_span = []
|
270 |
+
else:
|
271 |
+
run_index += 1
|
272 |
+
return sentences,singleton_spans_arr
|
273 |
+
|
274 |
+
|
275 |
+
|
276 |
+
def gen_padded_sentence(self,text,max_tokenized_sentence_length,tokenized_text_arr,orig_tokenized_length_arr,indexed_tokens_arr,attention_mask_arr,to_replace):
|
277 |
+
if (to_replace):
|
278 |
+
text_arr = text.split()
|
279 |
+
new_text_arr = []
|
280 |
+
for i in range(len(text_arr)):
|
281 |
+
if (text_arr[i] == "entity" ):
|
282 |
+
new_text_arr.append( "[MASK]")
|
283 |
+
else:
|
284 |
+
new_text_arr.append(text_arr[i])
|
285 |
+
text = ' '.join(new_text_arr)
|
286 |
+
text = '[CLS] ' + text + ' [SEP]'
|
287 |
+
tokenized_text = self.tokenizer.tokenize(text)
|
288 |
+
indexed_tokens = self.tokenizer.convert_tokens_to_ids(tokenized_text)
|
289 |
+
tok_length = len(indexed_tokens)
|
290 |
+
max_tokenized_sentence_length = max_tokenized_sentence_length if tok_length <= max_tokenized_sentence_length else tok_length
|
291 |
+
indexed_tokens_arr.append(indexed_tokens)
|
292 |
+
attention_mask_arr.append([1]*tok_length)
|
293 |
+
tokenized_text_arr.append(tokenized_text)
|
294 |
+
orig_tokenized_length_arr.append(tokenized_text)
|
295 |
+
return max_tokenized_sentence_length
|
296 |
+
|
297 |
+
|
298 |
+
|
299 |
+
def find_entity(self,word):
|
300 |
+
entities = self.labels_dict
|
301 |
+
lc_entities = self.lc_labels_dict
|
302 |
+
in_vocab = False
|
303 |
+
#words = self.filter_glue_words(words) #do not filter glue words anymore. Let them pass through
|
304 |
+
l_word = word.lower()
|
305 |
+
if l_word.isdigit():
|
306 |
+
ret_label = "MEASURE"
|
307 |
+
ret_counts = str(1)
|
308 |
+
elif (word in entities):
|
309 |
+
ret_label = entities[word]["label"]
|
310 |
+
ret_counts = entities[word]["counts"]
|
311 |
+
in_vocab = True
|
312 |
+
elif (l_word in entities):
|
313 |
+
ret_label = entities[l_word]["label"]
|
314 |
+
ret_counts = entities[l_word]["counts"]
|
315 |
+
in_vocab = True
|
316 |
+
elif (l_word in lc_entities):
|
317 |
+
ret_label = lc_entities[l_word]["label"]
|
318 |
+
ret_counts = lc_entities[l_word]["counts"]
|
319 |
+
in_vocab = True
|
320 |
+
else:
|
321 |
+
ret_label = "OTHER"
|
322 |
+
ret_counts = "1"
|
323 |
+
if (ret_label == "OTHER"):
|
324 |
+
ret_label = "UNTAGGED_ENTITY"
|
325 |
+
ret_counts = "1"
|
326 |
+
#print(word,ret_label,ret_counts)
|
327 |
+
return ret_label,ret_counts,in_vocab
|
328 |
+
|
329 |
+
#This is just a trivial hack for consistency of CI prediction of numbers
|
330 |
+
def override_ci_number_predictions(self,masked_sent):
|
331 |
+
words = masked_sent.split()
|
332 |
+
words_count = len(words)
|
333 |
+
if (len(words) == 4 and words[words_count-1] == "entity" and words[words_count -2] == "a" and words[words_count -3] == "is" and words[0].isnumeric()): #only integers skipped
|
334 |
+
return True,"two","1","NUMBER"
|
335 |
+
else:
|
336 |
+
return False,"","",""
|
337 |
+
|
338 |
+
def override_ci_for_vocab_terms(self,masked_sent):
|
339 |
+
words = masked_sent.split()
|
340 |
+
words_count = len(words)
|
341 |
+
if (len(words) == 4 and words[words_count-1] == "entity" and words[words_count -2] == "a" and words[words_count -3] == "is"):
|
342 |
+
entity,entity_count,in_vocab = self.find_entity(words[0])
|
343 |
+
if (in_vocab):
|
344 |
+
return True,words[0],entity_count,entity
|
345 |
+
return False,"","",""
|
346 |
+
|
347 |
+
|
348 |
+
|
349 |
+
def normalize_sent(self,sent):
|
350 |
+
normalized_tokens = "!\"%();?[]`{}"
|
351 |
+
end_tokens = "!,.:;?"
|
352 |
+
sent = sent.rstrip()
|
353 |
+
if (len(sent) > 1):
|
354 |
+
if (self.delimsep):
|
355 |
+
for i in range(len(normalized_tokens)):
|
356 |
+
sent = sent.replace(normalized_tokens[i],' ' + normalized_tokens[i] + ' ')
|
357 |
+
sent = sent.rstrip()
|
358 |
+
if (not sent.endswith(":__entity__")):
|
359 |
+
last_char = sent[-1]
|
360 |
+
if (last_char not in end_tokens): #End all sentences with a period if not already present in sentence.
|
361 |
+
sent = sent + ' . '
|
362 |
+
print("Normalized sent",sent)
|
363 |
+
return sent
|
364 |
+
|
365 |
+
def truncate_sent_if_too_long(self,text):
|
366 |
+
truncated_count = 0
|
367 |
+
orig_sent = text
|
368 |
+
while (True):
|
369 |
+
tok_text = '[CLS] ' + text + ' [SEP]'
|
370 |
+
tokenized_text = self.tokenizer.tokenize(tok_text)
|
371 |
+
if (len(tokenized_text) < MAX_TOKENIZED_SENT_LENGTH):
|
372 |
+
break
|
373 |
+
text = ' '.join(text.split()[:-1])
|
374 |
+
truncated_count += 1
|
375 |
+
if (truncated_count > 0):
|
376 |
+
print("Input sentence was truncated by: ", truncated_count, " tokens")
|
377 |
+
self.truncated_fp.write("Input sentence was truncated by: " + str(truncated_count) + " tokens\n")
|
378 |
+
self.truncated_fp.write(orig_sent + "\n")
|
379 |
+
self.truncated_fp.write(text + "\n\n")
|
380 |
+
return text
|
381 |
+
|
382 |
+
|
383 |
+
def get_descriptors(self,sent,pos_arr):
|
384 |
+
'''
|
385 |
+
Batched creation of descriptors given a sentence.
|
386 |
+
1) Find noun phrases to tag in a sentence if user did not explicitly tag.
|
387 |
+
2) Create 'N' CS and CI sentences if there are N phrases to tag. Total 2*N sentences
|
388 |
+
3) Create a batch padding all sentences to the maximum sentence length.
|
389 |
+
4) Perform inference on batch
|
390 |
+
5) Return json of descriptors for the ooriginal sentence as well as all CI sentences
|
391 |
+
'''
|
392 |
+
#Truncate sent if the tokenized sent is longer than max sent length
|
393 |
+
#st.info("in get descriptors")
|
394 |
+
sent = self.truncate_sent_if_too_long(sent)
|
395 |
+
#This is a modification of input text to words in vocab that match it in case insensitive manner.
|
396 |
+
#This is *STILL* required when we are using subwords too for prediction. The prediction quality is still better.
|
397 |
+
#An example is Mesothelioma is caused by exposure to asbestos. The quality of prediction is better when Mesothelioma is not split by lowercasing with A100 model
|
398 |
+
if (self.tokmod):
|
399 |
+
sent = self.modify_text_to_match_vocab(sent)
|
400 |
+
|
401 |
+
#The input sentence is normalized. Specifically all input is terminated with a punctuation if not already present. Also some of the punctuation marks are separated from text if glued to a word(disabled by default for test set sync)
|
402 |
+
sent = self.normalize_sent(sent)
|
403 |
+
|
404 |
+
#Step 1. Find entities to tag if user did not explicitly tag terms
|
405 |
+
#All noun phrases are tagged for prediction
|
406 |
+
if (SPECIFIC_TAG in sent):
|
407 |
+
terms_arr = utils.set_POS_based_on_entities(sent)
|
408 |
+
else:
|
409 |
+
if (pos_arr is None):
|
410 |
+
assert(0)
|
411 |
+
url = self.pos_server_url + sent.replace('"','\'')
|
412 |
+
r = self.dispatch_request(url)
|
413 |
+
terms_arr = self.extract_POS(r.text)
|
414 |
+
else:
|
415 |
+
# st.info("Reusing Pos arr")
|
416 |
+
terms_arr = pos_arr
|
417 |
+
|
418 |
+
print(terms_arr)
|
419 |
+
#Note span arr only contains phrases in the input that need to be tagged - not the span of all phrases in sentences
|
420 |
+
#Step 2. Create N CS sentences
|
421 |
+
#This returns masked sentences for all positions
|
422 |
+
main_sent_arr,masked_sent_arr,span_arr = utils.detect_masked_positions(terms_arr)
|
423 |
+
ignore_cs = True if (len(masked_sent_arr) == 1 and len(masked_sent_arr[0]) == 2 and masked_sent_arr[0][0] == "__entity__" and masked_sent_arr[0][1] == ".") else False #This is a boundary condition to avoid using cs if the input is just trying to get entity type for a phrase. There is no sentence context in that case.
|
424 |
+
|
425 |
+
|
426 |
+
#Step 2. Create N CI sentences
|
427 |
+
singleton_sentences,not_used_singleton_spans_arr = self.gen_single_phrase_sentences(terms_arr,span_arr)
|
428 |
+
|
429 |
+
|
430 |
+
#We now have 2*N sentences
|
431 |
+
max_tokenized_sentence_length = 0
|
432 |
+
tokenized_text_arr = []
|
433 |
+
indexed_tokens_arr = []
|
434 |
+
attention_mask_arr = []
|
435 |
+
all_sentences_arr = []
|
436 |
+
orig_tokenized_length_arr = []
|
437 |
+
assert(len(masked_sent_arr) == len(singleton_sentences))
|
438 |
+
for ci_s,cs_s in zip(singleton_sentences,masked_sent_arr):
|
439 |
+
all_sentences_arr.append(ci_s)
|
440 |
+
max_tokenized_sentence_length = self.gen_padded_sentence(ci_s,max_tokenized_sentence_length,tokenized_text_arr,orig_tokenized_length_arr,indexed_tokens_arr,attention_mask_arr,True)
|
441 |
+
cs_s = ' '.join(cs_s).replace("__entity__","entity")
|
442 |
+
all_sentences_arr.append(cs_s)
|
443 |
+
max_tokenized_sentence_length = self.gen_padded_sentence(cs_s,max_tokenized_sentence_length,tokenized_text_arr,orig_tokenized_length_arr,indexed_tokens_arr,attention_mask_arr,True)
|
444 |
+
|
445 |
+
|
446 |
+
#pad all sentences with length less than max sentence length. This includes the full sentence too since we used indexed_tokens_arr
|
447 |
+
for i in range(len(indexed_tokens_arr)):
|
448 |
+
padding = [self.tokenizer.pad_token_id]*(max_tokenized_sentence_length - len(indexed_tokens_arr[i]))
|
449 |
+
att_padding = [0]*(max_tokenized_sentence_length - len(indexed_tokens_arr[i]))
|
450 |
+
if (len(padding) > 0):
|
451 |
+
indexed_tokens_arr[i].extend(padding)
|
452 |
+
attention_mask_arr[i].extend(att_padding)
|
453 |
+
|
454 |
+
|
455 |
+
assert(len(main_sent_arr) == len(span_arr))
|
456 |
+
assert(len(all_sentences_arr) == len(indexed_tokens_arr))
|
457 |
+
assert(len(all_sentences_arr) == len(attention_mask_arr))
|
458 |
+
assert(len(all_sentences_arr) == len(tokenized_text_arr))
|
459 |
+
assert(len(all_sentences_arr) == len(orig_tokenized_length_arr))
|
460 |
+
# Convert inputs to PyTorch tensors
|
461 |
+
tokens_tensor = torch.tensor(indexed_tokens_arr)
|
462 |
+
attention_tensors = torch.tensor(attention_mask_arr)
|
463 |
+
|
464 |
+
|
465 |
+
print("Input:",sent)
|
466 |
+
ret_obj = OrderedDict()
|
467 |
+
with torch.no_grad():
|
468 |
+
predictions = self.model(tokens_tensor, attention_mask=attention_tensors)
|
469 |
+
for sent_index in range(len(predictions[0])):
|
470 |
+
|
471 |
+
#print("*** Current sentence ***",all_sentences_arr[sent_index])
|
472 |
+
if (self.log_descs):
|
473 |
+
fp = self.cs_fp if sent_index %2 != 0 else self.ci_fp
|
474 |
+
fp.write("\nCurrent sentence: " + all_sentences_arr[sent_index] + "\n")
|
475 |
+
prediction = "ci_prediction" if (sent_index %2 == 0 ) else "cs_prediction"
|
476 |
+
out_index = int(sent_index/2) + 1
|
477 |
+
if (out_index not in ret_obj):
|
478 |
+
ret_obj[out_index] = {}
|
479 |
+
assert(prediction not in ret_obj[out_index])
|
480 |
+
ret_obj[out_index][prediction] = {}
|
481 |
+
ret_obj[out_index][prediction]["sentence"] = all_sentences_arr[sent_index]
|
482 |
+
curr_sent_arr = []
|
483 |
+
ret_obj[out_index][prediction]["descs"] = curr_sent_arr
|
484 |
+
|
485 |
+
for word in range(len(tokenized_text_arr[sent_index])):
|
486 |
+
if (word == len(tokenized_text_arr[sent_index]) - 1): # SEP is skipped for CI and CS
|
487 |
+
continue
|
488 |
+
if (sent_index %2 == 0 and (word != 0 and word != len(orig_tokenized_length_arr[sent_index]) - 2)): #For all CI sentences pick only the neighbors of CLS and the last word of the sentence (X is a entity)
|
489 |
+
#if (sent_index %2 == 0 and (word != 0 and word != len(orig_tokenized_length_arr[sent_index]) - 2) and word != len(orig_tokenized_length_arr[sent_index]) - 3): #For all CI sentences - just pick CLS, "a" and "entity"
|
490 |
+
#if (sent_index %2 == 0 and (word != 0 and (word == len(orig_tokenized_length_arr[sent_index]) - 4))): #For all CI sentences pick ALL terms excluding "is" in "X is a entity"
|
491 |
+
continue
|
492 |
+
if (sent_index %2 == 0 and (word == 0 and not self.use_cls)): #This is for models like bert base cased where we cant use CLS - it is the same for all words.
|
493 |
+
continue
|
494 |
+
|
495 |
+
if (sent_index %2 != 0 and tokenized_text_arr[sent_index][word] != "[MASK]"): # for all CS sentences skip all terms except the mask position
|
496 |
+
continue
|
497 |
+
|
498 |
+
|
499 |
+
results_dict = {}
|
500 |
+
masked_index = word
|
501 |
+
#pick all model predictions for current position word
|
502 |
+
if (self.patched):
|
503 |
+
for j in range(len(predictions[0][0][sent_index][masked_index])):
|
504 |
+
tok = tokenizer.convert_ids_to_tokens([j])[0]
|
505 |
+
results_dict[tok] = float(predictions[0][0][sent_index][masked_index][j].tolist())
|
506 |
+
else:
|
507 |
+
for j in range(len(predictions[0][sent_index][masked_index])):
|
508 |
+
tok = self.tokenizer.convert_ids_to_tokens([j])[0]
|
509 |
+
results_dict[tok] = float(predictions[0][sent_index][masked_index][j].tolist())
|
510 |
+
k = 0
|
511 |
+
#sort it - big to small
|
512 |
+
sorted_d = OrderedDict(sorted(results_dict.items(), key=lambda kv: kv[1], reverse=True))
|
513 |
+
|
514 |
+
|
515 |
+
#print("********* Top predictions for token: ",tokenized_text_arr[sent_index][word])
|
516 |
+
if (self.log_descs):
|
517 |
+
fp.write("********* Top predictions for token: " + tokenized_text_arr[sent_index][word] + "\n")
|
518 |
+
if (sent_index %2 == 0): #For CI sentences, just pick half for CLS and entity position to match with CS counts
|
519 |
+
if (self.use_cls): #If we are not using [CLS] for models like BBC, then take all top k from the entity prediction
|
520 |
+
top_k = self.top_k/2
|
521 |
+
else:
|
522 |
+
top_k = self.top_k
|
523 |
+
else:
|
524 |
+
top_k = self.top_k
|
525 |
+
#Looping through each descriptor prediction for a position and picking it up subject to some conditions
|
526 |
+
for index in sorted_d:
|
527 |
+
#if (index in string.punctuation or index.startswith('##') or len(index) == 1 or index.startswith('.') or index.startswith('[')):
|
528 |
+
if index.lower() in self.descs: #these have almost no entity info - glue words like "the","a"
|
529 |
+
continue
|
530 |
+
#if (index in string.punctuation or len(index) == 1 or index.startswith('.') or index.startswith('[') or index.startswith("#")):
|
531 |
+
if (index in string.punctuation or len(index) == 1 or index.startswith('.') or index.startswith('[')):
|
532 |
+
continue
|
533 |
+
if (index.startswith("#")): #subwords suggest model is trying to predict a multi word term that generally tends to be noisy. So penalize. Count and skip
|
534 |
+
k += 1
|
535 |
+
continue
|
536 |
+
#print(index,round(float(sorted_d[index]),4))
|
537 |
+
if (sent_index % 2 != 0):
|
538 |
+
#CS predictions
|
539 |
+
entity,entity_count,dummy = self.find_entity(index)
|
540 |
+
if (self.log_descs):
|
541 |
+
self.cs_fp.write(index + " " + entity + " " + entity_count + " " + str(round(float(sorted_d[index]),4)) + "\n")
|
542 |
+
if (not ignore_cs):
|
543 |
+
curr_sent_arr.append({"desc":index,"e":entity,"e_count":entity_count,"v":str(round(float(sorted_d[index]),4))})
|
544 |
+
if (all_sentences_arr[sent_index].strip().rstrip(".").strip().endswith("entity")):
|
545 |
+
self.always_log_fp.write(' '.join(all_sentences_arr[sent_index].split()[:-1]) + " " + index + " :__entity__\n")
|
546 |
+
else:
|
547 |
+
#CI predictions of the form X is a entity
|
548 |
+
entity,entity_count,dummy = self.find_entity(index) #index is one of the predicted descs for the [CLS]/[MASK] psition
|
549 |
+
number_override,override_index,override_entity_count,override_entity = self.override_ci_number_predictions(all_sentences_arr[sent_index]) #Note this override just uses the sentence to override all descs
|
550 |
+
if (number_override): #note the prediction for this position still takes the prediction float values model returns
|
551 |
+
index = override_index
|
552 |
+
entity_count = override_entity_count
|
553 |
+
entity = override_entity
|
554 |
+
else:
|
555 |
+
if (not self.use_cls or word != 0):
|
556 |
+
override,override_index,override_entity_count,override_entity = self.override_ci_for_vocab_terms(all_sentences_arr[sent_index]) #this also uses the sentence to override, ignoring descs, except reusing the prediction score
|
557 |
+
if (override): #note the prediction for this position still takes the prediction float values model returns
|
558 |
+
index = override_index
|
559 |
+
entity_count = override_entity_count
|
560 |
+
entity = override_entity
|
561 |
+
k = top_k #just add this override once. We dont have to add this override for each descripor and inundate downstream NER with the same signature
|
562 |
+
|
563 |
+
if (self.log_descs):
|
564 |
+
self.ci_fp.write(index + " " + entity + " " + entity_count + " " + str(round(float(sorted_d[index]),4)) + "\n")
|
565 |
+
curr_sent_arr.append({"desc":index,"e":entity,"e_count":entity_count,"v":str(round(float(sorted_d[index]),4))})
|
566 |
+
#if (index != "two" and not index.startswith("#") and not all_sentences_arr[sent_index].strip().startswith("is ")):
|
567 |
+
if (index != "two" and not all_sentences_arr[sent_index].strip().startswith("is ")):
|
568 |
+
self.always_log_fp.write(' '.join(all_sentences_arr[sent_index].split()[:-1]) + " " + index + " :__entity__\n")
|
569 |
+
k += 1
|
570 |
+
if (k >= top_k):
|
571 |
+
break
|
572 |
+
#print()
|
573 |
+
#print(ret_obj)
|
574 |
+
#print(ret_obj)
|
575 |
+
#st.info("Enf. of prediciton")
|
576 |
+
#pdb.set_trace()
|
577 |
+
#final_obj = {"terms_arr":main_sent_arr,"span_arr":span_arr,"descs_and_entities":ret_obj,"all_sentences":all_sentences_arr}
|
578 |
+
final_obj = {"input":sent,"terms_arr":main_sent_arr,"span_arr":span_arr,"descs_and_entities":ret_obj}
|
579 |
+
if (self.log_descs):
|
580 |
+
self.ci_fp.flush()
|
581 |
+
self.cs_fp.flush()
|
582 |
+
self.always_log_fp.flush()
|
583 |
+
self.truncated_fp.flush()
|
584 |
+
return final_obj
|
585 |
+
|
586 |
+
|
587 |
+
test_arr = [
|
588 |
+
"ajit? is an engineer .",
|
589 |
+
"Sam:__entity__ Malone:__entity__ .",
|
590 |
+
"1. Jesper:__entity__ Ronnback:__entity__ ( Sweden:__entity__ ) 25.76 points",
|
591 |
+
"He felt New York has a chance:__entity__ to win this year's competition .",
|
592 |
+
"The new omicron variant could increase the likelihood that people will need a fourth coronavirus vaccine dose earlier than expected, executives at Prin dummy:__entity__ said Wednesday .",
|
593 |
+
"The new omicron variant could increase the likelihood that people will need a fourth coronavirus vaccine dose earlier than expected, executives at pharmaceutical:__entity__ giant:__entity__ Pfizer:__entity__ said Wednesday .",
|
594 |
+
"The conditions:__entity__ in the camp were very poor",
|
595 |
+
"Imatinib:__entity__ is used to treat nsclc",
|
596 |
+
"imatinib:__entity__ is used to treat nsclc",
|
597 |
+
"imatinib:__entity__ mesylate:__entity__ is used to treat nsclc",
|
598 |
+
"Staten is a :__entity__",
|
599 |
+
"John is a :__entity__",
|
600 |
+
"I met my best friend at eighteen :__entity__",
|
601 |
+
"I met my best friend at Parkinson's",
|
602 |
+
"e",
|
603 |
+
"Bandolier - Budgie ' , a free itunes app for ipad , iphone and ipod touch , released in December 2011 , tells the story of the making of Bandolier in the band 's own words - including an extensive audio interview with Burke Shelley",
|
604 |
+
"The portfolio manager of the new cryptocurrency firm underwent a bone marrow biopsy: for AML:__entity__:",
|
605 |
+
"Coronavirus:__entity__ disease 2019 (COVID-19) is a contagious disease caused by severe acute respiratory syndrome coronavirus 2 (SARS-CoV-2). The first known case was identified in Wuhan, China, in December 2019.[7] The disease has since spread worldwide, leading to an ongoing pandemic.[8]Symptoms of COVID-19 are variable, but often include fever,[9] cough, headache,[10] fatigue, breathing difficulties, and loss of smell and taste.[11][12][13] Symptoms may begin one to fourteen days after exposure to the virus. At least a third of people who are infected do not develop noticeable symptoms.[14] Of those people who develop symptoms noticeable enough to be classed as patients, most (81%) develop mild to moderate symptoms (up to mild pneumonia), while 14% develop severe symptoms (dyspnea, hypoxia, or more than 50% lung involvement on imaging), and 5% suffer critical symptoms (respiratory failure, shock, or multiorgan dysfunction).[15] Older people are at a higher risk of developing severe symptoms. Some people continue to experience a range of effects (long COVID) for months after recovery, and damage to organs has been observed.[16] Multi-year studies are underway to further investigate the long-term effects of the disease.[16]COVID-19 transmits when people breathe in air contaminated by droplets and small airborne particles containing the virus. The risk of breathing these in is highest when people are in close proximity, but they can be inhaled over longer distances, particularly indoors. Transmission can also occur if splashed or sprayed with contaminated fluids in the eyes, nose or mouth, and, rarely, via contaminated surfaces. People remain contagious for up to 20 days, and can spread the virus even if they do not develop symptoms.[17][18]Several testing methods have been developed to diagnose the disease. The standard diagnostic method is by detection of the virus' nucleic acid by real-time reverse transcription polymerase chain reaction (rRT-PCR), transcription-mediated amplification (TMA), or by reverse transcription loop-mediated isothermal amplification (RT-LAMP) from a nasopharyngeal swab.Several COVID-19 vaccines have been approved and distributed in various countries, which have initiated mass vaccination campaigns. Other preventive measures include physical or social distancing, quarantining, ventilation of indoor spaces, covering coughs and sneezes, hand washing, and keeping unwashed hands away from the face. The use of face masks or coverings has been recommended in public settings to minimize the risk of transmissions. While work is underway to develop drugs that inhibit the virus, the primary treatment is symptomatic. Management involves the treatment of symptoms, supportive care, isolation, and experimental measures.",
|
606 |
+
"imatinib was used to treat Michael Jackson . ",
|
607 |
+
"eg .",
|
608 |
+
"mesothelioma is caused by exposure to organic :__entity__",
|
609 |
+
"Mesothelioma is caused by exposure to asbestos:__entity__",
|
610 |
+
"Asbestos is a highly :__entity__",
|
611 |
+
"Fyodor:__entity__ Mikhailovich:__entity__ Dostoevsky:__entity__ was treated for Parkinsons:__entity__ and later died of lung carcinoma",
|
612 |
+
"Fyodor:__entity__ Mikhailovich:__entity__ Dostoevsky:__entity__",
|
613 |
+
"imatinib was used to treat Michael:__entity__ Jackson:__entity__",
|
614 |
+
"Ajit flew to Boston:__entity__",
|
615 |
+
"Ajit:__entity__ flew to Boston",
|
616 |
+
"A eGFR below 60:__entity__ indicates chronic kidney disease",
|
617 |
+
"imatinib was used to treat Michael Jackson",
|
618 |
+
"Ajit Valath:__entity__ Rajasekharan is an engineer at nFerence headquartered in Cambrigde MA",
|
619 |
+
"imatinib:__entity__",
|
620 |
+
"imatinib",
|
621 |
+
"iplimumab:__entity__",
|
622 |
+
"iplimumab",
|
623 |
+
"engineer:__entity__",
|
624 |
+
"engineer",
|
625 |
+
"Complications include peritonsillar:__entity__ abscess::__entity__",
|
626 |
+
"Imatinib was the first signal transduction inhibitor (STI,, used in a clinical setting. It prevents a BCR-ABL protein from exerting its role in the oncogenic pathway in chronic:__entity__ myeloid:__entity__ leukemia:__entity__ (CML,",
|
627 |
+
"Imatinib was the first signal transduction inhibitor (STI,, used in a clinical setting. It prevents a BCR-ABL protein from exerting its role in the oncogenic pathway in chronic myeloid leukemia (CML,",
|
628 |
+
"Imatinib was the first signal transduction inhibitor (STI,, used in a clinical setting. It prevents a BCR-ABL protein from exerting its role in the oncogenic pathway in chronic:__entity__ myeloid:___entity__ leukemia:__entity__ (CML,",
|
629 |
+
"Ajit Rajasekharan is an engineer:__entity__ at nFerence:__entity__",
|
630 |
+
"Imatinib was the first signal transduction inhibitor (STI,, used in a clinical setting. It prevents a BCR-ABL protein from exerting its role in the oncogenic pathway in chronic myeloid leukemia (CML,",
|
631 |
+
"Ajit:__entity__ Rajasekharan:__entity__ is an engineer",
|
632 |
+
"Imatinib:__entity__ was the first signal transduction inhibitor (STI,, used in a clinical setting. It prevents a BCR-ABL protein from exerting its role in the oncogenic pathway in chronic myeloid leukemia (CML,",
|
633 |
+
"Ajit Valath Rajasekharan is an engineer at nFerence headquartered in Cambrigde MA",
|
634 |
+
"Ajit:__entity__ Valath Rajasekharan is an engineer:__entity__ at nFerence headquartered in Cambrigde MA",
|
635 |
+
"Ajit:__entity__ Valath:__entity__ Rajasekharan is an engineer:__entity__ at nFerence headquartered in Cambrigde MA",
|
636 |
+
"Ajit:__entity__ Valath:__entity__ Rajasekharan:__entity__ is an engineer:__entity__ at nFerence headquartered in Cambrigde MA",
|
637 |
+
"Ajit Raj is an engineer:__entity__ at nFerence",
|
638 |
+
"Ajit Valath:__entity__ Rajasekharan is an engineer:__entity__ at nFerence headquartered in Cambrigde:__entity__ MA",
|
639 |
+
"Ajit Valath Rajasekharan is an engineer:__entity__ at nFerence headquartered in Cambrigde:__entity__ MA",
|
640 |
+
"Ajit Valath Rajasekharan is an engineer:__entity__ at nFerence headquartered in Cambrigde MA",
|
641 |
+
"Ajit Valath Rajasekharan is an engineer at nFerence headquartered in Cambrigde MA",
|
642 |
+
"Ajit:__entity__ Rajasekharan:__entity__ is an engineer at nFerence:__entity__",
|
643 |
+
"Imatinib mesylate is used to treat non small cell lung cancer",
|
644 |
+
"Imatinib mesylate is used to treat :__entity__",
|
645 |
+
"Imatinib is a term:__entity__",
|
646 |
+
"nsclc is a term:__entity__",
|
647 |
+
"Ajit Rajasekharan is a term:__entity__",
|
648 |
+
"ajit rajasekharan is a term:__entity__",
|
649 |
+
"John Doe is a term:__entity__"
|
650 |
+
]
|
651 |
+
|
652 |
+
|
653 |
+
def test_sentences(singleton,iter_val):
|
654 |
+
with open("debug.txt","w") as fp:
|
655 |
+
for test in iter_val:
|
656 |
+
test = test.rstrip('\n')
|
657 |
+
fp.write(test + "\n")
|
658 |
+
print(test)
|
659 |
+
out = singleton.get_descriptors(test)
|
660 |
+
print(out)
|
661 |
+
fp.write(json.dumps(out,indent=4))
|
662 |
+
fp.flush()
|
663 |
+
print()
|
664 |
+
pdb.set_trace()
|
665 |
+
|
666 |
+
|
667 |
+
if __name__ == '__main__':
|
668 |
+
parser = argparse.ArgumentParser(description='BERT descriptor service given a sentence. The word to be masked is specified as the special token entity ',formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
669 |
+
parser.add_argument('-config', action="store", dest="config", default=DEFAULT_CONFIG,help='config file path')
|
670 |
+
parser.add_argument('-model', action="store", dest="model", default=DEFAULT_MODEL_PATH,help='BERT pretrained models, or custom model path')
|
671 |
+
parser.add_argument('-input', action="store", dest="input", default="",help='Optional input file with sentences. If not specified, assumed to be canned sentence run (default behavior)')
|
672 |
+
parser.add_argument('-topk', action="store", dest="topk", default=DEFAULT_TOP_K,type=int,help='Number of neighbors to display')
|
673 |
+
parser.add_argument('-tolower', dest="tolower", action='store_true',help='Convert tokens to lowercase. Set to True only for uncased models')
|
674 |
+
parser.add_argument('-no-tolower', dest="tolower", action='store_false',help='Convert tokens to lowercase. Set to True only for uncased models')
|
675 |
+
parser.set_defaults(tolower=False)
|
676 |
+
parser.add_argument('-patched', dest="patched", action='store_true',help='Is pytorch code patched to harvest [CLS]')
|
677 |
+
parser.add_argument('-no-patched', dest="patched", action='store_false',help='Is pytorch code patched to harvest [CLS]')
|
678 |
+
parser.add_argument('-abbrev', dest="abbrev", action='store_true',help='Just output pivots - not all neighbors')
|
679 |
+
parser.add_argument('-no-abbrev', dest="abbrev", action='store_false',help='Just output pivots - not all neighbors')
|
680 |
+
parser.add_argument('-tokmod', dest="tokmod", action='store_true',help='Modify input token casings to match vocab - meaningful only for cased models')
|
681 |
+
parser.add_argument('-no-tokmod', dest="tokmod", action='store_false',help='Modify input token casings to match vocab - meaningful only for cased models')
|
682 |
+
parser.add_argument('-vocab', action="store", dest="vocab", default=DEFAULT_MODEL_PATH,help='Path to vocab file. This is required only if tokmod is true')
|
683 |
+
parser.add_argument('-labels', action="store", dest="labels", default=DEFAULT_LABELS_PATH,help='Path to labels file. This returns labels also')
|
684 |
+
parser.add_argument('-delimsep', dest="delimsep", action='store_true',help='Modify input tokens where delimiters are stuck to tokens. Turned off by default to be in sync with test sets')
|
685 |
+
parser.add_argument('-no-delimsep', dest="delimsep", action='store_true',help='Modify input tokens where delimiters are stuck to tokens. Turned off by default to be in sync with test sets')
|
686 |
+
parser.set_defaults(tolower=False)
|
687 |
+
parser.set_defaults(patched=False)
|
688 |
+
parser.set_defaults(abbrev=True)
|
689 |
+
parser.set_defaults(tokmod=True)
|
690 |
+
parser.set_defaults(delimsep=False)
|
691 |
+
|
692 |
+
results = parser.parse_args()
|
693 |
+
try:
|
694 |
+
singleton = BatchInference(results.config,results.model,results.tolower,results.patched,results.topk,results.abbrev,results.tokmod,results.vocab,results.labels,results.delimsep)
|
695 |
+
print("To lower casing is set to:",results.tolower)
|
696 |
+
if (len(results.input) == 0):
|
697 |
+
print("Canned test mode")
|
698 |
+
test_sentences(singleton,test_arr)
|
699 |
+
else:
|
700 |
+
print("Batch file test mode")
|
701 |
+
fp = open(results.input)
|
702 |
+
test_sentences(singleton,fp)
|
703 |
+
|
704 |
+
except:
|
705 |
+
print("Unexpected error:", sys.exc_info()[0])
|
706 |
+
traceback.print_exc(file=sys.stdout)
|
707 |
+
|
aggregate_server_json.py
ADDED
@@ -0,0 +1,541 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
import threading
|
3 |
+
import time
|
4 |
+
import math
|
5 |
+
import sys
|
6 |
+
import pdb
|
7 |
+
import requests
|
8 |
+
import urllib.parse
|
9 |
+
from common import *
|
10 |
+
import config_utils as cf
|
11 |
+
import json
|
12 |
+
from collections import OrderedDict
|
13 |
+
import argparse
|
14 |
+
import numpy as np
|
15 |
+
|
16 |
+
|
17 |
+
MASK = ":__entity__"
|
18 |
+
RESULT_MASK = "NER_FINAL_RESULTS:"
|
19 |
+
DEFAULT_CONFIG = "./ensemble_config.json"
|
20 |
+
|
21 |
+
DEFAULT_TEST_BATCH_FILE="bootstrap_test_set.txt"
|
22 |
+
NER_OUTPUT_FILE="ner_output.txt"
|
23 |
+
DEFAULT_THRESHOLD = 1 #1 standard deviation from nean - for cross over prediction
|
24 |
+
|
25 |
+
actions_arr = []
|
26 |
+
|
27 |
+
class AggregateNER:
|
28 |
+
def __init__(self,config_file):
|
29 |
+
global actions_arr
|
30 |
+
base_path = cf.read_config(config_file)["BASE_PATH"] if ("BASE_PATH" in cf.read_config(config_file)) else "./"
|
31 |
+
self.error_fp = open(base_path + "failed_queries_log.txt","a")
|
32 |
+
self.rfp = open(base_path + "query_response_log.txt","a")
|
33 |
+
self.query_log_fp = open(base_path + "query_logs.txt","a")
|
34 |
+
self.inferred_entities_log_fp = open(base_path + "inferred_entities_log.txt","a")
|
35 |
+
self.threshold = DEFAULT_THRESHOLD #TBD read this from confg. cf.read_config()["CROSS_OVER_THRESHOLD_SIGMA"]
|
36 |
+
self.servers = cf.read_config(config_file)["NER_SERVERS"]
|
37 |
+
actions_arr = [
|
38 |
+
{"url":cf.read_config(config_file)["actions_arr"][0]["url"],"desc":cf.read_config(config_file)["actions_arr"][0]["desc"], "precedence":cf.read_config(config_file)["bio_precedence_arr"],"common":cf.read_config(config_file)["common_entities_arr"]},
|
39 |
+
{"url":cf.read_config(config_file)["actions_arr"][1]["url"],"desc":cf.read_config(config_file)["actions_arr"][1]["desc"],"precedence":cf.read_config(config_file)["phi_precedence_arr"],"common":cf.read_config(config_file)["common_entities_arr"]},
|
40 |
+
]
|
41 |
+
|
42 |
+
def add_term_punct(self,sent):
|
43 |
+
if (len(sent) > 1):
|
44 |
+
end_tokens = "!,.:;?"
|
45 |
+
last_char = sent[-1]
|
46 |
+
if (last_char not in end_tokens): #End all sentences with a period if not already present in sentence.
|
47 |
+
sent = sent + ' . '
|
48 |
+
print("End punctuated sent:",sent)
|
49 |
+
return sent
|
50 |
+
|
51 |
+
def fetch_all(self,inp,model_results_arr):
|
52 |
+
|
53 |
+
self.query_log_fp.write(inp+"\n")
|
54 |
+
self.query_log_fp.flush()
|
55 |
+
inp = self.add_term_punct(inp)
|
56 |
+
results = model_results_arr
|
57 |
+
#print(json.dumps(results,indent=4))
|
58 |
+
|
59 |
+
#this updates results with ensembled results
|
60 |
+
results = self.ensemble_processing(inp,results)
|
61 |
+
|
62 |
+
return_stat = "Failed" if len(results["ensembled_ner"]) == 0 else "Success"
|
63 |
+
results["stats"] = { "Ensemble server count" : str(len(model_results_arr)), "return_status": return_stat}
|
64 |
+
|
65 |
+
self.rfp.write( "\n" + json.dumps(results,indent=4))
|
66 |
+
self.rfp.flush()
|
67 |
+
return results
|
68 |
+
|
69 |
+
|
70 |
+
def get_conflict_resolved_entity(self,results,term_index,terms_count,servers_arr):
|
71 |
+
pos_index = str(term_index + 1)
|
72 |
+
s1_entity = extract_main_entity(results,0,pos_index)
|
73 |
+
s2_entity = extract_main_entity(results,1,pos_index)
|
74 |
+
span_count1 = get_span_info(results,0,term_index,terms_count)
|
75 |
+
span_count2 = get_span_info(results,1,term_index,terms_count)
|
76 |
+
if(span_count1 != span_count2):
|
77 |
+
print("Both input spans dont match. This is the effect of normalized casing that is model specific. Picking min span length")
|
78 |
+
span_count1 = span_count1 if span_count1 <= span_count2 else span_count2
|
79 |
+
if (s1_entity == s2_entity):
|
80 |
+
server_index = 0 if (s1_entity in servers_arr[0]["precedence"]) else 1
|
81 |
+
if (s1_entity != "O"):
|
82 |
+
print("Both servers agree on prediction for term:",results[0]["ner"][pos_index]["term"],":",s1_entity)
|
83 |
+
return server_index,span_count1,-1
|
84 |
+
else:
|
85 |
+
print("Servers do not agree on prediction for term:",results[0]["ner"][pos_index]["term"],":",s1_entity,s2_entity)
|
86 |
+
if (s2_entity == "O"):
|
87 |
+
print("Server 2 returned O. Picking server 1")
|
88 |
+
return 0,span_count1,-1
|
89 |
+
if (s1_entity == "O"):
|
90 |
+
print("Server 1 returned O. Picking server 2")
|
91 |
+
return 1,span_count2,-1
|
92 |
+
#Both the servers dont agree on their predictions. First server is BIO server. Second is PHI
|
93 |
+
#Examine both server predictions.
|
94 |
+
#Case 1: If just one of them makes a single prediction, then just pick that - it indicates one model is confident while the other isnt.
|
95 |
+
#Else.
|
96 |
+
# If the top prediction of one of them is a cross prediction, then again drop that prediction and pick the server being cross predicted.
|
97 |
+
# Else. Return both predictions, but with the higher confidence prediction first
|
98 |
+
#Case 2: Both dont cross predict. Then just return both predictions with higher confidence prediction listed first
|
99 |
+
#Cross prediction is checked only for predictions a server makes ABOVE prediction mean.
|
100 |
+
picked_server_index,cross_prediction_count = self.pick_single_server_if_possible(results,term_index,servers_arr)
|
101 |
+
return picked_server_index,span_count1,cross_prediction_count
|
102 |
+
|
103 |
+
def pick_single_server_if_possible(self,results,term_index,servers_arr):
|
104 |
+
'''
|
105 |
+
Return param : index of picked server
|
106 |
+
'''
|
107 |
+
pos_index = str(term_index + 1)
|
108 |
+
predictions_dict = {}
|
109 |
+
orig_cs_predictions_dict = {}
|
110 |
+
single_prediction_count = 0
|
111 |
+
single_prediction_server_index = -1
|
112 |
+
for server_index in range(len(results)):
|
113 |
+
if (pos_index in results[server_index]["entity_distribution"]):
|
114 |
+
predictions = self.get_predictions_above_threshold(results[server_index]["entity_distribution"][pos_index])
|
115 |
+
predictions_dict[server_index] = predictions #This is used below to only return top server prediction
|
116 |
+
|
117 |
+
orig_cs_predictions = self.get_predictions_above_threshold(results[server_index]["orig_cs_prediction_details"][pos_index])
|
118 |
+
orig_cs_predictions_dict[server_index] = orig_cs_predictions #this is used below for cross prediction determination since it is just a CS prediction
|
119 |
+
#single_prediction_count += 1 if (len(orig_cs_predictions) == 1) else 0
|
120 |
+
#if (len(orig_cs_predictions) == 1):
|
121 |
+
# single_prediction_server_index = server_index
|
122 |
+
if (single_prediction_count == 1):
|
123 |
+
is_included = is_included_in_server_entities(orig_cs_predictions_dict[single_prediction_server_index],servers_arr[single_prediction_server_index],False)
|
124 |
+
if(is_included == False) :
|
125 |
+
print("This is an odd case of single server prediction, that is a cross over")
|
126 |
+
ret_index = 0 if single_prediction_server_index == 1 else 1
|
127 |
+
return ret_index,-1
|
128 |
+
else:
|
129 |
+
print("Returning the index of single prediction server")
|
130 |
+
return single_prediction_server_index,-1
|
131 |
+
elif (single_prediction_count == 2):
|
132 |
+
print("Both have single predictions")
|
133 |
+
cross_predictions = {}
|
134 |
+
cross_prediction_count = 0
|
135 |
+
for server_index in range(len(results)):
|
136 |
+
if (pos_index in results[server_index]["entity_distribution"]):
|
137 |
+
is_included = is_included_in_server_entities(orig_cs_predictions_dict[server_index],servers_arr[server_index],False)
|
138 |
+
cross_predictions[server_index] = not is_included
|
139 |
+
cross_prediction_count += 1 if not is_included else 0
|
140 |
+
if (cross_prediction_count == 2):
|
141 |
+
#this is an odd case of both cross predicting with high confidence. Not sure if we will ever come here.
|
142 |
+
print("*********** BOTH servers are cross predicting! ******")
|
143 |
+
return self.pick_top_server_prediction(predictions_dict),2
|
144 |
+
elif (cross_prediction_count == 0):
|
145 |
+
#Neither are cross predecting
|
146 |
+
print("*********** BOTH servers have single predictions within their domain - returning both ******")
|
147 |
+
return self.pick_top_server_prediction(predictions_dict),2
|
148 |
+
else:
|
149 |
+
print("Returning just the server that is not cross predicting, dumping the cross prediction")
|
150 |
+
ret_index = 1 if cross_predictions[0] == True else 0 #Given a server cross predicts, return the other server index
|
151 |
+
return ret_index,-1
|
152 |
+
else:
|
153 |
+
print("*** Both servers have multiple predictions above mean")
|
154 |
+
#both have multiple predictions above mean
|
155 |
+
cross_predictions = {}
|
156 |
+
strict_cross_predictions = {}
|
157 |
+
cross_prediction_count = 0
|
158 |
+
strict_cross_prediction_count = 0
|
159 |
+
for server_index in range(len(results)):
|
160 |
+
if (pos_index in results[server_index]["entity_distribution"]):
|
161 |
+
is_included = is_included_in_server_entities(orig_cs_predictions_dict[server_index],servers_arr[server_index],False)
|
162 |
+
strict_is_included = strict_is_included_in_server_entities(orig_cs_predictions_dict[server_index],servers_arr[server_index],False)
|
163 |
+
cross_predictions[server_index] = not is_included
|
164 |
+
strict_cross_predictions[server_index] = not strict_is_included
|
165 |
+
cross_prediction_count += 1 if not is_included else 0
|
166 |
+
strict_cross_prediction_count += 1 if not strict_is_included else 0
|
167 |
+
if (cross_prediction_count == 2):
|
168 |
+
print("*********** BOTH servers are ALSO cross predicting and have multiple predictions above mean ******")
|
169 |
+
return self.pick_top_server_prediction(predictions_dict),2
|
170 |
+
elif (cross_prediction_count == 0):
|
171 |
+
print("*********** BOTH servers are ALSO predicting within their domain ******")
|
172 |
+
#if just one of them is predicting in the common set, then just pick the server that is predicting in its primary set.
|
173 |
+
#if (strict_cross_prediction_count == 1):
|
174 |
+
# ret_index = 1 if (0 not in strict_cross_predictions or strict_cross_predictions[0] == True) else 0 #Given a server cross predicts, return the other server index
|
175 |
+
# return ret_index,-1
|
176 |
+
#else:
|
177 |
+
# return self.pick_top_server_prediction(predictions_dict),2
|
178 |
+
return self.pick_top_server_prediction(predictions_dict),2
|
179 |
+
else:
|
180 |
+
print("Returning just the server that is not cross predicting, dumping the cross prediction. This is mainly to reduce the noise in prefix predictions that show up in CS context predictions")
|
181 |
+
ret_index = 1 if (0 not in cross_predictions or cross_predictions[0] == True) else 0 #Given a server cross predicts, return the other server index
|
182 |
+
return ret_index,-1
|
183 |
+
#print("*********** One of them is also cross predicting ******")
|
184 |
+
#return self.pick_top_server_prediction(predictions_dict),2
|
185 |
+
|
186 |
+
|
187 |
+
|
188 |
+
def pick_top_server_prediction(self,predictions_dict):
|
189 |
+
'''
|
190 |
+
'''
|
191 |
+
if (len(predictions_dict) != 2):
|
192 |
+
return 0
|
193 |
+
assert(len(predictions_dict) == 2)
|
194 |
+
return 0 if (predictions_dict[0][0]["conf"] >= predictions_dict[1][0]["conf"]) else 1
|
195 |
+
|
196 |
+
|
197 |
+
def get_predictions_above_threshold(self,predictions):
|
198 |
+
dist = predictions["cs_distribution"]
|
199 |
+
sum_predictions = 0
|
200 |
+
ret_arr = []
|
201 |
+
if(len(dist) != 0):
|
202 |
+
mean_score = 1.0/len(dist) #input is a prob distriubution. so sum is 1
|
203 |
+
else:
|
204 |
+
mean_score = 0
|
205 |
+
#sum_deviation = 0
|
206 |
+
#for node in dist:
|
207 |
+
# sum_deviation += (mean_score - node["confidence"])*(mean_score - node["confidence"])
|
208 |
+
#variance = sum_deviation/len(dist)
|
209 |
+
#std_dev = math.sqrt(variance)
|
210 |
+
#threshold = mean_score + std_dev*self.threshold #default is 1 standard deviation from mean
|
211 |
+
threshold = mean_score
|
212 |
+
pick_count = 1
|
213 |
+
for node in dist:
|
214 |
+
if (node["confidence"] >= threshold):
|
215 |
+
ret_arr.append({"e":node["e"],"conf":node["confidence"]})
|
216 |
+
pick_count += 1
|
217 |
+
else:
|
218 |
+
break #this is a reverse sorted list. So no need to check anymore
|
219 |
+
if (len(dist) > 0):
|
220 |
+
assert(len(ret_arr) > 0)
|
221 |
+
return ret_arr
|
222 |
+
|
223 |
+
def check_if_entity_in_arr(self,entity,arr):
|
224 |
+
for node in arr:
|
225 |
+
if (entity == node["e"]):
|
226 |
+
return True
|
227 |
+
return False
|
228 |
+
|
229 |
+
def gen_resolved_entity(self,results,server_index,pivot_index,run_index,cross_prediction_count,servers_arr):
|
230 |
+
if (cross_prediction_count == 1 or cross_prediction_count == -1):
|
231 |
+
#This is the case where we are emitting just one server prediction. In this case, if CS and consolidated dont match, emit both
|
232 |
+
if (pivot_index in results[server_index]["orig_cs_prediction_details"]):
|
233 |
+
if (len(results[server_index]["orig_cs_prediction_details"][pivot_index]['cs_distribution']) == 0):
|
234 |
+
#just use the ci prediction in this case. This happens only for boundary cases of a single entity in a sentence and there is no context
|
235 |
+
orig_cs_entity = results[server_index]["orig_ci_prediction_details"][pivot_index]['cs_distribution'][0]
|
236 |
+
else:
|
237 |
+
orig_cs_entity = results[server_index]["orig_cs_prediction_details"][pivot_index]['cs_distribution'][0]
|
238 |
+
orig_ci_entity = results[server_index]["orig_ci_prediction_details"][pivot_index]['cs_distribution'][0]
|
239 |
+
m1 = orig_cs_entity["e"].split('[')[0]
|
240 |
+
m1_ci = orig_ci_entity["e"].split('[')[0]
|
241 |
+
is_ci_included = True if (m1_ci in servers_arr[server_index]["precedence"]) else False
|
242 |
+
consolidated_entity = results[server_index]["ner"][pivot_index]
|
243 |
+
m2,dummy = prefix_strip(consolidated_entity["e"].split('[')[0])
|
244 |
+
if (m1 != m2):
|
245 |
+
#if we come here consolidated is not same as cs prediction. So we emit both consolidated and cs
|
246 |
+
ret_obj = results[server_index]["ner"][run_index].copy()
|
247 |
+
dummy,prefix = prefix_strip(ret_obj["e"])
|
248 |
+
n1 = flip_category(orig_cs_entity)
|
249 |
+
n1["e"] = prefix + n1["e"]
|
250 |
+
n2 = flip_category(consolidated_entity)
|
251 |
+
ret_obj["e"] = n2["e"] + "/" + n1["e"]
|
252 |
+
return ret_obj
|
253 |
+
else:
|
254 |
+
#if we come here consolidated is same as cs prediction. So we try to either use ci or the second cs prediction if ci is out of domain
|
255 |
+
if (m1 != m1_ci):
|
256 |
+
#CS and CI are not same
|
257 |
+
if (is_ci_included):
|
258 |
+
#Emity both CS and CI
|
259 |
+
ret_obj = results[server_index]["ner"][run_index].copy()
|
260 |
+
dummy,prefix = prefix_strip(ret_obj["e"])
|
261 |
+
n1 = flip_category(orig_cs_entity)
|
262 |
+
n1["e"] = prefix + n1["e"]
|
263 |
+
n2 = flip_category(orig_ci_entity)
|
264 |
+
n2["e"] = prefix + n2["e"]
|
265 |
+
ret_obj["e"] = n1["e"] + "/" + n2["e"]
|
266 |
+
return ret_obj
|
267 |
+
else:
|
268 |
+
#We come here for the case where CI is not in server list. So we pick the second cs as an option if meaningful
|
269 |
+
if (len(results[server_index]["orig_cs_prediction_details"][pivot_index]['cs_distribution']) >= 2):
|
270 |
+
ret_arr = self.get_predictions_above_threshold(results[server_index]["orig_cs_prediction_details"][pivot_index])
|
271 |
+
orig_cs_second_entity = results[server_index]["orig_cs_prediction_details"][pivot_index]['cs_distribution'][1]
|
272 |
+
m2_cs = orig_cs_second_entity["e"].split('[')[0]
|
273 |
+
is_cs_included = True if (m2_cs in servers_arr[server_index]["precedence"]) else False
|
274 |
+
is_cs_included = True #Disabling cs included check. If prediction above threshold is cross prediction, then letting it through
|
275 |
+
assert (m2_cs != m1)
|
276 |
+
if (is_cs_included and self.check_if_entity_in_arr(m2_cs,ret_arr)):
|
277 |
+
ret_obj = results[server_index]["ner"][run_index].copy()
|
278 |
+
dummy,prefix = prefix_strip(ret_obj["e"])
|
279 |
+
n1 = flip_category(orig_cs_second_entity)
|
280 |
+
n1["e"] = prefix + n1["e"]
|
281 |
+
n2 = flip_category(orig_cs_entity)
|
282 |
+
n2["e"] = prefix + n2["e"]
|
283 |
+
ret_obj["e"] = n2["e"] + "/" + n1["e"]
|
284 |
+
return ret_obj
|
285 |
+
else:
|
286 |
+
return flip_category(results[server_index]["ner"][run_index])
|
287 |
+
else:
|
288 |
+
return flip_category(results[server_index]["ner"][run_index])
|
289 |
+
else:
|
290 |
+
#here cs and ci are same. So use two cs predictions if meaningful
|
291 |
+
if (len(results[server_index]["orig_cs_prediction_details"][pivot_index]['cs_distribution']) >= 2):
|
292 |
+
ret_arr = self.get_predictions_above_threshold(results[server_index]["orig_cs_prediction_details"][pivot_index])
|
293 |
+
orig_cs_second_entity = results[server_index]["orig_cs_prediction_details"][pivot_index]['cs_distribution'][1]
|
294 |
+
m2_cs = orig_cs_second_entity["e"].split('[')[0]
|
295 |
+
is_cs_included = True if (m2_cs in servers_arr[server_index]["precedence"]) else False
|
296 |
+
is_cs_included = True #Disabling cs included check. If prediction above threshold is cross prediction, then letting it through
|
297 |
+
assert (m2_cs != m1)
|
298 |
+
if (is_cs_included and self.check_if_entity_in_arr(m2_cs,ret_arr)):
|
299 |
+
ret_obj = results[server_index]["ner"][run_index].copy()
|
300 |
+
dummy,prefix = prefix_strip(ret_obj["e"])
|
301 |
+
n1 = flip_category(orig_cs_second_entity)
|
302 |
+
n1["e"] = prefix + n1["e"]
|
303 |
+
n2 = flip_category(orig_cs_entity)
|
304 |
+
n2["e"] = prefix + n2["e"]
|
305 |
+
ret_obj["e"] = n2["e"] + "/" + n1["e"]
|
306 |
+
return ret_obj
|
307 |
+
else:
|
308 |
+
return flip_category(results[server_index]["ner"][run_index])
|
309 |
+
else:
|
310 |
+
return flip_category(results[server_index]["ner"][run_index])
|
311 |
+
else:
|
312 |
+
return flip_category(results[server_index]["ner"][run_index])
|
313 |
+
else:
|
314 |
+
#Case where both servers dont match
|
315 |
+
ret_obj = results[server_index]["ner"][run_index].copy()
|
316 |
+
#ret_obj["e"] = results[0]["ner"][run_index]["e"] + "/" + results[1]["ner"][run_index]["e"]
|
317 |
+
index2 = 1 if server_index == 0 else 0 #this is the index of the dominant server with hihgher prediction confidence
|
318 |
+
n1 = flip_category(results[server_index]["ner"][run_index])
|
319 |
+
n2 = flip_category(results[index2]["ner"][run_index])
|
320 |
+
ret_obj["e"] = n1["e"] + "/" + n2["e"]
|
321 |
+
return ret_obj
|
322 |
+
|
323 |
+
|
324 |
+
def confirm_same_size_responses(self,sent,results):
|
325 |
+
count = 0
|
326 |
+
for i in range(len(results)):
|
327 |
+
if ("ner" in results[i]):
|
328 |
+
ner = results[i]["ner"]
|
329 |
+
else:
|
330 |
+
print("Server",i," returned invalid response;",results[i])
|
331 |
+
self.error_fp.write("Server " + str(i) + " failed for query: " + sent + "\n")
|
332 |
+
self.error_fp.flush()
|
333 |
+
return 0
|
334 |
+
if(count == 0):
|
335 |
+
assert(len(ner) > 0)
|
336 |
+
count = len(ner)
|
337 |
+
else:
|
338 |
+
if (count != len(ner)):
|
339 |
+
print("Warning. The return sizes of both servers do not match. This must be truncated sentence, where tokenization causes different length truncations. Using min length")
|
340 |
+
count = count if count < len(ner) else len(ner)
|
341 |
+
return count
|
342 |
+
|
343 |
+
|
344 |
+
def get_ensembled_entities(self,sent,results,servers_arr):
|
345 |
+
ensembled_ner = OrderedDict()
|
346 |
+
orig_cs_predictions = OrderedDict()
|
347 |
+
orig_ci_predictions = OrderedDict()
|
348 |
+
ensembled_conf = OrderedDict()
|
349 |
+
ambig_ensembled_conf = OrderedDict()
|
350 |
+
ensembled_ci = OrderedDict()
|
351 |
+
ensembled_cs = OrderedDict()
|
352 |
+
ambig_ensembled_ci = OrderedDict()
|
353 |
+
ambig_ensembled_cs = OrderedDict()
|
354 |
+
print("Ensemble candidates")
|
355 |
+
terms_count = self.confirm_same_size_responses(sent,results)
|
356 |
+
if (terms_count == 0):
|
357 |
+
return ensembled_ner,ensembled_conf,ensembled_ci,ensembled_cs,ambig_ensembled_conf,ambig_ensembled_ci,ambig_ensembled_cs,orig_cs_predictions,orig_ci_predictions
|
358 |
+
assert(len(servers_arr) == len(results))
|
359 |
+
term_index = 0
|
360 |
+
while (term_index < terms_count):
|
361 |
+
pos_index = str(term_index + 1)
|
362 |
+
assert(len(servers_arr) == 2) #TBD. Currently assumes two servers in prototype to see if this approach works. To be extended to multiple servers
|
363 |
+
server_index,span_count,cross_prediction_count = self.get_conflict_resolved_entity(results,term_index,terms_count,servers_arr)
|
364 |
+
pivot_index = str(term_index + 1)
|
365 |
+
for span_index in range(span_count):
|
366 |
+
run_index = str(term_index + 1 + span_index)
|
367 |
+
ensembled_ner[run_index] = self.gen_resolved_entity(results,server_index,pivot_index,run_index,cross_prediction_count,servers_arr)
|
368 |
+
if (run_index in results[server_index]["entity_distribution"]):
|
369 |
+
ensembled_conf[run_index] = results[server_index]["entity_distribution"][run_index]
|
370 |
+
ensembled_conf[run_index]["e"] = strip_prefixes(ensembled_ner[run_index]["e"]) #this is to make sure the same tag can be taken from NER result or this structure.
|
371 |
+
#When both server responses are required, just return the details of first server for now
|
372 |
+
ensembled_ci[run_index] = results[server_index]["ci_prediction_details"][run_index]
|
373 |
+
ensembled_cs[run_index] = results[server_index]["cs_prediction_details"][run_index]
|
374 |
+
orig_cs_predictions[run_index] = results[server_index]["orig_cs_prediction_details"][run_index]
|
375 |
+
orig_ci_predictions[run_index] = results[server_index]["orig_ci_prediction_details"][run_index]
|
376 |
+
|
377 |
+
if (cross_prediction_count == 0 or cross_prediction_count == 2): #This is an ambiguous prediction. Send both server responses
|
378 |
+
second_server = 1 if server_index == 0 else 1
|
379 |
+
if (run_index in results[second_server]["entity_distribution"]): #It may not be present if the B/I tags are out of sync from servers.
|
380 |
+
ambig_ensembled_conf[run_index] = results[second_server]["entity_distribution"][run_index]
|
381 |
+
ambig_ensembled_conf[run_index]["e"] = ensembled_ner[run_index]["e"] #this is to make sure the same tag can be taken from NER result or this structure.
|
382 |
+
ambig_ensembled_ci[run_index] = results[second_server]["ci_prediction_details"][run_index]
|
383 |
+
if (ensembled_ner[run_index]["e"] != "O"):
|
384 |
+
self.inferred_entities_log_fp.write(results[0]["ner"][run_index]["term"] + " " + ensembled_ner[run_index]["e"] + "\n")
|
385 |
+
term_index += span_count
|
386 |
+
self.inferred_entities_log_fp.flush()
|
387 |
+
return ensembled_ner,ensembled_conf,ensembled_ci,ensembled_cs,ambig_ensembled_conf,ambig_ensembled_ci,ambig_ensembled_cs,orig_cs_predictions,orig_ci_predictions
|
388 |
+
|
389 |
+
|
390 |
+
|
391 |
+
def ensemble_processing(self,sent,results):
|
392 |
+
global actions_arr
|
393 |
+
ensembled_ner,ensembled_conf,ci_details,cs_details,ambig_ensembled_conf,ambig_ci_details,ambig_cs_details,orig_cs_predictions,orig_ci_predictions = self.get_ensembled_entities(sent,results,actions_arr)
|
394 |
+
final_ner = OrderedDict()
|
395 |
+
final_ner["ensembled_ner"] = ensembled_ner
|
396 |
+
final_ner["ensembled_prediction_details"] = ensembled_conf
|
397 |
+
final_ner["ci_prediction_details"] = ci_details
|
398 |
+
final_ner["cs_prediction_details"] = cs_details
|
399 |
+
final_ner["ambig_prediction_details_conf"] = ambig_ensembled_conf
|
400 |
+
final_ner["ambig_prediction_details_ci"] = ambig_ci_details
|
401 |
+
final_ner["ambig_prediction_details_cs"] = ambig_cs_details
|
402 |
+
final_ner["orig_cs_prediction_details"] = orig_cs_predictions
|
403 |
+
final_ner["orig_ci_prediction_details"] = orig_ci_predictions
|
404 |
+
#final_ner["individual"] = results
|
405 |
+
return final_ner
|
406 |
+
|
407 |
+
|
408 |
+
|
409 |
+
|
410 |
+
class myThread (threading.Thread):
|
411 |
+
def __init__(self, url,param,desc):
|
412 |
+
threading.Thread.__init__(self)
|
413 |
+
self.url = url
|
414 |
+
self.param = param
|
415 |
+
self.desc = desc
|
416 |
+
self.results = {}
|
417 |
+
def run(self):
|
418 |
+
print ("Starting " + self.url + self.param)
|
419 |
+
escaped_url = self.url + self.param.replace("#","-") #TBD. This is a nasty hack for client side handling of #. To be fixed. For some reason, even replacing with parse.quote or just with %23 does not help. The fragment after # is not sent to server. Works just fine in wget with %23
|
420 |
+
print("ESCAPED:",escaped_url)
|
421 |
+
out = requests.get(escaped_url)
|
422 |
+
try:
|
423 |
+
self.results = json.loads(out.text,object_pairs_hook=OrderedDict)
|
424 |
+
except:
|
425 |
+
print("Empty response from server for input:",self.param)
|
426 |
+
self.results = json.loads("{}",object_pairs_hook=OrderedDict)
|
427 |
+
self.results["server"] = self.desc
|
428 |
+
print ("Exiting " + self.url + self.param)
|
429 |
+
|
430 |
+
|
431 |
+
|
432 |
+
# Create new threads
|
433 |
+
def create_workers(inp_dict,inp):
|
434 |
+
threads_arr = []
|
435 |
+
for i in range(len(inp_dict)):
|
436 |
+
threads_arr.append(myThread(inp_dict[i]["url"],inp,inp_dict[i]["desc"]))
|
437 |
+
return threads_arr
|
438 |
+
|
439 |
+
def start_workers(threads_arr):
|
440 |
+
for thread in threads_arr:
|
441 |
+
thread.start()
|
442 |
+
|
443 |
+
def wait_for_completion(threads_arr):
|
444 |
+
for thread in threads_arr:
|
445 |
+
thread.join()
|
446 |
+
|
447 |
+
def get_results(threads_arr):
|
448 |
+
results = []
|
449 |
+
for thread in threads_arr:
|
450 |
+
results.append(thread.results)
|
451 |
+
return results
|
452 |
+
|
453 |
+
|
454 |
+
|
455 |
+
def prefix_strip(term):
|
456 |
+
prefix = ""
|
457 |
+
if (term.startswith("B_") or term.startswith("I_")):
|
458 |
+
prefix = term[:2]
|
459 |
+
term = term[2:]
|
460 |
+
return term,prefix
|
461 |
+
|
462 |
+
def strip_prefixes(term):
|
463 |
+
split_entities = term.split('/')
|
464 |
+
if (len(split_entities) == 2):
|
465 |
+
term1,dummy = prefix_strip(split_entities[0])
|
466 |
+
term2,dummy = prefix_strip(split_entities[1])
|
467 |
+
return term1 + '/' + term2
|
468 |
+
else:
|
469 |
+
assert(len(split_entities) == 1)
|
470 |
+
term1,dummy = prefix_strip(split_entities[0])
|
471 |
+
return term1
|
472 |
+
|
473 |
+
|
474 |
+
#This hack is simply done for downstream API used for UI displays the entity instead of the class. Details has all additional info
|
475 |
+
def flip_category(obj):
|
476 |
+
new_obj = obj.copy()
|
477 |
+
entity_type_arr = obj["e"].split("[")
|
478 |
+
if (len(entity_type_arr) > 1):
|
479 |
+
term = entity_type_arr[0]
|
480 |
+
if (term.startswith("B_") or term.startswith("I_")):
|
481 |
+
prefix = term[:2]
|
482 |
+
new_obj["e"] = prefix + entity_type_arr[1].rstrip("]") + "[" + entity_type_arr[0][2:] + "]"
|
483 |
+
else:
|
484 |
+
new_obj["e"] = entity_type_arr[1].rstrip("]") + "[" + entity_type_arr[0] + "]"
|
485 |
+
return new_obj
|
486 |
+
|
487 |
+
|
488 |
+
def extract_main_entity(results,server_index,pos_index):
|
489 |
+
main_entity = results[server_index]["ner"][pos_index]["e"].split('[')[0]
|
490 |
+
main_entity,dummy = prefix_strip(main_entity)
|
491 |
+
return main_entity
|
492 |
+
|
493 |
+
|
494 |
+
def get_span_info(results,server_index,term_index,terms_count):
|
495 |
+
pos_index = str(term_index + 1)
|
496 |
+
entity = results[server_index]["ner"][pos_index]["e"]
|
497 |
+
span_count = 1
|
498 |
+
if (entity.startswith("I_")):
|
499 |
+
print("Skipping an I tag for server:",server_index,". This has to be done because of mismatched span because of model specific casing normalization that changes POS tagging. This happens only for sentencees user does not explicirly tag with ':__entity__'")
|
500 |
+
return span_count
|
501 |
+
assert(not entity.startswith("I_"))
|
502 |
+
if (entity.startswith("B_")):
|
503 |
+
term_index += 1
|
504 |
+
while(term_index < terms_count):
|
505 |
+
pos_index = str(term_index + 1)
|
506 |
+
entity = results[server_index]["ner"][pos_index]["e"]
|
507 |
+
if (entity == "O"):
|
508 |
+
break
|
509 |
+
span_count += 1
|
510 |
+
term_index += 1
|
511 |
+
return span_count
|
512 |
+
|
513 |
+
def is_included_in_server_entities(predictions,s_arr,check_first_only):
|
514 |
+
for entity in predictions:
|
515 |
+
entity = entity['e'].split('[')[0]
|
516 |
+
if ((entity not in s_arr["precedence"]) and (entity not in s_arr["common"])): #do not treat the presence of an entity in common as a cross over
|
517 |
+
return False
|
518 |
+
if (check_first_only):
|
519 |
+
return True #Just check the top prediction for inclusion in the new semantics
|
520 |
+
return True
|
521 |
+
|
522 |
+
def strict_is_included_in_server_entities(predictions,s_arr,check_first_only):
|
523 |
+
for entity in predictions:
|
524 |
+
entity = entity['e'].split('[')[0]
|
525 |
+
if ((entity not in s_arr["precedence"])): #do not treat the presence of an entity in common as a cross over
|
526 |
+
return False
|
527 |
+
if (check_first_only):
|
528 |
+
return True #Just check the top prediction for inclusion in the new semantics
|
529 |
+
return True
|
530 |
+
|
531 |
+
|
532 |
+
|
533 |
+
if __name__ == '__main__':
|
534 |
+
parser = argparse.ArgumentParser(description='main NER for a single model ',formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
535 |
+
parser.add_argument('-input', action="store", dest="input",default=DEFAULT_TEST_BATCH_FILE,help='Input file for batch run option')
|
536 |
+
parser.add_argument('-config', action="store", dest="config", default=DEFAULT_CONFIG,help='config file path')
|
537 |
+
parser.add_argument('-output', action="store", dest="output",default=NER_OUTPUT_FILE,help='Output file for batch run option')
|
538 |
+
parser.add_argument('-option', action="store", dest="option",default="canned",help='Valid options are canned,batch,interactive. canned - test few canned sentences used in medium artice. batch - tag sentences in input file. Entities to be tagged are determing used POS tagging to find noun phrases.interactive - input one sentence at a time')
|
539 |
+
results = parser.parse_args()
|
540 |
+
config_file = results.config
|
541 |
+
|
app.py
ADDED
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
import streamlit as st
|
3 |
+
import torch
|
4 |
+
import string
|
5 |
+
from annotated_text import annotated_text
|
6 |
+
|
7 |
+
from flair.data import Sentence
|
8 |
+
from flair.models import SequenceTagger
|
9 |
+
from transformers import BertTokenizer, BertForMaskedLM
|
10 |
+
import BatchInference as bd
|
11 |
+
import batched_main_NER as ner
|
12 |
+
import aggregate_server_json as aggr
|
13 |
+
import json
|
14 |
+
|
15 |
+
|
16 |
+
DEFAULT_TOP_K = 20
|
17 |
+
SPECIFIC_TAG=":__entity__"
|
18 |
+
|
19 |
+
|
20 |
+
|
21 |
+
@st.cache(suppress_st_warning=True, allow_output_mutation=True)
|
22 |
+
def POS_get_model(model_name):
|
23 |
+
val = SequenceTagger.load(model_name) # Load the model
|
24 |
+
return val
|
25 |
+
|
26 |
+
def getPos(s: Sentence):
|
27 |
+
texts = []
|
28 |
+
labels = []
|
29 |
+
for t in s.tokens:
|
30 |
+
for label in t.annotation_layers.keys():
|
31 |
+
texts.append(t.text)
|
32 |
+
labels.append(t.get_labels(label)[0].value)
|
33 |
+
return texts, labels
|
34 |
+
|
35 |
+
def getDictFromPOS(texts, labels):
|
36 |
+
return [["dummy",t,l,"dummy","dummy" ] for t, l in zip(texts, labels)]
|
37 |
+
|
38 |
+
def decode(tokenizer, pred_idx, top_clean):
|
39 |
+
ignore_tokens = string.punctuation + '[PAD]'
|
40 |
+
tokens = []
|
41 |
+
for w in pred_idx:
|
42 |
+
token = ''.join(tokenizer.decode(w).split())
|
43 |
+
if token not in ignore_tokens:
|
44 |
+
tokens.append(token.replace('##', ''))
|
45 |
+
return '\n'.join(tokens[:top_clean])
|
46 |
+
|
47 |
+
def encode(tokenizer, text_sentence, add_special_tokens=True):
|
48 |
+
text_sentence = text_sentence.replace('<mask>', tokenizer.mask_token)
|
49 |
+
# if <mask> is the last token, append a "." so that models dont predict punctuation.
|
50 |
+
if tokenizer.mask_token == text_sentence.split()[-1]:
|
51 |
+
text_sentence += ' .'
|
52 |
+
|
53 |
+
input_ids = torch.tensor([tokenizer.encode(text_sentence, add_special_tokens=add_special_tokens)])
|
54 |
+
mask_idx = torch.where(input_ids == tokenizer.mask_token_id)[1].tolist()[0]
|
55 |
+
return input_ids, mask_idx
|
56 |
+
|
57 |
+
def get_all_predictions(text_sentence, top_clean=5):
|
58 |
+
# ========================= BERT =================================
|
59 |
+
input_ids, mask_idx = encode(bert_tokenizer, text_sentence)
|
60 |
+
with torch.no_grad():
|
61 |
+
predict = bert_model(input_ids)[0]
|
62 |
+
bert = decode(bert_tokenizer, predict[0, mask_idx, :].topk(top_k).indices.tolist(), top_clean)
|
63 |
+
return {'bert': bert}
|
64 |
+
|
65 |
+
def get_bert_prediction(input_text,top_k):
|
66 |
+
try:
|
67 |
+
input_text += ' <mask>'
|
68 |
+
res = get_all_predictions(input_text, top_clean=int(top_k))
|
69 |
+
return res
|
70 |
+
except Exception as error:
|
71 |
+
pass
|
72 |
+
|
73 |
+
|
74 |
+
def load_pos_model():
|
75 |
+
checkpoint = "flair/pos-english"
|
76 |
+
return POS_get_model(checkpoint)
|
77 |
+
|
78 |
+
|
79 |
+
|
80 |
+
|
81 |
+
def init_session_states():
|
82 |
+
if 'top_k' not in st.session_state:
|
83 |
+
st.session_state['top_k'] = 20
|
84 |
+
if 'pos_model' not in st.session_state:
|
85 |
+
st.session_state['pos_model'] = None
|
86 |
+
if 'bio_model' not in st.session_state:
|
87 |
+
st.session_state['bio_model'] = None
|
88 |
+
if 'phi_model' not in st.session_state:
|
89 |
+
st.session_state['phi_model'] = None
|
90 |
+
if 'ner_bio' not in st.session_state:
|
91 |
+
st.session_state['ner_bio'] = None
|
92 |
+
if 'ner_phi' not in st.session_state:
|
93 |
+
st.session_state['ner_phi'] = None
|
94 |
+
if 'aggr' not in st.session_state:
|
95 |
+
st.session_state['aggr'] = None
|
96 |
+
|
97 |
+
|
98 |
+
|
99 |
+
def get_pos_arr(input_text,display_area):
|
100 |
+
if (st.session_state['pos_model'] is None):
|
101 |
+
display_area.text("Loading model 3 of 3.Loading POS model...")
|
102 |
+
st.session_state['pos_model'] = load_pos_model()
|
103 |
+
s = Sentence(input_text)
|
104 |
+
st.session_state['pos_model'].predict(s)
|
105 |
+
texts, labels = getPos(s)
|
106 |
+
pos_results = getDictFromPOS(texts, labels)
|
107 |
+
return pos_results
|
108 |
+
|
109 |
+
def perform_inference(text,display_area):
|
110 |
+
|
111 |
+
if (st.session_state['bio_model'] is None):
|
112 |
+
display_area.text("Loading model 1 of 3. Bio model...")
|
113 |
+
st.session_state['bio_model'] = bd.BatchInference("bio/desc_a100_config.json",'ajitrajasekharan/biomedical',False,False,DEFAULT_TOP_K,True,True, "bio/","bio/a100_labels.txt",False)
|
114 |
+
|
115 |
+
if (st.session_state['phi_model'] is None):
|
116 |
+
display_area.text("Loading model 2 of 3. PHI model...")
|
117 |
+
st.session_state['phi_model'] = bd.BatchInference("bbc/desc_bbc_config.json",'bert-base-cased',False,False,DEFAULT_TOP_K,True,True, "bbc/","bbc/bbc_labels.txt",False)
|
118 |
+
|
119 |
+
#Load POS model if needed and gets POS tags
|
120 |
+
if (SPECIFIC_TAG not in text):
|
121 |
+
pos_arr = get_pos_arr(text,display_area)
|
122 |
+
else:
|
123 |
+
pos_arr = None
|
124 |
+
|
125 |
+
if (st.session_state['ner_bio'] is None):
|
126 |
+
display_area.text("Initializing BIO module...")
|
127 |
+
st.session_state['ner_bio'] = ner.UnsupNER("bio/ner_a100_config.json")
|
128 |
+
|
129 |
+
if (st.session_state['ner_phi'] is None):
|
130 |
+
display_area.text("Initializing PHI module...")
|
131 |
+
st.session_state['ner_phi'] = ner.UnsupNER("bbc/ner_bbc_config.json")
|
132 |
+
|
133 |
+
if (st.session_state['aggr'] is None):
|
134 |
+
display_area.text("Initializing Aggregation modeule...")
|
135 |
+
st.session_state['aggr'] = aggr.AggregateNER("./ensemble_config.json")
|
136 |
+
|
137 |
+
|
138 |
+
|
139 |
+
display_area.text("Getting results from BIO model...")
|
140 |
+
bio_descs = st.session_state['bio_model'].get_descriptors(text,pos_arr)
|
141 |
+
display_area.text("Getting results from PHI model...")
|
142 |
+
phi_results = st.session_state['phi_model'].get_descriptors(text,pos_arr)
|
143 |
+
display_area.text("Aggregating BIO & PHI results...")
|
144 |
+
bio_ner = st.session_state['ner_bio'].tag_sentence_service(text,bio_descs)
|
145 |
+
phi_ner = st.session_state['ner_phi'].tag_sentence_service(text,phi_results)
|
146 |
+
|
147 |
+
combined_arr = [json.loads(bio_ner),json.loads(phi_ner)]
|
148 |
+
|
149 |
+
aggregate_results = st.session_state['aggr'].fetch_all(text,combined_arr)
|
150 |
+
return aggregate_results
|
151 |
+
|
152 |
+
|
153 |
+
sent_arr = [
|
154 |
+
"Lou Gehrig who works for XCorp and lives in New York suffers from Parkinson's ",
|
155 |
+
"Parkinson who works for XCorp and lives in New York suffers from Lou Gehrig's",
|
156 |
+
"lou gehrig was diagnosed with Parkinson's ",
|
157 |
+
"A eGFR below 60 indicates chronic kidney disease",
|
158 |
+
"Overexpression of EGFR occurs across a wide range of different cancers",
|
159 |
+
"Stanford called",
|
160 |
+
"He was diagnosed with non small cell lung cancer",
|
161 |
+
"I met my girl friends at the pub ",
|
162 |
+
"I met my New York friends at the pub",
|
163 |
+
"I met my XCorp friends at the pub",
|
164 |
+
"I met my two friends at the pub",
|
165 |
+
"Bio-Techne's genomic tools include advanced tissue-based in-situ hybridization assays sold under the ACD brand as well as a portfolio of assays for prostate cancer diagnosis ",
|
166 |
+
"There are no treatment options specifically indicated for ACD and physicians must utilize agents approved for other dermatology conditions", "As ACD has been implicated in apoptosis-resistant glioblastoma (GBM), there is a high medical need for identifying novel ACD-inducing drugs ",
|
167 |
+
"Located in the heart of Dublin , in the family home of acclaimed writer Oscar Wilde , ACD provides the perfect backdrop to inspire Irish (and Irish-at-heart) students to excel in business and the arts",
|
168 |
+
"Patients treated with anticancer chemotherapy drugs ( ACD ) are vulnerable to infectious diseases due to immunosuppression and to the direct impact of ACD on their intestinal microbiota ",
|
169 |
+
"In the LASOR trial , increasing daily imatinib dose from 400 to 600mg induced MMR at 12 and 24 months in 25% and 36% of the patients, respectively, who had suboptimal cytogenetic responses ",
|
170 |
+
"The sky turned dark in advance of the storm that was coming from the east ",
|
171 |
+
"She loves to watch Sunday afternoon football with her family ",
|
172 |
+
"Paul Erdos died at 83 "
|
173 |
+
]
|
174 |
+
|
175 |
+
|
176 |
+
sent_arr_masked = [
|
177 |
+
"Lou Gehrig:__entity__ who works for XCorp:__entity__ and lives in New:__entity__ York:__entity__ suffers from Parkinson's:__entity__ ",
|
178 |
+
"Parkinson:__entity__ who works for XCorp:__entity__ and lives in New:__entity__ York:__entity__ suffers from Lou Gehrig's:__entity__",
|
179 |
+
"lou:__entity__ gehrig:__entity__ was diagnosed with Parkinson's:__entity__ ",
|
180 |
+
"A eGFR:__entity__ below 60 indicates chronic kidney disease",
|
181 |
+
"Overexpression of EGFR:__entity__ occurs across a wide range of different cancers",
|
182 |
+
"Stanford:__entity__ called",
|
183 |
+
"He was diagnosed with non:__entity__ small:__entity__ cell:__entity__ lung:__entity__ cancer:__entity__",
|
184 |
+
"I met my girl:__entity__ friends at the pub ",
|
185 |
+
"I met my New:__entity__ York:__entity__ friends at the pub",
|
186 |
+
"I met my XCorp:__entity__ friends at the pub",
|
187 |
+
"I met my two:__entity__ friends at the pub",
|
188 |
+
"Bio-Techne's genomic tools include advanced tissue-based in-situ hybridization assays sold under the ACD:__entity__ brand as well as a portfolio of assays for prostate cancer diagnosis ",
|
189 |
+
"There are no treatment options specifically indicated for ACD:__entity__ and physicians must utilize agents approved for other dermatology conditions",
|
190 |
+
"As ACD:__entity__ has been implicated in apoptosis-resistant glioblastoma (GBM), there is a high medical need for identifying novel ACD-inducing drugs ",
|
191 |
+
"Located in the heart of Dublin , in the family home of acclaimed writer Oscar Wilde , ACD:__entity__ provides the perfect backdrop to inspire Irish (and Irish-at-heart) students to excel in business and the arts",
|
192 |
+
"Patients treated with anticancer chemotherapy drugs ( ACD:__entity__ ) are vulnerable to infectious diseases due to immunosuppression and to the direct impact of ACD on their intestinal microbiota ",
|
193 |
+
"In the LASOR:__entity__ trial:__entity__ , increasing daily imatinib dose from 400 to 600mg induced MMR at 12 and 24 months in 25% and 36% of the patients, respectively, who had suboptimal cytogenetic responses ",
|
194 |
+
"The sky turned dark:__entity__ in advance of the storm that was coming from the east ",
|
195 |
+
"She loves to watch Sunday afternoon football:__entity__ with her family ",
|
196 |
+
"Paul:__entity__ Erdos:__entity__ died at 83:__entity__ "
|
197 |
+
]
|
198 |
+
|
199 |
+
def init_selectbox():
|
200 |
+
return st.selectbox(
|
201 |
+
'Choose any of the sentences in pull-down below',
|
202 |
+
sent_arr,key='my_choice')
|
203 |
+
|
204 |
+
|
205 |
+
def on_text_change():
|
206 |
+
text = st.session_state.my_text
|
207 |
+
print("in callback: " + text)
|
208 |
+
perform_inference(text)
|
209 |
+
|
210 |
+
def main():
|
211 |
+
try:
|
212 |
+
|
213 |
+
init_session_states()
|
214 |
+
|
215 |
+
st.markdown("<h3 style='text-align: center;'>NER using pretrained models with <a href='https://ajitrajasekharan.github.io/2021/01/02/my-first-post.html'>no fine tuning</a></h3>", unsafe_allow_html=True)
|
216 |
+
#st.markdown("""
|
217 |
+
#<h3 style="font-size:16px; color: #ff0000; text-align: center"><b>App under construction... (not in working condition yet)</b></h3>
|
218 |
+
#""", unsafe_allow_html=True)
|
219 |
+
|
220 |
+
|
221 |
+
st.markdown("""
|
222 |
+
<p style="text-align:center;"><img src="https://ajitrajasekharan.github.io/images/1.png" width="700"></p>
|
223 |
+
<br/>
|
224 |
+
<br/>
|
225 |
+
""", unsafe_allow_html=True)
|
226 |
+
|
227 |
+
st.write("This app uses 3 models. Two Pretrained Bert models (**no fine tuning**) and a POS tagger")
|
228 |
+
|
229 |
+
|
230 |
+
with st.form('my_form'):
|
231 |
+
selected_sentence = init_selectbox()
|
232 |
+
text_input = st.text_area(label='Type any sentence below',value="")
|
233 |
+
submit_button = st.form_submit_button('Submit')
|
234 |
+
input_status_area = st.empty()
|
235 |
+
display_area = st.empty()
|
236 |
+
if submit_button:
|
237 |
+
start = time.time()
|
238 |
+
if (len(text_input) == 0):
|
239 |
+
text_input = sent_arr_masked[sent_arr.index(selected_sentence)]
|
240 |
+
input_status_area.text("Input sentence: " + text_input)
|
241 |
+
results = perform_inference(text_input,display_area)
|
242 |
+
display_area.empty()
|
243 |
+
with display_area.container():
|
244 |
+
st.text(f"prediction took {time.time() - start:.2f}s")
|
245 |
+
st.json(results)
|
246 |
+
|
247 |
+
|
248 |
+
|
249 |
+
|
250 |
+
|
251 |
+
#input_text = st.text_area(
|
252 |
+
# label="Type any sentence",
|
253 |
+
# on_change=on_text_change,key='my_text'
|
254 |
+
# )
|
255 |
+
|
256 |
+
st.markdown("""
|
257 |
+
<small style="font-size:16px; color: #7f7f7f; text-align: left"><br/><br/>Models used: <br/>(1) <a href='https://huggingface.co/ajitrajasekharan/biomedical' target='_blank'>Biomedical model</a> pretrained on Pubmed,Clinical trials and BookCorpus subset.<br/>(2) Bert-base-cased (for PHI entities - Person/location/organization etc.)<br/>(3) Flair POS tagger</small>
|
258 |
+
#""", unsafe_allow_html=True)
|
259 |
+
st.markdown("""
|
260 |
+
<h3 style="font-size:16px; color: #9f9f9f; text-align: center"><b> <a href='https://huggingface.co/spaces/ajitrajasekharan/Qualitative-pretrained-model-evaluation' target='_blank'>App link to examine pretrained models</a> used to perform NER without fine tuning</b></h3>
|
261 |
+
""", unsafe_allow_html=True)
|
262 |
+
st.markdown("""
|
263 |
+
<h3 style="font-size:16px; color: #9f9f9f; text-align: center">Github <a href='http://github.com/ajitrajasekharan/unsupervised_NER' target='_blank'>link to same working code </a>(without UI) as separate microservices</h3>
|
264 |
+
""", unsafe_allow_html=True)
|
265 |
+
|
266 |
+
except Exception as e:
|
267 |
+
print("Some error occurred in main")
|
268 |
+
st.exception(e)
|
269 |
+
|
270 |
+
if __name__ == "__main__":
|
271 |
+
main()
|
batched_main_NER.py
ADDED
@@ -0,0 +1,905 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pdb
|
2 |
+
import config_utils as cf
|
3 |
+
import requests
|
4 |
+
import sys
|
5 |
+
import urllib.parse
|
6 |
+
import numpy as np
|
7 |
+
from collections import OrderedDict
|
8 |
+
import argparse
|
9 |
+
from common import *
|
10 |
+
import json
|
11 |
+
|
12 |
+
#WORD_POS = 1
|
13 |
+
#TAG_POS = 2
|
14 |
+
#MASK_TAG = "__entity__"
|
15 |
+
DEFAULT_CONFIG = "./config.json"
|
16 |
+
DISPATCH_MASK_TAG = "entity"
|
17 |
+
DESC_HEAD = "PIVOT_DESCRIPTORS:"
|
18 |
+
#TYPE2_AMB = "AMB2-"
|
19 |
+
TYPE2_AMB = ""
|
20 |
+
DUMMY_DESCS=10
|
21 |
+
DEFAULT_ENTITY_MAP = "entity_types_consolidated.txt"
|
22 |
+
|
23 |
+
#RESET_POS_TAG='RESET'
|
24 |
+
SPECIFIC_TAG=":__entity__"
|
25 |
+
|
26 |
+
|
27 |
+
def softmax(x):
|
28 |
+
"""Compute softmax values for each sets of scores in x."""
|
29 |
+
return np.exp(x) / np.sum(np.exp(x), axis=0)
|
30 |
+
|
31 |
+
|
32 |
+
#noun_tags = ['NFP','JJ','NN','FW','NNS','NNPS','JJS','JJR','NNP','POS','CD']
|
33 |
+
#cap_tags = ['NFP','JJ','NN','FW','NNS','NNPS','JJS','JJR','NNP','PRP']
|
34 |
+
|
35 |
+
def read_common_descs(file_name):
|
36 |
+
common_descs = {}
|
37 |
+
with open(file_name) as fp:
|
38 |
+
for line in fp:
|
39 |
+
common_descs[line.strip()] = 1
|
40 |
+
print("Common descs for filtering read:",len(common_descs))
|
41 |
+
return common_descs
|
42 |
+
|
43 |
+
def read_entity_map(file_name):
|
44 |
+
emap = {}
|
45 |
+
with open(file_name) as fp:
|
46 |
+
for line in fp:
|
47 |
+
line = line.rstrip('\n')
|
48 |
+
entities = line.split()
|
49 |
+
if (len(entities) == 1):
|
50 |
+
assert(entities[0] not in emap)
|
51 |
+
emap[entities[0]] = entities[0]
|
52 |
+
else:
|
53 |
+
assert(len(entities) == 2)
|
54 |
+
entity_arr = entities[1].split('/')
|
55 |
+
if (entities[0] not in emap):
|
56 |
+
emap[entities[0]] = entities[0]
|
57 |
+
for entity in entity_arr:
|
58 |
+
assert(entity not in emap)
|
59 |
+
emap[entity] = entities[0]
|
60 |
+
print("Entity map:",len(emap))
|
61 |
+
return emap
|
62 |
+
|
63 |
+
class UnsupNER:
|
64 |
+
def __init__(self,config_file):
|
65 |
+
print("NER service handler started")
|
66 |
+
base_path = cf.read_config(config_file)["BASE_PATH"] if ("BASE_PATH" in cf.read_config(config_file)) else "./"
|
67 |
+
self.pos_server_url = cf.read_config(config_file)["POS_SERVER_URL"]
|
68 |
+
self.desc_server_url = cf.read_config(config_file)["DESC_SERVER_URL"]
|
69 |
+
self.entity_server_url = cf.read_config(config_file)["ENTITY_SERVER_URL"]
|
70 |
+
self.common_descs = read_common_descs(cf.read_config(config_file)["COMMON_DESCS_FILE"])
|
71 |
+
self.entity_map = read_entity_map(cf.read_config(config_file)["EMAP_FILE"])
|
72 |
+
self.rfp = open(base_path + "log_results.txt","a")
|
73 |
+
self.dfp = open(base_path + "log_debug.txt","a")
|
74 |
+
self.algo_ci_tag_fp = open(base_path + "algorthimic_ci_tags.txt","a")
|
75 |
+
print(self.pos_server_url)
|
76 |
+
print(self.desc_server_url)
|
77 |
+
print(self.entity_server_url)
|
78 |
+
np.set_printoptions(suppress=True) #this suppresses exponential representation when np is used to round
|
79 |
+
if (cf.read_config(config_file)["SUPPRESS_UNTAGGED"] == "1"):
|
80 |
+
self.suppress_untagged = True
|
81 |
+
else:
|
82 |
+
self.suppress_untagged = False #This is disabled in full debug text mode
|
83 |
+
|
84 |
+
|
85 |
+
#This is bad hack for prototyping - parsing from text output as opposed to json
|
86 |
+
def extract_POS(self,text):
|
87 |
+
arr = text.split('\n')
|
88 |
+
if (len(arr) > 0):
|
89 |
+
start_pos = 0
|
90 |
+
for i,line in enumerate(arr):
|
91 |
+
if (len(line) > 0):
|
92 |
+
start_pos += 1
|
93 |
+
continue
|
94 |
+
else:
|
95 |
+
break
|
96 |
+
#print(arr[start_pos:])
|
97 |
+
terms_arr = []
|
98 |
+
for i,line in enumerate(arr[start_pos:]):
|
99 |
+
terms = line.split('\t')
|
100 |
+
if (len(terms) == 5):
|
101 |
+
#print(terms)
|
102 |
+
terms_arr.append(terms)
|
103 |
+
return terms_arr
|
104 |
+
|
105 |
+
def normalize_casing(self,sent):
|
106 |
+
sent_arr = sent.split()
|
107 |
+
ret_sent_arr = []
|
108 |
+
for i,word in enumerate(sent_arr):
|
109 |
+
if (len(word) > 1):
|
110 |
+
norm_word = word[0] + word[1:].lower()
|
111 |
+
else:
|
112 |
+
norm_word = word[0]
|
113 |
+
ret_sent_arr.append(norm_word)
|
114 |
+
return ' '.join(ret_sent_arr)
|
115 |
+
|
116 |
+
#Full sentence tag call also generates json output.
|
117 |
+
def tag_sentence_service(self,text,desc_obj):
|
118 |
+
ret_str = self.tag_sentence(text,self.rfp,self.dfp,True,desc_obj)
|
119 |
+
return ret_str
|
120 |
+
|
121 |
+
def dictify_ner_response(self,ner_str):
|
122 |
+
arr = ner_str.split('\n')
|
123 |
+
ret_dict = OrderedDict()
|
124 |
+
count = 1
|
125 |
+
ref_indices_arr = []
|
126 |
+
for line in arr:
|
127 |
+
terms = line.split()
|
128 |
+
if (len(terms) == 2):
|
129 |
+
ret_dict[count] = {"term":terms[0],"e":terms[1]}
|
130 |
+
if (terms[1] != "O" and terms[1].startswith("B_")):
|
131 |
+
ref_indices_arr.append(count)
|
132 |
+
count += 1
|
133 |
+
elif (len(terms) == 1):
|
134 |
+
ret_dict[count] = {"term":"empty","e":terms[0]}
|
135 |
+
if (terms[0] != "O" and terms[0].startswith("B_")):
|
136 |
+
ref_indices_arr.append(count)
|
137 |
+
count += 1
|
138 |
+
if (len(ret_dict) > 3): #algorithmic harvesting of CI labels for human verification and adding to bootstrap list
|
139 |
+
self.algo_ci_tag_fp.write("SENT:" + ner_str.replace('\n',' ') + "\n")
|
140 |
+
out = terms[0].replace('[',' ').replace(']','').split()[-1]
|
141 |
+
out = '_'.join(out.split('_')[1:]) if out.startswith("B_") else out
|
142 |
+
print(out)
|
143 |
+
self.algo_ci_tag_fp.write(ret_dict[count-2]["term"] + " " + out + "\n")
|
144 |
+
self.algo_ci_tag_fp.flush()
|
145 |
+
else:
|
146 |
+
assert(len(terms) == 0) #If not empty something is not right
|
147 |
+
return ret_dict,ref_indices_arr
|
148 |
+
|
149 |
+
def blank_entity_sentence(self,sent,dfp):
|
150 |
+
value = True if sent.endswith(" :__entity__\n") else False
|
151 |
+
if (value == True):
|
152 |
+
print("\n\n**************** Skipping CI prediction in pooling for sent:",sent)
|
153 |
+
dfp.write("\n\n**************** Skipping CI prediction in pooling for sent:" + sent + "\n")
|
154 |
+
return value
|
155 |
+
|
156 |
+
def pool_confidences(self,ci_entities,ci_confidences,ci_subtypes,cs_entities,cs_confidences,cs_subtypes,debug_str_arr,sent,dfp):
|
157 |
+
main_classes = {}
|
158 |
+
assert(len(cs_entities) == len(cs_confidences))
|
159 |
+
assert(len(cs_subtypes) == len(cs_entities))
|
160 |
+
assert(len(ci_entities) == len(ci_confidences))
|
161 |
+
assert(len(ci_subtypes) == len(ci_entities))
|
162 |
+
#Pool entity classes across CI and CS
|
163 |
+
is_blank_statement = self.blank_entity_sentence(sent,dfp) #Do not pool CI confidences of the sentences of the form " is a entity". These sentences are sent for purely algo harvesting of CS terms. CI predictions will add noise.
|
164 |
+
if (not is_blank_statement): #Do not pool CI confidences of the sentences of the form " is a entity". These sentences are sent for purely algo harvesting of CS terms. CI predictions will add noise.
|
165 |
+
for e,c in zip(ci_entities,ci_confidences):
|
166 |
+
e_base = e.split('[')[0]
|
167 |
+
main_classes[e_base] = float(c)
|
168 |
+
for e,c in zip(cs_entities,cs_confidences):
|
169 |
+
e_base = e.split('[')[0]
|
170 |
+
if (e_base in main_classes):
|
171 |
+
main_classes[e_base] += float(c)
|
172 |
+
else:
|
173 |
+
main_classes[e_base] = float(c)
|
174 |
+
final_sorted_d = OrderedDict(sorted(main_classes.items(), key=lambda kv: kv[1], reverse=True))
|
175 |
+
main_dist = self.convert_positive_nums_to_dist(final_sorted_d)
|
176 |
+
main_classes_arr = list(final_sorted_d.keys())
|
177 |
+
#print("\nIn pooling confidences")
|
178 |
+
#print(main_classes_arr)
|
179 |
+
#print(main_dist)
|
180 |
+
#Pool subtypes across CI and CS for a particular entity class
|
181 |
+
subtype_factors = {}
|
182 |
+
for e_class in final_sorted_d:
|
183 |
+
if e_class in cs_subtypes:
|
184 |
+
stypes = cs_subtypes[e_class]
|
185 |
+
if (e_class not in subtype_factors):
|
186 |
+
subtype_factors[e_class] = {}
|
187 |
+
for st in stypes:
|
188 |
+
if (st in subtype_factors[e_class]):
|
189 |
+
subtype_factors[e_class][st] += stypes[st]
|
190 |
+
else:
|
191 |
+
subtype_factors[e_class][st] = stypes[st]
|
192 |
+
if (is_blank_statement):
|
193 |
+
continue
|
194 |
+
if e_class in ci_subtypes:
|
195 |
+
stypes = ci_subtypes[e_class]
|
196 |
+
if (e_class not in subtype_factors):
|
197 |
+
subtype_factors[e_class] = {}
|
198 |
+
for st in stypes:
|
199 |
+
if (st in subtype_factors[e_class]):
|
200 |
+
subtype_factors[e_class][st] += stypes[st]
|
201 |
+
else:
|
202 |
+
subtype_factors[e_class][st] = stypes[st]
|
203 |
+
sorted_subtype_factors = {}
|
204 |
+
for e_class in subtype_factors:
|
205 |
+
stypes = subtype_factors[e_class]
|
206 |
+
final_sorted_d = OrderedDict(sorted(stypes.items(), key=lambda kv: kv[1], reverse=True))
|
207 |
+
stypes_dist = self.convert_positive_nums_to_dist(final_sorted_d)
|
208 |
+
stypes_class_arr = list(final_sorted_d.keys())
|
209 |
+
sorted_subtype_factors[e_class] = {"stypes":stypes_class_arr,"dist":stypes_dist}
|
210 |
+
pooled_results = OrderedDict()
|
211 |
+
assert(len(main_classes_arr) == len(main_dist))
|
212 |
+
d_str_arr = []
|
213 |
+
d_str_arr.append("\n***CONSOLIDATED ENTITY:")
|
214 |
+
for e,c in zip(main_classes_arr,main_dist):
|
215 |
+
pooled_results[e] = {"e":e,"confidence":c}
|
216 |
+
d_str_arr.append(e + " " + str(c))
|
217 |
+
stypes_dict = sorted_subtype_factors[e]
|
218 |
+
pooled_st = OrderedDict()
|
219 |
+
for st,sd in zip(stypes_dict["stypes"],stypes_dict["dist"]):
|
220 |
+
pooled_st[st] = sd
|
221 |
+
pooled_results[e]["stypes"] = pooled_st
|
222 |
+
debug_str_arr.append(' '.join(d_str_arr))
|
223 |
+
print(' '.join(d_str_arr))
|
224 |
+
return pooled_results
|
225 |
+
|
226 |
+
|
227 |
+
|
228 |
+
|
229 |
+
|
230 |
+
|
231 |
+
|
232 |
+
|
233 |
+
|
234 |
+
def init_entity_info(self,entity_info_dict,index):
|
235 |
+
curr_term_dict = OrderedDict()
|
236 |
+
entity_info_dict[index] = curr_term_dict
|
237 |
+
curr_term_dict["ci"] = OrderedDict()
|
238 |
+
curr_term_dict["ci"]["entities"] = []
|
239 |
+
curr_term_dict["ci"]["descs"] = []
|
240 |
+
curr_term_dict["cs"] = OrderedDict()
|
241 |
+
curr_term_dict["cs"]["entities"] = []
|
242 |
+
curr_term_dict["cs"]["descs"] = []
|
243 |
+
|
244 |
+
|
245 |
+
|
246 |
+
|
247 |
+
#This now does specific tagging if there is a __entity__ in sentence; else does full tagging. TBD.
|
248 |
+
#TBD. Make response params same regardlesss of output format. Now it is different
|
249 |
+
def tag_sentence(self,sent,rfp,dfp,json_output,desc_obj):
|
250 |
+
print("Input: ", sent)
|
251 |
+
dfp.write("\n\n++++-------------------------------\n")
|
252 |
+
dfp.write("NER_INPUT: " + sent + "\n")
|
253 |
+
debug_str_arr = []
|
254 |
+
entity_info_dict = OrderedDict()
|
255 |
+
#url = self.desc_server_url + sent.replace('"','\'')
|
256 |
+
#r = self.dispatch_request(url)
|
257 |
+
#if (r is None):
|
258 |
+
# print("Empty response. Desc server is probably down: ",self.desc_server_url)
|
259 |
+
# return json.loads("[]")
|
260 |
+
#main_obj = json.loads(r.text)
|
261 |
+
main_obj = desc_obj
|
262 |
+
#print(json.dumps(main_obj,indent=4))
|
263 |
+
#Find CI predictions for ALL masked predictios in sentence
|
264 |
+
ci_predictions,orig_ci_entities = self.find_ci_entities(main_obj,debug_str_arr,entity_info_dict) #ci_entities is the same info as ci_predictions except packed differently for output
|
265 |
+
#Find CS predictions for ALL masked predictios in sentence. Use the CI predictions from previous step to
|
266 |
+
#pool
|
267 |
+
detected_entities_arr,ner_str,full_pooled_results,orig_cs_entities = self.find_cs_entities(sent,main_obj,rfp,dfp,debug_str_arr,ci_predictions,entity_info_dict)
|
268 |
+
assert(len(detected_entities_arr) == len(entity_info_dict))
|
269 |
+
print("--------")
|
270 |
+
if (json_output):
|
271 |
+
if (len(detected_entities_arr) != len(entity_info_dict)):
|
272 |
+
if (len(entity_info_dict) == 0):
|
273 |
+
self.init_entity_info(entity_info_dict,index)
|
274 |
+
entity_info_dict[1]["cs"]["entities"].append([{"e":"O","confidence":1}])
|
275 |
+
entity_info_dict[1]["ci"]["entities"].append([{"e":"O","confidence":1}])
|
276 |
+
ret_dict,ref_indices_arr = self.dictify_ner_response(ner_str) #Convert ner string to a dictionary for json output
|
277 |
+
assert(len(ref_indices_arr) == len(detected_entities_arr))
|
278 |
+
assert(len(entity_info_dict) == len(detected_entities_arr))
|
279 |
+
cs_aux_dict = OrderedDict()
|
280 |
+
ci_aux_dict = OrderedDict()
|
281 |
+
cs_aux_orig_entities = OrderedDict()
|
282 |
+
ci_aux_orig_entities = OrderedDict()
|
283 |
+
pooled_pred_dict = OrderedDict()
|
284 |
+
count = 0
|
285 |
+
assert(len(full_pooled_results) == len(detected_entities_arr))
|
286 |
+
assert(len(full_pooled_results) == len(orig_cs_entities))
|
287 |
+
assert(len(full_pooled_results) == len(orig_ci_entities))
|
288 |
+
for e,c,p,o,i in zip(detected_entities_arr,entity_info_dict,full_pooled_results,orig_cs_entities,orig_ci_entities):
|
289 |
+
val = entity_info_dict[c]
|
290 |
+
#cs_aux_dict[ref_indices_arr[count]] = {"e":e,"cs_distribution":val["cs"]["entities"],"cs_descs":val["cs"]["descs"]}
|
291 |
+
pooled_pred_dict[ref_indices_arr[count]] = {"e": e, "cs_distribution": list(p.values())}
|
292 |
+
cs_aux_dict[ref_indices_arr[count]] = {"e":e,"cs_descs":val["cs"]["descs"]}
|
293 |
+
#ci_aux_dict[ref_indices_arr[count]] = {"ci_distribution":val["ci"]["entities"],"ci_descs":val["ci"]["descs"]}
|
294 |
+
ci_aux_dict[ref_indices_arr[count]] = {"ci_descs":val["ci"]["descs"]}
|
295 |
+
cs_aux_orig_entities[ref_indices_arr[count]] = {"e":e,"cs_distribution": o}
|
296 |
+
ci_aux_orig_entities[ref_indices_arr[count]] = {"e":e,"cs_distribution": i}
|
297 |
+
count += 1
|
298 |
+
#print(ret_dict)
|
299 |
+
#print(aux_dict)
|
300 |
+
final_ret_dict = {"total_terms_count":len(ret_dict),"detected_entity_phrases_count":len(detected_entities_arr),"ner":ret_dict,"entity_distribution":pooled_pred_dict,"cs_prediction_details":cs_aux_dict,"ci_prediction_details":ci_aux_dict,"orig_cs_prediction_details":cs_aux_orig_entities,"orig_ci_prediction_details":ci_aux_orig_entities,"debug":debug_str_arr}
|
301 |
+
json_str = json.dumps(final_ret_dict,indent = 4)
|
302 |
+
#print (json_str)
|
303 |
+
#with open("single_debug.txt","w") as fp:
|
304 |
+
# fp.write(json_str)
|
305 |
+
|
306 |
+
dfp.write('\n'.join(debug_str_arr))
|
307 |
+
dfp.write("\n\nEND-------------------------------\n")
|
308 |
+
dfp.flush()
|
309 |
+
return json_str
|
310 |
+
else:
|
311 |
+
print(detected_entities_arr)
|
312 |
+
debug_str_arr.append("NER_FINAL_RESULTS: " + ' '.join(detected_entities_arr))
|
313 |
+
print("--------")
|
314 |
+
dfp.write('\n'.join(debug_str_arr))
|
315 |
+
dfp.write("\n\nEND-------------------------------\n")
|
316 |
+
dfp.flush()
|
317 |
+
return detected_entities_arr,span_arr,terms_arr,ner_str,debug_str_arr
|
318 |
+
|
319 |
+
def masked_word_first_letter_capitalize(self,entity):
|
320 |
+
arr = entity.split()
|
321 |
+
ret_arr = []
|
322 |
+
for term in arr:
|
323 |
+
if (len(term) > 1 and term[0].islower() and term[1].islower()):
|
324 |
+
ret_arr.append(term[0].upper() + term[1:])
|
325 |
+
else:
|
326 |
+
ret_arr.append(term)
|
327 |
+
return ' '.join(ret_arr)
|
328 |
+
|
329 |
+
|
330 |
+
def gen_single_phrase_sentences(self,terms_arr,masked_sent_arr,span_arr,rfp,dfp):
|
331 |
+
sentence_template = "%s is a entity"
|
332 |
+
print(span_arr)
|
333 |
+
sentences = []
|
334 |
+
singleton_spans_arr = []
|
335 |
+
run_index = 0
|
336 |
+
entity = ""
|
337 |
+
singleton_span = []
|
338 |
+
while (run_index < len(span_arr)):
|
339 |
+
if (span_arr[run_index] == 1):
|
340 |
+
while (run_index < len(span_arr)):
|
341 |
+
if (span_arr[run_index] == 1):
|
342 |
+
#print(terms_arr[run_index][WORD_POS],end=' ')
|
343 |
+
if (len(entity) == 0):
|
344 |
+
entity = terms_arr[run_index][WORD_POS]
|
345 |
+
else:
|
346 |
+
entity = entity + " " + terms_arr[run_index][WORD_POS]
|
347 |
+
singleton_span.append(1)
|
348 |
+
run_index += 1
|
349 |
+
else:
|
350 |
+
break
|
351 |
+
#print()
|
352 |
+
for i in sentence_template.split():
|
353 |
+
if (i != "%s"):
|
354 |
+
singleton_span.append(0)
|
355 |
+
entity = self.masked_word_first_letter_capitalize(entity)
|
356 |
+
sentence = sentence_template % entity
|
357 |
+
sentences.append(sentence)
|
358 |
+
singleton_spans_arr.append(singleton_span)
|
359 |
+
print(sentence)
|
360 |
+
print(singleton_span)
|
361 |
+
entity = ""
|
362 |
+
singleton_span = []
|
363 |
+
else:
|
364 |
+
run_index += 1
|
365 |
+
return sentences,singleton_spans_arr
|
366 |
+
|
367 |
+
|
368 |
+
def find_ci_entities(self,main_obj,debug_str_arr,entity_info_dict):
|
369 |
+
ci_predictions = []
|
370 |
+
orig_ci_confidences = []
|
371 |
+
term_index = 1
|
372 |
+
batch_obj = main_obj["descs_and_entities"]
|
373 |
+
for key in batch_obj:
|
374 |
+
masked_sent = batch_obj[key]["ci_prediction"]["sentence"]
|
375 |
+
print("\n**CI: ", masked_sent)
|
376 |
+
debug_str_arr.append(masked_sent)
|
377 |
+
#entity_info_dict["masked_sent"].append(masked_sent)
|
378 |
+
inp_arr = batch_obj[key]["ci_prediction"]["descs"]
|
379 |
+
descs = self.get_descriptors_for_masked_position(inp_arr)
|
380 |
+
self.init_entity_info(entity_info_dict,term_index)
|
381 |
+
entities,confidences,subtypes = self.get_entities_for_masked_position(inp_arr,descs,debug_str_arr,entity_info_dict[term_index]["ci"])
|
382 |
+
ci_predictions.append({"entities":entities,"confidences":confidences,"subtypes":subtypes})
|
383 |
+
orig_ci_confidences.append(self.pack_confidences(entities,confidences)) #this is sent for ensemble server to detect cross predictions. CS predicitons are more reflective of cross over than consolidated predictions, since CI may overwhelm CS
|
384 |
+
term_index += 1
|
385 |
+
return ci_predictions,orig_ci_confidences
|
386 |
+
|
387 |
+
|
388 |
+
def pack_confidences(self,cs_entities,cs_confidences):
|
389 |
+
assert(len(cs_entities) == len(cs_confidences))
|
390 |
+
orig_cs_arr = []
|
391 |
+
for e,c in zip(cs_entities,cs_confidences):
|
392 |
+
print(e,c)
|
393 |
+
e_split = e.split('[')
|
394 |
+
e_main = e_split[0]
|
395 |
+
if (len(e_split) > 1):
|
396 |
+
e_sub = e_split[1].split(',')[0].rstrip(']')
|
397 |
+
if (e_main != e_sub):
|
398 |
+
e = e_main + '[' + e_sub + ']'
|
399 |
+
else:
|
400 |
+
e = e_main
|
401 |
+
else:
|
402 |
+
e = e_main
|
403 |
+
orig_cs_arr.append({"e":e,"confidence":c})
|
404 |
+
return orig_cs_arr
|
405 |
+
|
406 |
+
|
407 |
+
#We have multiple masked versions of a single sentence. Tag each one of them
|
408 |
+
#and create a complete tagged version for a sentence
|
409 |
+
def find_cs_entities(self,sent,main_obj,rfp,dfp,debug_str_arr,ci_predictions,entity_info_dict):
|
410 |
+
#print(sent)
|
411 |
+
batch_obj = main_obj["descs_and_entities"]
|
412 |
+
dfp.write(sent + "\n")
|
413 |
+
term_index = 1
|
414 |
+
detected_entities_arr = []
|
415 |
+
full_pooled_results = []
|
416 |
+
orig_cs_confidences = []
|
417 |
+
for index,key in enumerate(batch_obj):
|
418 |
+
position_info = batch_obj[key]["cs_prediction"]["descs"]
|
419 |
+
ci_entities = ci_predictions[index]["entities"]
|
420 |
+
ci_confidences = ci_predictions[index]["confidences"]
|
421 |
+
ci_subtypes = ci_predictions[index]["subtypes"]
|
422 |
+
debug_str_arr.append("\n++++++ nth Masked term : " + str(key))
|
423 |
+
#dfp.write(key + "\n")
|
424 |
+
masked_sent = batch_obj[key]["cs_prediction"]["sentence"]
|
425 |
+
print("\n**CS: ",masked_sent)
|
426 |
+
descs = self.get_descriptors_for_masked_position(position_info)
|
427 |
+
#dfp.write(str(descs) + "\n")
|
428 |
+
if (len(descs) > 0):
|
429 |
+
cs_entities,cs_confidences,cs_subtypes = self.get_entities_for_masked_position(position_info,descs,debug_str_arr,entity_info_dict[term_index]["cs"])
|
430 |
+
else:
|
431 |
+
cs_entities = []
|
432 |
+
cs_confidences = []
|
433 |
+
cs_subtypes = []
|
434 |
+
#dfp.write(str(cs_entities) + "\n")
|
435 |
+
pooled_results = self.pool_confidences(ci_entities,ci_confidences,ci_subtypes,cs_entities,cs_confidences,cs_subtypes,debug_str_arr,sent,dfp)
|
436 |
+
self.fill_detected_entities(detected_entities_arr,pooled_results) #just picks the top prediction
|
437 |
+
full_pooled_results.append(pooled_results)
|
438 |
+
orig_cs_confidences.append(self.pack_confidences(cs_entities,cs_confidences)) #this is sent for ensemble server to detect cross predictions. CS predicitons are more reflective of cross over than consolidated predictions, since CI may overwhelm CS
|
439 |
+
#self.old_resolve_entities(i,singleton_entities,detected_entities_arr) #This decides how to pick entities given CI and CS predictions
|
440 |
+
term_index += 1
|
441 |
+
#out of the full loop over sentences. Now create NER sentence
|
442 |
+
terms_arr = main_obj["terms_arr"]
|
443 |
+
span_arr = main_obj["span_arr"]
|
444 |
+
ner_str = self.emit_sentence_entities(sent,terms_arr,detected_entities_arr,span_arr,rfp) #just outputs results in NER Conll format
|
445 |
+
dfp.flush()
|
446 |
+
return detected_entities_arr,ner_str,full_pooled_results,orig_cs_confidences
|
447 |
+
|
448 |
+
|
449 |
+
def fill_detected_entities(self,detected_entities_arr,entities):
|
450 |
+
if (len(entities) > 0):
|
451 |
+
top_e_class = next(iter(entities))
|
452 |
+
top_subtype = next(iter(entities[top_e_class]["stypes"]))
|
453 |
+
if (top_e_class != top_subtype):
|
454 |
+
top_prediction = top_e_class + "[" + top_subtype + "]"
|
455 |
+
else:
|
456 |
+
top_prediction = top_e_class
|
457 |
+
detected_entities_arr.append(top_prediction)
|
458 |
+
else:
|
459 |
+
detected_entities_arr.append("OTHER")
|
460 |
+
|
461 |
+
|
462 |
+
def fill_detected_entities_old(self,detected_entities_arr,entities,pan_arr):
|
463 |
+
entities_dict = {}
|
464 |
+
count = 1
|
465 |
+
for i in entities:
|
466 |
+
cand = i.split("-")
|
467 |
+
for j in cand:
|
468 |
+
terms = j.split("/")
|
469 |
+
for k in terms:
|
470 |
+
if (k not in entities_dict):
|
471 |
+
entities_dict[k] = 1.0/count
|
472 |
+
else:
|
473 |
+
entities_dict[k] += 1.0/count
|
474 |
+
count += 1
|
475 |
+
final_sorted_d = OrderedDict(sorted(entities_dict.items(), key=lambda kv: kv[1], reverse=True))
|
476 |
+
first = "OTHER"
|
477 |
+
for first in final_sorted_d:
|
478 |
+
break
|
479 |
+
detected_entities_arr.append(first)
|
480 |
+
|
481 |
+
#Contextual entity is picked as first candidate before context independent candidate
|
482 |
+
def old_resolve_entities(self,index,singleton_entities,detected_entities_arr):
|
483 |
+
if (singleton_entities[index].split('[')[0] != detected_entities_arr[index].split('[')[0]):
|
484 |
+
if (singleton_entities[index].split('[')[0] != "OTHER" and detected_entities_arr[index].split('[')[0] != "OTHER"):
|
485 |
+
detected_entities_arr[index] = detected_entities_arr[index] + "/" + singleton_entities[index]
|
486 |
+
elif (detected_entities_arr[index].split('[')[0] == "OTHER"):
|
487 |
+
detected_entities_arr[index] = singleton_entities[index]
|
488 |
+
else:
|
489 |
+
pass
|
490 |
+
else:
|
491 |
+
#this is the case when both CI and CS entity type match. Since the subtypes are already ordered, just merge(CS/CI,CS/CI...) the two picking unique subtypes
|
492 |
+
main_entity = detected_entities_arr[index].split('[')[0]
|
493 |
+
cs_arr = detected_entities_arr[index].split('[')[1].rstrip(']').split(',')
|
494 |
+
ci_arr = singleton_entities[index].split('[')[1].rstrip(']').split(',')
|
495 |
+
cs_arr_len = len(cs_arr)
|
496 |
+
ci_arr_len = len(ci_arr)
|
497 |
+
max_len = ci_arr_len if ci_arr_len > cs_arr_len else cs_arr_len
|
498 |
+
merged_unique_subtype_dict = OrderedDict()
|
499 |
+
for i in range(cs_arr_len):
|
500 |
+
if (i < cs_arr_len and cs_arr[i] not in merged_unique_subtype_dict):
|
501 |
+
merged_unique_subtype_dict[cs_arr[i]] = 1
|
502 |
+
if (i < ci_arr_len and ci_arr[i] not in merged_unique_subtype_dict):
|
503 |
+
merged_unique_subtype_dict[ci_arr[i]] = 1
|
504 |
+
new_subtypes_str = ','.join(list(merged_unique_subtype_dict.keys()))
|
505 |
+
detected_entities_arr[index] = main_entity + '[' + new_subtypes_str + ']'
|
506 |
+
|
507 |
+
|
508 |
+
|
509 |
+
|
510 |
+
|
511 |
+
|
512 |
+
def emit_sentence_entities(self,sent,terms_arr,detected_entities_arr,span_arr,rfp):
|
513 |
+
print("Final result")
|
514 |
+
ret_str = ""
|
515 |
+
for i,term in enumerate(terms_arr):
|
516 |
+
print(term,' ',end='')
|
517 |
+
print()
|
518 |
+
sent_arr = sent.split()
|
519 |
+
assert(len(terms_arr) == len(span_arr))
|
520 |
+
entity_index = 0
|
521 |
+
i = 0
|
522 |
+
in_span = False
|
523 |
+
while (i < len(span_arr)):
|
524 |
+
if (span_arr[i] == 0):
|
525 |
+
tag = "O"
|
526 |
+
if (in_span):
|
527 |
+
in_span = False
|
528 |
+
entity_index += 1
|
529 |
+
else:
|
530 |
+
if (in_span):
|
531 |
+
tag = "I_" + detected_entities_arr[entity_index]
|
532 |
+
else:
|
533 |
+
in_span = True
|
534 |
+
tag = "B_" + detected_entities_arr[entity_index]
|
535 |
+
rfp.write(terms_arr[i] + ' ' + tag + "\n")
|
536 |
+
ret_str = ret_str + terms_arr[i] + ' ' + tag + "\n"
|
537 |
+
print(tag + ' ',end='')
|
538 |
+
i += 1
|
539 |
+
print()
|
540 |
+
rfp.write("\n")
|
541 |
+
ret_str += "\n"
|
542 |
+
rfp.flush()
|
543 |
+
return ret_str
|
544 |
+
|
545 |
+
|
546 |
+
|
547 |
+
|
548 |
+
|
549 |
+
def get_descriptors_for_masked_position(self,inp_arr):
|
550 |
+
desc_arr = []
|
551 |
+
for i in range(len(inp_arr)):
|
552 |
+
desc_arr.append(inp_arr[i]["desc"])
|
553 |
+
desc_arr.append(inp_arr[i]["v"])
|
554 |
+
return desc_arr
|
555 |
+
|
556 |
+
def dispatch_request(self,url):
|
557 |
+
max_retries = 10
|
558 |
+
attempts = 0
|
559 |
+
while True:
|
560 |
+
try:
|
561 |
+
r = requests.get(url,timeout=1000)
|
562 |
+
if (r.status_code == 200):
|
563 |
+
return r
|
564 |
+
except:
|
565 |
+
print("Request:", url, " failed. Retrying...")
|
566 |
+
attempts += 1
|
567 |
+
if (attempts >= max_retries):
|
568 |
+
print("Request:", url, " failed")
|
569 |
+
break
|
570 |
+
|
571 |
+
def convert_positive_nums_to_dist(self,final_sorted_d):
|
572 |
+
factors = list(final_sorted_d.values()) #convert dict values to an array
|
573 |
+
factors = list(map(float, factors))
|
574 |
+
total = float(sum(factors))
|
575 |
+
if (total == 0):
|
576 |
+
total = 1
|
577 |
+
factors[0] = 1 #just make the sum 100%. This a boundary case for numbers for instance
|
578 |
+
factors = np.array(factors)
|
579 |
+
#factors = softmax(factors)
|
580 |
+
factors = factors/total
|
581 |
+
factors = np.round(factors,4)
|
582 |
+
return factors
|
583 |
+
|
584 |
+
def get_desc_weights_total(self,count,desc_weights):
|
585 |
+
i = 0
|
586 |
+
total = 0
|
587 |
+
while (i < count):
|
588 |
+
total += float(desc_weights[i+1])
|
589 |
+
i += 2
|
590 |
+
total = 1 if total == 0 else total
|
591 |
+
return total
|
592 |
+
|
593 |
+
|
594 |
+
def aggregate_entities(self,entities,desc_weights,debug_str_arr,entity_info_dict_entities):
|
595 |
+
''' Given a masked position, whose entity we are trying to determine,
|
596 |
+
First get descriptors for that postion 2*N array [desc1,score1,desc2,score2,...]
|
597 |
+
Then for each descriptor, get entity predictions which is an array 2*N of the form [e1,score1,e2,score2,...] where e1 could be DRUG/DISEASE and score1 is 10/8 etc.
|
598 |
+
In this function we aggregate each unique entity prediction (e.g. DISEASE) by summing up its weighted scores across all N predictions.
|
599 |
+
The result factor array is normalized to create a probability distribution
|
600 |
+
'''
|
601 |
+
count = len(entities)
|
602 |
+
assert(count %2 == 0)
|
603 |
+
aggregate_entities = {}
|
604 |
+
i = 0
|
605 |
+
subtypes = {}
|
606 |
+
while (i < count):
|
607 |
+
#entities[i] contains entity names and entities[i+] contains counts. Example PROTEIN/GENE/PERSON is i and 10/4/7 is i+1
|
608 |
+
curr_counts = entities[i+1].split('/') #this is one of the N predictions - this single prediction is itself a list of entities
|
609 |
+
trunc_e,trunc_counts = self.map_entities(entities[i].split('/'),curr_counts,subtypes) # Aggregate the subtype entities for this predictions. Subtypes aggregation is **across** the N predictions
|
610 |
+
#Also trunc_e contains the consolidated entity names.
|
611 |
+
assert(len(trunc_e) <= len(curr_counts)) # can be less if untagged is skipped
|
612 |
+
assert(len(trunc_e) == len(trunc_counts))
|
613 |
+
trunc_counts = softmax(trunc_counts) #this normalization is done to reduce the effect of absolute count of certain labeled entities, while aggregating the entity vectors across descriptors
|
614 |
+
curr_counts_sum = sum(map(int,trunc_counts)) #Using truncated count
|
615 |
+
curr_counts_sum = 1 if curr_counts_sum == 0 else curr_counts_sum
|
616 |
+
for j in range(len(trunc_e)): #this is iterating through the current instance of all *consolidated* tagged entity predictons (that is except UNTAGGED_ENTITY)
|
617 |
+
if (self.skip_untagged(trunc_e[j])):
|
618 |
+
continue
|
619 |
+
if (trunc_e[j] not in aggregate_entities):
|
620 |
+
aggregate_entities[trunc_e[j]] = (float(trunc_counts[j]))*float(desc_weights[i+1])
|
621 |
+
#aggregate_entities[trunc_e[j]] = (float(trunc_counts[j])/curr_counts_sum)*float(desc_weights[i+1])
|
622 |
+
#aggregate_entities[trunc_e[j]] = float(desc_weights[i+1])
|
623 |
+
else:
|
624 |
+
aggregate_entities[trunc_e[j]] += (float(trunc_counts[j]))*float(desc_weights[i+1])
|
625 |
+
#aggregate_entities[trunc_e[j]] += (float(trunc_counts[j])/curr_counts_sum)*float(desc_weights[i+1])
|
626 |
+
#aggregate_entities[trunc_e[j]] += float(desc_weights[i+1])
|
627 |
+
i += 2
|
628 |
+
final_sorted_d = OrderedDict(sorted(aggregate_entities.items(), key=lambda kv: kv[1], reverse=True))
|
629 |
+
if (len(final_sorted_d) == 0): #Case where all terms are tagged OTHER
|
630 |
+
final_sorted_d = {"OTHER":1}
|
631 |
+
subtypes["OTHER"] = {"OTHER":1}
|
632 |
+
factors = self.convert_positive_nums_to_dist(final_sorted_d)
|
633 |
+
ret_entities = list(final_sorted_d.keys())
|
634 |
+
confidences = factors.tolist()
|
635 |
+
print(ret_entities)
|
636 |
+
sorted_subtypes = self.sort_subtypes(subtypes)
|
637 |
+
ret_entities = self.update_entities_with_subtypes(ret_entities,sorted_subtypes)
|
638 |
+
print(ret_entities)
|
639 |
+
debug_str_arr.append(" ")
|
640 |
+
debug_str_arr.append(' '.join(ret_entities))
|
641 |
+
print(confidences)
|
642 |
+
assert(len(confidences) == len(ret_entities))
|
643 |
+
arr = []
|
644 |
+
for e,c in zip(ret_entities,confidences):
|
645 |
+
arr.append({"e":e,"confidence":c})
|
646 |
+
entity_info_dict_entities.append(arr)
|
647 |
+
debug_str_arr.append(' '.join([str(x) for x in confidences]))
|
648 |
+
debug_str_arr.append("\n\n")
|
649 |
+
return ret_entities,confidences,subtypes
|
650 |
+
|
651 |
+
|
652 |
+
def sort_subtypes(self,subtypes):
|
653 |
+
sorted_subtypes = OrderedDict()
|
654 |
+
for ent in subtypes:
|
655 |
+
final_sorted_d = OrderedDict(sorted(subtypes[ent].items(), key=lambda kv: kv[1], reverse=True))
|
656 |
+
sorted_subtypes[ent] = list(final_sorted_d.keys())
|
657 |
+
return sorted_subtypes
|
658 |
+
|
659 |
+
def update_entities_with_subtypes(self,ret_entities,subtypes):
|
660 |
+
new_entities = []
|
661 |
+
|
662 |
+
for ent in ret_entities:
|
663 |
+
#if (len(ret_entities) == 1):
|
664 |
+
# new_entities.append(ent) #avoid creating a subtype for a single case
|
665 |
+
# return new_entities
|
666 |
+
if (ent in subtypes):
|
667 |
+
new_entities.append(ent + '[' + ','.join(subtypes[ent]) + ']')
|
668 |
+
else:
|
669 |
+
new_entities.append(ent)
|
670 |
+
return new_entities
|
671 |
+
|
672 |
+
def skip_untagged(self,term):
|
673 |
+
if (self.suppress_untagged == True and (term == "OTHER" or term == "UNTAGGED_ENTITY")):
|
674 |
+
return True
|
675 |
+
return False
|
676 |
+
|
677 |
+
|
678 |
+
def map_entities(self,arr,counts_arr,subtypes_dict):
|
679 |
+
ret_arr = []
|
680 |
+
new_counts_arr = []
|
681 |
+
for index,term in enumerate(arr):
|
682 |
+
if (self.skip_untagged(term)):
|
683 |
+
continue
|
684 |
+
ret_arr.append(self.entity_map[term])
|
685 |
+
new_counts_arr.append(int(counts_arr[index]))
|
686 |
+
if (self.entity_map[term] not in subtypes_dict):
|
687 |
+
subtypes_dict[self.entity_map[term]] = {}
|
688 |
+
if (term not in subtypes_dict[self.entity_map[term]]):
|
689 |
+
#subtypes_dict[self.entity_map[i]][i] = 1
|
690 |
+
subtypes_dict[self.entity_map[term]][term] = int(counts_arr[index])
|
691 |
+
else:
|
692 |
+
#subtypes_dict[self.entity_map[i]][i] += 1
|
693 |
+
subtypes_dict[self.entity_map[term]][term] += int(counts_arr[index])
|
694 |
+
return ret_arr,new_counts_arr
|
695 |
+
|
696 |
+
def get_entities_from_batch(self,inp_arr):
|
697 |
+
entities_arr = []
|
698 |
+
for i in range(len(inp_arr)):
|
699 |
+
entities_arr.append(inp_arr[i]["e"])
|
700 |
+
entities_arr.append(inp_arr[i]["e_count"])
|
701 |
+
return entities_arr
|
702 |
+
|
703 |
+
|
704 |
+
def get_entities_for_masked_position(self,inp_arr,descs,debug_str_arr,entity_info_dict):
|
705 |
+
entities = self.get_entities_from_batch(inp_arr)
|
706 |
+
debug_combined_arr =[]
|
707 |
+
desc_arr =[]
|
708 |
+
assert(len(descs) %2 == 0)
|
709 |
+
assert(len(entities) %2 == 0)
|
710 |
+
index = 0
|
711 |
+
for d,e in zip(descs,entities):
|
712 |
+
p_e = '/'.join(e.split('/')[:5])
|
713 |
+
debug_combined_arr.append(d + " " + p_e)
|
714 |
+
if (index % 2 == 0):
|
715 |
+
temp_dict = OrderedDict()
|
716 |
+
temp_dict["d"] = d
|
717 |
+
temp_dict["e"] = e
|
718 |
+
else:
|
719 |
+
temp_dict["mlm"] = d
|
720 |
+
temp_dict["l_score"] = e
|
721 |
+
desc_arr.append(temp_dict)
|
722 |
+
index += 1
|
723 |
+
debug_str_arr.append("\n" + ', '.join(debug_combined_arr))
|
724 |
+
print(debug_combined_arr)
|
725 |
+
entity_info_dict["descs"] = desc_arr
|
726 |
+
#debug_str_arr.append(' '.join(entities))
|
727 |
+
assert(len(entities) == len(descs))
|
728 |
+
entities,confidences,subtypes = self.aggregate_entities(entities,descs,debug_str_arr,entity_info_dict["entities"])
|
729 |
+
return entities,confidences,subtypes
|
730 |
+
|
731 |
+
|
732 |
+
#This is again a bad hack for prototyping purposes - extracting fields from a raw text output as opposed to a structured output like json
|
733 |
+
def extract_descs(self,text):
|
734 |
+
arr = text.split('\n')
|
735 |
+
desc_arr = []
|
736 |
+
if (len(arr) > 0):
|
737 |
+
for i,line in enumerate(arr):
|
738 |
+
if (line.startswith(DESC_HEAD)):
|
739 |
+
terms = line.split(':')
|
740 |
+
desc_arr = ' '.join(terms[1:]).strip().split()
|
741 |
+
break
|
742 |
+
return desc_arr
|
743 |
+
|
744 |
+
|
745 |
+
def generate_masked_sentences(self,terms_arr):
|
746 |
+
size = len(terms_arr)
|
747 |
+
sentence_arr = []
|
748 |
+
span_arr = []
|
749 |
+
i = 0
|
750 |
+
while (i < size):
|
751 |
+
term_info = terms_arr[i]
|
752 |
+
if (term_info[TAG_POS] in noun_tags):
|
753 |
+
skip = self.gen_sentence(sentence_arr,terms_arr,i)
|
754 |
+
i += skip
|
755 |
+
for j in range(skip):
|
756 |
+
span_arr.append(1)
|
757 |
+
else:
|
758 |
+
i += 1
|
759 |
+
span_arr.append(0)
|
760 |
+
#print(sentence_arr)
|
761 |
+
return sentence_arr,span_arr
|
762 |
+
|
763 |
+
def gen_sentence(self,sentence_arr,terms_arr,index):
|
764 |
+
size = len(terms_arr)
|
765 |
+
new_sent = []
|
766 |
+
for prefix,term in enumerate(terms_arr[:index]):
|
767 |
+
new_sent.append(term[WORD_POS])
|
768 |
+
i = index
|
769 |
+
skip = 0
|
770 |
+
while (i < size):
|
771 |
+
if (terms_arr[i][TAG_POS] in noun_tags):
|
772 |
+
skip += 1
|
773 |
+
i += 1
|
774 |
+
else:
|
775 |
+
break
|
776 |
+
new_sent.append(MASK_TAG)
|
777 |
+
i = index + skip
|
778 |
+
while (i < size):
|
779 |
+
new_sent.append(terms_arr[i][WORD_POS])
|
780 |
+
i += 1
|
781 |
+
assert(skip != 0)
|
782 |
+
sentence_arr.append(new_sent)
|
783 |
+
return skip
|
784 |
+
|
785 |
+
|
786 |
+
|
787 |
+
|
788 |
+
|
789 |
+
|
790 |
+
|
791 |
+
|
792 |
+
def run_test(file_name,obj):
|
793 |
+
rfp = open("results.txt","w")
|
794 |
+
dfp = open("debug.txt","w")
|
795 |
+
with open(file_name) as fp:
|
796 |
+
count = 1
|
797 |
+
for line in fp:
|
798 |
+
if (len(line) > 1):
|
799 |
+
print(str(count) + "] ",line,end='')
|
800 |
+
obj.tag_sentence(line,rfp,dfp)
|
801 |
+
count += 1
|
802 |
+
rfp.close()
|
803 |
+
dfp.close()
|
804 |
+
|
805 |
+
|
806 |
+
def tag_single_entity_in_sentence(file_name,obj):
|
807 |
+
rfp = open("results.txt","w")
|
808 |
+
dfp = open("debug.txt","w")
|
809 |
+
sfp = open("se_results.txt","w")
|
810 |
+
with open(file_name) as fp:
|
811 |
+
count = 1
|
812 |
+
for line in fp:
|
813 |
+
if (len(line) > 1):
|
814 |
+
print(str(count) + "] ",line,end='')
|
815 |
+
#entity_arr,span_arr,terms_arr,ner_str,debug_str = obj.tag_sentence(line,rfp,dfp,False) # False for json output
|
816 |
+
json_str = obj.tag_sentence(line,rfp,dfp,True) # True for json output
|
817 |
+
#print("*******************:",terms_arr[span_arr.index(1)][WORD_POS].rstrip(":"),entity_arr[0])
|
818 |
+
#sfp.write(terms_arr[span_arr.index(1)][WORD_POS].rstrip(":") + " " + entity_arr[0] + "\n")
|
819 |
+
count += 1
|
820 |
+
sfp.flush()
|
821 |
+
#pdb.set_trace()
|
822 |
+
rfp.close()
|
823 |
+
sfp.close()
|
824 |
+
dfp.close()
|
825 |
+
|
826 |
+
|
827 |
+
|
828 |
+
|
829 |
+
test_arr = [
|
830 |
+
"He felt New:__entity__ York:__entity__ has a chance to win this year's competition",
|
831 |
+
"Ajit rajasekharan is an engineer at nFerence:__entity__",
|
832 |
+
"Ajit:__entity__ rajasekharan is an engineer:__entity__ at nFerence:__entity__",
|
833 |
+
"Mesothelioma:__entity__ is caused by exposure to asbestos:__entity__",
|
834 |
+
"Fyodor:__entity__ Mikhailovich:__entity__ Dostoevsky:__entity__ was treated for Parkinsons",
|
835 |
+
"Ajit:__entity__ Rajasekharan:__entity__ is an engineer at nFerence",
|
836 |
+
"A eGFR:__entity__ below 60 indicates chronic kidney disease",
|
837 |
+
"A eGFR below 60:__entity__ indicates chronic kidney disease",
|
838 |
+
"A eGFR:__entity__ below 60:__entity__ indicates chronic:__entity__ kidney:__entity__ disease:__entity__",
|
839 |
+
"Ajit:__entity__ rajasekharan is an engineer at nFerence",
|
840 |
+
"Her hypophysitis secondary to ipilimumab was well managed with supplemental hormones",
|
841 |
+
"In Seattle:__entity__ , Pete Incaviglia 's grand slam with one out in the sixth snapped a tie and lifted the Baltimore Orioles past the Seattle Mariners , 5-2 .",
|
842 |
+
"engineer",
|
843 |
+
"Austin:__entity__ called",
|
844 |
+
"Paul Erdős died at 83",
|
845 |
+
"Imatinib mesylate is a drug and is used to treat nsclc",
|
846 |
+
"In Seattle , Pete Incaviglia 's grand slam with one out in the sixth snapped a tie and lifted the Baltimore Orioles past the Seattle Mariners , 5-2 .",
|
847 |
+
"It was Incaviglia 's sixth grand slam and 200th homer of his career .",
|
848 |
+
"Add Women 's singles , third round Lisa Raymond ( U.S. ) beat Kimberly Po ( U.S. ) 6-3 6-2 .",
|
849 |
+
"1880s marked the beginning of Jazz",
|
850 |
+
"He flew from New York to SFO",
|
851 |
+
"Lionel Ritchie was popular in the 1980s",
|
852 |
+
"Lionel Ritchie was popular in the late eighties",
|
853 |
+
"John Doe flew from New York to Rio De Janiro via Miami",
|
854 |
+
"He felt New York has a chance to win this year's competition",
|
855 |
+
"Bandolier - Budgie ' , a free itunes app for ipad , iphone and ipod touch , released in December 2011 , tells the story of the making of Bandolier in the band 's own words - including an extensive audio interview with Burke Shelley",
|
856 |
+
"In humans mutations in Foxp2 leads to verbal dyspraxia",
|
857 |
+
"The recent spread of Corona virus flu from China to Italy,Iran, South Korea and Japan has caused global concern",
|
858 |
+
"Hotel California topped the singles chart",
|
859 |
+
"Elon Musk said Telsa will open a manufacturing plant in Europe",
|
860 |
+
"He flew from New York to SFO",
|
861 |
+
"After studies at Hofstra University , He worked for New York Telephone before He was elected to the New York State Assembly to represent the 16th District in Northwest Nassau County ",
|
862 |
+
"Everyday he rode his bicycle from Rajakilpakkam to Tambaram",
|
863 |
+
"If he loses Saturday , it could devalue his position as one of the world 's great boxers , \" Panamanian Boxing Association President Ramon Manzanares said .",
|
864 |
+
"West Indian all-rounder Phil Simmons took four for 38 on Friday as Leicestershire beat Somerset by an innings and 39 runs in two days to take over at the head of the county championship .",
|
865 |
+
"they are his friends ",
|
866 |
+
"they flew from Boston to Rio De Janiro and had a mocha",
|
867 |
+
"he flew from Boston to Rio De Janiro and had a mocha",
|
868 |
+
"X,Y,Z are medicines"]
|
869 |
+
|
870 |
+
|
871 |
+
def test_canned_sentences(obj):
|
872 |
+
rfp = open("results.txt","w")
|
873 |
+
dfp = open("debug.txt","w")
|
874 |
+
pdb.set_trace()
|
875 |
+
for line in test_arr:
|
876 |
+
ret_val = obj.tag_sentence(line,rfp,dfp,True)
|
877 |
+
pdb.set_trace()
|
878 |
+
rfp.close()
|
879 |
+
dfp.close()
|
880 |
+
|
881 |
+
if __name__ == '__main__':
|
882 |
+
parser = argparse.ArgumentParser(description='main NER for a single model ',formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
883 |
+
parser.add_argument('-input', action="store", dest="input",default="",help='Input file required for run options batch,single')
|
884 |
+
parser.add_argument('-config', action="store", dest="config", default=DEFAULT_CONFIG,help='config file path')
|
885 |
+
parser.add_argument('-option', action="store", dest="option",default="canned",help='Valid options are canned,batch,single. canned - test few canned sentences used in medium artice. batch - tag sentences in input file. Entities to be tagged are determing used POS tagging to find noun phrases. specific - tag specific entities in input file. The tagged word or phrases needs to be of the form w1:__entity_ w2:__entity_ Example:Her hypophysitis:__entity__ secondary to ipilimumab was well managed with supplemental:__entity__ hormones:__entity__')
|
886 |
+
results = parser.parse_args()
|
887 |
+
|
888 |
+
obj = UnsupNER(results.config)
|
889 |
+
if (results.option == "canned"):
|
890 |
+
test_canned_sentences(obj)
|
891 |
+
elif (results.option == "batch"):
|
892 |
+
if (len(results.input) == 0):
|
893 |
+
print("Input file needs to be specified")
|
894 |
+
else:
|
895 |
+
run_test(results.input,obj)
|
896 |
+
print("Tags and sentences are written in results.txt and debug.txt")
|
897 |
+
elif (results.option == "specific"):
|
898 |
+
if (len(results.input) == 0):
|
899 |
+
print("Input file needs to be specified")
|
900 |
+
else:
|
901 |
+
tag_single_entity_in_sentence(results.input,obj)
|
902 |
+
print("Tags and sentences are written in results.txt and debug.txt")
|
903 |
+
else:
|
904 |
+
print("Invalid argument:\n")
|
905 |
+
parser.print_help()
|
bbc/bbc_labels.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
bbc/desc_bbc_config.json
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{"POS_SERVER_URL": "http://127.0.0.1:8073/",
|
2 |
+
"LOG_DESCS": "0",
|
3 |
+
"USE_CLS": "0",
|
4 |
+
"BASE_PATH":"./bbc/",
|
5 |
+
"COMMON_DESCS_FILE": "untagged_terms.txt"
|
6 |
+
}
|
bbc/ner_bbc_config.json
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{"POS_SERVER_URL": "http://127.0.0.1:8073/",
|
2 |
+
"DESC_SERVER_URL": "http://127.0.0.1:8088/dummy/0/",
|
3 |
+
"ENTITY_SERVER_URL": "http://127.0.0.1:8043/",
|
4 |
+
"EMAP_FILE": "entity_types_consolidated.txt",
|
5 |
+
"FULL_SENTENCE_TAG": "1",
|
6 |
+
"SUPPRESS_UNTAGGED": "1",
|
7 |
+
"BASE_PATH":"./bbc/",
|
8 |
+
"COMMON_DESCS_FILE": "untagged_terms.txt"}
|
bbc/vocab.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
bio/a100_labels.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
bio/desc_a100_config.json
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{"POS_SERVER_URL": "http://127.0.0.1:8073/",
|
2 |
+
"LOG_DESCS": "0",
|
3 |
+
"USE_CLS": "1",
|
4 |
+
"BASE_PATH":"./bio/",
|
5 |
+
"COMMON_DESCS_FILE": "untagged_terms.txt"
|
6 |
+
}
|
bio/ner_a100_config.json
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{"POS_SERVER_URL": "http://127.0.0.1:8073/",
|
2 |
+
"DESC_SERVER_URL": "http://127.0.0.1:8087/dummy/0/",
|
3 |
+
"ENTITY_SERVER_URL": "http://127.0.0.1:8043/",
|
4 |
+
"EMAP_FILE": "entity_types_consolidated.txt",
|
5 |
+
"FULL_SENTENCE_TAG": "1",
|
6 |
+
"SUPPRESS_UNTAGGED": "1",
|
7 |
+
"BASE_PATH":"./bio/",
|
8 |
+
"COMMON_DESCS_FILE": "untagged_terms.txt"}
|
bio/vocab.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
common.py
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pdb
|
2 |
+
import sys
|
3 |
+
|
4 |
+
WORD_POS = 1
|
5 |
+
TAG_POS = 2
|
6 |
+
MASK_TAG = "__entity__"
|
7 |
+
INPUT_MASK_TAG = ":__entity__"
|
8 |
+
RESET_POS_TAG='RESET'
|
9 |
+
|
10 |
+
|
11 |
+
noun_tags = ['NFP','JJ','NN','FW','NNS','NNPS','JJS','JJR','NNP','POS','CD']
|
12 |
+
cap_tags = ['NFP','JJ','NN','FW','NNS','NNPS','JJS','JJR','NNP','PRP']
|
13 |
+
|
14 |
+
|
15 |
+
def detect_masked_positions(terms_arr):
|
16 |
+
sentence_arr,span_arr = generate_masked_sentences(terms_arr)
|
17 |
+
new_sent_arr = []
|
18 |
+
for i in range(len(terms_arr)):
|
19 |
+
new_sent_arr.append(terms_arr[i][WORD_POS])
|
20 |
+
return new_sent_arr,sentence_arr,span_arr
|
21 |
+
|
22 |
+
def generate_masked_sentences(terms_arr):
|
23 |
+
size = len(terms_arr)
|
24 |
+
sentence_arr = []
|
25 |
+
span_arr = []
|
26 |
+
i = 0
|
27 |
+
hack_for_no_nouns_case(terms_arr)
|
28 |
+
while (i < size):
|
29 |
+
term_info = terms_arr[i]
|
30 |
+
if (term_info[TAG_POS] in noun_tags):
|
31 |
+
skip = gen_sentence(sentence_arr,terms_arr,i)
|
32 |
+
i += skip
|
33 |
+
for j in range(skip):
|
34 |
+
span_arr.append(1)
|
35 |
+
else:
|
36 |
+
i += 1
|
37 |
+
span_arr.append(0)
|
38 |
+
#print(sentence_arr)
|
39 |
+
return sentence_arr,span_arr
|
40 |
+
|
41 |
+
def hack_for_no_nouns_case(terms_arr):
|
42 |
+
'''
|
43 |
+
This is just a hack for case user enters a sentence with no entity to be tagged specifically and the sentence has no nouns
|
44 |
+
Happens for odd inputs like a single word like "eg" etc.
|
45 |
+
Just make the first term as a noun to proceed.
|
46 |
+
'''
|
47 |
+
size = len(terms_arr)
|
48 |
+
i = 0
|
49 |
+
found = False
|
50 |
+
while (i < size):
|
51 |
+
term_info = terms_arr[i]
|
52 |
+
if (term_info[TAG_POS] in noun_tags):
|
53 |
+
found = True
|
54 |
+
break
|
55 |
+
else:
|
56 |
+
i += 1
|
57 |
+
if (not found and len(terms_arr) >= 1):
|
58 |
+
term_info = terms_arr[0]
|
59 |
+
term_info[TAG_POS] = noun_tags[0]
|
60 |
+
|
61 |
+
|
62 |
+
def gen_sentence(sentence_arr,terms_arr,index):
|
63 |
+
size = len(terms_arr)
|
64 |
+
new_sent = []
|
65 |
+
for prefix,term in enumerate(terms_arr[:index]):
|
66 |
+
new_sent.append(term[WORD_POS])
|
67 |
+
i = index
|
68 |
+
skip = 0
|
69 |
+
while (i < size):
|
70 |
+
if (terms_arr[i][TAG_POS] in noun_tags):
|
71 |
+
skip += 1
|
72 |
+
i += 1
|
73 |
+
else:
|
74 |
+
break
|
75 |
+
new_sent.append(MASK_TAG)
|
76 |
+
i = index + skip
|
77 |
+
while (i < size):
|
78 |
+
new_sent.append(terms_arr[i][WORD_POS])
|
79 |
+
i += 1
|
80 |
+
assert(skip != 0)
|
81 |
+
sentence_arr.append(new_sent)
|
82 |
+
return skip
|
83 |
+
|
84 |
+
|
85 |
+
|
86 |
+
def capitalize(terms_arr):
|
87 |
+
for i,term_tag in enumerate(terms_arr):
|
88 |
+
#print(term_tag)
|
89 |
+
if (term_tag[TAG_POS] in cap_tags):
|
90 |
+
word = term_tag[WORD_POS][0].upper() + term_tag[WORD_POS][1:]
|
91 |
+
term_tag[WORD_POS] = word
|
92 |
+
#print(terms_arr)
|
93 |
+
|
94 |
+
def set_POS_based_on_entities(sent):
|
95 |
+
terms_arr = []
|
96 |
+
sent_arr = sent.split()
|
97 |
+
for i,word in enumerate(sent_arr):
|
98 |
+
#print(term_tag)
|
99 |
+
term_tag = ['-']*5
|
100 |
+
if (word.endswith(INPUT_MASK_TAG)):
|
101 |
+
term_tag[TAG_POS] = noun_tags[0]
|
102 |
+
term_tag[WORD_POS] = word.replace(INPUT_MASK_TAG,"")
|
103 |
+
else:
|
104 |
+
term_tag[TAG_POS] = RESET_POS_TAG
|
105 |
+
term_tag[WORD_POS] = word
|
106 |
+
terms_arr.append(term_tag)
|
107 |
+
return terms_arr
|
108 |
+
#print(terms_arr)
|
109 |
+
|
110 |
+
def filter_common_noun_spans(span_arr,masked_sent_arr,terms_arr,common_descs):
|
111 |
+
ret_span_arr = span_arr.copy()
|
112 |
+
ret_masked_sent_arr = []
|
113 |
+
sent_index = 0
|
114 |
+
loop_span_index = 0
|
115 |
+
while (loop_span_index < len(span_arr)):
|
116 |
+
span_val = span_arr[loop_span_index]
|
117 |
+
orig_index = loop_span_index
|
118 |
+
if (span_val == 1):
|
119 |
+
curr_index = orig_index
|
120 |
+
is_all_common = True
|
121 |
+
while (curr_index < len(span_arr) and span_arr[curr_index] == 1):
|
122 |
+
term = terms_arr[curr_index]
|
123 |
+
if (term[WORD_POS].lower() not in common_descs):
|
124 |
+
is_all_common = False
|
125 |
+
curr_index += 1
|
126 |
+
loop_span_index = curr_index #note the loop scan index is updated
|
127 |
+
if (is_all_common):
|
128 |
+
curr_index = orig_index
|
129 |
+
print("Filtering common span: ",end='')
|
130 |
+
while (curr_index < len(span_arr) and span_arr[curr_index] == 1):
|
131 |
+
print(terms_arr[curr_index][WORD_POS],' ',end='')
|
132 |
+
ret_span_arr[curr_index] = 0
|
133 |
+
curr_index += 1
|
134 |
+
print()
|
135 |
+
sent_index += 1 # we are skipping a span
|
136 |
+
else:
|
137 |
+
ret_masked_sent_arr.append(masked_sent_arr[sent_index])
|
138 |
+
sent_index += 1
|
139 |
+
else:
|
140 |
+
loop_span_index += 1
|
141 |
+
return ret_masked_sent_arr,ret_span_arr
|
142 |
+
|
143 |
+
def normalize_casing(sent):
|
144 |
+
sent_arr = sent.split()
|
145 |
+
ret_sent_arr = []
|
146 |
+
for i,word in enumerate(sent_arr):
|
147 |
+
if (len(word) > 1):
|
148 |
+
norm_word = word[0] + word[1:].lower()
|
149 |
+
else:
|
150 |
+
norm_word = word[0]
|
151 |
+
ret_sent_arr.append(norm_word)
|
152 |
+
return ' '.join(ret_sent_arr)
|
153 |
+
|
common_descs.txt
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
a
|
2 |
+
all
|
3 |
+
an
|
4 |
+
and
|
5 |
+
any
|
6 |
+
are
|
7 |
+
as
|
8 |
+
at
|
9 |
+
away
|
10 |
+
be
|
11 |
+
beside
|
12 |
+
but
|
13 |
+
by
|
14 |
+
can
|
15 |
+
come
|
16 |
+
did
|
17 |
+
do
|
18 |
+
each
|
19 |
+
etc
|
20 |
+
far
|
21 |
+
free
|
22 |
+
get
|
23 |
+
gets
|
24 |
+
getting
|
25 |
+
give
|
26 |
+
given
|
27 |
+
gives
|
28 |
+
giving
|
29 |
+
go
|
30 |
+
goes
|
31 |
+
going
|
32 |
+
gonna
|
33 |
+
good
|
34 |
+
got
|
35 |
+
gotta
|
36 |
+
greatly
|
37 |
+
grow
|
38 |
+
growing
|
39 |
+
guess
|
40 |
+
had
|
41 |
+
has
|
42 |
+
how
|
43 |
+
in
|
44 |
+
is
|
45 |
+
it
|
46 |
+
its
|
47 |
+
itself
|
48 |
+
keep
|
49 |
+
keeps
|
50 |
+
kept
|
51 |
+
key
|
52 |
+
lack
|
53 |
+
led
|
54 |
+
let
|
55 |
+
lets
|
56 |
+
like
|
57 |
+
liked
|
58 |
+
likely
|
59 |
+
long
|
60 |
+
look
|
61 |
+
looking
|
62 |
+
looks
|
63 |
+
lose
|
64 |
+
loss
|
65 |
+
lost
|
66 |
+
lot
|
67 |
+
lots
|
68 |
+
lou
|
69 |
+
loud
|
70 |
+
made
|
71 |
+
make
|
72 |
+
matter
|
73 |
+
mean
|
74 |
+
meaning
|
75 |
+
means
|
76 |
+
meant
|
77 |
+
meet
|
78 |
+
meeting
|
79 |
+
meets
|
80 |
+
mere
|
81 |
+
merely
|
82 |
+
more
|
83 |
+
most
|
84 |
+
mostly
|
85 |
+
move
|
86 |
+
much
|
87 |
+
must
|
88 |
+
need
|
89 |
+
needed
|
90 |
+
needing
|
91 |
+
needs
|
92 |
+
new
|
93 |
+
next
|
94 |
+
nice
|
95 |
+
nobody
|
96 |
+
of
|
97 |
+
off
|
98 |
+
on
|
99 |
+
once
|
100 |
+
ongoing
|
101 |
+
only
|
102 |
+
or
|
103 |
+
place
|
104 |
+
placed
|
105 |
+
reach
|
106 |
+
same
|
107 |
+
saying
|
108 |
+
show
|
109 |
+
side
|
110 |
+
some
|
111 |
+
the
|
112 |
+
then
|
113 |
+
this
|
114 |
+
thence
|
115 |
+
thing
|
116 |
+
though
|
117 |
+
until
|
118 |
+
unto
|
119 |
+
usual
|
120 |
+
usually
|
121 |
+
wanna
|
122 |
+
want
|
123 |
+
wanted
|
124 |
+
wanting
|
125 |
+
wants
|
126 |
+
was
|
127 |
+
when
|
128 |
+
where
|
129 |
+
whereas
|
130 |
+
whereby
|
131 |
+
wherein
|
132 |
+
whether
|
133 |
+
which
|
134 |
+
while
|
135 |
+
whilst
|
136 |
+
whoever
|
137 |
+
whom
|
138 |
+
why
|
139 |
+
with
|
140 |
+
within
|
141 |
+
without
|
142 |
+
would
|
143 |
+
both
|
144 |
+
high
|
145 |
+
called
|
146 |
+
from
|
147 |
+
entitled
|
148 |
+
using
|
149 |
+
to
|
config_utils.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
|
3 |
+
|
4 |
+
|
5 |
+
def write_config(configs,file_name='server_config.json'):
|
6 |
+
print(json.dumps(configs))
|
7 |
+
with open(file_name, 'w') as outfile:
|
8 |
+
json.dump(configs, outfile)
|
9 |
+
|
10 |
+
|
11 |
+
def read_config(file_name='server_config.json'):
|
12 |
+
try:
|
13 |
+
with open(file_name) as json_file:
|
14 |
+
data = json.load(json_file)
|
15 |
+
#print(data)
|
16 |
+
return data
|
17 |
+
except:
|
18 |
+
print("Unable to open config file:",file_name)
|
19 |
+
return {}
|
ensemble_config.json
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{"NER_SERVERS": ["http://127.0.0.1:9088/dummy/","http://127.0.0.1:9089/dummy/"],
|
2 |
+
"BASE_PATH":"./logs/",
|
3 |
+
"bio_precedence_arr": [
|
4 |
+
"THERAPEUTIC_OR_PREVENTIVE_PROCEDURE",
|
5 |
+
"DISEASE",
|
6 |
+
"GENE",
|
7 |
+
"BODY_PART_OR_ORGAN_COMPONENT",
|
8 |
+
"BIO",
|
9 |
+
"ORGANISM_FUNCTION"
|
10 |
+
],
|
11 |
+
|
12 |
+
"phi_precedence_arr" : [
|
13 |
+
"PERSON",
|
14 |
+
"ORGANIZATION",
|
15 |
+
"ENT",
|
16 |
+
"COLOR",
|
17 |
+
"LANGUAGE",
|
18 |
+
"GRAMMAR_CONSTRUCT",
|
19 |
+
"LOCATION",
|
20 |
+
"SOCIAL_CIRCUMSTANCES"
|
21 |
+
],
|
22 |
+
|
23 |
+
"common_entities_arr":
|
24 |
+
[
|
25 |
+
"UNTAGGED_ENTITY",
|
26 |
+
"OTHER",
|
27 |
+
"GRAMMAR_CONSTRUCT",
|
28 |
+
"OBJECT",
|
29 |
+
"MEASURE",
|
30 |
+
"LOCATION"
|
31 |
+
],
|
32 |
+
|
33 |
+
"actions_arr" : [
|
34 |
+
{"url":"http://127.0.0.1:8089/dummy/","desc":"****************** A100 trained Bio model (Pubmed,Clincial trials, Bookcorpus(subset) **********"},
|
35 |
+
{"url":"http://127.0.0.1:8090/dummy/","desc":"********** Bert base cased (bookcorpus and Wikipedia) ***********"}
|
36 |
+
]
|
37 |
+
}
|
entity_types_consolidated.txt
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
THERAPEUTIC_OR_PREVENTIVE_PROCEDURE DRUG/CHEMICAL_SUBSTANCE/HAZARDOUS_OR_POISONOUS_SUBSTANCE/ESTABLISHED_PHARMACOLOGIC_CLASS/CHEMICAL_CLASS/VITAMIN/LAB_PROCEDURE/SURGICAL_AND_MEDICAL_PROCEDURES/DIAGNOSTIC_PROCEDURE/LAB_TEST_COMPONENT/STUDY/DRUG_ADJECTIVE
|
2 |
+
DISEASE MENTAL_OR_BEHAVIORAL_DYSFUNCTION/CONGENITAL_ABNORMALITY/CELL_OR_MOLECULAR_DYSFUNCTION/DISEASE_ADJECTIVE
|
3 |
+
GENE PROTEIN/ENZYME/VIRAL_PROTEIN/RECEPTOR/PROTEIN_FAMILY/MOUSE_PROTEIN_FAMILY/MOUSE_GENE/NUCLEOTIDE_SEQUENCE/GENE_EXPRESSION_ADJECTIVE
|
4 |
+
BODY_PART_OR_ORGAN_COMPONENT BODY_LOCATION_OR_REGION/BODY_SUBSTANCE/CELL/CELL_LINE/CELL_COMPONENT/BIO_MOLECULE/METABOLITE/HORMONE/BODY_ADJECTIVE
|
5 |
+
ORGANISM_FUNCTION ORGAN_OR_TISSUE_FUNCTION/PHYSIOLOGIC_FUNCTION/CELL_FUNCTION/FUNCTION_ADJECTIVE
|
6 |
+
BIO SPECIES/BACTERIUM/VIRUS/BIO_ADJECTIVE
|
7 |
+
OBJECT PRODUCT/MEDICAL_DEVICE/DEVICE/DEVICE_ADJECTIVE
|
8 |
+
MEASURE NUMBER/TIME/SEQUENCE/MEASURE_ADJECTIVE
|
9 |
+
PERSON PERSON_ADJECTIVE
|
10 |
+
ORGANIZATION UNIV/GOV/EDU/ORGANIZATION_ADJECTIVE
|
11 |
+
ENT SPORT/MOV/MUSIC/ENT_ADJECTIVE
|
12 |
+
LOCATION LOCATION_ADJECTIVE
|
13 |
+
SOCIAL_CIRCUMSTANCES RELIGION/SOCIAL_CIRCUMSTANCES_ADJECTIVE
|
14 |
+
COLOR COLOR_ADJECTIVE
|
15 |
+
LANGUAGE LANGUAGE_ADJECTIVE
|
16 |
+
GRAMMAR_CONSTRUCT
|
17 |
+
OTHER
|
18 |
+
UNTAGGED_ENTITY
|
logs/failed_queries_log.txt
ADDED
File without changes
|
logs/query_logs.txt
ADDED
File without changes
|
requirements.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
flair
|
2 |
+
st-annotated-text
|
3 |
+
|
untagged_terms.txt
ADDED
File without changes
|