peter szemraj commited on
Commit
2830ef7
1 Parent(s): 63bb54c

:construction: adding features for using symspell as a checker

Browse files
Files changed (2) hide show
  1. app.py +15 -5
  2. utils.py +47 -18
app.py CHANGED
@@ -33,6 +33,8 @@ from utils import (
33
  remove_trailing_punctuation,
34
  cleantxt_wrap,
35
  corr,
 
 
36
  )
37
 
38
  nltk.download("stopwords") # TODO: find where this requirement originates from
@@ -44,7 +46,7 @@ cwd = Path.cwd()
44
  my_cwd = str(cwd.resolve()) # string so it can be passed to os.path() objects
45
 
46
 
47
- def gramformer_correct(corrector, qphrase: str):
48
  """
49
  gramformer_correct - correct a string using a text2textgen pipeline model from transformers
50
  Args:
@@ -58,8 +60,8 @@ def gramformer_correct(corrector, qphrase: str):
58
  clean(qphrase), return_text=True, clean_up_tokenization_spaces=True
59
  )
60
  return corrected[0]["generated_text"]
61
- except:
62
- print("NOTE - failed to correct with gramformer")
63
  return clean(qphrase)
64
 
65
 
@@ -86,8 +88,8 @@ def ask_gpt(message: str):
86
  temp=0.75,
87
  top_p=0.65,
88
  )
89
- uniques = remove_repeated_words(resp["out_text"])
90
- bot_resp = corr((uniques))
91
  rt = round(time.perf_counter() - st, 2)
92
  print(f"took {rt} sec to respond")
93
  return remove_trailing_punctuation(bot_resp)
@@ -134,6 +136,13 @@ def get_parser():
134
  help="folder - with respect to git directory of your repo that has the model files in it (pytorch.bin + "
135
  "config.json)",
136
  )
 
 
 
 
 
 
 
137
  parser.add_argument(
138
  "--gram-model",
139
  required=False,
@@ -151,6 +160,7 @@ if __name__ == "__main__":
151
  model_loc = cwd.parent / default_model
152
  model_loc = str(model_loc.resolve())
153
  gram_model = args.gram_model
 
154
  print(f"using model stored here: \n {model_loc} \n")
155
  iface = gr.Interface(
156
  chat,
 
33
  remove_trailing_punctuation,
34
  cleantxt_wrap,
35
  corr,
36
+ build_symspell_obj,
37
+ symspeller,
38
  )
39
 
40
  nltk.download("stopwords") # TODO: find where this requirement originates from
 
46
  my_cwd = str(cwd.resolve()) # string so it can be passed to os.path() objects
47
 
48
 
49
+ def grammarpipe(corrector, qphrase: str):
50
  """
51
  gramformer_correct - correct a string using a text2textgen pipeline model from transformers
52
  Args:
 
60
  clean(qphrase), return_text=True, clean_up_tokenization_spaces=True
61
  )
62
  return corrected[0]["generated_text"]
63
+ except Exception as e:
64
+ print(f"NOTE - failed to correct with grammarpipe:\n {e}")
65
  return clean(qphrase)
66
 
67
 
 
88
  temp=0.75,
89
  top_p=0.65,
90
  )
91
+ cln_resp = symspeller(resp["out_text"], sym_checker=schnellspell)
92
+ bot_resp = corr(remove_repeated_words(cln_resp))
93
  rt = round(time.perf_counter() - st, 2)
94
  print(f"took {rt} sec to respond")
95
  return remove_trailing_punctuation(bot_resp)
 
136
  help="folder - with respect to git directory of your repo that has the model files in it (pytorch.bin + "
137
  "config.json)",
138
  )
139
+ parser.add_argument(
140
+ "--adv-correct",
141
+ required=False,
142
+ default=False,
143
+ action="store_true",
144
+ help="turn off symspell (baseline) correction to use a more advanced spell checker",
145
+ )
146
  parser.add_argument(
147
  "--gram-model",
148
  required=False,
 
160
  model_loc = cwd.parent / default_model
161
  model_loc = str(model_loc.resolve())
162
  gram_model = args.gram_model
163
+ schnellspell = build_symspell_obj()
164
  print(f"using model stored here: \n {model_loc} \n")
165
  iface = gr.Interface(
166
  chat,
utils.py CHANGED
@@ -39,37 +39,66 @@ def print_spacer(n=1):
39
  print("\n -------- " * n)
40
 
41
 
42
- def correct_phrase_load(my_string: str):
43
  """
44
- correct_phrase_load [basic / unoptimized implementation of SymSpell to correct a string]
45
 
46
  Args:
47
- my_string (str): [text to be corrected]
48
 
49
  Returns:
50
- str: the corrected string
51
  """
52
- sym_spell = SymSpell(max_dictionary_edit_distance=2, prefix_length=7)
53
-
54
- dictionary_path = (
55
- r"symspell_rsc/frequency_dictionary_en_82_765.txt" # from repo root
56
- )
57
- bigram_path = (
58
- r"symspell_rsc/frequency_bigramdictionary_en_243_342.txt" # from repo root
59
- )
60
  # term_index is the column of the term and count_index is the
61
  # column of the term frequency
62
- sym_spell.load_dictionary(dictionary_path, term_index=0, count_index=1)
63
- sym_spell.load_bigram_dictionary(bigram_path, term_index=0, count_index=2)
 
 
 
64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  # max edit distance per lookup (per single word, not per whole input string)
66
- suggestions = sym_spell.lookup_compound(
67
- clean(my_string), max_edit_distance=2, ignore_non_words=True
68
  )
 
 
 
 
 
 
 
69
  if len(suggestions) < 1:
70
- return my_string
71
  else:
72
- first_result = suggestions[0]
73
  return first_result._term
74
 
75
 
 
39
  print("\n -------- " * n)
40
 
41
 
42
+ def build_symspell_obj(edit_dist=3, prefix_length=7, dictionary_path=None, bigram_path=None,):
43
  """
44
+ build_symspell_obj [build a SymSpell object]
45
 
46
  Args:
47
+ verbose (bool, optional): Defaults to False.
48
 
49
  Returns:
50
+ SymSpell: a SymSpell object
51
  """
52
+ dictionary_path = r"symspell_rsc/frequency_dictionary_en_82_765.txt" if dictionary_path is None else dictionary_path
53
+ bigram_path = r"symspell_rsc/frequency_bigramdictionary_en_243_342.txt" if bigram_path is None else bigram_path
54
+ sym_checker = SymSpell(
55
+ max_dictionary_edit_distance=edit_dist, prefix_length=prefix_length)
 
 
 
 
56
  # term_index is the column of the term and count_index is the
57
  # column of the term frequency
58
+ sym_checker.load_dictionary(dictionary_path, term_index=0, count_index=1)
59
+ sym_checker.load_bigram_dictionary(bigram_path, term_index=0, count_index=2)
60
+
61
+ return sym_checker
62
+
63
 
64
+ def symspeller(my_string: str, sym_checker=None, max_dist: int = 3, prefix_length: int = 7,
65
+ ignore_non_words=True,
66
+ dictionary_path:str=None, bigram_path:str=None, verbose=False):
67
+ """
68
+ symspeller - a wrapper for the SymSpell class from symspellpy
69
+
70
+ Parameters
71
+ ----------
72
+ my_string : str, required, default=None, the string to be checked
73
+ sym_checker : SymSpell, optional, default=None, the SymSpell object to use
74
+ max_dist : int, optional, default=3, the maximum distance to look for replacements
75
+ """
76
+
77
+ assert len(my_string) > 0, "entered string for correction is empty"
78
+
79
+ if sym_checker is None:
80
+ # need to create a new class object. user can specify their own dictionary and bigram files
81
+ if verbose:
82
+ print("creating new SymSpell object")
83
+ sym_checker = build_symspell_obj(
84
+ edit_dist=max_dist, prefix_length=prefix_length, dictionary_path=dictionary_path, bigram_path=bigram_path,)
85
+ else:
86
+ if verbose: print("using existing SymSpell object")
87
  # max edit distance per lookup (per single word, not per whole input string)
88
+ suggestions = sym_checker.lookup_compound(
89
+ my_string, max_edit_distance=max_dist, ignore_non_words=ignore_non_words, ignore_term_with_digits=True, transfer_casing=True,
90
  )
91
+
92
+ if verbose:
93
+ print(f"{len(suggestions)} suggestions found")
94
+ print(f"the original string is:\n\t{my_string}")
95
+ sug_list = [sug.term for sug in suggestions]
96
+ print(f"suggestions:\n\t{sug_list}\n")
97
+
98
  if len(suggestions) < 1:
99
+ return clean(my_string) # no correction because no suggestions
100
  else:
101
+ first_result = suggestions[0] # first result is the most likely
102
  return first_result._term
103
 
104