model-explorer / app.py
bruAristimunha's picture
Redesign: IBM Plex typography + Okabe-Ito palette + restructured layout
03ad70c verified
"""Braindecode Model Explorer — interactive architecture browser.
Hugging Face Space that browses every EEG architecture in braindecode.
For each model: rendered docstring (figure, references, parameter list)
plus live instantiation to inspect param count and layer summary.
No pretrained weights are loaded — this is a pure architecture browser.
Aesthetic: editorial scientific instrument. IBM Plex (Sans / Serif /
Mono), Okabe-Ito colorblind-safe palette, warm-paper background. All
visual styling lives in GLOBAL_CSS below; the docstring renderer emits
structural HTML only.
"""
from __future__ import annotations
import inspect
from typing import Any
import gradio as gr
import torch
from torchinfo import summary
import braindecode.models as M
from braindecode.models.base import EEGModuleMixin
from docstring_renderer import (
get_signature_str,
get_source_link,
render_docstring_html,
)
# ---------------------------------------------------------------------------
# Catalog: discover every EEGModuleMixin subclass exported by braindecode.
# ---------------------------------------------------------------------------
def _discover_models() -> dict[str, type]:
catalog: dict[str, type] = {}
for name in sorted(getattr(M, "__all__", []) or dir(M)):
if name.startswith("_"):
continue
obj = getattr(M, name, None)
if (
inspect.isclass(obj)
and issubclass(obj, EEGModuleMixin)
and obj is not EEGModuleMixin
):
catalog[name] = obj
return catalog
MODELS: dict[str, type] = _discover_models()
MODEL_NAMES: list[str] = sorted(MODELS.keys())
DEFAULT_MODEL = "EEGNetv4" if "EEGNetv4" in MODELS else MODEL_NAMES[0]
try:
import braindecode as _bd
BD_VERSION = getattr(_bd, "__version__", "unknown")
except Exception:
BD_VERSION = "unknown"
# ---------------------------------------------------------------------------
# Heuristic defaults for the signal-shape form. Different model families
# expect very different inputs (sleep stagers want 30 s @ 100 Hz; motor-
# imagery models want ~4 s @ 250 Hz).
# ---------------------------------------------------------------------------
DEFAULTS = {
"sleep": dict(n_chans=2, sfreq=100, input_window_seconds=30.0, n_outputs=5),
"biot": dict(n_chans=16, sfreq=200, input_window_seconds=10.0, n_outputs=2),
"bendr": dict(n_chans=20, sfreq=256, input_window_seconds=4.0, n_outputs=2),
"labram": dict(n_chans=22, sfreq=200, input_window_seconds=4.0, n_outputs=2),
"default": dict(n_chans=22, sfreq=250, input_window_seconds=4.0, n_outputs=4),
}
def _defaults_for(name: str) -> dict[str, Any]:
lower = name.lower()
if "sleep" in lower or name in {"USleep", "AttnSleep", "DeepSleepNet"}:
return DEFAULTS["sleep"]
if "biot" in lower:
return DEFAULTS["biot"]
if "bendr" in lower:
return DEFAULTS["bendr"]
if "labram" in lower or "cbramod" in lower or "eegpt" in lower:
return DEFAULTS["labram"]
return DEFAULTS["default"]
# ---------------------------------------------------------------------------
# Global stylesheet — IBM Plex + Okabe-Ito + spatial system. Injected
# once via gr.Blocks(css=...).
# ---------------------------------------------------------------------------
GLOBAL_CSS = """
@import url('https://fonts.googleapis.com/css2?family=IBM+Plex+Sans:wght@400;500;600;700&family=IBM+Plex+Serif:wght@500;600&family=IBM+Plex+Mono:wght@400;500;600&display=swap');
:root {
--bd-blue: #0072B2;
--bd-green: #009E73;
--bd-orange: #D55E00;
--bd-pink: #CC79A7;
--bd-yellow: #E69F00;
--bd-skyblue: #56B4E9;
--bd-paper: #FAFAF7;
--bd-paper-deep: #F1EFE8;
--bd-rule: #E5E2D9;
--bd-ink: #1a1a1a;
--bd-meta: #6b6b6b;
}
/* Container & background ---------------------------------------------- */
body, .gradio-container {
background: var(--bd-paper) !important;
font-family: 'IBM Plex Sans', system-ui, sans-serif !important;
color: var(--bd-ink);
}
.gradio-container { max-width: 1320px !important; padding: 0 24px !important; }
.gradio-container * { font-family: inherit; }
/* Header band --------------------------------------------------------- */
.bd-header {
display: flex;
align-items: baseline;
justify-content: space-between;
padding: 22px 0 18px 0;
border-bottom: 1px solid var(--bd-rule);
margin-bottom: 28px;
flex-wrap: wrap;
gap: 12px;
}
.bd-header-title {
font-family: 'IBM Plex Serif', serif;
font-size: 26px;
font-weight: 600;
color: var(--bd-ink);
letter-spacing: -0.015em;
}
.bd-header-title .bd-mark {
color: var(--bd-blue);
font-weight: 500;
font-style: italic;
}
.bd-header-meta {
font-family: 'IBM Plex Mono', monospace;
font-size: 12px;
color: var(--bd-meta);
letter-spacing: 0.04em;
text-transform: uppercase;
}
.bd-header-meta .bd-dot { color: var(--bd-blue); margin: 0 8px; }
/* Info card (model display) ------------------------------------------- */
.bd-info {
margin: 0 0 24px 0;
padding-bottom: 18px;
border-bottom: 2px solid var(--bd-blue);
}
.bd-display {
font-family: 'IBM Plex Serif', serif;
font-size: 36px;
font-weight: 600;
color: var(--bd-blue);
letter-spacing: -0.02em;
line-height: 1.1;
margin: 0 0 6px 0;
}
.bd-tagline {
font-family: 'IBM Plex Sans', sans-serif;
font-size: 14px;
color: var(--bd-meta);
margin-bottom: 14px;
font-style: italic;
}
.bd-sig {
font-family: 'IBM Plex Mono', monospace;
font-size: 13px;
line-height: 1.55;
white-space: pre;
overflow-x: auto;
padding: 10px 14px;
background: var(--bd-paper-deep);
border-left: 2px solid var(--bd-blue);
color: #2a2a2a;
margin: 0 0 10px 0;
}
.bd-sig::-webkit-scrollbar { height: 6px; }
.bd-sig::-webkit-scrollbar-thumb { background: var(--bd-rule); border-radius: 3px; }
.bd-sig::-webkit-scrollbar-thumb:hover { background: var(--bd-blue); }
.bd-source {
display: inline-block;
color: var(--bd-meta);
font-size: 13px;
text-decoration: none;
border-bottom: 1px solid transparent;
transition: all 0.15s ease;
}
.bd-source:hover { color: var(--bd-blue); border-bottom-color: var(--bd-blue); }
/* Stat tile (live param count) ---------------------------------------- */
.bd-stat-card {
background: var(--bd-paper-deep);
border: 1px solid var(--bd-rule);
border-radius: 4px;
padding: 14px 16px;
margin-top: 12px;
}
.bd-meta-label {
font-family: 'IBM Plex Sans', sans-serif;
font-size: 11px;
font-weight: 600;
letter-spacing: 0.1em;
text-transform: uppercase;
color: var(--bd-meta);
margin: 0 0 4px 0;
}
.bd-stat {
font-family: 'IBM Plex Mono', monospace;
font-size: 28px;
font-weight: 600;
font-variant-numeric: tabular-nums;
color: var(--bd-blue);
line-height: 1;
}
.bd-stat-sub {
font-family: 'IBM Plex Mono', monospace;
font-size: 11px;
color: var(--bd-meta);
margin-top: 6px;
letter-spacing: 0.02em;
}
/* Section heading separator ------------------------------------------- */
.bd-section-rule {
display: flex; align-items: center;
gap: 12px;
margin: 28px 0 14px 0;
}
.bd-section-rule::before, .bd-section-rule::after {
content: ""; flex: 1;
height: 1px; background: var(--bd-rule);
}
.bd-section-rule span {
font-family: 'IBM Plex Sans', sans-serif;
font-size: 11px;
font-weight: 600;
letter-spacing: 0.14em;
text-transform: uppercase;
color: var(--bd-meta);
}
/* Docstring rendering (consumed by render_docstring_html) ------------- */
.bd-doc {
font-family: 'IBM Plex Sans', sans-serif;
font-size: 16px;
line-height: 1.65;
color: var(--bd-ink);
}
.bd-doc p, .bd-doc li { font-size: 16px; margin: 8px 0; }
.bd-doc h1, .bd-doc h2, .bd-doc h3 {
font-family: 'IBM Plex Serif', serif;
color: var(--bd-blue);
margin-top: 1.4em;
margin-bottom: 0.45em;
letter-spacing: -0.01em;
}
.bd-doc h1 { font-size: 24px; font-weight: 600; }
.bd-doc h2 { font-size: 20px; font-weight: 600; }
.bd-doc h3 { font-size: 17px; font-weight: 600;
font-family: 'IBM Plex Sans', sans-serif; }
.bd-doc pre {
background: var(--bd-paper-deep);
padding: 12px 14px;
border-radius: 4px;
font-family: 'IBM Plex Mono', monospace;
font-size: 13px;
line-height: 1.55;
overflow-x: auto;
border-left: 2px solid var(--bd-blue);
}
.bd-doc code {
background: rgba(0, 114, 178, 0.08);
padding: 1px 6px;
border-radius: 3px;
font-family: 'IBM Plex Mono', monospace;
font-size: 14px;
color: #0a4d77;
}
.bd-doc pre code { background: transparent; padding: 0; color: inherit; font-size: inherit; }
.bd-doc img {
max-width: 100%;
display: block;
margin: 18px auto;
border-radius: 4px;
box-shadow: 0 2px 14px rgba(0, 114, 178, 0.10);
}
.bd-doc table {
border-collapse: collapse;
margin: 14px 0;
font-size: 14px;
font-variant-numeric: tabular-nums;
width: 100%;
}
.bd-doc th, .bd-doc td {
border: 1px solid var(--bd-rule);
padding: 7px 12px;
text-align: left;
vertical-align: top;
}
.bd-doc th { background: var(--bd-paper-deep); font-weight: 600; color: var(--bd-meta); font-size: 12px; letter-spacing: 0.04em; text-transform: uppercase; }
.bd-doc .admonition {
border-left: 3px solid var(--bd-blue);
background: rgba(0, 114, 178, 0.05);
padding: 10px 16px;
margin: 16px 0;
border-radius: 0 4px 4px 0;
font-size: 15px;
}
.bd-doc .admonition.important { border-color: var(--bd-orange); background: rgba(213, 94, 0, 0.05); }
.bd-doc .admonition.note { border-color: var(--bd-green); background: rgba(0, 158, 115, 0.05); }
.bd-doc .admonition-title { font-weight: 600; margin-bottom: 4px; }
.bd-doc dl.field-list {
display: grid; grid-template-columns: max-content auto;
gap: 6px 16px; font-size: 15px; margin: 12px 0;
}
.bd-doc dl.field-list dt { font-weight: 600; color: var(--bd-meta); font-size: 13px; letter-spacing: 0.03em; text-transform: uppercase; padding-top: 2px; }
.bd-doc a { color: var(--bd-blue); text-decoration: none; border-bottom: 1px solid rgba(0, 114, 178, 0.3); }
.bd-doc a:hover { border-bottom-color: var(--bd-blue); }
/* Inline badge produced by docstring_renderer ------------------------- */
.bd-badge {
display: inline-block;
padding: 3px 10px;
border-radius: 3px;
color: white;
font-family: 'IBM Plex Sans', sans-serif;
font-size: 12px;
font-weight: 600;
letter-spacing: 0.02em;
margin: 0 4px 4px 0;
}
/* Form labels --------------------------------------------------------- */
label > span, .gradio-container .label-wrap span {
font-family: 'IBM Plex Sans', sans-serif;
font-size: 12px !important;
font-weight: 600 !important;
letter-spacing: 0.06em !important;
text-transform: uppercase !important;
color: var(--bd-meta) !important;
}
input[type="number"], textarea, select {
font-family: 'IBM Plex Mono', monospace !important;
font-size: 14px !important;
}
/* Primary button ------------------------------------------------------ */
button.primary, button[variant="primary"], .bd-cta {
background: var(--bd-blue) !important;
color: white !important;
font-family: 'IBM Plex Sans', sans-serif !important;
font-size: 13px !important;
font-weight: 600 !important;
letter-spacing: 0.06em !important;
text-transform: uppercase !important;
border-radius: 4px !important;
padding: 11px 18px !important;
border: none !important;
transition: background 0.15s ease;
}
button.primary:hover, button[variant="primary"]:hover, .bd-cta:hover {
background: #005a8c !important;
}
/* Footer -------------------------------------------------------------- */
.bd-footer {
margin: 40px 0 20px 0;
padding-top: 18px;
border-top: 1px solid var(--bd-rule);
font-family: 'IBM Plex Sans', sans-serif;
font-size: 12px;
color: var(--bd-meta);
letter-spacing: 0.04em;
display: flex;
justify-content: space-between;
flex-wrap: wrap;
gap: 8px;
}
.bd-footer a { color: var(--bd-blue); text-decoration: none; }
.bd-footer a:hover { text-decoration: underline; }
"""
# ---------------------------------------------------------------------------
# HTML fragments
# ---------------------------------------------------------------------------
import html as _html
def _info_card(name: str) -> str:
"""Primary visual anchor — display name + scrollable signature + source link."""
cls = MODELS[name]
sig = _html.escape(get_signature_str(cls))
link = get_source_link(cls) or "#"
docstring_first = (cls.__doc__ or "").strip().splitlines()
tagline = ""
if docstring_first:
first = docstring_first[0].strip()
# Strip rST cite markers like [Foo2023]_
import re as _re
first = _re.sub(r"\[\w+\]_", "", first).strip()
tagline = _html.escape(first[:200])
return (
f'<div class="bd-info">'
f' <div class="bd-display">{_html.escape(name)}</div>'
f' {f"<div class=\"bd-tagline\">{tagline}</div>" if tagline else ""}'
f' <pre class="bd-sig">{sig}</pre>'
f' <a class="bd-source" href="{link}" target="_blank">↗ Source on GitHub</a>'
f'</div>'
)
def _stat_tile(params: int | None = None, *, n_chans: int | None = None,
n_times: int | None = None, out_shape=None) -> str:
"""Live parameter count + input/output shapes."""
if params is None:
return (
'<div class="bd-stat-card">'
'<div class="bd-meta-label">Parameters</div>'
'<div class="bd-stat" style="color: var(--bd-meta);">—</div>'
'<div class="bd-stat-sub">press build to instantiate</div>'
'</div>'
)
pretty_out = (
f"({', '.join(str(d) for d in out_shape)})"
if isinstance(out_shape, tuple)
else str(out_shape)
)
return (
'<div class="bd-stat-card">'
'<div class="bd-meta-label">Parameters</div>'
f'<div class="bd-stat">{params:,}</div>'
f'<div class="bd-stat-sub">in (b, {n_chans}, {n_times}) → {pretty_out}</div>'
'</div>'
)
def _header_band() -> str:
return (
'<div class="bd-header">'
'<div class="bd-header-title">braindecode <span class="bd-mark">model explorer</span></div>'
f'<div class="bd-header-meta">v{BD_VERSION}<span class="bd-dot">•</span>{len(MODELS)} architectures<span class="bd-dot">•</span>no weights</div>'
'</div>'
)
def _section_rule(label: str) -> str:
return f'<div class="bd-section-rule"><span>{label}</span></div>'
def _footer() -> str:
return (
'<div class="bd-footer">'
'<div>An architecture browser for <a href="https://braindecode.org">braindecode</a>. '
'No pretrained weights served here — see '
'<a href="https://huggingface.co/braindecode">huggingface.co/braindecode</a>.</div>'
'<div><a href="https://github.com/braindecode/braindecode">github.com/braindecode/braindecode</a></div>'
'</div>'
)
# ---------------------------------------------------------------------------
# Event handlers
# ---------------------------------------------------------------------------
def show_model(name: str):
if name not in MODELS:
return "", "", _stat_tile(), {}, {}, {}, {}
info = _info_card(name)
doc_html = render_docstring_html(MODELS[name].__doc__)
d = _defaults_for(name)
return (
info,
doc_html,
_stat_tile(), # reset stat tile when switching models
gr.update(value=d["n_chans"]),
gr.update(value=d["sfreq"]),
gr.update(value=d["input_window_seconds"]),
gr.update(value=d["n_outputs"]),
)
def instantiate(name, n_chans, sfreq, window_s, n_outputs):
"""Build the model and return (stat_html, layer_summary_md)."""
if name not in MODELS:
return _stat_tile(), "Pick a model first."
cls = MODELS[name]
n_times = int(round(window_s * sfreq))
kwargs = dict(
n_chans=int(n_chans),
sfreq=float(sfreq),
input_window_seconds=float(window_s),
n_outputs=int(n_outputs),
)
sig_params = set(inspect.signature(cls.__init__).parameters)
kwargs = {k: v for k, v in kwargs.items() if k in sig_params}
try:
model = cls(**kwargs)
except Exception as exc: # noqa: BLE001
err = f"❌ **Failed to instantiate `{name}`** with `{kwargs}`:\n```\n{exc}\n```"
return _stat_tile(), err
n_params = sum(p.numel() for p in model.parameters())
try:
info = summary(
model,
input_size=(1, int(n_chans), n_times),
depth=3,
verbose=0,
col_names=("output_size", "num_params"),
)
summary_str = str(info)
except Exception as exc: # noqa: BLE001
summary_str = f"(torchinfo summary unavailable: {exc})"
out_shape: Any = "?"
try:
x = torch.randn(2, int(n_chans), n_times)
with torch.no_grad():
y = model(x)
out_shape = tuple(y.shape) if hasattr(y, "shape") else type(y).__name__
except Exception as exc: # noqa: BLE001
out_shape = f"forward failed: {exc}"
stat = _stat_tile(
params=n_params, n_chans=int(n_chans), n_times=n_times, out_shape=out_shape
)
return stat, f"```\n{summary_str}\n```"
# ---------------------------------------------------------------------------
# UI
# ---------------------------------------------------------------------------
def build_app() -> gr.Blocks:
theme = gr.themes.Soft(
primary_hue=gr.themes.colors.blue,
font=[gr.themes.GoogleFont("IBM Plex Sans"), "system-ui", "sans-serif"],
font_mono=[gr.themes.GoogleFont("IBM Plex Mono"), "monospace"],
)
with gr.Blocks(
title="Braindecode Model Explorer",
theme=theme,
css=GLOBAL_CSS,
) as app:
gr.HTML(_header_band())
with gr.Row(equal_height=False):
# ---------- LEFT: controls + stat tile ----------
with gr.Column(scale=1, min_width=280):
model_dd = gr.Dropdown(
choices=MODEL_NAMES,
value=DEFAULT_MODEL,
label="Architecture",
interactive=True,
filterable=True,
)
gr.HTML(_section_rule("Signal configuration"))
with gr.Group():
n_chans = gr.Number(value=22, label="n_chans", precision=0)
sfreq = gr.Number(value=250, label="sfreq · Hz")
window_s = gr.Number(value=4.0, label="window · seconds")
n_outputs = gr.Number(value=4, label="n_outputs", precision=0)
run_btn = gr.Button(
"Build network", variant="primary", elem_classes="bd-cta"
)
stat_html = gr.HTML(_stat_tile())
# ---------- RIGHT: model info + docstring ----------
with gr.Column(scale=3):
info_html = gr.HTML(_info_card(DEFAULT_MODEL))
gr.HTML(_section_rule("Architecture documentation"))
doc_html = gr.HTML(
render_docstring_html(MODELS[DEFAULT_MODEL].__doc__)
)
with gr.Accordion("Layer summary (after build)", open=False):
summary_md = gr.Markdown(
"_Press **Build network** to populate the summary._"
)
gr.HTML(_footer())
# ---------- wiring ----------
model_dd.change(
show_model,
inputs=model_dd,
outputs=[info_html, doc_html, stat_html, n_chans, sfreq, window_s, n_outputs],
)
run_btn.click(
instantiate,
inputs=[model_dd, n_chans, sfreq, window_s, n_outputs],
outputs=[stat_html, summary_md],
)
return app
if __name__ == "__main__":
# On HF Spaces the sandbox blocks localhost-only binds; expose on 0.0.0.0
# so the front-door proxy can reach us. Locally this still works fine.
build_app().launch(server_name="0.0.0.0", server_port=7860)