peter szemraj commited on
Commit
7b10a08
1 Parent(s): cfd67f6

:truck: consolidate grammar-related functions to one file

Browse files
Files changed (2) hide show
  1. grammar_improve.py +123 -4
  2. utils.py +0 -112
grammar_improve.py CHANGED
@@ -10,7 +10,7 @@ import math
10
  from cleantext import clean
11
  import time
12
  import re
13
- ""
14
 
15
 
16
  def fix_punct_spaces(string):
@@ -45,6 +45,126 @@ def split_sentences(text: str):
45
  return re.split(r"(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s", text)
46
 
47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  def t5b_correction(prompt: str, korrektor, verbose=False, beams=4):
49
  """
50
  t5b_correction - correct a string using a text2textgen pipeline model from transformers
@@ -100,7 +220,7 @@ def disp_neuspell_chkrs():
100
 
101
  def load_ns_checker(customckr=None):
102
  """
103
- load_ns_checker - helper function, load a neuspell checker from huggingface transformers
104
 
105
  Args:
106
  customckr (neuspell.NeuSpell): [neuspell checker object], optional, if not provided, will load the default checker
@@ -211,8 +331,7 @@ def correct_grammar(input_text: str, tokenizer, model,
211
 
212
  Returns
213
  -------
214
- [type]
215
- [description]
216
  """
217
  if len(input_text) < 5:
218
  return input_text
 
10
  from cleantext import clean
11
  import time
12
  import re
13
+
14
 
15
 
16
  def fix_punct_spaces(string):
 
45
  return re.split(r"(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s", text)
46
 
47
 
48
+ def remove_repeated_words(bot_response):
49
+ """
50
+ remove_repeated_words - remove repeated words from a string, returning only the first instance of each word
51
+
52
+ Parameters
53
+ ----------
54
+ bot_response : str
55
+ string to remove repeated words from
56
+
57
+ Returns
58
+ -------
59
+ str
60
+ string containing the first instance of each word
61
+ """
62
+ words = bot_response.split()
63
+ unique_words = []
64
+ for word in words:
65
+ if word not in unique_words:
66
+ unique_words.append(word)
67
+ return " ".join(unique_words)
68
+
69
+
70
+ def remove_trailing_punctuation(text: str, fuLL_strip=False):
71
+ """
72
+ remove_trailing_punctuation - remove trailing punctuation from a string. Purpose is to seem more natural to end users
73
+
74
+ Args:
75
+ text (str): [string to be cleaned]
76
+
77
+ Returns:
78
+ [str]: [cleaned string]
79
+ """
80
+ if fuLL_strip:
81
+ return text.strip("?!.,;:")
82
+ else:
83
+ return text.strip(".,;:")
84
+
85
+
86
+ """
87
+ start of SymSpell code
88
+ """
89
+
90
+
91
+ def symspeller(my_string: str, sym_checker=None, max_dist: int = 3, prefix_length: int = 7,
92
+ ignore_non_words=True,
93
+ dictionary_path: str = None, bigram_path: str = None, verbose=False):
94
+ """
95
+ symspeller - a wrapper for the SymSpell class from symspellpy
96
+
97
+ Parameters
98
+ ----------
99
+ my_string : str, required, default=None, the string to be checked
100
+ sym_checker : SymSpell, optional, default=None, the SymSpell object to use
101
+ max_dist : int, optional, default=3, the maximum distance to look for replacements
102
+ """
103
+
104
+ assert len(my_string) > 0, "entered string for correction is empty"
105
+
106
+ if sym_checker is None:
107
+ # need to create a new class object. user can specify their own dictionary and bigram files
108
+ if verbose:
109
+ print("creating new SymSpell object")
110
+ sym_checker = build_symspell_obj(
111
+ edit_dist=max_dist, prefix_length=prefix_length, dictionary_path=dictionary_path, bigram_path=bigram_path,)
112
+ else:
113
+ if verbose:
114
+ print("using existing SymSpell object")
115
+ # max edit distance per lookup (per single word, not per whole input string)
116
+ suggestions = sym_checker.lookup_compound(
117
+ my_string, max_edit_distance=max_dist, ignore_non_words=ignore_non_words, ignore_term_with_digits=True, transfer_casing=True,
118
+ )
119
+
120
+ if verbose:
121
+ print(f"{len(suggestions)} suggestions found")
122
+ print(f"the original string is:\n\t{my_string}")
123
+ sug_list = [sug.term for sug in suggestions]
124
+ print(f"suggestions:\n\t{sug_list}\n")
125
+
126
+ if len(suggestions) < 1:
127
+ return clean(my_string) # no correction because no suggestions
128
+ else:
129
+ first_result = suggestions[0] # first result is the most likely
130
+ return first_result._term
131
+
132
+
133
+ def build_symspell_obj(edit_dist=3, prefix_length=7, dictionary_path=None, bigram_path=None,):
134
+ """
135
+ build_symspell_obj [build a SymSpell object]
136
+
137
+ Args:
138
+ verbose (bool, optional): Defaults to False.
139
+
140
+ Returns:
141
+ SymSpell: a SymSpell object
142
+ """
143
+ dictionary_path = r"symspell_rsc/frequency_dictionary_en_82_765.txt" if dictionary_path is None else dictionary_path
144
+ bigram_path = r"symspell_rsc/frequency_bigramdictionary_en_243_342.txt" if bigram_path is None else bigram_path
145
+ sym_checker = SymSpell(
146
+ max_dictionary_edit_distance=edit_dist, prefix_length=prefix_length)
147
+ # term_index is the column of the term and count_index is the
148
+ # column of the term frequency
149
+ sym_checker.load_dictionary(dictionary_path, term_index=0, count_index=1)
150
+ sym_checker.load_bigram_dictionary(
151
+ bigram_path, term_index=0, count_index=2)
152
+
153
+ return sym_checker
154
+
155
+ """
156
+ NEEDED FOR T5
157
+ import torch
158
+ from transformers import T5Tokenizer, T5ForConditionalGeneration
159
+
160
+ model_name = 'deep-learning-analytics/GrammarCorrector'
161
+ # torch_device = 'cuda' if torch.cuda.is_available() else 'cpu'
162
+ torch_device = 'cpu'
163
+ gc_tokenizer = T5Tokenizer.from_pretrained(model_name)
164
+ gc_model = T5ForConditionalGeneration.from_pretrained(model_name).to(torch_device)
165
+
166
+ """
167
+
168
  def t5b_correction(prompt: str, korrektor, verbose=False, beams=4):
169
  """
170
  t5b_correction - correct a string using a text2textgen pipeline model from transformers
 
220
 
221
  def load_ns_checker(customckr=None):
222
  """
223
+ load_ns_checker - helper function, load / "set up" a neuspell checker from huggingface transformers
224
 
225
  Args:
226
  customckr (neuspell.NeuSpell): [neuspell checker object], optional, if not provided, will load the default checker
 
331
 
332
  Returns
333
  -------
334
+ str, corrected string (or list of strings if n_results > 1)
 
335
  """
336
  if len(input_text) < 5:
337
  return input_text
utils.py CHANGED
@@ -39,68 +39,6 @@ def print_spacer(n=1):
39
  print("\n -------- " * n)
40
 
41
 
42
- def build_symspell_obj(edit_dist=3, prefix_length=7, dictionary_path=None, bigram_path=None,):
43
- """
44
- build_symspell_obj [build a SymSpell object]
45
-
46
- Args:
47
- verbose (bool, optional): Defaults to False.
48
-
49
- Returns:
50
- SymSpell: a SymSpell object
51
- """
52
- dictionary_path = r"symspell_rsc/frequency_dictionary_en_82_765.txt" if dictionary_path is None else dictionary_path
53
- bigram_path = r"symspell_rsc/frequency_bigramdictionary_en_243_342.txt" if bigram_path is None else bigram_path
54
- sym_checker = SymSpell(
55
- max_dictionary_edit_distance=edit_dist, prefix_length=prefix_length)
56
- # term_index is the column of the term and count_index is the
57
- # column of the term frequency
58
- sym_checker.load_dictionary(dictionary_path, term_index=0, count_index=1)
59
- sym_checker.load_bigram_dictionary(bigram_path, term_index=0, count_index=2)
60
-
61
- return sym_checker
62
-
63
-
64
- def symspeller(my_string: str, sym_checker=None, max_dist: int = 3, prefix_length: int = 7,
65
- ignore_non_words=True,
66
- dictionary_path:str=None, bigram_path:str=None, verbose=False):
67
- """
68
- symspeller - a wrapper for the SymSpell class from symspellpy
69
-
70
- Parameters
71
- ----------
72
- my_string : str, required, default=None, the string to be checked
73
- sym_checker : SymSpell, optional, default=None, the SymSpell object to use
74
- max_dist : int, optional, default=3, the maximum distance to look for replacements
75
- """
76
-
77
- assert len(my_string) > 0, "entered string for correction is empty"
78
-
79
- if sym_checker is None:
80
- # need to create a new class object. user can specify their own dictionary and bigram files
81
- if verbose:
82
- print("creating new SymSpell object")
83
- sym_checker = build_symspell_obj(
84
- edit_dist=max_dist, prefix_length=prefix_length, dictionary_path=dictionary_path, bigram_path=bigram_path,)
85
- else:
86
- if verbose: print("using existing SymSpell object")
87
- # max edit distance per lookup (per single word, not per whole input string)
88
- suggestions = sym_checker.lookup_compound(
89
- my_string, max_edit_distance=max_dist, ignore_non_words=ignore_non_words, ignore_term_with_digits=True, transfer_casing=True,
90
- )
91
-
92
- if verbose:
93
- print(f"{len(suggestions)} suggestions found")
94
- print(f"the original string is:\n\t{my_string}")
95
- sug_list = [sug.term for sug in suggestions]
96
- print(f"suggestions:\n\t{sug_list}\n")
97
-
98
- if len(suggestions) < 1:
99
- return clean(my_string) # no correction because no suggestions
100
- else:
101
- first_result = suggestions[0] # first result is the most likely
102
- return first_result._term
103
-
104
 
105
  def fast_scandir(dirname: str):
106
  """
@@ -412,56 +350,6 @@ def dl_extract_zip(
412
  return extract_loc
413
 
414
 
415
- def remove_repeated_words(bot_response):
416
- """
417
- remove_repeated_words - remove repeated words from a string, returning only the first instance of each word
418
-
419
- Parameters
420
- ----------
421
- bot_response : str
422
- string to remove repeated words from
423
-
424
- Returns
425
- -------
426
- str
427
- string containing the first instance of each word
428
- """
429
- words = bot_response.split()
430
- unique_words = []
431
- for word in words:
432
- if word not in unique_words:
433
- unique_words.append(word)
434
- return " ".join(unique_words)
435
-
436
-
437
- def remove_trailing_punctuation(text: str, fuLL_strip=False):
438
- """
439
- remove_trailing_punctuation - remove trailing punctuation from a string
440
-
441
- Args:
442
- text (str): [string to be cleaned]
443
-
444
- Returns:
445
- [str]: [cleaned string]
446
- """
447
- if fuLL_strip:
448
- return text.strip("?!.,;:")
449
- else:
450
- return text.strip(".,;:")
451
-
452
-
453
- def split_sentences(text: str):
454
- """
455
- split_sentences - split a string into a list of sentences that keep their ending punctuation
456
-
457
- Args:
458
- text (str): [string to be split]
459
-
460
- Returns:
461
- [list]: [list of sentences]
462
- """
463
- return re.split(r"(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s", text)
464
-
465
 
466
  def cleantxt_wrap(ugly_text):
467
  """
 
39
  print("\n -------- " * n)
40
 
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
  def fast_scandir(dirname: str):
44
  """
 
350
  return extract_loc
351
 
352
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
353
 
354
  def cleantxt_wrap(ugly_text):
355
  """