LLM-Blender / app.py
DongfuJiang's picture
update
bf79ee8
raw
history blame
No virus
23.7 kB
import gradio as gr
import sys
import os
import random
import llm_blender
import descriptions
from datasets import load_dataset
from llm_blender.blender.blender_utils import get_topk_candidates_from_ranks
from typing import List
MAX_BASE_LLM_NUM = 20
MIN_BASE_LLM_NUM = 3
SOURCE_MAX_LENGTH = 256
DEFAULT_SOURCE_MAX_LENGTH = 128
CANDIDATE_MAX_LENGTH = 256
DEFAULT_CANDIDATE_MAX_LENGTH = 128
FUSER_MAX_NEW_TOKENS = 512
DEFAULT_FUSER_MAX_NEW_TOKENS = 256
EXAMPLES_DATASET = load_dataset("llm-blender/mix-instruct", split='validation', streaming=True)
SHUFFLED_EXAMPLES_DATASET = EXAMPLES_DATASET.shuffle(seed=42, buffer_size=1000)
EXAMPLES = []
CANDIDATE_EXAMPLES = {}
for example in SHUFFLED_EXAMPLES_DATASET.take(100):
EXAMPLES.append([
example['instruction'],
example['input'],
])
CANDIDATE_EXAMPLES[example['instruction']+example['input']] = example['candidates']
HHH_EXAMPLES = []
subsets = ['harmless', 'helpful', 'honest', 'other']
random.seed(42)
for subset in subsets:
dataset = load_dataset("HuggingFaceH4/hhh_alignment", subset)
for example in dataset['test']:
if random.random() < 0.5:
HHH_EXAMPLES.append([
subset,
example['input'],
example['targets']['choices'][0],
example['targets']['choices'][1],
"Response 1" if example['targets']['labels'][0] == 1 else "Response 2",
])
else:
HHH_EXAMPLES.append([
subset,
example['input'],
example['targets']['choices'][1],
example['targets']['choices'][0],
"Response 2" if example['targets']['labels'][0] == 1 else "Response 1",
])
def get_hhh_examples(subset, instruction, response1, response2, dummy_text):
return instruction, response1, response2
MT_BENCH_HUMAN_JUDGE_EXAMPLES = []
dataset = load_dataset("lmsys/mt_bench_human_judgments")
for example in dataset['human']:
if example['turn'] != 1:
continue
MT_BENCH_HUMAN_JUDGE_EXAMPLES.append([
example['model_a'],
example['model_b'],
str(example['conversation_a']),
str(example['conversation_b']),
"Model A" if example['winner'] == 'model_a' else "Model B",
])
def get_mt_bench_human_judge_examples(model_a, model_b, conversation_a, conversation_b, dummy_text):
chat_history_a = []
chat_history_b = []
conversation_a = eval(conversation_a)
conversation_b = eval(conversation_b)
for i in range(0, len(conversation_a), 2):
chat_history_a.append((conversation_a[i]['content'], conversation_a[i+1]['content']))
assert conversation_a[i]['role'] == 'user' and conversation_a[i+1]['role'] == 'assistant'
for i in range(0, len(conversation_b), 2):
chat_history_b.append((conversation_b[i]['content'], conversation_b[i+1]['content']))
assert conversation_b[i]['role'] == 'user' and conversation_b[i+1]['role'] == 'assistant'
return chat_history_a, chat_history_b
blender = llm_blender.Blender()
blender.loadranker("llm-blender/PairRM")
blender.loadfuser("llm-blender/gen_fuser_3b")
def update_base_llms_num(k, llm_outputs):
k = int(k)
return [gr.Dropdown(choices=[f"LLM-{i+1}" for i in range(k)],
value=f"LLM-1" if k >= 1 else "", visible=True),
{f"LLM-{i+1}": llm_outputs.get(f"LLM-{i+1}", "") for i in range(k)}]
def display_llm_output(llm_outputs, selected_base_llm_name):
return gr.Textbox(value=llm_outputs.get(selected_base_llm_name, ""),
label=selected_base_llm_name + " (Click Save to save current content)",
placeholder=f"Enter {selected_base_llm_name} output here", show_label=True)
def save_llm_output(selected_base_llm_name, selected_base_llm_output, llm_outputs):
llm_outputs({selected_base_llm_name: selected_base_llm_output})
return llm_outputs
def get_preprocess_examples(inst, input):
# get the num_of_base_llms
candidates = CANDIDATE_EXAMPLES[inst+input]
num_candiates = len(candidates)
dummy_text = inst+input
return inst, input, num_candiates, dummy_text
def update_base_llm_dropdown_along_examples(dummy_text):
candidates = CANDIDATE_EXAMPLES[dummy_text]
ex_llm_outputs = {f"LLM-{i+1}": candidates[i]['text'] for i in range(len(candidates))}
return ex_llm_outputs, "", ""
def check_save_ranker_inputs(inst, input, llm_outputs, blender_config):
if not inst and not input:
raise gr.Error("Please enter instruction or input context")
if not all([x for x in llm_outputs.values()]):
empty_llm_names = [llm_name for llm_name, llm_output in llm_outputs.items() if not llm_output]
raise gr.Error("Please enter base LLM outputs for LLMs: {}").format(empty_llm_names)
return {
"inst": inst,
"input": input,
"candidates": list(llm_outputs.values()),
}
def check_fuser_inputs(blender_state, blender_config, ranks):
if not (blender_state.get("inst", None) or blender_state.get("input", None)):
raise gr.Error("Please enter instruction or input context")
if "candidates" not in blender_state or len(ranks)==0:
raise gr.Error("Please rank LLM outputs first")
return
def llms_rank(inst, input, llm_outputs, blender_config):
candidates = list(llm_outputs.values())
rank_params = {
"source_max_length": blender_config['source_max_length'],
"candidate_max_length": blender_config['candidate_max_length'],
}
ranks = blender.rank(instructions=[inst], inputs=[input], candidates=[candidates])[0]
return [ranks, ", ".join([f"LLM-{i+1}: {rank}" for i, rank in enumerate(ranks)])]
def llms_fuse(blender_state, blender_config, ranks):
inst = blender_state['inst']
input = blender_state['input']
candidates = blender_state['candidates']
top_k_for_fuser = blender_config['top_k_for_fuser']
fuse_params = blender_config.copy()
fuse_params.pop("top_k_for_fuser")
fuse_params.pop("source_max_length")
fuse_params['no_repeat_ngram_size'] = 3
top_k_candidates = get_topk_candidates_from_ranks([ranks], [candidates], top_k=top_k_for_fuser)[0]
fuser_outputs = blender.fuse(instructions=[inst], inputs=[input], candidates=[top_k_candidates], **fuse_params, batch_size=1)[0]
return [fuser_outputs, fuser_outputs]
def display_fuser_output(fuser_output):
return fuser_output
with gr.Blocks(theme='ParityError/Anime') as demo:
with gr.Tab("LLM-Blender"):
# llm-blender interface
with gr.Row():
gr.Markdown(descriptions.LLM_BLENDER_OVERALL_DESC)
gr.Image("https://github.com/yuchenlin/LLM-Blender/blob/main/docs/llm_blender.png?raw=true", height=300)
gr.Markdown("## Input and Base LLMs")
with gr.Row():
with gr.Column():
inst_textbox = gr.Textbox(lines=1, label="Instruction", placeholder="Enter instruction here", show_label=True)
input_textbox = gr.Textbox(lines=4, label="Input Context", placeholder="Enter input context here", show_label=True)
with gr.Column():
saved_llm_outputs = gr.State(value={})
with gr.Group():
selected_base_llm_name_dropdown = gr.Dropdown(label="Base LLM",
choices=[f"LLM-{i+1}" for i in range(MIN_BASE_LLM_NUM)], value="LLM-1", show_label=True)
selected_base_llm_output = gr.Textbox(lines=4, label="LLM-1 (Click Save to save current content)",
placeholder="Enter LLM-1 output here", show_label=True)
with gr.Row():
base_llm_outputs_save_button = gr.Button('Save', variant='primary')
base_llm_outputs_clear_single_button = gr.Button('Clear Single', variant='primary')
base_llm_outputs_clear_all_button = gr.Button('Clear All', variant='primary')
base_llms_num = gr.Slider(
label='Number of base llms',
minimum=MIN_BASE_LLM_NUM,
maximum=MAX_BASE_LLM_NUM,
step=1,
value=MIN_BASE_LLM_NUM,
)
blender_state = gr.State(value={})
saved_rank_outputs = gr.State(value=[])
saved_fuse_outputs = gr.State(value=[])
gr.Markdown("## Blender Outputs")
with gr.Group():
rank_outputs = gr.Textbox(lines=1, label="Ranking outputs", placeholder="Ranking outputs", show_label=True)
fuser_outputs = gr.Textbox(lines=4, label="Fusing outputs", placeholder="Fusing outputs", show_label=True)
with gr.Row():
rank_button = gr.Button('Rank LLM Outputs', variant='primary')
fuse_button = gr.Button('Fuse Top-K ranked outputs', variant='primary')
clear_button = gr.Button('Clear Blender Outputs', variant='primary')
blender_config = gr.State(value={
"source_max_length": DEFAULT_SOURCE_MAX_LENGTH,
"candidate_max_length": DEFAULT_CANDIDATE_MAX_LENGTH,
"top_k_for_fuser": 3,
"max_new_tokens": DEFAULT_FUSER_MAX_NEW_TOKENS,
"temperature": 0.7,
"top_p": 1.0,
})
with gr.Accordion(label='Advanced options', open=False):
source_max_length = gr.Slider(
label='Max length of Instruction + Input',
minimum=1,
maximum=SOURCE_MAX_LENGTH,
step=1,
value=DEFAULT_SOURCE_MAX_LENGTH,
)
candidate_max_length = gr.Slider(
label='Max length of LLM-Output Candidate',
minimum=1,
maximum=CANDIDATE_MAX_LENGTH,
step=1,
value=DEFAULT_CANDIDATE_MAX_LENGTH,
)
top_k_for_fuser = gr.Slider(
label='Top-k ranked candidates to fuse',
minimum=1,
maximum=3,
step=1,
value=3,
)
max_new_tokens = gr.Slider(
label='Max new tokens fuser can generate',
minimum=1,
maximum=FUSER_MAX_NEW_TOKENS,
step=1,
value=DEFAULT_FUSER_MAX_NEW_TOKENS,
)
temperature = gr.Slider(
label='Temperature of fuser generation',
minimum=0.1,
maximum=2.0,
step=0.1,
value=0.7,
)
top_p = gr.Slider(
label='Top-p of fuser generation',
minimum=0.05,
maximum=1.0,
step=0.05,
value=1.0,
)
examples_dummy_textbox = gr.Textbox(lines=1, label="", placeholder="", show_label=False, visible=False)
batch_examples = gr.Examples(
examples=EXAMPLES,
fn=get_preprocess_examples,
cache_examples=True,
examples_per_page=5,
inputs=[inst_textbox, input_textbox],
outputs=[inst_textbox, input_textbox, base_llms_num, examples_dummy_textbox],
)
base_llms_num.change(
fn=update_base_llms_num,
inputs=[base_llms_num, saved_llm_outputs],
outputs=[selected_base_llm_name_dropdown, saved_llm_outputs],
)
examples_dummy_textbox.change(
fn=update_base_llm_dropdown_along_examples,
inputs=[examples_dummy_textbox],
outputs=[saved_llm_outputs, rank_outputs, fuser_outputs],
).then(
fn=display_llm_output,
inputs=[saved_llm_outputs, selected_base_llm_name_dropdown],
outputs=selected_base_llm_output,
)
selected_base_llm_name_dropdown.change(
fn=display_llm_output,
inputs=[saved_llm_outputs, selected_base_llm_name_dropdown],
outputs=selected_base_llm_output,
)
base_llm_outputs_save_button.click(
fn=save_llm_output,
inputs=[selected_base_llm_name_dropdown, selected_base_llm_output, saved_llm_outputs],
outputs=saved_llm_outputs,
)
base_llm_outputs_clear_all_button.click(
fn=lambda: [{}, ""],
inputs=[],
outputs=[saved_llm_outputs, selected_base_llm_output],
)
base_llm_outputs_clear_single_button.click(
fn=lambda: "",
inputs=[],
outputs=selected_base_llm_output,
)
rank_button.click(
fn=check_save_ranker_inputs,
inputs=[inst_textbox, input_textbox, saved_llm_outputs, blender_config],
outputs=blender_state,
).success(
fn=llms_rank,
inputs=[inst_textbox, input_textbox, saved_llm_outputs, blender_config],
outputs=[saved_rank_outputs, rank_outputs],
)
fuse_button.click(
fn=check_fuser_inputs,
inputs=[blender_state, blender_config, saved_rank_outputs],
outputs=[],
).success(
fn=llms_fuse,
inputs=[blender_state, blender_config, saved_rank_outputs],
outputs=[saved_fuse_outputs, fuser_outputs],
)
clear_button.click(
fn=lambda: ["", "", {}, []],
inputs=[],
outputs=[rank_outputs, fuser_outputs, blender_state, saved_rank_outputs],
)
# update blender config
source_max_length.change(
fn=lambda x, y: y.update({"source_max_length": x}) or y,
inputs=[source_max_length, blender_config],
outputs=blender_config,
)
candidate_max_length.change(
fn=lambda x, y: y.update({"candidate_max_length": x}) or y,
inputs=[candidate_max_length, blender_config],
outputs=blender_config,
)
top_k_for_fuser.change(
fn=lambda x, y: y.update({"top_k_for_fuser": x}) or y,
inputs=[top_k_for_fuser, blender_config],
outputs=blender_config,
)
max_new_tokens.change(
fn=lambda x, y: y.update({"max_new_tokens": x}) or y,
inputs=[max_new_tokens, blender_config],
outputs=blender_config,
)
temperature.change(
fn=lambda x, y: y.update({"temperature": x}) or y,
inputs=[temperature, blender_config],
outputs=blender_config,
)
top_p.change(
fn=lambda x, y: y.update({"top_p": x}) or y,
inputs=[top_p, blender_config],
outputs=blender_config,
)
with gr.Tab("PairRM"):
# PairRM interface
with gr.Row():
gr.Markdown(descriptions.PairRM_OVERALL_DESC)
gr.Image("https://yuchenlin.xyz/LLM-Blender/pairranker.png")
with gr.Tab("Compare two responses"):
instruction = gr.Textbox(lines=1, label="Instruction", placeholder="Enter instruction here", show_label=True)
with gr.Row():
response1 = gr.Textbox(lines=4, label="Response 1", placeholder="Enter response 1 here", show_label=True)
response2 = gr.Textbox(lines=4, label="Response 2", placeholder="Enter response 2 here", show_label=True)
with gr.Row():
compare_button = gr.Button('Compare', variant='primary')
clear_button = gr.Button('Clear', variant='primary')
with gr.Row():
compare_result = gr.Textbox(lines=1, label="Compare Result", placeholder="", show_label=True)
compare_result_prob = gr.Textbox(lines=1, label="PairRM Confidence", placeholder="", show_label=True)
def compare_fn(inst, response1, response2):
if not inst:
raise gr.Error("Please enter instruction")
if not response1 or not response2:
raise gr.Error("Please enter response 1 and response 2")
comparison_results = blender.compare([inst], [response1], [response2], return_logits=True)
logit = comparison_results[0]
if logit > 0:
result = "Response 1 is better than Response 2"
prob = f"Confidence: {round(logit, 2)}"
elif logit < 0:
result = "Response 2 is better than Response 1"
prob = f"Cofidence: {round(abs(logit), 2)}"
else:
result = "Response 1 and Response 2 are equally good"
prob = f"No confidence for tie"
return [result, prob]
compare_button.click(
fn=compare_fn,
inputs=[instruction, response1, response2],
outputs=[compare_result, compare_result_prob],
)
clear_button.click(
fn=lambda: ["", ""],
inputs=[],
outputs=[compare_result, compare_result_prob],
)
hhh_dummy_textbox1 = gr.Textbox(lines=1, label="subset", placeholder="", show_label=False, visible=False)
hhh_dummy_textbox2 = gr.Textbox(lines=1, label="Better Response", placeholder="", show_label=False, visible=False)
gr.Markdown("## Examples from [HuggingFaceH4/hhh_alignment](https://huggingface.co/datasets/HuggingFaceH4/hhh_alignment)")
gr.Examples(
HHH_EXAMPLES,
fn=get_hhh_examples,
cache_examples=True,
examples_per_page=5,
inputs=[hhh_dummy_textbox1, instruction, response1, response2, hhh_dummy_textbox2],
outputs=[instruction, response1, response2],
)
with gr.Tab("Compare assistant's response in two multi-turn conversations"):
gr.Markdown("NOTE: Comparison of two conversations is based on that the user query in each turn is the same of two conversations.")
def append_message(message, chat_history):
if not message:
return "", chat_history
if len(chat_history) == 0:
chat_history.append((message, "(Please enter your bot response)"))
else:
if chat_history[-1][1] == "(Please enter your bot response)":
chat_history[-1] = (chat_history[-1][0], message)
else:
chat_history.append((message, "(Please enter your bot response)"))
return "", chat_history
with gr.Row():
with gr.Column():
gr.Markdown("### Conversation A")
chatbot1 = gr.Chatbot()
msg1 = gr.Textbox(lines=1, label="Enter Chat history for Conversation A", placeholder="Enter your message here", show_label=True)
clear1 = gr.ClearButton([msg1, chatbot1])
msg1.submit(append_message, [msg1, chatbot1], [msg1, chatbot1])
with gr.Column():
gr.Markdown("### Conversation B")
chatbot2 = gr.Chatbot()
msg2 = gr.Textbox(lines=1, label="Enter Chat history for Conversation B", placeholder="Enter your message here", show_label=True)
clear2 = gr.ClearButton([msg2, chatbot2])
msg2.submit(append_message, [msg2, chatbot2], [msg2, chatbot2])
with gr.Row():
compare_button = gr.Button('Compare', variant='primary')
with gr.Row():
compare_result = gr.Textbox(lines=1, label="Compare Result", placeholder="", show_label=True)
compare_result_prob = gr.Textbox(lines=1, label="PairRM Confidence", placeholder="", show_label=True)
def compare_conv_fn(chat_history1, chat_history2):
if len(chat_history1) == 0 or len(chat_history2) == 0:
raise gr.Error("Please enter chat history for both conversations")
assert chat_history1[-1][1] != "(Please enter your bot response)" \
and chat_history2[-1][1] != "(Please enter your bot response)", \
"Please complete chat history for both conversations"
chat1_messages = []
for item in chat_history1:
chat1_messages.append({
"role": "USER",
"content": item[0],
})
chat1_messages.append({
"role": "ASSISTANT",
"content": item[1],
})
chat2_messages = []
for item in chat_history2:
chat2_messages.append({
"role": "USER",
"content": item[0],
})
chat2_messages.append({
"role": "ASSISTANT",
"content": item[1],
})
comparison_results = blender.compare_conversations([chat1_messages], [chat2_messages], return_logits=True)
logit = comparison_results[0]
if logit > 0:
result = "Assistant's response in Conversation A is better than Conversation B"
prob = f"Confidence: {round(logit, 2)}"
elif logit < 0:
result = "Assistant's response in Conversation B is better than Conversation A"
prob = f"Cofidence: {round(abs(logit), 2)}"
else:
result = "Assistant's response in Conversation A and Conversation B are equally good"
prob = f"No confidence for tie"
return [result, prob]
compare_button.click(
fn=compare_conv_fn,
inputs=[chatbot1, chatbot2],
outputs=[compare_result, compare_result_prob],
)
model_a_dummy_textbox = gr.Textbox(lines=1, label="Model A", placeholder="", show_label=False, visible=False)
model_b_dummy_textbox = gr.Textbox(lines=1, label="Model B", placeholder="", show_label=False, visible=False)
winner_dummy_textbox = gr.Textbox(lines=1, label="Better Model in conversation", placeholder="", show_label=False, visible=False)
chatbot1_dummy_textbox = gr.Textbox(lines=1, label="Conversation A", placeholder="", show_label=False, visible=False)
chatbot2_dummy_textbox = gr.Textbox(lines=1, label="Conversation B", placeholder="", show_label=False, visible=False)
gr.Markdown("## Examples from [lmsys/mt_bench_human_judgments](https://huggingface.co/datasets/lmsys/mt_bench_human_judgments)")
gr.Examples(
MT_BENCH_HUMAN_JUDGE_EXAMPLES,
fn=get_mt_bench_human_judge_examples,
cache_examples=True,
examples_per_page=5,
inputs=[model_a_dummy_textbox, model_b_dummy_textbox, chatbot1_dummy_textbox, chatbot2_dummy_textbox, winner_dummy_textbox],
outputs=[chatbot1, chatbot2],
)
gr.Markdown(descriptions.CITATION)
demo.queue(max_size=20).launch()