File size: 10,973 Bytes
5f67cc3
c68510e
4921bbf
c68510e
4921bbf
d51b632
4921bbf
 
 
d51b632
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5f67cc3
4921bbf
5f67cc3
d51b632
5f67cc3
5a41adf
5f67cc3
 
d51b632
 
 
 
5f67cc3
 
 
 
 
d51b632
5f67cc3
 
 
 
 
9d879a4
 
0b99db3
 
 
9d879a4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5f67cc3
 
 
 
 
1a15aaa
3f50411
4921bbf
5f67cc3
 
 
 
 
1a15aaa
 
 
5a41adf
1a15aaa
 
 
5f67cc3
 
 
 
9d879a4
5f67cc3
9d879a4
5f67cc3
9d879a4
5f67cc3
 
 
 
 
9d879a4
5f67cc3
9d879a4
 
5f67cc3
 
 
 
9d879a4
5f67cc3
9d879a4
5f67cc3
9d879a4
5f67cc3
 
 
 
 
 
 
9d879a4
5f67cc3
9d879a4
5f67cc3
9d879a4
5f67cc3
0b99db3
5f67cc3
 
 
 
 
 
 
 
 
d51b632
5f67cc3
5a41adf
5f67cc3
 
d51b632
5f67cc3
5a41adf
5f67cc3
9a970ef
5f67cc3
 
 
4921bbf
5f67cc3
 
 
 
 
 
 
4921bbf
 
 
 
5f67cc3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c68510e
4921bbf
5f67cc3
 
c68510e
 
 
 
 
 
9d879a4
c68510e
 
 
 
9d879a4
c68510e
 
 
 
 
 
9d879a4
c68510e
 
 
 
 
 
9d879a4
c68510e
 
 
 
 
 
9d879a4
c68510e
 
 
 
 
 
 
 
9d879a4
c68510e
 
 
 
 
 
 
 
 
4921bbf
5f67cc3
c68510e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5f67cc3
c68510e
4921bbf
5f67cc3
c68510e
 
 
 
 
5f67cc3
c68510e
 
 
 
4921bbf
 
 
9a970ef
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
import matplotlib.pyplot as plt
import numpy as np
import functools

@functools.lru_cache(maxsize=None)
def get_num_hidden_layers_in_pp(hidden_size, num_layers, vocab_size, intermediate_size, num_attention_heads, pp_size):
    if pp_size == 1:
        return num_layers

    # Get list of pipeline blocks and their costs
    pipeline_blocks = []
    block_costs = []
    
    # Embedding layer (treated as zero cost in the original implementation)
    pipeline_blocks.append("embedding")
    block_costs.append(0)
    
    # Decoder layers
    decoder_cost = (4 * num_attention_heads * (hidden_size//num_attention_heads) * hidden_size + 
                   3 * intermediate_size * hidden_size)
    for _ in range(num_layers):
        pipeline_blocks.append("decoder")
        block_costs.append(decoder_cost)
    
    # LM head
    pipeline_blocks.append("lm_head")
    block_costs.append(vocab_size * hidden_size)
    
    # Now follow the same logic as the original code
    total_cost = sum(block_costs)
    target_cost_per_rank = total_cost / pp_size
    
    blocks_in_rank0 = 0
    current_cost = 0
    
    for block_idx, block_cost in enumerate(block_costs):
        current_cost += block_cost
        blocks_in_rank0 += 1
        
        # Check if we should move to next rank
        remaining_ranks = pp_size - 1  # -1 because we're calculating for rank 0
        remaining_nonzero_blocks = sum(1 for c in block_costs[block_idx+1:] if c > 0)
        
        if (remaining_ranks > 0 and remaining_nonzero_blocks <= remaining_ranks) or (current_cost >= target_cost_per_rank):
            break
            
    num_hidden_layers_in_pp = blocks_in_rank0 - 1 # We exclude first rank as it's the embedding layer
    return num_hidden_layers_in_pp

@functools.lru_cache(maxsize=None)
def calculate_memory_components(
    hidden_size, num_attention_heads, num_key_value_heads, num_layers, vocab_size, intermediate_size,
    seq_len, mbs, batch_accum, tp, pp, dp, zero_stage,
    tie_word_embeddings, full_checkpointing=False
):
    # Calculate base components first
    if pp == 1:
        num_hidden_layers_in_pp = num_layers
    else:
        num_hidden_layers_in_pp = get_num_hidden_layers_in_pp(hidden_size, num_layers, vocab_size, intermediate_size, num_attention_heads, pp)
    
    # Model BF16 calculation
    vocab_embeddings = vocab_size * hidden_size * (2 if (not tie_word_embeddings and pp==1) else 1)
    
    layer_params = (
        (hidden_size * hidden_size * (1 + 2*num_key_value_heads/num_attention_heads))  # qkv_proj
        + (hidden_size * hidden_size)     # out_proj
        + (hidden_size * 2 * intermediate_size)  # gate_up_proj
        + (intermediate_size * hidden_size)      # down_proj
    )
    
    model_bf16_full = (vocab_embeddings + num_hidden_layers_in_pp * layer_params) * (2 / 1024 / 1024) / tp

    # Calculate number of parameters in billions
    num_params_in_B = (vocab_embeddings + num_layers*layer_params) / 1e9

    # Adjust model components based on ZeRO stage
    if zero_stage == 3:
        # In ZeRO-3, model parameters are sharded across dp ranks
        model_bf16 = model_bf16_full / dp
        fp32_params = 2 * model_bf16
        fp32_grads = 2 * model_bf16
        optimstates = 4 * model_bf16
        # Additional communication buffers for ZeRO-3
        zero3_buffers = 2 * model_bf16  # For parameter gathering during forward/backward
    else:
        # For ZeRO-0/1/2
        dp_if_zero = 1 if zero_stage == 0 else dp
        model_bf16 = model_bf16_full
        fp32_params = 2 * model_bf16 / dp_if_zero
        fp32_grads = 2 * model_bf16
        optimstates = 4 * model_bf16 / dp_if_zero
        zero3_buffers = 0

    use_ddp = zero_stage == 0 and dp > 1
    ddp_grads_buffers = model_bf16 if use_ddp else 0
    overhead = 72 + 32 * mbs

    # Activations calculation with FSDP checkpointing support
    is_mha = num_key_value_heads == num_attention_heads
    decoder_layer_mib = (seq_len * mbs * hidden_size/tp) * (2/1024/1024) * (4*intermediate_size/hidden_size + 6 + 2*num_key_value_heads/num_attention_heads + 2)
    
    if pp > 1:
        activs = min(pp, batch_accum) * num_hidden_layers_in_pp * decoder_layer_mib
    else:
        cast_to_fp32 = sharded_cross_entropy = seq_len * mbs * vocab_size * (2 / 1024 / 1024) * 2 / tp
        base_activs = num_layers * decoder_layer_mib + cast_to_fp32 + sharded_cross_entropy
        
        # Apply activation reduction for FSDP checkpointing in ZeRO-3
        if zero_stage == 3 and full_checkpointing:
            activs = base_activs / dp  # Activation memory is reduced by dp factor with checkpointing
        else:
            activs = base_activs

    # Calculate aggregate metrics
    memory_usage_after_optimstates = (
        model_bf16 + 
        fp32_params + 
        fp32_grads + 
        optimstates + 
        ddp_grads_buffers + 
        zero3_buffers +
        overhead
    )

    memory_usage_before_optimstates = (
        model_bf16 + 
        fp32_params + 
        fp32_grads + 
        ddp_grads_buffers +
        zero3_buffers
    )

    memory_usage_peak_tbi = (
        model_bf16 + 
        fp32_params + 
        fp32_grads + 
        optimstates + 
        ddp_grads_buffers + 
        zero3_buffers +
        overhead + 
        activs
    )

    return {
        "Components": {
            "Model BF16": model_bf16,
            "FP32 Parameters": fp32_params,
            "FP32 Gradients": fp32_grads,
            "Optimizer States": optimstates,
            "DDP Gradient Buffers": ddp_grads_buffers,
            "ZeRO-3 Buffers": zero3_buffers,
            "Overhead": overhead,
            "Activations": activs,
        },
        "Aggregates": {
            "Memory Before Optimizer States": memory_usage_before_optimstates,
            "Memory After Optimizer States": memory_usage_after_optimstates,
            "Peak Memory (TBI)": memory_usage_peak_tbi
        }
    }

def plot_memory_breakdown(
    hidden_size, num_attention_heads, num_key_value_heads, num_layers, vocab_size, intermediate_size,
    seq_len, mbs, batch_accum, tp, pp, dp, zero_stage,
    tie_word_embeddings, full_checkpointing=False
):
    results = calculate_memory_components(
        hidden_size, num_attention_heads, num_key_value_heads, num_layers, vocab_size, intermediate_size,
        seq_len, mbs, batch_accum, tp, pp, dp, zero_stage,
        tie_word_embeddings, full_checkpointing
    )
    memory_usage_peak_tbi = results["Aggregates"]["Peak Memory (TBI)"]
    
    # Create figure for components plot
    plt.close('all')
    fig1 = plt.figure(figsize=(10, 5))
    ax1 = fig1.add_subplot(1, 1, 1)
    
    # Plot components
    components = results["Components"]
    names = list(components.keys())
    values = list(components.values())
    
    colors = plt.cm.Set3(np.linspace(0, 1, len(components)))
    color_map = dict(zip(names, colors))
    
    bars1 = ax1.bar(range(len(components)), values, color=colors)
    
    # Add value labels with better positioning
    for bar in bars1:
        height = bar.get_height()
        ax1.text(bar.get_x() + bar.get_width()/2., height,
                f'{height:.1f} MiB',
                ha='center', va='bottom',
                rotation=0)  # Remove rotation for better readability
    
    # Customize the first plot
    ax1.set_xticks(range(len(components)))
    ax1.set_xticklabels(names, rotation=45, ha='right')
    ax1.set_ylabel('Memory (MiB)')
    ax1.set_title('Memory Component Breakdown', pad=20)
    
    plt.tight_layout()
    
    # Create figure for timeline plot
    fig2 = plt.figure(figsize=(10, 6))
    ax2 = fig2.add_subplot(1, 1, 1)
    
    # Define timeline steps and their components
    c = results["Components"]
    timeline_steps = {
        "Model Init": [
            ("Model BF16", c["Model BF16"]),
            ("DDP Gradient Buffers", c["DDP Gradient Buffers"]),
            ("ZeRO-3 Buffers", c["ZeRO-3 Buffers"]),
        ],
        "Gradient Accumulator Init": [
            ("Model BF16", c["Model BF16"]),
            ("DDP Gradient Buffers", c["DDP Gradient Buffers"]),
            ("ZeRO-3 Buffers", c["ZeRO-3 Buffers"]),
            ("FP32 Parameters", c["FP32 Parameters"]),
            ("FP32 Gradients", c["FP32 Gradients"])
        ],
        "Fwd-Bwd Peak": [
            ("Model BF16", c["Model BF16"]),
            ("DDP Gradient Buffers", c["DDP Gradient Buffers"]),
            ("ZeRO-3 Buffers", c["ZeRO-3 Buffers"]),
            ("FP32 Parameters", c["FP32 Parameters"]),
            ("FP32 Gradients", c["FP32 Gradients"]),
            ("Activations", c["Activations"])
        ],
        "Optimizer Step": [
            ("Model BF16", c["Model BF16"]),
            ("ZeRO-3 Buffers", c["ZeRO-3 Buffers"]),
            ("FP32 Parameters", c["FP32 Parameters"]),
            ("FP32 Gradients", c["FP32 Gradients"]),
            ("Optimizer States", c["Optimizer States"])
        ],
        "2nd Fwd-Bwd Peak": [
            ("Model BF16", c["Model BF16"]),
            ("ZeRO-3 Buffers", c["ZeRO-3 Buffers"]),
            ("FP32 Parameters", c["FP32 Parameters"]),
            ("FP32 Gradients", c["FP32 Gradients"]),
            ("Optimizer States", c["Optimizer States"]),
            ("DDP Gradient Buffers", c["DDP Gradient Buffers"]),
            ("Activations", c["Activations"])
        ],
        "2nd Optimizer Step": [
            ("Model BF16", c["Model BF16"]),
            ("ZeRO-3 Buffers", c["ZeRO-3 Buffers"]),
            ("FP32 Parameters", c["FP32 Parameters"]),
            ("FP32 Gradients", c["FP32 Gradients"]),
            ("Optimizer States", c["Optimizer States"]),
            ("DDP Gradient Buffers", c["DDP Gradient Buffers"])
        ]
    }
    # Plot timeline
    x = range(len(timeline_steps))
    bottom = np.zeros(len(timeline_steps))
    
    
    for component in c.keys():
        heights = []
        for step_components in timeline_steps.values():
            height = 0
            for comp_name, comp_value in step_components:
                if comp_name == component:
                    height = comp_value
            heights.append(height)
        
        ax2.bar(x, heights, bottom=bottom, label=component, color=color_map[component])
        bottom += heights

    # Customize the timeline plot
    ax2.set_xticks(x)
    ax2.set_xticklabels(timeline_steps.keys(), rotation=45, ha='right')
    ax2.set_ylabel('Memory (MiB)')
    ax2.set_title('Memory Timeline', pad=20)

    
    # Add total memory labels on top of each bar
    for i, total in enumerate(bottom):
        ax2.text(i, total, f'{total:.1f} MiB', ha='center', va='bottom')
    
    # Adjust layout
    plt.tight_layout()
    # Set y-axis limit
    max_y_value = max(bottom)
    ax2.set_ylim(0, max(80000, max_y_value))
    
    # Add legend below the plot
    # plt.subplots_adjust(bottom=0.8)
    ax2.legend(loc='lower center', bbox_to_anchor=(0.5, -1.5), ncol=3)
    return fig1, fig2, memory_usage_peak_tbi