Spaces:
Runtime error
Runtime error
peter szemraj
commited on
Commit
•
c566631
1
Parent(s):
47de588
:art: format code to black
Browse files- app.py +11 -3
- grammar_improve.py +88 -48
- utils.py +0 -2
app.py
CHANGED
@@ -27,7 +27,14 @@ import os
|
|
27 |
import sys
|
28 |
from os.path import dirname
|
29 |
import nltk
|
30 |
-
from grammar_improve import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
from scratch.grammar_tests import load_ns_checker, neuspell_correct
|
32 |
|
33 |
from utils import (
|
@@ -112,7 +119,8 @@ def get_parser():
|
|
112 |
description="submit a question, GPT model responds"
|
113 |
)
|
114 |
parser.add_argument(
|
115 |
-
"-m",
|
|
|
116 |
required=False,
|
117 |
type=str,
|
118 |
default="ballpark-trivia-L",
|
@@ -170,4 +178,4 @@ if __name__ == "__main__":
|
|
170 |
iface.launch(
|
171 |
share=True,
|
172 |
enable_queue=True, # also allows for dealing with multiple users simultaneously (per newer gradio version)
|
173 |
-
)
|
|
|
27 |
import sys
|
28 |
from os.path import dirname
|
29 |
import nltk
|
30 |
+
from grammar_improve import (
|
31 |
+
load_ns_checker,
|
32 |
+
neuspell_correct,
|
33 |
+
remove_repeated_words,
|
34 |
+
remove_trailing_punctuation,
|
35 |
+
build_symspell_obj,
|
36 |
+
symspeller,
|
37 |
+
)
|
38 |
from scratch.grammar_tests import load_ns_checker, neuspell_correct
|
39 |
|
40 |
from utils import (
|
|
|
119 |
description="submit a question, GPT model responds"
|
120 |
)
|
121 |
parser.add_argument(
|
122 |
+
"-m",
|
123 |
+
"--model",
|
124 |
required=False,
|
125 |
type=str,
|
126 |
default="ballpark-trivia-L",
|
|
|
178 |
iface.launch(
|
179 |
share=True,
|
180 |
enable_queue=True, # also allows for dealing with multiple users simultaneously (per newer gradio version)
|
181 |
+
)
|
grammar_improve.py
CHANGED
@@ -14,7 +14,6 @@ import re
|
|
14 |
from symspellpy.symspellpy import SymSpell
|
15 |
|
16 |
|
17 |
-
|
18 |
def fix_punct_spaces(string):
|
19 |
"""
|
20 |
fix_punct_spaces - replace spaces around punctuation with punctuation. For example, "hello , there" -> "hello, there"
|
@@ -28,9 +27,8 @@ def fix_punct_spaces(string):
|
|
28 |
str, corrected string
|
29 |
"""
|
30 |
|
31 |
-
fix_spaces = re.compile(r
|
32 |
-
string = fix_spaces.sub(lambda x: "{} ".format(
|
33 |
-
x.group(1).replace(" ", "")), string)
|
34 |
return string.strip()
|
35 |
|
36 |
|
@@ -90,9 +88,16 @@ start of SymSpell code
|
|
90 |
"""
|
91 |
|
92 |
|
93 |
-
def symspeller(
|
94 |
-
|
95 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
"""
|
97 |
symspeller - a wrapper for the SymSpell class from symspellpy
|
98 |
|
@@ -110,13 +115,21 @@ def symspeller(my_string: str, sym_checker=None, max_dist: int = 3, prefix_lengt
|
|
110 |
if verbose:
|
111 |
print("creating new SymSpell object")
|
112 |
sym_checker = build_symspell_obj(
|
113 |
-
edit_dist=max_dist,
|
|
|
|
|
|
|
|
|
114 |
else:
|
115 |
if verbose:
|
116 |
print("using existing SymSpell object")
|
117 |
# max edit distance per lookup (per single word, not per whole input string)
|
118 |
suggestions = sym_checker.lookup_compound(
|
119 |
-
my_string,
|
|
|
|
|
|
|
|
|
120 |
)
|
121 |
|
122 |
if verbose:
|
@@ -132,7 +145,12 @@ def symspeller(my_string: str, sym_checker=None, max_dist: int = 3, prefix_lengt
|
|
132 |
return first_result._term
|
133 |
|
134 |
|
135 |
-
def build_symspell_obj(
|
|
|
|
|
|
|
|
|
|
|
136 |
"""
|
137 |
build_symspell_obj [build a SymSpell object]
|
138 |
|
@@ -142,18 +160,27 @@ def build_symspell_obj(edit_dist=3, prefix_length=7, dictionary_path=None, bigra
|
|
142 |
Returns:
|
143 |
SymSpell: a SymSpell object
|
144 |
"""
|
145 |
-
dictionary_path =
|
146 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
147 |
sym_checker = SymSpell(
|
148 |
-
max_dictionary_edit_distance=edit_dist, prefix_length=prefix_length
|
|
|
149 |
# term_index is the column of the term and count_index is the
|
150 |
# column of the term frequency
|
151 |
sym_checker.load_dictionary(dictionary_path, term_index=0, count_index=1)
|
152 |
-
sym_checker.load_bigram_dictionary(
|
153 |
-
bigram_path, term_index=0, count_index=2)
|
154 |
|
155 |
return sym_checker
|
156 |
|
|
|
157 |
"""
|
158 |
NEEDED FOR T5
|
159 |
import torch
|
@@ -167,6 +194,7 @@ gc_model = T5ForConditionalGeneration.from_pretrained(model_name).to(torch_devic
|
|
167 |
|
168 |
"""
|
169 |
|
|
|
170 |
def t5b_correction(prompt: str, korrektor, verbose=False, beams=4):
|
171 |
"""
|
172 |
t5b_correction - correct a string using a text2textgen pipeline model from transformers
|
@@ -186,18 +214,19 @@ def t5b_correction(prompt: str, korrektor, verbose=False, beams=4):
|
|
186 |
p_min_len = int(math.ceil(0.9 * len(prompt)))
|
187 |
p_max_len = int(math.ceil(1.1 * len(prompt)))
|
188 |
if verbose:
|
189 |
-
print(f
|
190 |
-
gcorr_result = korrektor(
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
|
|
199 |
if verbose:
|
200 |
-
print(f
|
201 |
return gcorr_result
|
202 |
|
203 |
|
@@ -244,7 +273,7 @@ def load_ns_checker(customckr=None):
|
|
244 |
|
245 |
def neuspell_correct(input_text: str, checker=None, verbose=False):
|
246 |
"""
|
247 |
-
neuspell_correct - correct a string using neuspell.
|
248 |
note that modificaitons to the checker are needed if doing list-based corrections
|
249 |
|
250 |
Parameters
|
@@ -264,7 +293,7 @@ def neuspell_correct(input_text: str, checker=None, verbose=False):
|
|
264 |
cleaned_txt = fix_punct_spaces(corrected)
|
265 |
|
266 |
if verbose:
|
267 |
-
print(f
|
268 |
return cleaned_txt
|
269 |
|
270 |
|
@@ -310,11 +339,17 @@ def DLA_correct(qphrase: str):
|
|
310 |
return " ".join(full_cor)
|
311 |
|
312 |
|
313 |
-
def correct_grammar(
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
318 |
"""
|
319 |
correct_grammar - correct a string using a text2textgen pipeline model from transformers.
|
320 |
This function is an alternative to the t5b_correction function.
|
@@ -337,21 +372,26 @@ def correct_grammar(input_text: str, tokenizer, model,
|
|
337 |
"""
|
338 |
if len(input_text) < 5:
|
339 |
return input_text
|
340 |
-
max_length = min(int(math.ceil(len(input_text)*1.2)), 128)
|
341 |
-
batch = tokenizer(
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
|
352 |
-
|
353 |
-
|
354 |
-
|
|
|
|
|
|
|
|
|
|
|
355 |
|
356 |
if isinstance(tgt_text, list):
|
357 |
return tgt_text[0]
|
|
|
14 |
from symspellpy.symspellpy import SymSpell
|
15 |
|
16 |
|
|
|
17 |
def fix_punct_spaces(string):
|
18 |
"""
|
19 |
fix_punct_spaces - replace spaces around punctuation with punctuation. For example, "hello , there" -> "hello, there"
|
|
|
27 |
str, corrected string
|
28 |
"""
|
29 |
|
30 |
+
fix_spaces = re.compile(r"\s*([?!.,]+(?:\s+[?!.,]+)*)\s*")
|
31 |
+
string = fix_spaces.sub(lambda x: "{} ".format(x.group(1).replace(" ", "")), string)
|
|
|
32 |
return string.strip()
|
33 |
|
34 |
|
|
|
88 |
"""
|
89 |
|
90 |
|
91 |
+
def symspeller(
|
92 |
+
my_string: str,
|
93 |
+
sym_checker=None,
|
94 |
+
max_dist: int = 3,
|
95 |
+
prefix_length: int = 7,
|
96 |
+
ignore_non_words=True,
|
97 |
+
dictionary_path: str = None,
|
98 |
+
bigram_path: str = None,
|
99 |
+
verbose=False,
|
100 |
+
):
|
101 |
"""
|
102 |
symspeller - a wrapper for the SymSpell class from symspellpy
|
103 |
|
|
|
115 |
if verbose:
|
116 |
print("creating new SymSpell object")
|
117 |
sym_checker = build_symspell_obj(
|
118 |
+
edit_dist=max_dist,
|
119 |
+
prefix_length=prefix_length,
|
120 |
+
dictionary_path=dictionary_path,
|
121 |
+
bigram_path=bigram_path,
|
122 |
+
)
|
123 |
else:
|
124 |
if verbose:
|
125 |
print("using existing SymSpell object")
|
126 |
# max edit distance per lookup (per single word, not per whole input string)
|
127 |
suggestions = sym_checker.lookup_compound(
|
128 |
+
my_string,
|
129 |
+
max_edit_distance=max_dist,
|
130 |
+
ignore_non_words=ignore_non_words,
|
131 |
+
ignore_term_with_digits=True,
|
132 |
+
transfer_casing=True,
|
133 |
)
|
134 |
|
135 |
if verbose:
|
|
|
145 |
return first_result._term
|
146 |
|
147 |
|
148 |
+
def build_symspell_obj(
|
149 |
+
edit_dist=3,
|
150 |
+
prefix_length=7,
|
151 |
+
dictionary_path=None,
|
152 |
+
bigram_path=None,
|
153 |
+
):
|
154 |
"""
|
155 |
build_symspell_obj [build a SymSpell object]
|
156 |
|
|
|
160 |
Returns:
|
161 |
SymSpell: a SymSpell object
|
162 |
"""
|
163 |
+
dictionary_path = (
|
164 |
+
r"symspell_rsc/frequency_dictionary_en_82_765.txt"
|
165 |
+
if dictionary_path is None
|
166 |
+
else dictionary_path
|
167 |
+
)
|
168 |
+
bigram_path = (
|
169 |
+
r"symspell_rsc/frequency_bigramdictionary_en_243_342.txt"
|
170 |
+
if bigram_path is None
|
171 |
+
else bigram_path
|
172 |
+
)
|
173 |
sym_checker = SymSpell(
|
174 |
+
max_dictionary_edit_distance=edit_dist, prefix_length=prefix_length
|
175 |
+
)
|
176 |
# term_index is the column of the term and count_index is the
|
177 |
# column of the term frequency
|
178 |
sym_checker.load_dictionary(dictionary_path, term_index=0, count_index=1)
|
179 |
+
sym_checker.load_bigram_dictionary(bigram_path, term_index=0, count_index=2)
|
|
|
180 |
|
181 |
return sym_checker
|
182 |
|
183 |
+
|
184 |
"""
|
185 |
NEEDED FOR T5
|
186 |
import torch
|
|
|
194 |
|
195 |
"""
|
196 |
|
197 |
+
|
198 |
def t5b_correction(prompt: str, korrektor, verbose=False, beams=4):
|
199 |
"""
|
200 |
t5b_correction - correct a string using a text2textgen pipeline model from transformers
|
|
|
214 |
p_min_len = int(math.ceil(0.9 * len(prompt)))
|
215 |
p_max_len = int(math.ceil(1.1 * len(prompt)))
|
216 |
if verbose:
|
217 |
+
print(f"setting min to {p_min_len} and max to {p_max_len}\n")
|
218 |
+
gcorr_result = korrektor(
|
219 |
+
f"grammar: {prompt}",
|
220 |
+
return_text=True,
|
221 |
+
clean_up_tokenization_spaces=True,
|
222 |
+
num_beams=beams,
|
223 |
+
max_length=p_max_len,
|
224 |
+
repetition_penalty=1.3,
|
225 |
+
length_penalty=0.2,
|
226 |
+
no_repeat_ngram_size=3,
|
227 |
+
)
|
228 |
if verbose:
|
229 |
+
print(f"grammar correction result: \n\t{gcorr_result}\n")
|
230 |
return gcorr_result
|
231 |
|
232 |
|
|
|
273 |
|
274 |
def neuspell_correct(input_text: str, checker=None, verbose=False):
|
275 |
"""
|
276 |
+
neuspell_correct - correct a string using neuspell.
|
277 |
note that modificaitons to the checker are needed if doing list-based corrections
|
278 |
|
279 |
Parameters
|
|
|
293 |
cleaned_txt = fix_punct_spaces(corrected)
|
294 |
|
295 |
if verbose:
|
296 |
+
print(f"neuspell correction result: \n\t{cleaned_txt}\n")
|
297 |
return cleaned_txt
|
298 |
|
299 |
|
|
|
339 |
return " ".join(full_cor)
|
340 |
|
341 |
|
342 |
+
def correct_grammar(
|
343 |
+
input_text: str,
|
344 |
+
tokenizer,
|
345 |
+
model,
|
346 |
+
n_results: int = 1,
|
347 |
+
beams: int = 8,
|
348 |
+
temp=1,
|
349 |
+
uniq_ngrams=2,
|
350 |
+
rep_penalty=1.5,
|
351 |
+
device="cpu",
|
352 |
+
):
|
353 |
"""
|
354 |
correct_grammar - correct a string using a text2textgen pipeline model from transformers.
|
355 |
This function is an alternative to the t5b_correction function.
|
|
|
372 |
"""
|
373 |
if len(input_text) < 5:
|
374 |
return input_text
|
375 |
+
max_length = min(int(math.ceil(len(input_text) * 1.2)), 128)
|
376 |
+
batch = tokenizer(
|
377 |
+
[input_text],
|
378 |
+
truncation=True,
|
379 |
+
padding="max_length",
|
380 |
+
max_length=max_length,
|
381 |
+
return_tensors="pt",
|
382 |
+
).to(device)
|
383 |
+
translated = model.generate(
|
384 |
+
**batch,
|
385 |
+
max_length=max_length,
|
386 |
+
min_length=min(10, len(input_text)),
|
387 |
+
no_repeat_ngram_size=uniq_ngrams,
|
388 |
+
repetition_penalty=rep_penalty,
|
389 |
+
num_beams=beams,
|
390 |
+
num_return_sequences=n_results,
|
391 |
+
temperature=temp,
|
392 |
+
)
|
393 |
+
|
394 |
+
tgt_text = tokenizer.batch_decode(translated, skip_special_tokens=True)
|
395 |
|
396 |
if isinstance(tgt_text, list):
|
397 |
return tgt_text[0]
|
utils.py
CHANGED
@@ -39,7 +39,6 @@ def print_spacer(n=1):
|
|
39 |
print("\n -------- " * n)
|
40 |
|
41 |
|
42 |
-
|
43 |
def fast_scandir(dirname: str):
|
44 |
"""
|
45 |
fast_scandir [an os.path-based means to return all subfolders in a given filepath]
|
@@ -350,7 +349,6 @@ def dl_extract_zip(
|
|
350 |
return extract_loc
|
351 |
|
352 |
|
353 |
-
|
354 |
def cleantxt_wrap(ugly_text):
|
355 |
"""
|
356 |
cleantxt_wrap - applies the clean function to a string.
|
|
|
39 |
print("\n -------- " * n)
|
40 |
|
41 |
|
|
|
42 |
def fast_scandir(dirname: str):
|
43 |
"""
|
44 |
fast_scandir [an os.path-based means to return all subfolders in a given filepath]
|
|
|
349 |
return extract_loc
|
350 |
|
351 |
|
|
|
352 |
def cleantxt_wrap(ugly_text):
|
353 |
"""
|
354 |
cleantxt_wrap - applies the clean function to a string.
|