Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import json | |
from pathlib import Path | |
import gradio as gr | |
from uuid import uuid4 | |
from datasets import load_dataset | |
from collections import Counter | |
import numpy as np | |
from configs import configs | |
from clients import backend, logger | |
from backend.helpers import get_random_session_samples | |
dataset = load_dataset("iyosha-huji/stressEval", token=configs.HF_API_TOKEN)["test"] | |
INSTRUCTIONS = """<div align='center'>You are given an audio sample and a question with 2 answer options.\n\nListen to the audio and select the correct answer from the options below.\n\n<b>Note:</b> The question is the same for all samples, but the audio and the corresponding answers change.</div>""" | |
with open(Path(__file__).parent / "data/stage_indices.json") as f: | |
STAGE_SPLITS = json.load(f) | |
def human_eval_tab(): | |
with gr.Tab(label="Evaluation"): | |
# ==== State ===== | |
i = gr.State(-1) | |
selected_answer = gr.State(None) | |
answers_dict = gr.State({}) | |
logged_in = gr.State(False) | |
session_id = gr.State(None) | |
user_name = gr.State(None) | |
session_sample_indices = gr.State([]) | |
# === Login UI === | |
with gr.Group(visible=True) as login_group: | |
gr.Markdown("### 🔐 Login to Continue") | |
with gr.Row(): | |
username = gr.Text(label="Username", placeholder="Enter username") | |
password = gr.Text( | |
label="Password", type="password", placeholder="Enter password" | |
) | |
login_error = gr.Markdown( | |
"\u274c Incorrect login, try again. Enter username and password.", | |
visible=False, | |
) | |
login_btn = gr.Button("Login") | |
def login(usr, p): | |
if p == configs.USER_PASSWORD and usr.strip() != "": | |
new_session_id = str(uuid4()) | |
sample_indices, stage = get_random_session_samples( | |
backend, dataset, STAGE_SPLITS, usr, num_samples=15 | |
) | |
logger.info(f"Session ID: {new_session_id}, Stage: {stage}") | |
return ( | |
True, | |
gr.update(visible=False), | |
gr.update(visible=False), | |
new_session_id, | |
sample_indices, | |
usr, | |
) | |
else: | |
return ( | |
False, | |
gr.update(visible=True), | |
gr.update(visible=True), | |
None, | |
[], | |
None, | |
) | |
# === Login Button === | |
login_btn.click( | |
fn=login, | |
inputs=[username, password], | |
outputs=[ | |
logged_in, | |
login_group, | |
login_error, | |
session_id, | |
session_sample_indices, | |
user_name, | |
], | |
) | |
# === UI Elements === | |
next_btn = gr.Button("Start", visible=False) | |
prev_btn = gr.Button("Previous Sample", visible=False) | |
warning_msg = gr.Markdown( | |
"<span style='color:red;'>\u26a0\ufe0f Please select an answer before continuing.</span>", | |
visible=False, | |
) | |
with gr.Group(visible=False) as app_group: | |
with gr.Group(): | |
gr.Markdown("<div align='center'><big><b>Instructions</b></big></div>") | |
gr.Markdown(INSTRUCTIONS) | |
with gr.Group(visible=False) as question_group: | |
with gr.Row(show_progress=True): | |
with gr.Column(variant="compact"): | |
sample_info = gr.Markdown() | |
gr.Markdown("**Question:**") | |
question_md = gr.Markdown() | |
radio = gr.Radio(label="Answer:", interactive=True) | |
with gr.Column(variant="compact"): | |
audio_output = gr.Audio( | |
interactive=False, type="numpy", label="Audio:" | |
) | |
with gr.Group( | |
visible=False, elem_id="final_page" | |
) as final_group: # Final page, not visible until the end | |
gr.Markdown( | |
""" | |
# 🎉 Thanks for your help! | |
You helped moving science forward 🤓 | |
Your responses have been recorded. | |
You may now close this tab. | |
""" | |
) | |
# === Logic === | |
def update_ui(i, answers, session_sample_indices): | |
if i == -1: # We haven't started yet | |
return ( | |
gr.update(visible=False), | |
"", | |
"", | |
gr.update(visible=False), | |
gr.update(visible=False), | |
None, | |
) | |
# show the question | |
true_index = session_sample_indices[i] | |
sample = dataset[true_index] | |
audio_data = (sample["audio"]["sampling_rate"], sample["audio"]["array"]) | |
previous_answer = answers.get(i, None) | |
return ( | |
gr.update(visible=True), | |
f"<div align='center'>Sample <b>{i+1}</b> out of <b>{len(session_sample_indices)}</b></div>", | |
"Out of the following answers, according to the speaker's stressed words, what is most likely the underlying intention of the speaker?", | |
gr.update(value=audio_data), | |
gr.update( | |
choices=sample["possible_answers"], | |
value=previous_answer, | |
), | |
previous_answer, | |
) | |
def update_next_index( | |
i, answer, answers, session_id, session_sample_indices, user_name | |
): | |
if answer is None and i != -1: # if no answer is selected | |
# show warning message | |
return ( | |
gr.update(), | |
gr.update(visible=True), | |
gr.update(), | |
answers, | |
gr.update(visible=False), | |
gr.update(visible=True), | |
) | |
if answer: # if an answer is selected | |
# save the answer to the backend | |
answers[i] = answer | |
true_index = session_sample_indices[i] | |
sample = dataset[true_index] | |
interp_id = sample["interpretation_id"] | |
trans_id = sample["transcription_id"] | |
user_id = session_id | |
user_name_str = user_name or "anonymous" | |
logger.info( | |
"saving answer to backend", | |
context={ | |
"i": true_index, | |
"interp_id": interp_id, | |
"answer": answer, | |
"user_id": user_id, | |
}, | |
) | |
if not backend.update_row(true_index, interp_id, user_id, answer): | |
backend.add_row( | |
true_index, interp_id, trans_id, user_id, answer, user_name_str | |
) | |
if i + 1 == len(session_sample_indices): # Last question just answered | |
return ( | |
-1, # reset i to stop showing question | |
gr.update(visible=False), | |
gr.update(visible=False), | |
answers, | |
gr.update(visible=True), # show final page | |
gr.update(visible=False), # hide previous button | |
) | |
# go to the next question | |
new_i = i + 1 if i + 1 < len(session_sample_indices) else 0 | |
return ( | |
new_i, | |
gr.update(visible=False), | |
gr.update(value="Submit answer and go to Next"), | |
answers, | |
gr.update(visible=False), | |
gr.update(visible=True), | |
) | |
def update_prev_index(i): | |
# prevent goint back in the first question and first page | |
if i <= 0: | |
return i, gr.update(visible=False) | |
# go back to the previous question | |
else: | |
return i - 1, gr.update(visible=False) | |
def answer_change_callback(answer, i, answers): | |
answers[i] = answer | |
return answer, answers | |
def login_callback(logged_in): | |
return ( | |
( | |
gr.update(visible=True), | |
gr.update(visible=True), | |
gr.update(visible=False), | |
gr.update(visible=False), | |
) | |
if logged_in | |
else ( | |
gr.update(visible=False), | |
gr.update(visible=False), | |
gr.update(visible=False), | |
gr.update(visible=False), | |
) | |
) | |
# === Events === | |
next_btn.click( | |
update_next_index, | |
[ | |
i, | |
selected_answer, | |
answers_dict, | |
session_id, | |
session_sample_indices, | |
user_name, | |
], | |
[i, warning_msg, next_btn, answers_dict, final_group, prev_btn], | |
) | |
prev_btn.click(update_prev_index, i, [i, warning_msg]) | |
i.change( | |
update_ui, | |
[i, answers_dict, session_sample_indices], | |
[ | |
question_group, | |
sample_info, | |
question_md, | |
audio_output, | |
radio, | |
selected_answer, | |
], | |
) | |
radio.change( | |
answer_change_callback, | |
[radio, i, answers_dict], | |
[selected_answer, answers_dict], | |
) | |
logged_in.change( | |
login_callback, logged_in, [app_group, next_btn, prev_btn, warning_msg] | |
) | |
def compute_random_sampled_accuracy(df, dataset, n_rounds=100, seed=42): | |
rng = np.random.default_rng(seed) | |
# Filter to interpretation_ids with at least 3 user answers | |
counts = df.groupby("interpretation_id")["user_id"].nunique() | |
eligible_ids = set(counts[counts >= 3].index) | |
# Group answers by interpretation_id | |
grouped = df[df["interpretation_id"].isin(eligible_ids)].groupby( | |
"interpretation_id" | |
) | |
all_scores = [] | |
total_answered_per_round = [] | |
for _ in range(n_rounds): | |
correct = 0 | |
total = 0 | |
for interp_id, group in grouped: | |
if group.empty: | |
continue | |
# Randomly pick one row | |
row = group.sample(1, random_state=rng.integers(1e6)).iloc[0] | |
answer = row["answer"] | |
idx = int(row["index_in_dataset"]) | |
sample = dataset[idx] | |
gt = sample["possible_answers"][sample["label"]] | |
total += 1 | |
if answer == gt: | |
correct += 1 | |
if total > 0: | |
all_scores.append(correct / total) | |
total_answered_per_round.append(total) | |
if all_scores: | |
mean_acc = np.mean(all_scores) | |
mean_total = int(np.mean(total_answered_per_round)) | |
std_acc = np.std(all_scores, ddof=1) # sample std | |
ci_95 = 1.96 * std_acc / np.sqrt(n_rounds) | |
return mean_acc, std_acc, mean_total, ci_95 | |
return None, None, 0, None | |
def get_admin_tab(): | |
with gr.Tab("Admin Console"): | |
admin_password = gr.Text(label="Enter Admin Password", type="password") | |
check_btn = gr.Button("Enter") | |
error_box = gr.Markdown("", visible=False) | |
output_box = gr.Markdown("", visible=False) | |
def calculate_majority_vote_accuracy(pw): | |
if pw != configs.ADMIN_PASSWORD: | |
return gr.update( | |
visible=True, value="❌ Incorrect password." | |
), gr.update(visible=False) | |
df = backend.get_all_rows() | |
if df.empty: | |
return gr.update(visible=True, value="No data available."), gr.update( | |
visible=False | |
) | |
# Majority vote per interpretation_id | |
majority_answers = {} | |
for interp_id, group in df.groupby("interpretation_id"): | |
answer_counts = Counter(group["answer"]) | |
if answer_counts: | |
majority_answers[interp_id] = answer_counts.most_common(1)[0][0] | |
counts = df.groupby("interpretation_id")["user_id"].nunique().to_dict() | |
total_answers = len(df) | |
users_count = df["user_id"].nunique() | |
stage_acc = {} | |
stage_completes = {} | |
stage_counts = {} | |
stage_remaining = {} | |
# global_correct = 0 | |
# global_total = 0 | |
for stage in ["stage1", "stage2", "stage3"]: | |
correct, total = 0, 0 | |
complete = 0 | |
for i in STAGE_SPLITS[stage]: | |
sample = dataset[i] | |
interp_id = sample["interpretation_id"] | |
label = sample["label"] | |
gt = sample["possible_answers"][label] | |
n = counts.get(interp_id, 0) | |
if n >= 3: | |
complete += 1 | |
if interp_id in majority_answers: | |
pred = majority_answers[interp_id] | |
total += 1 | |
if pred == gt: | |
correct += 1 | |
stage_counts[stage] = len(STAGE_SPLITS[stage]) | |
stage_completes[stage] = complete | |
stage_remaining[stage] = 3 * len(STAGE_SPLITS[stage]) - sum( | |
counts.get(dataset[i]["interpretation_id"], 0) | |
for i in STAGE_SPLITS[stage] | |
) | |
if complete == len(STAGE_SPLITS[stage]): | |
acc = correct / total if total > 0 else 0 | |
stage_acc[stage] = (acc, correct, total) | |
else: | |
stage_acc[stage] = None # not shown yet | |
# Determine active stage | |
if stage_completes["stage1"] < stage_counts["stage1"]: | |
current_stage = "Stage 1" | |
elif stage_completes["stage2"] < stage_counts["stage2"]: | |
current_stage = "Stage 2" | |
else: | |
current_stage = "Stage 3" | |
# Majority Vote Accuracy Section | |
agg_lines = [] | |
if stage_acc["stage1"]: | |
acc1, c1, t1 = stage_acc["stage1"] | |
agg_lines.append(f"- **Stage 1:** {acc1:.2%} ({c1}/{t1})") | |
if stage_acc["stage2"]: | |
acc2, c2, t2 = stage_acc["stage2"] | |
agg_lines.append( | |
f"- **Stage 1+2:** {(c1 + c2) / (t1 + t2):.2%} ({c1 + c2}/{t1 + t2})" | |
) | |
if stage_acc["stage3"]: | |
acc3, c3, t3 = stage_acc["stage3"] | |
agg_lines.append( | |
f"- **All Stages:** {(c1 + c2 + c3) / (t1 + t2 + t3):.2%} ({c1 + c2 + c3}/{t1 + t2 + t3})" | |
) | |
agg_msg = "\n".join(agg_lines) if agg_lines else "No completed stages yet." | |
# Compute random-sampled accuracy | |
n_rounds = 100 | |
rand_acc, rand_std, rand_total, rand_ci = compute_random_sampled_accuracy( | |
df, dataset, n_rounds=n_rounds | |
) | |
# Random-sampled Accuracy | |
if rand_acc is not None: | |
rand_acc_msg = ( | |
f"**Accuracy:** {rand_acc:.2%} ± {rand_ci:.2%} (95% CI)\n\n" | |
f"Standard deviation: {rand_std:.2%}\n\n" | |
f"Samples used: {rand_total} × {n_rounds} rounds" | |
) | |
else: | |
rand_acc_msg = "Random sampling failed (no data)." | |
correct = 0 | |
total = 0 | |
for _, row in df.iterrows(): | |
idx = int(row["index_in_dataset"]) | |
if idx >= len(dataset): | |
continue # skip out-of-range | |
sample = dataset[idx] | |
gt_answer = sample["possible_answers"][sample["label"]] | |
if row["answer"] == gt_answer: | |
correct += 1 | |
total += 1 | |
overall_acc = correct / total if total > 0 else None | |
if overall_acc is not None: | |
overall_acc_msg = ( | |
f"Overall Accuracy: {overall_acc:.2%} ({correct}/{total})" | |
) | |
else: | |
overall_acc_msg = "No data available." | |
# Final message (no indentation!) | |
msg = f""" | |
## ✅ Accuracy Summary | |
### Overall Accuracy | |
{overall_acc_msg} | |
--- | |
### Majority Vote | |
{agg_msg} | |
--- | |
### Random-Sampled Accuracy | |
{rand_acc_msg} | |
--- | |
## 📊 Answer Progress | |
- **Total answers submitted:** {total_answers} | |
- **Answers to go (global):** {3 * len(dataset) - total_answers} | |
- **Unique users:** {users_count} | |
--- | |
## 🧱 Stage Breakdown | |
| Stage | Completed | Total | Remaining Answers | | |
|-------|-----------|--------|-------------------| | |
| 1 | {stage_completes['stage1']} / {stage_counts['stage1']} | {stage_counts['stage1']} | {stage_remaining['stage1']} | | |
| 2 | {stage_completes['stage2']} / {stage_counts['stage2']} | {stage_counts['stage2']} | {stage_remaining['stage2']} | | |
| 3 | {stage_completes['stage3']} / {stage_counts['stage3']} | {stage_counts['stage3']} | {stage_remaining['stage3']} | | |
**➡️ Current Active Stage:** {current_stage} | |
""" | |
return gr.update(visible=False), gr.update(visible=True, value=msg) | |
check_btn.click( | |
fn=calculate_majority_vote_accuracy, | |
inputs=admin_password, | |
outputs=[error_box, output_box], | |
) | |
# App UI | |
with gr.Blocks() as demo: | |
human_eval_tab() | |
get_admin_tab() | |
# Launch app | |
demo.launch() | |