predict_memory / utils.py
nouamanetazi's picture
nouamanetazi HF Staff
support VLMs
5a41adf
import matplotlib.pyplot as plt
import numpy as np
import functools
@functools.lru_cache(maxsize=None)
def get_num_hidden_layers_in_pp(hidden_size, num_layers, vocab_size, intermediate_size, num_attention_heads, pp_size):
if pp_size == 1:
return num_layers
# Get list of pipeline blocks and their costs
pipeline_blocks = []
block_costs = []
# Embedding layer (treated as zero cost in the original implementation)
pipeline_blocks.append("embedding")
block_costs.append(0)
# Decoder layers
decoder_cost = (4 * num_attention_heads * (hidden_size//num_attention_heads) * hidden_size +
3 * intermediate_size * hidden_size)
for _ in range(num_layers):
pipeline_blocks.append("decoder")
block_costs.append(decoder_cost)
# LM head
pipeline_blocks.append("lm_head")
block_costs.append(vocab_size * hidden_size)
# Now follow the same logic as the original code
total_cost = sum(block_costs)
target_cost_per_rank = total_cost / pp_size
blocks_in_rank0 = 0
current_cost = 0
for block_idx, block_cost in enumerate(block_costs):
current_cost += block_cost
blocks_in_rank0 += 1
# Check if we should move to next rank
remaining_ranks = pp_size - 1 # -1 because we're calculating for rank 0
remaining_nonzero_blocks = sum(1 for c in block_costs[block_idx+1:] if c > 0)
if (remaining_ranks > 0 and remaining_nonzero_blocks <= remaining_ranks) or (current_cost >= target_cost_per_rank):
break
num_hidden_layers_in_pp = blocks_in_rank0 - 1 # We exclude first rank as it's the embedding layer
return num_hidden_layers_in_pp
@functools.lru_cache(maxsize=None)
def calculate_memory_components(
hidden_size, num_attention_heads, num_key_value_heads, num_layers, vocab_size, intermediate_size,
seq_len, mbs, batch_accum, tp, pp, dp, zero_stage,
tie_word_embeddings, full_checkpointing=False
):
# Calculate base components first
if pp == 1:
num_hidden_layers_in_pp = num_layers
else:
num_hidden_layers_in_pp = get_num_hidden_layers_in_pp(hidden_size, num_layers, vocab_size, intermediate_size, num_attention_heads, pp)
# Model BF16 calculation
vocab_embeddings = vocab_size * hidden_size * (2 if (not tie_word_embeddings and pp==1) else 1)
layer_params = (
(hidden_size * hidden_size * (1 + 2*num_key_value_heads/num_attention_heads)) # qkv_proj
+ (hidden_size * hidden_size) # out_proj
+ (hidden_size * 2 * intermediate_size) # gate_up_proj
+ (intermediate_size * hidden_size) # down_proj
)
model_bf16_full = (vocab_embeddings + num_hidden_layers_in_pp * layer_params) * (2 / 1024 / 1024) / tp
# Calculate number of parameters in billions
num_params_in_B = (vocab_embeddings + num_layers*layer_params) / 1e9
# Adjust model components based on ZeRO stage
if zero_stage == 3:
# In ZeRO-3, model parameters are sharded across dp ranks
model_bf16 = model_bf16_full / dp
fp32_params = 2 * model_bf16
fp32_grads = 2 * model_bf16
optimstates = 4 * model_bf16
# Additional communication buffers for ZeRO-3
zero3_buffers = 2 * model_bf16 # For parameter gathering during forward/backward
else:
# For ZeRO-0/1/2
dp_if_zero = 1 if zero_stage == 0 else dp
model_bf16 = model_bf16_full
fp32_params = 2 * model_bf16 / dp_if_zero
fp32_grads = 2 * model_bf16
optimstates = 4 * model_bf16 / dp_if_zero
zero3_buffers = 0
use_ddp = zero_stage == 0 and dp > 1
ddp_grads_buffers = model_bf16 if use_ddp else 0
overhead = 72 + 32 * mbs
# Activations calculation with FSDP checkpointing support
is_mha = num_key_value_heads == num_attention_heads
decoder_layer_mib = (seq_len * mbs * hidden_size/tp) * (2/1024/1024) * (4*intermediate_size/hidden_size + 6 + 2*num_key_value_heads/num_attention_heads + 2)
if pp > 1:
activs = min(pp, batch_accum) * num_hidden_layers_in_pp * decoder_layer_mib
else:
cast_to_fp32 = sharded_cross_entropy = seq_len * mbs * vocab_size * (2 / 1024 / 1024) * 2 / tp
base_activs = num_layers * decoder_layer_mib + cast_to_fp32 + sharded_cross_entropy
# Apply activation reduction for FSDP checkpointing in ZeRO-3
if zero_stage == 3 and full_checkpointing:
activs = base_activs / dp # Activation memory is reduced by dp factor with checkpointing
else:
activs = base_activs
# Calculate aggregate metrics
memory_usage_after_optimstates = (
model_bf16 +
fp32_params +
fp32_grads +
optimstates +
ddp_grads_buffers +
zero3_buffers +
overhead
)
memory_usage_before_optimstates = (
model_bf16 +
fp32_params +
fp32_grads +
ddp_grads_buffers +
zero3_buffers
)
memory_usage_peak_tbi = (
model_bf16 +
fp32_params +
fp32_grads +
optimstates +
ddp_grads_buffers +
zero3_buffers +
overhead +
activs
)
return {
"Components": {
"Model BF16": model_bf16,
"FP32 Parameters": fp32_params,
"FP32 Gradients": fp32_grads,
"Optimizer States": optimstates,
"DDP Gradient Buffers": ddp_grads_buffers,
"ZeRO-3 Buffers": zero3_buffers,
"Overhead": overhead,
"Activations": activs,
},
"Aggregates": {
"Memory Before Optimizer States": memory_usage_before_optimstates,
"Memory After Optimizer States": memory_usage_after_optimstates,
"Peak Memory (TBI)": memory_usage_peak_tbi
}
}
def plot_memory_breakdown(
hidden_size, num_attention_heads, num_key_value_heads, num_layers, vocab_size, intermediate_size,
seq_len, mbs, batch_accum, tp, pp, dp, zero_stage,
tie_word_embeddings, full_checkpointing=False
):
results = calculate_memory_components(
hidden_size, num_attention_heads, num_key_value_heads, num_layers, vocab_size, intermediate_size,
seq_len, mbs, batch_accum, tp, pp, dp, zero_stage,
tie_word_embeddings, full_checkpointing
)
memory_usage_peak_tbi = results["Aggregates"]["Peak Memory (TBI)"]
# Create figure for components plot
plt.close('all')
fig1 = plt.figure(figsize=(10, 5))
ax1 = fig1.add_subplot(1, 1, 1)
# Plot components
components = results["Components"]
names = list(components.keys())
values = list(components.values())
colors = plt.cm.Set3(np.linspace(0, 1, len(components)))
color_map = dict(zip(names, colors))
bars1 = ax1.bar(range(len(components)), values, color=colors)
# Add value labels with better positioning
for bar in bars1:
height = bar.get_height()
ax1.text(bar.get_x() + bar.get_width()/2., height,
f'{height:.1f} MiB',
ha='center', va='bottom',
rotation=0) # Remove rotation for better readability
# Customize the first plot
ax1.set_xticks(range(len(components)))
ax1.set_xticklabels(names, rotation=45, ha='right')
ax1.set_ylabel('Memory (MiB)')
ax1.set_title('Memory Component Breakdown', pad=20)
plt.tight_layout()
# Create figure for timeline plot
fig2 = plt.figure(figsize=(10, 6))
ax2 = fig2.add_subplot(1, 1, 1)
# Define timeline steps and their components
c = results["Components"]
timeline_steps = {
"Model Init": [
("Model BF16", c["Model BF16"]),
("DDP Gradient Buffers", c["DDP Gradient Buffers"]),
("ZeRO-3 Buffers", c["ZeRO-3 Buffers"]),
],
"Gradient Accumulator Init": [
("Model BF16", c["Model BF16"]),
("DDP Gradient Buffers", c["DDP Gradient Buffers"]),
("ZeRO-3 Buffers", c["ZeRO-3 Buffers"]),
("FP32 Parameters", c["FP32 Parameters"]),
("FP32 Gradients", c["FP32 Gradients"])
],
"Fwd-Bwd Peak": [
("Model BF16", c["Model BF16"]),
("DDP Gradient Buffers", c["DDP Gradient Buffers"]),
("ZeRO-3 Buffers", c["ZeRO-3 Buffers"]),
("FP32 Parameters", c["FP32 Parameters"]),
("FP32 Gradients", c["FP32 Gradients"]),
("Activations", c["Activations"])
],
"Optimizer Step": [
("Model BF16", c["Model BF16"]),
("ZeRO-3 Buffers", c["ZeRO-3 Buffers"]),
("FP32 Parameters", c["FP32 Parameters"]),
("FP32 Gradients", c["FP32 Gradients"]),
("Optimizer States", c["Optimizer States"])
],
"2nd Fwd-Bwd Peak": [
("Model BF16", c["Model BF16"]),
("ZeRO-3 Buffers", c["ZeRO-3 Buffers"]),
("FP32 Parameters", c["FP32 Parameters"]),
("FP32 Gradients", c["FP32 Gradients"]),
("Optimizer States", c["Optimizer States"]),
("DDP Gradient Buffers", c["DDP Gradient Buffers"]),
("Activations", c["Activations"])
],
"2nd Optimizer Step": [
("Model BF16", c["Model BF16"]),
("ZeRO-3 Buffers", c["ZeRO-3 Buffers"]),
("FP32 Parameters", c["FP32 Parameters"]),
("FP32 Gradients", c["FP32 Gradients"]),
("Optimizer States", c["Optimizer States"]),
("DDP Gradient Buffers", c["DDP Gradient Buffers"])
]
}
# Plot timeline
x = range(len(timeline_steps))
bottom = np.zeros(len(timeline_steps))
for component in c.keys():
heights = []
for step_components in timeline_steps.values():
height = 0
for comp_name, comp_value in step_components:
if comp_name == component:
height = comp_value
heights.append(height)
ax2.bar(x, heights, bottom=bottom, label=component, color=color_map[component])
bottom += heights
# Customize the timeline plot
ax2.set_xticks(x)
ax2.set_xticklabels(timeline_steps.keys(), rotation=45, ha='right')
ax2.set_ylabel('Memory (MiB)')
ax2.set_title('Memory Timeline', pad=20)
# Add total memory labels on top of each bar
for i, total in enumerate(bottom):
ax2.text(i, total, f'{total:.1f} MiB', ha='center', va='bottom')
# Adjust layout
plt.tight_layout()
# Set y-axis limit
max_y_value = max(bottom)
ax2.set_ylim(0, max(80000, max_y_value))
# Add legend below the plot
# plt.subplots_adjust(bottom=0.8)
ax2.legend(loc='lower center', bbox_to_anchor=(0.5, -1.5), ncol=3)
return fig1, fig2, memory_usage_peak_tbi