Spaces:
Running
Running
thibaud frere
Move assets into content/assets; update imports; clean .gitattributes; fix LFS tracking
b8e1b6c
| import plotly.graph_objects as go | |
| import plotly.io as pio | |
| import numpy as np | |
| import os | |
| import uuid | |
| """ | |
| Interactive line chart example (Baseline / Improved / Target) with a live slider. | |
| Context: research-style training curves for multiple datasets (CIFAR-10, CIFAR-100, ImageNet-1K). | |
| The slider "Augmentation α" blends the Improved curve between the Baseline (α=0) | |
| and an augmented counterpart (α=1) via a simple mixing equation. | |
| Export remains responsive, with no zoom and no mode bar. | |
| """ | |
| # Grid (x) and parameterization | |
| N = 240 | |
| x = np.linspace(0, 1, N) | |
| # Logistic helper for smooth learning curves | |
| def logistic(xv: np.ndarray, ymin: float, ymax: float, k: float, x0: float) -> np.ndarray: | |
| return ymin + (ymax - ymin) / (1.0 + np.exp(-k * (xv - x0))) | |
| # Plausible dataset params (baseline vs augmented) + a constant target line | |
| datasets_params = [ | |
| { | |
| "name": "CIFAR-10", | |
| "base": {"ymin": 0.10, "ymax": 0.90, "k": 10.0, "x0": 0.55}, | |
| "aug": {"ymin": 0.15, "ymax": 0.96, "k": 12.0, "x0": 0.40}, | |
| "target": 0.97, | |
| }, | |
| { | |
| "name": "CIFAR-100", | |
| "base": {"ymin": 0.05, "ymax": 0.70, "k": 9.5, "x0": 0.60}, | |
| "aug": {"ymin": 0.08, "ymax": 0.80, "k": 11.0, "x0": 0.45}, | |
| "target": 0.85, | |
| }, | |
| { | |
| "name": "ImageNet-1K", | |
| "base": {"ymin": 0.02, "ymax": 0.68, "k": 8.5, "x0": 0.65}, | |
| "aug": {"ymin": 0.04, "ymax": 0.75, "k": 9.5, "x0": 0.50}, | |
| "target": 0.82, | |
| }, | |
| ] | |
| # Initial dataset index and alpha | |
| alpha0 = 0.7 | |
| ds0 = datasets_params[0] | |
| base0 = logistic(x, **ds0["base"]) | |
| aug0 = logistic(x, **ds0["aug"]) | |
| target0 = np.full_like(x, ds0["target"], dtype=float) | |
| # Traces: Baseline (fixed), Improved (blended by α), Target (constant goal) | |
| blend = lambda l, e, a: (1 - a) * l + a * e | |
| y1 = base0 | |
| y2 = blend(base0, aug0, alpha0) | |
| y3 = target0 | |
| color_base = "#64748b" # slate-500 | |
| color_improved = "#F981D4" # pink | |
| color_target = "#4b5563" # gray-600 (dash) | |
| fig = go.Figure() | |
| fig.add_trace( | |
| go.Scatter( | |
| x=x, | |
| y=y1, | |
| name="Baseline", | |
| mode="lines", | |
| line=dict(color=color_base, width=2, shape="spline", smoothing=0.6), | |
| hovertemplate="<b>%{fullData.name}</b><br>x=%{x:.2f}<br>y=%{y:.3f}<extra></extra>", | |
| showlegend=True, | |
| ) | |
| ) | |
| fig.add_trace( | |
| go.Scatter( | |
| x=x, | |
| y=y2, | |
| name="Improved", | |
| mode="lines", | |
| line=dict(color=color_improved, width=2, shape="spline", smoothing=0.6), | |
| hovertemplate="<b>%{fullData.name}</b><br>x=%{x:.2f}<br>y=%{y:.3f}<extra></extra>", | |
| showlegend=True, | |
| ) | |
| ) | |
| fig.add_trace( | |
| go.Scatter( | |
| x=x, | |
| y=y3, | |
| name="Target", | |
| mode="lines", | |
| line=dict(color=color_target, width=2, dash="dash"), | |
| hovertemplate="<b>%{fullData.name}</b><br>x=%{x:.2f}<br>y=%{y:.3f}<extra></extra>", | |
| showlegend=True, | |
| ) | |
| ) | |
| fig.update_layout( | |
| autosize=True, | |
| paper_bgcolor="rgba(0,0,0,0)", | |
| plot_bgcolor="rgba(0,0,0,0)", | |
| margin=dict(l=40, r=28, t=20, b=40), | |
| hovermode="x unified", | |
| legend=dict( | |
| orientation="v", | |
| x=1, | |
| y=0, | |
| xanchor="right", | |
| yanchor="bottom", | |
| bgcolor="rgba(255,255,255,0)", | |
| borderwidth=0, | |
| ), | |
| hoverlabel=dict( | |
| bgcolor="white", | |
| font=dict(color="#111827", size=12), | |
| bordercolor="rgba(0,0,0,0.15)", | |
| align="left", | |
| namelength=-1, | |
| ), | |
| xaxis=dict( | |
| showgrid=False, | |
| zeroline=False, | |
| showline=True, | |
| linecolor="rgba(0,0,0,0.25)", | |
| linewidth=1, | |
| ticks="outside", | |
| ticklen=6, | |
| tickcolor="rgba(0,0,0,0.25)", | |
| tickfont=dict(size=12, color="rgba(0,0,0,0.55)"), | |
| title=None, | |
| automargin=True, | |
| fixedrange=True, | |
| ), | |
| yaxis=dict( | |
| showgrid=False, | |
| zeroline=False, | |
| showline=True, | |
| linecolor="rgba(0,0,0,0.25)", | |
| linewidth=1, | |
| ticks="outside", | |
| ticklen=6, | |
| tickcolor="rgba(0,0,0,0.25)", | |
| tickfont=dict(size=12, color="rgba(0,0,0,0.55)"), | |
| title=None, | |
| tickformat=".2f", | |
| rangemode="tozero", | |
| automargin=True, | |
| fixedrange=True, | |
| ), | |
| ) | |
| # Write the fragment next to this file into src/fragments/line.html (robust path) | |
| output_path = os.path.join(os.path.dirname(__file__), "fragments", "line.html") | |
| os.makedirs(os.path.dirname(output_path), exist_ok=True) | |
| # Inject a small post-render script to round the hover box corners | |
| post_script = """ | |
| (function(){ | |
| function attach(gd){ | |
| function round(){ | |
| try { | |
| var root = gd && gd.parentNode ? gd.parentNode : document; | |
| var rects = root.querySelectorAll('.hoverlayer .hovertext rect'); | |
| rects.forEach(function(r){ r.setAttribute('rx', 8); r.setAttribute('ry', 8); }); | |
| } catch(e) {} | |
| } | |
| if (gd && gd.on) { | |
| gd.on('plotly_hover', round); | |
| gd.on('plotly_unhover', round); | |
| gd.on('plotly_relayout', round); | |
| } | |
| setTimeout(round, 0); | |
| } | |
| var plots = document.querySelectorAll('.js-plotly-plot'); | |
| plots.forEach(attach); | |
| })(); | |
| """ | |
| html_plot = pio.to_html( | |
| fig, | |
| include_plotlyjs=False, | |
| full_html=False, | |
| post_script=post_script, | |
| config={ | |
| "displayModeBar": False, | |
| "responsive": True, | |
| "scrollZoom": False, | |
| "doubleClick": False, | |
| "modeBarButtonsToRemove": [ | |
| "zoom2d", "pan2d", "select2d", "lasso2d", | |
| "zoomIn2d", "zoomOut2d", "autoScale2d", "resetScale2d", | |
| "toggleSpikelines" | |
| ], | |
| }, | |
| ) | |
| # Build a self-contained fragment with a live slider (no mouseup required) | |
| uid = uuid.uuid4().hex[:8] | |
| slider_id = f"line-ex-alpha-{uid}" | |
| container_id = f"line-ex-container-{uid}" | |
| slider_tpl = ''' | |
| <div id="__CID__"> | |
| __PLOT__ | |
| <div class="plotly_controls" style="margin-top:12px; display:flex; gap:16px; align-items:center;"> | |
| <label style="font-size:12px;color:rgba(0,0,0,.65); display:flex; align-items:center; gap:6px; white-space:nowrap; padding:6px 10px;"> | |
| Dataset | |
| <select id="__DSID__" style="font-size:12px; padding:2px 6px;"> | |
| <option value="0">CIFAR-10</option> | |
| <option value="1">CIFAR-100</option> | |
| <option value="2">ImageNet-1K</option> | |
| </select> | |
| </label> | |
| <label style="font-size:12px;color:rgba(0,0,0,.65);display:flex;align-items:center;gap:10px; flex:1; padding:6px 10px;"> | |
| Augmentation α | |
| <input id="__SID__" type="range" min="0" max="1" step="0.01" value="__A0__" style="flex:1;"> | |
| <span class="alpha-value">__A0__</span> | |
| </label> | |
| </div> | |
| </div> | |
| <script> | |
| (function(){ | |
| var container = document.getElementById('__CID__'); | |
| if(!container) return; | |
| var gd = container.querySelector('.js-plotly-plot'); | |
| var slider = document.getElementById('__SID__'); | |
| var dsSelect = document.getElementById('__DSID__'); | |
| var valueEl = container.querySelector('.alpha-value'); | |
| var N = __N__; | |
| var xs = Array.from({length: N}, function(_,i){ return i/(N-1); }); | |
| function logistic(x, ymin, ymax, k, x0){ return ymin + (ymax - ymin) / (1 + Math.exp(-k*(x - x0))); } | |
| function blend(l,e,a){ return (1-a)*l + a*e; } | |
| var datasets = [ | |
| { name:'CIFAR-10', base:{ymin:0.10,ymax:0.90,k:10.0,x0:0.55}, aug:{ymin:0.15,ymax:0.96,k:12.0,x0:0.40}, target:0.97 }, | |
| { name:'CIFAR-100', base:{ymin:0.05,ymax:0.70,k:9.5,x0:0.60}, aug:{ymin:0.08,ymax:0.80,k:11.0,x0:0.45}, target:0.85 }, | |
| { name:'ImageNet-1K', base:{ymin:0.02,ymax:0.68,k:8.5,x0:0.65}, aug:{ymin:0.04,ymax:0.75,k:9.5,x0:0.50}, target:0.82 } | |
| ]; | |
| var dsi = 0; | |
| var yb = xs.map(function(x){ return logistic(x, datasets[dsi].base.ymin, datasets[dsi].base.ymax, datasets[dsi].base.k, datasets[dsi].base.x0); }); | |
| var ya = xs.map(function(x){ return logistic(x, datasets[dsi].aug.ymin, datasets[dsi].aug.ymax, datasets[dsi].aug.k, datasets[dsi].aug.x0); }); | |
| var yt = xs.map(function(){ return datasets[dsi].target; }); | |
| function applyAlpha(a){ | |
| var yi = yb.map(function(v,i){ return blend(v, ya[i], a); }); | |
| Plotly.restyle(gd, {y:[yi]}, [1]); // only Improved changes with α | |
| if(valueEl) valueEl.textContent = a.toFixed(2); | |
| } | |
| function applyDataset(){ | |
| var d = datasets[dsi]; | |
| yb = xs.map(function(x){ return logistic(x, d.base.ymin, d.base.ymax, d.base.k, d.base.x0); }); | |
| ya = xs.map(function(x){ return logistic(x, d.aug.ymin, d.aug.ymax, d.aug.k, d.aug.x0); }); | |
| yt = xs.map(function(){ return d.target; }); | |
| var a = parseFloat(slider.value)||0; | |
| var yi = yb.map(function(v,i){ return blend(v, ya[i], a); }); | |
| Plotly.restyle(gd, {y:[yb]}, [0]); // Baseline | |
| Plotly.restyle(gd, {y:[yi]}, [1]); // Improved (blended) | |
| Plotly.restyle(gd, {y:[yt]}, [2]); // Target | |
| } | |
| var initA = parseFloat(slider.value)||0; | |
| slider.addEventListener('input', function(e){ applyAlpha(parseFloat(e.target.value)||0); }); | |
| dsSelect.addEventListener('change', function(e){ dsi = parseInt(e.target.value)||0; applyDataset(); }); | |
| setTimeout(function(){ applyDataset(); applyAlpha(initA); }, 0); | |
| })(); | |
| </script> | |
| ''' | |
| slider_html = (slider_tpl | |
| .replace('__CID__', container_id) | |
| .replace('__SID__', slider_id) | |
| .replace('__A0__', f"{alpha0:.2f}") | |
| .replace('__N__', str(N)) | |
| .replace('__PLOT__', html_plot) | |
| ) | |
| with open("../../app/src/content/fragments/line.html", "w", encoding="utf-8") as f: | |
| f.write(slider_html) | |