zetavg commited on
Commit
883e16a
1 Parent(s): 4623b35

resolve generation canceling issue

Browse files
llama_lora/globals.py CHANGED
@@ -25,6 +25,10 @@ class Global:
25
  # Training Control
26
  should_stop_training = False
27
 
 
 
 
 
28
  # Model related
29
  model_has_been_used = False
30
  loaded_base_model_with_lora = None
 
25
  # Training Control
26
  should_stop_training = False
27
 
28
+ # Generation Control
29
+ should_stop_generating = False
30
+ generation_force_stopped_at = None
31
+
32
  # Model related
33
  model_has_been_used = False
34
  loaded_base_model_with_lora = None
llama_lora/ui/inference_ui.py CHANGED
@@ -19,6 +19,7 @@ from ..utils.callbacks import Iteratorize, Stream
19
  device = get_device()
20
 
21
  default_show_raw = True
 
22
 
23
 
24
  def do_inference(
@@ -37,6 +38,15 @@ def do_inference(
37
  progress=gr.Progress(track_tqdm=True),
38
  ):
39
  try:
 
 
 
 
 
 
 
 
 
40
  variables = [variable_0, variable_1, variable_2, variable_3,
41
  variable_4, variable_5, variable_6, variable_7]
42
  prompter = Prompter(prompt_template)
@@ -69,12 +79,20 @@ def do_inference(
69
  yield out
70
 
71
  for partial_sentence in word_generator(message):
72
- yield partial_sentence, json.dumps(list(range(len(partial_sentence.split()))), indent=2)
 
 
 
 
 
73
  time.sleep(0.05)
74
 
75
  return
76
  time.sleep(1)
77
- yield message, json.dumps(list(range(len(message.split()))), indent=2)
 
 
 
78
  return
79
 
80
  model = get_base_model()
@@ -100,6 +118,19 @@ def do_inference(
100
  "max_new_tokens": max_new_tokens,
101
  }
102
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  if stream_output:
104
  # Stream the reply 1 token at a time.
105
  # This is based on the trick of using 'stopping_criteria' to create an iterator,
@@ -131,29 +162,61 @@ def do_inference(
131
  raw_output = None
132
  if show_raw:
133
  raw_output = str(output)
134
- yield prompter.get_response(decoded_output), raw_output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  return # early return for stream_output
136
 
137
  # Without streaming
138
  with torch.no_grad():
139
- generation_output = model.generate(
140
- input_ids=input_ids,
141
- generation_config=generation_config,
142
- return_dict_in_generate=True,
143
- output_scores=True,
144
- max_new_tokens=max_new_tokens,
145
- )
146
  s = generation_output.sequences[0]
147
  output = tokenizer.decode(s)
148
  raw_output = None
149
  if show_raw:
150
  raw_output = str(s)
151
- yield prompter.get_response(output), raw_output
 
 
 
 
 
 
 
 
152
 
153
  except Exception as e:
154
  raise gr.Error(e)
155
 
156
 
 
 
 
 
 
157
  def reload_selections(current_lora_model, current_prompt_template):
158
  available_template_names = get_available_template_names()
159
  available_template_names_with_none = available_template_names + ["None"]
@@ -186,7 +249,8 @@ def handle_prompt_template_change(prompt_template, lora_model):
186
  gr_updates.append(gr.Textbox.update(
187
  label="Not Used", visible=False))
188
 
189
- model_prompt_template_message_update = gr.Markdown.update("", visible=False)
 
190
  lora_mode_info = get_info_of_available_lora_model(lora_model)
191
  if lora_mode_info and isinstance(lora_mode_info, dict):
192
  model_prompt_template = lora_mode_info.get("prompt_template")
@@ -352,7 +416,7 @@ def inference_ui():
352
  with gr.Column(elem_id="inference_output_group_container"):
353
  with gr.Column(elem_id="inference_output_group"):
354
  inference_output = gr.Textbox(
355
- lines=12, label="Output", elem_id="inference_output")
356
  inference_output.style(show_copy_button=True)
357
  with gr.Accordion(
358
  "Raw Output",
@@ -413,8 +477,12 @@ def inference_ui():
413
  outputs=[inference_output, inference_raw_output],
414
  api_name="inference"
415
  )
416
- stop_btn.click(fn=None, inputs=None, outputs=None,
417
- cancels=[generate_event])
 
 
 
 
418
 
419
  update_prompt_preview_event = update_prompt_preview_btn.click(fn=update_prompt_preview, inputs=[prompt_template,
420
  variable_0, variable_1, variable_2, variable_3,
@@ -624,5 +692,27 @@ def inference_ui():
624
  });
625
  }
626
  }, 100);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
627
  }
628
  """)
 
19
  device = get_device()
20
 
21
  default_show_raw = True
22
+ inference_output_lines = 12
23
 
24
 
25
  def do_inference(
 
38
  progress=gr.Progress(track_tqdm=True),
39
  ):
40
  try:
41
+ if Global.generation_force_stopped_at is not None:
42
+ required_elapsed_time_after_forced_stop = 1
43
+ current_unix_time = time.time()
44
+ remaining_time = required_elapsed_time_after_forced_stop - \
45
+ (current_unix_time - Global.generation_force_stopped_at)
46
+ if remaining_time > 0:
47
+ time.sleep(remaining_time)
48
+ Global.generation_force_stopped_at = None
49
+
50
  variables = [variable_0, variable_1, variable_2, variable_3,
51
  variable_4, variable_5, variable_6, variable_7]
52
  prompter = Prompter(prompt_template)
 
79
  yield out
80
 
81
  for partial_sentence in word_generator(message):
82
+ yield (
83
+ gr.Textbox.update(
84
+ value=partial_sentence, lines=inference_output_lines),
85
+ json.dumps(
86
+ list(range(len(partial_sentence.split()))), indent=2)
87
+ )
88
  time.sleep(0.05)
89
 
90
  return
91
  time.sleep(1)
92
+ yield (
93
+ gr.Textbox.update(value=message, lines=1), # TODO
94
+ json.dumps(list(range(len(message.split()))), indent=2)
95
+ )
96
  return
97
 
98
  model = get_base_model()
 
118
  "max_new_tokens": max_new_tokens,
119
  }
120
 
121
+ def ui_generation_stopping_criteria(input_ids, score, **kwargs):
122
+ if Global.should_stop_generating:
123
+ return True
124
+ return False
125
+
126
+ Global.should_stop_generating = False
127
+ generate_params.setdefault(
128
+ "stopping_criteria", transformers.StoppingCriteriaList()
129
+ )
130
+ generate_params["stopping_criteria"].append(
131
+ ui_generation_stopping_criteria
132
+ )
133
+
134
  if stream_output:
135
  # Stream the reply 1 token at a time.
136
  # This is based on the trick of using 'stopping_criteria' to create an iterator,
 
162
  raw_output = None
163
  if show_raw:
164
  raw_output = str(output)
165
+ response = prompter.get_response(decoded_output)
166
+
167
+ if Global.should_stop_generating:
168
+ return
169
+
170
+ yield (
171
+ gr.Textbox.update(
172
+ value=response, lines=inference_output_lines),
173
+ raw_output)
174
+
175
+ if Global.should_stop_generating:
176
+ # If the user stops the generation, and then clicks the
177
+ # generation button again, they may mysteriously landed
178
+ # here, in the previous, should-be-stopped generation
179
+ # function call, with the new generation function not be
180
+ # called at all. To workaround this, we yield a message
181
+ # and setting lines=1, and if the front-end JS detects
182
+ # that lines has been set to 1 (rows="1" in HTML),
183
+ # it will automatically click the generate button again
184
+ # (gr.Textbox.update() does not support updating
185
+ # elem_classes or elem_id).
186
+ # [WORKAROUND-UI01]
187
+ yield (
188
+ gr.Textbox.update(
189
+ value="Please retry", lines=1),
190
+ None)
191
  return # early return for stream_output
192
 
193
  # Without streaming
194
  with torch.no_grad():
195
+ generation_output = model.generate(**generate_params)
 
 
 
 
 
 
196
  s = generation_output.sequences[0]
197
  output = tokenizer.decode(s)
198
  raw_output = None
199
  if show_raw:
200
  raw_output = str(s)
201
+
202
+ response = prompter.get_response(output)
203
+ if Global.should_stop_generating:
204
+ return
205
+
206
+ yield (
207
+ gr.Textbox.update(value=response, lines=inference_output_lines),
208
+ raw_output)
209
+
210
 
211
  except Exception as e:
212
  raise gr.Error(e)
213
 
214
 
215
+ def handle_stop_generate():
216
+ Global.generation_force_stopped_at = time.time()
217
+ Global.should_stop_generating = True
218
+
219
+
220
  def reload_selections(current_lora_model, current_prompt_template):
221
  available_template_names = get_available_template_names()
222
  available_template_names_with_none = available_template_names + ["None"]
 
249
  gr_updates.append(gr.Textbox.update(
250
  label="Not Used", visible=False))
251
 
252
+ model_prompt_template_message_update = gr.Markdown.update(
253
+ "", visible=False)
254
  lora_mode_info = get_info_of_available_lora_model(lora_model)
255
  if lora_mode_info and isinstance(lora_mode_info, dict):
256
  model_prompt_template = lora_mode_info.get("prompt_template")
 
416
  with gr.Column(elem_id="inference_output_group_container"):
417
  with gr.Column(elem_id="inference_output_group"):
418
  inference_output = gr.Textbox(
419
+ lines=inference_output_lines, label="Output", elem_id="inference_output")
420
  inference_output.style(show_copy_button=True)
421
  with gr.Accordion(
422
  "Raw Output",
 
477
  outputs=[inference_output, inference_raw_output],
478
  api_name="inference"
479
  )
480
+ stop_btn.click(
481
+ fn=handle_stop_generate,
482
+ inputs=None,
483
+ outputs=None,
484
+ cancels=[generate_event]
485
+ )
486
 
487
  update_prompt_preview_event = update_prompt_preview_btn.click(fn=update_prompt_preview, inputs=[prompt_template,
488
  variable_0, variable_1, variable_2, variable_3,
 
692
  });
693
  }
694
  }, 100);
695
+
696
+ // [WORKAROUND-UI01]
697
+ setTimeout(function () {
698
+ const inference_output_textarea = document.querySelector(
699
+ '#inference_output textarea'
700
+ );
701
+ if (!inference_output_textarea) return;
702
+ const observer = new MutationObserver(function () {
703
+ if (inference_output_textarea.getAttribute('rows') === '1') {
704
+ setTimeout(function () {
705
+ const inference_generate_btn = document.getElementById(
706
+ 'inference_generate_btn'
707
+ );
708
+ if (inference_generate_btn) inference_generate_btn.click();
709
+ }, 10);
710
+ }
711
+ });
712
+ observer.observe(inference_output_textarea, {
713
+ attributes: true,
714
+ attributeFilter: ['rows'],
715
+ });
716
+ }, 100);
717
  }
718
  """)