import matplotlib.pyplot as plt import numpy as np import functools @functools.lru_cache(maxsize=None) def get_num_hidden_layers_in_pp(hidden_size, num_layers, vocab_size, intermediate_size, num_attention_heads, pp_size): if pp_size == 1: return num_layers # Get list of pipeline blocks and their costs pipeline_blocks = [] block_costs = [] # Embedding layer (treated as zero cost in the original implementation) pipeline_blocks.append("embedding") block_costs.append(0) # Decoder layers decoder_cost = (4 * num_attention_heads * (hidden_size//num_attention_heads) * hidden_size + 3 * intermediate_size * hidden_size) for _ in range(num_layers): pipeline_blocks.append("decoder") block_costs.append(decoder_cost) # LM head pipeline_blocks.append("lm_head") block_costs.append(vocab_size * hidden_size) # Now follow the same logic as the original code total_cost = sum(block_costs) target_cost_per_rank = total_cost / pp_size blocks_in_rank0 = 0 current_cost = 0 for block_idx, block_cost in enumerate(block_costs): current_cost += block_cost blocks_in_rank0 += 1 # Check if we should move to next rank remaining_ranks = pp_size - 1 # -1 because we're calculating for rank 0 remaining_nonzero_blocks = sum(1 for c in block_costs[block_idx+1:] if c > 0) if (remaining_ranks > 0 and remaining_nonzero_blocks <= remaining_ranks) or (current_cost >= target_cost_per_rank): break num_hidden_layers_in_pp = blocks_in_rank0 - 1 # We exclude first rank as it's the embedding layer return num_hidden_layers_in_pp @functools.lru_cache(maxsize=None) def calculate_memory_components( hidden_size, num_attention_heads, num_key_value_heads, num_layers, vocab_size, intermediate_size, seq_len, mbs, batch_accum, tp, pp, dp, zero_stage, tie_word_embeddings, full_checkpointing=False ): # Calculate base components first if pp == 1: num_hidden_layers_in_pp = num_layers else: num_hidden_layers_in_pp = get_num_hidden_layers_in_pp(hidden_size, num_layers, vocab_size, intermediate_size, num_attention_heads, pp) # Model BF16 calculation vocab_embeddings = vocab_size * hidden_size * (2 if (not tie_word_embeddings and pp==1) else 1) layer_params = ( (hidden_size * hidden_size * (1 + 2*num_key_value_heads/num_attention_heads)) # qkv_proj + (hidden_size * hidden_size) # out_proj + (hidden_size * 2 * intermediate_size) # gate_up_proj + (intermediate_size * hidden_size) # down_proj ) model_bf16_full = (vocab_embeddings + num_hidden_layers_in_pp * layer_params) * (2 / 1024 / 1024) / tp # Calculate number of parameters in billions num_params_in_B = (vocab_embeddings + num_layers*layer_params) / 1e9 # Adjust model components based on ZeRO stage if zero_stage == 3: # In ZeRO-3, model parameters are sharded across dp ranks model_bf16 = model_bf16_full / dp fp32_params = 2 * model_bf16 fp32_grads = 2 * model_bf16 optimstates = 4 * model_bf16 # Additional communication buffers for ZeRO-3 zero3_buffers = 2 * model_bf16 # For parameter gathering during forward/backward else: # For ZeRO-0/1/2 dp_if_zero = 1 if zero_stage == 0 else dp model_bf16 = model_bf16_full fp32_params = 2 * model_bf16 / dp_if_zero fp32_grads = 2 * model_bf16 optimstates = 4 * model_bf16 / dp_if_zero zero3_buffers = 0 use_ddp = zero_stage == 0 and dp > 1 ddp_grads_buffers = model_bf16 if use_ddp else 0 overhead = 72 + 32 * mbs # Activations calculation with FSDP checkpointing support is_mha = num_key_value_heads == num_attention_heads decoder_layer_mib = (seq_len * mbs * hidden_size/tp) * (2/1024/1024) * (4*intermediate_size/hidden_size + 6 + 2*num_key_value_heads/num_attention_heads + 2) if pp > 1: activs = min(pp, batch_accum) * num_hidden_layers_in_pp * decoder_layer_mib else: cast_to_fp32 = sharded_cross_entropy = seq_len * mbs * vocab_size * (2 / 1024 / 1024) * 2 / tp base_activs = num_layers * decoder_layer_mib + cast_to_fp32 + sharded_cross_entropy # Apply activation reduction for FSDP checkpointing in ZeRO-3 if zero_stage == 3 and full_checkpointing: activs = base_activs / dp # Activation memory is reduced by dp factor with checkpointing else: activs = base_activs # Calculate aggregate metrics memory_usage_after_optimstates = ( model_bf16 + fp32_params + fp32_grads + optimstates + ddp_grads_buffers + zero3_buffers + overhead ) memory_usage_before_optimstates = ( model_bf16 + fp32_params + fp32_grads + ddp_grads_buffers + zero3_buffers ) memory_usage_peak_tbi = ( model_bf16 + fp32_params + fp32_grads + optimstates + ddp_grads_buffers + zero3_buffers + overhead + activs ) return { "Components": { "Model BF16": model_bf16, "FP32 Parameters": fp32_params, "FP32 Gradients": fp32_grads, "Optimizer States": optimstates, "DDP Gradient Buffers": ddp_grads_buffers, "ZeRO-3 Buffers": zero3_buffers, "Overhead": overhead, "Activations": activs, }, "Aggregates": { "Memory Before Optimizer States": memory_usage_before_optimstates, "Memory After Optimizer States": memory_usage_after_optimstates, "Peak Memory (TBI)": memory_usage_peak_tbi } } def plot_memory_breakdown( hidden_size, num_attention_heads, num_key_value_heads, num_layers, vocab_size, intermediate_size, seq_len, mbs, batch_accum, tp, pp, dp, zero_stage, tie_word_embeddings, full_checkpointing=False ): results = calculate_memory_components( hidden_size, num_attention_heads, num_key_value_heads, num_layers, vocab_size, intermediate_size, seq_len, mbs, batch_accum, tp, pp, dp, zero_stage, tie_word_embeddings, full_checkpointing ) memory_usage_peak_tbi = results["Aggregates"]["Peak Memory (TBI)"] # Create figure for components plot plt.close('all') fig1 = plt.figure(figsize=(10, 5)) ax1 = fig1.add_subplot(1, 1, 1) # Plot components components = results["Components"] names = list(components.keys()) values = list(components.values()) colors = plt.cm.Set3(np.linspace(0, 1, len(components))) color_map = dict(zip(names, colors)) bars1 = ax1.bar(range(len(components)), values, color=colors) # Add value labels with better positioning for bar in bars1: height = bar.get_height() ax1.text(bar.get_x() + bar.get_width()/2., height, f'{height:.1f} MiB', ha='center', va='bottom', rotation=0) # Remove rotation for better readability # Customize the first plot ax1.set_xticks(range(len(components))) ax1.set_xticklabels(names, rotation=45, ha='right') ax1.set_ylabel('Memory (MiB)') ax1.set_title('Memory Component Breakdown', pad=20) plt.tight_layout() # Create figure for timeline plot fig2 = plt.figure(figsize=(10, 6)) ax2 = fig2.add_subplot(1, 1, 1) # Define timeline steps and their components c = results["Components"] timeline_steps = { "Model Init": [ ("Model BF16", c["Model BF16"]), ("DDP Gradient Buffers", c["DDP Gradient Buffers"]), ("ZeRO-3 Buffers", c["ZeRO-3 Buffers"]), ], "Gradient Accumulator Init": [ ("Model BF16", c["Model BF16"]), ("DDP Gradient Buffers", c["DDP Gradient Buffers"]), ("ZeRO-3 Buffers", c["ZeRO-3 Buffers"]), ("FP32 Parameters", c["FP32 Parameters"]), ("FP32 Gradients", c["FP32 Gradients"]) ], "Fwd-Bwd Peak": [ ("Model BF16", c["Model BF16"]), ("DDP Gradient Buffers", c["DDP Gradient Buffers"]), ("ZeRO-3 Buffers", c["ZeRO-3 Buffers"]), ("FP32 Parameters", c["FP32 Parameters"]), ("FP32 Gradients", c["FP32 Gradients"]), ("Activations", c["Activations"]) ], "Optimizer Step": [ ("Model BF16", c["Model BF16"]), ("ZeRO-3 Buffers", c["ZeRO-3 Buffers"]), ("FP32 Parameters", c["FP32 Parameters"]), ("FP32 Gradients", c["FP32 Gradients"]), ("Optimizer States", c["Optimizer States"]) ], "2nd Fwd-Bwd Peak": [ ("Model BF16", c["Model BF16"]), ("ZeRO-3 Buffers", c["ZeRO-3 Buffers"]), ("FP32 Parameters", c["FP32 Parameters"]), ("FP32 Gradients", c["FP32 Gradients"]), ("Optimizer States", c["Optimizer States"]), ("DDP Gradient Buffers", c["DDP Gradient Buffers"]), ("Activations", c["Activations"]) ], "2nd Optimizer Step": [ ("Model BF16", c["Model BF16"]), ("ZeRO-3 Buffers", c["ZeRO-3 Buffers"]), ("FP32 Parameters", c["FP32 Parameters"]), ("FP32 Gradients", c["FP32 Gradients"]), ("Optimizer States", c["Optimizer States"]), ("DDP Gradient Buffers", c["DDP Gradient Buffers"]) ] } # Plot timeline x = range(len(timeline_steps)) bottom = np.zeros(len(timeline_steps)) for component in c.keys(): heights = [] for step_components in timeline_steps.values(): height = 0 for comp_name, comp_value in step_components: if comp_name == component: height = comp_value heights.append(height) ax2.bar(x, heights, bottom=bottom, label=component, color=color_map[component]) bottom += heights # Customize the timeline plot ax2.set_xticks(x) ax2.set_xticklabels(timeline_steps.keys(), rotation=45, ha='right') ax2.set_ylabel('Memory (MiB)') ax2.set_title('Memory Timeline', pad=20) # Add total memory labels on top of each bar for i, total in enumerate(bottom): ax2.text(i, total, f'{total:.1f} MiB', ha='center', va='bottom') # Adjust layout plt.tight_layout() # Set y-axis limit max_y_value = max(bottom) ax2.set_ylim(0, max(80000, max_y_value)) # Add legend below the plot # plt.subplots_adjust(bottom=0.8) ax2.legend(loc='lower center', bbox_to_anchor=(0.5, -1.5), ncol=3) return fig1, fig2, memory_usage_peak_tbi