Spaces:
Runtime error
Runtime error
Merge branch 'dev-2'
Browse files- llama_lora/lib/inference.py +3 -3
- llama_lora/models.py +60 -16
- llama_lora/ui/inference_ui.py +175 -25
- llama_lora/ui/main_page.py +39 -0
llama_lora/lib/inference.py
CHANGED
@@ -66,14 +66,14 @@ def generate(
|
|
66 |
with generate_with_streaming(**generate_params) as generator:
|
67 |
for output in generator:
|
68 |
decoded_output = tokenizer.decode(output, skip_special_tokens=skip_special_tokens)
|
69 |
-
yield decoded_output, output
|
70 |
if output[-1] in [tokenizer.eos_token_id]:
|
71 |
break
|
72 |
|
73 |
if generation_output:
|
74 |
output = generation_output.sequences[0]
|
75 |
decoded_output = tokenizer.decode(output, skip_special_tokens=skip_special_tokens)
|
76 |
-
yield decoded_output, output
|
77 |
|
78 |
return # early return for stream_output
|
79 |
|
@@ -82,5 +82,5 @@ def generate(
|
|
82 |
generation_output = model.generate(**generate_params)
|
83 |
output = generation_output.sequences[0]
|
84 |
decoded_output = tokenizer.decode(output, skip_special_tokens=skip_special_tokens)
|
85 |
-
yield decoded_output, output
|
86 |
return
|
|
|
66 |
with generate_with_streaming(**generate_params) as generator:
|
67 |
for output in generator:
|
68 |
decoded_output = tokenizer.decode(output, skip_special_tokens=skip_special_tokens)
|
69 |
+
yield decoded_output, output, False
|
70 |
if output[-1] in [tokenizer.eos_token_id]:
|
71 |
break
|
72 |
|
73 |
if generation_output:
|
74 |
output = generation_output.sequences[0]
|
75 |
decoded_output = tokenizer.decode(output, skip_special_tokens=skip_special_tokens)
|
76 |
+
yield decoded_output, output, True
|
77 |
|
78 |
return # early return for stream_output
|
79 |
|
|
|
82 |
generation_output = model.generate(**generate_params)
|
83 |
output = generation_output.sequences[0]
|
84 |
decoded_output = tokenizer.decode(output, skip_special_tokens=skip_special_tokens)
|
85 |
+
yield decoded_output, output, True
|
86 |
return
|
llama_lora/models.py
CHANGED
@@ -5,7 +5,10 @@ import json
|
|
5 |
import re
|
6 |
|
7 |
import torch
|
8 |
-
from transformers import
|
|
|
|
|
|
|
9 |
from peft import PeftModel
|
10 |
|
11 |
from .globals import Global
|
@@ -27,42 +30,83 @@ def get_new_base_model(base_model_name):
|
|
27 |
Global.name_of_new_base_model_that_is_ready_to_be_used = None
|
28 |
clear_cache()
|
29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
device = get_device()
|
31 |
|
32 |
if device == "cuda":
|
33 |
-
|
34 |
-
|
35 |
load_in_8bit=Global.load_8bit,
|
36 |
torch_dtype=torch.float16,
|
37 |
# device_map="auto",
|
38 |
# ? https://github.com/tloen/alpaca-lora/issues/21
|
39 |
device_map={'': 0},
|
|
|
|
|
40 |
trust_remote_code=Global.trust_remote_code
|
41 |
)
|
42 |
elif device == "mps":
|
43 |
-
|
44 |
-
|
45 |
device_map={"": device},
|
46 |
torch_dtype=torch.float16,
|
|
|
|
|
47 |
trust_remote_code=Global.trust_remote_code
|
48 |
)
|
49 |
else:
|
50 |
-
|
51 |
-
|
52 |
device_map={"": device},
|
53 |
low_cpu_mem_usage=True,
|
|
|
|
|
54 |
trust_remote_code=Global.trust_remote_code
|
55 |
)
|
56 |
|
57 |
-
tokenizer = get_tokenizer(base_model_name)
|
58 |
-
|
59 |
-
if re.match("[^/]+/llama", base_model_name):
|
60 |
-
model.config.pad_token_id = tokenizer.pad_token_id = 0
|
61 |
-
model.config.bos_token_id = tokenizer.bos_token_id = 1
|
62 |
-
model.config.eos_token_id = tokenizer.eos_token_id = 2
|
63 |
-
|
64 |
-
return model
|
65 |
-
|
66 |
|
67 |
def get_tokenizer(base_model_name):
|
68 |
if Global.ui_dev_mode:
|
|
|
5 |
import re
|
6 |
|
7 |
import torch
|
8 |
+
from transformers import (
|
9 |
+
AutoModelForCausalLM, AutoModel,
|
10 |
+
AutoTokenizer, LlamaTokenizer
|
11 |
+
)
|
12 |
from peft import PeftModel
|
13 |
|
14 |
from .globals import Global
|
|
|
30 |
Global.name_of_new_base_model_that_is_ready_to_be_used = None
|
31 |
clear_cache()
|
32 |
|
33 |
+
model_class = AutoModelForCausalLM
|
34 |
+
from_tf = False
|
35 |
+
force_download = False
|
36 |
+
has_tried_force_download = False
|
37 |
+
while True:
|
38 |
+
try:
|
39 |
+
model = _get_model_from_pretrained(
|
40 |
+
model_class, base_model_name, from_tf=from_tf, force_download=force_download)
|
41 |
+
break
|
42 |
+
except Exception as e:
|
43 |
+
if 'from_tf' in str(e):
|
44 |
+
print(
|
45 |
+
f"Got error while loading model {base_model_name} with AutoModelForCausalLM: {e}.")
|
46 |
+
print("Retrying with from_tf=True...")
|
47 |
+
from_tf = True
|
48 |
+
force_download = False
|
49 |
+
elif model_class == AutoModelForCausalLM:
|
50 |
+
print(
|
51 |
+
f"Got error while loading model {base_model_name} with AutoModelForCausalLM: {e}.")
|
52 |
+
print("Retrying with AutoModel...")
|
53 |
+
model_class = AutoModel
|
54 |
+
force_download = False
|
55 |
+
else:
|
56 |
+
if has_tried_force_download:
|
57 |
+
raise e
|
58 |
+
print(
|
59 |
+
f"Got error while loading model {base_model_name}: {e}.")
|
60 |
+
print("Retrying with force_download=True...")
|
61 |
+
model_class = AutoModelForCausalLM
|
62 |
+
from_tf = False
|
63 |
+
force_download = True
|
64 |
+
has_tried_force_download = True
|
65 |
+
|
66 |
+
tokenizer = get_tokenizer(base_model_name)
|
67 |
+
|
68 |
+
if re.match("[^/]+/llama", base_model_name):
|
69 |
+
model.config.pad_token_id = tokenizer.pad_token_id = 0
|
70 |
+
model.config.bos_token_id = tokenizer.bos_token_id = 1
|
71 |
+
model.config.eos_token_id = tokenizer.eos_token_id = 2
|
72 |
+
|
73 |
+
return model
|
74 |
+
|
75 |
+
|
76 |
+
def _get_model_from_pretrained(model_class, model_name, from_tf=False, force_download=False):
|
77 |
device = get_device()
|
78 |
|
79 |
if device == "cuda":
|
80 |
+
return model_class.from_pretrained(
|
81 |
+
model_name,
|
82 |
load_in_8bit=Global.load_8bit,
|
83 |
torch_dtype=torch.float16,
|
84 |
# device_map="auto",
|
85 |
# ? https://github.com/tloen/alpaca-lora/issues/21
|
86 |
device_map={'': 0},
|
87 |
+
from_tf=from_tf,
|
88 |
+
force_download=force_download,
|
89 |
trust_remote_code=Global.trust_remote_code
|
90 |
)
|
91 |
elif device == "mps":
|
92 |
+
return model_class.from_pretrained(
|
93 |
+
model_name,
|
94 |
device_map={"": device},
|
95 |
torch_dtype=torch.float16,
|
96 |
+
from_tf=from_tf,
|
97 |
+
force_download=force_download,
|
98 |
trust_remote_code=Global.trust_remote_code
|
99 |
)
|
100 |
else:
|
101 |
+
return model_class.from_pretrained(
|
102 |
+
model_name,
|
103 |
device_map={"": device},
|
104 |
low_cpu_mem_usage=True,
|
105 |
+
from_tf=from_tf,
|
106 |
+
force_download=force_download,
|
107 |
trust_remote_code=Global.trust_remote_code
|
108 |
)
|
109 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
|
111 |
def get_tokenizer(base_model_name):
|
112 |
if Global.ui_dev_mode:
|
llama_lora/ui/inference_ui.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
import gradio as gr
|
|
|
2 |
import time
|
3 |
import json
|
4 |
|
@@ -21,13 +22,21 @@ default_show_raw = True
|
|
21 |
inference_output_lines = 12
|
22 |
|
23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
def prepare_inference(lora_model_name, progress=gr.Progress(track_tqdm=True)):
|
25 |
base_model_name = Global.base_model_name
|
26 |
|
27 |
try:
|
28 |
get_tokenizer(base_model_name)
|
29 |
get_model(base_model_name, lora_model_name)
|
30 |
-
return ("", "")
|
31 |
|
32 |
except Exception as e:
|
33 |
raise gr.Error(e)
|
@@ -65,6 +74,31 @@ def do_inference(
|
|
65 |
prompter = Prompter(prompt_template)
|
66 |
prompt = prompter.generate_prompt(variables)
|
67 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
if Global.ui_dev_mode:
|
69 |
message = f"Hi, Iβm currently in UI-development mode and do not have access to resources to process your request. However, this behavior is similar to what will actually happen, so you can try and see how it will work!\n\nBase model: {base_model_name}\nLoRA model: {lora_model_name}\n\nThe following is your prompt:\n\n{prompt}"
|
70 |
print(message)
|
@@ -83,35 +117,50 @@ def do_inference(
|
|
83 |
out += "\n"
|
84 |
yield out
|
85 |
|
|
|
86 |
for partial_sentence in word_generator(message):
|
|
|
87 |
yield (
|
88 |
gr.Textbox.update(
|
89 |
-
value=
|
|
|
90 |
json.dumps(
|
91 |
-
list(range(len(
|
|
|
|
|
|
|
|
|
|
|
92 |
)
|
93 |
time.sleep(0.05)
|
94 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
return
|
96 |
time.sleep(1)
|
97 |
yield (
|
98 |
gr.Textbox.update(value=message, lines=inference_output_lines),
|
99 |
-
json.dumps(list(range(len(message.split()))), indent=2)
|
|
|
|
|
|
|
100 |
)
|
101 |
return
|
102 |
|
103 |
tokenizer = get_tokenizer(base_model_name)
|
104 |
model = get_model(base_model_name, lora_model_name)
|
105 |
|
106 |
-
generation_config = GenerationConfig(
|
107 |
-
temperature=float(temperature), # to avoid ValueError('`temperature` has to be a strictly positive float, but is 2')
|
108 |
-
top_p=top_p,
|
109 |
-
top_k=top_k,
|
110 |
-
repetition_penalty=repetition_penalty,
|
111 |
-
num_beams=num_beams,
|
112 |
-
do_sample=temperature > 0, # https://github.com/huggingface/transformers/issues/22405#issuecomment-1485527953
|
113 |
-
)
|
114 |
-
|
115 |
def ui_generation_stopping_criteria(input_ids, score, **kwargs):
|
116 |
if Global.should_stop_generating:
|
117 |
return True
|
@@ -129,10 +178,8 @@ def do_inference(
|
|
129 |
'stream_output': stream_output
|
130 |
}
|
131 |
|
132 |
-
for (decoded_output, output) in generate(**generation_args):
|
133 |
-
raw_output_str =
|
134 |
-
if show_raw:
|
135 |
-
raw_output_str = str(output)
|
136 |
response = prompter.get_response(decoded_output)
|
137 |
|
138 |
if Global.should_stop_generating:
|
@@ -141,7 +188,12 @@ def do_inference(
|
|
141 |
yield (
|
142 |
gr.Textbox.update(
|
143 |
value=response, lines=inference_output_lines),
|
144 |
-
raw_output_str
|
|
|
|
|
|
|
|
|
|
|
145 |
|
146 |
if Global.should_stop_generating:
|
147 |
# If the user stops the generation, and then clicks the
|
@@ -199,11 +251,13 @@ def get_warning_message_for_lora_model_and_prompt_template(lora_model, prompt_te
|
|
199 |
if lora_mode_info and isinstance(lora_mode_info, dict):
|
200 |
model_base_model = lora_mode_info.get("base_model")
|
201 |
if model_base_model and model_base_model != Global.base_model_name:
|
202 |
-
messages.append(
|
|
|
203 |
|
204 |
model_prompt_template = lora_mode_info.get("prompt_template")
|
205 |
if model_prompt_template and model_prompt_template != prompt_template:
|
206 |
-
messages.append(
|
|
|
207 |
|
208 |
return " ".join(messages)
|
209 |
|
@@ -221,7 +275,8 @@ def handle_prompt_template_change(prompt_template, lora_model):
|
|
221 |
|
222 |
model_prompt_template_message_update = gr.Markdown.update(
|
223 |
"", visible=False)
|
224 |
-
warning_message = get_warning_message_for_lora_model_and_prompt_template(
|
|
|
225 |
if warning_message:
|
226 |
model_prompt_template_message_update = gr.Markdown.update(
|
227 |
warning_message, visible=True)
|
@@ -241,7 +296,8 @@ def handle_lora_model_change(lora_model, prompt_template):
|
|
241 |
|
242 |
model_prompt_template_message_update = gr.Markdown.update(
|
243 |
"", visible=False)
|
244 |
-
warning_message = get_warning_message_for_lora_model_and_prompt_template(
|
|
|
245 |
if warning_message:
|
246 |
model_prompt_template_message_update = gr.Markdown.update(
|
247 |
warning_message, visible=True)
|
@@ -260,6 +316,56 @@ def update_prompt_preview(prompt_template,
|
|
260 |
|
261 |
|
262 |
def inference_ui():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
263 |
things_that_might_timeout = []
|
264 |
|
265 |
with gr.Blocks() as inference_ui_blocks:
|
@@ -387,6 +493,47 @@ def inference_ui():
|
|
387 |
inference_output = gr.Textbox(
|
388 |
lines=inference_output_lines, label="Output", elem_id="inference_output")
|
389 |
inference_output.style(show_copy_button=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
390 |
with gr.Accordion(
|
391 |
"Raw Output",
|
392 |
open=not default_show_raw,
|
@@ -400,7 +547,8 @@ def inference_ui():
|
|
400 |
interactive=False,
|
401 |
elem_id="inference_raw_output")
|
402 |
|
403 |
-
reload_selected_models_btn = gr.Button(
|
|
|
404 |
|
405 |
show_raw_change_event = show_raw.change(
|
406 |
fn=lambda show_raw: gr.Accordion.update(visible=show_raw),
|
@@ -440,7 +588,8 @@ def inference_ui():
|
|
440 |
generate_event = generate_btn.click(
|
441 |
fn=prepare_inference,
|
442 |
inputs=[lora_model],
|
443 |
-
outputs=[inference_output,
|
|
|
444 |
).then(
|
445 |
fn=do_inference,
|
446 |
inputs=[
|
@@ -457,7 +606,8 @@ def inference_ui():
|
|
457 |
stream_output,
|
458 |
show_raw,
|
459 |
],
|
460 |
-
outputs=[inference_output,
|
|
|
461 |
api_name="inference"
|
462 |
)
|
463 |
stop_btn.click(
|
|
|
1 |
import gradio as gr
|
2 |
+
import os
|
3 |
import time
|
4 |
import json
|
5 |
|
|
|
22 |
inference_output_lines = 12
|
23 |
|
24 |
|
25 |
+
class LoggingItem:
|
26 |
+
def __init__(self, label):
|
27 |
+
self.label = label
|
28 |
+
|
29 |
+
def deserialize(self, value, **kwargs):
|
30 |
+
return value
|
31 |
+
|
32 |
+
|
33 |
def prepare_inference(lora_model_name, progress=gr.Progress(track_tqdm=True)):
|
34 |
base_model_name = Global.base_model_name
|
35 |
|
36 |
try:
|
37 |
get_tokenizer(base_model_name)
|
38 |
get_model(base_model_name, lora_model_name)
|
39 |
+
return ("", "", gr.Textbox.update(visible=False))
|
40 |
|
41 |
except Exception as e:
|
42 |
raise gr.Error(e)
|
|
|
74 |
prompter = Prompter(prompt_template)
|
75 |
prompt = prompter.generate_prompt(variables)
|
76 |
|
77 |
+
generation_config = GenerationConfig(
|
78 |
+
# to avoid ValueError('`temperature` has to be a strictly positive float, but is 2')
|
79 |
+
temperature=float(temperature),
|
80 |
+
top_p=top_p,
|
81 |
+
top_k=top_k,
|
82 |
+
repetition_penalty=repetition_penalty,
|
83 |
+
num_beams=num_beams,
|
84 |
+
# https://github.com/huggingface/transformers/issues/22405#issuecomment-1485527953
|
85 |
+
do_sample=temperature > 0,
|
86 |
+
)
|
87 |
+
|
88 |
+
def get_output_for_flagging(output, raw_output, completed=True):
|
89 |
+
return json.dumps({
|
90 |
+
'base_model': base_model_name,
|
91 |
+
'adaptor_model': lora_model_name,
|
92 |
+
'prompt': prompt,
|
93 |
+
'output': output,
|
94 |
+
'completed': completed,
|
95 |
+
'raw_output': raw_output,
|
96 |
+
'max_new_tokens': max_new_tokens,
|
97 |
+
'prompt_template': prompt_template,
|
98 |
+
'prompt_template_variables': variables,
|
99 |
+
'generation_config': generation_config.to_dict(),
|
100 |
+
})
|
101 |
+
|
102 |
if Global.ui_dev_mode:
|
103 |
message = f"Hi, Iβm currently in UI-development mode and do not have access to resources to process your request. However, this behavior is similar to what will actually happen, so you can try and see how it will work!\n\nBase model: {base_model_name}\nLoRA model: {lora_model_name}\n\nThe following is your prompt:\n\n{prompt}"
|
104 |
print(message)
|
|
|
117 |
out += "\n"
|
118 |
yield out
|
119 |
|
120 |
+
output = ""
|
121 |
for partial_sentence in word_generator(message):
|
122 |
+
output = partial_sentence
|
123 |
yield (
|
124 |
gr.Textbox.update(
|
125 |
+
value=output,
|
126 |
+
lines=inference_output_lines),
|
127 |
json.dumps(
|
128 |
+
list(range(len(output.split()))),
|
129 |
+
indent=2),
|
130 |
+
gr.Textbox.update(
|
131 |
+
value=get_output_for_flagging(
|
132 |
+
output, "", completed=False),
|
133 |
+
visible=True)
|
134 |
)
|
135 |
time.sleep(0.05)
|
136 |
|
137 |
+
yield (
|
138 |
+
gr.Textbox.update(
|
139 |
+
value=output,
|
140 |
+
lines=inference_output_lines),
|
141 |
+
json.dumps(
|
142 |
+
list(range(len(output.split()))),
|
143 |
+
indent=2),
|
144 |
+
gr.Textbox.update(
|
145 |
+
value=get_output_for_flagging(
|
146 |
+
output, "", completed=True),
|
147 |
+
visible=True)
|
148 |
+
)
|
149 |
+
|
150 |
return
|
151 |
time.sleep(1)
|
152 |
yield (
|
153 |
gr.Textbox.update(value=message, lines=inference_output_lines),
|
154 |
+
json.dumps(list(range(len(message.split()))), indent=2),
|
155 |
+
gr.Textbox.update(
|
156 |
+
value=get_output_for_flagging(message, ""),
|
157 |
+
visible=True)
|
158 |
)
|
159 |
return
|
160 |
|
161 |
tokenizer = get_tokenizer(base_model_name)
|
162 |
model = get_model(base_model_name, lora_model_name)
|
163 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
164 |
def ui_generation_stopping_criteria(input_ids, score, **kwargs):
|
165 |
if Global.should_stop_generating:
|
166 |
return True
|
|
|
178 |
'stream_output': stream_output
|
179 |
}
|
180 |
|
181 |
+
for (decoded_output, output, completed) in generate(**generation_args):
|
182 |
+
raw_output_str = str(output)
|
|
|
|
|
183 |
response = prompter.get_response(decoded_output)
|
184 |
|
185 |
if Global.should_stop_generating:
|
|
|
188 |
yield (
|
189 |
gr.Textbox.update(
|
190 |
value=response, lines=inference_output_lines),
|
191 |
+
raw_output_str,
|
192 |
+
gr.Textbox.update(
|
193 |
+
value=get_output_for_flagging(
|
194 |
+
decoded_output, raw_output_str, completed=completed),
|
195 |
+
visible=True)
|
196 |
+
)
|
197 |
|
198 |
if Global.should_stop_generating:
|
199 |
# If the user stops the generation, and then clicks the
|
|
|
251 |
if lora_mode_info and isinstance(lora_mode_info, dict):
|
252 |
model_base_model = lora_mode_info.get("base_model")
|
253 |
if model_base_model and model_base_model != Global.base_model_name:
|
254 |
+
messages.append(
|
255 |
+
f"β οΈ This model was trained on top of base model `{model_base_model}`, it might not work properly with the selected base model `{Global.base_model_name}`.")
|
256 |
|
257 |
model_prompt_template = lora_mode_info.get("prompt_template")
|
258 |
if model_prompt_template and model_prompt_template != prompt_template:
|
259 |
+
messages.append(
|
260 |
+
f"This model was trained with prompt template `{model_prompt_template}`.")
|
261 |
|
262 |
return " ".join(messages)
|
263 |
|
|
|
275 |
|
276 |
model_prompt_template_message_update = gr.Markdown.update(
|
277 |
"", visible=False)
|
278 |
+
warning_message = get_warning_message_for_lora_model_and_prompt_template(
|
279 |
+
lora_model, prompt_template)
|
280 |
if warning_message:
|
281 |
model_prompt_template_message_update = gr.Markdown.update(
|
282 |
warning_message, visible=True)
|
|
|
296 |
|
297 |
model_prompt_template_message_update = gr.Markdown.update(
|
298 |
"", visible=False)
|
299 |
+
warning_message = get_warning_message_for_lora_model_and_prompt_template(
|
300 |
+
lora_model, prompt_template)
|
301 |
if warning_message:
|
302 |
model_prompt_template_message_update = gr.Markdown.update(
|
303 |
warning_message, visible=True)
|
|
|
316 |
|
317 |
|
318 |
def inference_ui():
|
319 |
+
flagging_dir = os.path.join(Global.data_dir, "flagging", "inference")
|
320 |
+
if not os.path.exists(flagging_dir):
|
321 |
+
os.makedirs(flagging_dir)
|
322 |
+
|
323 |
+
flag_callback = gr.CSVLogger()
|
324 |
+
flag_components = [
|
325 |
+
LoggingItem("Base Model"),
|
326 |
+
LoggingItem("Adaptor Model"),
|
327 |
+
LoggingItem("Type"),
|
328 |
+
LoggingItem("Prompt"),
|
329 |
+
LoggingItem("Output"),
|
330 |
+
LoggingItem("Completed"),
|
331 |
+
LoggingItem("Config"),
|
332 |
+
LoggingItem("Raw Output"),
|
333 |
+
LoggingItem("Max New Tokens"),
|
334 |
+
LoggingItem("Prompt Template"),
|
335 |
+
LoggingItem("Prompt Template Variables"),
|
336 |
+
LoggingItem("Generation Config"),
|
337 |
+
]
|
338 |
+
flag_callback.setup(flag_components, flagging_dir)
|
339 |
+
|
340 |
+
def get_flag_callback_args(output_for_flagging_str, flag_type):
|
341 |
+
output_for_flagging = json.loads(output_for_flagging_str)
|
342 |
+
generation_config = output_for_flagging.get("generation_config", {})
|
343 |
+
config = []
|
344 |
+
if generation_config.get('do_sample', False):
|
345 |
+
config.append(
|
346 |
+
f"Temperature: {generation_config.get('temperature')}")
|
347 |
+
config.append(f"Top P: {generation_config.get('top_p')}")
|
348 |
+
config.append(f"Top K: {generation_config.get('top_k')}")
|
349 |
+
num_beams = generation_config.get('num_beams', 1)
|
350 |
+
if num_beams > 1:
|
351 |
+
config.append(f"Beams: {generation_config.get('num_beams')}")
|
352 |
+
config.append(f"RP: {generation_config.get('repetition_penalty')}")
|
353 |
+
return [
|
354 |
+
output_for_flagging.get("base_model", ""),
|
355 |
+
output_for_flagging.get("adaptor_model", ""),
|
356 |
+
flag_type,
|
357 |
+
output_for_flagging.get("prompt", ""),
|
358 |
+
output_for_flagging.get("output", ""),
|
359 |
+
str(output_for_flagging.get("completed", "")),
|
360 |
+
", ".join(config),
|
361 |
+
output_for_flagging.get("raw_output", ""),
|
362 |
+
str(output_for_flagging.get("max_new_tokens", "")),
|
363 |
+
output_for_flagging.get("prompt_template", ""),
|
364 |
+
json.dumps(output_for_flagging.get(
|
365 |
+
"prompt_template_variables", "")),
|
366 |
+
json.dumps(output_for_flagging.get("generation_config", "")),
|
367 |
+
]
|
368 |
+
|
369 |
things_that_might_timeout = []
|
370 |
|
371 |
with gr.Blocks() as inference_ui_blocks:
|
|
|
493 |
inference_output = gr.Textbox(
|
494 |
lines=inference_output_lines, label="Output", elem_id="inference_output")
|
495 |
inference_output.style(show_copy_button=True)
|
496 |
+
|
497 |
+
with gr.Row(elem_id="inference_flagging_group"):
|
498 |
+
output_for_flagging = gr.Textbox(
|
499 |
+
interactive=False, visible=False,
|
500 |
+
elem_id="inference_output_for_flagging")
|
501 |
+
flag_btn = gr.Button(
|
502 |
+
"Flag", elem_id="inference_flag_btn")
|
503 |
+
flag_up_btn = gr.Button(
|
504 |
+
"π", elem_id="inference_flag_up_btn")
|
505 |
+
flag_down_btn = gr.Button(
|
506 |
+
"π", elem_id="inference_flag_down_btn")
|
507 |
+
flag_output = gr.Markdown(
|
508 |
+
"", elem_id="inference_flag_output")
|
509 |
+
flag_btn.click(
|
510 |
+
lambda d: (flag_callback.flag(
|
511 |
+
get_flag_callback_args(d, "Flag"),
|
512 |
+
flag_option="Flag",
|
513 |
+
username=None
|
514 |
+
), "")[1],
|
515 |
+
inputs=[output_for_flagging],
|
516 |
+
outputs=[flag_output],
|
517 |
+
preprocess=False)
|
518 |
+
flag_up_btn.click(
|
519 |
+
lambda d: (flag_callback.flag(
|
520 |
+
get_flag_callback_args(d, "π"),
|
521 |
+
flag_option="Up Vote",
|
522 |
+
username=None
|
523 |
+
), "")[1],
|
524 |
+
inputs=[output_for_flagging],
|
525 |
+
outputs=[flag_output],
|
526 |
+
preprocess=False)
|
527 |
+
flag_down_btn.click(
|
528 |
+
lambda d: (flag_callback.flag(
|
529 |
+
get_flag_callback_args(d, "π"),
|
530 |
+
flag_option="Down Vote",
|
531 |
+
username=None
|
532 |
+
), "")[1],
|
533 |
+
inputs=[output_for_flagging],
|
534 |
+
outputs=[flag_output],
|
535 |
+
preprocess=False)
|
536 |
+
|
537 |
with gr.Accordion(
|
538 |
"Raw Output",
|
539 |
open=not default_show_raw,
|
|
|
547 |
interactive=False,
|
548 |
elem_id="inference_raw_output")
|
549 |
|
550 |
+
reload_selected_models_btn = gr.Button(
|
551 |
+
"", elem_id="inference_reload_selected_models_btn")
|
552 |
|
553 |
show_raw_change_event = show_raw.change(
|
554 |
fn=lambda show_raw: gr.Accordion.update(visible=show_raw),
|
|
|
588 |
generate_event = generate_btn.click(
|
589 |
fn=prepare_inference,
|
590 |
inputs=[lora_model],
|
591 |
+
outputs=[inference_output,
|
592 |
+
inference_raw_output, output_for_flagging],
|
593 |
).then(
|
594 |
fn=do_inference,
|
595 |
inputs=[
|
|
|
606 |
stream_output,
|
607 |
show_raw,
|
608 |
],
|
609 |
+
outputs=[inference_output,
|
610 |
+
inference_raw_output, output_for_flagging],
|
611 |
api_name="inference"
|
612 |
)
|
613 |
stop_btn.click(
|
llama_lora/ui/main_page.py
CHANGED
@@ -398,6 +398,45 @@ def main_page_custom_css():
|
|
398 |
bottom: 16px;
|
399 |
}
|
400 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
401 |
#dataset_plain_text_input_variables_separator textarea,
|
402 |
#dataset_plain_text_input_and_output_separator textarea,
|
403 |
#dataset_plain_text_data_separator textarea {
|
|
|
398 |
bottom: 16px;
|
399 |
}
|
400 |
|
401 |
+
#inference_flagging_group {
|
402 |
+
position: relative;
|
403 |
+
}
|
404 |
+
#inference_flag_output {
|
405 |
+
min-height: 1px !important;
|
406 |
+
position: absolute;
|
407 |
+
top: 0;
|
408 |
+
bottom: 0;
|
409 |
+
right: 0;
|
410 |
+
pointer-events: none;
|
411 |
+
opacity: 0.5;
|
412 |
+
}
|
413 |
+
#inference_flag_output .wrap {
|
414 |
+
top: 0;
|
415 |
+
bottom: 0;
|
416 |
+
right: 0;
|
417 |
+
justify-content: center;
|
418 |
+
align-items: flex-end;
|
419 |
+
padding: 4px !important;
|
420 |
+
}
|
421 |
+
#inference_flag_output .wrap svg {
|
422 |
+
display: none;
|
423 |
+
}
|
424 |
+
.form:has(> #inference_output_for_flagging),
|
425 |
+
#inference_output_for_flagging {
|
426 |
+
display: none;
|
427 |
+
}
|
428 |
+
#inference_flagging_group:has(#inference_output_for_flagging.hidden) {
|
429 |
+
opacity: 0.5;
|
430 |
+
pointer-events: none;
|
431 |
+
}
|
432 |
+
#inference_flag_up_btn, #inference_flag_down_btn {
|
433 |
+
min-width: 44px;
|
434 |
+
flex-grow: 1;
|
435 |
+
}
|
436 |
+
#inference_flag_btn {
|
437 |
+
flex-grow: 2;
|
438 |
+
}
|
439 |
+
|
440 |
#dataset_plain_text_input_variables_separator textarea,
|
441 |
#dataset_plain_text_input_and_output_separator textarea,
|
442 |
#dataset_plain_text_data_separator textarea {
|