Spaces:
Running
Running
""" | |
Gradio demo – visualise benchmark accuracy curves. | |
Required CSV files (place in the *same* folder as app.py): | |
├── aggregated_accuracy.csv | |
├── qa_accuracy.csv | |
├── ocr_accuracy.csv | |
└── temporal_accuracy.csv | |
Each file has the columns | |
Model,<context‑length‑1>,<context‑length‑2>,… | |
where the context‑length headers are strings such as `30min`, `60min`, `120min`, … | |
No further cleaning / renaming is done apart from two cosmetic replacements | |
(“gpt4.1” → “ChatGPT 4.1”, “gemini2.5pro” → “Gemini 2.5 Pro”). | |
""" | |
from pathlib import Path | |
import pandas as pd | |
import plotly.graph_objects as go | |
import gradio as gr | |
import math | |
# --------------------------------------------------------------------- # | |
# Config # | |
# --------------------------------------------------------------------- # | |
FILES = { | |
"aggregated": "aggregated_accuracy.csv", | |
"qa": "qa_accuracy.csv", | |
"ocr": "ocr_accuracy.csv", | |
"temporal": "temporal_accuracy.csv", | |
} | |
# Mapping of internal benchmark keys to nicely formatted display labels | |
DISPLAY_LABELS = { | |
"aggregated": "Aggregated", | |
"qa": "QA", | |
"ocr": "OCR", | |
"temporal": "Temporal", | |
} | |
# Optional: choose which models are selected by default for each benchmark. | |
# Use the *display names* exactly as they appear in the Models list. | |
# If a benchmark is missing, it falls back to the first six models. | |
DEFAULT_MODELS: dict[str, list[str]] = { | |
"aggregated": [ | |
"Gemini 2.5 Pro", | |
"ChatGPT 4.1", | |
"Qwen2.5-VL-7B", | |
"InternVL2.5-8B", | |
"LLaMA-3.2-11B-Vision", | |
], | |
} | |
RENAME = { | |
r"gpt4\.1": "ChatGPT 4.1", | |
r"Gemini\s2\.5\spro": "Gemini 2.5 Pro", | |
r"LLaMA-3\.2B-11B": "LLaMA-3.2-11B-Vision", | |
} | |
# --------------------------------------------------------------------- # | |
# Data loading # | |
# --------------------------------------------------------------------- # | |
def _read_csv(path: str | Path) -> pd.DataFrame: | |
df = pd.read_csv(path) | |
df["Model"] = df["Model"].replace(RENAME, regex=True).astype(str) | |
return df | |
dfs: dict[str, pd.DataFrame] = {name: _read_csv(path) for name, path in FILES.items()} | |
# --------------------------------------------------------------------- # | |
# Colour palette and model metadata # | |
# --------------------------------------------------------------------- # | |
import plotly.express as px | |
SAFE_PALETTE = px.colors.qualitative.Safe # colour-blind-safe qualitative palette (10 colours) | |
# Deterministic list of all unique model names to ensure consistent colour mapping | |
ALL_MODELS: list[str] = sorted({m for df in dfs.values() for m in df["Model"].unique()}) | |
MARKER_SYMBOLS = [ | |
"circle", | |
"square", | |
"triangle-up", | |
"diamond", | |
"cross", | |
"triangle-down", | |
"x", | |
"triangle-right", | |
"triangle-left", | |
"pentagon", | |
] | |
TIME_COLS = [c for c in dfs["aggregated"].columns if c.lower() != "model"] | |
def _pretty_time(label: str) -> str: | |
"""‘30min’ → ‘30min’; ‘120min’ → ‘2hr’; keeps original if no match.""" | |
if label.endswith("min"): | |
minutes = int(label[:-3]) | |
if minutes >= 60: | |
hours = minutes / 60 | |
return f"{hours:.0f}hr" if hours.is_integer() else f"{hours:.1f}hr" | |
return label | |
TIME_LABELS = {c: _pretty_time(c) for c in TIME_COLS} | |
# --------------------------------------------------------------------- # | |
# Plotting # | |
# --------------------------------------------------------------------- # | |
def render_chart( | |
benchmark: str, | |
models: list[str], | |
log_scale: bool, | |
) -> go.Figure: | |
bench_key = benchmark.lower() | |
df = dfs[bench_key] | |
fig = go.Figure() | |
# Define colour and marker based on deterministic mapping | |
palette = SAFE_PALETTE | |
# Determine minimum non-zero Y value across selected models for log scaling | |
min_y_val = None | |
for idx, m in enumerate(models): | |
row = df.loc[df["Model"] == m] | |
if row.empty: | |
continue | |
y = row[TIME_COLS].values.flatten() | |
y = [val if val != 0 else None for val in y] # show gaps for 0 / missing | |
# Track minimum non-zero accuracy | |
y_non_none = [val for val in y if val is not None] | |
if y_non_none: | |
cur_min = min(y_non_none) | |
if min_y_val is None or cur_min < min_y_val: | |
min_y_val = cur_min | |
model_idx = ALL_MODELS.index(m) if m in ALL_MODELS else idx | |
color = palette[model_idx % len(palette)] | |
symbol = MARKER_SYMBOLS[model_idx % len(MARKER_SYMBOLS)] | |
fig.add_trace( | |
go.Scatter( | |
x=[TIME_LABELS[c] for c in TIME_COLS], | |
y=y, | |
mode="lines+markers", | |
name=m, | |
line=dict(width=3, color=color), | |
marker=dict(size=6, color=color, symbol=symbol), | |
connectgaps=False, | |
) | |
) | |
# Set Y-axis properties | |
if log_scale: | |
# Fallback to 0.1 if there are no valid points | |
if min_y_val is None or min_y_val <= 0: | |
min_y_val = 0.1 | |
# Plotly expects log10 values for range when axis type is "log" | |
yaxis_range = [math.floor(math.log10(min_y_val)), 2] # max at 10^2 = 100 | |
yaxis_type = "log" | |
else: | |
yaxis_range = [0, 100] | |
yaxis_type = "linear" | |
fig.update_layout( | |
title=f"{DISPLAY_LABELS.get(bench_key, bench_key.capitalize())} Accuracy Over Time", | |
xaxis_title="Video Duration", | |
yaxis_title="Accuracy (%)", | |
yaxis_type=yaxis_type, | |
yaxis_range=yaxis_range, | |
legend_title="Model", | |
legend=dict( | |
orientation="h", | |
y=-0.25, | |
x=0.5, | |
xanchor="center", | |
tracegroupgap=8, | |
itemwidth=60, | |
), | |
margin=dict(t=40, r=20, b=80, l=60), | |
template="plotly_dark", | |
font=dict(family="Inter,Helvetica,Arial,sans-serif", size=14), | |
title_font=dict(size=20, family="Inter,Helvetica,Arial,sans-serif", color="white"), | |
xaxis=dict(gridcolor="rgba(255,255,255,0.15)"), | |
yaxis=dict(gridcolor="rgba(255,255,255,0.15)"), | |
hoverlabel=dict(bgcolor="#1e1e1e", font_color="#eeeeee", bordercolor="#888"), | |
) | |
return fig | |
# --------------------------------------------------------------------- # | |
# UI # | |
# --------------------------------------------------------------------- # | |
CSS = """ | |
#controls { | |
padding: 8px 12px; | |
} | |
.scrollbox { | |
max-height: 300px; | |
overflow-y: auto; | |
} | |
body, .gradio-container { | |
font-family: 'Inter', 'Helvetica', sans-serif; | |
} | |
.gradio-container h1, .gradio-container h2 { | |
font-weight: 600; | |
} | |
#controls, .scrollbox { | |
background: rgba(255,255,255,0.02); | |
border-radius: 6px; | |
} | |
input[type="checkbox"]:checked { | |
accent-color: #FF715E; | |
} | |
""" | |
def available_models(bench: str) -> list[str]: | |
return sorted(dfs[bench]["Model"].unique()) | |
def default_models(bench: str) -> list[str]: | |
"""Return list of default-selected models for a benchmark.""" | |
opts = available_models(bench) | |
configured = DEFAULT_MODELS.get(bench, []) | |
# Keep only those present in opts | |
valid = [m for m in configured if m in opts] | |
if not valid: | |
# Fall back to first six | |
valid = opts[:6] | |
return valid | |
with gr.Blocks(theme=gr.themes.Base(), css=CSS) as demo: | |
gr.Markdown( | |
""" | |
# 📈 TimeScope | |
How long can your video model keep up? | |
""" | |
) | |
# ---- top controls row ---- # | |
with gr.Row(): | |
benchmark_dd = gr.Dropdown( | |
label="Type", | |
choices=list(DISPLAY_LABELS.values()), | |
value=DISPLAY_LABELS["aggregated"], | |
scale=1, | |
) | |
log_cb = gr.Checkbox( | |
label="Log-scale Y-axis", | |
value=False, | |
scale=1, | |
) | |
# ---- models list and plot ---- # | |
plot_out = gr.Plot( | |
render_chart("Aggregated", default_models("aggregated"), False) | |
) | |
models_cb = gr.CheckboxGroup( | |
label="Models", | |
choices=available_models("aggregated"), | |
value=default_models("aggregated"), | |
interactive=True, | |
elem_classes=["scrollbox"], | |
) | |
# ‑-- dynamic callbacks ‑-- # | |
def _update_models(bench: str): | |
bench_key = bench.lower() | |
opts = available_models(bench_key) | |
defaults = default_models(bench_key) | |
# Use generic gr.update for compatibility across Gradio versions | |
return gr.update(choices=opts, value=defaults) | |
benchmark_dd.change( | |
fn=_update_models, | |
inputs=benchmark_dd, | |
outputs=models_cb, | |
queue=False, | |
) | |
for ctrl in (benchmark_dd, models_cb, log_cb): | |
ctrl.change( | |
fn=render_chart, | |
inputs=[benchmark_dd, models_cb, log_cb], | |
outputs=plot_out, | |
queue=False, | |
) | |
# Make legend interaction clearer: click to toggle traces | |
demo.launch(share=True) |