Spaces:
Runtime error
Runtime error
zetavg
commited on
Commit
•
883e16a
1
Parent(s):
4623b35
resolve generation canceling issue
Browse files- llama_lora/globals.py +4 -0
- llama_lora/ui/inference_ui.py +105 -15
llama_lora/globals.py
CHANGED
@@ -25,6 +25,10 @@ class Global:
|
|
25 |
# Training Control
|
26 |
should_stop_training = False
|
27 |
|
|
|
|
|
|
|
|
|
28 |
# Model related
|
29 |
model_has_been_used = False
|
30 |
loaded_base_model_with_lora = None
|
|
|
25 |
# Training Control
|
26 |
should_stop_training = False
|
27 |
|
28 |
+
# Generation Control
|
29 |
+
should_stop_generating = False
|
30 |
+
generation_force_stopped_at = None
|
31 |
+
|
32 |
# Model related
|
33 |
model_has_been_used = False
|
34 |
loaded_base_model_with_lora = None
|
llama_lora/ui/inference_ui.py
CHANGED
@@ -19,6 +19,7 @@ from ..utils.callbacks import Iteratorize, Stream
|
|
19 |
device = get_device()
|
20 |
|
21 |
default_show_raw = True
|
|
|
22 |
|
23 |
|
24 |
def do_inference(
|
@@ -37,6 +38,15 @@ def do_inference(
|
|
37 |
progress=gr.Progress(track_tqdm=True),
|
38 |
):
|
39 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
variables = [variable_0, variable_1, variable_2, variable_3,
|
41 |
variable_4, variable_5, variable_6, variable_7]
|
42 |
prompter = Prompter(prompt_template)
|
@@ -69,12 +79,20 @@ def do_inference(
|
|
69 |
yield out
|
70 |
|
71 |
for partial_sentence in word_generator(message):
|
72 |
-
yield
|
|
|
|
|
|
|
|
|
|
|
73 |
time.sleep(0.05)
|
74 |
|
75 |
return
|
76 |
time.sleep(1)
|
77 |
-
yield
|
|
|
|
|
|
|
78 |
return
|
79 |
|
80 |
model = get_base_model()
|
@@ -100,6 +118,19 @@ def do_inference(
|
|
100 |
"max_new_tokens": max_new_tokens,
|
101 |
}
|
102 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
if stream_output:
|
104 |
# Stream the reply 1 token at a time.
|
105 |
# This is based on the trick of using 'stopping_criteria' to create an iterator,
|
@@ -131,29 +162,61 @@ def do_inference(
|
|
131 |
raw_output = None
|
132 |
if show_raw:
|
133 |
raw_output = str(output)
|
134 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
135 |
return # early return for stream_output
|
136 |
|
137 |
# Without streaming
|
138 |
with torch.no_grad():
|
139 |
-
generation_output = model.generate(
|
140 |
-
input_ids=input_ids,
|
141 |
-
generation_config=generation_config,
|
142 |
-
return_dict_in_generate=True,
|
143 |
-
output_scores=True,
|
144 |
-
max_new_tokens=max_new_tokens,
|
145 |
-
)
|
146 |
s = generation_output.sequences[0]
|
147 |
output = tokenizer.decode(s)
|
148 |
raw_output = None
|
149 |
if show_raw:
|
150 |
raw_output = str(s)
|
151 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
152 |
|
153 |
except Exception as e:
|
154 |
raise gr.Error(e)
|
155 |
|
156 |
|
|
|
|
|
|
|
|
|
|
|
157 |
def reload_selections(current_lora_model, current_prompt_template):
|
158 |
available_template_names = get_available_template_names()
|
159 |
available_template_names_with_none = available_template_names + ["None"]
|
@@ -186,7 +249,8 @@ def handle_prompt_template_change(prompt_template, lora_model):
|
|
186 |
gr_updates.append(gr.Textbox.update(
|
187 |
label="Not Used", visible=False))
|
188 |
|
189 |
-
model_prompt_template_message_update = gr.Markdown.update(
|
|
|
190 |
lora_mode_info = get_info_of_available_lora_model(lora_model)
|
191 |
if lora_mode_info and isinstance(lora_mode_info, dict):
|
192 |
model_prompt_template = lora_mode_info.get("prompt_template")
|
@@ -352,7 +416,7 @@ def inference_ui():
|
|
352 |
with gr.Column(elem_id="inference_output_group_container"):
|
353 |
with gr.Column(elem_id="inference_output_group"):
|
354 |
inference_output = gr.Textbox(
|
355 |
-
lines=
|
356 |
inference_output.style(show_copy_button=True)
|
357 |
with gr.Accordion(
|
358 |
"Raw Output",
|
@@ -413,8 +477,12 @@ def inference_ui():
|
|
413 |
outputs=[inference_output, inference_raw_output],
|
414 |
api_name="inference"
|
415 |
)
|
416 |
-
stop_btn.click(
|
417 |
-
|
|
|
|
|
|
|
|
|
418 |
|
419 |
update_prompt_preview_event = update_prompt_preview_btn.click(fn=update_prompt_preview, inputs=[prompt_template,
|
420 |
variable_0, variable_1, variable_2, variable_3,
|
@@ -624,5 +692,27 @@ def inference_ui():
|
|
624 |
});
|
625 |
}
|
626 |
}, 100);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
627 |
}
|
628 |
""")
|
|
|
19 |
device = get_device()
|
20 |
|
21 |
default_show_raw = True
|
22 |
+
inference_output_lines = 12
|
23 |
|
24 |
|
25 |
def do_inference(
|
|
|
38 |
progress=gr.Progress(track_tqdm=True),
|
39 |
):
|
40 |
try:
|
41 |
+
if Global.generation_force_stopped_at is not None:
|
42 |
+
required_elapsed_time_after_forced_stop = 1
|
43 |
+
current_unix_time = time.time()
|
44 |
+
remaining_time = required_elapsed_time_after_forced_stop - \
|
45 |
+
(current_unix_time - Global.generation_force_stopped_at)
|
46 |
+
if remaining_time > 0:
|
47 |
+
time.sleep(remaining_time)
|
48 |
+
Global.generation_force_stopped_at = None
|
49 |
+
|
50 |
variables = [variable_0, variable_1, variable_2, variable_3,
|
51 |
variable_4, variable_5, variable_6, variable_7]
|
52 |
prompter = Prompter(prompt_template)
|
|
|
79 |
yield out
|
80 |
|
81 |
for partial_sentence in word_generator(message):
|
82 |
+
yield (
|
83 |
+
gr.Textbox.update(
|
84 |
+
value=partial_sentence, lines=inference_output_lines),
|
85 |
+
json.dumps(
|
86 |
+
list(range(len(partial_sentence.split()))), indent=2)
|
87 |
+
)
|
88 |
time.sleep(0.05)
|
89 |
|
90 |
return
|
91 |
time.sleep(1)
|
92 |
+
yield (
|
93 |
+
gr.Textbox.update(value=message, lines=1), # TODO
|
94 |
+
json.dumps(list(range(len(message.split()))), indent=2)
|
95 |
+
)
|
96 |
return
|
97 |
|
98 |
model = get_base_model()
|
|
|
118 |
"max_new_tokens": max_new_tokens,
|
119 |
}
|
120 |
|
121 |
+
def ui_generation_stopping_criteria(input_ids, score, **kwargs):
|
122 |
+
if Global.should_stop_generating:
|
123 |
+
return True
|
124 |
+
return False
|
125 |
+
|
126 |
+
Global.should_stop_generating = False
|
127 |
+
generate_params.setdefault(
|
128 |
+
"stopping_criteria", transformers.StoppingCriteriaList()
|
129 |
+
)
|
130 |
+
generate_params["stopping_criteria"].append(
|
131 |
+
ui_generation_stopping_criteria
|
132 |
+
)
|
133 |
+
|
134 |
if stream_output:
|
135 |
# Stream the reply 1 token at a time.
|
136 |
# This is based on the trick of using 'stopping_criteria' to create an iterator,
|
|
|
162 |
raw_output = None
|
163 |
if show_raw:
|
164 |
raw_output = str(output)
|
165 |
+
response = prompter.get_response(decoded_output)
|
166 |
+
|
167 |
+
if Global.should_stop_generating:
|
168 |
+
return
|
169 |
+
|
170 |
+
yield (
|
171 |
+
gr.Textbox.update(
|
172 |
+
value=response, lines=inference_output_lines),
|
173 |
+
raw_output)
|
174 |
+
|
175 |
+
if Global.should_stop_generating:
|
176 |
+
# If the user stops the generation, and then clicks the
|
177 |
+
# generation button again, they may mysteriously landed
|
178 |
+
# here, in the previous, should-be-stopped generation
|
179 |
+
# function call, with the new generation function not be
|
180 |
+
# called at all. To workaround this, we yield a message
|
181 |
+
# and setting lines=1, and if the front-end JS detects
|
182 |
+
# that lines has been set to 1 (rows="1" in HTML),
|
183 |
+
# it will automatically click the generate button again
|
184 |
+
# (gr.Textbox.update() does not support updating
|
185 |
+
# elem_classes or elem_id).
|
186 |
+
# [WORKAROUND-UI01]
|
187 |
+
yield (
|
188 |
+
gr.Textbox.update(
|
189 |
+
value="Please retry", lines=1),
|
190 |
+
None)
|
191 |
return # early return for stream_output
|
192 |
|
193 |
# Without streaming
|
194 |
with torch.no_grad():
|
195 |
+
generation_output = model.generate(**generate_params)
|
|
|
|
|
|
|
|
|
|
|
|
|
196 |
s = generation_output.sequences[0]
|
197 |
output = tokenizer.decode(s)
|
198 |
raw_output = None
|
199 |
if show_raw:
|
200 |
raw_output = str(s)
|
201 |
+
|
202 |
+
response = prompter.get_response(output)
|
203 |
+
if Global.should_stop_generating:
|
204 |
+
return
|
205 |
+
|
206 |
+
yield (
|
207 |
+
gr.Textbox.update(value=response, lines=inference_output_lines),
|
208 |
+
raw_output)
|
209 |
+
|
210 |
|
211 |
except Exception as e:
|
212 |
raise gr.Error(e)
|
213 |
|
214 |
|
215 |
+
def handle_stop_generate():
|
216 |
+
Global.generation_force_stopped_at = time.time()
|
217 |
+
Global.should_stop_generating = True
|
218 |
+
|
219 |
+
|
220 |
def reload_selections(current_lora_model, current_prompt_template):
|
221 |
available_template_names = get_available_template_names()
|
222 |
available_template_names_with_none = available_template_names + ["None"]
|
|
|
249 |
gr_updates.append(gr.Textbox.update(
|
250 |
label="Not Used", visible=False))
|
251 |
|
252 |
+
model_prompt_template_message_update = gr.Markdown.update(
|
253 |
+
"", visible=False)
|
254 |
lora_mode_info = get_info_of_available_lora_model(lora_model)
|
255 |
if lora_mode_info and isinstance(lora_mode_info, dict):
|
256 |
model_prompt_template = lora_mode_info.get("prompt_template")
|
|
|
416 |
with gr.Column(elem_id="inference_output_group_container"):
|
417 |
with gr.Column(elem_id="inference_output_group"):
|
418 |
inference_output = gr.Textbox(
|
419 |
+
lines=inference_output_lines, label="Output", elem_id="inference_output")
|
420 |
inference_output.style(show_copy_button=True)
|
421 |
with gr.Accordion(
|
422 |
"Raw Output",
|
|
|
477 |
outputs=[inference_output, inference_raw_output],
|
478 |
api_name="inference"
|
479 |
)
|
480 |
+
stop_btn.click(
|
481 |
+
fn=handle_stop_generate,
|
482 |
+
inputs=None,
|
483 |
+
outputs=None,
|
484 |
+
cancels=[generate_event]
|
485 |
+
)
|
486 |
|
487 |
update_prompt_preview_event = update_prompt_preview_btn.click(fn=update_prompt_preview, inputs=[prompt_template,
|
488 |
variable_0, variable_1, variable_2, variable_3,
|
|
|
692 |
});
|
693 |
}
|
694 |
}, 100);
|
695 |
+
|
696 |
+
// [WORKAROUND-UI01]
|
697 |
+
setTimeout(function () {
|
698 |
+
const inference_output_textarea = document.querySelector(
|
699 |
+
'#inference_output textarea'
|
700 |
+
);
|
701 |
+
if (!inference_output_textarea) return;
|
702 |
+
const observer = new MutationObserver(function () {
|
703 |
+
if (inference_output_textarea.getAttribute('rows') === '1') {
|
704 |
+
setTimeout(function () {
|
705 |
+
const inference_generate_btn = document.getElementById(
|
706 |
+
'inference_generate_btn'
|
707 |
+
);
|
708 |
+
if (inference_generate_btn) inference_generate_btn.click();
|
709 |
+
}, 10);
|
710 |
+
}
|
711 |
+
});
|
712 |
+
observer.observe(inference_output_textarea, {
|
713 |
+
attributes: true,
|
714 |
+
attributeFilter: ['rows'],
|
715 |
+
});
|
716 |
+
}, 100);
|
717 |
}
|
718 |
""")
|