Spaces:
Running
Running
File size: 5,178 Bytes
0dd5c06 cf196b3 f796553 cf196b3 73e8b86 cf196b3 f796553 cf196b3 67812d2 73e8b86 a19f11e 3c495cc a19f11e f796553 000d4f2 f796553 cf196b3 8ee349a cf196b3 000d4f2 3c495cc 47db0c3 cf196b3 a089fa0 3c495cc 47db0c3 3c495cc a089fa0 cf196b3 3c495cc a089fa0 47db0c3 a089fa0 3c495cc 73e8b86 a089fa0 73e8b86 300b938 cf196b3 3c495cc cf196b3 3c495cc cf196b3 871741c cf196b3 3c495cc cf196b3 6b89337 000d4f2 6b89337 73e8b86 000d4f2 73e8b86 71d0339 cf196b3 a089fa0 cf196b3 71d0339 73e8b86 a089fa0 3c495cc a089fa0 000d4f2 3c495cc 000d4f2 a089fa0 73e8b86 a19f11e 73e8b86 6b89337 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
"""
It provides a platform for comparing the responses of two LLMs.
"""
import enum
import json
import os
from uuid import uuid4
import firebase_admin
from firebase_admin import credentials
from firebase_admin import firestore
import gradio as gr
from leaderboard import build_leaderboard
import response
from response import get_responses
# Path to local credentials file, used in local development.
CREDENTIALS_PATH = os.environ.get("CREDENTIALS_PATH")
# Credentials passed as an environment variable, used in deployment.
CREDENTIALS = os.environ.get("CREDENTIALS")
def get_credentials():
# Set credentials using a file in a local environment, if available.
if CREDENTIALS_PATH and os.path.exists(CREDENTIALS_PATH):
return credentials.Certificate(CREDENTIALS_PATH)
# Use environment variable for credentials when the file is not found,
# as credentials should not be public.
json_cred = json.loads(CREDENTIALS)
return credentials.Certificate(json_cred)
# TODO(#21): Fix auto-reload issue related to the initialization of Firebase.
firebase_admin.initialize_app(get_credentials())
db = firestore.client()
SUPPORTED_TRANSLATION_LANGUAGES = [
"Korean", "English", "Chinese", "Japanese", "Spanish", "French"
]
class VoteOptions(enum.Enum):
MODEL_A = "Model A is better"
MODEL_B = "Model B is better"
TIE = "Tie"
def vote(vote_button, response_a, response_b, model_a_name, model_b_name,
user_prompt, instruction, category, source_lang, target_lang):
doc_id = uuid4().hex
winner = VoteOptions(vote_button).name.lower()
deactivated_buttons = [gr.Button(interactive=False) for _ in range(3)]
doc = {
"id": doc_id,
"prompt": user_prompt,
"instruction": instruction,
"model_a": model_a_name,
"model_b": model_b_name,
"model_a_response": response_a,
"model_b_response": response_b,
"winner": winner,
"timestamp": firestore.SERVER_TIMESTAMP
}
if category == response.Category.SUMMARIZE.value:
doc_ref = db.collection("arena-summarizations").document(doc_id)
doc_ref.set(doc)
return deactivated_buttons
if category == response.Category.TRANSLATE.value:
if not source_lang or not target_lang:
raise gr.Error("Please select source and target languages.")
doc_ref = db.collection("arena-translations").document(doc_id)
doc["source_language"] = source_lang.lower()
doc["target_language"] = target_lang.lower()
doc_ref.set(doc)
return deactivated_buttons
raise gr.Error("Please select a response type.")
with gr.Blocks(title="Arena") as app:
with gr.Row():
category_radio = gr.Radio(
[category.value for category in response.Category],
label="Category",
info="The chosen category determines the instruction sent to the LLMs.")
source_language = gr.Dropdown(
choices=SUPPORTED_TRANSLATION_LANGUAGES,
label="Source language",
info="Choose the source language for translation.",
interactive=True,
visible=False)
target_language = gr.Dropdown(
choices=SUPPORTED_TRANSLATION_LANGUAGES,
label="Target language",
info="Choose the target language for translation.",
interactive=True,
visible=False)
def update_language_visibility(category):
visible = category == response.Category.TRANSLATE.value
return {
source_language: gr.Dropdown(visible=visible),
target_language: gr.Dropdown(visible=visible)
}
category_radio.change(update_language_visibility, category_radio,
[source_language, target_language])
model_names = [gr.State(None), gr.State(None)]
response_boxes = [gr.State(None), gr.State(None)]
prompt = gr.TextArea(label="Prompt", lines=4)
submit = gr.Button()
with gr.Row():
response_boxes[0] = gr.Textbox(label="Model A", interactive=False)
response_boxes[1] = gr.Textbox(label="Model B", interactive=False)
# TODO(#5): Display it only after the user submits the prompt.
with gr.Row():
option_a = gr.Button(VoteOptions.MODEL_A.value)
option_b = gr.Button(VoteOptions.MODEL_B.value)
tie = gr.Button(VoteOptions.TIE.value)
# TODO(#7): Hide it until the user votes.
with gr.Accordion("Show models", open=False):
with gr.Row():
model_names[0] = gr.Textbox(label="Model A", interactive=False)
model_names[1] = gr.Textbox(label="Model B", interactive=False)
vote_buttons = [option_a, option_b, tie]
instruction_state = gr.State("")
submit.click(
get_responses, [prompt, category_radio, source_language, target_language],
response_boxes + model_names + vote_buttons + [instruction_state])
common_inputs = response_boxes + model_names + [
prompt, instruction_state, category_radio, source_language,
target_language
]
option_a.click(vote, [option_a] + common_inputs, vote_buttons)
option_b.click(vote, [option_b] + common_inputs, vote_buttons)
tie.click(vote, [tie] + common_inputs, vote_buttons)
build_leaderboard(db)
if __name__ == "__main__":
# We need to enable queue to use generators.
app.queue()
app.launch(debug=True)
|