barrosoluqueroberto's picture
Update model inspector
56bbc8d verified
import gradio as gr
from typing import Dict, Any, Optional, List, Tuple, Union
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from transformers import AutoConfig
import logging
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class PlotlyModelArchitectureVisualizer:
def __init__(self, hf_token: Optional[str] = None):
self.config = None
self.hf_token = hf_token
# Universal color scheme - consistent across all models
self.universal_colors = {
'embedding': '#f8f9fa', # Light gray for embeddings
'layer_norm': '#e9ecef', # Light gray for layer norms
'attention': '#495057', # Dark gray for attention
'output': '#f8f9fa', # Light gray for output layers
'text': '#212529', # Dark text
'container_outer': '#dee2e6', # Outer container
'moe_inner': '#d4edda', # Green background for MoE models
'dense_inner': '#f8d7da', # Red background for dense models
'feedforward_moe': '#28a745', # Green for MoE FFN
'feedforward_dense': '#dc3545', # Red for dense FFN
'router': '#fd7e14', # Orange for router
'expert': '#20c997', # Teal for experts
'callout_bg': 'rgba(255,255,255,0.9)',
'accent_blue': '#007bff',
'accent_green': '#28a745',
'accent_red': '#dc3545'
}
def get_model_config(self, model_name: str) -> Dict[str, Any]:
"""Fetch model configuration from Hugging Face"""
try:
logger.info(f"Fetching config for {model_name}")
config = AutoConfig.from_pretrained(model_name, token=self.hf_token, trust_remote_code=True)
return config.to_dict()
except Exception as e:
logger.error(f"Error fetching config for {model_name}: {e}")
return {}
def extract_config_values(self, config: Dict[str, Any]) -> Dict[str, Any]:
"""Extract and normalize configuration values with architecture detection"""
# Detect model architecture type
model_type = config.get('model_type', 'unknown').lower()
is_moe = any(key in config for key in [
'num_experts', 'n_routed_experts', 'moe_intermediate_size',
'num_experts_per_tok', 'router_aux_loss_coef'
])
# Extract MoE-specific parameters
moe_params = {}
if is_moe:
moe_params = {
'num_experts': config.get('num_experts', config.get('n_routed_experts', 'N/A')),
'experts_per_token': config.get('num_experts_per_tok', 'N/A'),
'moe_intermediate_size': config.get('moe_intermediate_size', 'N/A'),
'router_aux_loss': config.get('router_aux_loss_coef', config.get('aux_loss_alpha', 'N/A')),
'shared_experts': config.get('n_shared_experts', 0)
}
# Calculate model size estimate (simplified)
hidden_size = config.get('hidden_size', config.get('d_model', config.get('n_embd', 0)))
num_layers = config.get('num_hidden_layers', config.get('n_layer', config.get('num_layers', 0)))
vocab_size = config.get('vocab_size', 0)
if isinstance(hidden_size, int) and isinstance(num_layers, int) and isinstance(vocab_size, int):
# Very rough parameter count estimation
if is_moe:
# MoE models are much larger but use fewer parameters per token
estimated_params = (hidden_size * num_layers * vocab_size) // 1000000 # Simplified
size_suffix = "B" if estimated_params > 1000 else "M"
estimated_params = estimated_params // 1000 if estimated_params > 1000 else estimated_params
else:
estimated_params = (hidden_size * num_layers * vocab_size) // 1000000
size_suffix = "B" if estimated_params > 1000 else "M"
estimated_params = estimated_params // 1000 if estimated_params > 1000 else estimated_params
else:
estimated_params = "Unknown"
size_suffix = ""
return {
'model_type': config.get('model_type', 'unknown'),
'hidden_size': hidden_size if hidden_size != 0 else 'N/A',
'num_layers': num_layers if num_layers != 0 else 'N/A',
'num_heads': config.get('num_attention_heads', config.get('n_head', config.get('num_heads', 'N/A'))),
'vocab_size': vocab_size if vocab_size != 0 else 'N/A',
'max_position': config.get('max_position_embeddings',
config.get('n_positions', config.get('max_seq_len', 'N/A'))),
'intermediate_size': config.get('intermediate_size',
config.get('d_ff', hidden_size if hidden_size != 0 else 'N/A')),
'is_moe': is_moe,
'moe_params': moe_params,
'estimated_size': f"{estimated_params}{size_suffix}" if estimated_params != "Unknown" else "Unknown",
'kv_heads': config.get('num_key_value_heads', config.get('num_heads', 'N/A')),
'head_dim': config.get('head_dim', config.get('qk_nope_head_dim', 'N/A')),
'activation': config.get('hidden_act', config.get('activation_function', 'N/A'))
}
def add_container(self, fig: go.Figure, x: float, y: float, width: float, height: float,
color: str, line_width: int = 1, row: int = 1, col: int = 1) -> None:
"""Add a container/background box"""
fig.add_shape(
type="rect",
x0=x, y0=y, x1=x + width, y1=y + height,
fillcolor=color,
line=dict(color='black', width=line_width),
layer="below",
row=row, col=col
)
def add_layer_box(self, fig: go.Figure, x: float, y: float, width: float, height: float,
text: str, color: str, hover_text: str = None, row: int = 1, col: int = 1,
text_size: int = 7) -> None:
"""Add a rounded rectangle representing a layer"""
# Add the box shape
fig.add_shape(
type="rect",
x0=x, y0=y, x1=x + width, y1=y + height,
fillcolor=color,
line=dict(color='black', width=1),
layer="below",
row=row, col=col
)
# Add text label
fig.add_annotation(
x=x + width / 2,
y=y + height / 2,
text=text,
showarrow=False,
font=dict(size=text_size, color=self.universal_colors['text']),
bgcolor=self.universal_colors['callout_bg'],
bordercolor="black",
borderwidth=1,
row=row, col=col
)
# Add invisible scatter point for hover functionality
if hover_text:
fig.add_trace(go.Scatter(
x=[x + width / 2],
y=[y + height / 2],
mode='markers',
marker=dict(size=12, opacity=0),
hovertemplate=f"<b>{text}</b><br>{hover_text}<extra></extra>",
showlegend=False,
name=text
), row=row, col=col)
def add_moe_router_visualization(self, fig: go.Figure, x: float, y: float,
config_values: Dict[str, Any], row: int = 1, col: int = 1) -> None:
"""Add MoE router and expert visualization with improved layout"""
moe_params = config_values['moe_params']
# Router box - positioned more centrally
router_width, router_height = 0.4, 0.12
router_x = x + 0.2 # Center it better within the available space
self.add_layer_box(
fig, router_x, y, router_width, router_height,
"Router", self.universal_colors['router'],
f"{moe_params['experts_per_token']} experts activated <br>from {moe_params['num_experts']} total",
row, col, 6
)
# Expert boxes - positioned with better spacing
expert_y = y - 0.25 # Closer to router
expert_width, expert_height = 0.18, 0.1
experts_to_show = min(3, int(moe_params['experts_per_token']) if isinstance(moe_params['experts_per_token'],
int) else 3)
# Center the experts under the router
total_expert_width = experts_to_show * expert_width + (experts_to_show - 1) * 0.04
experts_start_x = router_x + (router_width - total_expert_width) / 2
for i in range(experts_to_show):
expert_x = experts_start_x + i * (expert_width + 0.04)
self.add_layer_box(
fig, expert_x, expert_y, expert_width, expert_height,
f"Expert\n{i + 1}", self.universal_colors['expert'],
f"MoE intermediate size: {moe_params['moe_intermediate_size']}",
row, col, 5
)
# Arrow from router to expert - pointing downward
self.add_connection_arrow(
fig, router_x + router_width / 2, y,
expert_x + expert_width / 2, expert_y + expert_height, row, col
)
# Add "..." if more experts exist - positioned to the right
if experts_to_show < int(moe_params['experts_per_token']) if isinstance(moe_params['experts_per_token'],
int) else False:
fig.add_annotation(
x=experts_start_x + experts_to_show * (expert_width + 0.04) + 0.05,
y=expert_y + expert_height / 2,
text="...",
showarrow=False,
font=dict(size=8, color=self.universal_colors['text']),
row=row, col=col
)
def add_side_panel(self, fig: go.Figure, x: float, y: float, width: float, height: float,
title: str, components: List[str], config_values: Dict[str, Any],
row: int = 1, col: int = 1) -> None:
"""Add a side panel with component breakdown"""
# Panel container with dashed border
fig.add_shape(
type="rect",
x0=x, y0=y, x1=x + width, y1=y + height,
fillcolor=self.universal_colors['callout_bg'],
line=dict(color='gray', width=1, dash='dash'),
layer="below",
row=row, col=col
)
# Panel title
fig.add_annotation(
x=x + width / 2, y=y + height - 0.08,
text=f"<b>{title}</b>",
showarrow=False,
font=dict(size=8, color=self.universal_colors['text']),
row=row, col=col
)
# Component boxes
component_height = 0.1
start_y = y + height - 0.2
for i, component in enumerate(components):
comp_y = start_y - i * (component_height + 0.03)
if "Linear" in component:
color = self.universal_colors['output']
elif "activation" in component.lower() or "SiLU" in component or "ReLU" in component:
color = self.universal_colors['feedforward_moe'] if config_values['is_moe'] else self.universal_colors[
'feedforward_dense']
else:
color = self.universal_colors['embedding']
self.add_layer_box(
fig, x + 0.03, comp_y, width - 0.06, component_height,
component, color, None, row, col, 6
)
def add_connection_arrow(self, fig: go.Figure, start_x: float, start_y: float,
end_x: float, end_y: float, row: int = 1, col: int = 1) -> None:
"""Add an arrow between layers"""
fig.add_annotation(
x=end_x, y=end_y,
ax=start_x, ay=start_y,
xref=f'x{col}' if col > 1 else 'x',
yref=f'y{row}' if row > 1 else 'y',
axref=f'x{col}' if col > 1 else 'x',
ayref=f'y{row}' if row > 1 else 'y',
showarrow=True,
arrowhead=2,
arrowsize=1,
arrowwidth=1.5,
arrowcolor='black'
)
def create_single_model_diagram(self, fig: go.Figure, model_name: str,
config_values: Dict[str, Any], row: int = 1, col: int = 1) -> None:
"""Add a single model's architecture to the subplot with improved layout"""
# Layout parameters
base_x, base_y = 0.3, 0.2
main_width, main_height = 2.2, 2.8
layer_width, layer_height = 1.8, 0.2
# Model title with size
model_display_name = model_name.split('/')[-1] if '/' in model_name else model_name
title_text = f"<b>{model_display_name}</b>"
if config_values['estimated_size'] != "Unknown":
title_text += f" ({config_values['estimated_size']})"
fig.add_annotation(
x=base_x + main_width / 2, y=base_y + main_height + 0.2,
text=title_text,
showarrow=False,
font=dict(size=10, color=self.universal_colors['accent_blue']),
row=row, col=col
)
# Outer container (main frame)
self.add_container(
fig, base_x - 0.1, base_y - 0.1, main_width + 0.2, main_height + 0.2,
self.universal_colors['container_outer'], 2, row, col
)
# Inner container (colored by architecture type)
inner_color = (self.universal_colors['moe_inner'] if config_values['is_moe']
else self.universal_colors['dense_inner'])
self.add_container(
fig, base_x + 0.1, base_y + 0.8, main_width - 0.2, main_height - 1.2,
inner_color, 1, row, col
)
# Layer definitions with universal colors
layers = [
('Token Embedding', base_y + 0.3, self.universal_colors['embedding'],
f"Vocab: {config_values['vocab_size']:,}<br>Embedding dim: {config_values['hidden_size']}"),
('Layer Norm', base_y + 0.6, self.universal_colors['layer_norm'],
'Input normalization'),
(f'Multi-Head Attention\n({config_values["num_heads"]} heads)',
base_y + 0.9, self.universal_colors['attention'],
f"Heads: {config_values['num_heads']}<br>Hidden: {config_values['hidden_size']}<br>KV Heads: {config_values['kv_heads']}"),
('Layer Norm', base_y + 1.2, self.universal_colors['layer_norm'],
'Post-attention norm'),
]
# Add MoE or Dense FFN layer
if config_values['is_moe']:
layers.append((
'MoE Feed Forward',
base_y + 1.5, self.universal_colors['feedforward_moe'],
f"Experts: {config_values['moe_params']['num_experts']}<br>Active per token: {config_values['moe_params']['experts_per_token']}<br>MoE intermediate: {config_values['moe_params']['moe_intermediate_size']}"
))
else:
layers.append((
'Feed Forward Network',
base_y + 1.5, self.universal_colors['feedforward_dense'],
f"Intermediate size: {config_values['intermediate_size']}<br>Activation: {config_values['activation']}"
))
layers.extend([
('Layer Norm', base_y + 1.8, self.universal_colors['layer_norm'],
'Post-FFN normalization'),
('Output Projection', base_y + 2.1, self.universal_colors['output'],
f"Projects to vocab: {config_values['vocab_size']:,}")
])
# Add all layers
layer_centers = []
for layer_name, y_pos, color, hover_info in layers:
layer_x = base_x + (main_width - layer_width) / 2
self.add_layer_box(
fig, layer_x, y_pos, layer_width, layer_height,
layer_name, color, hover_info, row, col
)
layer_centers.append((layer_x + layer_width / 2, y_pos + layer_height / 2))
# Add arrows between layers
for i in range(len(layer_centers) - 1):
start_x, start_y = layer_centers[i]
end_x, end_y = layer_centers[i + 1]
arrow_start_y = start_y + layer_height / 2
arrow_end_y = end_y - layer_height / 2
if arrow_end_y > arrow_start_y:
self.add_connection_arrow(fig, start_x, arrow_start_y, end_x, arrow_end_y, row, col)
# Add layer repetition indicator
if isinstance(config_values['num_layers'], int) and config_values['num_layers'] > 1:
fig.add_annotation(
x=base_x - 0.05, y=base_y + 1.4,
text=f"Γ—{config_values['num_layers']}<br>layers",
showarrow=False,
font=dict(size=7, color=self.universal_colors['text']),
bgcolor=self.universal_colors['callout_bg'],
bordercolor="black", borderwidth=1,
row=row, col=col
)
# Add side panel for component details
panel_x = base_x + main_width + 0.3
panel_y = base_y + 1.5 # Moved up to avoid MoE visualization
panel_width = 0.7
panel_height = 0.8
if config_values['is_moe']:
# MoE side panel
components = [
"Linear layer",
f"{config_values['activation'].upper()} activation",
"Linear layer",
"Router",
f"{config_values['moe_params']['experts_per_token']} active experts"
]
panel_title = "MoE Module"
else:
# Dense FFN side panel
components = [
"Linear layer",
f"{config_values['activation'].upper()} activation",
"Linear layer"
]
panel_title = "FeedForward Module"
self.add_side_panel(fig, panel_x, panel_y, panel_width, panel_height,
panel_title, components, config_values, row, col)
# Add MoE router visualization if applicable
if config_values['is_moe']:
# Position router visualization below side panel with better spacing
router_x = panel_x + 0.05
router_y = panel_y - 0.5
self.add_moe_router_visualization(fig, router_x, router_y, config_values, row, col)
def add_callout(self, fig: go.Figure, point_x: float, point_y: float,
text_x: float, text_y: float, text: str, row: int = 1, col: int = 1) -> None:
"""Add a callout with leader line - arrow points FROM point TO text"""
fig.add_annotation(
x=text_x, y=text_y, # Text position
ax=point_x, ay=point_y, # Arrow start position (the component being referenced)
text=text,
showarrow=True,
arrowhead=2, arrowsize=1, arrowwidth=1,
arrowcolor='gray',
font=dict(size=7),
bgcolor=self.universal_colors['callout_bg'],
bordercolor="gray", borderwidth=1,
xref=f'x{col}' if col > 1 else 'x',
yref=f'y{row}' if row > 1 else 'y',
axref=f'x{col}' if col > 1 else 'x',
ayref=f'y{row}' if row > 1 else 'y'
)
def create_comparison_diagram(self, models_data: List[Tuple[str, Dict[str, Any]]]) -> go.Figure:
"""Create comparison diagram for multiple models"""
num_models = len(models_data)
if num_models == 0:
return go.Figure()
# Create subplots - always use single row layout
if num_models == 1:
fig = make_subplots(rows=1, cols=1, subplot_titles=[models_data[0][0]])
elif num_models == 2:
fig = make_subplots(rows=1, cols=2,
subplot_titles=[model[0] for model in models_data])
else: # 3 models
fig = make_subplots(rows=1, cols=3,
subplot_titles=[model[0] for model in models_data])
# Set up layout
fig.update_layout(
height=700,
width=1200,
showlegend=False,
title_text="🧠 Model Architecture Comparison",
title_x=0.5,
title_font=dict(size=18)
)
# Add each model to its subplot
for i, (model_name, config_values) in enumerate(models_data):
row, col = 1, i + 1
self.create_single_model_diagram(fig, model_name, config_values, row, col)
# Update axes to hide ticks and labels - expanded range for callouts
fig.update_xaxes(showgrid=False, showticklabels=False, zeroline=False, range=[0, 5.0])
fig.update_yaxes(showgrid=False, showticklabels=False, zeroline=False, range=[-0.5, 3.5])
return fig
def generate_visualization(self, model_names: List[str]) -> Union[go.Figure, str]:
"""Generate visualization for given models"""
# Filter out empty model names
valid_models = [name.strip() for name in model_names if name and name.strip()]
if not valid_models:
return "Please enter at least one model name."
models_data = []
errors = []
for model_name in valid_models:
try:
config = self.get_model_config(model_name)
if config:
config_values = self.extract_config_values(config)
models_data.append((model_name, config_values))
else:
errors.append(f"Could not load config for {model_name}")
except Exception as e:
errors.append(f"Error with {model_name}: {str(e)}")
if not models_data:
return f"❌ Could not load any models. Errors: {'; '.join(errors)}"
if errors:
logger.warning(f"Some models failed to load: {errors}")
try:
fig = self.create_comparison_diagram(models_data)
return fig
except Exception as e:
return f"❌ Error generating diagram: {str(e)}"
def create_gradio_interface():
"""Create and configure the Gradio interface"""
visualizer = PlotlyModelArchitectureVisualizer()
def process_models(model1: str, model2: str = "", model3: str = "") -> Union[go.Figure, str]:
"""Process the model inputs and generate visualization"""
models = [model1, model2, model3]
return visualizer.generate_visualization(models)
# Create the interface
with gr.Blocks(
title="🧠 Model Architecture Visualizer",
theme=gr.themes.Soft(),
css="""
.gradio-container {
max-width: 1200px !important;
}
.model-input {
font-family: monospace;
}
"""
) as demo:
gr.Markdown("""
# 🧠 Interactive Model Architecture Visualizer
Compare up to 3 Hugging Face transformer models side-by-side!
Enter model IDs below to see their architecture diagrams with interactive features.
### πŸ“‹ How to Use
1. **Enter Model IDs**: Use Hugging Face model identifiers (e.g., `moonshotai/Kimi-K2-Base`, `openai/gpt-oss-120b`, `deepseek-ai/DeepSeek-R1-0528`)
2. **Compare Models**: Add up to 3 models to see them side-by-side
3. **Explore Interactively**: Hover over components to see detailed specifications
""")
# Model inputs in a single row
gr.Markdown("### πŸ“ Model Configuration")
with gr.Row():
model1 = gr.Textbox(
label="Model 1 (Required)",
placeholder="e.g., openai/gpt-oss-120b",
value="openai/gpt-oss-120b",
elem_classes=["model-input"]
)
model2 = gr.Textbox(
label="Model 2 (Optional)",
placeholder="e.g., moonshotai/Kimi-K2-Base",
elem_classes=["model-input"]
)
model3 = gr.Textbox(
label="Model 3 (Optional)",
placeholder="e.g., deepseek-ai/DeepSeek-R1-0528",
elem_classes=["model-input"]
)
with gr.Row():
generate_btn = gr.Button("πŸš€ Generate Visualization", variant="primary", size="lg")
clear_btn = gr.Button("πŸ—‘οΈ Clear", variant="secondary")
# Visualization output - full width
output_plot = gr.Plot(
label="🧠 Architecture Visualization",
show_label=True
)
# Event handlers
generate_btn.click(
fn=process_models,
inputs=[model1, model2, model3],
outputs=output_plot
)
clear_btn.click(
fn=lambda: ("", "", "", None),
outputs=[model1, model2, model3, output_plot]
)
# Auto-generate for default model
demo.load(
fn=lambda: process_models("openai/gpt-oss-120b"),
outputs=output_plot
)
gr.Markdown("""Built with ❀️ using Plotly, Gradio, and Hugging Face Transformers""")
return demo
if __name__ == "__main__":
# Create and launch the app
demo = create_gradio_interface()
# For HuggingFace Spaces deployment
demo.launch(
share=False,
server_name="0.0.0.0",
server_port=7860,
show_error=True
)