import json
from pathlib import Path
import gradio as gr
from uuid import uuid4
from datasets import load_dataset
from collections import Counter
import numpy as np
from configs import configs
from clients import backend, logger
from backend.helpers import get_random_session_samples
dataset = load_dataset("iyosha-huji/stressEval", token=configs.HF_API_TOKEN)["test"]
INSTRUCTIONS = """
You are given an audio sample and a question with 2 answer options.\n\nListen to the audio and select the correct answer from the options below.\n\nNote: The question is the same for all samples, but the audio and the corresponding answers change.
"""
with open(Path(__file__).parent / "data/stage_indices.json") as f:
STAGE_SPLITS = json.load(f)
def human_eval_tab():
with gr.Tab(label="Evaluation"):
# ==== State =====
i = gr.State(-1)
selected_answer = gr.State(None)
answers_dict = gr.State({})
logged_in = gr.State(False)
session_id = gr.State(None)
user_name = gr.State(None)
session_sample_indices = gr.State([])
# === Login UI ===
with gr.Group(visible=True) as login_group:
gr.Markdown("### 🔐 Login to Continue")
with gr.Row():
username = gr.Text(label="Username", placeholder="Enter username")
password = gr.Text(
label="Password", type="password", placeholder="Enter password"
)
login_error = gr.Markdown(
"\u274c Incorrect login, try again. Enter username and password.",
visible=False,
)
login_btn = gr.Button("Login")
def login(usr, p):
if p == configs.USER_PASSWORD and usr.strip() != "":
new_session_id = str(uuid4())
sample_indices, stage = get_random_session_samples(
backend, dataset, STAGE_SPLITS, usr, num_samples=15
)
logger.info(f"Session ID: {new_session_id}, Stage: {stage}")
return (
True,
gr.update(visible=False),
gr.update(visible=False),
new_session_id,
sample_indices,
usr,
)
else:
return (
False,
gr.update(visible=True),
gr.update(visible=True),
None,
[],
None,
)
# === Login Button ===
login_btn.click(
fn=login,
inputs=[username, password],
outputs=[
logged_in,
login_group,
login_error,
session_id,
session_sample_indices,
user_name,
],
)
# === UI Elements ===
next_btn = gr.Button("Start", visible=False)
prev_btn = gr.Button("Previous Sample", visible=False)
warning_msg = gr.Markdown(
"\u26a0\ufe0f Please select an answer before continuing.",
visible=False,
)
with gr.Group(visible=False) as app_group:
with gr.Group():
gr.Markdown("Instructions
")
gr.Markdown(INSTRUCTIONS)
with gr.Group(visible=False) as question_group:
with gr.Row(show_progress=True):
with gr.Column(variant="compact"):
sample_info = gr.Markdown()
gr.Markdown("**Question:**")
question_md = gr.Markdown()
radio = gr.Radio(label="Answer:", interactive=True)
with gr.Column(variant="compact"):
audio_output = gr.Audio(
interactive=False, type="numpy", label="Audio:"
)
with gr.Group(
visible=False, elem_id="final_page"
) as final_group: # Final page, not visible until the end
gr.Markdown(
"""
# 🎉 Thanks for your help!
You helped moving science forward 🤓
Your responses have been recorded.
You may now close this tab.
"""
)
# === Logic ===
def update_ui(i, answers, session_sample_indices):
if i == -1: # We haven't started yet
return (
gr.update(visible=False),
"",
"",
gr.update(visible=False),
gr.update(visible=False),
None,
)
# show the question
true_index = session_sample_indices[i]
sample = dataset[true_index]
audio_data = (sample["audio"]["sampling_rate"], sample["audio"]["array"])
previous_answer = answers.get(i, None)
return (
gr.update(visible=True),
f"Sample {i+1} out of {len(session_sample_indices)}
",
"Out of the following answers, according to the speaker's stressed words, what is most likely the underlying intention of the speaker?",
gr.update(value=audio_data),
gr.update(
choices=sample["possible_answers"],
value=previous_answer,
),
previous_answer,
)
def update_next_index(
i, answer, answers, session_id, session_sample_indices, user_name
):
if answer is None and i != -1: # if no answer is selected
# show warning message
return (
gr.update(),
gr.update(visible=True),
gr.update(),
answers,
gr.update(visible=False),
gr.update(visible=True),
)
if answer: # if an answer is selected
# save the answer to the backend
answers[i] = answer
true_index = session_sample_indices[i]
sample = dataset[true_index]
interp_id = sample["interpretation_id"]
trans_id = sample["transcription_id"]
user_id = session_id
user_name_str = user_name or "anonymous"
logger.info(
"saving answer to backend",
context={
"i": true_index,
"interp_id": interp_id,
"answer": answer,
"user_id": user_id,
},
)
if not backend.update_row(true_index, interp_id, user_id, answer):
backend.add_row(
true_index, interp_id, trans_id, user_id, answer, user_name_str
)
if i + 1 == len(session_sample_indices): # Last question just answered
return (
-1, # reset i to stop showing question
gr.update(visible=False),
gr.update(visible=False),
answers,
gr.update(visible=True), # show final page
gr.update(visible=False), # hide previous button
)
# go to the next question
new_i = i + 1 if i + 1 < len(session_sample_indices) else 0
return (
new_i,
gr.update(visible=False),
gr.update(value="Submit answer and go to Next"),
answers,
gr.update(visible=False),
gr.update(visible=True),
)
def update_prev_index(i):
# prevent goint back in the first question and first page
if i <= 0:
return i, gr.update(visible=False)
# go back to the previous question
else:
return i - 1, gr.update(visible=False)
def answer_change_callback(answer, i, answers):
answers[i] = answer
return answer, answers
def login_callback(logged_in):
return (
(
gr.update(visible=True),
gr.update(visible=True),
gr.update(visible=False),
gr.update(visible=False),
)
if logged_in
else (
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False),
)
)
# === Events ===
next_btn.click(
update_next_index,
[
i,
selected_answer,
answers_dict,
session_id,
session_sample_indices,
user_name,
],
[i, warning_msg, next_btn, answers_dict, final_group, prev_btn],
)
prev_btn.click(update_prev_index, i, [i, warning_msg])
i.change(
update_ui,
[i, answers_dict, session_sample_indices],
[
question_group,
sample_info,
question_md,
audio_output,
radio,
selected_answer,
],
)
radio.change(
answer_change_callback,
[radio, i, answers_dict],
[selected_answer, answers_dict],
)
logged_in.change(
login_callback, logged_in, [app_group, next_btn, prev_btn, warning_msg]
)
def compute_random_sampled_accuracy(df, dataset, n_rounds=100, seed=42):
rng = np.random.default_rng(seed)
# Filter to interpretation_ids with at least 3 user answers
counts = df.groupby("interpretation_id")["user_id"].nunique()
eligible_ids = set(counts[counts >= 3].index)
# Group answers by interpretation_id
grouped = df[df["interpretation_id"].isin(eligible_ids)].groupby(
"interpretation_id"
)
all_scores = []
total_answered_per_round = []
for _ in range(n_rounds):
correct = 0
total = 0
for interp_id, group in grouped:
if group.empty:
continue
# Randomly pick one row
row = group.sample(1, random_state=rng.integers(1e6)).iloc[0]
answer = row["answer"]
idx = int(row["index_in_dataset"])
sample = dataset[idx]
gt = sample["possible_answers"][sample["label"]]
total += 1
if answer == gt:
correct += 1
if total > 0:
all_scores.append(correct / total)
total_answered_per_round.append(total)
if all_scores:
mean_acc = np.mean(all_scores)
mean_total = int(np.mean(total_answered_per_round))
std_acc = np.std(all_scores, ddof=1) # sample std
ci_95 = 1.96 * std_acc / np.sqrt(n_rounds)
return mean_acc, std_acc, mean_total, ci_95
return None, None, 0, None
def get_admin_tab():
with gr.Tab("Admin Console"):
admin_password = gr.Text(label="Enter Admin Password", type="password")
check_btn = gr.Button("Enter")
error_box = gr.Markdown("", visible=False)
output_box = gr.Markdown("", visible=False)
def calculate_majority_vote_accuracy(pw):
if pw != configs.ADMIN_PASSWORD:
return gr.update(
visible=True, value="❌ Incorrect password."
), gr.update(visible=False)
df = backend.get_all_rows()
if df.empty:
return gr.update(visible=True, value="No data available."), gr.update(
visible=False
)
# Majority vote per interpretation_id
majority_answers = {}
for interp_id, group in df.groupby("interpretation_id"):
answer_counts = Counter(group["answer"])
if answer_counts:
majority_answers[interp_id] = answer_counts.most_common(1)[0][0]
counts = df.groupby("interpretation_id")["user_id"].nunique().to_dict()
total_answers = len(df)
users_count = df["user_id"].nunique()
stage_acc = {}
stage_completes = {}
stage_counts = {}
stage_remaining = {}
# global_correct = 0
# global_total = 0
for stage in ["stage1", "stage2", "stage3"]:
correct, total = 0, 0
complete = 0
for i in STAGE_SPLITS[stage]:
sample = dataset[i]
interp_id = sample["interpretation_id"]
label = sample["label"]
gt = sample["possible_answers"][label]
n = counts.get(interp_id, 0)
if n >= 3:
complete += 1
if interp_id in majority_answers:
pred = majority_answers[interp_id]
total += 1
if pred == gt:
correct += 1
stage_counts[stage] = len(STAGE_SPLITS[stage])
stage_completes[stage] = complete
stage_remaining[stage] = 3 * len(STAGE_SPLITS[stage]) - sum(
counts.get(dataset[i]["interpretation_id"], 0)
for i in STAGE_SPLITS[stage]
)
if complete == len(STAGE_SPLITS[stage]):
acc = correct / total if total > 0 else 0
stage_acc[stage] = (acc, correct, total)
else:
stage_acc[stage] = None # not shown yet
# Determine active stage
if stage_completes["stage1"] < stage_counts["stage1"]:
current_stage = "Stage 1"
elif stage_completes["stage2"] < stage_counts["stage2"]:
current_stage = "Stage 2"
else:
current_stage = "Stage 3"
# Majority Vote Accuracy Section
agg_lines = []
if stage_acc["stage1"]:
acc1, c1, t1 = stage_acc["stage1"]
agg_lines.append(f"- **Stage 1:** {acc1:.2%} ({c1}/{t1})")
if stage_acc["stage2"]:
acc2, c2, t2 = stage_acc["stage2"]
agg_lines.append(
f"- **Stage 1+2:** {(c1 + c2) / (t1 + t2):.2%} ({c1 + c2}/{t1 + t2})"
)
if stage_acc["stage3"]:
acc3, c3, t3 = stage_acc["stage3"]
agg_lines.append(
f"- **All Stages:** {(c1 + c2 + c3) / (t1 + t2 + t3):.2%} ({c1 + c2 + c3}/{t1 + t2 + t3})"
)
agg_msg = "\n".join(agg_lines) if agg_lines else "No completed stages yet."
# Compute random-sampled accuracy
n_rounds = 100
rand_acc, rand_std, rand_total, rand_ci = compute_random_sampled_accuracy(
df, dataset, n_rounds=n_rounds
)
# Random-sampled Accuracy
if rand_acc is not None:
rand_acc_msg = (
f"**Accuracy:** {rand_acc:.2%} ± {rand_ci:.2%} (95% CI)\n\n"
f"Standard deviation: {rand_std:.2%}\n\n"
f"Samples used: {rand_total} × {n_rounds} rounds"
)
else:
rand_acc_msg = "Random sampling failed (no data)."
correct = 0
total = 0
for _, row in df.iterrows():
idx = int(row["index_in_dataset"])
if idx >= len(dataset):
continue # skip out-of-range
sample = dataset[idx]
gt_answer = sample["possible_answers"][sample["label"]]
if row["answer"] == gt_answer:
correct += 1
total += 1
overall_acc = correct / total if total > 0 else None
if overall_acc is not None:
overall_acc_msg = (
f"Overall Accuracy: {overall_acc:.2%} ({correct}/{total})"
)
else:
overall_acc_msg = "No data available."
# Final message (no indentation!)
msg = f"""
## ✅ Accuracy Summary
### Overall Accuracy
{overall_acc_msg}
---
### Majority Vote
{agg_msg}
---
### Random-Sampled Accuracy
{rand_acc_msg}
---
## 📊 Answer Progress
- **Total answers submitted:** {total_answers}
- **Answers to go (global):** {3 * len(dataset) - total_answers}
- **Unique users:** {users_count}
---
## 🧱 Stage Breakdown
| Stage | Completed | Total | Remaining Answers |
|-------|-----------|--------|-------------------|
| 1 | {stage_completes['stage1']} / {stage_counts['stage1']} | {stage_counts['stage1']} | {stage_remaining['stage1']} |
| 2 | {stage_completes['stage2']} / {stage_counts['stage2']} | {stage_counts['stage2']} | {stage_remaining['stage2']} |
| 3 | {stage_completes['stage3']} / {stage_counts['stage3']} | {stage_counts['stage3']} | {stage_remaining['stage3']} |
**➡️ Current Active Stage:** {current_stage}
"""
return gr.update(visible=False), gr.update(visible=True, value=msg)
check_btn.click(
fn=calculate_majority_vote_accuracy,
inputs=admin_password,
outputs=[error_box, output_box],
)
# App UI
with gr.Blocks() as demo:
human_eval_tab()
get_admin_tab()
# Launch app
demo.launch()