Spaces:
Running
Running
| import gradio as gr | |
| import random | |
| from threading import Thread | |
| from queue import Queue | |
| # Import our new modules | |
| import config | |
| import backend | |
| # --- HELPER FUNCTIONS (Unchanged) --- | |
| def get_random_question(domain): | |
| data_conf = config.DATASET_CONFIG[domain] | |
| dataset = data_conf["dataset"] | |
| if not dataset: | |
| return "Failed to load dataset.", "N/A" | |
| random_index = random.randint(0, len(dataset) - 1) | |
| sample = dataset[random_index] | |
| if domain == "Math": | |
| question = sample[data_conf["question_col"]] | |
| answer = sample[data_conf["answer_col"]] | |
| elif domain == "Bio": | |
| instruction = sample[data_conf["instruction_col"]] | |
| bio_input = sample[data_conf["input_col"]] | |
| answer = sample[data_conf["answer_col"]] | |
| if bio_input and bio_input.strip(): | |
| question = f"**Instruction:**\n{instruction}\n\n**Input:**\n{bio_input}" | |
| else: | |
| question = instruction | |
| return question, answer | |
| def update_domain_settings(domain): | |
| models = list(config.ALL_MODELS[domain].keys()) | |
| def_base = next((m for m in models if "Base" in m), models[0]) | |
| def_ft = next((m for m in models if "Finetuned" in m), models[0]) | |
| q, a = get_random_question(domain) | |
| return [ | |
| gr.Dropdown(choices=models, value=def_base), | |
| gr.Dropdown(choices=models, value=def_ft), | |
| gr.Textbox(value=q), | |
| a, | |
| gr.Markdown(visible=False) | |
| ] | |
| def load_next_question(domain): | |
| q, a = get_random_question(domain) | |
| return [gr.Textbox(value=q), a, gr.Markdown(visible=False, value="")] | |
| def reveal_answer(hidden_answer): | |
| return gr.Markdown(value=f"**Ground Truth Answer:**\n\n{hidden_answer}", visible=True) | |
| # --- CORE LOGIC (REBUILT FOR TRUE PARALLEL STREAMING) --- | |
| def stream_to_queue(model_id, prompt, lane, queue, key): | |
| """ | |
| A worker function that runs in a thread. | |
| It calls the streaming API and puts tokens into the queue. | |
| """ | |
| try: | |
| # call_modal_api is a generator | |
| for token in backend.call_modal_api(model_id, prompt, lane): | |
| queue.put((key, token)) | |
| except Exception as e: | |
| queue.put((key, f"\n\nTHREAD ERROR: {e}")) | |
| finally: | |
| # When the stream is done, put a 'None' sentinel | |
| queue.put((key, None)) | |
| def run_comparison(domain, question, model_1_name, model_2_name): | |
| # 1. Get IDs | |
| id_1 = config.ALL_MODELS[domain].get(model_1_name) | |
| id_2 = config.ALL_MODELS[domain].get(model_2_name) | |
| # 2. Ask the Smart Router | |
| lane_for_m1, lane_for_m2 = backend.router.get_routing_plan(id_1, id_2) | |
| # 3. Create the Queue and Threads | |
| q = Queue() | |
| Thread( | |
| target=stream_to_queue, | |
| args=(id_1, question, lane_for_m1, q, 'm1') | |
| ).start() | |
| Thread( | |
| target=stream_to_queue, | |
| args=(id_2, question, lane_for_m2, q, 'm2') | |
| ).start() | |
| # 4. Listen to the Queue | |
| text1 = "" | |
| text2 = "" | |
| m1_done = False | |
| m2_done = False | |
| # Clear boxes and start | |
| yield "", "", gr.Markdown(visible=False) | |
| while not (m1_done and m2_done): | |
| # Wait for the next token from *either* thread | |
| try: | |
| key, token = q.get() | |
| except Exception as e: | |
| # This should ideally not happen | |
| print(f"Queue error: {e}") | |
| continue | |
| # Check for the 'None' sentinel | |
| if token is None: | |
| if key == 'm1': | |
| m1_done = True | |
| elif key == 'm2': | |
| m2_done = True | |
| else: | |
| # Append the new token | |
| if key == 'm1': | |
| text1 += token | |
| elif key == 'm2': | |
| text2 += token | |
| # Yield the updated full text | |
| yield text1, text2, gr.Markdown(visible=False) | |
| # --- UI BUILD (Unchanged) --- | |
| initial_question, initial_answer = get_random_question("Math") | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown( | |
| """ | |
| # π¬ LLM Finetuning Arena | |
| ### Comparing Finetuned vs. Base Models on Specialized Tasks | |
| """ | |
| ) | |
| hidden_answer_state = gr.State(value=initial_answer) | |
| with gr.Row(): | |
| domain_radio = gr.Radio( | |
| ["Math", "Bio"], label="1. Select Domain", value="Math" | |
| ) | |
| with gr.Row(): | |
| question_box = gr.Textbox( | |
| label="2. Question Prompt (Editable)", | |
| value=initial_question, lines=5, scale=4 | |
| ) | |
| next_btn = gr.Button("Load Random Question π", scale=1, min_width=100) | |
| with gr.Row(): | |
| model_1_dd = gr.Dropdown( | |
| label="3. Select Model 1 (Left)", | |
| choices=list(config.ALL_MODELS["Math"].keys()), | |
| value=next((m for m in config.ALL_MODELS["Math"] if "Base" in m)) | |
| ) | |
| model_2_dd = gr.Dropdown( | |
| label="4. Select Model 2 (Right)", | |
| choices=list(config.ALL_MODELS["Math"].keys()), | |
| value=next((m for m in config.ALL_MODELS["Math"] if "Finetuned" in m)) | |
| ) | |
| with gr.Row(): | |
| run_btn = gr.Button("π Run Comparison", variant="primary", scale=3) | |
| show_answer_btn = gr.Button("Show Ground Truth Answer", scale=1) | |
| answer_display_box = gr.Markdown(label="Ground Truth Answer", visible=False) | |
| gr.Markdown("---") | |
| with gr.Row(): | |
| output_1_box = gr.Markdown(label="Output: Model 1") | |
| output_2_box = gr.Markdown(label="Output: Model 2") | |
| # --- EVENTS (Unchanged) --- | |
| domain_radio.change( | |
| fn=update_domain_settings, | |
| inputs=[domain_radio], | |
| outputs=[model_1_dd, model_2_dd, question_box, hidden_answer_state, answer_display_box] | |
| ) | |
| next_btn.click( | |
| fn=load_next_question, | |
| inputs=[domain_radio], | |
| outputs=[question_box, hidden_answer_state, answer_display_box] | |
| ) | |
| show_answer_btn.click( | |
| fn=reveal_answer, | |
| inputs=[hidden_answer_state], | |
| outputs=[answer_display_box] | |
| ) | |
| run_btn.click( | |
| fn=run_comparison, | |
| inputs=[domain_radio, question_box, model_1_dd, model_2_dd], | |
| outputs=[output_1_box, output_2_box, answer_display_box] | |
| ) | |
| if __name__ == "__main__": | |
| if not config.MY_AUTH_TOKEN: | |
| print("β οΈ WARNING: ARENA_AUTH_TOKEN is not set.") | |
| demo.launch() |