import pdb import sys WORD_POS = 1 TAG_POS = 2 MASK_TAG = "__entity__" INPUT_MASK_TAG = ":__entity__" RESET_POS_TAG='RESET' noun_tags = ['NFP','JJ','NN','FW','NNS','NNPS','JJS','JJR','NNP','POS','CD'] cap_tags = ['NFP','JJ','NN','FW','NNS','NNPS','JJS','JJR','NNP','PRP'] def detect_masked_positions(terms_arr): sentence_arr,span_arr = generate_masked_sentences(terms_arr) new_sent_arr = [] for i in range(len(terms_arr)): new_sent_arr.append(terms_arr[i][WORD_POS]) return new_sent_arr,sentence_arr,span_arr def generate_masked_sentences(terms_arr): size = len(terms_arr) sentence_arr = [] span_arr = [] i = 0 hack_for_no_nouns_case(terms_arr) while (i < size): term_info = terms_arr[i] if (term_info[TAG_POS] in noun_tags): skip = gen_sentence(sentence_arr,terms_arr,i) i += skip for j in range(skip): span_arr.append(1) else: i += 1 span_arr.append(0) #print(sentence_arr) return sentence_arr,span_arr def hack_for_no_nouns_case(terms_arr): ''' This is just a hack for case user enters a sentence with no entity to be tagged specifically and the sentence has no nouns Happens for odd inputs like a single word like "eg" etc. Just make the first term as a noun to proceed. ''' size = len(terms_arr) i = 0 found = False while (i < size): term_info = terms_arr[i] if (term_info[TAG_POS] in noun_tags): found = True break else: i += 1 if (not found and len(terms_arr) >= 1): term_info = terms_arr[0] term_info[TAG_POS] = noun_tags[0] def gen_sentence(sentence_arr,terms_arr,index): size = len(terms_arr) new_sent = [] for prefix,term in enumerate(terms_arr[:index]): new_sent.append(term[WORD_POS]) i = index skip = 0 while (i < size): if (terms_arr[i][TAG_POS] in noun_tags): skip += 1 i += 1 else: break new_sent.append(MASK_TAG) i = index + skip while (i < size): new_sent.append(terms_arr[i][WORD_POS]) i += 1 assert(skip != 0) sentence_arr.append(new_sent) return skip def capitalize(terms_arr): for i,term_tag in enumerate(terms_arr): #print(term_tag) if (term_tag[TAG_POS] in cap_tags): word = term_tag[WORD_POS][0].upper() + term_tag[WORD_POS][1:] term_tag[WORD_POS] = word #print(terms_arr) def set_POS_based_on_entities(sent): terms_arr = [] sent_arr = sent.split() for i,word in enumerate(sent_arr): #print(term_tag) term_tag = ['-']*5 if (word.endswith(INPUT_MASK_TAG)): term_tag[TAG_POS] = noun_tags[0] term_tag[WORD_POS] = word.replace(INPUT_MASK_TAG,"") else: term_tag[TAG_POS] = RESET_POS_TAG term_tag[WORD_POS] = word terms_arr.append(term_tag) return terms_arr #print(terms_arr) def filter_common_noun_spans(span_arr,masked_sent_arr,terms_arr,common_descs): ret_span_arr = span_arr.copy() ret_masked_sent_arr = [] sent_index = 0 loop_span_index = 0 while (loop_span_index < len(span_arr)): span_val = span_arr[loop_span_index] orig_index = loop_span_index if (span_val == 1): curr_index = orig_index is_all_common = True while (curr_index < len(span_arr) and span_arr[curr_index] == 1): term = terms_arr[curr_index] if (term[WORD_POS].lower() not in common_descs): is_all_common = False curr_index += 1 loop_span_index = curr_index #note the loop scan index is updated if (is_all_common): curr_index = orig_index print("Filtering common span: ",end='') while (curr_index < len(span_arr) and span_arr[curr_index] == 1): print(terms_arr[curr_index][WORD_POS],' ',end='') ret_span_arr[curr_index] = 0 curr_index += 1 print() sent_index += 1 # we are skipping a span else: ret_masked_sent_arr.append(masked_sent_arr[sent_index]) sent_index += 1 else: loop_span_index += 1 return ret_masked_sent_arr,ret_span_arr def normalize_casing(sent): sent_arr = sent.split() ret_sent_arr = [] for i,word in enumerate(sent_arr): if (len(word) > 1): norm_word = word[0] + word[1:].lower() else: norm_word = word[0] ret_sent_arr.append(norm_word) return ' '.join(ret_sent_arr)