gector-xlnet-base-cased-5k / grammar_error_correction_pipeline.py
ktzsh's picture
Upload folder using huggingface_hub
010f214
import os
import numpy as np
from transformers import Pipeline, TensorType
class GectorBase(object):
DELIMINTER = " "
START_TOKEN = "$START"
PAD = "<PAD>"
UNK = "<UNK>"
REPLACEMENTS = {
"''": '"',
"--": "—",
"`": "'",
"'ve": "' ve",
}
def decode_verb_form(self, original):
return self.model.config.verb_form_vocab["decode"].get(original)
def get_target_sent_by_edits(self, source_tokens, edits):
target_tokens = source_tokens[:]
shift_idx = 0
for edit in edits:
start, end, label, _ = edit
target_pos = start + shift_idx
source_token = (
target_tokens[target_pos]
if len(target_tokens) > target_pos >= 0
else ""
)
if label == "":
del target_tokens[target_pos]
shift_idx -= 1
elif start == end:
word = label.replace("$APPEND_", "")
target_tokens[target_pos:target_pos] = [word]
shift_idx += 1
elif label.startswith("$TRANSFORM_"):
word = self.apply_reverse_transformation(source_token, label)
if word is None:
word = source_token
target_tokens[target_pos] = word
elif start == end - 1:
word = label.replace("$REPLACE_", "")
target_tokens[target_pos] = word
elif label.startswith("$MERGE_"):
target_tokens[target_pos + 1 : target_pos + 1] = [label]
shift_idx += 1
return self.replace_merge_transforms(target_tokens)
def replace_merge_transforms(self, tokens):
if all(not x.startswith("$MERGE_") for x in tokens):
return tokens
target_line = " ".join(tokens)
target_line = target_line.replace(" $MERGE_HYPHEN ", "-")
target_line = target_line.replace(" $MERGE_SPACE ", "")
return target_line.split()
def convert_using_case(self, token, smart_action):
if not smart_action.startswith("$TRANSFORM_CASE_"):
return token
if smart_action.endswith("LOWER"):
return token.lower()
elif smart_action.endswith("UPPER"):
return token.upper()
elif smart_action.endswith("CAPITAL"):
return token.capitalize()
elif smart_action.endswith("CAPITAL_1"):
return token[0] + token[1:].capitalize()
elif smart_action.endswith("UPPER_-1"):
return token[:-1].upper() + token[-1]
else:
return token
def convert_using_verb(self, token, smart_action):
key_word = "$TRANSFORM_VERB_"
if not smart_action.startswith(key_word):
raise Exception(f"Unknown action type {smart_action}")
encoding_part = f"{token}_{smart_action[len(key_word):]}"
decoded_target_word = self.decode_verb_form(encoding_part)
return decoded_target_word
def convert_using_split(self, token, smart_action):
key_word = "$TRANSFORM_SPLIT"
if not smart_action.startswith(key_word):
raise Exception(f"Unknown action type {smart_action}")
target_words = token.split("-")
return " ".join(target_words)
def convert_using_plural(self, token, smart_action):
if smart_action.endswith("PLURAL"):
return token + "s"
elif smart_action.endswith("SINGULAR"):
return token[:-1]
else:
raise Exception(f"Unknown action type {smart_action}")
def apply_reverse_transformation(self, source_token, transform):
if transform.startswith("$TRANSFORM"):
# deal with equal
if transform == "$KEEP":
return source_token
# deal with case
if transform.startswith("$TRANSFORM_CASE"):
return self.convert_using_case(source_token, transform)
# deal with verb
if transform.startswith("$TRANSFORM_VERB"):
return self.convert_using_verb(source_token, transform)
# deal with split
if transform.startswith("$TRANSFORM_SPLIT"):
return self.convert_using_split(source_token, transform)
# deal with single/plural
if transform.startswith("$TRANSFORM_AGREEMENT"):
return self.convert_using_plural(source_token, transform)
# raise exception if not find correct type
raise Exception(f"Unknown action type {transform}")
else:
return source_token
def get_token_action(self, token, index, prob, sugg_token, min_error_probability):
"""Get lost of suggested actions for token."""
# cases when we don't need to do anything
if prob < min_error_probability or sugg_token in [self.UNK, self.PAD, "$KEEP"]:
return None
if (
sugg_token.startswith("$REPLACE_")
or sugg_token.startswith("$TRANSFORM_")
or sugg_token == "$DELETE"
):
start_pos = index
end_pos = index + 1
elif sugg_token.startswith("$APPEND_") or sugg_token.startswith("$MERGE_"):
start_pos = index + 1
end_pos = index + 1
if sugg_token == "$DELETE":
sugg_token_clear = ""
elif sugg_token.startswith("$TRANSFORM_") or sugg_token.startswith("$MERGE_"):
sugg_token_clear = sugg_token[:]
else:
sugg_token_clear = sugg_token[sugg_token.index("_") + 1 :]
return start_pos - 1, end_pos - 1, sugg_token_clear, prob
class GrammarErrorCorrectionPipeline(Pipeline, GectorBase):
def _sanitize_parameters(self, **kwargs):
preprocess_kwargs = {
"max_len": int(kwargs.get("max_len", 50)),
"lowercase_tokens": bool(kwargs.get("lowercase_tokens", False)),
}
forward_kwargs = {
"iterations": int(kwargs.get("iterations", 1)),
"max_len": int(kwargs.get("max_len", 50)),
"min_len": int(kwargs.get("min_len", 3)),
"min_error_probability": float(kwargs.get("min_error_probability", 0.0)),
}
postprocess_kwargs = {}
return preprocess_kwargs, forward_kwargs, postprocess_kwargs
def add_word_offsets(self, tokenized_input):
word_ids = tokenized_input.word_ids()
offsets = [i for i, x in enumerate(word_ids) if i == 0 or x != word_ids[i - 1]]
if self.framework == TensorType.PYTORCH:
import torch
offsets = torch.tensor([offsets], dtype=torch.long)
mask = torch.ones_like(offsets)
tokenized_input["word_offsets"] = offsets
tokenized_input["word_mask"] = mask
return tokenized_input
def preprocess(self, model_input, **kwargs):
tokens = [self.START_TOKEN] + model_input.split(self.DELIMINTER)
tokenized_input = self.tokenizer(
tokens,
max_length=kwargs.get("max_len"),
add_special_tokens=False,
truncation=True,
is_split_into_words=True,
return_token_type_ids=True,
return_tensors=self.framework,
)
tokenized_input["oriignal_tokens"] = tokens[1:]
tokenized_input = self.add_word_offsets(tokenized_input)
return tokenized_input
def _forward_iterative(self, batch, **forward_kwargs):
oriignal_tokens = batch.pop("oriignal_tokens")
model_outputs = self.model(**batch)
error_probs = model_outputs.max_error_probabilities.numpy()
class_probabilities_correct = model_outputs.class_probabilities_correct.numpy()
all_probabilities = np.amax(class_probabilities_correct, axis=-1)
all_idxs = np.argmax(class_probabilities_correct, axis=-1)
all_results = []
noop_index = self.model.config.detect_label2id.get("$CORRECT")
for tokens, probabilities, idxs, error_prob in zip(
oriignal_tokens, all_probabilities, all_idxs, error_probs
):
length = min(len(tokens), forward_kwargs.get("max_len"))
edits = []
# skip whole sentences if there no errors
if max(idxs) == 0:
all_results.append(tokens)
continue
# skip whole sentence if probability of correctness is not high
if error_prob < forward_kwargs.get("min_error_probability"):
all_results.append(tokens)
continue
for i in range(length + 1):
# because of START token
if i == 0:
token = self.START_TOKEN
else:
token = tokens[i - 1]
# skip if there is no error
if idxs[i] == noop_index:
continue
sugg_token = self.model.config.id2label[str(idxs[i])]
action = self.get_token_action(
token,
i,
probabilities[i],
sugg_token,
forward_kwargs.get("min_error_probability"),
)
if not action:
continue
edits.append(action)
all_results.append(self.get_target_sent_by_edits(tokens, edits))
return all_results
def _forward(self, model_inputs, **forward_kwargs):
outputs = []
for iter in range(forward_kwargs.get("iterations")):
outputs = self._forward_iterative(model_inputs, **forward_kwargs)
return {"output": outputs}
def postprocess(self, model_outputs):
return model_outputs