bkhmsi's picture
bug fix
ae072a3
# app.py
"""
Hugging Face Space: MoE Expert Routing Visualizer (Gradio)
----------------------------------------------------------
This Space lets a user:
- Choose a model (from a dropdown or a free-text box)
- Enter a user prompt, and optionally an assistant prompt
- Call a backend function that returns 4 routing percentages (Language, Logic, Social, World)
- See a bar plot + table of the percentages
🧩 Plug your real routing function in router_backend.py -> get_expert_routing().
By default, a deterministic "mock mode" produces stable pseudo-random percentages from the prompt.
"""
import os
import hashlib
from typing import Dict, List, Tuple, Union
import gradio as gr
import plotly
import plotly.express as px
import pandas as pd
from router_backend import get_expert_routing
# ---- Expected backend adapter ------------------------------------------------
# Implement your real function in router_backend.py with the following signature:
# def get_expert_routing(model_id: str, prompt: str) -> Union[List[float], Dict[str, float], Tuple[float, float, float, float]]
# It MUST return 4 values that sum to ~100 (percentages) in the fixed order:
# ["Language", "Logic", "Social", "World"]
# or a mapping with those keys.
# try:
# from router_backend import get_expert_routing # your real backend
# BACKEND_AVAILABLE = True
# except Exception as e: # keep error for display if needed
# BACKEND_AVAILABLE = False
# _backend_import_error = e
EXPERTS = ["Language", "Logic", "Social", "World"]
DEFAULT_MODELS = [
"micro-smollm2-135m",
"micro-smollm2-360m",
"micro-llama-1b",
"micro-llama-3b",
"micro-llama-1b-dpo",
"micro-moe-smollm2-135m",
"micro-moe-smollm2-360m",
"micro-moe-llama-1b",
]
def _mock_routing(model_id: str, prompt: str, seed: int = 0) -> List[float]:
"""
Deterministic mock routing percentages based on model_id + prompt + seed.
Returns a list of 4 percentages summing to 100.0
"""
h = hashlib.sha256(f"{model_id}||{prompt}||{seed}".encode()).digest()
# split into 4 positive numbers
vals = [int.from_bytes(h[i*8:(i+1)*8], "little") % 10_000 + 1 for i in range(4)]
s = sum(vals)
return [100.0 * v / s for v in vals]
def _normalize_output(r: Union[List[float], Tuple[float, float, float, float], Dict[str, float]]) -> List[float]:
"""
Normalize different return types into a 4-length list ordered as EXPERTS.
"""
if isinstance(r, dict):
vals = [float(r.get(k, 0.0)) for k in EXPERTS]
else:
vals = [float(x) for x in list(r)]
if len(vals) != 4:
raise ValueError(f"Expected 4 values, got {len(vals)}.")
# renormalize to 100 if needed
s = sum(vals)
if s <= 0:
raise ValueError("Sum of routing percentages is non-positive.")
vals = [100.0 * v / s for v in vals]
return vals
def _compose_prompt(user_prompt: str, assistant_prompt: str) -> str:
user_prompt = (user_prompt or "").strip()
assistant_prompt = (assistant_prompt or "").strip()
if assistant_prompt:
return [{"role": "user", "content": user_prompt}, {"role": "assistant", "content": assistant_prompt}]
return user_prompt
def route_and_plot(
model_choice: str,
user_prompt: str,
assistant_prompt: str,
ablate_language: bool,
ablate_logic: bool,
ablate_social: bool,
ablate_world: bool,
) -> Tuple[pd.DataFrame, "plotly.graph_objs._figure.Figure", str]:
"""
Main pipeline:
- Compose prompt (user + optional assistant)
- Call backend (real or mock)
- Return a table + bar plot + status message
"""
hf_token = os.getenv("HF_TOKEN")
ablations = []
if ablate_language:
ablations.append("language")
if ablate_logic:
ablations.append("logic")
if ablate_social:
ablations.append("social")
if ablate_world:
ablations.append("world")
seed = 42
use_mock = False
if len(ablations) == 4:
msg = "Error message: you can't ablate all experts.<br>Falling back to mock data."
generation = None
vals = _mock_routing(model_id, prompt, seed=seed)
else:
model_id = model_choice.strip()
if not model_id:
raise gr.Error("Please select a model or enter a custom model id.")
prompt = _compose_prompt(user_prompt, assistant_prompt)
if not prompt:
raise gr.Error("Please enter a prompt.")
if use_mock:
msg = "Using mock data."
vals = _mock_routing(model_id, prompt, seed=seed)
generation = None
else:
try:
raw, generation = get_expert_routing(model_id, hf_token, prompt, ablations) # <-- your real function
vals = _normalize_output(raw)
msg = "Routed with real backend."
except Exception as e:
# fallback to mock on error, but surface message
msg = f"Backend error: {e}\nFalling back to mock data."
vals = _mock_routing(model_id, prompt, seed=seed)
generation = None
df = pd.DataFrame({"Expert": EXPERTS, "Percent": vals})
colors = ["#97D077", "#4285F4", "#FFAB40", "#A64D79"]
fig = px.bar(df, x="Expert", y="Percent", title="Token Routing by Expert (%)", text="Percent")
fig.update_traces(marker_color=colors)
fig.update_traces(texttemplate="%{text:.2f}%", textposition="outside")
fig.update_layout(yaxis_range=[0, max(100, max(vals) * 1.25)], bargap=0.35)
status = f"Model: {model_id}<br>{msg}"
if generation is None:
generation = assistant_prompt
return generation, df, fig, status
with gr.Blocks(theme=gr.themes.Soft(), fill_height=True) as demo:
gr.Markdown(
"""
# 🧠 Mixture of Cognitive Reasoner (MiCRo) Expert Routing Visualizer
## Enter a prompt (and optionally an assistant reply), pick a model, and visualize how tokens were routed across experts.
Paper: [Mixture of Cognitive Reasoners: Modular Reasoning with Brain-Like Specialization](https://arxiv.org/abs/2506.13331)
----
This demo visualizes how modular language models allocate computation across specialized experts—Language, Logic, Social, and World—when processing a given prompt.
Each expert corresponds to a cognitive domain inspired by human brain networks. Enter a prompt to see how tokens are dynamically routed across modules, revealing the model's internal reasoning structure.
""".strip()
)
with gr.Row():
model_choice = gr.Dropdown(choices=DEFAULT_MODELS, label="Select a model", value=DEFAULT_MODELS[0])
# hf_token = gr.Textbox(label="Huggingface token for authentication", placeholder="Required for Llama-based models", lines=1)
with gr.Column():
with gr.Row():
gr.Markdown(
"""
#### Ablate Experts
(Check to disable an expert; the routing percentages will be redistributed among the remaining experts)
""", label="Ablate Experts"
)
with gr.Row():
ablate_language = gr.Checkbox(value=False, label="Language Expert")
ablate_logic = gr.Checkbox(value=False, label="Logic Expert")
ablate_social = gr.Checkbox(value=False, label="Social Expert")
ablate_world = gr.Checkbox(value=False, label="World Expert")
with gr.Row():
user_prompt = gr.Textbox(lines=6, label="User prompt", placeholder="Type the user message here...")
assistant_prompt = gr.Textbox(lines=6, label="Assistant prompt (optional)", placeholder="Type the assistant message here (optional)...")
# with gr.Row():
# use_mock = gr.Checkbox(value=True, label="Use mock data (uncheck to call your backend)")
# seed = gr.Slider(value=0, minimum=0, maximum=10_000, step=1, label="Mock seed")
run = gr.Button("Run Routing", variant="primary")
generation_output = gr.Textbox(lines=4, label="Generated Response", placeholder="Generated text will appear here...", interactive=False)
with gr.Row():
table = gr.Dataframe(label="Routing Percentages", interactive=False)
plot = gr.Plot(label="Bar Plot")
status = gr.Markdown("", label="System Message")
run.click(
route_and_plot,
inputs=[model_choice, user_prompt, assistant_prompt, ablate_language, ablate_logic, ablate_social, ablate_world],
outputs=[generation_output, table, plot, status],
)
# example prompts
examples = [
[
"micro-llama-1b", # dropdown model
"Correct the grammar: \"She go to the park every morning.\"", # user prompt
"She goes to the park every morning.", # assistant prompt (empty)
False, False, False, False # no ablations
],
[
"micro-llama-1b", # dropdown model
"What is 27 multiplied by 14?", # user prompt
"First, break it down: 27 * 10 = 270. Then 27 * 4 = 108. Add them together: 270 + 108 = 378. So the answer is 378.", # assistant prompt (empty)
False, False, False, False # no ablations
],
[
"micro-llama-1b", # dropdown model
"Why did Sarah look away when John asked if she was okay?", # user prompt
"Because she didn't want him to see that she was upset.", # assistant prompt (empty)
False, False, False, False # no ablations
],
[
"micro-llama-1b", # dropdown model
"Why do people usually eat breakfast in the morning?", # user prompt
"Because after sleeping, the body needs energy to start the day.", # assistant prompt (empty)
False, False, False, False # no ablations
],
]
gr.Examples(
examples=examples,
inputs=[model_choice, user_prompt, assistant_prompt, ablate_language, ablate_logic, ablate_social, ablate_world],
label="Try these examples:",
cache_examples=True,
fn=route_and_plot,
outputs=[generation_output, table, plot, status],
)
if __name__ == "__main__":
demo.launch()