DongfuJiang commited on
Commit
62174a3
·
1 Parent(s): 7d90e18
Files changed (1) hide show
  1. app.py +123 -46
app.py CHANGED
@@ -6,7 +6,14 @@ from typing import List
6
 
7
  MAX_BASE_LLM_NUM = 20
8
  MIN_BASE_LLM_NUM = 3
9
- DESCRIPTIONS = """\
 
 
 
 
 
 
 
10
  LLM-Blender is an innovative ensembling framework to attain consistently superior performance by leveraging the diverse strengths of multiple open-source large language models (LLMs). LLM-Blender cut the weaknesses through ranking and integrate the strengths through fusing generation to enhance the capability of LLMs.
11
  """
12
  EXAMPLES_DATASET = load_dataset("llm-blender/mix-instruct", split='validation', streaming=True)
@@ -21,7 +28,6 @@ for example in SHUFFLED_EXAMPLES_DATASET.take(100):
21
  CANDIDATE_EXAMPLES[example['instruction']+example['input']] = example['candidates']
22
 
23
  # Download ranker checkpoint
24
- os.system("ls -l /home/user/.local/lib/python3.10/site-packages/llm_blender")
25
  if not os.path.exists("pairranker-deberta-v3-large.zip"):
26
  os.system("gdown https://drive.google.com/uc?id=1EpvFu_qYY0MaIu0BAAhK-sYKHVWtccWg")
27
  if not os.path.exists("pairranker-deberta-v3-large"):
@@ -35,13 +41,13 @@ ranker_config.ranker_type = "pairranker"
35
  ranker_config.model_type = "deberta"
36
  ranker_config.model_name = "microsoft/deberta-v3-large" # ranker backbone
37
  ranker_config.load_checkpoint = "./pairranker-deberta-v3-large" # ranker checkpoint <your checkpoint path>
38
- ranker_config.source_maxlength = 128
39
- ranker_config.candidate_maxlength = 128
40
  ranker_config.n_tasks = 1 # number of singal that has been used to train the ranker. This checkpoint is trained using BARTScore only, thus being 1.
41
  fuser_config = llm_blender.GenFuserConfig()
42
  fuser_config.model_name = "llm-blender/gen_fuser_3b" # our pre-trained fuser
43
  fuser_config.max_length = 1024
44
- fuser_config.candidate_maxlength = 128
45
  blender_config = llm_blender.BlenderConfig()
46
  blender_config.device = "cpu" # blender ranker and fuser device
47
  blender = llm_blender.Blender(blender_config, ranker_config, fuser_config)
@@ -74,7 +80,7 @@ def update_base_llm_dropdown_along_examples(dummy_text):
74
  ex_llm_outputs = {f"LLM-{i+1}": candidates[i]['text'] for i in range(len(candidates))}
75
  return ex_llm_outputs
76
 
77
- def check_save_ranker_inputs(inst, input, llm_outputs):
78
  if not inst and not input:
79
  raise gr.Error("Please enter instruction or input context")
80
 
@@ -87,23 +93,29 @@ def check_save_ranker_inputs(inst, input, llm_outputs):
87
  "candidates": list(llm_outputs.values()),
88
  }
89
 
90
- def check_fuser_inputs(blender_state, top_k_for_fuser, ranks):
91
  pass
92
 
93
- def llms_rank(inst, input, llm_outputs):
94
  candidates = list(llm_outputs.values())
95
-
96
- return blender.rank(instructions=[inst], inputs=[input], candidates=[candidates])[0]
 
 
 
 
97
 
98
- def display_ranks(ranks):
99
- return ", ".join([f"LLM-{i+1}: {rank}" for i, rank in enumerate(ranks)])
100
 
101
- def llms_fuse(blender_state, top_k_for_fuser, ranks):
102
  inst = blender_state['inst']
103
  input = blender_state['input']
104
  candidates = blender_state['candidates']
 
 
 
105
  top_k_candidates = get_topk_candidates_from_ranks([ranks], [candidates], top_k=top_k_for_fuser)[0]
106
- return blender.fuse(instructions=[inst], inputs=[input], candidates=[top_k_candidates])[0]
 
107
 
108
  def display_fuser_output(fuser_output):
109
  return fuser_output
@@ -111,16 +123,18 @@ def display_fuser_output(fuser_output):
111
 
112
  with gr.Blocks(theme='ParityError/Anime') as demo:
113
  gr.Markdown(DESCRIPTIONS)
 
114
  with gr.Row():
115
  with gr.Column():
116
  inst_textbox = gr.Textbox(lines=1, label="Instruction", placeholder="Enter instruction here", show_label=True)
117
  input_textbox = gr.Textbox(lines=4, label="Input Context", placeholder="Enter input context here", show_label=True)
118
  with gr.Column():
119
  saved_llm_outputs = gr.State(value={})
120
- selected_base_llm_name_dropdown = gr.Dropdown(label="Base LLM",
121
- choices=[f"LLM-{i+1}" for i in range(MIN_BASE_LLM_NUM)], value="LLM-1", show_label=True)
122
- selected_base_llm_output = gr.Textbox(lines=4, label="LLM-1 (Click Save to save current content)",
123
- placeholder="Enter LLM-1 output here", show_label=True)
 
124
  with gr.Row():
125
  base_llm_outputs_save_button = gr.Button('Save', variant='primary')
126
 
@@ -136,28 +150,67 @@ with gr.Blocks(theme='ParityError/Anime') as demo:
136
  )
137
 
138
  blender_state = gr.State(value={})
139
- with gr.Tab("Ranking outputs"):
140
- saved_rank_outputs = gr.State(value=[])
141
- rank_outputs = gr.Textbox(lines=4, label="Ranking outputs", placeholder="Ranking outputs", show_label=True)
142
- with gr.Tab("Fusing outputs"):
143
- saved_fuse_outputs = gr.State(value=[])
144
  fuser_outputs = gr.Textbox(lines=4, label="Fusing outputs", placeholder="Fusing outputs", show_label=True)
145
  with gr.Row():
146
- rank_button = gr.Button('Rank LLM Outputs', variant='primary',
147
- scale=1, min_width=0)
148
- fuse_button = gr.Button('Fuse Top-K ranked outputs', variant='primary',
149
- scale=1, min_width=0)
150
- clear_button = gr.Button('Clear Blender', variant='primary',
151
- scale=1, min_width=0)
 
 
 
 
 
152
 
153
  with gr.Accordion(label='Advanced options', open=False):
154
-
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  top_k_for_fuser = gr.Slider(
156
- label='Top k for fuser',
157
  minimum=1,
158
  maximum=3,
159
  step=1,
160
- value=1,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
  )
162
 
163
  examples_dummy_textbox = gr.Textbox(lines=1, label="", placeholder="", show_label=False, visible=False)
@@ -211,30 +264,22 @@ with gr.Blocks(theme='ParityError/Anime') as demo:
211
 
212
  rank_button.click(
213
  fn=check_save_ranker_inputs,
214
- inputs=[inst_textbox, input_textbox, saved_llm_outputs],
215
  outputs=blender_state,
216
  ).success(
217
  fn=llms_rank,
218
- inputs=[inst_textbox, input_textbox, saved_llm_outputs],
219
- outputs=[saved_rank_outputs],
220
- ).then(
221
- fn=display_ranks,
222
- inputs=[saved_rank_outputs],
223
- outputs=rank_outputs,
224
  )
225
 
226
  fuse_button.click(
227
  fn=check_fuser_inputs,
228
- inputs=[blender_state, top_k_for_fuser, saved_rank_outputs],
229
  outputs=[],
230
  ).success(
231
  fn=llms_fuse,
232
- inputs=[blender_state, top_k_for_fuser, saved_rank_outputs],
233
- outputs=[saved_fuse_outputs],
234
- ).then(
235
- fn=display_fuser_output,
236
- inputs=[saved_fuse_outputs],
237
- outputs=fuser_outputs,
238
  )
239
 
240
  clear_button.click(
@@ -243,6 +288,38 @@ with gr.Blocks(theme='ParityError/Anime') as demo:
243
  outputs=[rank_outputs, fuser_outputs, blender_state, saved_rank_outputs],
244
  )
245
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
246
 
247
 
248
 
 
6
 
7
  MAX_BASE_LLM_NUM = 20
8
  MIN_BASE_LLM_NUM = 3
9
+ SOURCE_MAX_LENGTH = 256
10
+ DEFAULT_SOURCE_MAX_LENGTH = 128
11
+ CANDIDATE_MAX_LENGTH = 256
12
+ DEFAULT_CANDIDATE_MAX_LENGTH = 128
13
+ FUSER_MAX_NEW_TOKENS = 512
14
+ DEFAULT_FUSER_MAX_NEW_TOKENS = 256
15
+ DESCRIPTIONS = """# LLM-BLENDER
16
+
17
  LLM-Blender is an innovative ensembling framework to attain consistently superior performance by leveraging the diverse strengths of multiple open-source large language models (LLMs). LLM-Blender cut the weaknesses through ranking and integrate the strengths through fusing generation to enhance the capability of LLMs.
18
  """
19
  EXAMPLES_DATASET = load_dataset("llm-blender/mix-instruct", split='validation', streaming=True)
 
28
  CANDIDATE_EXAMPLES[example['instruction']+example['input']] = example['candidates']
29
 
30
  # Download ranker checkpoint
 
31
  if not os.path.exists("pairranker-deberta-v3-large.zip"):
32
  os.system("gdown https://drive.google.com/uc?id=1EpvFu_qYY0MaIu0BAAhK-sYKHVWtccWg")
33
  if not os.path.exists("pairranker-deberta-v3-large"):
 
41
  ranker_config.model_type = "deberta"
42
  ranker_config.model_name = "microsoft/deberta-v3-large" # ranker backbone
43
  ranker_config.load_checkpoint = "./pairranker-deberta-v3-large" # ranker checkpoint <your checkpoint path>
44
+ ranker_config.source_maxlength = DEFAULT_SOURCE_MAX_LENGTH
45
+ ranker_config.candidate_maxlength = DEFAULT_CANDIDATE_MAX_LENGTH
46
  ranker_config.n_tasks = 1 # number of singal that has been used to train the ranker. This checkpoint is trained using BARTScore only, thus being 1.
47
  fuser_config = llm_blender.GenFuserConfig()
48
  fuser_config.model_name = "llm-blender/gen_fuser_3b" # our pre-trained fuser
49
  fuser_config.max_length = 1024
50
+ fuser_config.candidate_maxlength = DEFAULT_CANDIDATE_MAX_LENGTH
51
  blender_config = llm_blender.BlenderConfig()
52
  blender_config.device = "cpu" # blender ranker and fuser device
53
  blender = llm_blender.Blender(blender_config, ranker_config, fuser_config)
 
80
  ex_llm_outputs = {f"LLM-{i+1}": candidates[i]['text'] for i in range(len(candidates))}
81
  return ex_llm_outputs
82
 
83
+ def check_save_ranker_inputs(inst, input, llm_outputs, blender_config):
84
  if not inst and not input:
85
  raise gr.Error("Please enter instruction or input context")
86
 
 
93
  "candidates": list(llm_outputs.values()),
94
  }
95
 
96
+ def check_fuser_inputs(blender_state, blender_config, ranks):
97
  pass
98
 
99
+ def llms_rank(inst, input, llm_outputs, blender_config):
100
  candidates = list(llm_outputs.values())
101
+ rank_params = {
102
+ "source_max_length": blender_config['source_max_length'],
103
+ "candidate_max_length": blender_config['candidate_max_length'],
104
+ }
105
+ ranks = blender.rank(instructions=[inst], inputs=[input], candidates=[candidates])[0]
106
+ return [ranks, ", ".join([f"LLM-{i+1}: {rank}" for i, rank in enumerate(ranks)])]
107
 
 
 
108
 
109
+ def llms_fuse(blender_state, blender_config, ranks):
110
  inst = blender_state['inst']
111
  input = blender_state['input']
112
  candidates = blender_state['candidates']
113
+ top_k_for_fuser = blender_config['top_k_for_fuser']
114
+ fuse_params = blender_config.copy()
115
+ del fuse_params["top_k_for_fuser"]
116
  top_k_candidates = get_topk_candidates_from_ranks([ranks], [candidates], top_k=top_k_for_fuser)[0]
117
+ fuser_outputs = blender.fuse(instructions=[inst], inputs=[input], candidates=[top_k_candidates], **fuse_params)[0]
118
+ return [fuser_outputs, fuser_outputs]
119
 
120
  def display_fuser_output(fuser_output):
121
  return fuser_output
 
123
 
124
  with gr.Blocks(theme='ParityError/Anime') as demo:
125
  gr.Markdown(DESCRIPTIONS)
126
+ gr.Markdown("## Input and Base LLMs")
127
  with gr.Row():
128
  with gr.Column():
129
  inst_textbox = gr.Textbox(lines=1, label="Instruction", placeholder="Enter instruction here", show_label=True)
130
  input_textbox = gr.Textbox(lines=4, label="Input Context", placeholder="Enter input context here", show_label=True)
131
  with gr.Column():
132
  saved_llm_outputs = gr.State(value={})
133
+ with gr.Group():
134
+ selected_base_llm_name_dropdown = gr.Dropdown(label="Base LLM",
135
+ choices=[f"LLM-{i+1}" for i in range(MIN_BASE_LLM_NUM)], value="LLM-1", show_label=True)
136
+ selected_base_llm_output = gr.Textbox(lines=4, label="LLM-1 (Click Save to save current content)",
137
+ placeholder="Enter LLM-1 output here", show_label=True)
138
  with gr.Row():
139
  base_llm_outputs_save_button = gr.Button('Save', variant='primary')
140
 
 
150
  )
151
 
152
  blender_state = gr.State(value={})
153
+ saved_rank_outputs = gr.State(value=[])
154
+ saved_fuse_outputs = gr.State(value=[])
155
+ gr.Markdown("## Blender Outputs")
156
+ with gr.Group():
157
+ rank_outputs = gr.Textbox(lines=1, label="Ranking outputs", placeholder="Ranking outputs", show_label=True)
158
  fuser_outputs = gr.Textbox(lines=4, label="Fusing outputs", placeholder="Fusing outputs", show_label=True)
159
  with gr.Row():
160
+ rank_button = gr.Button('Rank LLM Outputs', variant='primary')
161
+ fuse_button = gr.Button('Fuse Top-K ranked outputs', variant='primary')
162
+ clear_button = gr.Button('Clear Blender Outputs', variant='primary')
163
+ blender_config = gr.State(value={
164
+ "source_max_length": DEFAULT_SOURCE_MAX_LENGTH,
165
+ "candidate_max_length": DEFAULT_CANDIDATE_MAX_LENGTH,
166
+ "top_k_for_fuser": 3,
167
+ "max_new_tokens": DEFAULT_FUSER_MAX_NEW_TOKENS,
168
+ "temperature": 0.7,
169
+ "top_p": 1.0,
170
+ })
171
 
172
  with gr.Accordion(label='Advanced options', open=False):
173
+ source_max_length = gr.Slider(
174
+ label='Max length of Instruction + Input',
175
+ minimum=1,
176
+ maximum=SOURCE_MAX_LENGTH,
177
+ step=1,
178
+ value=DEFAULT_SOURCE_MAX_LENGTH,
179
+ )
180
+ candidate_max_length = gr.Slider(
181
+ label='Max length of LLM-Output Candidate',
182
+ minimum=1,
183
+ maximum=CANDIDATE_MAX_LENGTH,
184
+ step=1,
185
+ value=DEFAULT_CANDIDATE_MAX_LENGTH,
186
+ )
187
  top_k_for_fuser = gr.Slider(
188
+ label='Top-k ranked candidates to fuse',
189
  minimum=1,
190
  maximum=3,
191
  step=1,
192
+ value=3,
193
+ )
194
+ max_new_tokens = gr.Slider(
195
+ label='Max new tokens fuser can generate',
196
+ minimum=1,
197
+ maximum=FUSER_MAX_NEW_TOKENS,
198
+ step=1,
199
+ value=DEFAULT_FUSER_MAX_NEW_TOKENS,
200
+ )
201
+ temperature = gr.Slider(
202
+ label='Temperature of fuser generation',
203
+ minimum=0.1,
204
+ maximum=2.0,
205
+ step=0.1,
206
+ value=0.7,
207
+ )
208
+ top_p = gr.Slider(
209
+ label='Top-p of fuser generation',
210
+ minimum=0.05,
211
+ maximum=1.0,
212
+ step=0.05,
213
+ value=1.0,
214
  )
215
 
216
  examples_dummy_textbox = gr.Textbox(lines=1, label="", placeholder="", show_label=False, visible=False)
 
264
 
265
  rank_button.click(
266
  fn=check_save_ranker_inputs,
267
+ inputs=[inst_textbox, input_textbox, saved_llm_outputs, blender_config],
268
  outputs=blender_state,
269
  ).success(
270
  fn=llms_rank,
271
+ inputs=[inst_textbox, input_textbox, saved_llm_outputs, blender_config],
272
+ outputs=[saved_rank_outputs, rank_outputs],
 
 
 
 
273
  )
274
 
275
  fuse_button.click(
276
  fn=check_fuser_inputs,
277
+ inputs=[blender_state, blender_config, saved_rank_outputs],
278
  outputs=[],
279
  ).success(
280
  fn=llms_fuse,
281
+ inputs=[blender_state, blender_config, saved_rank_outputs],
282
+ outputs=[saved_fuse_outputs, fuser_outputs],
 
 
 
 
283
  )
284
 
285
  clear_button.click(
 
288
  outputs=[rank_outputs, fuser_outputs, blender_state, saved_rank_outputs],
289
  )
290
 
291
+ # update blender config
292
+ source_max_length.change(
293
+ fn=lambda x, y: y.update({"source_max_length": x}) or y,
294
+ inputs=[source_max_length, blender_config],
295
+ outputs=blender_config,
296
+ )
297
+ candidate_max_length.change(
298
+ fn=lambda x, y: y.update({"candidate_max_length": x}) or y,
299
+ inputs=[candidate_max_length, blender_config],
300
+ outputs=blender_config,
301
+ )
302
+ top_k_for_fuser.change(
303
+ fn=lambda x, y: y.update({"top_k_for_fuser": x}) or y,
304
+ inputs=[top_k_for_fuser, blender_config],
305
+ outputs=blender_config,
306
+ )
307
+ max_new_tokens.change(
308
+ fn=lambda x, y: y.update({"max_new_tokens": x}) or y,
309
+ inputs=[max_new_tokens, blender_config],
310
+ outputs=blender_config,
311
+ )
312
+ temperature.change(
313
+ fn=lambda x, y: y.update({"temperature": x}) or y,
314
+ inputs=[temperature, blender_config],
315
+ outputs=blender_config,
316
+ )
317
+ top_p.change(
318
+ fn=lambda x, y: y.update({"top_p": x}) or y,
319
+ inputs=[top_p, blender_config],
320
+ outputs=blender_config,
321
+ )
322
+
323
 
324
 
325