Spaces:
Running
Running
import gradio as gr | |
from typing import Dict, Any, Optional, List, Tuple, Union | |
import plotly.graph_objects as go | |
from plotly.subplots import make_subplots | |
from transformers import AutoConfig | |
import logging | |
# Set up logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
class PlotlyModelArchitectureVisualizer: | |
def __init__(self, hf_token: Optional[str] = None): | |
self.config = None | |
self.hf_token = hf_token | |
# Universal color scheme - consistent across all models | |
self.universal_colors = { | |
'embedding': '#f8f9fa', # Light gray for embeddings | |
'layer_norm': '#e9ecef', # Light gray for layer norms | |
'attention': '#495057', # Dark gray for attention | |
'output': '#f8f9fa', # Light gray for output layers | |
'text': '#212529', # Dark text | |
'container_outer': '#dee2e6', # Outer container | |
'moe_inner': '#d4edda', # Green background for MoE models | |
'dense_inner': '#f8d7da', # Red background for dense models | |
'feedforward_moe': '#28a745', # Green for MoE FFN | |
'feedforward_dense': '#dc3545', # Red for dense FFN | |
'router': '#fd7e14', # Orange for router | |
'expert': '#20c997', # Teal for experts | |
'callout_bg': 'rgba(255,255,255,0.9)', | |
'accent_blue': '#007bff', | |
'accent_green': '#28a745', | |
'accent_red': '#dc3545' | |
} | |
def get_model_config(self, model_name: str) -> Dict[str, Any]: | |
"""Fetch model configuration from Hugging Face""" | |
try: | |
logger.info(f"Fetching config for {model_name}") | |
config = AutoConfig.from_pretrained(model_name, token=self.hf_token, trust_remote_code=True) | |
return config.to_dict() | |
except Exception as e: | |
logger.error(f"Error fetching config for {model_name}: {e}") | |
return {} | |
def extract_config_values(self, config: Dict[str, Any]) -> Dict[str, Any]: | |
"""Extract and normalize configuration values with architecture detection""" | |
# Detect model architecture type | |
model_type = config.get('model_type', 'unknown').lower() | |
is_moe = any(key in config for key in [ | |
'num_experts', 'n_routed_experts', 'moe_intermediate_size', | |
'num_experts_per_tok', 'router_aux_loss_coef' | |
]) | |
# Extract MoE-specific parameters | |
moe_params = {} | |
if is_moe: | |
moe_params = { | |
'num_experts': config.get('num_experts', config.get('n_routed_experts', 'N/A')), | |
'experts_per_token': config.get('num_experts_per_tok', 'N/A'), | |
'moe_intermediate_size': config.get('moe_intermediate_size', 'N/A'), | |
'router_aux_loss': config.get('router_aux_loss_coef', config.get('aux_loss_alpha', 'N/A')), | |
'shared_experts': config.get('n_shared_experts', 0) | |
} | |
# Calculate model size estimate (simplified) | |
hidden_size = config.get('hidden_size', config.get('d_model', config.get('n_embd', 0))) | |
num_layers = config.get('num_hidden_layers', config.get('n_layer', config.get('num_layers', 0))) | |
vocab_size = config.get('vocab_size', 0) | |
if isinstance(hidden_size, int) and isinstance(num_layers, int) and isinstance(vocab_size, int): | |
# Very rough parameter count estimation | |
if is_moe: | |
# MoE models are much larger but use fewer parameters per token | |
estimated_params = (hidden_size * num_layers * vocab_size) // 1000000 # Simplified | |
size_suffix = "B" if estimated_params > 1000 else "M" | |
estimated_params = estimated_params // 1000 if estimated_params > 1000 else estimated_params | |
else: | |
estimated_params = (hidden_size * num_layers * vocab_size) // 1000000 | |
size_suffix = "B" if estimated_params > 1000 else "M" | |
estimated_params = estimated_params // 1000 if estimated_params > 1000 else estimated_params | |
else: | |
estimated_params = "Unknown" | |
size_suffix = "" | |
return { | |
'model_type': config.get('model_type', 'unknown'), | |
'hidden_size': hidden_size if hidden_size != 0 else 'N/A', | |
'num_layers': num_layers if num_layers != 0 else 'N/A', | |
'num_heads': config.get('num_attention_heads', config.get('n_head', config.get('num_heads', 'N/A'))), | |
'vocab_size': vocab_size if vocab_size != 0 else 'N/A', | |
'max_position': config.get('max_position_embeddings', | |
config.get('n_positions', config.get('max_seq_len', 'N/A'))), | |
'intermediate_size': config.get('intermediate_size', | |
config.get('d_ff', hidden_size if hidden_size != 0 else 'N/A')), | |
'is_moe': is_moe, | |
'moe_params': moe_params, | |
'estimated_size': f"{estimated_params}{size_suffix}" if estimated_params != "Unknown" else "Unknown", | |
'kv_heads': config.get('num_key_value_heads', config.get('num_heads', 'N/A')), | |
'head_dim': config.get('head_dim', config.get('qk_nope_head_dim', 'N/A')), | |
'activation': config.get('hidden_act', config.get('activation_function', 'N/A')) | |
} | |
def add_container(self, fig: go.Figure, x: float, y: float, width: float, height: float, | |
color: str, line_width: int = 1, row: int = 1, col: int = 1) -> None: | |
"""Add a container/background box""" | |
fig.add_shape( | |
type="rect", | |
x0=x, y0=y, x1=x + width, y1=y + height, | |
fillcolor=color, | |
line=dict(color='black', width=line_width), | |
layer="below", | |
row=row, col=col | |
) | |
def add_layer_box(self, fig: go.Figure, x: float, y: float, width: float, height: float, | |
text: str, color: str, hover_text: str = None, row: int = 1, col: int = 1, | |
text_size: int = 7) -> None: | |
"""Add a rounded rectangle representing a layer""" | |
# Add the box shape | |
fig.add_shape( | |
type="rect", | |
x0=x, y0=y, x1=x + width, y1=y + height, | |
fillcolor=color, | |
line=dict(color='black', width=1), | |
layer="below", | |
row=row, col=col | |
) | |
# Add text label | |
fig.add_annotation( | |
x=x + width / 2, | |
y=y + height / 2, | |
text=text, | |
showarrow=False, | |
font=dict(size=text_size, color=self.universal_colors['text']), | |
bgcolor=self.universal_colors['callout_bg'], | |
bordercolor="black", | |
borderwidth=1, | |
row=row, col=col | |
) | |
# Add invisible scatter point for hover functionality | |
if hover_text: | |
fig.add_trace(go.Scatter( | |
x=[x + width / 2], | |
y=[y + height / 2], | |
mode='markers', | |
marker=dict(size=12, opacity=0), | |
hovertemplate=f"<b>{text}</b><br>{hover_text}<extra></extra>", | |
showlegend=False, | |
name=text | |
), row=row, col=col) | |
def add_moe_router_visualization(self, fig: go.Figure, x: float, y: float, | |
config_values: Dict[str, Any], row: int = 1, col: int = 1) -> None: | |
"""Add MoE router and expert visualization with improved layout""" | |
moe_params = config_values['moe_params'] | |
# Router box - positioned more centrally | |
router_width, router_height = 0.4, 0.12 | |
router_x = x + 0.2 # Center it better within the available space | |
self.add_layer_box( | |
fig, router_x, y, router_width, router_height, | |
"Router", self.universal_colors['router'], | |
f"{moe_params['experts_per_token']} experts activated <br>from {moe_params['num_experts']} total", | |
row, col, 6 | |
) | |
# Expert boxes - positioned with better spacing | |
expert_y = y - 0.25 # Closer to router | |
expert_width, expert_height = 0.18, 0.1 | |
experts_to_show = min(3, int(moe_params['experts_per_token']) if isinstance(moe_params['experts_per_token'], | |
int) else 3) | |
# Center the experts under the router | |
total_expert_width = experts_to_show * expert_width + (experts_to_show - 1) * 0.04 | |
experts_start_x = router_x + (router_width - total_expert_width) / 2 | |
for i in range(experts_to_show): | |
expert_x = experts_start_x + i * (expert_width + 0.04) | |
self.add_layer_box( | |
fig, expert_x, expert_y, expert_width, expert_height, | |
f"Expert\n{i + 1}", self.universal_colors['expert'], | |
f"MoE intermediate size: {moe_params['moe_intermediate_size']}", | |
row, col, 5 | |
) | |
# Arrow from router to expert - pointing downward | |
self.add_connection_arrow( | |
fig, router_x + router_width / 2, y, | |
expert_x + expert_width / 2, expert_y + expert_height, row, col | |
) | |
# Add "..." if more experts exist - positioned to the right | |
if experts_to_show < int(moe_params['experts_per_token']) if isinstance(moe_params['experts_per_token'], | |
int) else False: | |
fig.add_annotation( | |
x=experts_start_x + experts_to_show * (expert_width + 0.04) + 0.05, | |
y=expert_y + expert_height / 2, | |
text="...", | |
showarrow=False, | |
font=dict(size=8, color=self.universal_colors['text']), | |
row=row, col=col | |
) | |
def add_side_panel(self, fig: go.Figure, x: float, y: float, width: float, height: float, | |
title: str, components: List[str], config_values: Dict[str, Any], | |
row: int = 1, col: int = 1) -> None: | |
"""Add a side panel with component breakdown""" | |
# Panel container with dashed border | |
fig.add_shape( | |
type="rect", | |
x0=x, y0=y, x1=x + width, y1=y + height, | |
fillcolor=self.universal_colors['callout_bg'], | |
line=dict(color='gray', width=1, dash='dash'), | |
layer="below", | |
row=row, col=col | |
) | |
# Panel title | |
fig.add_annotation( | |
x=x + width / 2, y=y + height - 0.08, | |
text=f"<b>{title}</b>", | |
showarrow=False, | |
font=dict(size=8, color=self.universal_colors['text']), | |
row=row, col=col | |
) | |
# Component boxes | |
component_height = 0.1 | |
start_y = y + height - 0.2 | |
for i, component in enumerate(components): | |
comp_y = start_y - i * (component_height + 0.03) | |
if "Linear" in component: | |
color = self.universal_colors['output'] | |
elif "activation" in component.lower() or "SiLU" in component or "ReLU" in component: | |
color = self.universal_colors['feedforward_moe'] if config_values['is_moe'] else self.universal_colors[ | |
'feedforward_dense'] | |
else: | |
color = self.universal_colors['embedding'] | |
self.add_layer_box( | |
fig, x + 0.03, comp_y, width - 0.06, component_height, | |
component, color, None, row, col, 6 | |
) | |
def add_connection_arrow(self, fig: go.Figure, start_x: float, start_y: float, | |
end_x: float, end_y: float, row: int = 1, col: int = 1) -> None: | |
"""Add an arrow between layers""" | |
fig.add_annotation( | |
x=end_x, y=end_y, | |
ax=start_x, ay=start_y, | |
xref=f'x{col}' if col > 1 else 'x', | |
yref=f'y{row}' if row > 1 else 'y', | |
axref=f'x{col}' if col > 1 else 'x', | |
ayref=f'y{row}' if row > 1 else 'y', | |
showarrow=True, | |
arrowhead=2, | |
arrowsize=1, | |
arrowwidth=1.5, | |
arrowcolor='black' | |
) | |
def create_single_model_diagram(self, fig: go.Figure, model_name: str, | |
config_values: Dict[str, Any], row: int = 1, col: int = 1) -> None: | |
"""Add a single model's architecture to the subplot with improved layout""" | |
# Layout parameters | |
base_x, base_y = 0.3, 0.2 | |
main_width, main_height = 2.2, 2.8 | |
layer_width, layer_height = 1.8, 0.2 | |
# Model title with size | |
model_display_name = model_name.split('/')[-1] if '/' in model_name else model_name | |
title_text = f"<b>{model_display_name}</b>" | |
if config_values['estimated_size'] != "Unknown": | |
title_text += f" ({config_values['estimated_size']})" | |
fig.add_annotation( | |
x=base_x + main_width / 2, y=base_y + main_height + 0.2, | |
text=title_text, | |
showarrow=False, | |
font=dict(size=10, color=self.universal_colors['accent_blue']), | |
row=row, col=col | |
) | |
# Outer container (main frame) | |
self.add_container( | |
fig, base_x - 0.1, base_y - 0.1, main_width + 0.2, main_height + 0.2, | |
self.universal_colors['container_outer'], 2, row, col | |
) | |
# Inner container (colored by architecture type) | |
inner_color = (self.universal_colors['moe_inner'] if config_values['is_moe'] | |
else self.universal_colors['dense_inner']) | |
self.add_container( | |
fig, base_x + 0.1, base_y + 0.8, main_width - 0.2, main_height - 1.2, | |
inner_color, 1, row, col | |
) | |
# Layer definitions with universal colors | |
layers = [ | |
('Token Embedding', base_y + 0.3, self.universal_colors['embedding'], | |
f"Vocab: {config_values['vocab_size']:,}<br>Embedding dim: {config_values['hidden_size']}"), | |
('Layer Norm', base_y + 0.6, self.universal_colors['layer_norm'], | |
'Input normalization'), | |
(f'Multi-Head Attention\n({config_values["num_heads"]} heads)', | |
base_y + 0.9, self.universal_colors['attention'], | |
f"Heads: {config_values['num_heads']}<br>Hidden: {config_values['hidden_size']}<br>KV Heads: {config_values['kv_heads']}"), | |
('Layer Norm', base_y + 1.2, self.universal_colors['layer_norm'], | |
'Post-attention norm'), | |
] | |
# Add MoE or Dense FFN layer | |
if config_values['is_moe']: | |
layers.append(( | |
'MoE Feed Forward', | |
base_y + 1.5, self.universal_colors['feedforward_moe'], | |
f"Experts: {config_values['moe_params']['num_experts']}<br>Active per token: {config_values['moe_params']['experts_per_token']}<br>MoE intermediate: {config_values['moe_params']['moe_intermediate_size']}" | |
)) | |
else: | |
layers.append(( | |
'Feed Forward Network', | |
base_y + 1.5, self.universal_colors['feedforward_dense'], | |
f"Intermediate size: {config_values['intermediate_size']}<br>Activation: {config_values['activation']}" | |
)) | |
layers.extend([ | |
('Layer Norm', base_y + 1.8, self.universal_colors['layer_norm'], | |
'Post-FFN normalization'), | |
('Output Projection', base_y + 2.1, self.universal_colors['output'], | |
f"Projects to vocab: {config_values['vocab_size']:,}") | |
]) | |
# Add all layers | |
layer_centers = [] | |
for layer_name, y_pos, color, hover_info in layers: | |
layer_x = base_x + (main_width - layer_width) / 2 | |
self.add_layer_box( | |
fig, layer_x, y_pos, layer_width, layer_height, | |
layer_name, color, hover_info, row, col | |
) | |
layer_centers.append((layer_x + layer_width / 2, y_pos + layer_height / 2)) | |
# Add arrows between layers | |
for i in range(len(layer_centers) - 1): | |
start_x, start_y = layer_centers[i] | |
end_x, end_y = layer_centers[i + 1] | |
arrow_start_y = start_y + layer_height / 2 | |
arrow_end_y = end_y - layer_height / 2 | |
if arrow_end_y > arrow_start_y: | |
self.add_connection_arrow(fig, start_x, arrow_start_y, end_x, arrow_end_y, row, col) | |
# Add layer repetition indicator | |
if isinstance(config_values['num_layers'], int) and config_values['num_layers'] > 1: | |
fig.add_annotation( | |
x=base_x - 0.05, y=base_y + 1.4, | |
text=f"Γ{config_values['num_layers']}<br>layers", | |
showarrow=False, | |
font=dict(size=7, color=self.universal_colors['text']), | |
bgcolor=self.universal_colors['callout_bg'], | |
bordercolor="black", borderwidth=1, | |
row=row, col=col | |
) | |
# Add side panel for component details | |
panel_x = base_x + main_width + 0.3 | |
panel_y = base_y + 1.5 # Moved up to avoid MoE visualization | |
panel_width = 0.7 | |
panel_height = 0.8 | |
if config_values['is_moe']: | |
# MoE side panel | |
components = [ | |
"Linear layer", | |
f"{config_values['activation'].upper()} activation", | |
"Linear layer", | |
"Router", | |
f"{config_values['moe_params']['experts_per_token']} active experts" | |
] | |
panel_title = "MoE Module" | |
else: | |
# Dense FFN side panel | |
components = [ | |
"Linear layer", | |
f"{config_values['activation'].upper()} activation", | |
"Linear layer" | |
] | |
panel_title = "FeedForward Module" | |
self.add_side_panel(fig, panel_x, panel_y, panel_width, panel_height, | |
panel_title, components, config_values, row, col) | |
# Add MoE router visualization if applicable | |
if config_values['is_moe']: | |
# Position router visualization below side panel with better spacing | |
router_x = panel_x + 0.05 | |
router_y = panel_y - 0.5 | |
self.add_moe_router_visualization(fig, router_x, router_y, config_values, row, col) | |
def add_callout(self, fig: go.Figure, point_x: float, point_y: float, | |
text_x: float, text_y: float, text: str, row: int = 1, col: int = 1) -> None: | |
"""Add a callout with leader line - arrow points FROM point TO text""" | |
fig.add_annotation( | |
x=text_x, y=text_y, # Text position | |
ax=point_x, ay=point_y, # Arrow start position (the component being referenced) | |
text=text, | |
showarrow=True, | |
arrowhead=2, arrowsize=1, arrowwidth=1, | |
arrowcolor='gray', | |
font=dict(size=7), | |
bgcolor=self.universal_colors['callout_bg'], | |
bordercolor="gray", borderwidth=1, | |
xref=f'x{col}' if col > 1 else 'x', | |
yref=f'y{row}' if row > 1 else 'y', | |
axref=f'x{col}' if col > 1 else 'x', | |
ayref=f'y{row}' if row > 1 else 'y' | |
) | |
def create_comparison_diagram(self, models_data: List[Tuple[str, Dict[str, Any]]]) -> go.Figure: | |
"""Create comparison diagram for multiple models""" | |
num_models = len(models_data) | |
if num_models == 0: | |
return go.Figure() | |
# Create subplots - always use single row layout | |
if num_models == 1: | |
fig = make_subplots(rows=1, cols=1, subplot_titles=[models_data[0][0]]) | |
elif num_models == 2: | |
fig = make_subplots(rows=1, cols=2, | |
subplot_titles=[model[0] for model in models_data]) | |
else: # 3 models | |
fig = make_subplots(rows=1, cols=3, | |
subplot_titles=[model[0] for model in models_data]) | |
# Set up layout | |
fig.update_layout( | |
height=700, | |
width=1200, | |
showlegend=False, | |
title_text="π§ Model Architecture Comparison", | |
title_x=0.5, | |
title_font=dict(size=18) | |
) | |
# Add each model to its subplot | |
for i, (model_name, config_values) in enumerate(models_data): | |
row, col = 1, i + 1 | |
self.create_single_model_diagram(fig, model_name, config_values, row, col) | |
# Update axes to hide ticks and labels - expanded range for callouts | |
fig.update_xaxes(showgrid=False, showticklabels=False, zeroline=False, range=[0, 5.0]) | |
fig.update_yaxes(showgrid=False, showticklabels=False, zeroline=False, range=[-0.5, 3.5]) | |
return fig | |
def generate_visualization(self, model_names: List[str]) -> Union[go.Figure, str]: | |
"""Generate visualization for given models""" | |
# Filter out empty model names | |
valid_models = [name.strip() for name in model_names if name and name.strip()] | |
if not valid_models: | |
return "Please enter at least one model name." | |
models_data = [] | |
errors = [] | |
for model_name in valid_models: | |
try: | |
config = self.get_model_config(model_name) | |
if config: | |
config_values = self.extract_config_values(config) | |
models_data.append((model_name, config_values)) | |
else: | |
errors.append(f"Could not load config for {model_name}") | |
except Exception as e: | |
errors.append(f"Error with {model_name}: {str(e)}") | |
if not models_data: | |
return f"β Could not load any models. Errors: {'; '.join(errors)}" | |
if errors: | |
logger.warning(f"Some models failed to load: {errors}") | |
try: | |
fig = self.create_comparison_diagram(models_data) | |
return fig | |
except Exception as e: | |
return f"β Error generating diagram: {str(e)}" | |
def create_gradio_interface(): | |
"""Create and configure the Gradio interface""" | |
visualizer = PlotlyModelArchitectureVisualizer() | |
def process_models(model1: str, model2: str = "", model3: str = "") -> Union[go.Figure, str]: | |
"""Process the model inputs and generate visualization""" | |
models = [model1, model2, model3] | |
return visualizer.generate_visualization(models) | |
# Create the interface | |
with gr.Blocks( | |
title="π§ Model Architecture Visualizer", | |
theme=gr.themes.Soft(), | |
css=""" | |
.gradio-container { | |
max-width: 1200px !important; | |
} | |
.model-input { | |
font-family: monospace; | |
} | |
""" | |
) as demo: | |
gr.Markdown(""" | |
# π§ Interactive Model Architecture Visualizer | |
Compare up to 3 Hugging Face transformer models side-by-side! | |
Enter model IDs below to see their architecture diagrams with interactive features. | |
### π How to Use | |
1. **Enter Model IDs**: Use Hugging Face model identifiers (e.g., `moonshotai/Kimi-K2-Base`, `openai/gpt-oss-120b`, `deepseek-ai/DeepSeek-R1-0528`) | |
2. **Compare Models**: Add up to 3 models to see them side-by-side | |
3. **Explore Interactively**: Hover over components to see detailed specifications | |
""") | |
# Model inputs in a single row | |
gr.Markdown("### π Model Configuration") | |
with gr.Row(): | |
model1 = gr.Textbox( | |
label="Model 1 (Required)", | |
placeholder="e.g., openai/gpt-oss-120b", | |
value="openai/gpt-oss-120b", | |
elem_classes=["model-input"] | |
) | |
model2 = gr.Textbox( | |
label="Model 2 (Optional)", | |
placeholder="e.g., moonshotai/Kimi-K2-Base", | |
elem_classes=["model-input"] | |
) | |
model3 = gr.Textbox( | |
label="Model 3 (Optional)", | |
placeholder="e.g., deepseek-ai/DeepSeek-R1-0528", | |
elem_classes=["model-input"] | |
) | |
with gr.Row(): | |
generate_btn = gr.Button("π Generate Visualization", variant="primary", size="lg") | |
clear_btn = gr.Button("ποΈ Clear", variant="secondary") | |
# Visualization output - full width | |
output_plot = gr.Plot( | |
label="π§ Architecture Visualization", | |
show_label=True | |
) | |
# Event handlers | |
generate_btn.click( | |
fn=process_models, | |
inputs=[model1, model2, model3], | |
outputs=output_plot | |
) | |
clear_btn.click( | |
fn=lambda: ("", "", "", None), | |
outputs=[model1, model2, model3, output_plot] | |
) | |
# Auto-generate for default model | |
demo.load( | |
fn=lambda: process_models("openai/gpt-oss-120b"), | |
outputs=output_plot | |
) | |
gr.Markdown("""Built with β€οΈ using Plotly, Gradio, and Hugging Face Transformers""") | |
return demo | |
if __name__ == "__main__": | |
# Create and launch the app | |
demo = create_gradio_interface() | |
# For HuggingFace Spaces deployment | |
demo.launch( | |
share=False, | |
server_name="0.0.0.0", | |
server_port=7860, | |
show_error=True | |
) |