Spaces:
Running
Running
import gradio as gr | |
import matplotlib.pyplot as plt | |
import yaml | |
import json | |
from pathlib import Path | |
import io | |
from utils import calculate_memory_components, plot_memory_breakdown | |
def load_config_from_content(content): | |
try: | |
# Try parsing as JSON first | |
try: | |
config = json.loads(content) | |
# Check if this is a multimodal model with text_config | |
if 'text_config' in config: | |
# Use text_config for model parameters | |
text_config = config['text_config'] | |
return { | |
'hidden_size': text_config['hidden_size'], | |
'num_layers': text_config['num_hidden_layers'], | |
'vocab_size': config.get('vocab_size', 256000), # Default for multimodal models | |
'intermediate_size': text_config['intermediate_size'], | |
'seq_len': 2048, # Default value since not in config | |
'mbs': 1, # Default value | |
'batch_accum': 1, # Default value | |
'tp': 1, # Default value | |
'pp': 1, # Default value | |
'dp': 1, # Default value | |
'zero_stage': 0, # Default value | |
'tie_word_embeddings': config.get('tie_word_embeddings', True), | |
'num_attention_heads': text_config['num_attention_heads'], | |
'num_key_value_heads': text_config.get('num_key_value_heads', text_config['num_attention_heads']), | |
'full_checkpointing': False # Default value | |
} | |
else: | |
# Original code for non-multimodal models | |
return { | |
'hidden_size': config['hidden_size'], | |
'num_layers': config['num_hidden_layers'], | |
'vocab_size': config['vocab_size'], | |
'intermediate_size': config['intermediate_size'], | |
'seq_len': 2048, # Default value since not in config | |
'mbs': 1, # Default value | |
'batch_accum': 1, # Default value | |
'tp': 1, # Default value | |
'pp': 1, # Default value | |
'dp': 1, # Default value | |
'zero_stage': 0, # Default value | |
'tie_word_embeddings': config.get('tie_word_embeddings', True), | |
'num_attention_heads': config['num_attention_heads'], | |
'num_key_value_heads': config.get('num_key_value_heads', config['num_attention_heads']), | |
'full_checkpointing': False # Default value | |
} | |
except json.JSONDecodeError: | |
# If not JSON, try YAML | |
config = yaml.safe_load(content) | |
# Extract relevant parameters from YAML config | |
model_config = config['model']['model_config'] | |
parallelism = config['parallelism'] | |
tokens = config['tokens'] | |
optimizer = config['optimizer'] | |
return { | |
'hidden_size': model_config['hidden_size'], | |
'num_layers': model_config['num_hidden_layers'], | |
'vocab_size': model_config['vocab_size'], | |
'intermediate_size': model_config['intermediate_size'], | |
'seq_len': tokens['sequence_length'], | |
'mbs': tokens['micro_batch_size'], | |
'batch_accum': tokens['batch_accumulation_per_replica'], | |
'tp': parallelism['tp'], | |
'pp': parallelism['pp'], | |
'dp': parallelism['dp'], | |
'zero_stage': optimizer['zero_stage'], | |
'tie_word_embeddings': model_config['tie_word_embeddings'], | |
'num_attention_heads': model_config['num_attention_heads'], | |
'num_key_value_heads': model_config.get('num_key_value_heads', model_config['num_attention_heads']), | |
'full_checkpointing': optimizer.get('full_checkpointing', False) # Renamed from fsdp_checkpointing | |
} | |
except Exception as e: | |
raise gr.Error(f"Error parsing configuration: {str(e)}") | |
def load_config_from_yaml_file(yaml_path): | |
if not yaml_path: | |
return None | |
with open(yaml_path.name, 'r') as f: | |
return load_config_from_content(f.read()) | |
def format_config_display(config): | |
if not config: | |
return "No configuration loaded" | |
# Calculate number of parameters | |
vocab_embeddings = config['vocab_size'] * config['hidden_size'] * (1 if config['tie_word_embeddings'] else 2) | |
layer_params = ( | |
(config['hidden_size'] * config['hidden_size'] * (1 + 2*config['num_key_value_heads']/config['num_attention_heads'])) # qkv_proj | |
+ (config['hidden_size'] * config['hidden_size']) # out_proj | |
+ (config['hidden_size'] * 2 * config['intermediate_size']) # gate_up_proj | |
+ (config['intermediate_size'] * config['hidden_size']) # down_proj | |
) | |
total_params = (vocab_embeddings + config['num_layers'] * layer_params) | |
params_billions = total_params / 1_000_000_000 | |
sections = { | |
"Model Architecture": [ | |
"hidden_size", "num_layers", "vocab_size", | |
"intermediate_size", "tie_word_embeddings", "num_attention_heads", "num_key_value_heads", | |
("num_params", f"{params_billions:.2f}B") # Show params in billions | |
], | |
"Training Configuration": [ | |
"seq_len", "mbs", "batch_accum" | |
], | |
"Parallelism": [ | |
"tp", "pp", "dp", "zero_stage", "full_checkpointing" | |
] | |
} | |
output = "<div style='display: flex;'>" | |
for section_name, params in sections.items(): | |
output += f"<div style='flex: 1; padding-right: 20px;'><h3>{section_name}</h3>" | |
for param in params: | |
if isinstance(param, tuple): | |
# Handle custom parameter display | |
param_name, value = param | |
output += f"<b>{param_name}</b>: {value}<br>" | |
else: | |
value = config.get(param, 'N/A') | |
output += f"<b>{param}</b>: {value}<br>" | |
output += "</div>" | |
output += "</div>" | |
return output | |
def process_yaml_and_plot(config): | |
if not config: | |
return None, None, "No configuration loaded", None | |
fig1, fig2, memory_usage_peak_tbi = plot_memory_breakdown(**config) | |
oom_prediction = "OOM" if memory_usage_peak_tbi > 75000 else "No OOM" | |
return fig1, fig2, format_config_display(config), oom_prediction | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
with gr.Column(scale=1): | |
with gr.Accordion("Configuration Input", open=True): | |
config_text = gr.Textbox( | |
label="Paste YAML or JSON configuration", | |
placeholder="Paste your YAML or JSON configuration here...", | |
lines=10 | |
) | |
config_submit = gr.Button("Calculate Memory from Config") | |
with gr.Accordion("Manual Configuration", open=True): | |
with gr.Accordion("Model Architecture", open=True): | |
with gr.Row(): | |
hidden_size = gr.Number(4096, label="Hidden Size") | |
num_layers = gr.Number(32, label="Number of Layers") | |
with gr.Row(): | |
vocab_size = gr.Number(50432, label="Vocabulary Size") | |
intermediate_size = gr.Number(11008, label="Intermediate Size") | |
with gr.Row(): | |
num_attention_heads = gr.Number(32, label="Number of Attention Heads") | |
num_key_value_heads = gr.Number(32, label="Number of Key Value Heads") | |
tie_word_embeddings = gr.Checkbox(True, label="Tie Word Embeddings") | |
with gr.Accordion("Training Configuration", open=True): | |
with gr.Row(): | |
seq_len = gr.Number(2048, label="Sequence Length") | |
mbs = gr.Number(1, label="Micro Batch Size") | |
batch_accum = gr.Number(1, label="Gradient Accumulation Steps") | |
with gr.Accordion("Parallelism", open=True): | |
with gr.Row(): | |
tp = gr.Number(1, label="Tensor Parallelism") | |
pp = gr.Number(1, label="Pipeline Parallelism") | |
dp = gr.Number(1, label="Data Parallelism") | |
zero_stage = gr.Radio([0, 1, 2, 3], value=0, label="ZeRO Stage") | |
full_checkpointing = gr.Checkbox(False, label="Full Activation Checkpointing") | |
manual_submit = gr.Button("Calculate Memory (Manual Input)") | |
with gr.Column(scale=2): | |
config_display = gr.Markdown(label="Configuration Values") | |
oom_display = gr.Text(label="OOM Prediction") | |
plot1 = gr.Plot(label="Memory Component Breakdown") | |
plot2 = gr.Plot(label="Aggregate Memory Metrics") | |
# Handle config text input | |
config_submit.click( | |
lambda x: process_yaml_and_update_ui(load_config_from_content(x) if x else None), | |
inputs=[config_text], | |
outputs=[ | |
plot1, plot2, config_display, oom_display, | |
hidden_size, num_attention_heads, num_key_value_heads, num_layers, | |
vocab_size, intermediate_size, seq_len, mbs, batch_accum, | |
tp, pp, dp, zero_stage, tie_word_embeddings, full_checkpointing | |
] | |
) | |
def process_yaml_and_update_ui(config): | |
if not config: | |
return [None, None, "No configuration loaded", None] + [gr.update() for _ in range(14)] | |
fig1, fig2, memory_usage_peak_tbi = plot_memory_breakdown(**config) | |
oom_prediction = "OOM" if memory_usage_peak_tbi > 75000 else "No OOM" | |
# Return values for all outputs including UI updates | |
return [ | |
fig1, fig2, | |
format_config_display(config), | |
oom_prediction, | |
# UI component updates | |
config['hidden_size'], | |
config['num_attention_heads'], | |
config['num_key_value_heads'], | |
config['num_layers'], | |
config['vocab_size'], | |
config['intermediate_size'], | |
config['seq_len'], | |
config['mbs'], | |
config['batch_accum'], | |
config['tp'], | |
config['pp'], | |
config['dp'], | |
config['zero_stage'], | |
config['tie_word_embeddings'], | |
config['full_checkpointing'] | |
] | |
# Handle manual input | |
def manual_input_to_config(*args): | |
config = { | |
'hidden_size': args[0], | |
'num_layers': args[3], | |
'vocab_size': args[4], | |
'intermediate_size': args[5], | |
'seq_len': args[6], | |
'mbs': args[7], | |
'batch_accum': args[8], | |
'tp': args[9], | |
'pp': args[10], | |
'dp': args[11], | |
'zero_stage': args[12], | |
'tie_word_embeddings': args[13], | |
'num_attention_heads': args[1], | |
'num_key_value_heads': args[2], | |
'full_checkpointing': args[14] # Renamed from fsdp_checkpointing | |
} | |
return process_yaml_and_update_ui(config) | |
manual_submit.click( | |
manual_input_to_config, | |
inputs=[ | |
hidden_size, num_attention_heads, num_key_value_heads, num_layers, vocab_size, intermediate_size, | |
seq_len, mbs, batch_accum, tp, pp, dp, zero_stage, | |
tie_word_embeddings, full_checkpointing # Renamed from fsdp_checkpointing | |
], | |
outputs=[plot1, plot2, config_display, oom_display] | |
) | |
if __name__ == "__main__": | |
demo.launch() | |