abrakjamson commited on
Commit
85e58bb
·
1 Parent(s): 4da1fb0

Disable input while generating

Browse files
Files changed (1) hide show
  1. app.py +47 -10
app.py CHANGED
@@ -24,10 +24,12 @@ model = AutoModelForCausalLM.from_pretrained(
24
  trust_remote_code=True,
25
  use_safetensors=True
26
  )
27
- model = model.to("cuda:0" if torch.cuda.is_available() else "cpu")
28
- print(f"Is CUDA available: {torch.cuda.is_available()}")
29
- if torch.cuda.is_available():
 
30
  print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
 
31
 
32
  model = ControlModel(model, list(range(-5, -18, -1)))
33
 
@@ -87,7 +89,8 @@ def generate_response(system_prompt, user_message, history, max_new_tokens, repi
87
  Returns a list of tuples, the user message and the assistant response,
88
  which Gradio uses to update the chatbot history
89
  """
90
-
 
91
  # Separate checkboxes and sliders based on type
92
  # The first x in args are the checkbox names (the file names)
93
  # The second x in args are the slider values
@@ -139,7 +142,10 @@ def generate_response(system_prompt, user_message, history, max_new_tokens, repi
139
  "repetition_penalty": repetition_penalty.value,
140
  }
141
 
142
- _streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=False,)
 
 
 
143
 
144
  generate_kwargs = dict(
145
  input_ids,
@@ -155,6 +161,9 @@ def generate_response(system_prompt, user_message, history, max_new_tokens, repi
155
 
156
  # Display the response as it streams in, prepending the control vector info
157
  partial_message = ""
 
 
 
158
  for new_token in _streamer:
159
  if new_token != '<' and new_token != '</s>': # seems to hit EOS correctly without this needed
160
  partial_message += new_token
@@ -181,14 +190,17 @@ def generate_response(system_prompt, user_message, history, max_new_tokens, repi
181
 
182
  # Update conversation history
183
  history.append((user_message, assistant_response_display))
184
- yield history
185
 
186
  def generate_response_with_retry(system_prompt, user_message, history, max_new_tokens, repitition_penalty, do_sample, *args):
187
  # Remove last user input and assistant response from history, then call generate_response()
 
 
188
  if history:
189
  history = history[0:-1]
190
- for output in generate_response(system_prompt, user_message, history, max_new_tokens, repetition_penalty, do_sample, *args):
191
- yield output
 
192
 
193
  # Function to reset the conversation history
194
  def reset_chat():
@@ -281,7 +293,7 @@ def set_preset_stoner(*args):
281
  for check in model_names_and_indexes:
282
  if check == "Angry":
283
  new_checkbox_values.append(True)
284
- new_slider_values.append(0.5)
285
  elif check == "Right-leaning":
286
  new_checkbox_values.append(True)
287
  new_slider_values.append(-0.5)
@@ -323,6 +335,15 @@ def set_preset_facts(*args):
323
 
324
  return new_checkbox_values + new_slider_values
325
 
 
 
 
 
 
 
 
 
 
326
  tooltip_css = """
327
  /* Tooltip container */
328
  .tooltip {
@@ -560,10 +581,22 @@ with gr.Blocks(
560
  inputs_list = [system_prompt, user_input, chatbot, max_new_tokens, repetition_penalty, do_sample] + control_checks + control_sliders
561
 
562
  # Define button actions
 
 
 
 
 
 
563
  submit_button.click(
564
  generate_response,
565
  inputs=inputs_list,
566
  outputs=[chatbot]
 
 
 
 
 
 
567
  )
568
 
569
  user_input.submit(
@@ -575,7 +608,11 @@ with gr.Blocks(
575
  retry_button.click(
576
  generate_response_with_retry,
577
  inputs=inputs_list,
578
- outputs=[chatbot]
 
 
 
 
579
  )
580
 
581
  new_chat_button.click(
 
24
  trust_remote_code=True,
25
  use_safetensors=True
26
  )
27
+ cuda = torch.cuda.is_available()
28
+ print(f"Is CUDA available: {cuda}")
29
+ model = model.to("cuda:0" if cuda else "cpu")
30
+ if cuda:
31
  print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
32
+
33
 
34
  model = ControlModel(model, list(range(-5, -18, -1)))
35
 
 
89
  Returns a list of tuples, the user message and the assistant response,
90
  which Gradio uses to update the chatbot history
91
  """
92
+ global previous_turn
93
+ previous_turn = user_message
94
  # Separate checkboxes and sliders based on type
95
  # The first x in args are the checkbox names (the file names)
96
  # The second x in args are the slider values
 
142
  "repetition_penalty": repetition_penalty.value,
143
  }
144
 
145
+ timeout = 120.0
146
+ if cuda:
147
+ timeout = 10.0
148
+ _streamer = TextIteratorStreamer(tokenizer, timeout=timeout, skip_prompt=True, skip_special_tokens=False,)
149
 
150
  generate_kwargs = dict(
151
  input_ids,
 
161
 
162
  # Display the response as it streams in, prepending the control vector info
163
  partial_message = ""
164
+ #show the control vector info while we wait for the first token
165
+ temp_output = "*" + assistant_message_title + "*" + "\n\n*Please wait*..." + partial_message
166
+ yield history + [(user_message, temp_output)]
167
  for new_token in _streamer:
168
  if new_token != '<' and new_token != '</s>': # seems to hit EOS correctly without this needed
169
  partial_message += new_token
 
190
 
191
  # Update conversation history
192
  history.append((user_message, assistant_response_display))
193
+ return history
194
 
195
  def generate_response_with_retry(system_prompt, user_message, history, max_new_tokens, repitition_penalty, do_sample, *args):
196
  # Remove last user input and assistant response from history, then call generate_response()
197
+ global previous_turn
198
+ previous_ueser_message = previous_turn
199
  if history:
200
  history = history[0:-1]
201
+ # Using the previous turn's text, even though it isn't in the textbox anymore
202
+ for output in generate_response(system_prompt, previous_ueser_message, history, max_new_tokens, repetition_penalty, do_sample, *args):
203
+ yield [output, previous_ueser_message]
204
 
205
  # Function to reset the conversation history
206
  def reset_chat():
 
293
  for check in model_names_and_indexes:
294
  if check == "Angry":
295
  new_checkbox_values.append(True)
296
+ new_slider_values.append(0.4)
297
  elif check == "Right-leaning":
298
  new_checkbox_values.append(True)
299
  new_slider_values.append(-0.5)
 
335
 
336
  return new_checkbox_values + new_slider_values
337
 
338
+ def disable_controls():
339
+ return gr.update(interactive= False, value= "⌛ Processing"), gr.update(interactive=False)
340
+
341
+ def enable_controls():
342
+ return gr.update(interactive= True, value= "💬 Submit"), gr.update(interactive= True)
343
+
344
+ def clear_input(input_textbox):
345
+ return ""
346
+
347
  tooltip_css = """
348
  /* Tooltip container */
349
  .tooltip {
 
581
  inputs_list = [system_prompt, user_input, chatbot, max_new_tokens, repetition_penalty, do_sample] + control_checks + control_sliders
582
 
583
  # Define button actions
584
+ # Disable the submit button while processing
585
+ submit_button.click(
586
+ disable_controls,
587
+ inputs= None,
588
+ outputs= [submit_button, user_input]
589
+ )
590
  submit_button.click(
591
  generate_response,
592
  inputs=inputs_list,
593
  outputs=[chatbot]
594
+ ).then(
595
+ clear_input,
596
+ inputs= user_input,
597
+ outputs= user_input
598
+ ).then(
599
+ enable_controls, inputs=None, outputs=[submit_button, user_input]
600
  )
601
 
602
  user_input.submit(
 
608
  retry_button.click(
609
  generate_response_with_retry,
610
  inputs=inputs_list,
611
+ outputs=[chatbot, user_input]
612
+ ).then(
613
+ clear_input,
614
+ inputs= user_input,
615
+ outputs= user_input
616
  )
617
 
618
  new_chat_button.click(