DongfuJiang commited on
Commit
a83a6e5
1 Parent(s): bf79ee8
Files changed (2) hide show
  1. app.py +27 -17
  2. requirements.txt +1 -2
app.py CHANGED
@@ -17,17 +17,23 @@ CANDIDATE_MAX_LENGTH = 256
17
  DEFAULT_CANDIDATE_MAX_LENGTH = 128
18
  FUSER_MAX_NEW_TOKENS = 512
19
  DEFAULT_FUSER_MAX_NEW_TOKENS = 256
20
- EXAMPLES_DATASET = load_dataset("llm-blender/mix-instruct", split='validation', streaming=True)
21
- SHUFFLED_EXAMPLES_DATASET = EXAMPLES_DATASET.shuffle(seed=42, buffer_size=1000)
22
- EXAMPLES = []
23
- CANDIDATE_EXAMPLES = {}
24
- for example in SHUFFLED_EXAMPLES_DATASET.take(100):
25
- EXAMPLES.append([
 
 
 
26
  example['instruction'],
27
  example['input'],
28
  ])
29
- CANDIDATE_EXAMPLES[example['instruction']+example['input']] = example['candidates']
 
 
30
 
 
31
  HHH_EXAMPLES = []
32
  subsets = ['harmless', 'helpful', 'honest', 'other']
33
  random.seed(42)
@@ -53,6 +59,7 @@ for subset in subsets:
53
  def get_hhh_examples(subset, instruction, response1, response2, dummy_text):
54
  return instruction, response1, response2
55
 
 
56
  MT_BENCH_HUMAN_JUDGE_EXAMPLES = []
57
  dataset = load_dataset("lmsys/mt_bench_human_judgments")
58
  for example in dataset['human']:
@@ -101,15 +108,17 @@ def save_llm_output(selected_base_llm_name, selected_base_llm_output, llm_output
101
 
102
  def get_preprocess_examples(inst, input):
103
  # get the num_of_base_llms
104
- candidates = CANDIDATE_EXAMPLES[inst+input]
105
  num_candiates = len(candidates)
106
  dummy_text = inst+input
107
  return inst, input, num_candiates, dummy_text
108
 
109
- def update_base_llm_dropdown_along_examples(dummy_text):
110
- candidates = CANDIDATE_EXAMPLES[dummy_text]
111
  ex_llm_outputs = {f"LLM-{i+1}": candidates[i]['text'] for i in range(len(candidates))}
112
- return ex_llm_outputs, "", ""
 
 
113
 
114
  def check_save_ranker_inputs(inst, input, llm_outputs, blender_config):
115
  if not inst and not input:
@@ -125,10 +134,11 @@ def check_save_ranker_inputs(inst, input, llm_outputs, blender_config):
125
  }
126
 
127
  def check_fuser_inputs(blender_state, blender_config, ranks):
128
- if not (blender_state.get("inst", None) or blender_state.get("input", None)):
129
- raise gr.Error("Please enter instruction or input context")
130
  if "candidates" not in blender_state or len(ranks)==0:
131
  raise gr.Error("Please rank LLM outputs first")
 
 
 
132
  return
133
 
134
  def llms_rank(inst, input, llm_outputs, blender_config):
@@ -259,7 +269,7 @@ with gr.Blocks(theme='ParityError/Anime') as demo:
259
 
260
  examples_dummy_textbox = gr.Textbox(lines=1, label="", placeholder="", show_label=False, visible=False)
261
  batch_examples = gr.Examples(
262
- examples=EXAMPLES,
263
  fn=get_preprocess_examples,
264
  cache_examples=True,
265
  examples_per_page=5,
@@ -267,7 +277,7 @@ with gr.Blocks(theme='ParityError/Anime') as demo:
267
  outputs=[inst_textbox, input_textbox, base_llms_num, examples_dummy_textbox],
268
  )
269
 
270
- base_llms_num.change(
271
  fn=update_base_llms_num,
272
  inputs=[base_llms_num, saved_llm_outputs],
273
  outputs=[selected_base_llm_name_dropdown, saved_llm_outputs],
@@ -275,8 +285,8 @@ with gr.Blocks(theme='ParityError/Anime') as demo:
275
 
276
  examples_dummy_textbox.change(
277
  fn=update_base_llm_dropdown_along_examples,
278
- inputs=[examples_dummy_textbox],
279
- outputs=[saved_llm_outputs, rank_outputs, fuser_outputs],
280
  ).then(
281
  fn=display_llm_output,
282
  inputs=[saved_llm_outputs, selected_base_llm_name_dropdown],
 
17
  DEFAULT_CANDIDATE_MAX_LENGTH = 128
18
  FUSER_MAX_NEW_TOKENS = 512
19
  DEFAULT_FUSER_MAX_NEW_TOKENS = 256
20
+
21
+
22
+ # MIX-INSTRUCT
23
+ EXAMPLES_DATASET = load_dataset("llm-blender/mix-instruct", split='validation')
24
+ SHUFFLED_EXAMPLES_DATASET = EXAMPLES_DATASET.shuffle(seed=42)
25
+ MIX_INSTRUCT_EXAMPLES = []
26
+ CANDIDATE_MAP = {}
27
+ for i, example in enumerate(SHUFFLED_EXAMPLES_DATASET):
28
+ MIX_INSTRUCT_EXAMPLES.append([
29
  example['instruction'],
30
  example['input'],
31
  ])
32
+ CANDIDATE_MAP[example['instruction']+example['input']] = example['candidates']
33
+ if i > 100:
34
+ break
35
 
36
+ # HHH ALIGNMENT
37
  HHH_EXAMPLES = []
38
  subsets = ['harmless', 'helpful', 'honest', 'other']
39
  random.seed(42)
 
59
  def get_hhh_examples(subset, instruction, response1, response2, dummy_text):
60
  return instruction, response1, response2
61
 
62
+ # MT_BENCH_HUMAN_JUDGMENTS
63
  MT_BENCH_HUMAN_JUDGE_EXAMPLES = []
64
  dataset = load_dataset("lmsys/mt_bench_human_judgments")
65
  for example in dataset['human']:
 
108
 
109
  def get_preprocess_examples(inst, input):
110
  # get the num_of_base_llms
111
+ candidates = CANDIDATE_MAP[inst+input]
112
  num_candiates = len(candidates)
113
  dummy_text = inst+input
114
  return inst, input, num_candiates, dummy_text
115
 
116
+ def update_base_llm_dropdown_along_examples(inst, input):
117
+ candidates = CANDIDATE_MAP[inst+input]
118
  ex_llm_outputs = {f"LLM-{i+1}": candidates[i]['text'] for i in range(len(candidates))}
119
+ k = len(candidates)
120
+ return ex_llm_outputs, "", "", \
121
+ gr.Dropdown(choices=[f"LLM-{i+1}" for i in range(k)], value=f"LLM-1" if k >= 1 else "", visible=True)
122
 
123
  def check_save_ranker_inputs(inst, input, llm_outputs, blender_config):
124
  if not inst and not input:
 
134
  }
135
 
136
  def check_fuser_inputs(blender_state, blender_config, ranks):
 
 
137
  if "candidates" not in blender_state or len(ranks)==0:
138
  raise gr.Error("Please rank LLM outputs first")
139
+ if not (blender_state.get("inst", None) or blender_state.get("input", None)):
140
+ raise gr.Error("Please enter instruction or input context")
141
+
142
  return
143
 
144
  def llms_rank(inst, input, llm_outputs, blender_config):
 
269
 
270
  examples_dummy_textbox = gr.Textbox(lines=1, label="", placeholder="", show_label=False, visible=False)
271
  batch_examples = gr.Examples(
272
+ examples=MIX_INSTRUCT_EXAMPLES,
273
  fn=get_preprocess_examples,
274
  cache_examples=True,
275
  examples_per_page=5,
 
277
  outputs=[inst_textbox, input_textbox, base_llms_num, examples_dummy_textbox],
278
  )
279
 
280
+ base_llms_num.input(
281
  fn=update_base_llms_num,
282
  inputs=[base_llms_num, saved_llm_outputs],
283
  outputs=[selected_base_llm_name_dropdown, saved_llm_outputs],
 
285
 
286
  examples_dummy_textbox.change(
287
  fn=update_base_llm_dropdown_along_examples,
288
+ inputs=[inst_textbox, input_textbox],
289
+ outputs=[saved_llm_outputs, rank_outputs, fuser_outputs, selected_base_llm_name_dropdown],
290
  ).then(
291
  fn=display_llm_output,
292
  inputs=[saved_llm_outputs, selected_base_llm_name_dropdown],
requirements.txt CHANGED
@@ -1,2 +1 @@
1
- llm_blender @ git+https://github.com/yuchenlin/LLM-Blender.git@main
2
- gdown
 
1
+ llm_blender @ git+https://github.com/yuchenlin/LLM-Blender.git@main