Spaces:
Sleeping
Sleeping
| # app.py | |
| import json | |
| import threading | |
| import time | |
| from pathlib import Path | |
| import solara | |
| import pandas as pd | |
| import plotly.graph_objects as go | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| # for robust hover/click from the browser | |
| import anywidget | |
| import traitlets as t | |
| import html # for escaping token text in the HTML label | |
| # ---------- Model ---------- | |
| MODEL_ID = "Qwen/Qwen3-0.6B" # same as the working HF Space | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | |
| model = AutoModelForCausalLM.from_pretrained(MODEL_ID) | |
| # ---------- Theme & layout (light blue / white / black accents) ---------- | |
| theme_css = """ | |
| :root{ | |
| --primary:#38bdf8; --bg:#ffffff; --text:#0b0f14; --muted:#6b7280; --border:#e5e7eb; | |
| --mono:'ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, "Liberation Mono", "Courier New", monospace'; | |
| } | |
| /* Base */ | |
| body{ background:var(--bg); color:var(--text); margin:0;} | |
| h1{ margin:6px 0 8px; } | |
| /* Two-column layout */ | |
| .app-row { display:flex; align-items:flex-start; gap:16px; } /* was 24px */ | |
| .predictions-panel { flex:0 0 320px; position:relative; z-index:10;}/* was 360px */ | |
| .plot-panel { flex:1 1 auto; position:relative; z-index:1; overflow:hidden; } | |
| /* Prediction rows (tighter) */ | |
| .rowbtn{ | |
| width:100%; | |
| padding:6px 10px; /* was 10px 12px */ | |
| border-radius:10px; /* was 12px */ | |
| border:1px solid var(--border); | |
| background:#fff; color:var(--text); | |
| display:flex; justify-content:flex-start; align-items:center; | |
| text-align:left; cursor:pointer; user-select:none; | |
| font-family: var(--mono); | |
| font-size:13px; /* was default ~14–16 */ | |
| line-height:1.15; | |
| letter-spacing:.2px; | |
| margin-bottom:6px; /* explicit, keeps list consistent */ | |
| } | |
| .rowbtn:hover{ background:#f7fbff; border-color:#c3e8fb; } | |
| /* New: 4-column grid inside each row button */ | |
| .rowbtn-grid{ | |
| display:grid; | |
| grid-template-columns: 28px 72px 72px 1fr; /* # | probs | tokenID | token */ | |
| column-gap:8px; | |
| align-items:center; | |
| width:100%; | |
| font-family: var(--mono); | |
| font-size:13px; | |
| line-height:1.15; | |
| } | |
| /* Neighbor chips (smaller) */ | |
| .badge{ | |
| display:inline-block; padding:2px 6px; /* was 2px 8px */ | |
| border:1px solid var(--border); border-radius:999px; margin:2px; | |
| font-size:12px; line-height:1.15; | |
| } | |
| """ | |
| # ---------- Reactive state ---------- | |
| text_rx = solara.reactive("Twinkle, twinkle, little ") | |
| preds_rx = solara.reactive(pd.DataFrame(columns=["probs", "id", "tok"])) | |
| selected_token_id_rx = solara.reactive(None) | |
| neighbor_list_rx = solara.reactive([]) | |
| last_hovered_id_rx = solara.reactive(None) | |
| auto_running_rx = solara.reactive(True) | |
| neigh_msg_rx = solara.reactive("") # message shown when no neighborhood is available | |
| # ---------- Embedding assets ---------- | |
| ASSETS = Path("assets/embeddings") | |
| COORDS_PATH = ASSETS / "pca_top5k_coords.json" | |
| NEIGH_PATH = ASSETS / "neighbors_top5k_k40.json" | |
| coords = {} | |
| neighbors = {} | |
| ids_set = set() | |
| if COORDS_PATH.exists() and NEIGH_PATH.exists(): | |
| coords = json.loads(COORDS_PATH.read_text("utf-8")) | |
| neighbors = json.loads(NEIGH_PATH.read_text("utf-8")) | |
| ids_set = set(map(int, coords.keys())) | |
| else: | |
| notice_rx.set("Embedding files not found — add assets/embeddings/*.json to enable the map.") | |
| # ---------- Helpers ---------- | |
| def display_token_from_id(tid: int) -> str: | |
| toks = tokenizer.convert_ids_to_tokens([int(tid)], skip_special_tokens=True) | |
| t = toks[0] if toks else "" | |
| for lead in ("▁", "Ġ"): | |
| if t.startswith(lead): | |
| t = t[len(lead):] | |
| t = t.replace("\n","↵") | |
| return t if t.strip() else "␠" | |
| def fmt_row(idx: int, prob: str, tid: int, tok_disp: str) -> str: | |
| # columns: index, probability, token id, token text | |
| return f"{idx:<2} {prob:<7} {tid:<6} {tok_disp}" | |
| # ---------- Prediction ---------- | |
| def predict_top10(prompt: str) -> pd.DataFrame: | |
| if not prompt: | |
| return pd.DataFrame(columns=["probs", "id", "tok"]) | |
| tokens = tokenizer(prompt, return_tensors="pt", padding=False) | |
| out = model.generate( | |
| **tokens, | |
| max_new_tokens=1, | |
| output_scores=True, | |
| return_dict_in_generate=True, | |
| pad_token_id=tokenizer.eos_token_id, | |
| do_sample=False, # greedy; temp/top_k are ignored (by design) | |
| ) | |
| scores = torch.softmax(out.scores[0], dim=-1) | |
| topk = torch.topk(scores, 10) | |
| ids = [int(topk.indices[0, i]) for i in range(10)] | |
| probs = [float(topk.values[0, i]) for i in range(10)] | |
| toks = [tokenizer.decode([i]) for i in ids] # for append | |
| df = pd.DataFrame({"probs": probs, "id": ids, "tok": toks}) | |
| df["probs"] = df["probs"].map(lambda p: f"{p:.2%}") | |
| return df | |
| def on_predict(): | |
| df = predict_top10(text_rx.value) | |
| preds_rx.set(df) | |
| if len(df) == 0: | |
| return | |
| if selected_token_id_rx.value is None: | |
| preview_token(int(df.iloc[0]["id"])) # only first time | |
| else: | |
| fig_rx.set(highlight(int(selected_token_id_rx.value))) # preserve selection | |
| # ---------- Plot / neighborhood ---------- | |
| def base_scatter(): | |
| fig = go.Figure() | |
| if coords: | |
| xs, ys = zip(*[coords[k] for k in coords.keys()]) | |
| fig.add_trace(go.Scattergl( | |
| x=xs, y=ys, mode="markers", | |
| marker=dict(size=3, opacity=1.0, color="rgba(56,189,248,0.15)"), | |
| hoverinfo="skip", | |
| )) | |
| fig.update_layout( | |
| height=380, margin=dict(l=6,r=6,t=6,b=6), | |
| paper_bgcolor="white", plot_bgcolor="white", | |
| xaxis=dict(visible=False), yaxis=dict(visible=False), | |
| showlegend=False, | |
| ) | |
| return fig | |
| fig_rx = solara.reactive(base_scatter()) | |
| def get_neighbor_list(token_id: int, k: int = 20): | |
| if not ids_set or token_id not in ids_set: | |
| return [] | |
| raw = neighbors.get("neighbors", {}).get(str(token_id), []) | |
| return raw[:k] | |
| def highlight(token_id: int): | |
| fig = base_scatter() | |
| # Not in map (or missing map) → clear chips and show message | |
| if not coords or token_id not in ids_set: | |
| neighbor_list_rx.set([]) | |
| if not coords: | |
| neigh_msg_rx.set("Embedding map unavailable – add `assets/embeddings/*.json`.") | |
| else: | |
| neigh_msg_rx.set("Neighborhood unavailable for this token (not in the top-5k set).") | |
| return fig | |
| # In map → clear message and draw neighbors/target | |
| neigh_msg_rx.set("") | |
| nbrs = get_neighbor_list(token_id, k=20) | |
| if nbrs: | |
| nx = [coords[str(nid)][0] for nid,_ in nbrs] | |
| ny = [coords[str(nid)][1] for nid,_ in nbrs] | |
| fig.add_trace(go.Scattergl( | |
| x=nx, y=ny, mode="markers", | |
| marker=dict(size=6, color="rgba(56,189,248,0.75)"), | |
| hoverinfo="skip", | |
| )) | |
| chips = [(display_token_from_id(int(nid)), float(sim)) for nid,sim in nbrs] | |
| neighbor_list_rx.set(chips) | |
| else: | |
| neighbor_list_rx.set([]) | |
| tx, ty = coords[str(token_id)] | |
| fig.add_trace(go.Scattergl( | |
| x=[tx], y=[ty], mode="markers", | |
| marker=dict(size=10, color="rgba(34,211,238,1.0)", line=dict(width=1)), | |
| hoverinfo="skip", | |
| )) | |
| return fig | |
| def preview_token(token_id: int): | |
| # print("preview ->", token_id) # enable for debugging in Space logs | |
| token_id = int(token_id) | |
| if last_hovered_id_rx.value == token_id: | |
| return | |
| last_hovered_id_rx.set(token_id) | |
| selected_token_id_rx.set(token_id) | |
| fig_rx.set(highlight(token_id)) | |
| def append_token(token_id: int): | |
| # print("append ->", token_id) | |
| decoded = tokenizer.decode([int(token_id)]) | |
| text_rx.set(text_rx.value + decoded) | |
| preview_token(int(token_id)) | |
| on_predict() | |
| # ---------- Debounced auto-predict ---------- | |
| def AutoPredictWatcher(): | |
| text = text_rx.value | |
| auto = auto_running_rx.value | |
| def effect(): | |
| if not auto: | |
| return | |
| cancelled = False | |
| snap = text | |
| def worker(): | |
| time.sleep(0.25) | |
| if not cancelled and snap == text_rx.value: | |
| on_predict() | |
| threading.Thread(target=worker, daemon=True).start() | |
| def cleanup(): | |
| nonlocal cancelled | |
| cancelled = True | |
| return cleanup | |
| solara.use_effect(effect, [text, auto]) | |
| return solara.Text("", style={"display": "none"}) | |
| # ---------- Hover-enabled list (browser) ---------- | |
| class HoverList(anywidget.AnyWidget): | |
| """ | |
| Renders the prediction rows in the browser and streams hover/click events | |
| back to Python via synced traitlets. Supports HTML row labels via `label_html`. | |
| """ | |
| _esm = """ | |
| export function render({ model, el }) { | |
| const renderList = () => { | |
| const items = model.get('items') || []; | |
| el.innerHTML = ""; | |
| const wrap = document.createElement('div'); | |
| wrap.style.display = 'flex'; | |
| wrap.style.flexDirection = 'column'; | |
| items.forEach((item) => { | |
| const { tid, label, label_html } = item; | |
| const btn = document.createElement('button'); | |
| btn.className = 'rowbtn'; | |
| btn.setAttribute('type', 'button'); | |
| btn.setAttribute('role', 'button'); | |
| btn.setAttribute('tabindex', '0'); | |
| // Prefer HTML layout if provided; fall back to plain text | |
| if (label_html) { btn.innerHTML = label_html; } | |
| else { btn.textContent = label || ""; } | |
| // Hover → preview (bind several events for reliability) | |
| const preview = () => { | |
| model.set('hovered_id', tid); | |
| model.save_changes(); | |
| }; | |
| btn.addEventListener('mouseenter', preview); | |
| btn.addEventListener('mouseover', preview); | |
| btn.addEventListener('mousemove', preview); | |
| btn.addEventListener('focus', preview); | |
| // Click → append | |
| btn.addEventListener('click', () => { | |
| model.set('clicked_id', tid); | |
| model.save_changes(); | |
| }); | |
| wrap.appendChild(btn); | |
| }); | |
| el.appendChild(wrap); | |
| }; | |
| renderList(); | |
| model.on('change:items', renderList); | |
| } | |
| """ | |
| items = t.List(trait=t.Dict()).tag(sync=True) # [{tid:int, label?:str, label_html?:str}, ...] | |
| hovered_id = t.Int(allow_none=True).tag(sync=True) | |
| clicked_id = t.Int(allow_none=True).tag(sync=True) | |
| # ---------- Predictions list (uses HoverList) ---------- | |
| def PredictionsList(): | |
| df = preds_rx.value | |
| with solara.Column(gap="6px", style={"maxWidth": "720px"}): | |
| solara.Markdown("### Prediction") | |
| solara.Text( | |
| " # probs tokenID next predicted", | |
| style={ | |
| "color": "var(--muted)", | |
| "fontFamily": 'ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, "Liberation Mono", "Courier New", monospace', | |
| }, | |
| ) | |
| # Build items for the browser widget | |
| items = [] | |
| for i, row in df.iterrows(): | |
| tid = int(row["id"]) | |
| prob = row["probs"] # already a formatted string like "4.12%" | |
| tok_disp = display_token_from_id(tid) | |
| tok_safe = html.escape(tok_disp) # protect the HTML label | |
| label_html = ( | |
| f'<div class="rowbtn-grid">' | |
| f' <span class="c0">{i}</span>' | |
| f' <span class="c1">{prob}</span>' | |
| f' <span class="c2">{tid}</span>' | |
| f' <span class="c3">{tok_safe}</span>' | |
| f'</div>' | |
| ) | |
| items.append({"tid": tid, "label_html": label_html}) # <-- note label_html | |
| w = HoverList() | |
| w.items = items | |
| # Hover → preview (updates plot + neighbor chips) | |
| def _on_hover(change): | |
| tid = change["new"] | |
| if tid is not None: | |
| preview_token(int(tid)) | |
| w.observe(_on_hover, names="hovered_id") | |
| # Click → append | |
| def _on_click(change): | |
| tid = change["new"] | |
| if tid is not None: | |
| append_token(int(tid)) | |
| w.observe(_on_click, names="clicked_id") | |
| solara.display(w) | |
| # ---------- Page ---------- | |
| def Page(): | |
| solara.Style(theme_css) | |
| with solara.Column(margin=8, gap="10px"): | |
| solara.Markdown("# Next-Token Predictor + Semantic Neighborhood") | |
| solara.Markdown( | |
| "Type text to see AI's top predictions for the next token. " | |
| "Click a token to append it to your text. " | |
| "Hover over a token to preview its **semantic neighborhood**." | |
| ) | |
| solara.InputText("Enter text", value=text_rx, continuous_update=True, style={"minWidth":"520px"}) | |
| with solara.Row(classes=["app-row"]): | |
| with solara.Column(classes=["predictions-panel"]): | |
| PredictionsList() | |
| with solara.Column(classes=["plot-panel"]): | |
| solara.Markdown("### Semantic Neighborhood") | |
| if not coords: | |
| solara.Markdown("> Embedding map unavailable – add `assets/embeddings/*.json`.") | |
| else: | |
| solara.FigurePlotly(fig_rx.value) | |
| if neighbor_list_rx.value: | |
| solara.Markdown("**Nearest neighbors:**") | |
| with solara.Row(style={"flex-wrap":"wrap"}): | |
| for tok, sim in neighbor_list_rx.value: | |
| solara.HTML( | |
| tag="span", | |
| unsafe_innerHTML=f'<span class="badge">{tok} {(sim*100):.1f}%</span>' | |
| ) | |
| elif neigh_msg_rx.value: | |
| solara.Text(neigh_msg_rx.value, style={"color":"var(--muted)"}) | |
| AutoPredictWatcher() | |
| # ---------- Kickoff ---------- | |
| on_predict() | |
| Page() | |