|
import os |
|
from pathlib import Path |
|
import re |
|
|
|
|
|
VOCAB_DIR = Path(__file__).resolve().parent |
|
PAD = "@@PADDING@@" |
|
UNK = "@@UNKNOWN@@" |
|
START_TOKEN = "$START" |
|
SEQ_DELIMETERS = {"tokens": " ", "labels": "SEPL|||SEPR", "operations": "SEPL__SEPR"} |
|
|
|
|
|
def get_verb_form_dicts(): |
|
path_to_dict = os.path.join(VOCAB_DIR, "verb-form-vocab.txt") |
|
encode, decode = {}, {} |
|
with open(path_to_dict, encoding="utf-8") as f: |
|
for line in f: |
|
words, tags = line.split(":") |
|
word1, word2 = words.split("_") |
|
tag1, tag2 = tags.split("_") |
|
decode_key = f"{word1}_{tag1}_{tag2.strip()}" |
|
if decode_key not in decode: |
|
encode[words] = tags |
|
decode[decode_key] = word2 |
|
return encode, decode |
|
|
|
|
|
ENCODE_VERB_DICT, DECODE_VERB_DICT = get_verb_form_dicts() |
|
|
|
|
|
def get_target_sent_by_edits(source_tokens, edits): |
|
target_tokens = source_tokens[:] |
|
shift_idx = 0 |
|
for edit in edits: |
|
start, end, label, _ = edit |
|
target_pos = start + shift_idx |
|
if start < 0: |
|
continue |
|
elif len(target_tokens) > target_pos: |
|
source_token = target_tokens[target_pos] |
|
else: |
|
source_token = "" |
|
if label == "": |
|
del target_tokens[target_pos] |
|
shift_idx -= 1 |
|
elif start == end: |
|
word = label.replace("$APPEND_", "") |
|
|
|
if (target_pos < len(target_tokens) and target_tokens[target_pos] == word) or ( |
|
target_pos > 0 and target_tokens[target_pos - 1] == word |
|
): |
|
continue |
|
target_tokens[target_pos:target_pos] = [word] |
|
shift_idx += 1 |
|
elif label.startswith("$TRANSFORM_"): |
|
word = apply_reverse_transformation(source_token, label) |
|
if word is None: |
|
word = source_token |
|
target_tokens[target_pos] = word |
|
elif start == end - 1: |
|
word = label.replace("$REPLACE_", "") |
|
target_tokens[target_pos] = word |
|
elif label.startswith("$MERGE_"): |
|
target_tokens[target_pos + 1 : target_pos + 1] = [label] |
|
shift_idx += 1 |
|
|
|
return replace_merge_transforms(target_tokens) |
|
|
|
|
|
def replace_merge_transforms(tokens): |
|
if all(not x.startswith("$MERGE_") for x in tokens): |
|
return tokens |
|
if tokens[0].startswith("$MERGE_"): |
|
tokens = tokens[1:] |
|
if tokens[-1].startswith("$MERGE_"): |
|
tokens = tokens[:-1] |
|
|
|
target_line = " ".join(tokens) |
|
target_line = target_line.replace(" $MERGE_HYPHEN ", "-") |
|
target_line = target_line.replace(" $MERGE_SPACE ", "") |
|
target_line = re.sub(r'([\.\,\?\:]\s+)+', r'\1', target_line) |
|
return target_line.split() |
|
|
|
|
|
def convert_using_case(token, smart_action): |
|
if not smart_action.startswith("$TRANSFORM_CASE_"): |
|
return token |
|
if smart_action.endswith("LOWER"): |
|
return token.lower() |
|
elif smart_action.endswith("UPPER"): |
|
return token.upper() |
|
elif smart_action.endswith("CAPITAL"): |
|
return token.capitalize() |
|
elif smart_action.endswith("CAPITAL_1"): |
|
return token[0] + token[1:].capitalize() |
|
elif smart_action.endswith("UPPER_-1"): |
|
return token[:-1].upper() + token[-1] |
|
else: |
|
return token |
|
|
|
|
|
def convert_using_verb(token, smart_action): |
|
key_word = "$TRANSFORM_VERB_" |
|
if not smart_action.startswith(key_word): |
|
raise Exception(f"Unknown action type {smart_action}") |
|
encoding_part = f"{token}_{smart_action[len(key_word):]}" |
|
decoded_target_word = decode_verb_form(encoding_part) |
|
return decoded_target_word |
|
|
|
|
|
def convert_using_split(token, smart_action): |
|
key_word = "$TRANSFORM_SPLIT" |
|
if not smart_action.startswith(key_word): |
|
raise Exception(f"Unknown action type {smart_action}") |
|
target_words = token.split("-") |
|
return " ".join(target_words) |
|
|
|
|
|
def convert_using_plural(token, smart_action): |
|
if smart_action.endswith("PLURAL"): |
|
return token + "s" |
|
elif smart_action.endswith("SINGULAR"): |
|
return token[:-1] |
|
else: |
|
raise Exception(f"Unknown action type {smart_action}") |
|
|
|
|
|
def apply_reverse_transformation(source_token, transform): |
|
if transform.startswith("$TRANSFORM"): |
|
|
|
if transform == "$KEEP": |
|
return source_token |
|
|
|
if transform.startswith("$TRANSFORM_CASE"): |
|
return convert_using_case(source_token, transform) |
|
|
|
if transform.startswith("$TRANSFORM_VERB"): |
|
return convert_using_verb(source_token, transform) |
|
|
|
if transform.startswith("$TRANSFORM_SPLIT"): |
|
return convert_using_split(source_token, transform) |
|
|
|
if transform.startswith("$TRANSFORM_AGREEMENT"): |
|
return convert_using_plural(source_token, transform) |
|
|
|
raise Exception(f"Unknown action type {transform}") |
|
else: |
|
return source_token |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def read_parallel_lines(fn1, fn2): |
|
with open(fn1, encoding='utf-8') as f1, open(fn2, encoding='utf-8') as f2: |
|
for line1, line2 in zip(f1, f2): |
|
line1 = line1.strip() |
|
line2 = line2.strip() |
|
|
|
yield line1, line2 |
|
|
|
|
|
def read_lines(fn, skip_strip=False): |
|
if not os.path.exists(fn): |
|
return [] |
|
with open(fn, 'r', encoding='utf-8') as f: |
|
lines = f.readlines() |
|
return [s.strip() for s in lines if s.strip() or skip_strip] |
|
|
|
|
|
def write_lines(fn, lines, mode='w'): |
|
if mode == 'w' and os.path.exists(fn): |
|
os.remove(fn) |
|
with open(fn, encoding='utf-8', mode=mode) as f: |
|
f.writelines(['%s\n' % s for s in lines]) |
|
|
|
|
|
def decode_verb_form(original): |
|
return DECODE_VERB_DICT.get(original) |
|
|
|
|
|
def encode_verb_form(original_word, corrected_word): |
|
decoding_request = original_word + "_" + corrected_word |
|
decoding_response = ENCODE_VERB_DICT.get(decoding_request, "").strip() |
|
if original_word and decoding_response: |
|
answer = decoding_response |
|
else: |
|
answer = None |
|
return answer |
|
|
|
|
|
def get_weights_name(transformer_name, lowercase): |
|
if transformer_name == 'bert' and lowercase: |
|
return 'bert-base-uncased' |
|
if transformer_name == 'bert' and not lowercase: |
|
return 'bert-base-cased' |
|
if transformer_name == 'bert-large' and not lowercase: |
|
return 'bert-large-cased' |
|
if transformer_name == 'distilbert': |
|
if not lowercase: |
|
print('Warning! This model was trained only on uncased sentences.') |
|
return 'distilbert-base-uncased' |
|
if transformer_name == 'albert': |
|
if not lowercase: |
|
print('Warning! This model was trained only on uncased sentences.') |
|
return 'albert-base-v1' |
|
if lowercase: |
|
print('Warning! This model was trained only on cased sentences.') |
|
if transformer_name == 'roberta': |
|
return 'roberta-base' |
|
if transformer_name == 'roberta-large': |
|
return 'roberta-large' |
|
if transformer_name == 'gpt2': |
|
return 'gpt2' |
|
if transformer_name == 'transformerxl': |
|
return 'transfo-xl-wt103' |
|
if transformer_name == 'xlnet': |
|
return 'xlnet-base-cased' |
|
if transformer_name == 'xlnet-large': |
|
return 'xlnet-large-cased' |
|
|
|
return transformer_name |
|
|