import gradio as gr import sys import os import random import llm_blender import descriptions from datasets import load_dataset from llm_blender.blender.blender_utils import get_topk_candidates_from_ranks from typing import List MAX_BASE_LLM_NUM = 20 MIN_BASE_LLM_NUM = 3 SOURCE_MAX_LENGTH = 256 DEFAULT_SOURCE_MAX_LENGTH = 128 CANDIDATE_MAX_LENGTH = 256 DEFAULT_CANDIDATE_MAX_LENGTH = 128 FUSER_MAX_NEW_TOKENS = 512 DEFAULT_FUSER_MAX_NEW_TOKENS = 256 EXAMPLES_DATASET = load_dataset("llm-blender/mix-instruct", split='validation', streaming=True) SHUFFLED_EXAMPLES_DATASET = EXAMPLES_DATASET.shuffle(seed=42, buffer_size=1000) EXAMPLES = [] CANDIDATE_EXAMPLES = {} for example in SHUFFLED_EXAMPLES_DATASET.take(100): EXAMPLES.append([ example['instruction'], example['input'], ]) CANDIDATE_EXAMPLES[example['instruction']+example['input']] = example['candidates'] HHH_EXAMPLES = [] subsets = ['harmless', 'helpful', 'honest', 'other'] random.seed(42) for subset in subsets: dataset = load_dataset("HuggingFaceH4/hhh_alignment", subset) for example in dataset['test']: if random.random() < 0.5: HHH_EXAMPLES.append([ subset, example['input'], example['targets']['choices'][0], example['targets']['choices'][1], "Response 1" if example['targets']['labels'][0] == 1 else "Response 2", ]) else: HHH_EXAMPLES.append([ subset, example['input'], example['targets']['choices'][1], example['targets']['choices'][0], "Response 2" if example['targets']['labels'][0] == 1 else "Response 1", ]) def get_hhh_examples(subset, instruction, response1, response2, dummy_text): return instruction, response1, response2 MT_BENCH_HUMAN_JUDGE_EXAMPLES = [] dataset = load_dataset("lmsys/mt_bench_human_judgments") for example in dataset['human']: if example['turn'] != 1: continue MT_BENCH_HUMAN_JUDGE_EXAMPLES.append([ example['model_a'], example['model_b'], str(example['conversation_a']), str(example['conversation_b']), "Model A" if example['winner'] == 'model_a' else "Model B", ]) def get_mt_bench_human_judge_examples(model_a, model_b, conversation_a, conversation_b, dummy_text): chat_history_a = [] chat_history_b = [] conversation_a = eval(conversation_a) conversation_b = eval(conversation_b) for i in range(0, len(conversation_a), 2): chat_history_a.append((conversation_a[i]['content'], conversation_a[i+1]['content'])) assert conversation_a[i]['role'] == 'user' and conversation_a[i+1]['role'] == 'assistant' for i in range(0, len(conversation_b), 2): chat_history_b.append((conversation_b[i]['content'], conversation_b[i+1]['content'])) assert conversation_b[i]['role'] == 'user' and conversation_b[i+1]['role'] == 'assistant' return chat_history_a, chat_history_b blender = llm_blender.Blender() blender.loadranker("llm-blender/PairRM") blender.loadfuser("llm-blender/gen_fuser_3b") def update_base_llms_num(k, llm_outputs): k = int(k) return [gr.Dropdown(choices=[f"LLM-{i+1}" for i in range(k)], value=f"LLM-1" if k >= 1 else "", visible=True), {f"LLM-{i+1}": llm_outputs.get(f"LLM-{i+1}", "") for i in range(k)}] def display_llm_output(llm_outputs, selected_base_llm_name): return gr.Textbox(value=llm_outputs.get(selected_base_llm_name, ""), label=selected_base_llm_name + " (Click Save to save current content)", placeholder=f"Enter {selected_base_llm_name} output here", show_label=True) def save_llm_output(selected_base_llm_name, selected_base_llm_output, llm_outputs): llm_outputs({selected_base_llm_name: selected_base_llm_output}) return llm_outputs def get_preprocess_examples(inst, input): # get the num_of_base_llms candidates = CANDIDATE_EXAMPLES[inst+input] num_candiates = len(candidates) dummy_text = inst+input return inst, input, num_candiates, dummy_text def update_base_llm_dropdown_along_examples(dummy_text): candidates = CANDIDATE_EXAMPLES[dummy_text] ex_llm_outputs = {f"LLM-{i+1}": candidates[i]['text'] for i in range(len(candidates))} return ex_llm_outputs, "", "" def check_save_ranker_inputs(inst, input, llm_outputs, blender_config): if not inst and not input: raise gr.Error("Please enter instruction or input context") if not all([x for x in llm_outputs.values()]): empty_llm_names = [llm_name for llm_name, llm_output in llm_outputs.items() if not llm_output] raise gr.Error("Please enter base LLM outputs for LLMs: {}").format(empty_llm_names) return { "inst": inst, "input": input, "candidates": list(llm_outputs.values()), } def check_fuser_inputs(blender_state, blender_config, ranks): if not (blender_state.get("inst", None) or blender_state.get("input", None)): raise gr.Error("Please enter instruction or input context") if "candidates" not in blender_state or len(ranks)==0: raise gr.Error("Please rank LLM outputs first") return def llms_rank(inst, input, llm_outputs, blender_config): candidates = list(llm_outputs.values()) rank_params = { "source_max_length": blender_config['source_max_length'], "candidate_max_length": blender_config['candidate_max_length'], } ranks = blender.rank(instructions=[inst], inputs=[input], candidates=[candidates])[0] return [ranks, ", ".join([f"LLM-{i+1}: {rank}" for i, rank in enumerate(ranks)])] def llms_fuse(blender_state, blender_config, ranks): inst = blender_state['inst'] input = blender_state['input'] candidates = blender_state['candidates'] top_k_for_fuser = blender_config['top_k_for_fuser'] fuse_params = blender_config.copy() fuse_params.pop("top_k_for_fuser") fuse_params.pop("source_max_length") fuse_params['no_repeat_ngram_size'] = 3 top_k_candidates = get_topk_candidates_from_ranks([ranks], [candidates], top_k=top_k_for_fuser)[0] fuser_outputs = blender.fuse(instructions=[inst], inputs=[input], candidates=[top_k_candidates], **fuse_params, batch_size=1)[0] return [fuser_outputs, fuser_outputs] def display_fuser_output(fuser_output): return fuser_output with gr.Blocks(theme='ParityError/Anime') as demo: with gr.Tab("LLM-Blender"): # llm-blender interface with gr.Row(): gr.Markdown(descriptions.LLM_BLENDER_OVERALL_DESC) gr.Image("https://github.com/yuchenlin/LLM-Blender/blob/main/docs/llm_blender.png?raw=true", height=300) gr.Markdown("## Input and Base LLMs") with gr.Row(): with gr.Column(): inst_textbox = gr.Textbox(lines=1, label="Instruction", placeholder="Enter instruction here", show_label=True) input_textbox = gr.Textbox(lines=4, label="Input Context", placeholder="Enter input context here", show_label=True) with gr.Column(): saved_llm_outputs = gr.State(value={}) with gr.Group(): selected_base_llm_name_dropdown = gr.Dropdown(label="Base LLM", choices=[f"LLM-{i+1}" for i in range(MIN_BASE_LLM_NUM)], value="LLM-1", show_label=True) selected_base_llm_output = gr.Textbox(lines=4, label="LLM-1 (Click Save to save current content)", placeholder="Enter LLM-1 output here", show_label=True) with gr.Row(): base_llm_outputs_save_button = gr.Button('Save', variant='primary') base_llm_outputs_clear_single_button = gr.Button('Clear Single', variant='primary') base_llm_outputs_clear_all_button = gr.Button('Clear All', variant='primary') base_llms_num = gr.Slider( label='Number of base llms', minimum=MIN_BASE_LLM_NUM, maximum=MAX_BASE_LLM_NUM, step=1, value=MIN_BASE_LLM_NUM, ) blender_state = gr.State(value={}) saved_rank_outputs = gr.State(value=[]) saved_fuse_outputs = gr.State(value=[]) gr.Markdown("## Blender Outputs") with gr.Group(): rank_outputs = gr.Textbox(lines=1, label="Ranking outputs", placeholder="Ranking outputs", show_label=True) fuser_outputs = gr.Textbox(lines=4, label="Fusing outputs", placeholder="Fusing outputs", show_label=True) with gr.Row(): rank_button = gr.Button('Rank LLM Outputs', variant='primary') fuse_button = gr.Button('Fuse Top-K ranked outputs', variant='primary') clear_button = gr.Button('Clear Blender Outputs', variant='primary') blender_config = gr.State(value={ "source_max_length": DEFAULT_SOURCE_MAX_LENGTH, "candidate_max_length": DEFAULT_CANDIDATE_MAX_LENGTH, "top_k_for_fuser": 3, "max_new_tokens": DEFAULT_FUSER_MAX_NEW_TOKENS, "temperature": 0.7, "top_p": 1.0, }) with gr.Accordion(label='Advanced options', open=False): source_max_length = gr.Slider( label='Max length of Instruction + Input', minimum=1, maximum=SOURCE_MAX_LENGTH, step=1, value=DEFAULT_SOURCE_MAX_LENGTH, ) candidate_max_length = gr.Slider( label='Max length of LLM-Output Candidate', minimum=1, maximum=CANDIDATE_MAX_LENGTH, step=1, value=DEFAULT_CANDIDATE_MAX_LENGTH, ) top_k_for_fuser = gr.Slider( label='Top-k ranked candidates to fuse', minimum=1, maximum=3, step=1, value=3, ) max_new_tokens = gr.Slider( label='Max new tokens fuser can generate', minimum=1, maximum=FUSER_MAX_NEW_TOKENS, step=1, value=DEFAULT_FUSER_MAX_NEW_TOKENS, ) temperature = gr.Slider( label='Temperature of fuser generation', minimum=0.1, maximum=2.0, step=0.1, value=0.7, ) top_p = gr.Slider( label='Top-p of fuser generation', minimum=0.05, maximum=1.0, step=0.05, value=1.0, ) examples_dummy_textbox = gr.Textbox(lines=1, label="", placeholder="", show_label=False, visible=False) batch_examples = gr.Examples( examples=EXAMPLES, fn=get_preprocess_examples, cache_examples=True, examples_per_page=5, inputs=[inst_textbox, input_textbox], outputs=[inst_textbox, input_textbox, base_llms_num, examples_dummy_textbox], ) base_llms_num.change( fn=update_base_llms_num, inputs=[base_llms_num, saved_llm_outputs], outputs=[selected_base_llm_name_dropdown, saved_llm_outputs], ) examples_dummy_textbox.change( fn=update_base_llm_dropdown_along_examples, inputs=[examples_dummy_textbox], outputs=[saved_llm_outputs, rank_outputs, fuser_outputs], ).then( fn=display_llm_output, inputs=[saved_llm_outputs, selected_base_llm_name_dropdown], outputs=selected_base_llm_output, ) selected_base_llm_name_dropdown.change( fn=display_llm_output, inputs=[saved_llm_outputs, selected_base_llm_name_dropdown], outputs=selected_base_llm_output, ) base_llm_outputs_save_button.click( fn=save_llm_output, inputs=[selected_base_llm_name_dropdown, selected_base_llm_output, saved_llm_outputs], outputs=saved_llm_outputs, ) base_llm_outputs_clear_all_button.click( fn=lambda: [{}, ""], inputs=[], outputs=[saved_llm_outputs, selected_base_llm_output], ) base_llm_outputs_clear_single_button.click( fn=lambda: "", inputs=[], outputs=selected_base_llm_output, ) rank_button.click( fn=check_save_ranker_inputs, inputs=[inst_textbox, input_textbox, saved_llm_outputs, blender_config], outputs=blender_state, ).success( fn=llms_rank, inputs=[inst_textbox, input_textbox, saved_llm_outputs, blender_config], outputs=[saved_rank_outputs, rank_outputs], ) fuse_button.click( fn=check_fuser_inputs, inputs=[blender_state, blender_config, saved_rank_outputs], outputs=[], ).success( fn=llms_fuse, inputs=[blender_state, blender_config, saved_rank_outputs], outputs=[saved_fuse_outputs, fuser_outputs], ) clear_button.click( fn=lambda: ["", "", {}, []], inputs=[], outputs=[rank_outputs, fuser_outputs, blender_state, saved_rank_outputs], ) # update blender config source_max_length.change( fn=lambda x, y: y.update({"source_max_length": x}) or y, inputs=[source_max_length, blender_config], outputs=blender_config, ) candidate_max_length.change( fn=lambda x, y: y.update({"candidate_max_length": x}) or y, inputs=[candidate_max_length, blender_config], outputs=blender_config, ) top_k_for_fuser.change( fn=lambda x, y: y.update({"top_k_for_fuser": x}) or y, inputs=[top_k_for_fuser, blender_config], outputs=blender_config, ) max_new_tokens.change( fn=lambda x, y: y.update({"max_new_tokens": x}) or y, inputs=[max_new_tokens, blender_config], outputs=blender_config, ) temperature.change( fn=lambda x, y: y.update({"temperature": x}) or y, inputs=[temperature, blender_config], outputs=blender_config, ) top_p.change( fn=lambda x, y: y.update({"top_p": x}) or y, inputs=[top_p, blender_config], outputs=blender_config, ) with gr.Tab("PairRM"): # PairRM interface with gr.Row(): gr.Markdown(descriptions.PairRM_OVERALL_DESC) gr.Image("https://yuchenlin.xyz/LLM-Blender/pairranker.png") with gr.Tab("Compare two responses"): instruction = gr.Textbox(lines=1, label="Instruction", placeholder="Enter instruction here", show_label=True) with gr.Row(): response1 = gr.Textbox(lines=4, label="Response 1", placeholder="Enter response 1 here", show_label=True) response2 = gr.Textbox(lines=4, label="Response 2", placeholder="Enter response 2 here", show_label=True) with gr.Row(): compare_button = gr.Button('Compare', variant='primary') clear_button = gr.Button('Clear', variant='primary') with gr.Row(): compare_result = gr.Textbox(lines=1, label="Compare Result", placeholder="", show_label=True) compare_result_prob = gr.Textbox(lines=1, label="PairRM Confidence", placeholder="", show_label=True) def compare_fn(inst, response1, response2): if not inst: raise gr.Error("Please enter instruction") if not response1 or not response2: raise gr.Error("Please enter response 1 and response 2") comparison_results = blender.compare([inst], [response1], [response2], return_logits=True) logit = comparison_results[0] if logit > 0: result = "Response 1 is better than Response 2" prob = f"Confidence: {round(logit, 2)}" elif logit < 0: result = "Response 2 is better than Response 1" prob = f"Cofidence: {round(abs(logit), 2)}" else: result = "Response 1 and Response 2 are equally good" prob = f"No confidence for tie" return [result, prob] compare_button.click( fn=compare_fn, inputs=[instruction, response1, response2], outputs=[compare_result, compare_result_prob], ) clear_button.click( fn=lambda: ["", ""], inputs=[], outputs=[compare_result, compare_result_prob], ) hhh_dummy_textbox1 = gr.Textbox(lines=1, label="subset", placeholder="", show_label=False, visible=False) hhh_dummy_textbox2 = gr.Textbox(lines=1, label="Better Response", placeholder="", show_label=False, visible=False) gr.Markdown("## Examples from [HuggingFaceH4/hhh_alignment](https://huggingface.co/datasets/HuggingFaceH4/hhh_alignment)") gr.Examples( HHH_EXAMPLES, fn=get_hhh_examples, cache_examples=True, examples_per_page=5, inputs=[hhh_dummy_textbox1, instruction, response1, response2, hhh_dummy_textbox2], outputs=[instruction, response1, response2], ) with gr.Tab("Compare assistant's response in two multi-turn conversations"): gr.Markdown("NOTE: Comparison of two conversations is based on that the user query in each turn is the same of two conversations.") def append_message(message, chat_history): if not message: return "", chat_history if len(chat_history) == 0: chat_history.append((message, "(Please enter your bot response)")) else: if chat_history[-1][1] == "(Please enter your bot response)": chat_history[-1] = (chat_history[-1][0], message) else: chat_history.append((message, "(Please enter your bot response)")) return "", chat_history with gr.Row(): with gr.Column(): gr.Markdown("### Conversation A") chatbot1 = gr.Chatbot() msg1 = gr.Textbox(lines=1, label="Enter Chat history for Conversation A", placeholder="Enter your message here", show_label=True) clear1 = gr.ClearButton([msg1, chatbot1]) msg1.submit(append_message, [msg1, chatbot1], [msg1, chatbot1]) with gr.Column(): gr.Markdown("### Conversation B") chatbot2 = gr.Chatbot() msg2 = gr.Textbox(lines=1, label="Enter Chat history for Conversation B", placeholder="Enter your message here", show_label=True) clear2 = gr.ClearButton([msg2, chatbot2]) msg2.submit(append_message, [msg2, chatbot2], [msg2, chatbot2]) with gr.Row(): compare_button = gr.Button('Compare', variant='primary') with gr.Row(): compare_result = gr.Textbox(lines=1, label="Compare Result", placeholder="", show_label=True) compare_result_prob = gr.Textbox(lines=1, label="PairRM Confidence", placeholder="", show_label=True) def compare_conv_fn(chat_history1, chat_history2): if len(chat_history1) == 0 or len(chat_history2) == 0: raise gr.Error("Please enter chat history for both conversations") assert chat_history1[-1][1] != "(Please enter your bot response)" \ and chat_history2[-1][1] != "(Please enter your bot response)", \ "Please complete chat history for both conversations" chat1_messages = [] for item in chat_history1: chat1_messages.append({ "role": "USER", "content": item[0], }) chat1_messages.append({ "role": "ASSISTANT", "content": item[1], }) chat2_messages = [] for item in chat_history2: chat2_messages.append({ "role": "USER", "content": item[0], }) chat2_messages.append({ "role": "ASSISTANT", "content": item[1], }) comparison_results = blender.compare_conversations([chat1_messages], [chat2_messages], return_logits=True) logit = comparison_results[0] if logit > 0: result = "Assistant's response in Conversation A is better than Conversation B" prob = f"Confidence: {round(logit, 2)}" elif logit < 0: result = "Assistant's response in Conversation B is better than Conversation A" prob = f"Cofidence: {round(abs(logit), 2)}" else: result = "Assistant's response in Conversation A and Conversation B are equally good" prob = f"No confidence for tie" return [result, prob] compare_button.click( fn=compare_conv_fn, inputs=[chatbot1, chatbot2], outputs=[compare_result, compare_result_prob], ) model_a_dummy_textbox = gr.Textbox(lines=1, label="Model A", placeholder="", show_label=False, visible=False) model_b_dummy_textbox = gr.Textbox(lines=1, label="Model B", placeholder="", show_label=False, visible=False) winner_dummy_textbox = gr.Textbox(lines=1, label="Better Model in conversation", placeholder="", show_label=False, visible=False) chatbot1_dummy_textbox = gr.Textbox(lines=1, label="Conversation A", placeholder="", show_label=False, visible=False) chatbot2_dummy_textbox = gr.Textbox(lines=1, label="Conversation B", placeholder="", show_label=False, visible=False) gr.Markdown("## Examples from [lmsys/mt_bench_human_judgments](https://huggingface.co/datasets/lmsys/mt_bench_human_judgments)") gr.Examples( MT_BENCH_HUMAN_JUDGE_EXAMPLES, fn=get_mt_bench_human_judge_examples, cache_examples=True, examples_per_page=5, inputs=[model_a_dummy_textbox, model_b_dummy_textbox, chatbot1_dummy_textbox, chatbot2_dummy_textbox, winner_dummy_textbox], outputs=[chatbot1, chatbot2], ) gr.Markdown(descriptions.CITATION) demo.queue(max_size=20).launch()