spoken-norm-taggen / utils.py
nguyenvulebinh's picture
add chunk merging
5b512d5
import difflib
import regtag
import random
def merge_span(words, tags):
spans, span_tags = [], []
current_tag = 'O'
span = []
for w, t in zip(words, tags):
w = w.strip(":-")
if len(w) == 0:
continue
t_info = t.split('-')
if t_info[-1] != current_tag or t_info[0] == 'B':
if len(span) > 0:
spans.append(' '.join(span))
span_tags.append(current_tag)
span = [w]
current_tag = t_info[-1]
else:
span.append(w)
if len(span) > 0:
spans.append(' '.join(span))
span_tags.append(current_tag)
return spans, span_tags
def make_spoken(text, do_split=True):
src, tgt = [], []
if do_split:
chunk_size = random.choice(list(range(0, 10)) + list(range(10, 35)) * 4)
if chunk_size > 0:
text = random.choice(split_chunk_input(text, chunk_size))
else:
text = ''
words, word_tags = merge_span(*regtag.tagging(text))
for span, t in zip(words, word_tags):
if t == 'O':
for w in span.split():
w = w.strip('/.,?!').lower()
if len(w) > 0:
src.append(w)
tgt.append(w)
if random.random() < 0.01:
random_value = regtag.augment.get_random_span()
tgt.append(random_value[0])
src.append(random_value[1].lower())
else:
random_value = regtag.augment.get_random_span(t, span.lower())
tgt.append(random_value[0])
src.append(random_value[1].lower())
if len(src) == 0:
tgt, src = regtag.get_random_span()
src = [src]
tgt = [tgt]
return src, tgt
def split_chunk_input(raw_text, chunk_size):
input_words = raw_text.strip().split()
clean_data = [input_words[i:i + chunk_size] for i in range(0, len(input_words), chunk_size)]
if len(clean_data) > 1:
clean_data = [" ".join(clean_data[i] + clean_data[i + 1]) for i in range(len(clean_data) - 1)]
else:
clean_data = [" ".join(clean_data[0])]
return clean_data
def split_chunk_input(raw_text, chunk_size=40, overlap=10):
input_words = raw_text.strip().split()
part_per_chunk = chunk_size // overlap
clean_data = [input_words[i:i + overlap] for i in range(0, len(input_words), overlap)]
if len(clean_data) > 1:
merge_data = []
for i in range(0, len(clean_data) - 1, part_per_chunk - 1):
merge_data.append(' '.join([y for x in clean_data[i:i + part_per_chunk] for y in x]))
else:
merge_data = [" ".join(clean_data[0])]
return merge_data
def merge_two_chunk(chunk_1, chunk_2, overlap, debug=False):
def extract_phrase_word(phrase):
if phrase.startswith('<mask>'):
return phrase[7:].split('](')[1][:-1].split()
else:
return [phrase]
def has_tag(phrase):
if phrase.startswith('<') and phrase.endswith(')'):
return True
return False
def extract_compete_region(list_phrases, is_head):
if is_head:
list_phrases = list_phrases[::-1]
compete = []
remain = []
handle_count = 0
for phrase in list_phrases:
phrase_word = extract_phrase_word(phrase)
if len(phrase_word) + handle_count <= overlap:
compete.append(phrase)
handle_count += len(phrase_word)
else:
if handle_count < overlap:
remain_compete_count = overlap - handle_count
remain.append(phrase)
if not is_head:
compete.extend(["<delete>({})".format(item) for item in phrase_word[:remain_compete_count]])
else:
compete.extend(
["<delete>({})".format(item) for item in phrase_word[::-1][:remain_compete_count]])
handle_count = overlap
else:
remain.append(phrase)
if is_head:
compete = compete[::-1]
remain = remain[::-1]
return remain, compete
def is_equal(phrase_1, phrase_2):
if phrase_1 == phrase_2:
return True
if extract_phrase_word(phrase_1) == extract_phrase_word(phrase_2):
if phrase_1.startswith('<mask>') and phrase_2.startswith('<mask>'):
return True
return False
def merge_compete(list_1, list_2):
idx_list_1, idx_list_2, combine_phrases = [], [], []
mark_term_complete = []
list_raw = [extract_phrase_word(item) for item in list_1]
list_raw = [y for x in list_raw for y in x]
for idx, phrase in enumerate(list_1):
idx_list_1.extend([idx] * len(extract_phrase_word(phrase)))
for idx, phrase in enumerate(list_2):
idx_list_2.extend([idx] * len(extract_phrase_word(phrase)))
# print(idx_list_1, idx_list_2)
for idx, (idx_1, idx_2) in enumerate(zip(idx_list_1, idx_list_2)):
if list_1[idx_1].startswith('<delete>') or list_2[idx_2].startswith('<delete>'):
continue
elif is_equal(list_1[idx_1], list_2[idx_2]):
# print(list_1[idx_1])
if '1_{}'.format(idx_1) not in mark_term_complete and '2_{}'.format(idx_2) not in mark_term_complete:
if idx <= overlap//2:
combine_phrases.append(list_1[idx_1])
mark_term_complete.append('1_{}'.format(idx_1))
else:
combine_phrases.append(list_2[idx_2])
mark_term_complete.append('2_{}'.format(idx_2))
else:
combine_phrases.append(list_raw[idx])
mark_term_complete.extend(['1_{}'.format(idx_1), '2_{}'.format(idx_2)])
# print(mark_term_complete)
return combine_phrases
remain_1, compete_1 = extract_compete_region(chunk_1, is_head=True)
remain_2, compete_2 = extract_compete_region(chunk_2[1:-1], is_head=False)
compromise = merge_compete(compete_1, compete_2)
if debug:
print(remain_1, '\n', compete_1)
print('-----------------------')
print(compete_2, '\n', remain_2)
print('-----------------------')
print(compromise, '\n\n')
return remain_1 + compromise + remain_2
def merge_chunk_pre_norm(list_chunks, overlap, debug=False):
if len(list_chunks) == 0:
return []
if len(list_chunks) == 1:
return list_chunks[0][1:-1]
current_chunk = list_chunks[0][1:-1]
for tmp_chunk in list_chunks[1:]:
current_chunk = merge_two_chunk(current_chunk, tmp_chunk, overlap, debug=debug)
return current_chunk
def equalize(s1, s2):
l1 = s1.split()
l2 = s2.split()
res1 = []
res2 = []
combine = []
prev = difflib.Match(0, 0, 0)
for match in difflib.SequenceMatcher(a=l1, b=l2).get_matching_blocks():
if prev.a + prev.size != match.a:
for i in range(prev.a + prev.size, match.a):
res2 += ['_' * len(l1[i])]
res1 += l1[prev.a + prev.size:match.a]
for i in l1[prev.a + prev.size:match.a]:
if len(combine) < len(l1) // 2:
print(l1[prev.a + prev.size:match.a])
combine.append(i)
if prev.b + prev.size != match.b:
for i in range(prev.b + prev.size, match.b):
res1 += ['_' * len(l2[i])]
res2 += l2[prev.b + prev.size:match.b]
for i in l2[prev.b + prev.size:match.b]:
if len(combine) >= len(l2) // 2:
print(l2[prev.b + prev.size:match.b])
combine.append(i)
res1 += l1[match.a:match.a + match.size]
res2 += l2[match.b:match.b + match.size]
combine += l2[match.b:match.b + match.size]
prev = match
return ' '.join(res1), ' '.join(res2), combine
def count_overlap(words_1, words_2):
# print(words_1, words_2)
assert len(words_1) == len(words_2)
len_overlap = 0
for match in difflib.SequenceMatcher(a=words_1, b=words_2).get_matching_blocks():
len_overlap += match.size
# for w1, w2 in zip(words_1, words_2):
# if w1 == w2:
# len_overlap += 1
return len_overlap
def find_overlap_chunk(txt_1, txt_2):
# print(txt_1)
# print(txt_2)
window_view = 1
idx_1 = len(txt_1) - window_view
idx_2 = window_view
over_lap = 0
current_best_idx_1 = len(txt_1)
current_best_idx_2 = 0
while window_view <= len(txt_1) and window_view <= len(txt_2):
current_overlap = count_overlap(txt_1[idx_1:], txt_2[:idx_2])
print(current_overlap)
if over_lap < current_overlap:
over_lap = current_overlap
current_best_idx_1 = idx_1
current_best_idx_2 = idx_2
window_view += 1
idx_1 = len(txt_1) - window_view
idx_2 = window_view
# else:
# break
print('----->', txt_1[current_best_idx_1:], txt_2[:current_best_idx_2])
return txt_1[current_best_idx_1:], txt_2[:current_best_idx_2]
def concat_chunks(list_chunks):
concat_string = list_chunks[0].split()
for i in range(1, len(list_chunks)):
remain_string = list_chunks[i].split()
s1, s2 = find_overlap_chunk(concat_string, remain_string)
s1 = ' '.join(s1)
s2 = ' '.join(s2)
_, _, overlap_merged = equalize(s1, s2)
merge_len = len(s1.split())
concat_string = concat_string[:len(concat_string) - merge_len] + overlap_merged + remain_string[merge_len:]
concat_string = ' '.join(concat_string)
return concat_string