predict_memory / app.py
nouamanetazi's picture
nouamanetazi HF Staff
support VLMs
5a41adf
import gradio as gr
import matplotlib.pyplot as plt
import yaml
import json
from pathlib import Path
import io
from utils import calculate_memory_components, plot_memory_breakdown
def load_config_from_content(content):
try:
# Try parsing as JSON first
try:
config = json.loads(content)
# Check if this is a multimodal model with text_config
if 'text_config' in config:
# Use text_config for model parameters
text_config = config['text_config']
return {
'hidden_size': text_config['hidden_size'],
'num_layers': text_config['num_hidden_layers'],
'vocab_size': config.get('vocab_size', 256000), # Default for multimodal models
'intermediate_size': text_config['intermediate_size'],
'seq_len': 2048, # Default value since not in config
'mbs': 1, # Default value
'batch_accum': 1, # Default value
'tp': 1, # Default value
'pp': 1, # Default value
'dp': 1, # Default value
'zero_stage': 0, # Default value
'tie_word_embeddings': config.get('tie_word_embeddings', True),
'num_attention_heads': text_config['num_attention_heads'],
'num_key_value_heads': text_config.get('num_key_value_heads', text_config['num_attention_heads']),
'full_checkpointing': False # Default value
}
else:
# Original code for non-multimodal models
return {
'hidden_size': config['hidden_size'],
'num_layers': config['num_hidden_layers'],
'vocab_size': config['vocab_size'],
'intermediate_size': config['intermediate_size'],
'seq_len': 2048, # Default value since not in config
'mbs': 1, # Default value
'batch_accum': 1, # Default value
'tp': 1, # Default value
'pp': 1, # Default value
'dp': 1, # Default value
'zero_stage': 0, # Default value
'tie_word_embeddings': config.get('tie_word_embeddings', True),
'num_attention_heads': config['num_attention_heads'],
'num_key_value_heads': config.get('num_key_value_heads', config['num_attention_heads']),
'full_checkpointing': False # Default value
}
except json.JSONDecodeError:
# If not JSON, try YAML
config = yaml.safe_load(content)
# Extract relevant parameters from YAML config
model_config = config['model']['model_config']
parallelism = config['parallelism']
tokens = config['tokens']
optimizer = config['optimizer']
return {
'hidden_size': model_config['hidden_size'],
'num_layers': model_config['num_hidden_layers'],
'vocab_size': model_config['vocab_size'],
'intermediate_size': model_config['intermediate_size'],
'seq_len': tokens['sequence_length'],
'mbs': tokens['micro_batch_size'],
'batch_accum': tokens['batch_accumulation_per_replica'],
'tp': parallelism['tp'],
'pp': parallelism['pp'],
'dp': parallelism['dp'],
'zero_stage': optimizer['zero_stage'],
'tie_word_embeddings': model_config['tie_word_embeddings'],
'num_attention_heads': model_config['num_attention_heads'],
'num_key_value_heads': model_config.get('num_key_value_heads', model_config['num_attention_heads']),
'full_checkpointing': optimizer.get('full_checkpointing', False) # Renamed from fsdp_checkpointing
}
except Exception as e:
raise gr.Error(f"Error parsing configuration: {str(e)}")
def load_config_from_yaml_file(yaml_path):
if not yaml_path:
return None
with open(yaml_path.name, 'r') as f:
return load_config_from_content(f.read())
def format_config_display(config):
if not config:
return "No configuration loaded"
# Calculate number of parameters
vocab_embeddings = config['vocab_size'] * config['hidden_size'] * (1 if config['tie_word_embeddings'] else 2)
layer_params = (
(config['hidden_size'] * config['hidden_size'] * (1 + 2*config['num_key_value_heads']/config['num_attention_heads'])) # qkv_proj
+ (config['hidden_size'] * config['hidden_size']) # out_proj
+ (config['hidden_size'] * 2 * config['intermediate_size']) # gate_up_proj
+ (config['intermediate_size'] * config['hidden_size']) # down_proj
)
total_params = (vocab_embeddings + config['num_layers'] * layer_params)
params_billions = total_params / 1_000_000_000
sections = {
"Model Architecture": [
"hidden_size", "num_layers", "vocab_size",
"intermediate_size", "tie_word_embeddings", "num_attention_heads", "num_key_value_heads",
("num_params", f"{params_billions:.2f}B") # Show params in billions
],
"Training Configuration": [
"seq_len", "mbs", "batch_accum"
],
"Parallelism": [
"tp", "pp", "dp", "zero_stage", "full_checkpointing"
]
}
output = "<div style='display: flex;'>"
for section_name, params in sections.items():
output += f"<div style='flex: 1; padding-right: 20px;'><h3>{section_name}</h3>"
for param in params:
if isinstance(param, tuple):
# Handle custom parameter display
param_name, value = param
output += f"<b>{param_name}</b>: {value}<br>"
else:
value = config.get(param, 'N/A')
output += f"<b>{param}</b>: {value}<br>"
output += "</div>"
output += "</div>"
return output
def process_yaml_and_plot(config):
if not config:
return None, None, "No configuration loaded", None
fig1, fig2, memory_usage_peak_tbi = plot_memory_breakdown(**config)
oom_prediction = "OOM" if memory_usage_peak_tbi > 75000 else "No OOM"
return fig1, fig2, format_config_display(config), oom_prediction
with gr.Blocks() as demo:
with gr.Row():
with gr.Column(scale=1):
with gr.Accordion("Configuration Input", open=True):
config_text = gr.Textbox(
label="Paste YAML or JSON configuration",
placeholder="Paste your YAML or JSON configuration here...",
lines=10
)
config_submit = gr.Button("Calculate Memory from Config")
with gr.Accordion("Manual Configuration", open=True):
with gr.Accordion("Model Architecture", open=True):
with gr.Row():
hidden_size = gr.Number(4096, label="Hidden Size")
num_layers = gr.Number(32, label="Number of Layers")
with gr.Row():
vocab_size = gr.Number(50432, label="Vocabulary Size")
intermediate_size = gr.Number(11008, label="Intermediate Size")
with gr.Row():
num_attention_heads = gr.Number(32, label="Number of Attention Heads")
num_key_value_heads = gr.Number(32, label="Number of Key Value Heads")
tie_word_embeddings = gr.Checkbox(True, label="Tie Word Embeddings")
with gr.Accordion("Training Configuration", open=True):
with gr.Row():
seq_len = gr.Number(2048, label="Sequence Length")
mbs = gr.Number(1, label="Micro Batch Size")
batch_accum = gr.Number(1, label="Gradient Accumulation Steps")
with gr.Accordion("Parallelism", open=True):
with gr.Row():
tp = gr.Number(1, label="Tensor Parallelism")
pp = gr.Number(1, label="Pipeline Parallelism")
dp = gr.Number(1, label="Data Parallelism")
zero_stage = gr.Radio([0, 1, 2, 3], value=0, label="ZeRO Stage")
full_checkpointing = gr.Checkbox(False, label="Full Activation Checkpointing")
manual_submit = gr.Button("Calculate Memory (Manual Input)")
with gr.Column(scale=2):
config_display = gr.Markdown(label="Configuration Values")
oom_display = gr.Text(label="OOM Prediction")
plot1 = gr.Plot(label="Memory Component Breakdown")
plot2 = gr.Plot(label="Aggregate Memory Metrics")
# Handle config text input
config_submit.click(
lambda x: process_yaml_and_update_ui(load_config_from_content(x) if x else None),
inputs=[config_text],
outputs=[
plot1, plot2, config_display, oom_display,
hidden_size, num_attention_heads, num_key_value_heads, num_layers,
vocab_size, intermediate_size, seq_len, mbs, batch_accum,
tp, pp, dp, zero_stage, tie_word_embeddings, full_checkpointing
]
)
def process_yaml_and_update_ui(config):
if not config:
return [None, None, "No configuration loaded", None] + [gr.update() for _ in range(14)]
fig1, fig2, memory_usage_peak_tbi = plot_memory_breakdown(**config)
oom_prediction = "OOM" if memory_usage_peak_tbi > 75000 else "No OOM"
# Return values for all outputs including UI updates
return [
fig1, fig2,
format_config_display(config),
oom_prediction,
# UI component updates
config['hidden_size'],
config['num_attention_heads'],
config['num_key_value_heads'],
config['num_layers'],
config['vocab_size'],
config['intermediate_size'],
config['seq_len'],
config['mbs'],
config['batch_accum'],
config['tp'],
config['pp'],
config['dp'],
config['zero_stage'],
config['tie_word_embeddings'],
config['full_checkpointing']
]
# Handle manual input
def manual_input_to_config(*args):
config = {
'hidden_size': args[0],
'num_layers': args[3],
'vocab_size': args[4],
'intermediate_size': args[5],
'seq_len': args[6],
'mbs': args[7],
'batch_accum': args[8],
'tp': args[9],
'pp': args[10],
'dp': args[11],
'zero_stage': args[12],
'tie_word_embeddings': args[13],
'num_attention_heads': args[1],
'num_key_value_heads': args[2],
'full_checkpointing': args[14] # Renamed from fsdp_checkpointing
}
return process_yaml_and_update_ui(config)
manual_submit.click(
manual_input_to_config,
inputs=[
hidden_size, num_attention_heads, num_key_value_heads, num_layers, vocab_size, intermediate_size,
seq_len, mbs, batch_accum, tp, pp, dp, zero_stage,
tie_word_embeddings, full_checkpointing # Renamed from fsdp_checkpointing
],
outputs=[plot1, plot2, config_display, oom_display]
)
if __name__ == "__main__":
demo.launch()