import os from pathlib import Path import re VOCAB_DIR = Path(__file__).resolve().parent PAD = "@@PADDING@@" UNK = "@@UNKNOWN@@" START_TOKEN = "$START" SEQ_DELIMETERS = {"tokens": " ", "labels": "SEPL|||SEPR", "operations": "SEPL__SEPR"} def get_verb_form_dicts(): path_to_dict = os.path.join(VOCAB_DIR, "verb-form-vocab.txt") encode, decode = {}, {} with open(path_to_dict, encoding="utf-8") as f: for line in f: words, tags = line.split(":") word1, word2 = words.split("_") tag1, tag2 = tags.split("_") decode_key = f"{word1}_{tag1}_{tag2.strip()}" if decode_key not in decode: encode[words] = tags decode[decode_key] = word2 return encode, decode ENCODE_VERB_DICT, DECODE_VERB_DICT = get_verb_form_dicts() def get_target_sent_by_edits(source_tokens, edits): target_tokens = source_tokens[:] shift_idx = 0 for edit in edits: start, end, label, _ = edit target_pos = start + shift_idx if start < 0: continue elif len(target_tokens) > target_pos: source_token = target_tokens[target_pos] else: source_token = "" if label == "": del target_tokens[target_pos] shift_idx -= 1 elif start == end: word = label.replace("$APPEND_", "") # Avoid appending same token twice if (target_pos < len(target_tokens) and target_tokens[target_pos] == word) or ( target_pos > 0 and target_tokens[target_pos - 1] == word ): continue target_tokens[target_pos:target_pos] = [word] shift_idx += 1 elif label.startswith("$TRANSFORM_"): word = apply_reverse_transformation(source_token, label) if word is None: word = source_token target_tokens[target_pos] = word elif start == end - 1: word = label.replace("$REPLACE_", "") target_tokens[target_pos] = word elif label.startswith("$MERGE_"): target_tokens[target_pos + 1 : target_pos + 1] = [label] shift_idx += 1 return replace_merge_transforms(target_tokens) def replace_merge_transforms(tokens): if all(not x.startswith("$MERGE_") for x in tokens): return tokens if tokens[0].startswith("$MERGE_"): tokens = tokens[1:] if tokens[-1].startswith("$MERGE_"): tokens = tokens[:-1] target_line = " ".join(tokens) target_line = target_line.replace(" $MERGE_HYPHEN ", "-") target_line = target_line.replace(" $MERGE_SPACE ", "") target_line = re.sub(r'([\.\,\?\:]\s+)+', r'\1', target_line) return target_line.split() def convert_using_case(token, smart_action): if not smart_action.startswith("$TRANSFORM_CASE_"): return token if smart_action.endswith("LOWER"): return token.lower() elif smart_action.endswith("UPPER"): return token.upper() elif smart_action.endswith("CAPITAL"): return token.capitalize() elif smart_action.endswith("CAPITAL_1"): return token[0] + token[1:].capitalize() elif smart_action.endswith("UPPER_-1"): return token[:-1].upper() + token[-1] else: return token def convert_using_verb(token, smart_action): key_word = "$TRANSFORM_VERB_" if not smart_action.startswith(key_word): raise Exception(f"Unknown action type {smart_action}") encoding_part = f"{token}_{smart_action[len(key_word):]}" decoded_target_word = decode_verb_form(encoding_part) return decoded_target_word def convert_using_split(token, smart_action): key_word = "$TRANSFORM_SPLIT" if not smart_action.startswith(key_word): raise Exception(f"Unknown action type {smart_action}") target_words = token.split("-") return " ".join(target_words) def convert_using_plural(token, smart_action): if smart_action.endswith("PLURAL"): return token + "s" elif smart_action.endswith("SINGULAR"): return token[:-1] else: raise Exception(f"Unknown action type {smart_action}") def apply_reverse_transformation(source_token, transform): if transform.startswith("$TRANSFORM"): # deal with equal if transform == "$KEEP": return source_token # deal with case if transform.startswith("$TRANSFORM_CASE"): return convert_using_case(source_token, transform) # deal with verb if transform.startswith("$TRANSFORM_VERB"): return convert_using_verb(source_token, transform) # deal with split if transform.startswith("$TRANSFORM_SPLIT"): return convert_using_split(source_token, transform) # deal with single/plural if transform.startswith("$TRANSFORM_AGREEMENT"): return convert_using_plural(source_token, transform) # raise exception if not find correct type raise Exception(f"Unknown action type {transform}") else: return source_token # def read_parallel_lines(fn1, fn2): # lines1 = read_lines(fn1, skip_strip=True) # lines2 = read_lines(fn2, skip_strip=True) # assert len(lines1) == len(lines2) # out_lines1, out_lines2 = [], [] # for line1, line2 in zip(lines1, lines2): # if not line1.strip() or not line2.strip(): # continue # else: # out_lines1.append(line1) # out_lines2.append(line2) # return out_lines1, out_lines2 def read_parallel_lines(fn1, fn2): with open(fn1, encoding='utf-8') as f1, open(fn2, encoding='utf-8') as f2: for line1, line2 in zip(f1, f2): line1 = line1.strip() line2 = line2.strip() yield line1, line2 def read_lines(fn, skip_strip=False): if not os.path.exists(fn): return [] with open(fn, 'r', encoding='utf-8') as f: lines = f.readlines() return [s.strip() for s in lines if s.strip() or skip_strip] def write_lines(fn, lines, mode='w'): if mode == 'w' and os.path.exists(fn): os.remove(fn) with open(fn, encoding='utf-8', mode=mode) as f: f.writelines(['%s\n' % s for s in lines]) def decode_verb_form(original): return DECODE_VERB_DICT.get(original) def encode_verb_form(original_word, corrected_word): decoding_request = original_word + "_" + corrected_word decoding_response = ENCODE_VERB_DICT.get(decoding_request, "").strip() if original_word and decoding_response: answer = decoding_response else: answer = None return answer def get_weights_name(transformer_name, lowercase): if transformer_name == 'bert' and lowercase: return 'bert-base-uncased' if transformer_name == 'bert' and not lowercase: return 'bert-base-cased' if transformer_name == 'bert-large' and not lowercase: return 'bert-large-cased' if transformer_name == 'distilbert': if not lowercase: print('Warning! This model was trained only on uncased sentences.') return 'distilbert-base-uncased' if transformer_name == 'albert': if not lowercase: print('Warning! This model was trained only on uncased sentences.') return 'albert-base-v1' if lowercase: print('Warning! This model was trained only on cased sentences.') if transformer_name == 'roberta': return 'roberta-base' if transformer_name == 'roberta-large': return 'roberta-large' if transformer_name == 'gpt2': return 'gpt2' if transformer_name == 'transformerxl': return 'transfo-xl-wt103' if transformer_name == 'xlnet': return 'xlnet-base-cased' if transformer_name == 'xlnet-large': return 'xlnet-large-cased' return transformer_name