orion / src /utils.py
andreslu's picture
Upload 25 files
0f14897
raw
history blame
3.5 kB
from ngram import NGram
def post_process_template(tB):
if tB.endswith('.') == False:
tB += '.'
return tB
# return tB.split('.')[0] + '.'
def construct_template(words, templateA, if_then=False):
if len(words) == 2:
# template = ['{} <mask> {}.'.format(words[0], words[1])]
templates = [
# '{} is <mask> {}.'.format(words[0], words[1]),
'{} <mask> {}.'.format(words[0], words[1]),
]
elif len(words) == 1:
templates = [
# '{} is <mask>.'.format(words[0]),
'{} <mask>.'.format(words[0])]
elif len(words) == 0:
templates = []
if if_then:
for word in words:
index = templateA.index('<mask>')
templateA = templateA[:index] + word + templateA[index + len('<mask>'):]
templates = ['If ' + templateA + ' then ' + template for template in templates]
return templates
def filter_words(words_prob):
word_count = {}
token1_count = {}
word2_count = {}
ret = []
for words, prob, *_ in words_prob:
filter_this = False
# filter repetitive token
token_count = {}
for word in words:
for token in word.split(' '):
if token in token_count:
filter_this = True
token_count[token] = 1
if filter_this:
prob *= 0.5
# filter repetitive words
if len(words) == 2 and words[0] == words[1]:
continue
# filter repetitive first token
token1 = words[0].split(' ')[0]
if token1 not in token1_count:
token1_count[token1] = 1
else:
token1_count[token1] += 1
prob /= token1_count[token1]
for word in words:
if word not in word_count:
word_count[word] = 0
word_count[word] += 1
prob /= word_count[word]
if len(words) == 2:
if words[1] not in word2_count:
word2_count[words[1]] = 0
word2_count[words[1]] += 1
prob /= word2_count[words[1]]
ret.append([words, prob])
return sorted(ret, key=lambda x: x[1], reverse=True)
import math
from copy import deepcopy
def convert_for_print(arr):
ret = deepcopy(arr)
for i in range(len(ret)):
ret[i][1] = round(ret[i][1], 7)
if len(ret[i]) == 3:
for j in range(len(ret[i][2])):
ret[i][2][j] = round(ret[i][2][j], 7)
return ret
def formalize_tA(tA):
tA = tA.strip()
if tA.endswith('.'):
tA = tA[:-1].strip() + '.'
else:
tA += '.'
tA = tA.replace(' ,', ',')
tA = tA.replace(" '", "'")
return tA
ngram_n = 3
def extract_similar_words(txt, words):
max_word_length = 0
for word in words:
if len(word) > max_word_length:
max_word_length = len(word)
txt_ngrams = []
for i in range(len(txt)):
for j in range(i + ngram_n, min(len(txt), i + max_word_length + 5)):
txt_ngrams.append(txt[i:j].lower())
n = NGram(txt_ngrams, key=lambda x: x.lower(), N=ngram_n)
ret = []
for word in words:
matched_word = n.find(word.lower(), 0.5)
if matched_word is None:
return None
ret.append(matched_word)
return ret
def extract_words(txt, words):
for word in words:
if word not in txt:
return None
return [word.lower() for word in words]