peter szemraj commited on
Commit
c566631
1 Parent(s): 47de588

:art: format code to black

Browse files
Files changed (3) hide show
  1. app.py +11 -3
  2. grammar_improve.py +88 -48
  3. utils.py +0 -2
app.py CHANGED
@@ -27,7 +27,14 @@ import os
27
  import sys
28
  from os.path import dirname
29
  import nltk
30
- from grammar_improve import load_ns_checker, neuspell_correct, remove_repeated_words, remove_trailing_punctuation, build_symspell_obj, symspeller
 
 
 
 
 
 
 
31
  from scratch.grammar_tests import load_ns_checker, neuspell_correct
32
 
33
  from utils import (
@@ -112,7 +119,8 @@ def get_parser():
112
  description="submit a question, GPT model responds"
113
  )
114
  parser.add_argument(
115
- "-m", "--model",
 
116
  required=False,
117
  type=str,
118
  default="ballpark-trivia-L",
@@ -170,4 +178,4 @@ if __name__ == "__main__":
170
  iface.launch(
171
  share=True,
172
  enable_queue=True, # also allows for dealing with multiple users simultaneously (per newer gradio version)
173
- )
 
27
  import sys
28
  from os.path import dirname
29
  import nltk
30
+ from grammar_improve import (
31
+ load_ns_checker,
32
+ neuspell_correct,
33
+ remove_repeated_words,
34
+ remove_trailing_punctuation,
35
+ build_symspell_obj,
36
+ symspeller,
37
+ )
38
  from scratch.grammar_tests import load_ns_checker, neuspell_correct
39
 
40
  from utils import (
 
119
  description="submit a question, GPT model responds"
120
  )
121
  parser.add_argument(
122
+ "-m",
123
+ "--model",
124
  required=False,
125
  type=str,
126
  default="ballpark-trivia-L",
 
178
  iface.launch(
179
  share=True,
180
  enable_queue=True, # also allows for dealing with multiple users simultaneously (per newer gradio version)
181
+ )
grammar_improve.py CHANGED
@@ -14,7 +14,6 @@ import re
14
  from symspellpy.symspellpy import SymSpell
15
 
16
 
17
-
18
  def fix_punct_spaces(string):
19
  """
20
  fix_punct_spaces - replace spaces around punctuation with punctuation. For example, "hello , there" -> "hello, there"
@@ -28,9 +27,8 @@ def fix_punct_spaces(string):
28
  str, corrected string
29
  """
30
 
31
- fix_spaces = re.compile(r'\s*([?!.,]+(?:\s+[?!.,]+)*)\s*')
32
- string = fix_spaces.sub(lambda x: "{} ".format(
33
- x.group(1).replace(" ", "")), string)
34
  return string.strip()
35
 
36
 
@@ -90,9 +88,16 @@ start of SymSpell code
90
  """
91
 
92
 
93
- def symspeller(my_string: str, sym_checker=None, max_dist: int = 3, prefix_length: int = 7,
94
- ignore_non_words=True,
95
- dictionary_path: str = None, bigram_path: str = None, verbose=False):
 
 
 
 
 
 
 
96
  """
97
  symspeller - a wrapper for the SymSpell class from symspellpy
98
 
@@ -110,13 +115,21 @@ def symspeller(my_string: str, sym_checker=None, max_dist: int = 3, prefix_lengt
110
  if verbose:
111
  print("creating new SymSpell object")
112
  sym_checker = build_symspell_obj(
113
- edit_dist=max_dist, prefix_length=prefix_length, dictionary_path=dictionary_path, bigram_path=bigram_path,)
 
 
 
 
114
  else:
115
  if verbose:
116
  print("using existing SymSpell object")
117
  # max edit distance per lookup (per single word, not per whole input string)
118
  suggestions = sym_checker.lookup_compound(
119
- my_string, max_edit_distance=max_dist, ignore_non_words=ignore_non_words, ignore_term_with_digits=True, transfer_casing=True,
 
 
 
 
120
  )
121
 
122
  if verbose:
@@ -132,7 +145,12 @@ def symspeller(my_string: str, sym_checker=None, max_dist: int = 3, prefix_lengt
132
  return first_result._term
133
 
134
 
135
- def build_symspell_obj(edit_dist=3, prefix_length=7, dictionary_path=None, bigram_path=None,):
 
 
 
 
 
136
  """
137
  build_symspell_obj [build a SymSpell object]
138
 
@@ -142,18 +160,27 @@ def build_symspell_obj(edit_dist=3, prefix_length=7, dictionary_path=None, bigra
142
  Returns:
143
  SymSpell: a SymSpell object
144
  """
145
- dictionary_path = r"symspell_rsc/frequency_dictionary_en_82_765.txt" if dictionary_path is None else dictionary_path
146
- bigram_path = r"symspell_rsc/frequency_bigramdictionary_en_243_342.txt" if bigram_path is None else bigram_path
 
 
 
 
 
 
 
 
147
  sym_checker = SymSpell(
148
- max_dictionary_edit_distance=edit_dist, prefix_length=prefix_length)
 
149
  # term_index is the column of the term and count_index is the
150
  # column of the term frequency
151
  sym_checker.load_dictionary(dictionary_path, term_index=0, count_index=1)
152
- sym_checker.load_bigram_dictionary(
153
- bigram_path, term_index=0, count_index=2)
154
 
155
  return sym_checker
156
 
 
157
  """
158
  NEEDED FOR T5
159
  import torch
@@ -167,6 +194,7 @@ gc_model = T5ForConditionalGeneration.from_pretrained(model_name).to(torch_devic
167
 
168
  """
169
 
 
170
  def t5b_correction(prompt: str, korrektor, verbose=False, beams=4):
171
  """
172
  t5b_correction - correct a string using a text2textgen pipeline model from transformers
@@ -186,18 +214,19 @@ def t5b_correction(prompt: str, korrektor, verbose=False, beams=4):
186
  p_min_len = int(math.ceil(0.9 * len(prompt)))
187
  p_max_len = int(math.ceil(1.1 * len(prompt)))
188
  if verbose:
189
- print(f'setting min to {p_min_len} and max to {p_max_len}\n')
190
- gcorr_result = korrektor(f"grammar: {prompt}",
191
- return_text=True,
192
- clean_up_tokenization_spaces=True,
193
- num_beams=beams,
194
- max_length=p_max_len,
195
- repetition_penalty=1.3,
196
- length_penalty=0.2,
197
- no_repeat_ngram_size=3,
198
- )
 
199
  if verbose:
200
- print(f'grammar correction result: \n\t{gcorr_result}\n')
201
  return gcorr_result
202
 
203
 
@@ -244,7 +273,7 @@ def load_ns_checker(customckr=None):
244
 
245
  def neuspell_correct(input_text: str, checker=None, verbose=False):
246
  """
247
- neuspell_correct - correct a string using neuspell.
248
  note that modificaitons to the checker are needed if doing list-based corrections
249
 
250
  Parameters
@@ -264,7 +293,7 @@ def neuspell_correct(input_text: str, checker=None, verbose=False):
264
  cleaned_txt = fix_punct_spaces(corrected)
265
 
266
  if verbose:
267
- print(f'neuspell correction result: \n\t{cleaned_txt}\n')
268
  return cleaned_txt
269
 
270
 
@@ -310,11 +339,17 @@ def DLA_correct(qphrase: str):
310
  return " ".join(full_cor)
311
 
312
 
313
- def correct_grammar(input_text: str, tokenizer, model,
314
- n_results: int = 1,
315
- beams: int = 8,
316
- temp=1, uniq_ngrams=2, rep_penalty=1.5,
317
- device='cpu'):
 
 
 
 
 
 
318
  """
319
  correct_grammar - correct a string using a text2textgen pipeline model from transformers.
320
  This function is an alternative to the t5b_correction function.
@@ -337,21 +372,26 @@ def correct_grammar(input_text: str, tokenizer, model,
337
  """
338
  if len(input_text) < 5:
339
  return input_text
340
- max_length = min(int(math.ceil(len(input_text)*1.2)), 128)
341
- batch = tokenizer([input_text], truncation=True,
342
- padding='max_length',
343
- max_length=max_length, return_tensors="pt").to(device)
344
- translated = model.generate(**batch,
345
- max_length=max_length,
346
- min_length=min(10, len(input_text)),
347
- no_repeat_ngram_size=uniq_ngrams,
348
- repetition_penalty=rep_penalty,
349
- num_beams=beams,
350
- num_return_sequences=n_results,
351
- temperature=temp)
352
-
353
- tgt_text = tokenizer.batch_decode(translated,
354
- skip_special_tokens=True)
 
 
 
 
 
355
 
356
  if isinstance(tgt_text, list):
357
  return tgt_text[0]
 
14
  from symspellpy.symspellpy import SymSpell
15
 
16
 
 
17
  def fix_punct_spaces(string):
18
  """
19
  fix_punct_spaces - replace spaces around punctuation with punctuation. For example, "hello , there" -> "hello, there"
 
27
  str, corrected string
28
  """
29
 
30
+ fix_spaces = re.compile(r"\s*([?!.,]+(?:\s+[?!.,]+)*)\s*")
31
+ string = fix_spaces.sub(lambda x: "{} ".format(x.group(1).replace(" ", "")), string)
 
32
  return string.strip()
33
 
34
 
 
88
  """
89
 
90
 
91
+ def symspeller(
92
+ my_string: str,
93
+ sym_checker=None,
94
+ max_dist: int = 3,
95
+ prefix_length: int = 7,
96
+ ignore_non_words=True,
97
+ dictionary_path: str = None,
98
+ bigram_path: str = None,
99
+ verbose=False,
100
+ ):
101
  """
102
  symspeller - a wrapper for the SymSpell class from symspellpy
103
 
 
115
  if verbose:
116
  print("creating new SymSpell object")
117
  sym_checker = build_symspell_obj(
118
+ edit_dist=max_dist,
119
+ prefix_length=prefix_length,
120
+ dictionary_path=dictionary_path,
121
+ bigram_path=bigram_path,
122
+ )
123
  else:
124
  if verbose:
125
  print("using existing SymSpell object")
126
  # max edit distance per lookup (per single word, not per whole input string)
127
  suggestions = sym_checker.lookup_compound(
128
+ my_string,
129
+ max_edit_distance=max_dist,
130
+ ignore_non_words=ignore_non_words,
131
+ ignore_term_with_digits=True,
132
+ transfer_casing=True,
133
  )
134
 
135
  if verbose:
 
145
  return first_result._term
146
 
147
 
148
+ def build_symspell_obj(
149
+ edit_dist=3,
150
+ prefix_length=7,
151
+ dictionary_path=None,
152
+ bigram_path=None,
153
+ ):
154
  """
155
  build_symspell_obj [build a SymSpell object]
156
 
 
160
  Returns:
161
  SymSpell: a SymSpell object
162
  """
163
+ dictionary_path = (
164
+ r"symspell_rsc/frequency_dictionary_en_82_765.txt"
165
+ if dictionary_path is None
166
+ else dictionary_path
167
+ )
168
+ bigram_path = (
169
+ r"symspell_rsc/frequency_bigramdictionary_en_243_342.txt"
170
+ if bigram_path is None
171
+ else bigram_path
172
+ )
173
  sym_checker = SymSpell(
174
+ max_dictionary_edit_distance=edit_dist, prefix_length=prefix_length
175
+ )
176
  # term_index is the column of the term and count_index is the
177
  # column of the term frequency
178
  sym_checker.load_dictionary(dictionary_path, term_index=0, count_index=1)
179
+ sym_checker.load_bigram_dictionary(bigram_path, term_index=0, count_index=2)
 
180
 
181
  return sym_checker
182
 
183
+
184
  """
185
  NEEDED FOR T5
186
  import torch
 
194
 
195
  """
196
 
197
+
198
  def t5b_correction(prompt: str, korrektor, verbose=False, beams=4):
199
  """
200
  t5b_correction - correct a string using a text2textgen pipeline model from transformers
 
214
  p_min_len = int(math.ceil(0.9 * len(prompt)))
215
  p_max_len = int(math.ceil(1.1 * len(prompt)))
216
  if verbose:
217
+ print(f"setting min to {p_min_len} and max to {p_max_len}\n")
218
+ gcorr_result = korrektor(
219
+ f"grammar: {prompt}",
220
+ return_text=True,
221
+ clean_up_tokenization_spaces=True,
222
+ num_beams=beams,
223
+ max_length=p_max_len,
224
+ repetition_penalty=1.3,
225
+ length_penalty=0.2,
226
+ no_repeat_ngram_size=3,
227
+ )
228
  if verbose:
229
+ print(f"grammar correction result: \n\t{gcorr_result}\n")
230
  return gcorr_result
231
 
232
 
 
273
 
274
  def neuspell_correct(input_text: str, checker=None, verbose=False):
275
  """
276
+ neuspell_correct - correct a string using neuspell.
277
  note that modificaitons to the checker are needed if doing list-based corrections
278
 
279
  Parameters
 
293
  cleaned_txt = fix_punct_spaces(corrected)
294
 
295
  if verbose:
296
+ print(f"neuspell correction result: \n\t{cleaned_txt}\n")
297
  return cleaned_txt
298
 
299
 
 
339
  return " ".join(full_cor)
340
 
341
 
342
+ def correct_grammar(
343
+ input_text: str,
344
+ tokenizer,
345
+ model,
346
+ n_results: int = 1,
347
+ beams: int = 8,
348
+ temp=1,
349
+ uniq_ngrams=2,
350
+ rep_penalty=1.5,
351
+ device="cpu",
352
+ ):
353
  """
354
  correct_grammar - correct a string using a text2textgen pipeline model from transformers.
355
  This function is an alternative to the t5b_correction function.
 
372
  """
373
  if len(input_text) < 5:
374
  return input_text
375
+ max_length = min(int(math.ceil(len(input_text) * 1.2)), 128)
376
+ batch = tokenizer(
377
+ [input_text],
378
+ truncation=True,
379
+ padding="max_length",
380
+ max_length=max_length,
381
+ return_tensors="pt",
382
+ ).to(device)
383
+ translated = model.generate(
384
+ **batch,
385
+ max_length=max_length,
386
+ min_length=min(10, len(input_text)),
387
+ no_repeat_ngram_size=uniq_ngrams,
388
+ repetition_penalty=rep_penalty,
389
+ num_beams=beams,
390
+ num_return_sequences=n_results,
391
+ temperature=temp,
392
+ )
393
+
394
+ tgt_text = tokenizer.batch_decode(translated, skip_special_tokens=True)
395
 
396
  if isinstance(tgt_text, list):
397
  return tgt_text[0]
utils.py CHANGED
@@ -39,7 +39,6 @@ def print_spacer(n=1):
39
  print("\n -------- " * n)
40
 
41
 
42
-
43
  def fast_scandir(dirname: str):
44
  """
45
  fast_scandir [an os.path-based means to return all subfolders in a given filepath]
@@ -350,7 +349,6 @@ def dl_extract_zip(
350
  return extract_loc
351
 
352
 
353
-
354
  def cleantxt_wrap(ugly_text):
355
  """
356
  cleantxt_wrap - applies the clean function to a string.
 
39
  print("\n -------- " * n)
40
 
41
 
 
42
  def fast_scandir(dirname: str):
43
  """
44
  fast_scandir [an os.path-based means to return all subfolders in a given filepath]
 
349
  return extract_loc
350
 
351
 
 
352
  def cleantxt_wrap(ugly_text):
353
  """
354
  cleantxt_wrap - applies the clean function to a string.