import gradio as gr import matplotlib.pyplot as plt import yaml import json from pathlib import Path import io from utils import calculate_memory_components, plot_memory_breakdown def load_config_from_content(content): try: # Try parsing as JSON first try: config = json.loads(content) # Check if this is a multimodal model with text_config if 'text_config' in config: # Use text_config for model parameters text_config = config['text_config'] return { 'hidden_size': text_config['hidden_size'], 'num_layers': text_config['num_hidden_layers'], 'vocab_size': config.get('vocab_size', 256000), # Default for multimodal models 'intermediate_size': text_config['intermediate_size'], 'seq_len': 2048, # Default value since not in config 'mbs': 1, # Default value 'batch_accum': 1, # Default value 'tp': 1, # Default value 'pp': 1, # Default value 'dp': 1, # Default value 'zero_stage': 0, # Default value 'tie_word_embeddings': config.get('tie_word_embeddings', True), 'num_attention_heads': text_config['num_attention_heads'], 'num_key_value_heads': text_config.get('num_key_value_heads', text_config['num_attention_heads']), 'full_checkpointing': False # Default value } else: # Original code for non-multimodal models return { 'hidden_size': config['hidden_size'], 'num_layers': config['num_hidden_layers'], 'vocab_size': config['vocab_size'], 'intermediate_size': config['intermediate_size'], 'seq_len': 2048, # Default value since not in config 'mbs': 1, # Default value 'batch_accum': 1, # Default value 'tp': 1, # Default value 'pp': 1, # Default value 'dp': 1, # Default value 'zero_stage': 0, # Default value 'tie_word_embeddings': config.get('tie_word_embeddings', True), 'num_attention_heads': config['num_attention_heads'], 'num_key_value_heads': config.get('num_key_value_heads', config['num_attention_heads']), 'full_checkpointing': False # Default value } except json.JSONDecodeError: # If not JSON, try YAML config = yaml.safe_load(content) # Extract relevant parameters from YAML config model_config = config['model']['model_config'] parallelism = config['parallelism'] tokens = config['tokens'] optimizer = config['optimizer'] return { 'hidden_size': model_config['hidden_size'], 'num_layers': model_config['num_hidden_layers'], 'vocab_size': model_config['vocab_size'], 'intermediate_size': model_config['intermediate_size'], 'seq_len': tokens['sequence_length'], 'mbs': tokens['micro_batch_size'], 'batch_accum': tokens['batch_accumulation_per_replica'], 'tp': parallelism['tp'], 'pp': parallelism['pp'], 'dp': parallelism['dp'], 'zero_stage': optimizer['zero_stage'], 'tie_word_embeddings': model_config['tie_word_embeddings'], 'num_attention_heads': model_config['num_attention_heads'], 'num_key_value_heads': model_config.get('num_key_value_heads', model_config['num_attention_heads']), 'full_checkpointing': optimizer.get('full_checkpointing', False) # Renamed from fsdp_checkpointing } except Exception as e: raise gr.Error(f"Error parsing configuration: {str(e)}") def load_config_from_yaml_file(yaml_path): if not yaml_path: return None with open(yaml_path.name, 'r') as f: return load_config_from_content(f.read()) def format_config_display(config): if not config: return "No configuration loaded" # Calculate number of parameters vocab_embeddings = config['vocab_size'] * config['hidden_size'] * (1 if config['tie_word_embeddings'] else 2) layer_params = ( (config['hidden_size'] * config['hidden_size'] * (1 + 2*config['num_key_value_heads']/config['num_attention_heads'])) # qkv_proj + (config['hidden_size'] * config['hidden_size']) # out_proj + (config['hidden_size'] * 2 * config['intermediate_size']) # gate_up_proj + (config['intermediate_size'] * config['hidden_size']) # down_proj ) total_params = (vocab_embeddings + config['num_layers'] * layer_params) params_billions = total_params / 1_000_000_000 sections = { "Model Architecture": [ "hidden_size", "num_layers", "vocab_size", "intermediate_size", "tie_word_embeddings", "num_attention_heads", "num_key_value_heads", ("num_params", f"{params_billions:.2f}B") # Show params in billions ], "Training Configuration": [ "seq_len", "mbs", "batch_accum" ], "Parallelism": [ "tp", "pp", "dp", "zero_stage", "full_checkpointing" ] } output = "
" for section_name, params in sections.items(): output += f"

{section_name}

" for param in params: if isinstance(param, tuple): # Handle custom parameter display param_name, value = param output += f"{param_name}: {value}
" else: value = config.get(param, 'N/A') output += f"{param}: {value}
" output += "
" output += "
" return output def process_yaml_and_plot(config): if not config: return None, None, "No configuration loaded", None fig1, fig2, memory_usage_peak_tbi = plot_memory_breakdown(**config) oom_prediction = "OOM" if memory_usage_peak_tbi > 75000 else "No OOM" return fig1, fig2, format_config_display(config), oom_prediction with gr.Blocks() as demo: with gr.Row(): with gr.Column(scale=1): with gr.Accordion("Configuration Input", open=True): config_text = gr.Textbox( label="Paste YAML or JSON configuration", placeholder="Paste your YAML or JSON configuration here...", lines=10 ) config_submit = gr.Button("Calculate Memory from Config") with gr.Accordion("Manual Configuration", open=True): with gr.Accordion("Model Architecture", open=True): with gr.Row(): hidden_size = gr.Number(4096, label="Hidden Size") num_layers = gr.Number(32, label="Number of Layers") with gr.Row(): vocab_size = gr.Number(50432, label="Vocabulary Size") intermediate_size = gr.Number(11008, label="Intermediate Size") with gr.Row(): num_attention_heads = gr.Number(32, label="Number of Attention Heads") num_key_value_heads = gr.Number(32, label="Number of Key Value Heads") tie_word_embeddings = gr.Checkbox(True, label="Tie Word Embeddings") with gr.Accordion("Training Configuration", open=True): with gr.Row(): seq_len = gr.Number(2048, label="Sequence Length") mbs = gr.Number(1, label="Micro Batch Size") batch_accum = gr.Number(1, label="Gradient Accumulation Steps") with gr.Accordion("Parallelism", open=True): with gr.Row(): tp = gr.Number(1, label="Tensor Parallelism") pp = gr.Number(1, label="Pipeline Parallelism") dp = gr.Number(1, label="Data Parallelism") zero_stage = gr.Radio([0, 1, 2, 3], value=0, label="ZeRO Stage") full_checkpointing = gr.Checkbox(False, label="Full Activation Checkpointing") manual_submit = gr.Button("Calculate Memory (Manual Input)") with gr.Column(scale=2): config_display = gr.Markdown(label="Configuration Values") oom_display = gr.Text(label="OOM Prediction") plot1 = gr.Plot(label="Memory Component Breakdown") plot2 = gr.Plot(label="Aggregate Memory Metrics") # Handle config text input config_submit.click( lambda x: process_yaml_and_update_ui(load_config_from_content(x) if x else None), inputs=[config_text], outputs=[ plot1, plot2, config_display, oom_display, hidden_size, num_attention_heads, num_key_value_heads, num_layers, vocab_size, intermediate_size, seq_len, mbs, batch_accum, tp, pp, dp, zero_stage, tie_word_embeddings, full_checkpointing ] ) def process_yaml_and_update_ui(config): if not config: return [None, None, "No configuration loaded", None] + [gr.update() for _ in range(14)] fig1, fig2, memory_usage_peak_tbi = plot_memory_breakdown(**config) oom_prediction = "OOM" if memory_usage_peak_tbi > 75000 else "No OOM" # Return values for all outputs including UI updates return [ fig1, fig2, format_config_display(config), oom_prediction, # UI component updates config['hidden_size'], config['num_attention_heads'], config['num_key_value_heads'], config['num_layers'], config['vocab_size'], config['intermediate_size'], config['seq_len'], config['mbs'], config['batch_accum'], config['tp'], config['pp'], config['dp'], config['zero_stage'], config['tie_word_embeddings'], config['full_checkpointing'] ] # Handle manual input def manual_input_to_config(*args): config = { 'hidden_size': args[0], 'num_layers': args[3], 'vocab_size': args[4], 'intermediate_size': args[5], 'seq_len': args[6], 'mbs': args[7], 'batch_accum': args[8], 'tp': args[9], 'pp': args[10], 'dp': args[11], 'zero_stage': args[12], 'tie_word_embeddings': args[13], 'num_attention_heads': args[1], 'num_key_value_heads': args[2], 'full_checkpointing': args[14] # Renamed from fsdp_checkpointing } return process_yaml_and_update_ui(config) manual_submit.click( manual_input_to_config, inputs=[ hidden_size, num_attention_heads, num_key_value_heads, num_layers, vocab_size, intermediate_size, seq_len, mbs, batch_accum, tp, pp, dp, zero_stage, tie_word_embeddings, full_checkpointing # Renamed from fsdp_checkpointing ], outputs=[plot1, plot2, config_display, oom_display] ) if __name__ == "__main__": demo.launch()