Spaces:
Runtime error
Runtime error
peter szemraj
commited on
Commit
•
7b10a08
1
Parent(s):
cfd67f6
:truck: consolidate grammar-related functions to one file
Browse files- grammar_improve.py +123 -4
- utils.py +0 -112
grammar_improve.py
CHANGED
@@ -10,7 +10,7 @@ import math
|
|
10 |
from cleantext import clean
|
11 |
import time
|
12 |
import re
|
13 |
-
|
14 |
|
15 |
|
16 |
def fix_punct_spaces(string):
|
@@ -45,6 +45,126 @@ def split_sentences(text: str):
|
|
45 |
return re.split(r"(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s", text)
|
46 |
|
47 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
def t5b_correction(prompt: str, korrektor, verbose=False, beams=4):
|
49 |
"""
|
50 |
t5b_correction - correct a string using a text2textgen pipeline model from transformers
|
@@ -100,7 +220,7 @@ def disp_neuspell_chkrs():
|
|
100 |
|
101 |
def load_ns_checker(customckr=None):
|
102 |
"""
|
103 |
-
load_ns_checker - helper function, load a neuspell checker from huggingface transformers
|
104 |
|
105 |
Args:
|
106 |
customckr (neuspell.NeuSpell): [neuspell checker object], optional, if not provided, will load the default checker
|
@@ -211,8 +331,7 @@ def correct_grammar(input_text: str, tokenizer, model,
|
|
211 |
|
212 |
Returns
|
213 |
-------
|
214 |
-
|
215 |
-
[description]
|
216 |
"""
|
217 |
if len(input_text) < 5:
|
218 |
return input_text
|
|
|
10 |
from cleantext import clean
|
11 |
import time
|
12 |
import re
|
13 |
+
|
14 |
|
15 |
|
16 |
def fix_punct_spaces(string):
|
|
|
45 |
return re.split(r"(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s", text)
|
46 |
|
47 |
|
48 |
+
def remove_repeated_words(bot_response):
|
49 |
+
"""
|
50 |
+
remove_repeated_words - remove repeated words from a string, returning only the first instance of each word
|
51 |
+
|
52 |
+
Parameters
|
53 |
+
----------
|
54 |
+
bot_response : str
|
55 |
+
string to remove repeated words from
|
56 |
+
|
57 |
+
Returns
|
58 |
+
-------
|
59 |
+
str
|
60 |
+
string containing the first instance of each word
|
61 |
+
"""
|
62 |
+
words = bot_response.split()
|
63 |
+
unique_words = []
|
64 |
+
for word in words:
|
65 |
+
if word not in unique_words:
|
66 |
+
unique_words.append(word)
|
67 |
+
return " ".join(unique_words)
|
68 |
+
|
69 |
+
|
70 |
+
def remove_trailing_punctuation(text: str, fuLL_strip=False):
|
71 |
+
"""
|
72 |
+
remove_trailing_punctuation - remove trailing punctuation from a string. Purpose is to seem more natural to end users
|
73 |
+
|
74 |
+
Args:
|
75 |
+
text (str): [string to be cleaned]
|
76 |
+
|
77 |
+
Returns:
|
78 |
+
[str]: [cleaned string]
|
79 |
+
"""
|
80 |
+
if fuLL_strip:
|
81 |
+
return text.strip("?!.,;:")
|
82 |
+
else:
|
83 |
+
return text.strip(".,;:")
|
84 |
+
|
85 |
+
|
86 |
+
"""
|
87 |
+
start of SymSpell code
|
88 |
+
"""
|
89 |
+
|
90 |
+
|
91 |
+
def symspeller(my_string: str, sym_checker=None, max_dist: int = 3, prefix_length: int = 7,
|
92 |
+
ignore_non_words=True,
|
93 |
+
dictionary_path: str = None, bigram_path: str = None, verbose=False):
|
94 |
+
"""
|
95 |
+
symspeller - a wrapper for the SymSpell class from symspellpy
|
96 |
+
|
97 |
+
Parameters
|
98 |
+
----------
|
99 |
+
my_string : str, required, default=None, the string to be checked
|
100 |
+
sym_checker : SymSpell, optional, default=None, the SymSpell object to use
|
101 |
+
max_dist : int, optional, default=3, the maximum distance to look for replacements
|
102 |
+
"""
|
103 |
+
|
104 |
+
assert len(my_string) > 0, "entered string for correction is empty"
|
105 |
+
|
106 |
+
if sym_checker is None:
|
107 |
+
# need to create a new class object. user can specify their own dictionary and bigram files
|
108 |
+
if verbose:
|
109 |
+
print("creating new SymSpell object")
|
110 |
+
sym_checker = build_symspell_obj(
|
111 |
+
edit_dist=max_dist, prefix_length=prefix_length, dictionary_path=dictionary_path, bigram_path=bigram_path,)
|
112 |
+
else:
|
113 |
+
if verbose:
|
114 |
+
print("using existing SymSpell object")
|
115 |
+
# max edit distance per lookup (per single word, not per whole input string)
|
116 |
+
suggestions = sym_checker.lookup_compound(
|
117 |
+
my_string, max_edit_distance=max_dist, ignore_non_words=ignore_non_words, ignore_term_with_digits=True, transfer_casing=True,
|
118 |
+
)
|
119 |
+
|
120 |
+
if verbose:
|
121 |
+
print(f"{len(suggestions)} suggestions found")
|
122 |
+
print(f"the original string is:\n\t{my_string}")
|
123 |
+
sug_list = [sug.term for sug in suggestions]
|
124 |
+
print(f"suggestions:\n\t{sug_list}\n")
|
125 |
+
|
126 |
+
if len(suggestions) < 1:
|
127 |
+
return clean(my_string) # no correction because no suggestions
|
128 |
+
else:
|
129 |
+
first_result = suggestions[0] # first result is the most likely
|
130 |
+
return first_result._term
|
131 |
+
|
132 |
+
|
133 |
+
def build_symspell_obj(edit_dist=3, prefix_length=7, dictionary_path=None, bigram_path=None,):
|
134 |
+
"""
|
135 |
+
build_symspell_obj [build a SymSpell object]
|
136 |
+
|
137 |
+
Args:
|
138 |
+
verbose (bool, optional): Defaults to False.
|
139 |
+
|
140 |
+
Returns:
|
141 |
+
SymSpell: a SymSpell object
|
142 |
+
"""
|
143 |
+
dictionary_path = r"symspell_rsc/frequency_dictionary_en_82_765.txt" if dictionary_path is None else dictionary_path
|
144 |
+
bigram_path = r"symspell_rsc/frequency_bigramdictionary_en_243_342.txt" if bigram_path is None else bigram_path
|
145 |
+
sym_checker = SymSpell(
|
146 |
+
max_dictionary_edit_distance=edit_dist, prefix_length=prefix_length)
|
147 |
+
# term_index is the column of the term and count_index is the
|
148 |
+
# column of the term frequency
|
149 |
+
sym_checker.load_dictionary(dictionary_path, term_index=0, count_index=1)
|
150 |
+
sym_checker.load_bigram_dictionary(
|
151 |
+
bigram_path, term_index=0, count_index=2)
|
152 |
+
|
153 |
+
return sym_checker
|
154 |
+
|
155 |
+
"""
|
156 |
+
NEEDED FOR T5
|
157 |
+
import torch
|
158 |
+
from transformers import T5Tokenizer, T5ForConditionalGeneration
|
159 |
+
|
160 |
+
model_name = 'deep-learning-analytics/GrammarCorrector'
|
161 |
+
# torch_device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
162 |
+
torch_device = 'cpu'
|
163 |
+
gc_tokenizer = T5Tokenizer.from_pretrained(model_name)
|
164 |
+
gc_model = T5ForConditionalGeneration.from_pretrained(model_name).to(torch_device)
|
165 |
+
|
166 |
+
"""
|
167 |
+
|
168 |
def t5b_correction(prompt: str, korrektor, verbose=False, beams=4):
|
169 |
"""
|
170 |
t5b_correction - correct a string using a text2textgen pipeline model from transformers
|
|
|
220 |
|
221 |
def load_ns_checker(customckr=None):
|
222 |
"""
|
223 |
+
load_ns_checker - helper function, load / "set up" a neuspell checker from huggingface transformers
|
224 |
|
225 |
Args:
|
226 |
customckr (neuspell.NeuSpell): [neuspell checker object], optional, if not provided, will load the default checker
|
|
|
331 |
|
332 |
Returns
|
333 |
-------
|
334 |
+
str, corrected string (or list of strings if n_results > 1)
|
|
|
335 |
"""
|
336 |
if len(input_text) < 5:
|
337 |
return input_text
|
utils.py
CHANGED
@@ -39,68 +39,6 @@ def print_spacer(n=1):
|
|
39 |
print("\n -------- " * n)
|
40 |
|
41 |
|
42 |
-
def build_symspell_obj(edit_dist=3, prefix_length=7, dictionary_path=None, bigram_path=None,):
|
43 |
-
"""
|
44 |
-
build_symspell_obj [build a SymSpell object]
|
45 |
-
|
46 |
-
Args:
|
47 |
-
verbose (bool, optional): Defaults to False.
|
48 |
-
|
49 |
-
Returns:
|
50 |
-
SymSpell: a SymSpell object
|
51 |
-
"""
|
52 |
-
dictionary_path = r"symspell_rsc/frequency_dictionary_en_82_765.txt" if dictionary_path is None else dictionary_path
|
53 |
-
bigram_path = r"symspell_rsc/frequency_bigramdictionary_en_243_342.txt" if bigram_path is None else bigram_path
|
54 |
-
sym_checker = SymSpell(
|
55 |
-
max_dictionary_edit_distance=edit_dist, prefix_length=prefix_length)
|
56 |
-
# term_index is the column of the term and count_index is the
|
57 |
-
# column of the term frequency
|
58 |
-
sym_checker.load_dictionary(dictionary_path, term_index=0, count_index=1)
|
59 |
-
sym_checker.load_bigram_dictionary(bigram_path, term_index=0, count_index=2)
|
60 |
-
|
61 |
-
return sym_checker
|
62 |
-
|
63 |
-
|
64 |
-
def symspeller(my_string: str, sym_checker=None, max_dist: int = 3, prefix_length: int = 7,
|
65 |
-
ignore_non_words=True,
|
66 |
-
dictionary_path:str=None, bigram_path:str=None, verbose=False):
|
67 |
-
"""
|
68 |
-
symspeller - a wrapper for the SymSpell class from symspellpy
|
69 |
-
|
70 |
-
Parameters
|
71 |
-
----------
|
72 |
-
my_string : str, required, default=None, the string to be checked
|
73 |
-
sym_checker : SymSpell, optional, default=None, the SymSpell object to use
|
74 |
-
max_dist : int, optional, default=3, the maximum distance to look for replacements
|
75 |
-
"""
|
76 |
-
|
77 |
-
assert len(my_string) > 0, "entered string for correction is empty"
|
78 |
-
|
79 |
-
if sym_checker is None:
|
80 |
-
# need to create a new class object. user can specify their own dictionary and bigram files
|
81 |
-
if verbose:
|
82 |
-
print("creating new SymSpell object")
|
83 |
-
sym_checker = build_symspell_obj(
|
84 |
-
edit_dist=max_dist, prefix_length=prefix_length, dictionary_path=dictionary_path, bigram_path=bigram_path,)
|
85 |
-
else:
|
86 |
-
if verbose: print("using existing SymSpell object")
|
87 |
-
# max edit distance per lookup (per single word, not per whole input string)
|
88 |
-
suggestions = sym_checker.lookup_compound(
|
89 |
-
my_string, max_edit_distance=max_dist, ignore_non_words=ignore_non_words, ignore_term_with_digits=True, transfer_casing=True,
|
90 |
-
)
|
91 |
-
|
92 |
-
if verbose:
|
93 |
-
print(f"{len(suggestions)} suggestions found")
|
94 |
-
print(f"the original string is:\n\t{my_string}")
|
95 |
-
sug_list = [sug.term for sug in suggestions]
|
96 |
-
print(f"suggestions:\n\t{sug_list}\n")
|
97 |
-
|
98 |
-
if len(suggestions) < 1:
|
99 |
-
return clean(my_string) # no correction because no suggestions
|
100 |
-
else:
|
101 |
-
first_result = suggestions[0] # first result is the most likely
|
102 |
-
return first_result._term
|
103 |
-
|
104 |
|
105 |
def fast_scandir(dirname: str):
|
106 |
"""
|
@@ -412,56 +350,6 @@ def dl_extract_zip(
|
|
412 |
return extract_loc
|
413 |
|
414 |
|
415 |
-
def remove_repeated_words(bot_response):
|
416 |
-
"""
|
417 |
-
remove_repeated_words - remove repeated words from a string, returning only the first instance of each word
|
418 |
-
|
419 |
-
Parameters
|
420 |
-
----------
|
421 |
-
bot_response : str
|
422 |
-
string to remove repeated words from
|
423 |
-
|
424 |
-
Returns
|
425 |
-
-------
|
426 |
-
str
|
427 |
-
string containing the first instance of each word
|
428 |
-
"""
|
429 |
-
words = bot_response.split()
|
430 |
-
unique_words = []
|
431 |
-
for word in words:
|
432 |
-
if word not in unique_words:
|
433 |
-
unique_words.append(word)
|
434 |
-
return " ".join(unique_words)
|
435 |
-
|
436 |
-
|
437 |
-
def remove_trailing_punctuation(text: str, fuLL_strip=False):
|
438 |
-
"""
|
439 |
-
remove_trailing_punctuation - remove trailing punctuation from a string
|
440 |
-
|
441 |
-
Args:
|
442 |
-
text (str): [string to be cleaned]
|
443 |
-
|
444 |
-
Returns:
|
445 |
-
[str]: [cleaned string]
|
446 |
-
"""
|
447 |
-
if fuLL_strip:
|
448 |
-
return text.strip("?!.,;:")
|
449 |
-
else:
|
450 |
-
return text.strip(".,;:")
|
451 |
-
|
452 |
-
|
453 |
-
def split_sentences(text: str):
|
454 |
-
"""
|
455 |
-
split_sentences - split a string into a list of sentences that keep their ending punctuation
|
456 |
-
|
457 |
-
Args:
|
458 |
-
text (str): [string to be split]
|
459 |
-
|
460 |
-
Returns:
|
461 |
-
[list]: [list of sentences]
|
462 |
-
"""
|
463 |
-
return re.split(r"(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s", text)
|
464 |
-
|
465 |
|
466 |
def cleantxt_wrap(ugly_text):
|
467 |
"""
|
|
|
39 |
print("\n -------- " * n)
|
40 |
|
41 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
|
43 |
def fast_scandir(dirname: str):
|
44 |
"""
|
|
|
350 |
return extract_loc
|
351 |
|
352 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
353 |
|
354 |
def cleantxt_wrap(ugly_text):
|
355 |
"""
|