Spaces:
Runtime error
Runtime error
from dataclasses import dataclass | |
import gradio as gr | |
from database import PostgreSQL, Entry | |
from models import ALL_SUPPORTED_MODELS | |
from utils import * | |
db = PostgreSQL() | |
class State: | |
dataset: str = "mmlu" | |
topic: str = DEFAULT_TOPIC | |
model: str = random.choice(MODELS) | |
card_path: str = None | |
card: Card = None | |
summarizer_model: str = DEFAULT_SUMMARIZER | |
card_summary: str = None | |
qa: str = None | |
model_answer: str = None | |
ground_truth: bool = None # if the model correctly answers the question | |
submitted: bool = False | |
def __post_init__(self): | |
if any([self.card_path is None, self.card is None, self.card_summary is None]): | |
assert all([self.card_path is None, self.card is None, self.card_summary is None]) | |
self.card, self.card_path = sample_random_card(self.dataset, self.topic, self.model) | |
# init qa | |
if any([self.qa is None, self.model_answer is None, self.ground_truth is None]): | |
assert all([self.qa is None, self.model_answer is None, self.ground_truth is None]) | |
self.qa, self.model_answer, self.ground_truth = sample_random_qa(self.dataset, self.topic, self.model) | |
if self.card_summary is None: | |
self.card_summary = summarize_card(self.summarizer_model, self.card, self.qa) | |
# self.card_summary = "" | |
def submit_guess(guess: str, reasoning: str, confidence: int, state: State): | |
"""submit_button.click() | |
This function mutates the state. | |
""" | |
if guess is None: | |
return "Please make a guess and then submit!", "", state | |
if state.submitted: | |
return "You have already submitted your guess!", state.model_answer, state | |
guess = True if guess == "Correct" else False | |
result = "You are right!" if (state.ground_truth == guess) else "You are wrong!" | |
# need to store: topic, model, card, question, guess, reasoning, confidence | |
entry = Entry(state.model, state.card_path, state.topic, state.qa, guess, state.ground_truth, reasoning, confidence) | |
db.insert(entry) | |
state.submitted = True | |
return result, state.model_answer, state | |
def next_guess(state: State): | |
"""next_button.click() | |
This function mutates the state. | |
""" | |
state.qa, state.model_answer, state.ground_truth = sample_random_qa(state.dataset, state.topic, state.model) | |
state.card_summary = summarize_card(state.summarizer_model, state.card, state.qa) | |
state.submitted = False | |
return state.qa, state.card_summary, "", "", state | |
def re_summarize(state: State): | |
"""re_summarize_button.click() | |
This function mutates the state. | |
""" | |
state.card_summary = summarize_card(state.summarizer_model, state.card, state.qa) | |
return state.card_summary, state | |
def switch_card(state: State): | |
"""switch_card_button.click() | |
This function mutates the state. | |
""" | |
old_model = state.model | |
state.model = random.choice(MODELS) | |
state.card, state.card_path = sample_random_card(state.dataset, state.topic, state.model) | |
next_guess(state) | |
return old_model, state.qa, "", state.card.get_markdown_str(), state.card_summary, "", state | |
def init_app(): | |
theme = gr.themes.Default( | |
primary_hue="orange", | |
secondary_hue="blue", | |
neutral_hue="gray", | |
text_size=gr.themes.Size( | |
name="text_custom", | |
xxs="10px", | |
xs="12px", | |
sm="14px", | |
md="16px", | |
lg="20px", | |
xl="24px", | |
xxl="28px", | |
), | |
) | |
with gr.Blocks(theme=theme) as app: | |
gr_state = gr.State(State()) # this state if only for the current user | |
s = gr_state.value | |
s: State | |
with gr.Row(): # header | |
with gr.Column(scale=1): | |
with gr.Group(): | |
gr_dataset = gr.Dropdown(label="Select a Dataset", choices=DATASETS, value=s.dataset) | |
gr_topic = gr.Dropdown(label="Select a Topic", choices=TOPICS[s.dataset], value=s.topic) | |
gr_switch_card_button = gr.Button("Switch Evaluation Card") | |
gr_previous_model = gr.Textbox(label="Previous Model (A model may have multiple cards.)", value="", interactive=False) | |
with gr.Column(scale=2): | |
gr_instruction = gr.Markdown(value=read_all("prompts/instructions.md")) | |
with gr.Row(): | |
with gr.Column(scale=1): # question/guess column | |
gr_question = gr.Textbox(label="Question", value=s.qa, interactive=False, show_copy_button=True) | |
with gr.Group(): | |
gr_reasoning = gr.Textbox(label="Explanation for Your Guess", lines=1, placeholder="Reason your decision (optional)") | |
gr_guess = gr.Radio(label="I believe the model will answer this question", choices=["Correct", "Incorrect"]) | |
gr_confidence = gr.Slider(label="Confidence", minimum=1, maximum=5, step=1, value=3) | |
gr_guess_result = gr.Textbox(label="Result", value="", | |
placeholder="We will show the result once you submit your guess! :>", interactive=False) | |
gr_submit_button = gr.Button("Submit") | |
gr_next_button = gr.Button("Next Entry (will not change the full card)") | |
with gr.Column(scale=2): # card column | |
with gr.Accordion(label="Full Evaluation Card", open=False): | |
gr_full_card = gr.Markdown(label="Full Evaluation Card", value=s.card.get_markdown_str()) | |
with gr.Group(): | |
with gr.Row(): | |
gr_summarizer = gr.Dropdown(label="Select a Model as the Summarizer", choices=ALL_SUPPORTED_MODELS, value=DEFAULT_SUMMARIZER, scale=2, interactive=True) | |
gr_re_summarize_button = gr.Button("Re-generate Summary", scale=1) | |
with gr.Accordion(label="Evaluation Card Summary", open=True): | |
gr_relevant_card = gr.Markdown(value=s.card_summary) | |
gr_model_answer = gr.Textbox(label="Model's Answer", value="", interactive=False, show_copy_button=True) | |
gr_submit_button.click(fn=submit_guess, | |
inputs=[gr_guess, gr_reasoning, gr_confidence, gr_state], | |
outputs=[gr_guess_result, gr_model_answer, gr_state]) | |
gr_next_button.click(fn=next_guess, | |
inputs=[gr_state], | |
outputs=[gr_question, gr_relevant_card, gr_model_answer, gr_guess_result, gr_state]) | |
gr_re_summarize_button.click(fn=re_summarize, | |
inputs=[gr_state], | |
outputs=[gr_relevant_card, gr_state]) | |
gr_switch_card_button.click(fn=switch_card, | |
inputs=[gr_state], | |
outputs=[gr_previous_model, gr_question, gr_guess_result, gr_full_card, gr_relevant_card, gr_model_answer, gr_state]) | |
app.launch() | |
if __name__ == "__main__": | |
init_app() | |