import os import numpy as np from transformers import Pipeline, TensorType class GectorBase(object): DELIMINTER = " " START_TOKEN = "$START" PAD = "" UNK = "" REPLACEMENTS = { "''": '"', "--": "—", "`": "'", "'ve": "' ve", } def decode_verb_form(self, original): return self.model.config.verb_form_vocab["decode"].get(original) def get_target_sent_by_edits(self, source_tokens, edits): target_tokens = source_tokens[:] shift_idx = 0 for edit in edits: start, end, label, _ = edit target_pos = start + shift_idx source_token = ( target_tokens[target_pos] if len(target_tokens) > target_pos >= 0 else "" ) if label == "": del target_tokens[target_pos] shift_idx -= 1 elif start == end: word = label.replace("$APPEND_", "") target_tokens[target_pos:target_pos] = [word] shift_idx += 1 elif label.startswith("$TRANSFORM_"): word = self.apply_reverse_transformation(source_token, label) if word is None: word = source_token target_tokens[target_pos] = word elif start == end - 1: word = label.replace("$REPLACE_", "") target_tokens[target_pos] = word elif label.startswith("$MERGE_"): target_tokens[target_pos + 1 : target_pos + 1] = [label] shift_idx += 1 return self.replace_merge_transforms(target_tokens) def replace_merge_transforms(self, tokens): if all(not x.startswith("$MERGE_") for x in tokens): return tokens target_line = " ".join(tokens) target_line = target_line.replace(" $MERGE_HYPHEN ", "-") target_line = target_line.replace(" $MERGE_SPACE ", "") return target_line.split() def convert_using_case(self, token, smart_action): if not smart_action.startswith("$TRANSFORM_CASE_"): return token if smart_action.endswith("LOWER"): return token.lower() elif smart_action.endswith("UPPER"): return token.upper() elif smart_action.endswith("CAPITAL"): return token.capitalize() elif smart_action.endswith("CAPITAL_1"): return token[0] + token[1:].capitalize() elif smart_action.endswith("UPPER_-1"): return token[:-1].upper() + token[-1] else: return token def convert_using_verb(self, token, smart_action): key_word = "$TRANSFORM_VERB_" if not smart_action.startswith(key_word): raise Exception(f"Unknown action type {smart_action}") encoding_part = f"{token}_{smart_action[len(key_word):]}" decoded_target_word = self.decode_verb_form(encoding_part) return decoded_target_word def convert_using_split(self, token, smart_action): key_word = "$TRANSFORM_SPLIT" if not smart_action.startswith(key_word): raise Exception(f"Unknown action type {smart_action}") target_words = token.split("-") return " ".join(target_words) def convert_using_plural(self, token, smart_action): if smart_action.endswith("PLURAL"): return token + "s" elif smart_action.endswith("SINGULAR"): return token[:-1] else: raise Exception(f"Unknown action type {smart_action}") def apply_reverse_transformation(self, source_token, transform): if transform.startswith("$TRANSFORM"): # deal with equal if transform == "$KEEP": return source_token # deal with case if transform.startswith("$TRANSFORM_CASE"): return self.convert_using_case(source_token, transform) # deal with verb if transform.startswith("$TRANSFORM_VERB"): return self.convert_using_verb(source_token, transform) # deal with split if transform.startswith("$TRANSFORM_SPLIT"): return self.convert_using_split(source_token, transform) # deal with single/plural if transform.startswith("$TRANSFORM_AGREEMENT"): return self.convert_using_plural(source_token, transform) # raise exception if not find correct type raise Exception(f"Unknown action type {transform}") else: return source_token def get_token_action(self, token, index, prob, sugg_token, min_error_probability): """Get lost of suggested actions for token.""" # cases when we don't need to do anything if prob < min_error_probability or sugg_token in [self.UNK, self.PAD, "$KEEP"]: return None if ( sugg_token.startswith("$REPLACE_") or sugg_token.startswith("$TRANSFORM_") or sugg_token == "$DELETE" ): start_pos = index end_pos = index + 1 elif sugg_token.startswith("$APPEND_") or sugg_token.startswith("$MERGE_"): start_pos = index + 1 end_pos = index + 1 if sugg_token == "$DELETE": sugg_token_clear = "" elif sugg_token.startswith("$TRANSFORM_") or sugg_token.startswith("$MERGE_"): sugg_token_clear = sugg_token[:] else: sugg_token_clear = sugg_token[sugg_token.index("_") + 1 :] return start_pos - 1, end_pos - 1, sugg_token_clear, prob class GrammarErrorCorrectionPipeline(Pipeline, GectorBase): def _sanitize_parameters(self, **kwargs): preprocess_kwargs = { "max_len": int(kwargs.get("max_len", 50)), "lowercase_tokens": bool(kwargs.get("lowercase_tokens", False)), } forward_kwargs = { "iterations": int(kwargs.get("iterations", 1)), "max_len": int(kwargs.get("max_len", 50)), "min_len": int(kwargs.get("min_len", 3)), "min_error_probability": float(kwargs.get("min_error_probability", 0.0)), } postprocess_kwargs = {} return preprocess_kwargs, forward_kwargs, postprocess_kwargs def add_word_offsets(self, tokenized_input): word_ids = tokenized_input.word_ids() offsets = [i for i, x in enumerate(word_ids) if i == 0 or x != word_ids[i - 1]] if self.framework == TensorType.PYTORCH: import torch offsets = torch.tensor([offsets], dtype=torch.long) mask = torch.ones_like(offsets) tokenized_input["word_offsets"] = offsets tokenized_input["word_mask"] = mask return tokenized_input def preprocess(self, model_input, **kwargs): tokens = [self.START_TOKEN] + model_input.split(self.DELIMINTER) tokenized_input = self.tokenizer( tokens, max_length=kwargs.get("max_len"), add_special_tokens=False, truncation=True, is_split_into_words=True, return_token_type_ids=True, return_tensors=self.framework, ) tokenized_input["oriignal_tokens"] = tokens[1:] tokenized_input = self.add_word_offsets(tokenized_input) return tokenized_input def _forward_iterative(self, batch, **forward_kwargs): oriignal_tokens = batch.pop("oriignal_tokens") model_outputs = self.model(**batch) error_probs = model_outputs.max_error_probabilities.numpy() class_probabilities_correct = model_outputs.class_probabilities_correct.numpy() all_probabilities = np.amax(class_probabilities_correct, axis=-1) all_idxs = np.argmax(class_probabilities_correct, axis=-1) all_results = [] noop_index = self.model.config.detect_label2id.get("$CORRECT") for tokens, probabilities, idxs, error_prob in zip( oriignal_tokens, all_probabilities, all_idxs, error_probs ): length = min(len(tokens), forward_kwargs.get("max_len")) edits = [] # skip whole sentences if there no errors if max(idxs) == 0: all_results.append(tokens) continue # skip whole sentence if probability of correctness is not high if error_prob < forward_kwargs.get("min_error_probability"): all_results.append(tokens) continue for i in range(length + 1): # because of START token if i == 0: token = self.START_TOKEN else: token = tokens[i - 1] # skip if there is no error if idxs[i] == noop_index: continue sugg_token = self.model.config.id2label[str(idxs[i])] action = self.get_token_action( token, i, probabilities[i], sugg_token, forward_kwargs.get("min_error_probability"), ) if not action: continue edits.append(action) all_results.append(self.get_target_sent_by_edits(tokens, edits)) return all_results def _forward(self, model_inputs, **forward_kwargs): outputs = [] for iter in range(forward_kwargs.get("iterations")): outputs = self._forward_iterative(model_inputs, **forward_kwargs) return {"output": outputs} def postprocess(self, model_outputs): return model_outputs