LLM_Finetuning / app.py
farhananis005's picture
LLM finetuning demo
ec0af28 verified
import gradio as gr
import random
from threading import Thread
from queue import Queue
# Import our new modules
import config
import backend
# --- HELPER FUNCTIONS (Unchanged) ---
def get_random_question(domain):
data_conf = config.DATASET_CONFIG[domain]
dataset = data_conf["dataset"]
if not dataset:
return "Failed to load dataset.", "N/A"
random_index = random.randint(0, len(dataset) - 1)
sample = dataset[random_index]
if domain == "Math":
question = sample[data_conf["question_col"]]
answer = sample[data_conf["answer_col"]]
elif domain == "Bio":
instruction = sample[data_conf["instruction_col"]]
bio_input = sample[data_conf["input_col"]]
answer = sample[data_conf["answer_col"]]
if bio_input and bio_input.strip():
question = f"**Instruction:**\n{instruction}\n\n**Input:**\n{bio_input}"
else:
question = instruction
return question, answer
def update_domain_settings(domain):
models = list(config.ALL_MODELS[domain].keys())
def_base = next((m for m in models if "Base" in m), models[0])
def_ft = next((m for m in models if "Finetuned" in m), models[0])
q, a = get_random_question(domain)
return [
gr.Dropdown(choices=models, value=def_base),
gr.Dropdown(choices=models, value=def_ft),
gr.Textbox(value=q),
a,
gr.Markdown(visible=False)
]
def load_next_question(domain):
q, a = get_random_question(domain)
return [gr.Textbox(value=q), a, gr.Markdown(visible=False, value="")]
def reveal_answer(hidden_answer):
return gr.Markdown(value=f"**Ground Truth Answer:**\n\n{hidden_answer}", visible=True)
# --- CORE LOGIC (REBUILT FOR TRUE PARALLEL STREAMING) ---
def stream_to_queue(model_id, prompt, lane, queue, key):
"""
A worker function that runs in a thread.
It calls the streaming API and puts tokens into the queue.
"""
try:
# call_modal_api is a generator
for token in backend.call_modal_api(model_id, prompt, lane):
queue.put((key, token))
except Exception as e:
queue.put((key, f"\n\nTHREAD ERROR: {e}"))
finally:
# When the stream is done, put a 'None' sentinel
queue.put((key, None))
def run_comparison(domain, question, model_1_name, model_2_name):
# 1. Get IDs
id_1 = config.ALL_MODELS[domain].get(model_1_name)
id_2 = config.ALL_MODELS[domain].get(model_2_name)
# 2. Ask the Smart Router
lane_for_m1, lane_for_m2 = backend.router.get_routing_plan(id_1, id_2)
# 3. Create the Queue and Threads
q = Queue()
Thread(
target=stream_to_queue,
args=(id_1, question, lane_for_m1, q, 'm1')
).start()
Thread(
target=stream_to_queue,
args=(id_2, question, lane_for_m2, q, 'm2')
).start()
# 4. Listen to the Queue
text1 = ""
text2 = ""
m1_done = False
m2_done = False
# Clear boxes and start
yield "", "", gr.Markdown(visible=False)
while not (m1_done and m2_done):
# Wait for the next token from *either* thread
try:
key, token = q.get()
except Exception as e:
# This should ideally not happen
print(f"Queue error: {e}")
continue
# Check for the 'None' sentinel
if token is None:
if key == 'm1':
m1_done = True
elif key == 'm2':
m2_done = True
else:
# Append the new token
if key == 'm1':
text1 += token
elif key == 'm2':
text2 += token
# Yield the updated full text
yield text1, text2, gr.Markdown(visible=False)
# --- UI BUILD (Unchanged) ---
initial_question, initial_answer = get_random_question("Math")
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# πŸ”¬ LLM Finetuning Arena
### Comparing Finetuned vs. Base Models on Specialized Tasks
"""
)
hidden_answer_state = gr.State(value=initial_answer)
with gr.Row():
domain_radio = gr.Radio(
["Math", "Bio"], label="1. Select Domain", value="Math"
)
with gr.Row():
question_box = gr.Textbox(
label="2. Question Prompt (Editable)",
value=initial_question, lines=5, scale=4
)
next_btn = gr.Button("Load Random Question πŸ”„", scale=1, min_width=100)
with gr.Row():
model_1_dd = gr.Dropdown(
label="3. Select Model 1 (Left)",
choices=list(config.ALL_MODELS["Math"].keys()),
value=next((m for m in config.ALL_MODELS["Math"] if "Base" in m))
)
model_2_dd = gr.Dropdown(
label="4. Select Model 2 (Right)",
choices=list(config.ALL_MODELS["Math"].keys()),
value=next((m for m in config.ALL_MODELS["Math"] if "Finetuned" in m))
)
with gr.Row():
run_btn = gr.Button("πŸš€ Run Comparison", variant="primary", scale=3)
show_answer_btn = gr.Button("Show Ground Truth Answer", scale=1)
answer_display_box = gr.Markdown(label="Ground Truth Answer", visible=False)
gr.Markdown("---")
with gr.Row():
output_1_box = gr.Markdown(label="Output: Model 1")
output_2_box = gr.Markdown(label="Output: Model 2")
# --- EVENTS (Unchanged) ---
domain_radio.change(
fn=update_domain_settings,
inputs=[domain_radio],
outputs=[model_1_dd, model_2_dd, question_box, hidden_answer_state, answer_display_box]
)
next_btn.click(
fn=load_next_question,
inputs=[domain_radio],
outputs=[question_box, hidden_answer_state, answer_display_box]
)
show_answer_btn.click(
fn=reveal_answer,
inputs=[hidden_answer_state],
outputs=[answer_display_box]
)
run_btn.click(
fn=run_comparison,
inputs=[domain_radio, question_box, model_1_dd, model_2_dd],
outputs=[output_1_box, output_2_box, answer_display_box]
)
if __name__ == "__main__":
if not config.MY_AUTH_TOKEN:
print("⚠️ WARNING: ARENA_AUTH_TOKEN is not set.")
demo.launch()