import difflib import regtag import random def merge_span(words, tags): spans, span_tags = [], [] current_tag = 'O' span = [] for w, t in zip(words, tags): w = w.strip(":-") if len(w) == 0: continue t_info = t.split('-') if t_info[-1] != current_tag or t_info[0] == 'B': if len(span) > 0: spans.append(' '.join(span)) span_tags.append(current_tag) span = [w] current_tag = t_info[-1] else: span.append(w) if len(span) > 0: spans.append(' '.join(span)) span_tags.append(current_tag) return spans, span_tags def make_spoken(text, do_split=True): src, tgt = [], [] if do_split: chunk_size = random.choice(list(range(0, 10)) + list(range(10, 35)) * 4) if chunk_size > 0: text = random.choice(split_chunk_input(text, chunk_size)) else: text = '' words, word_tags = merge_span(*regtag.tagging(text)) for span, t in zip(words, word_tags): if t == 'O': for w in span.split(): w = w.strip('/.,?!').lower() if len(w) > 0: src.append(w) tgt.append(w) if random.random() < 0.01: random_value = regtag.augment.get_random_span() tgt.append(random_value[0]) src.append(random_value[1].lower()) else: random_value = regtag.augment.get_random_span(t, span.lower()) tgt.append(random_value[0]) src.append(random_value[1].lower()) if len(src) == 0: tgt, src = regtag.get_random_span() src = [src] tgt = [tgt] return src, tgt def split_chunk_input(raw_text, chunk_size): input_words = raw_text.strip().split() clean_data = [input_words[i:i + chunk_size] for i in range(0, len(input_words), chunk_size)] if len(clean_data) > 1: clean_data = [" ".join(clean_data[i] + clean_data[i + 1]) for i in range(len(clean_data) - 1)] else: clean_data = [" ".join(clean_data[0])] return clean_data def split_chunk_input(raw_text, chunk_size=40, overlap=10): input_words = raw_text.strip().split() part_per_chunk = chunk_size // overlap clean_data = [input_words[i:i + overlap] for i in range(0, len(input_words), overlap)] if len(clean_data) > 1: merge_data = [] for i in range(0, len(clean_data) - 1, part_per_chunk - 1): merge_data.append(' '.join([y for x in clean_data[i:i + part_per_chunk] for y in x])) else: merge_data = [" ".join(clean_data[0])] return merge_data def merge_two_chunk(chunk_1, chunk_2, overlap, debug=False): def extract_phrase_word(phrase): if phrase.startswith(''): return phrase[7:].split('](')[1][:-1].split() else: return [phrase] def has_tag(phrase): if phrase.startswith('<') and phrase.endswith(')'): return True return False def extract_compete_region(list_phrases, is_head): if is_head: list_phrases = list_phrases[::-1] compete = [] remain = [] handle_count = 0 for phrase in list_phrases: phrase_word = extract_phrase_word(phrase) if len(phrase_word) + handle_count <= overlap: compete.append(phrase) handle_count += len(phrase_word) else: if handle_count < overlap: remain_compete_count = overlap - handle_count remain.append(phrase) if not is_head: compete.extend(["({})".format(item) for item in phrase_word[:remain_compete_count]]) else: compete.extend( ["({})".format(item) for item in phrase_word[::-1][:remain_compete_count]]) handle_count = overlap else: remain.append(phrase) if is_head: compete = compete[::-1] remain = remain[::-1] return remain, compete def is_equal(phrase_1, phrase_2): if phrase_1 == phrase_2: return True if extract_phrase_word(phrase_1) == extract_phrase_word(phrase_2): if phrase_1.startswith('') and phrase_2.startswith(''): return True return False def merge_compete(list_1, list_2): idx_list_1, idx_list_2, combine_phrases = [], [], [] mark_term_complete = [] list_raw = [extract_phrase_word(item) for item in list_1] list_raw = [y for x in list_raw for y in x] for idx, phrase in enumerate(list_1): idx_list_1.extend([idx] * len(extract_phrase_word(phrase))) for idx, phrase in enumerate(list_2): idx_list_2.extend([idx] * len(extract_phrase_word(phrase))) # print(idx_list_1, idx_list_2) for idx, (idx_1, idx_2) in enumerate(zip(idx_list_1, idx_list_2)): if list_1[idx_1].startswith('') or list_2[idx_2].startswith(''): continue elif is_equal(list_1[idx_1], list_2[idx_2]): # print(list_1[idx_1]) if '1_{}'.format(idx_1) not in mark_term_complete and '2_{}'.format(idx_2) not in mark_term_complete: if idx <= overlap//2: combine_phrases.append(list_1[idx_1]) mark_term_complete.append('1_{}'.format(idx_1)) else: combine_phrases.append(list_2[idx_2]) mark_term_complete.append('2_{}'.format(idx_2)) else: combine_phrases.append(list_raw[idx]) mark_term_complete.extend(['1_{}'.format(idx_1), '2_{}'.format(idx_2)]) # print(mark_term_complete) return combine_phrases remain_1, compete_1 = extract_compete_region(chunk_1, is_head=True) remain_2, compete_2 = extract_compete_region(chunk_2[1:-1], is_head=False) compromise = merge_compete(compete_1, compete_2) if debug: print(remain_1, '\n', compete_1) print('-----------------------') print(compete_2, '\n', remain_2) print('-----------------------') print(compromise, '\n\n') return remain_1 + compromise + remain_2 def merge_chunk_pre_norm(list_chunks, overlap, debug=False): if len(list_chunks) == 0: return [] if len(list_chunks) == 1: return list_chunks[0][1:-1] current_chunk = list_chunks[0][1:-1] for tmp_chunk in list_chunks[1:]: current_chunk = merge_two_chunk(current_chunk, tmp_chunk, overlap, debug=debug) return current_chunk def equalize(s1, s2): l1 = s1.split() l2 = s2.split() res1 = [] res2 = [] combine = [] prev = difflib.Match(0, 0, 0) for match in difflib.SequenceMatcher(a=l1, b=l2).get_matching_blocks(): if prev.a + prev.size != match.a: for i in range(prev.a + prev.size, match.a): res2 += ['_' * len(l1[i])] res1 += l1[prev.a + prev.size:match.a] for i in l1[prev.a + prev.size:match.a]: if len(combine) < len(l1) // 2: print(l1[prev.a + prev.size:match.a]) combine.append(i) if prev.b + prev.size != match.b: for i in range(prev.b + prev.size, match.b): res1 += ['_' * len(l2[i])] res2 += l2[prev.b + prev.size:match.b] for i in l2[prev.b + prev.size:match.b]: if len(combine) >= len(l2) // 2: print(l2[prev.b + prev.size:match.b]) combine.append(i) res1 += l1[match.a:match.a + match.size] res2 += l2[match.b:match.b + match.size] combine += l2[match.b:match.b + match.size] prev = match return ' '.join(res1), ' '.join(res2), combine def count_overlap(words_1, words_2): # print(words_1, words_2) assert len(words_1) == len(words_2) len_overlap = 0 for match in difflib.SequenceMatcher(a=words_1, b=words_2).get_matching_blocks(): len_overlap += match.size # for w1, w2 in zip(words_1, words_2): # if w1 == w2: # len_overlap += 1 return len_overlap def find_overlap_chunk(txt_1, txt_2): # print(txt_1) # print(txt_2) window_view = 1 idx_1 = len(txt_1) - window_view idx_2 = window_view over_lap = 0 current_best_idx_1 = len(txt_1) current_best_idx_2 = 0 while window_view <= len(txt_1) and window_view <= len(txt_2): current_overlap = count_overlap(txt_1[idx_1:], txt_2[:idx_2]) print(current_overlap) if over_lap < current_overlap: over_lap = current_overlap current_best_idx_1 = idx_1 current_best_idx_2 = idx_2 window_view += 1 idx_1 = len(txt_1) - window_view idx_2 = window_view # else: # break print('----->', txt_1[current_best_idx_1:], txt_2[:current_best_idx_2]) return txt_1[current_best_idx_1:], txt_2[:current_best_idx_2] def concat_chunks(list_chunks): concat_string = list_chunks[0].split() for i in range(1, len(list_chunks)): remain_string = list_chunks[i].split() s1, s2 = find_overlap_chunk(concat_string, remain_string) s1 = ' '.join(s1) s2 = ' '.join(s2) _, _, overlap_merged = equalize(s1, s2) merge_len = len(s1.split()) concat_string = concat_string[:len(concat_string) - merge_len] + overlap_merged + remain_string[merge_len:] concat_string = ' '.join(concat_string) return concat_string