TimeScope / app.py
orrzohar's picture
Upload 5 files
4dca8ec verified
"""
Gradio demo – visualise benchmark accuracy curves.
Required CSV files (place in the *same* folder as app.py):
├── aggregated_accuracy.csv
├── qa_accuracy.csv
├── ocr_accuracy.csv
└── temporal_accuracy.csv
Each file has the columns
Model,<context‑length‑1>,<context‑length‑2>,…
where the context‑length headers are strings such as `30min`, `60min`, `120min`, …
No further cleaning / renaming is done apart from two cosmetic replacements
(“gpt4.1” → “ChatGPT 4.1”, “gemini2.5pro” → “Gemini 2.5 Pro”).
"""
from pathlib import Path
import pandas as pd
import plotly.graph_objects as go
import gradio as gr
import math
# --------------------------------------------------------------------- #
# Config #
# --------------------------------------------------------------------- #
FILES = {
"aggregated": "aggregated_accuracy.csv",
"qa": "qa_accuracy.csv",
"ocr": "ocr_accuracy.csv",
"temporal": "temporal_accuracy.csv",
}
# Mapping of internal benchmark keys to nicely formatted display labels
DISPLAY_LABELS = {
"aggregated": "Aggregated",
"qa": "QA",
"ocr": "OCR",
"temporal": "Temporal",
}
# Optional: choose which models are selected by default for each benchmark.
# Use the *display names* exactly as they appear in the Models list.
# If a benchmark is missing, it falls back to the first six models.
DEFAULT_MODELS: dict[str, list[str]] = {
"aggregated": [
"Gemini 2.5 Pro",
"ChatGPT 4.1",
"Qwen2.5-VL-7B",
"InternVL2.5-8B",
"LLaMA-3.2-11B-Vision",
],
}
RENAME = {
r"gpt4\.1": "ChatGPT 4.1",
r"Gemini\s2\.5\spro": "Gemini 2.5 Pro",
r"LLaMA-3\.2B-11B": "LLaMA-3.2-11B-Vision",
}
# --------------------------------------------------------------------- #
# Data loading #
# --------------------------------------------------------------------- #
def _read_csv(path: str | Path) -> pd.DataFrame:
df = pd.read_csv(path)
df["Model"] = df["Model"].replace(RENAME, regex=True).astype(str)
return df
dfs: dict[str, pd.DataFrame] = {name: _read_csv(path) for name, path in FILES.items()}
# --------------------------------------------------------------------- #
# Colour palette and model metadata #
# --------------------------------------------------------------------- #
import plotly.express as px
SAFE_PALETTE = px.colors.qualitative.Safe # colour-blind-safe qualitative palette (10 colours)
# Deterministic list of all unique model names to ensure consistent colour mapping
ALL_MODELS: list[str] = sorted({m for df in dfs.values() for m in df["Model"].unique()})
MARKER_SYMBOLS = [
"circle",
"square",
"triangle-up",
"diamond",
"cross",
"triangle-down",
"x",
"triangle-right",
"triangle-left",
"pentagon",
]
TIME_COLS = [c for c in dfs["aggregated"].columns if c.lower() != "model"]
def _pretty_time(label: str) -> str:
"""‘30min’ → ‘30min’; ‘120min’ → ‘2hr’; keeps original if no match."""
if label.endswith("min"):
minutes = int(label[:-3])
if minutes >= 60:
hours = minutes / 60
return f"{hours:.0f}hr" if hours.is_integer() else f"{hours:.1f}hr"
return label
TIME_LABELS = {c: _pretty_time(c) for c in TIME_COLS}
# --------------------------------------------------------------------- #
# Plotting #
# --------------------------------------------------------------------- #
def render_chart(
benchmark: str,
models: list[str],
log_scale: bool,
) -> go.Figure:
bench_key = benchmark.lower()
df = dfs[bench_key]
fig = go.Figure()
# Define colour and marker based on deterministic mapping
palette = SAFE_PALETTE
# Determine minimum non-zero Y value across selected models for log scaling
min_y_val = None
for idx, m in enumerate(models):
row = df.loc[df["Model"] == m]
if row.empty:
continue
y = row[TIME_COLS].values.flatten()
y = [val if val != 0 else None for val in y] # show gaps for 0 / missing
# Track minimum non-zero accuracy
y_non_none = [val for val in y if val is not None]
if y_non_none:
cur_min = min(y_non_none)
if min_y_val is None or cur_min < min_y_val:
min_y_val = cur_min
model_idx = ALL_MODELS.index(m) if m in ALL_MODELS else idx
color = palette[model_idx % len(palette)]
symbol = MARKER_SYMBOLS[model_idx % len(MARKER_SYMBOLS)]
fig.add_trace(
go.Scatter(
x=[TIME_LABELS[c] for c in TIME_COLS],
y=y,
mode="lines+markers",
name=m,
line=dict(width=3, color=color),
marker=dict(size=6, color=color, symbol=symbol),
connectgaps=False,
)
)
# Set Y-axis properties
if log_scale:
# Fallback to 0.1 if there are no valid points
if min_y_val is None or min_y_val <= 0:
min_y_val = 0.1
# Plotly expects log10 values for range when axis type is "log"
yaxis_range = [math.floor(math.log10(min_y_val)), 2] # max at 10^2 = 100
yaxis_type = "log"
else:
yaxis_range = [0, 100]
yaxis_type = "linear"
fig.update_layout(
title=f"{DISPLAY_LABELS.get(bench_key, bench_key.capitalize())} Accuracy Over Time",
xaxis_title="Video Duration",
yaxis_title="Accuracy (%)",
yaxis_type=yaxis_type,
yaxis_range=yaxis_range,
legend_title="Model",
legend=dict(
orientation="h",
y=-0.25,
x=0.5,
xanchor="center",
tracegroupgap=8,
itemwidth=60,
),
margin=dict(t=40, r=20, b=80, l=60),
template="plotly_dark",
font=dict(family="Inter,Helvetica,Arial,sans-serif", size=14),
title_font=dict(size=20, family="Inter,Helvetica,Arial,sans-serif", color="white"),
xaxis=dict(gridcolor="rgba(255,255,255,0.15)"),
yaxis=dict(gridcolor="rgba(255,255,255,0.15)"),
hoverlabel=dict(bgcolor="#1e1e1e", font_color="#eeeeee", bordercolor="#888"),
)
return fig
# --------------------------------------------------------------------- #
# UI #
# --------------------------------------------------------------------- #
CSS = """
#controls {
padding: 8px 12px;
}
.scrollbox {
max-height: 300px;
overflow-y: auto;
}
body, .gradio-container {
font-family: 'Inter', 'Helvetica', sans-serif;
}
.gradio-container h1, .gradio-container h2 {
font-weight: 600;
}
#controls, .scrollbox {
background: rgba(255,255,255,0.02);
border-radius: 6px;
}
input[type="checkbox"]:checked {
accent-color: #FF715E;
}
"""
def available_models(bench: str) -> list[str]:
return sorted(dfs[bench]["Model"].unique())
def default_models(bench: str) -> list[str]:
"""Return list of default-selected models for a benchmark."""
opts = available_models(bench)
configured = DEFAULT_MODELS.get(bench, [])
# Keep only those present in opts
valid = [m for m in configured if m in opts]
if not valid:
# Fall back to first six
valid = opts[:6]
return valid
with gr.Blocks(theme=gr.themes.Base(), css=CSS) as demo:
gr.Markdown(
"""
# 📈 TimeScope
How long can your video model keep up?
"""
)
# ---- top controls row ---- #
with gr.Row():
benchmark_dd = gr.Dropdown(
label="Type",
choices=list(DISPLAY_LABELS.values()),
value=DISPLAY_LABELS["aggregated"],
scale=1,
)
log_cb = gr.Checkbox(
label="Log-scale Y-axis",
value=False,
scale=1,
)
# ---- models list and plot ---- #
plot_out = gr.Plot(
render_chart("Aggregated", default_models("aggregated"), False)
)
models_cb = gr.CheckboxGroup(
label="Models",
choices=available_models("aggregated"),
value=default_models("aggregated"),
interactive=True,
elem_classes=["scrollbox"],
)
# ‑-- dynamic callbacks ‑-- #
def _update_models(bench: str):
bench_key = bench.lower()
opts = available_models(bench_key)
defaults = default_models(bench_key)
# Use generic gr.update for compatibility across Gradio versions
return gr.update(choices=opts, value=defaults)
benchmark_dd.change(
fn=_update_models,
inputs=benchmark_dd,
outputs=models_cb,
queue=False,
)
for ctrl in (benchmark_dd, models_cb, log_cb):
ctrl.change(
fn=render_chart,
inputs=[benchmark_dd, models_cb, log_cb],
outputs=plot_out,
queue=False,
)
# Make legend interaction clearer: click to toggle traces
demo.launch(share=True)