zetavg commited on
Commit
889210b
Β·
unverified Β·
2 Parent(s): 6201a81 750c900

Merge branch 'dev-2'

Browse files
llama_lora/lib/inference.py CHANGED
@@ -66,14 +66,14 @@ def generate(
66
  with generate_with_streaming(**generate_params) as generator:
67
  for output in generator:
68
  decoded_output = tokenizer.decode(output, skip_special_tokens=skip_special_tokens)
69
- yield decoded_output, output
70
  if output[-1] in [tokenizer.eos_token_id]:
71
  break
72
 
73
  if generation_output:
74
  output = generation_output.sequences[0]
75
  decoded_output = tokenizer.decode(output, skip_special_tokens=skip_special_tokens)
76
- yield decoded_output, output
77
 
78
  return # early return for stream_output
79
 
@@ -82,5 +82,5 @@ def generate(
82
  generation_output = model.generate(**generate_params)
83
  output = generation_output.sequences[0]
84
  decoded_output = tokenizer.decode(output, skip_special_tokens=skip_special_tokens)
85
- yield decoded_output, output
86
  return
 
66
  with generate_with_streaming(**generate_params) as generator:
67
  for output in generator:
68
  decoded_output = tokenizer.decode(output, skip_special_tokens=skip_special_tokens)
69
+ yield decoded_output, output, False
70
  if output[-1] in [tokenizer.eos_token_id]:
71
  break
72
 
73
  if generation_output:
74
  output = generation_output.sequences[0]
75
  decoded_output = tokenizer.decode(output, skip_special_tokens=skip_special_tokens)
76
+ yield decoded_output, output, True
77
 
78
  return # early return for stream_output
79
 
 
82
  generation_output = model.generate(**generate_params)
83
  output = generation_output.sequences[0]
84
  decoded_output = tokenizer.decode(output, skip_special_tokens=skip_special_tokens)
85
+ yield decoded_output, output, True
86
  return
llama_lora/models.py CHANGED
@@ -5,7 +5,10 @@ import json
5
  import re
6
 
7
  import torch
8
- from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer
 
 
 
9
  from peft import PeftModel
10
 
11
  from .globals import Global
@@ -27,42 +30,83 @@ def get_new_base_model(base_model_name):
27
  Global.name_of_new_base_model_that_is_ready_to_be_used = None
28
  clear_cache()
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  device = get_device()
31
 
32
  if device == "cuda":
33
- model = AutoModelForCausalLM.from_pretrained(
34
- base_model_name,
35
  load_in_8bit=Global.load_8bit,
36
  torch_dtype=torch.float16,
37
  # device_map="auto",
38
  # ? https://github.com/tloen/alpaca-lora/issues/21
39
  device_map={'': 0},
 
 
40
  trust_remote_code=Global.trust_remote_code
41
  )
42
  elif device == "mps":
43
- model = AutoModelForCausalLM.from_pretrained(
44
- base_model_name,
45
  device_map={"": device},
46
  torch_dtype=torch.float16,
 
 
47
  trust_remote_code=Global.trust_remote_code
48
  )
49
  else:
50
- model = AutoModelForCausalLM.from_pretrained(
51
- base_model_name,
52
  device_map={"": device},
53
  low_cpu_mem_usage=True,
 
 
54
  trust_remote_code=Global.trust_remote_code
55
  )
56
 
57
- tokenizer = get_tokenizer(base_model_name)
58
-
59
- if re.match("[^/]+/llama", base_model_name):
60
- model.config.pad_token_id = tokenizer.pad_token_id = 0
61
- model.config.bos_token_id = tokenizer.bos_token_id = 1
62
- model.config.eos_token_id = tokenizer.eos_token_id = 2
63
-
64
- return model
65
-
66
 
67
  def get_tokenizer(base_model_name):
68
  if Global.ui_dev_mode:
 
5
  import re
6
 
7
  import torch
8
+ from transformers import (
9
+ AutoModelForCausalLM, AutoModel,
10
+ AutoTokenizer, LlamaTokenizer
11
+ )
12
  from peft import PeftModel
13
 
14
  from .globals import Global
 
30
  Global.name_of_new_base_model_that_is_ready_to_be_used = None
31
  clear_cache()
32
 
33
+ model_class = AutoModelForCausalLM
34
+ from_tf = False
35
+ force_download = False
36
+ has_tried_force_download = False
37
+ while True:
38
+ try:
39
+ model = _get_model_from_pretrained(
40
+ model_class, base_model_name, from_tf=from_tf, force_download=force_download)
41
+ break
42
+ except Exception as e:
43
+ if 'from_tf' in str(e):
44
+ print(
45
+ f"Got error while loading model {base_model_name} with AutoModelForCausalLM: {e}.")
46
+ print("Retrying with from_tf=True...")
47
+ from_tf = True
48
+ force_download = False
49
+ elif model_class == AutoModelForCausalLM:
50
+ print(
51
+ f"Got error while loading model {base_model_name} with AutoModelForCausalLM: {e}.")
52
+ print("Retrying with AutoModel...")
53
+ model_class = AutoModel
54
+ force_download = False
55
+ else:
56
+ if has_tried_force_download:
57
+ raise e
58
+ print(
59
+ f"Got error while loading model {base_model_name}: {e}.")
60
+ print("Retrying with force_download=True...")
61
+ model_class = AutoModelForCausalLM
62
+ from_tf = False
63
+ force_download = True
64
+ has_tried_force_download = True
65
+
66
+ tokenizer = get_tokenizer(base_model_name)
67
+
68
+ if re.match("[^/]+/llama", base_model_name):
69
+ model.config.pad_token_id = tokenizer.pad_token_id = 0
70
+ model.config.bos_token_id = tokenizer.bos_token_id = 1
71
+ model.config.eos_token_id = tokenizer.eos_token_id = 2
72
+
73
+ return model
74
+
75
+
76
+ def _get_model_from_pretrained(model_class, model_name, from_tf=False, force_download=False):
77
  device = get_device()
78
 
79
  if device == "cuda":
80
+ return model_class.from_pretrained(
81
+ model_name,
82
  load_in_8bit=Global.load_8bit,
83
  torch_dtype=torch.float16,
84
  # device_map="auto",
85
  # ? https://github.com/tloen/alpaca-lora/issues/21
86
  device_map={'': 0},
87
+ from_tf=from_tf,
88
+ force_download=force_download,
89
  trust_remote_code=Global.trust_remote_code
90
  )
91
  elif device == "mps":
92
+ return model_class.from_pretrained(
93
+ model_name,
94
  device_map={"": device},
95
  torch_dtype=torch.float16,
96
+ from_tf=from_tf,
97
+ force_download=force_download,
98
  trust_remote_code=Global.trust_remote_code
99
  )
100
  else:
101
+ return model_class.from_pretrained(
102
+ model_name,
103
  device_map={"": device},
104
  low_cpu_mem_usage=True,
105
+ from_tf=from_tf,
106
+ force_download=force_download,
107
  trust_remote_code=Global.trust_remote_code
108
  )
109
 
 
 
 
 
 
 
 
 
 
110
 
111
  def get_tokenizer(base_model_name):
112
  if Global.ui_dev_mode:
llama_lora/ui/inference_ui.py CHANGED
@@ -1,4 +1,5 @@
1
  import gradio as gr
 
2
  import time
3
  import json
4
 
@@ -21,13 +22,21 @@ default_show_raw = True
21
  inference_output_lines = 12
22
 
23
 
 
 
 
 
 
 
 
 
24
  def prepare_inference(lora_model_name, progress=gr.Progress(track_tqdm=True)):
25
  base_model_name = Global.base_model_name
26
 
27
  try:
28
  get_tokenizer(base_model_name)
29
  get_model(base_model_name, lora_model_name)
30
- return ("", "")
31
 
32
  except Exception as e:
33
  raise gr.Error(e)
@@ -65,6 +74,31 @@ def do_inference(
65
  prompter = Prompter(prompt_template)
66
  prompt = prompter.generate_prompt(variables)
67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  if Global.ui_dev_mode:
69
  message = f"Hi, I’m currently in UI-development mode and do not have access to resources to process your request. However, this behavior is similar to what will actually happen, so you can try and see how it will work!\n\nBase model: {base_model_name}\nLoRA model: {lora_model_name}\n\nThe following is your prompt:\n\n{prompt}"
70
  print(message)
@@ -83,35 +117,50 @@ def do_inference(
83
  out += "\n"
84
  yield out
85
 
 
86
  for partial_sentence in word_generator(message):
 
87
  yield (
88
  gr.Textbox.update(
89
- value=partial_sentence, lines=inference_output_lines),
 
90
  json.dumps(
91
- list(range(len(partial_sentence.split()))), indent=2)
 
 
 
 
 
92
  )
93
  time.sleep(0.05)
94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  return
96
  time.sleep(1)
97
  yield (
98
  gr.Textbox.update(value=message, lines=inference_output_lines),
99
- json.dumps(list(range(len(message.split()))), indent=2)
 
 
 
100
  )
101
  return
102
 
103
  tokenizer = get_tokenizer(base_model_name)
104
  model = get_model(base_model_name, lora_model_name)
105
 
106
- generation_config = GenerationConfig(
107
- temperature=float(temperature), # to avoid ValueError('`temperature` has to be a strictly positive float, but is 2')
108
- top_p=top_p,
109
- top_k=top_k,
110
- repetition_penalty=repetition_penalty,
111
- num_beams=num_beams,
112
- do_sample=temperature > 0, # https://github.com/huggingface/transformers/issues/22405#issuecomment-1485527953
113
- )
114
-
115
  def ui_generation_stopping_criteria(input_ids, score, **kwargs):
116
  if Global.should_stop_generating:
117
  return True
@@ -129,10 +178,8 @@ def do_inference(
129
  'stream_output': stream_output
130
  }
131
 
132
- for (decoded_output, output) in generate(**generation_args):
133
- raw_output_str = None
134
- if show_raw:
135
- raw_output_str = str(output)
136
  response = prompter.get_response(decoded_output)
137
 
138
  if Global.should_stop_generating:
@@ -141,7 +188,12 @@ def do_inference(
141
  yield (
142
  gr.Textbox.update(
143
  value=response, lines=inference_output_lines),
144
- raw_output_str)
 
 
 
 
 
145
 
146
  if Global.should_stop_generating:
147
  # If the user stops the generation, and then clicks the
@@ -199,11 +251,13 @@ def get_warning_message_for_lora_model_and_prompt_template(lora_model, prompt_te
199
  if lora_mode_info and isinstance(lora_mode_info, dict):
200
  model_base_model = lora_mode_info.get("base_model")
201
  if model_base_model and model_base_model != Global.base_model_name:
202
- messages.append(f"⚠️ This model was trained on top of base model `{model_base_model}`, it might not work properly with the selected base model `{Global.base_model_name}`.")
 
203
 
204
  model_prompt_template = lora_mode_info.get("prompt_template")
205
  if model_prompt_template and model_prompt_template != prompt_template:
206
- messages.append(f"This model was trained with prompt template `{model_prompt_template}`.")
 
207
 
208
  return " ".join(messages)
209
 
@@ -221,7 +275,8 @@ def handle_prompt_template_change(prompt_template, lora_model):
221
 
222
  model_prompt_template_message_update = gr.Markdown.update(
223
  "", visible=False)
224
- warning_message = get_warning_message_for_lora_model_and_prompt_template(lora_model, prompt_template)
 
225
  if warning_message:
226
  model_prompt_template_message_update = gr.Markdown.update(
227
  warning_message, visible=True)
@@ -241,7 +296,8 @@ def handle_lora_model_change(lora_model, prompt_template):
241
 
242
  model_prompt_template_message_update = gr.Markdown.update(
243
  "", visible=False)
244
- warning_message = get_warning_message_for_lora_model_and_prompt_template(lora_model, prompt_template)
 
245
  if warning_message:
246
  model_prompt_template_message_update = gr.Markdown.update(
247
  warning_message, visible=True)
@@ -260,6 +316,56 @@ def update_prompt_preview(prompt_template,
260
 
261
 
262
  def inference_ui():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
263
  things_that_might_timeout = []
264
 
265
  with gr.Blocks() as inference_ui_blocks:
@@ -387,6 +493,47 @@ def inference_ui():
387
  inference_output = gr.Textbox(
388
  lines=inference_output_lines, label="Output", elem_id="inference_output")
389
  inference_output.style(show_copy_button=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
390
  with gr.Accordion(
391
  "Raw Output",
392
  open=not default_show_raw,
@@ -400,7 +547,8 @@ def inference_ui():
400
  interactive=False,
401
  elem_id="inference_raw_output")
402
 
403
- reload_selected_models_btn = gr.Button("", elem_id="inference_reload_selected_models_btn")
 
404
 
405
  show_raw_change_event = show_raw.change(
406
  fn=lambda show_raw: gr.Accordion.update(visible=show_raw),
@@ -440,7 +588,8 @@ def inference_ui():
440
  generate_event = generate_btn.click(
441
  fn=prepare_inference,
442
  inputs=[lora_model],
443
- outputs=[inference_output, inference_raw_output],
 
444
  ).then(
445
  fn=do_inference,
446
  inputs=[
@@ -457,7 +606,8 @@ def inference_ui():
457
  stream_output,
458
  show_raw,
459
  ],
460
- outputs=[inference_output, inference_raw_output],
 
461
  api_name="inference"
462
  )
463
  stop_btn.click(
 
1
  import gradio as gr
2
+ import os
3
  import time
4
  import json
5
 
 
22
  inference_output_lines = 12
23
 
24
 
25
+ class LoggingItem:
26
+ def __init__(self, label):
27
+ self.label = label
28
+
29
+ def deserialize(self, value, **kwargs):
30
+ return value
31
+
32
+
33
  def prepare_inference(lora_model_name, progress=gr.Progress(track_tqdm=True)):
34
  base_model_name = Global.base_model_name
35
 
36
  try:
37
  get_tokenizer(base_model_name)
38
  get_model(base_model_name, lora_model_name)
39
+ return ("", "", gr.Textbox.update(visible=False))
40
 
41
  except Exception as e:
42
  raise gr.Error(e)
 
74
  prompter = Prompter(prompt_template)
75
  prompt = prompter.generate_prompt(variables)
76
 
77
+ generation_config = GenerationConfig(
78
+ # to avoid ValueError('`temperature` has to be a strictly positive float, but is 2')
79
+ temperature=float(temperature),
80
+ top_p=top_p,
81
+ top_k=top_k,
82
+ repetition_penalty=repetition_penalty,
83
+ num_beams=num_beams,
84
+ # https://github.com/huggingface/transformers/issues/22405#issuecomment-1485527953
85
+ do_sample=temperature > 0,
86
+ )
87
+
88
+ def get_output_for_flagging(output, raw_output, completed=True):
89
+ return json.dumps({
90
+ 'base_model': base_model_name,
91
+ 'adaptor_model': lora_model_name,
92
+ 'prompt': prompt,
93
+ 'output': output,
94
+ 'completed': completed,
95
+ 'raw_output': raw_output,
96
+ 'max_new_tokens': max_new_tokens,
97
+ 'prompt_template': prompt_template,
98
+ 'prompt_template_variables': variables,
99
+ 'generation_config': generation_config.to_dict(),
100
+ })
101
+
102
  if Global.ui_dev_mode:
103
  message = f"Hi, I’m currently in UI-development mode and do not have access to resources to process your request. However, this behavior is similar to what will actually happen, so you can try and see how it will work!\n\nBase model: {base_model_name}\nLoRA model: {lora_model_name}\n\nThe following is your prompt:\n\n{prompt}"
104
  print(message)
 
117
  out += "\n"
118
  yield out
119
 
120
+ output = ""
121
  for partial_sentence in word_generator(message):
122
+ output = partial_sentence
123
  yield (
124
  gr.Textbox.update(
125
+ value=output,
126
+ lines=inference_output_lines),
127
  json.dumps(
128
+ list(range(len(output.split()))),
129
+ indent=2),
130
+ gr.Textbox.update(
131
+ value=get_output_for_flagging(
132
+ output, "", completed=False),
133
+ visible=True)
134
  )
135
  time.sleep(0.05)
136
 
137
+ yield (
138
+ gr.Textbox.update(
139
+ value=output,
140
+ lines=inference_output_lines),
141
+ json.dumps(
142
+ list(range(len(output.split()))),
143
+ indent=2),
144
+ gr.Textbox.update(
145
+ value=get_output_for_flagging(
146
+ output, "", completed=True),
147
+ visible=True)
148
+ )
149
+
150
  return
151
  time.sleep(1)
152
  yield (
153
  gr.Textbox.update(value=message, lines=inference_output_lines),
154
+ json.dumps(list(range(len(message.split()))), indent=2),
155
+ gr.Textbox.update(
156
+ value=get_output_for_flagging(message, ""),
157
+ visible=True)
158
  )
159
  return
160
 
161
  tokenizer = get_tokenizer(base_model_name)
162
  model = get_model(base_model_name, lora_model_name)
163
 
 
 
 
 
 
 
 
 
 
164
  def ui_generation_stopping_criteria(input_ids, score, **kwargs):
165
  if Global.should_stop_generating:
166
  return True
 
178
  'stream_output': stream_output
179
  }
180
 
181
+ for (decoded_output, output, completed) in generate(**generation_args):
182
+ raw_output_str = str(output)
 
 
183
  response = prompter.get_response(decoded_output)
184
 
185
  if Global.should_stop_generating:
 
188
  yield (
189
  gr.Textbox.update(
190
  value=response, lines=inference_output_lines),
191
+ raw_output_str,
192
+ gr.Textbox.update(
193
+ value=get_output_for_flagging(
194
+ decoded_output, raw_output_str, completed=completed),
195
+ visible=True)
196
+ )
197
 
198
  if Global.should_stop_generating:
199
  # If the user stops the generation, and then clicks the
 
251
  if lora_mode_info and isinstance(lora_mode_info, dict):
252
  model_base_model = lora_mode_info.get("base_model")
253
  if model_base_model and model_base_model != Global.base_model_name:
254
+ messages.append(
255
+ f"⚠️ This model was trained on top of base model `{model_base_model}`, it might not work properly with the selected base model `{Global.base_model_name}`.")
256
 
257
  model_prompt_template = lora_mode_info.get("prompt_template")
258
  if model_prompt_template and model_prompt_template != prompt_template:
259
+ messages.append(
260
+ f"This model was trained with prompt template `{model_prompt_template}`.")
261
 
262
  return " ".join(messages)
263
 
 
275
 
276
  model_prompt_template_message_update = gr.Markdown.update(
277
  "", visible=False)
278
+ warning_message = get_warning_message_for_lora_model_and_prompt_template(
279
+ lora_model, prompt_template)
280
  if warning_message:
281
  model_prompt_template_message_update = gr.Markdown.update(
282
  warning_message, visible=True)
 
296
 
297
  model_prompt_template_message_update = gr.Markdown.update(
298
  "", visible=False)
299
+ warning_message = get_warning_message_for_lora_model_and_prompt_template(
300
+ lora_model, prompt_template)
301
  if warning_message:
302
  model_prompt_template_message_update = gr.Markdown.update(
303
  warning_message, visible=True)
 
316
 
317
 
318
  def inference_ui():
319
+ flagging_dir = os.path.join(Global.data_dir, "flagging", "inference")
320
+ if not os.path.exists(flagging_dir):
321
+ os.makedirs(flagging_dir)
322
+
323
+ flag_callback = gr.CSVLogger()
324
+ flag_components = [
325
+ LoggingItem("Base Model"),
326
+ LoggingItem("Adaptor Model"),
327
+ LoggingItem("Type"),
328
+ LoggingItem("Prompt"),
329
+ LoggingItem("Output"),
330
+ LoggingItem("Completed"),
331
+ LoggingItem("Config"),
332
+ LoggingItem("Raw Output"),
333
+ LoggingItem("Max New Tokens"),
334
+ LoggingItem("Prompt Template"),
335
+ LoggingItem("Prompt Template Variables"),
336
+ LoggingItem("Generation Config"),
337
+ ]
338
+ flag_callback.setup(flag_components, flagging_dir)
339
+
340
+ def get_flag_callback_args(output_for_flagging_str, flag_type):
341
+ output_for_flagging = json.loads(output_for_flagging_str)
342
+ generation_config = output_for_flagging.get("generation_config", {})
343
+ config = []
344
+ if generation_config.get('do_sample', False):
345
+ config.append(
346
+ f"Temperature: {generation_config.get('temperature')}")
347
+ config.append(f"Top P: {generation_config.get('top_p')}")
348
+ config.append(f"Top K: {generation_config.get('top_k')}")
349
+ num_beams = generation_config.get('num_beams', 1)
350
+ if num_beams > 1:
351
+ config.append(f"Beams: {generation_config.get('num_beams')}")
352
+ config.append(f"RP: {generation_config.get('repetition_penalty')}")
353
+ return [
354
+ output_for_flagging.get("base_model", ""),
355
+ output_for_flagging.get("adaptor_model", ""),
356
+ flag_type,
357
+ output_for_flagging.get("prompt", ""),
358
+ output_for_flagging.get("output", ""),
359
+ str(output_for_flagging.get("completed", "")),
360
+ ", ".join(config),
361
+ output_for_flagging.get("raw_output", ""),
362
+ str(output_for_flagging.get("max_new_tokens", "")),
363
+ output_for_flagging.get("prompt_template", ""),
364
+ json.dumps(output_for_flagging.get(
365
+ "prompt_template_variables", "")),
366
+ json.dumps(output_for_flagging.get("generation_config", "")),
367
+ ]
368
+
369
  things_that_might_timeout = []
370
 
371
  with gr.Blocks() as inference_ui_blocks:
 
493
  inference_output = gr.Textbox(
494
  lines=inference_output_lines, label="Output", elem_id="inference_output")
495
  inference_output.style(show_copy_button=True)
496
+
497
+ with gr.Row(elem_id="inference_flagging_group"):
498
+ output_for_flagging = gr.Textbox(
499
+ interactive=False, visible=False,
500
+ elem_id="inference_output_for_flagging")
501
+ flag_btn = gr.Button(
502
+ "Flag", elem_id="inference_flag_btn")
503
+ flag_up_btn = gr.Button(
504
+ "πŸ‘", elem_id="inference_flag_up_btn")
505
+ flag_down_btn = gr.Button(
506
+ "πŸ‘Ž", elem_id="inference_flag_down_btn")
507
+ flag_output = gr.Markdown(
508
+ "", elem_id="inference_flag_output")
509
+ flag_btn.click(
510
+ lambda d: (flag_callback.flag(
511
+ get_flag_callback_args(d, "Flag"),
512
+ flag_option="Flag",
513
+ username=None
514
+ ), "")[1],
515
+ inputs=[output_for_flagging],
516
+ outputs=[flag_output],
517
+ preprocess=False)
518
+ flag_up_btn.click(
519
+ lambda d: (flag_callback.flag(
520
+ get_flag_callback_args(d, "πŸ‘"),
521
+ flag_option="Up Vote",
522
+ username=None
523
+ ), "")[1],
524
+ inputs=[output_for_flagging],
525
+ outputs=[flag_output],
526
+ preprocess=False)
527
+ flag_down_btn.click(
528
+ lambda d: (flag_callback.flag(
529
+ get_flag_callback_args(d, "πŸ‘Ž"),
530
+ flag_option="Down Vote",
531
+ username=None
532
+ ), "")[1],
533
+ inputs=[output_for_flagging],
534
+ outputs=[flag_output],
535
+ preprocess=False)
536
+
537
  with gr.Accordion(
538
  "Raw Output",
539
  open=not default_show_raw,
 
547
  interactive=False,
548
  elem_id="inference_raw_output")
549
 
550
+ reload_selected_models_btn = gr.Button(
551
+ "", elem_id="inference_reload_selected_models_btn")
552
 
553
  show_raw_change_event = show_raw.change(
554
  fn=lambda show_raw: gr.Accordion.update(visible=show_raw),
 
588
  generate_event = generate_btn.click(
589
  fn=prepare_inference,
590
  inputs=[lora_model],
591
+ outputs=[inference_output,
592
+ inference_raw_output, output_for_flagging],
593
  ).then(
594
  fn=do_inference,
595
  inputs=[
 
606
  stream_output,
607
  show_raw,
608
  ],
609
+ outputs=[inference_output,
610
+ inference_raw_output, output_for_flagging],
611
  api_name="inference"
612
  )
613
  stop_btn.click(
llama_lora/ui/main_page.py CHANGED
@@ -398,6 +398,45 @@ def main_page_custom_css():
398
  bottom: 16px;
399
  }
400
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
401
  #dataset_plain_text_input_variables_separator textarea,
402
  #dataset_plain_text_input_and_output_separator textarea,
403
  #dataset_plain_text_data_separator textarea {
 
398
  bottom: 16px;
399
  }
400
 
401
+ #inference_flagging_group {
402
+ position: relative;
403
+ }
404
+ #inference_flag_output {
405
+ min-height: 1px !important;
406
+ position: absolute;
407
+ top: 0;
408
+ bottom: 0;
409
+ right: 0;
410
+ pointer-events: none;
411
+ opacity: 0.5;
412
+ }
413
+ #inference_flag_output .wrap {
414
+ top: 0;
415
+ bottom: 0;
416
+ right: 0;
417
+ justify-content: center;
418
+ align-items: flex-end;
419
+ padding: 4px !important;
420
+ }
421
+ #inference_flag_output .wrap svg {
422
+ display: none;
423
+ }
424
+ .form:has(> #inference_output_for_flagging),
425
+ #inference_output_for_flagging {
426
+ display: none;
427
+ }
428
+ #inference_flagging_group:has(#inference_output_for_flagging.hidden) {
429
+ opacity: 0.5;
430
+ pointer-events: none;
431
+ }
432
+ #inference_flag_up_btn, #inference_flag_down_btn {
433
+ min-width: 44px;
434
+ flex-grow: 1;
435
+ }
436
+ #inference_flag_btn {
437
+ flex-grow: 2;
438
+ }
439
+
440
  #dataset_plain_text_input_variables_separator textarea,
441
  #dataset_plain_text_input_and_output_separator textarea,
442
  #dataset_plain_text_data_separator textarea {