|
import os |
|
import numpy as np |
|
|
|
from transformers import Pipeline, TensorType |
|
|
|
|
|
class GectorBase(object): |
|
DELIMINTER = " " |
|
START_TOKEN = "$START" |
|
PAD = "<PAD>" |
|
UNK = "<UNK>" |
|
REPLACEMENTS = { |
|
"''": '"', |
|
"--": "—", |
|
"`": "'", |
|
"'ve": "' ve", |
|
} |
|
|
|
def decode_verb_form(self, original): |
|
return self.model.config.verb_form_vocab["decode"].get(original) |
|
|
|
def get_target_sent_by_edits(self, source_tokens, edits): |
|
target_tokens = source_tokens[:] |
|
shift_idx = 0 |
|
for edit in edits: |
|
start, end, label, _ = edit |
|
target_pos = start + shift_idx |
|
source_token = ( |
|
target_tokens[target_pos] |
|
if len(target_tokens) > target_pos >= 0 |
|
else "" |
|
) |
|
if label == "": |
|
del target_tokens[target_pos] |
|
shift_idx -= 1 |
|
elif start == end: |
|
word = label.replace("$APPEND_", "") |
|
target_tokens[target_pos:target_pos] = [word] |
|
shift_idx += 1 |
|
elif label.startswith("$TRANSFORM_"): |
|
word = self.apply_reverse_transformation(source_token, label) |
|
if word is None: |
|
word = source_token |
|
target_tokens[target_pos] = word |
|
elif start == end - 1: |
|
word = label.replace("$REPLACE_", "") |
|
target_tokens[target_pos] = word |
|
elif label.startswith("$MERGE_"): |
|
target_tokens[target_pos + 1 : target_pos + 1] = [label] |
|
shift_idx += 1 |
|
|
|
return self.replace_merge_transforms(target_tokens) |
|
|
|
def replace_merge_transforms(self, tokens): |
|
if all(not x.startswith("$MERGE_") for x in tokens): |
|
return tokens |
|
|
|
target_line = " ".join(tokens) |
|
target_line = target_line.replace(" $MERGE_HYPHEN ", "-") |
|
target_line = target_line.replace(" $MERGE_SPACE ", "") |
|
return target_line.split() |
|
|
|
def convert_using_case(self, token, smart_action): |
|
if not smart_action.startswith("$TRANSFORM_CASE_"): |
|
return token |
|
if smart_action.endswith("LOWER"): |
|
return token.lower() |
|
elif smart_action.endswith("UPPER"): |
|
return token.upper() |
|
elif smart_action.endswith("CAPITAL"): |
|
return token.capitalize() |
|
elif smart_action.endswith("CAPITAL_1"): |
|
return token[0] + token[1:].capitalize() |
|
elif smart_action.endswith("UPPER_-1"): |
|
return token[:-1].upper() + token[-1] |
|
else: |
|
return token |
|
|
|
def convert_using_verb(self, token, smart_action): |
|
key_word = "$TRANSFORM_VERB_" |
|
if not smart_action.startswith(key_word): |
|
raise Exception(f"Unknown action type {smart_action}") |
|
encoding_part = f"{token}_{smart_action[len(key_word):]}" |
|
decoded_target_word = self.decode_verb_form(encoding_part) |
|
return decoded_target_word |
|
|
|
def convert_using_split(self, token, smart_action): |
|
key_word = "$TRANSFORM_SPLIT" |
|
if not smart_action.startswith(key_word): |
|
raise Exception(f"Unknown action type {smart_action}") |
|
target_words = token.split("-") |
|
return " ".join(target_words) |
|
|
|
def convert_using_plural(self, token, smart_action): |
|
if smart_action.endswith("PLURAL"): |
|
return token + "s" |
|
elif smart_action.endswith("SINGULAR"): |
|
return token[:-1] |
|
else: |
|
raise Exception(f"Unknown action type {smart_action}") |
|
|
|
def apply_reverse_transformation(self, source_token, transform): |
|
if transform.startswith("$TRANSFORM"): |
|
|
|
if transform == "$KEEP": |
|
return source_token |
|
|
|
if transform.startswith("$TRANSFORM_CASE"): |
|
return self.convert_using_case(source_token, transform) |
|
|
|
if transform.startswith("$TRANSFORM_VERB"): |
|
return self.convert_using_verb(source_token, transform) |
|
|
|
if transform.startswith("$TRANSFORM_SPLIT"): |
|
return self.convert_using_split(source_token, transform) |
|
|
|
if transform.startswith("$TRANSFORM_AGREEMENT"): |
|
return self.convert_using_plural(source_token, transform) |
|
|
|
raise Exception(f"Unknown action type {transform}") |
|
else: |
|
return source_token |
|
|
|
def get_token_action(self, token, index, prob, sugg_token, min_error_probability): |
|
"""Get lost of suggested actions for token.""" |
|
|
|
if prob < min_error_probability or sugg_token in [self.UNK, self.PAD, "$KEEP"]: |
|
return None |
|
|
|
if ( |
|
sugg_token.startswith("$REPLACE_") |
|
or sugg_token.startswith("$TRANSFORM_") |
|
or sugg_token == "$DELETE" |
|
): |
|
start_pos = index |
|
end_pos = index + 1 |
|
elif sugg_token.startswith("$APPEND_") or sugg_token.startswith("$MERGE_"): |
|
start_pos = index + 1 |
|
end_pos = index + 1 |
|
|
|
if sugg_token == "$DELETE": |
|
sugg_token_clear = "" |
|
elif sugg_token.startswith("$TRANSFORM_") or sugg_token.startswith("$MERGE_"): |
|
sugg_token_clear = sugg_token[:] |
|
else: |
|
sugg_token_clear = sugg_token[sugg_token.index("_") + 1 :] |
|
|
|
return start_pos - 1, end_pos - 1, sugg_token_clear, prob |
|
|
|
|
|
class GrammarErrorCorrectionPipeline(Pipeline, GectorBase): |
|
def _sanitize_parameters(self, **kwargs): |
|
preprocess_kwargs = { |
|
"max_len": int(kwargs.get("max_len", 50)), |
|
"lowercase_tokens": bool(kwargs.get("lowercase_tokens", False)), |
|
} |
|
forward_kwargs = { |
|
"iterations": int(kwargs.get("iterations", 1)), |
|
"max_len": int(kwargs.get("max_len", 50)), |
|
"min_len": int(kwargs.get("min_len", 3)), |
|
"min_error_probability": float(kwargs.get("min_error_probability", 0.0)), |
|
} |
|
postprocess_kwargs = {} |
|
return preprocess_kwargs, forward_kwargs, postprocess_kwargs |
|
|
|
def add_word_offsets(self, tokenized_input): |
|
word_ids = tokenized_input.word_ids() |
|
offsets = [i for i, x in enumerate(word_ids) if i == 0 or x != word_ids[i - 1]] |
|
if self.framework == TensorType.PYTORCH: |
|
import torch |
|
|
|
offsets = torch.tensor([offsets], dtype=torch.long) |
|
mask = torch.ones_like(offsets) |
|
tokenized_input["word_offsets"] = offsets |
|
tokenized_input["word_mask"] = mask |
|
return tokenized_input |
|
|
|
def preprocess(self, model_input, **kwargs): |
|
tokens = [self.START_TOKEN] + model_input.split(self.DELIMINTER) |
|
tokenized_input = self.tokenizer( |
|
tokens, |
|
max_length=kwargs.get("max_len"), |
|
add_special_tokens=False, |
|
truncation=True, |
|
is_split_into_words=True, |
|
return_token_type_ids=True, |
|
return_tensors=self.framework, |
|
) |
|
tokenized_input["oriignal_tokens"] = tokens[1:] |
|
tokenized_input = self.add_word_offsets(tokenized_input) |
|
return tokenized_input |
|
|
|
def _forward_iterative(self, batch, **forward_kwargs): |
|
oriignal_tokens = batch.pop("oriignal_tokens") |
|
model_outputs = self.model(**batch) |
|
|
|
error_probs = model_outputs.max_error_probabilities.numpy() |
|
class_probabilities_correct = model_outputs.class_probabilities_correct.numpy() |
|
all_probabilities = np.amax(class_probabilities_correct, axis=-1) |
|
all_idxs = np.argmax(class_probabilities_correct, axis=-1) |
|
|
|
all_results = [] |
|
noop_index = self.model.config.detect_label2id.get("$CORRECT") |
|
for tokens, probabilities, idxs, error_prob in zip( |
|
oriignal_tokens, all_probabilities, all_idxs, error_probs |
|
): |
|
length = min(len(tokens), forward_kwargs.get("max_len")) |
|
edits = [] |
|
|
|
|
|
if max(idxs) == 0: |
|
all_results.append(tokens) |
|
continue |
|
|
|
|
|
if error_prob < forward_kwargs.get("min_error_probability"): |
|
all_results.append(tokens) |
|
continue |
|
for i in range(length + 1): |
|
|
|
if i == 0: |
|
token = self.START_TOKEN |
|
else: |
|
token = tokens[i - 1] |
|
|
|
if idxs[i] == noop_index: |
|
continue |
|
|
|
sugg_token = self.model.config.id2label[str(idxs[i])] |
|
action = self.get_token_action( |
|
token, |
|
i, |
|
probabilities[i], |
|
sugg_token, |
|
forward_kwargs.get("min_error_probability"), |
|
) |
|
if not action: |
|
continue |
|
|
|
edits.append(action) |
|
all_results.append(self.get_target_sent_by_edits(tokens, edits)) |
|
return all_results |
|
|
|
def _forward(self, model_inputs, **forward_kwargs): |
|
outputs = [] |
|
for iter in range(forward_kwargs.get("iterations")): |
|
outputs = self._forward_iterative(model_inputs, **forward_kwargs) |
|
return {"output": outputs} |
|
|
|
def postprocess(self, model_outputs): |
|
return model_outputs |
|
|