pszemraj commited on
Commit
a5629da
1 Parent(s): b67934e

🐛 update defaults

Browse files

Signed-off-by: peter szemraj <peterszemraj@gmail.com>

Files changed (2) hide show
  1. app.py +23 -16
  2. converse.py +8 -2
app.py CHANGED
@@ -44,11 +44,15 @@ import transformers
44
 
45
  transformers.logging.set_verbosity_error()
46
  cwd = Path.cwd()
47
- my_cwd = str(cwd.resolve()) # string so it can be passed to os.path() objects
48
 
49
 
50
  def chat(
51
- prompt_message, temperature: float = 0.5, top_p: float = 0.95, top_k: int = 20, constrained_generation: str = "False"
 
 
 
 
52
  ) -> str:
53
  """
54
  chat - the main function for the chatbot. This is the function that is called when the user
@@ -84,7 +88,7 @@ def ask_gpt(
84
  chat_pipe,
85
  speaker="person alpha",
86
  responder="person beta",
87
- min_length=4,
88
  max_length=48,
89
  top_p=0.95,
90
  top_k=25,
@@ -99,7 +103,7 @@ def ask_gpt(
99
  :param chat_pipe: the pipeline object for the model, created by the pipeline() function
100
  :param str speaker: the name of the speaker, defaults to "person alpha"
101
  :param str responder: the name of the responder, defaults to "person beta"
102
- :param int min_length: the minimum length of the response, defaults to 4
103
  :param int max_length: the maximum length of the response, defaults to 64
104
  :param float top_p: the top_p value for the model, defaults to 0.95
105
  :param int top_k: the top_k value for the model, defaults to 25
@@ -128,22 +132,20 @@ def ask_gpt(
128
  temperature=temperature,
129
  max_length=max_length,
130
  min_length=min_length,
131
- constrained_beam_search = constrained_generation,
132
  )
133
  gpt_et = time.perf_counter()
134
  gpt_rt = round(gpt_et - st, 2)
135
  rawtxt = resp["out_text"]
136
  # check for proper nouns
137
  if basic_sc:
138
- cln_resp = symspeller(rawtxt, sym_checker=schnellspell)
139
  else:
140
  cln_resp = synthesize_grammar(corrector=grammarbot, message=rawtxt)
141
  bot_resp_a = corr(remove_repeated_words(cln_resp))
142
  bot_resp = fix_punct_spacing(bot_resp_a)
143
  corr_rt = round(time.perf_counter() - gpt_et, 4)
144
- print(
145
- f"{gpt_rt + corr_rt} to respond, {gpt_rt} GPT, {corr_rt} for correction\n"
146
- )
147
  return remove_trailing_punctuation(bot_resp)
148
 
149
 
@@ -163,6 +165,7 @@ def get_parser():
163
  help="the model to use for the chatbot on https://huggingface.co/models OR a path to a local model",
164
  )
165
  parser.add_argument(
 
166
  "--gram-model",
167
  required=False,
168
  type=str,
@@ -173,9 +176,9 @@ def get_parser():
173
  parser.add_argument(
174
  "--basic-sc",
175
  required=False,
176
- default=False, # TODO: change this back to False once Neuspell issues are resolved.
177
  action="store_true",
178
- help="turn on symspell (baseline) correction instead of the more advanced neural net models",
179
  )
180
 
181
  parser.add_argument(
@@ -188,7 +191,7 @@ def get_parser():
188
  "--test",
189
  action="store_true",
190
  default=False,
191
- help="load the smallest model for simple testing",
192
  )
193
 
194
  return parser
@@ -207,7 +210,7 @@ if __name__ == "__main__":
207
  gram_model = str(args.gram_model)
208
  device = 0 if torch.cuda.is_available() else -1
209
 
210
- print(f"CUDA avail is {torch.cuda.is_available()}")
211
 
212
  my_chatbot = (
213
  pipeline("text-generation", model=model_loc.resolve(), device=device)
@@ -218,12 +221,12 @@ if __name__ == "__main__":
218
 
219
  if basic_sc:
220
  print("Using the baseline spellchecker")
221
- schnellspell = build_symspell_obj()
222
  else:
223
  print("using neural spell checker")
224
  grammarbot = pipeline("text2text-generation", gram_model, device=device)
225
 
226
- print(f"using model stored here: \n {model_loc} \n")
227
  iface = gr.Interface(
228
  chat,
229
  inputs=[
@@ -238,7 +241,11 @@ if __name__ == "__main__":
238
  ),
239
  Slider(minimum=0.0, maximum=1.0, step=0.01, default=0.95, label="top_p"),
240
  Slider(minimum=0, maximum=100, step=5, default=20, label="top_k"),
241
- Radio(choices=["True", "False"], default="False", label="constrained_generation"),
 
 
 
 
242
  ],
243
  outputs="html",
244
  examples_per_page=8,
 
44
 
45
  transformers.logging.set_verbosity_error()
46
  cwd = Path.cwd()
47
+ _cwd_str = str(cwd.resolve()) # string so it can be passed to os.path() objects
48
 
49
 
50
  def chat(
51
+ prompt_message,
52
+ temperature: float = 0.5,
53
+ top_p: float = 0.95,
54
+ top_k: int = 20,
55
+ constrained_generation: str = "False",
56
  ) -> str:
57
  """
58
  chat - the main function for the chatbot. This is the function that is called when the user
 
88
  chat_pipe,
89
  speaker="person alpha",
90
  responder="person beta",
91
+ min_length=12,
92
  max_length=48,
93
  top_p=0.95,
94
  top_k=25,
 
103
  :param chat_pipe: the pipeline object for the model, created by the pipeline() function
104
  :param str speaker: the name of the speaker, defaults to "person alpha"
105
  :param str responder: the name of the responder, defaults to "person beta"
106
+ :param int min_length: the minimum length of the response, defaults to 12
107
  :param int max_length: the maximum length of the response, defaults to 64
108
  :param float top_p: the top_p value for the model, defaults to 0.95
109
  :param int top_k: the top_k value for the model, defaults to 25
 
132
  temperature=temperature,
133
  max_length=max_length,
134
  min_length=min_length,
135
+ constrained_beam_search=constrained_generation,
136
  )
137
  gpt_et = time.perf_counter()
138
  gpt_rt = round(gpt_et - st, 2)
139
  rawtxt = resp["out_text"]
140
  # check for proper nouns
141
  if basic_sc:
142
+ cln_resp = symspeller(rawtxt, sym_checker=basic_spell)
143
  else:
144
  cln_resp = synthesize_grammar(corrector=grammarbot, message=rawtxt)
145
  bot_resp_a = corr(remove_repeated_words(cln_resp))
146
  bot_resp = fix_punct_spacing(bot_resp_a)
147
  corr_rt = round(time.perf_counter() - gpt_et, 4)
148
+ print(f"{gpt_rt + corr_rt} to respond, {gpt_rt} GPT, {corr_rt} for correction\n")
 
 
149
  return remove_trailing_punctuation(bot_resp)
150
 
151
 
 
165
  help="the model to use for the chatbot on https://huggingface.co/models OR a path to a local model",
166
  )
167
  parser.add_argument(
168
+ "-gm",
169
  "--gram-model",
170
  required=False,
171
  type=str,
 
176
  parser.add_argument(
177
  "--basic-sc",
178
  required=False,
179
+ default=False,
180
  action="store_true",
181
+ help="use symspell (statistical spelling correction) instead of neural spell correction",
182
  )
183
 
184
  parser.add_argument(
 
191
  "--test",
192
  action="store_true",
193
  default=False,
194
+ help="load the smallest model for simple testing (ethzanalytics/distilgpt2-tiny-conversational)",
195
  )
196
 
197
  return parser
 
210
  gram_model = str(args.gram_model)
211
  device = 0 if torch.cuda.is_available() else -1
212
 
213
+ logging.info(f"CUDA avail is {torch.cuda.is_available()}")
214
 
215
  my_chatbot = (
216
  pipeline("text-generation", model=model_loc.resolve(), device=device)
 
221
 
222
  if basic_sc:
223
  print("Using the baseline spellchecker")
224
+ basic_spell = build_symspell_obj()
225
  else:
226
  print("using neural spell checker")
227
  grammarbot = pipeline("text2text-generation", gram_model, device=device)
228
 
229
+ logging.info(f"using model stored here: \n {model_loc} \n")
230
  iface = gr.Interface(
231
  chat,
232
  inputs=[
 
241
  ),
242
  Slider(minimum=0.0, maximum=1.0, step=0.01, default=0.95, label="top_p"),
243
  Slider(minimum=0, maximum=100, step=5, default=20, label="top_k"),
244
+ Radio(
245
+ choices=["True", "False"],
246
+ default="False",
247
+ label="constrained_generation",
248
+ ),
249
  ],
250
  outputs="html",
251
  examples_per_page=8,
converse.py CHANGED
@@ -23,7 +23,7 @@ def discussion(
23
  responder: str,
24
  pipeline,
25
  timeout=45,
26
- min_length=4,
27
  max_length=64,
28
  top_p=0.95,
29
  top_k=50,
@@ -60,6 +60,8 @@ def discussion(
60
  str, the generated text
61
  """
62
 
 
 
63
  p_list = [] # track conversation
64
  p_list.append(speaker.lower() + ":" + "\n")
65
  p_list.append(prompt_text.lower() + "\n")
@@ -75,6 +77,8 @@ def discussion(
75
  response = constrained_generation(
76
  prompt=this_prompt,
77
  pipeline=pipeline,
 
 
78
  no_repeat_ngram_size=no_repeat_ngram_size,
79
  length_penalty=length_penalty,
80
  repetition_penalty=1.0,
@@ -101,6 +105,7 @@ def discussion(
101
  speaker,
102
  responder,
103
  timeout=timeout,
 
104
  max_length=max_length,
105
  top_p=top_p,
106
  top_k=top_k,
@@ -112,6 +117,7 @@ def discussion(
112
  device=device,
113
  verbose=verbose,
114
  )
 
115
  if isinstance(bot_dialogue, list) and len(bot_dialogue) > 1:
116
  bot_resp = ", ".join(bot_dialogue)
117
  elif isinstance(bot_dialogue, list) and len(bot_dialogue) == 1:
@@ -123,12 +129,12 @@ def discussion(
123
  # remove the last ',' '.' chars
124
  bot_resp = remove_trailing_punctuation(bot_resp)
125
  if verbose:
 
126
  print("\n... bot response:\n")
127
  pp.pprint(bot_resp)
128
  p_list.append(bot_resp + "\n")
129
  p_list.append("\n")
130
 
131
- print("\nfinished!")
132
  logging.info(f"finished generating response:\n\t{bot_resp}")
133
  # return the bot response and the full conversation
134
 
 
23
  responder: str,
24
  pipeline,
25
  timeout=45,
26
+ min_length=8,
27
  max_length=64,
28
  top_p=0.95,
29
  top_k=50,
 
60
  str, the generated text
61
  """
62
 
63
+ logging.debug(f"input args: {locals()}")
64
+
65
  p_list = [] # track conversation
66
  p_list.append(speaker.lower() + ":" + "\n")
67
  p_list.append(prompt_text.lower() + "\n")
 
77
  response = constrained_generation(
78
  prompt=this_prompt,
79
  pipeline=pipeline,
80
+ min_generated_tokens=min_length,
81
+ max_generated_tokens=max_length,
82
  no_repeat_ngram_size=no_repeat_ngram_size,
83
  length_penalty=length_penalty,
84
  repetition_penalty=1.0,
 
105
  speaker,
106
  responder,
107
  timeout=timeout,
108
+ min_length=min_length,
109
  max_length=max_length,
110
  top_p=top_p,
111
  top_k=top_k,
 
117
  device=device,
118
  verbose=verbose,
119
  )
120
+ logging.debug(f"generation done. bot_dialogue: {bot_dialogue}")
121
  if isinstance(bot_dialogue, list) and len(bot_dialogue) > 1:
122
  bot_resp = ", ".join(bot_dialogue)
123
  elif isinstance(bot_dialogue, list) and len(bot_dialogue) == 1:
 
129
  # remove the last ',' '.' chars
130
  bot_resp = remove_trailing_punctuation(bot_resp)
131
  if verbose:
132
+ print("\nfinished!")
133
  print("\n... bot response:\n")
134
  pp.pprint(bot_resp)
135
  p_list.append(bot_resp + "\n")
136
  p_list.append("\n")
137
 
 
138
  logging.info(f"finished generating response:\n\t{bot_resp}")
139
  # return the bot response and the full conversation
140