ajit
Initial commit
854a552
import pdb
import sys
WORD_POS = 1
TAG_POS = 2
MASK_TAG = "__entity__"
INPUT_MASK_TAG = ":__entity__"
RESET_POS_TAG='RESET'
noun_tags = ['NFP','JJ','NN','FW','NNS','NNPS','JJS','JJR','NNP','POS','CD']
cap_tags = ['NFP','JJ','NN','FW','NNS','NNPS','JJS','JJR','NNP','PRP']
def detect_masked_positions(terms_arr):
sentence_arr,span_arr = generate_masked_sentences(terms_arr)
new_sent_arr = []
for i in range(len(terms_arr)):
new_sent_arr.append(terms_arr[i][WORD_POS])
return new_sent_arr,sentence_arr,span_arr
def generate_masked_sentences(terms_arr):
size = len(terms_arr)
sentence_arr = []
span_arr = []
i = 0
hack_for_no_nouns_case(terms_arr)
while (i < size):
term_info = terms_arr[i]
if (term_info[TAG_POS] in noun_tags):
skip = gen_sentence(sentence_arr,terms_arr,i)
i += skip
for j in range(skip):
span_arr.append(1)
else:
i += 1
span_arr.append(0)
#print(sentence_arr)
return sentence_arr,span_arr
def hack_for_no_nouns_case(terms_arr):
'''
This is just a hack for case user enters a sentence with no entity to be tagged specifically and the sentence has no nouns
Happens for odd inputs like a single word like "eg" etc.
Just make the first term as a noun to proceed.
'''
size = len(terms_arr)
i = 0
found = False
while (i < size):
term_info = terms_arr[i]
if (term_info[TAG_POS] in noun_tags):
found = True
break
else:
i += 1
if (not found and len(terms_arr) >= 1):
term_info = terms_arr[0]
term_info[TAG_POS] = noun_tags[0]
def gen_sentence(sentence_arr,terms_arr,index):
size = len(terms_arr)
new_sent = []
for prefix,term in enumerate(terms_arr[:index]):
new_sent.append(term[WORD_POS])
i = index
skip = 0
while (i < size):
if (terms_arr[i][TAG_POS] in noun_tags):
skip += 1
i += 1
else:
break
new_sent.append(MASK_TAG)
i = index + skip
while (i < size):
new_sent.append(terms_arr[i][WORD_POS])
i += 1
assert(skip != 0)
sentence_arr.append(new_sent)
return skip
def capitalize(terms_arr):
for i,term_tag in enumerate(terms_arr):
#print(term_tag)
if (term_tag[TAG_POS] in cap_tags):
word = term_tag[WORD_POS][0].upper() + term_tag[WORD_POS][1:]
term_tag[WORD_POS] = word
#print(terms_arr)
def set_POS_based_on_entities(sent):
terms_arr = []
sent_arr = sent.split()
for i,word in enumerate(sent_arr):
#print(term_tag)
term_tag = ['-']*5
if (word.endswith(INPUT_MASK_TAG)):
term_tag[TAG_POS] = noun_tags[0]
term_tag[WORD_POS] = word.replace(INPUT_MASK_TAG,"")
else:
term_tag[TAG_POS] = RESET_POS_TAG
term_tag[WORD_POS] = word
terms_arr.append(term_tag)
return terms_arr
#print(terms_arr)
def filter_common_noun_spans(span_arr,masked_sent_arr,terms_arr,common_descs):
ret_span_arr = span_arr.copy()
ret_masked_sent_arr = []
sent_index = 0
loop_span_index = 0
while (loop_span_index < len(span_arr)):
span_val = span_arr[loop_span_index]
orig_index = loop_span_index
if (span_val == 1):
curr_index = orig_index
is_all_common = True
while (curr_index < len(span_arr) and span_arr[curr_index] == 1):
term = terms_arr[curr_index]
if (term[WORD_POS].lower() not in common_descs):
is_all_common = False
curr_index += 1
loop_span_index = curr_index #note the loop scan index is updated
if (is_all_common):
curr_index = orig_index
print("Filtering common span: ",end='')
while (curr_index < len(span_arr) and span_arr[curr_index] == 1):
print(terms_arr[curr_index][WORD_POS],' ',end='')
ret_span_arr[curr_index] = 0
curr_index += 1
print()
sent_index += 1 # we are skipping a span
else:
ret_masked_sent_arr.append(masked_sent_arr[sent_index])
sent_index += 1
else:
loop_span_index += 1
return ret_masked_sent_arr,ret_span_arr
def normalize_casing(sent):
sent_arr = sent.split()
ret_sent_arr = []
for i,word in enumerate(sent_arr):
if (len(word) > 1):
norm_word = word[0] + word[1:].lower()
else:
norm_word = word[0]
ret_sent_arr.append(norm_word)
return ' '.join(ret_sent_arr)