""" grammar_improve.py - this .py script contains functions to improve the grammar of a user's input or the models output. """ from datetime import datetime import os import pprint as pp from neuspell import BertChecker, SclstmChecker import neuspell import math from cleantext import clean import time import re import sys from symspellpy.symspellpy import SymSpell from utils import suppress_stdout def detect_propers(text: str): """ detect_propers - detect if a string contains proper nouns Args: text (str): [string to be checked] Returns: [bool]: [True if string contains proper nouns] """ pat = re.compile(r"(?:\w+['’])?\w+(?:-(?:\w+['’])?\w+)*") return bool(pat.search(text)) def fix_punct_spaces(string): """ fix_punct_spaces - replace spaces around punctuation with punctuation. For example, "hello , there" -> "hello, there" Parameters ---------- string : str, required, input string to be corrected Returns ------- str, corrected string """ fix_spaces = re.compile(r"\s*([?!.,]+(?:\s+[?!.,]+)*)\s*") string = fix_spaces.sub(lambda x: "{} ".format(x.group(1).replace(" ", "")), string) return string.strip() def split_sentences(text: str): """ split_sentences - split a string into a list of sentences that keep their ending punctuation. powered by regex witchcraft Args: text (str): [string to be split] Returns: [list]: [list of strings] """ return re.split(r"(? 0, "entered string for correction is empty" if sym_checker is None: # need to create a new class object. user can specify their own dictionary and bigram files if verbose: print("creating new SymSpell object") sym_checker = build_symspell_obj( edit_dist=max_dist, prefix_length=prefix_length, dictionary_path=dictionary_path, bigram_path=bigram_path, ) else: if verbose: print("using existing SymSpell object") # max edit distance per lookup (per single word, not per whole input string) suggestions = sym_checker.lookup_compound( my_string, max_edit_distance=max_dist, ignore_non_words=ignore_non_words, ignore_term_with_digits=True, transfer_casing=True, ) if verbose: print(f"{len(suggestions)} suggestions found") print(f"the original string is:\n\t{my_string}") sug_list = [sug.term for sug in suggestions] print(f"suggestions:\n\t{sug_list}\n") if len(suggestions) < 1: return clean(my_string) # no correction because no suggestions else: first_result = suggestions[0] # first result is the most likely return first_result._term def build_symspell_obj( edit_dist=2, prefix_length=7, dictionary_path=None, bigram_path=None, ): """ build_symspell_obj [build a SymSpell object] Args: verbose (bool, optional): Defaults to False. Returns: SymSpell: a SymSpell object """ dictionary_path = ( r"symspell_rsc/frequency_dictionary_en_82_765.txt" if dictionary_path is None else dictionary_path ) bigram_path = ( r"symspell_rsc/frequency_bigramdictionary_en_243_342.txt" if bigram_path is None else bigram_path ) sym_checker = SymSpell( max_dictionary_edit_distance=edit_dist + 2, prefix_length=prefix_length ) # term_index is the column of the term and count_index is the # column of the term frequency sym_checker.load_dictionary(dictionary_path, term_index=0, count_index=1) sym_checker.load_bigram_dictionary(bigram_path, term_index=0, count_index=2) return sym_checker """ # if using t5b_correction to check for spelling errors, use this code to initialize the objects import torch from transformers import T5Tokenizer, T5ForConditionalGeneration model_name = 'deep-learning-analytics/GrammarCorrector' # torch_device = 'cuda' if torch.cuda.is_available() else 'cpu' torch_device = 'cpu' gc_tokenizer = T5Tokenizer.from_pretrained(model_name) gc_model = T5ForConditionalGeneration.from_pretrained(model_name).to(torch_device) """ def t5b_correction(prompt: str, korrektor, verbose=False, beams=4): """ t5b_correction - correct a string using a text2textgen pipeline model from transformers Parameters ---------- prompt : str, required, input prompt to be corrected korrektor : transformers.pipeline, required, pipeline object verbose : bool, optional, whether to print the corrected prompt. Defaults to False. beams : int, optional, number of beams to use for the correction. Defaults to 4. Returns ------- str, corrected prompt """ p_min_len = int(math.ceil(0.9 * len(prompt))) p_max_len = int(math.ceil(1.1 * len(prompt))) if verbose: print(f"setting min to {p_min_len} and max to {p_max_len}\n") gcorr_result = korrektor( f"grammar: {prompt}", return_text=True, clean_up_tokenization_spaces=True, num_beams=beams, max_length=p_max_len, repetition_penalty=1.3, length_penalty=0.2, no_repeat_ngram_size=2, ) if verbose: print(f"grammar correction result: \n\t{gcorr_result}\n") return gcorr_result def all_neuspell_chkrs(): """ disp_neuspell_chkrs - display the neuspell checkers available Parameters ---------- None Returns ------- checker_opts - list of checkers available """ checker_opts = dir(neuspell) print(f"\navailable checkers:") pp.pprint(checker_opts, indent=4, compact=True) return checker_opts def load_ns_checker(customckr=None, fast=False): """ load_ns_checker - helper function, load / "set up" a neuspell checker from huggingface transformers Args: customckr (neuspell.NeuSpell): [neuspell checker object], optional, if not provided, will load the default checker Returns: [neuspell.NeuSpell]: [neuspell checker object] """ st = time.perf_counter() # stop all printing to the console with suppress_stdout(): if customckr is None and not fast: checker = BertChecker( pretrained=True ) # load the default checker, has the best balance elif customckr is None and fast: checker = SclstmChecker( pretrained=True ) # this one is faster but not as accurate else: checker = customckr(pretrained=True) rt_min = (time.perf_counter() - st) / 60 # return to standard logging level print(f"\n\nloaded checker in {rt_min} minutes") return checker def neuspell_correct(input_text: str, checker=None, verbose=False): """ neuspell_correct - correct a string using neuspell. note that modificaitons to the checker are needed if doing list-based corrections Parameters ---------- input_text : str, required, input string to be corrected checker : neuspell.NeuSpell, optional, neuspell checker object. Defaults to None. verbose : bool, optional, whether to print the corrected string. Defaults to False. Returns ------- str, corrected string """ if isinstance(input_text, str) and len(input_text) < 4: print(f"input text of {input_text} is too short to be corrected") return input_text if checker is None: print("NOTE - no checker provided, loading default checker") checker = SclstmChecker(pretrained=True) corrected = checker.correct(input_text) cleaned_txt = fix_punct_spaces(corrected) if verbose: print(f"neuspell correction result: \n\t{cleaned_txt}\n") return cleaned_txt def grammarpipe(corrector, qphrase: str): """ gramformer_correct - THE ORIGINAL ONE USED IN PROJECT AND NEEDS TO BE CHANGED. Idea is to correct a string using a text2textgen pipeline model from transformers Args: corrector (transformers.pipeline): [transformers pipeline object, already created w/ relevant model] qphrase (str): [text to be corrected] Returns: [str]: [corrected text] """ if isinstance(qphrase, str) and len(qphrase) < 4: print(f"input text of {qphrase} is too short to be corrected") return qphrase try: corrected = corrector( clean(qphrase), return_text=True, clean_up_tokenization_spaces=True ) return corrected[0]["generated_text"] except Exception as e: print(f"NOTE - failed to correct with grammarpipe:\n {e}") return clean(qphrase) def DLA_correct(qphrase: str): """ DLA_correct - an "overhead" function to call correct_grammar() on a string, allowing for each newline to be corrected individually Args: qphrase (str): [string to be corrected] Returns: str, the list of the corrected strings joined under " " """ if isinstance(qphrase, str) and len(qphrase) < 4: print(f"input text of {qphrase} is too short to be corrected") return qphrase sentences = split_sentences(qphrase) if len(sentences) == 1: corrected = correct_grammar(sentences[0]) return corrected else: full_cor = [] for sen in sentences: corr_sen = correct_grammar(clean(sen)) full_cor.append(corr_sen) return " ".join(full_cor) def correct_grammar( input_text: str, tokenizer, model, n_results: int = 1, beams: int = 8, temp=1, uniq_ngrams=2, rep_penalty=1.5, device="cpu", ): """ correct_grammar - correct a string using a text2textgen pipeline model from transformers. This function is an alternative to the t5b_correction function. Parameters ---------- input_text : str, required, input string to be corrected tokenizer : transformers.T5Tokenizer, required, tokenizer object, already created w/ relevant model model : transformers.T5ForConditionalGeneration, required, model object, already created w/ relevant model n_results : int, optional, number of results to return. Defaults to 1. beams : int, optional, number of beams to use for the correction. Defaults to 8. temp : int, optional, temperature to use for the correction. Defaults to 1. uniq_ngrams : int, optional, number of ngrams to use for the correction. Defaults to 2. rep_penalty : float, optional, penalty to use for the correction. Defaults to 1.5. device : str, optional, device to use for the correction. Defaults to 'cpu'. Returns ------- str, corrected string (or list of strings if n_results > 1) """ st = time.perf_counter() if len(input_text) < 5: return input_text max_length = min(int(math.ceil(len(input_text) * 1.2)), 128) batch = tokenizer( [input_text], truncation=True, padding="max_length", max_length=max_length, return_tensors="pt", ).to(device) translated = model.generate( **batch, max_length=max_length, min_length=min(10, len(input_text)), no_repeat_ngram_size=uniq_ngrams, repetition_penalty=rep_penalty, num_beams=beams, num_return_sequences=n_results, temperature=temp, ) tgt_text = tokenizer.batch_decode(translated, skip_special_tokens=True) rt_min = (time.perf_counter() - st) / 60 print(f"\n\ncorrected in {rt_min} minutes") if isinstance(tgt_text, list): return tgt_text[0] else: return tgt_text