arena / app.py
Kang Suhyun
[#71] Add custom prompt option (#77)
43c8549 unverified
raw
history blame
6.57 kB
"""
It provides a platform for comparing the responses of two LLMs.
"""
import enum
from uuid import uuid4
from firebase_admin import firestore
import gradio as gr
import lingua
import credentials
from credentials import set_credentials
from leaderboard import build_leaderboard
from leaderboard import db
from leaderboard import SUPPORTED_TRANSLATION_LANGUAGES
from model import check_models
from model import supported_models
import response
from response import get_responses
detector = lingua.LanguageDetectorBuilder.from_all_languages().build()
class VoteOptions(enum.Enum):
MODEL_A = "Model A is better"
MODEL_B = "Model B is better"
TIE = "Tie"
def vote(vote_button, response_a, response_b, model_a_name, model_b_name,
prompt, instruction, category, source_lang, target_lang):
doc_id = uuid4().hex
winner = VoteOptions(vote_button).name.lower()
deactivated_buttons = [gr.Button(interactive=False) for _ in range(3)]
outputs = deactivated_buttons + [gr.Row(visible=True)]
doc = {
"id": doc_id,
"prompt": prompt,
"instruction": instruction,
"model_a": model_a_name,
"model_b": model_b_name,
"model_a_response": response_a,
"model_b_response": response_b,
"winner": winner,
"timestamp": firestore.SERVER_TIMESTAMP
}
if category == response.Category.SUMMARIZE.value:
language_a = detector.detect_language_of(response_a)
language_b = detector.detect_language_of(response_b)
doc_ref = db.collection("arena-summarizations").document(doc_id)
doc["model_a_response_language"] = language_a.name.lower()
doc["model_b_response_language"] = language_b.name.lower()
doc_ref.set(doc)
return outputs
if category == response.Category.TRANSLATE.value:
if not source_lang or not target_lang:
raise gr.Error("Please select source and target languages.")
doc_ref = db.collection("arena-translations").document(doc_id)
doc["source_language"] = source_lang.lower()
doc["target_language"] = target_lang.lower()
doc_ref.set(doc)
return outputs
raise gr.Error("Please select a response type.")
# Removes the persistent orange border from the leaderboard, which
# appears due to the 'generating' class when using the 'every' parameter.
css = """
.leaderboard .generating {
border: none;
}
"""
with gr.Blocks(title="Arena", css=css) as app:
with gr.Row():
category_radio = gr.Radio(
choices=[category.value for category in response.Category],
value=response.Category.SUMMARIZE.value,
label="Category",
info="The chosen category determines the instruction sent to the LLMs.")
source_language = gr.Dropdown(
choices=SUPPORTED_TRANSLATION_LANGUAGES,
value="English",
label="Source language",
info="Choose the source language for translation.",
interactive=True,
visible=False)
target_language = gr.Dropdown(
choices=SUPPORTED_TRANSLATION_LANGUAGES,
value="Spanish",
label="Target language",
info="Choose the target language for translation.",
interactive=True,
visible=False)
def update_language_visibility(category):
visible = category == response.Category.TRANSLATE.value
return {
source_language: gr.Dropdown(visible=visible),
target_language: gr.Dropdown(visible=visible)
}
category_radio.change(update_language_visibility, category_radio,
[source_language, target_language])
model_names = [gr.State(None), gr.State(None)]
response_boxes = [gr.State(None), gr.State(None)]
prompt_textarea = gr.TextArea(label="Prompt", lines=4)
submit = gr.Button()
with gr.Group():
with gr.Row():
response_boxes[0] = gr.Textbox(label="Model A", interactive=False)
response_boxes[1] = gr.Textbox(label="Model B", interactive=False)
with gr.Row(visible=False) as model_name_row:
model_names[0] = gr.Textbox(show_label=False)
model_names[1] = gr.Textbox(show_label=False)
with gr.Row(visible=False) as vote_row:
option_a = gr.Button(VoteOptions.MODEL_A.value)
option_b = gr.Button(VoteOptions.MODEL_B.value)
tie = gr.Button(VoteOptions.TIE.value)
instruction_state = gr.State("")
# The following elements need to be reset when the user changes
# the category, source language, or target language.
ui_elements = [
response_boxes[0], response_boxes[1], model_names[0], model_names[1],
instruction_state, model_name_row, vote_row
]
def reset_ui():
return [gr.Textbox(value="") for _ in range(4)
] + [gr.State(""),
gr.Row(visible=False),
gr.Row(visible=False)]
category_radio.change(fn=reset_ui, outputs=ui_elements)
source_language.change(fn=reset_ui, outputs=ui_elements)
target_language.change(fn=reset_ui, outputs=ui_elements)
submit_event = submit.click(
fn=lambda: [
gr.Radio(interactive=False),
gr.Dropdown(interactive=False),
gr.Dropdown(interactive=False),
gr.Button(interactive=False),
gr.Row(visible=False),
gr.Row(visible=False)
],
outputs=[
category_radio, source_language, target_language, submit, vote_row,
model_name_row
]).then(fn=get_responses,
inputs=[
prompt_textarea, category_radio, source_language,
target_language
],
outputs=response_boxes + model_names + [instruction_state])
submit_event.success(fn=lambda: gr.Row(visible=True), outputs=vote_row)
submit_event.then(
fn=lambda: [
gr.Radio(interactive=True),
gr.Dropdown(interactive=True),
gr.Dropdown(interactive=True),
gr.Button(interactive=True)
],
outputs=[category_radio, source_language, target_language, submit])
common_inputs = response_boxes + model_names + [
prompt_textarea, instruction_state, category_radio, source_language,
target_language
]
common_outputs = [option_a, option_b, tie, model_name_row]
option_a.click(vote, [option_a] + common_inputs, common_outputs)
option_b.click(vote, [option_b] + common_inputs, common_outputs)
tie.click(vote, [tie] + common_inputs, common_outputs)
build_leaderboard()
if __name__ == "__main__":
set_credentials(credentials.CREDENTIALS, credentials.CREDENTIALS_PATH)
check_models(supported_models)
# We need to enable queue to use generators.
app.queue()
app.launch(debug=True)