import gradio as gr from typing import Dict, Any, Optional, List, Tuple, Union import plotly.graph_objects as go from plotly.subplots import make_subplots from transformers import AutoConfig import logging # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class PlotlyModelArchitectureVisualizer: def __init__(self, hf_token: Optional[str] = None): self.config = None self.hf_token = hf_token # Universal color scheme - consistent across all models self.universal_colors = { 'embedding': '#f8f9fa', # Light gray for embeddings 'layer_norm': '#e9ecef', # Light gray for layer norms 'attention': '#495057', # Dark gray for attention 'output': '#f8f9fa', # Light gray for output layers 'text': '#212529', # Dark text 'container_outer': '#dee2e6', # Outer container 'moe_inner': '#d4edda', # Green background for MoE models 'dense_inner': '#f8d7da', # Red background for dense models 'feedforward_moe': '#28a745', # Green for MoE FFN 'feedforward_dense': '#dc3545', # Red for dense FFN 'router': '#fd7e14', # Orange for router 'expert': '#20c997', # Teal for experts 'callout_bg': 'rgba(255,255,255,0.9)', 'accent_blue': '#007bff', 'accent_green': '#28a745', 'accent_red': '#dc3545' } def get_model_config(self, model_name: str) -> Dict[str, Any]: """Fetch model configuration from Hugging Face""" try: logger.info(f"Fetching config for {model_name}") config = AutoConfig.from_pretrained(model_name, token=self.hf_token, trust_remote_code=True) return config.to_dict() except Exception as e: logger.error(f"Error fetching config for {model_name}: {e}") return {} def extract_config_values(self, config: Dict[str, Any]) -> Dict[str, Any]: """Extract and normalize configuration values with architecture detection""" # Detect model architecture type model_type = config.get('model_type', 'unknown').lower() is_moe = any(key in config for key in [ 'num_experts', 'n_routed_experts', 'moe_intermediate_size', 'num_experts_per_tok', 'router_aux_loss_coef' ]) # Extract MoE-specific parameters moe_params = {} if is_moe: moe_params = { 'num_experts': config.get('num_experts', config.get('n_routed_experts', 'N/A')), 'experts_per_token': config.get('num_experts_per_tok', 'N/A'), 'moe_intermediate_size': config.get('moe_intermediate_size', 'N/A'), 'router_aux_loss': config.get('router_aux_loss_coef', config.get('aux_loss_alpha', 'N/A')), 'shared_experts': config.get('n_shared_experts', 0) } # Calculate model size estimate (simplified) hidden_size = config.get('hidden_size', config.get('d_model', config.get('n_embd', 0))) num_layers = config.get('num_hidden_layers', config.get('n_layer', config.get('num_layers', 0))) vocab_size = config.get('vocab_size', 0) if isinstance(hidden_size, int) and isinstance(num_layers, int) and isinstance(vocab_size, int): # Very rough parameter count estimation if is_moe: # MoE models are much larger but use fewer parameters per token estimated_params = (hidden_size * num_layers * vocab_size) // 1000000 # Simplified size_suffix = "B" if estimated_params > 1000 else "M" estimated_params = estimated_params // 1000 if estimated_params > 1000 else estimated_params else: estimated_params = (hidden_size * num_layers * vocab_size) // 1000000 size_suffix = "B" if estimated_params > 1000 else "M" estimated_params = estimated_params // 1000 if estimated_params > 1000 else estimated_params else: estimated_params = "Unknown" size_suffix = "" return { 'model_type': config.get('model_type', 'unknown'), 'hidden_size': hidden_size if hidden_size != 0 else 'N/A', 'num_layers': num_layers if num_layers != 0 else 'N/A', 'num_heads': config.get('num_attention_heads', config.get('n_head', config.get('num_heads', 'N/A'))), 'vocab_size': vocab_size if vocab_size != 0 else 'N/A', 'max_position': config.get('max_position_embeddings', config.get('n_positions', config.get('max_seq_len', 'N/A'))), 'intermediate_size': config.get('intermediate_size', config.get('d_ff', hidden_size if hidden_size != 0 else 'N/A')), 'is_moe': is_moe, 'moe_params': moe_params, 'estimated_size': f"{estimated_params}{size_suffix}" if estimated_params != "Unknown" else "Unknown", 'kv_heads': config.get('num_key_value_heads', config.get('num_heads', 'N/A')), 'head_dim': config.get('head_dim', config.get('qk_nope_head_dim', 'N/A')), 'activation': config.get('hidden_act', config.get('activation_function', 'N/A')) } def add_container(self, fig: go.Figure, x: float, y: float, width: float, height: float, color: str, line_width: int = 1, row: int = 1, col: int = 1) -> None: """Add a container/background box""" fig.add_shape( type="rect", x0=x, y0=y, x1=x + width, y1=y + height, fillcolor=color, line=dict(color='black', width=line_width), layer="below", row=row, col=col ) def add_layer_box(self, fig: go.Figure, x: float, y: float, width: float, height: float, text: str, color: str, hover_text: str = None, row: int = 1, col: int = 1, text_size: int = 7) -> None: """Add a rounded rectangle representing a layer""" # Add the box shape fig.add_shape( type="rect", x0=x, y0=y, x1=x + width, y1=y + height, fillcolor=color, line=dict(color='black', width=1), layer="below", row=row, col=col ) # Add text label fig.add_annotation( x=x + width / 2, y=y + height / 2, text=text, showarrow=False, font=dict(size=text_size, color=self.universal_colors['text']), bgcolor=self.universal_colors['callout_bg'], bordercolor="black", borderwidth=1, row=row, col=col ) # Add invisible scatter point for hover functionality if hover_text: fig.add_trace(go.Scatter( x=[x + width / 2], y=[y + height / 2], mode='markers', marker=dict(size=12, opacity=0), hovertemplate=f"{text}
{hover_text}", showlegend=False, name=text ), row=row, col=col) def add_moe_router_visualization(self, fig: go.Figure, x: float, y: float, config_values: Dict[str, Any], row: int = 1, col: int = 1) -> None: """Add MoE router and expert visualization with improved layout""" moe_params = config_values['moe_params'] # Router box - positioned more centrally router_width, router_height = 0.4, 0.12 router_x = x + 0.2 # Center it better within the available space self.add_layer_box( fig, router_x, y, router_width, router_height, "Router", self.universal_colors['router'], f"{moe_params['experts_per_token']} experts activated
from {moe_params['num_experts']} total", row, col, 6 ) # Expert boxes - positioned with better spacing expert_y = y - 0.25 # Closer to router expert_width, expert_height = 0.18, 0.1 experts_to_show = min(3, int(moe_params['experts_per_token']) if isinstance(moe_params['experts_per_token'], int) else 3) # Center the experts under the router total_expert_width = experts_to_show * expert_width + (experts_to_show - 1) * 0.04 experts_start_x = router_x + (router_width - total_expert_width) / 2 for i in range(experts_to_show): expert_x = experts_start_x + i * (expert_width + 0.04) self.add_layer_box( fig, expert_x, expert_y, expert_width, expert_height, f"Expert\n{i + 1}", self.universal_colors['expert'], f"MoE intermediate size: {moe_params['moe_intermediate_size']}", row, col, 5 ) # Arrow from router to expert - pointing downward self.add_connection_arrow( fig, router_x + router_width / 2, y, expert_x + expert_width / 2, expert_y + expert_height, row, col ) # Add "..." if more experts exist - positioned to the right if experts_to_show < int(moe_params['experts_per_token']) if isinstance(moe_params['experts_per_token'], int) else False: fig.add_annotation( x=experts_start_x + experts_to_show * (expert_width + 0.04) + 0.05, y=expert_y + expert_height / 2, text="...", showarrow=False, font=dict(size=8, color=self.universal_colors['text']), row=row, col=col ) def add_side_panel(self, fig: go.Figure, x: float, y: float, width: float, height: float, title: str, components: List[str], config_values: Dict[str, Any], row: int = 1, col: int = 1) -> None: """Add a side panel with component breakdown""" # Panel container with dashed border fig.add_shape( type="rect", x0=x, y0=y, x1=x + width, y1=y + height, fillcolor=self.universal_colors['callout_bg'], line=dict(color='gray', width=1, dash='dash'), layer="below", row=row, col=col ) # Panel title fig.add_annotation( x=x + width / 2, y=y + height - 0.08, text=f"{title}", showarrow=False, font=dict(size=8, color=self.universal_colors['text']), row=row, col=col ) # Component boxes component_height = 0.1 start_y = y + height - 0.2 for i, component in enumerate(components): comp_y = start_y - i * (component_height + 0.03) if "Linear" in component: color = self.universal_colors['output'] elif "activation" in component.lower() or "SiLU" in component or "ReLU" in component: color = self.universal_colors['feedforward_moe'] if config_values['is_moe'] else self.universal_colors[ 'feedforward_dense'] else: color = self.universal_colors['embedding'] self.add_layer_box( fig, x + 0.03, comp_y, width - 0.06, component_height, component, color, None, row, col, 6 ) def add_connection_arrow(self, fig: go.Figure, start_x: float, start_y: float, end_x: float, end_y: float, row: int = 1, col: int = 1) -> None: """Add an arrow between layers""" fig.add_annotation( x=end_x, y=end_y, ax=start_x, ay=start_y, xref=f'x{col}' if col > 1 else 'x', yref=f'y{row}' if row > 1 else 'y', axref=f'x{col}' if col > 1 else 'x', ayref=f'y{row}' if row > 1 else 'y', showarrow=True, arrowhead=2, arrowsize=1, arrowwidth=1.5, arrowcolor='black' ) def create_single_model_diagram(self, fig: go.Figure, model_name: str, config_values: Dict[str, Any], row: int = 1, col: int = 1) -> None: """Add a single model's architecture to the subplot with improved layout""" # Layout parameters base_x, base_y = 0.3, 0.2 main_width, main_height = 2.2, 2.8 layer_width, layer_height = 1.8, 0.2 # Model title with size model_display_name = model_name.split('/')[-1] if '/' in model_name else model_name title_text = f"{model_display_name}" if config_values['estimated_size'] != "Unknown": title_text += f" ({config_values['estimated_size']})" fig.add_annotation( x=base_x + main_width / 2, y=base_y + main_height + 0.2, text=title_text, showarrow=False, font=dict(size=10, color=self.universal_colors['accent_blue']), row=row, col=col ) # Outer container (main frame) self.add_container( fig, base_x - 0.1, base_y - 0.1, main_width + 0.2, main_height + 0.2, self.universal_colors['container_outer'], 2, row, col ) # Inner container (colored by architecture type) inner_color = (self.universal_colors['moe_inner'] if config_values['is_moe'] else self.universal_colors['dense_inner']) self.add_container( fig, base_x + 0.1, base_y + 0.8, main_width - 0.2, main_height - 1.2, inner_color, 1, row, col ) # Layer definitions with universal colors layers = [ ('Token Embedding', base_y + 0.3, self.universal_colors['embedding'], f"Vocab: {config_values['vocab_size']:,}
Embedding dim: {config_values['hidden_size']}"), ('Layer Norm', base_y + 0.6, self.universal_colors['layer_norm'], 'Input normalization'), (f'Multi-Head Attention\n({config_values["num_heads"]} heads)', base_y + 0.9, self.universal_colors['attention'], f"Heads: {config_values['num_heads']}
Hidden: {config_values['hidden_size']}
KV Heads: {config_values['kv_heads']}"), ('Layer Norm', base_y + 1.2, self.universal_colors['layer_norm'], 'Post-attention norm'), ] # Add MoE or Dense FFN layer if config_values['is_moe']: layers.append(( 'MoE Feed Forward', base_y + 1.5, self.universal_colors['feedforward_moe'], f"Experts: {config_values['moe_params']['num_experts']}
Active per token: {config_values['moe_params']['experts_per_token']}
MoE intermediate: {config_values['moe_params']['moe_intermediate_size']}" )) else: layers.append(( 'Feed Forward Network', base_y + 1.5, self.universal_colors['feedforward_dense'], f"Intermediate size: {config_values['intermediate_size']}
Activation: {config_values['activation']}" )) layers.extend([ ('Layer Norm', base_y + 1.8, self.universal_colors['layer_norm'], 'Post-FFN normalization'), ('Output Projection', base_y + 2.1, self.universal_colors['output'], f"Projects to vocab: {config_values['vocab_size']:,}") ]) # Add all layers layer_centers = [] for layer_name, y_pos, color, hover_info in layers: layer_x = base_x + (main_width - layer_width) / 2 self.add_layer_box( fig, layer_x, y_pos, layer_width, layer_height, layer_name, color, hover_info, row, col ) layer_centers.append((layer_x + layer_width / 2, y_pos + layer_height / 2)) # Add arrows between layers for i in range(len(layer_centers) - 1): start_x, start_y = layer_centers[i] end_x, end_y = layer_centers[i + 1] arrow_start_y = start_y + layer_height / 2 arrow_end_y = end_y - layer_height / 2 if arrow_end_y > arrow_start_y: self.add_connection_arrow(fig, start_x, arrow_start_y, end_x, arrow_end_y, row, col) # Add layer repetition indicator if isinstance(config_values['num_layers'], int) and config_values['num_layers'] > 1: fig.add_annotation( x=base_x - 0.05, y=base_y + 1.4, text=f"×{config_values['num_layers']}
layers", showarrow=False, font=dict(size=7, color=self.universal_colors['text']), bgcolor=self.universal_colors['callout_bg'], bordercolor="black", borderwidth=1, row=row, col=col ) # Add side panel for component details panel_x = base_x + main_width + 0.3 panel_y = base_y + 1.5 # Moved up to avoid MoE visualization panel_width = 0.7 panel_height = 0.8 if config_values['is_moe']: # MoE side panel components = [ "Linear layer", f"{config_values['activation'].upper()} activation", "Linear layer", "Router", f"{config_values['moe_params']['experts_per_token']} active experts" ] panel_title = "MoE Module" else: # Dense FFN side panel components = [ "Linear layer", f"{config_values['activation'].upper()} activation", "Linear layer" ] panel_title = "FeedForward Module" self.add_side_panel(fig, panel_x, panel_y, panel_width, panel_height, panel_title, components, config_values, row, col) # Add MoE router visualization if applicable if config_values['is_moe']: # Position router visualization below side panel with better spacing router_x = panel_x + 0.05 router_y = panel_y - 0.5 self.add_moe_router_visualization(fig, router_x, router_y, config_values, row, col) def add_callout(self, fig: go.Figure, point_x: float, point_y: float, text_x: float, text_y: float, text: str, row: int = 1, col: int = 1) -> None: """Add a callout with leader line - arrow points FROM point TO text""" fig.add_annotation( x=text_x, y=text_y, # Text position ax=point_x, ay=point_y, # Arrow start position (the component being referenced) text=text, showarrow=True, arrowhead=2, arrowsize=1, arrowwidth=1, arrowcolor='gray', font=dict(size=7), bgcolor=self.universal_colors['callout_bg'], bordercolor="gray", borderwidth=1, xref=f'x{col}' if col > 1 else 'x', yref=f'y{row}' if row > 1 else 'y', axref=f'x{col}' if col > 1 else 'x', ayref=f'y{row}' if row > 1 else 'y' ) def create_comparison_diagram(self, models_data: List[Tuple[str, Dict[str, Any]]]) -> go.Figure: """Create comparison diagram for multiple models""" num_models = len(models_data) if num_models == 0: return go.Figure() # Create subplots - always use single row layout if num_models == 1: fig = make_subplots(rows=1, cols=1, subplot_titles=[models_data[0][0]]) elif num_models == 2: fig = make_subplots(rows=1, cols=2, subplot_titles=[model[0] for model in models_data]) else: # 3 models fig = make_subplots(rows=1, cols=3, subplot_titles=[model[0] for model in models_data]) # Set up layout fig.update_layout( height=700, width=1200, showlegend=False, title_text="🧠 Model Architecture Comparison", title_x=0.5, title_font=dict(size=18) ) # Add each model to its subplot for i, (model_name, config_values) in enumerate(models_data): row, col = 1, i + 1 self.create_single_model_diagram(fig, model_name, config_values, row, col) # Update axes to hide ticks and labels - expanded range for callouts fig.update_xaxes(showgrid=False, showticklabels=False, zeroline=False, range=[0, 5.0]) fig.update_yaxes(showgrid=False, showticklabels=False, zeroline=False, range=[-0.5, 3.5]) return fig def generate_visualization(self, model_names: List[str]) -> Union[go.Figure, str]: """Generate visualization for given models""" # Filter out empty model names valid_models = [name.strip() for name in model_names if name and name.strip()] if not valid_models: return "Please enter at least one model name." models_data = [] errors = [] for model_name in valid_models: try: config = self.get_model_config(model_name) if config: config_values = self.extract_config_values(config) models_data.append((model_name, config_values)) else: errors.append(f"Could not load config for {model_name}") except Exception as e: errors.append(f"Error with {model_name}: {str(e)}") if not models_data: return f"❌ Could not load any models. Errors: {'; '.join(errors)}" if errors: logger.warning(f"Some models failed to load: {errors}") try: fig = self.create_comparison_diagram(models_data) return fig except Exception as e: return f"❌ Error generating diagram: {str(e)}" def create_gradio_interface(): """Create and configure the Gradio interface""" visualizer = PlotlyModelArchitectureVisualizer() def process_models(model1: str, model2: str = "", model3: str = "") -> Union[go.Figure, str]: """Process the model inputs and generate visualization""" models = [model1, model2, model3] return visualizer.generate_visualization(models) # Create the interface with gr.Blocks( title="🧠 Model Architecture Visualizer", theme=gr.themes.Soft(), css=""" .gradio-container { max-width: 1200px !important; } .model-input { font-family: monospace; } """ ) as demo: gr.Markdown(""" # 🧠 Interactive Model Architecture Visualizer Compare up to 3 Hugging Face transformer models side-by-side! Enter model IDs below to see their architecture diagrams with interactive features. ### 📋 How to Use 1. **Enter Model IDs**: Use Hugging Face model identifiers (e.g., `moonshotai/Kimi-K2-Base`, `openai/gpt-oss-120b`, `deepseek-ai/DeepSeek-R1-0528`) 2. **Compare Models**: Add up to 3 models to see them side-by-side 3. **Explore Interactively**: Hover over components to see detailed specifications """) # Model inputs in a single row gr.Markdown("### 📝 Model Configuration") with gr.Row(): model1 = gr.Textbox( label="Model 1 (Required)", placeholder="e.g., openai/gpt-oss-120b", value="openai/gpt-oss-120b", elem_classes=["model-input"] ) model2 = gr.Textbox( label="Model 2 (Optional)", placeholder="e.g., moonshotai/Kimi-K2-Base", elem_classes=["model-input"] ) model3 = gr.Textbox( label="Model 3 (Optional)", placeholder="e.g., deepseek-ai/DeepSeek-R1-0528", elem_classes=["model-input"] ) with gr.Row(): generate_btn = gr.Button("🚀 Generate Visualization", variant="primary", size="lg") clear_btn = gr.Button("🗑️ Clear", variant="secondary") # Visualization output - full width output_plot = gr.Plot( label="🧠 Architecture Visualization", show_label=True ) # Event handlers generate_btn.click( fn=process_models, inputs=[model1, model2, model3], outputs=output_plot ) clear_btn.click( fn=lambda: ("", "", "", None), outputs=[model1, model2, model3, output_plot] ) # Auto-generate for default model demo.load( fn=lambda: process_models("openai/gpt-oss-120b"), outputs=output_plot ) gr.Markdown("""Built with ❤️ using Plotly, Gradio, and Hugging Face Transformers""") return demo if __name__ == "__main__": # Create and launch the app demo = create_gradio_interface() # For HuggingFace Spaces deployment demo.launch( share=False, server_name="0.0.0.0", server_port=7860, show_error=True )