File size: 3,476 Bytes
addbb37
 
 
 
5ee5935
addbb37
 
f8eee5a
addbb37
 
 
 
ac8821a
db805e9
 
ac8821a
f8eee5a
 
 
ac8821a
addbb37
5ee5935
 
 
addbb37
ac8821a
f8eee5a
 
ac8821a
f8eee5a
 
 
addbb37
 
 
 
 
f8eee5a
 
 
 
 
 
addbb37
 
f8eee5a
ac8821a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
addbb37
f8eee5a
 
 
 
addbb37
 
 
 
f8eee5a
 
addbb37
 
 
 
5ee5935
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
addbb37
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import gradio as gr
import matplotlib.pyplot as plt


def plot_forecast(num_param, precision, grad_ckpt, batch_size, seq_len):
    # Convert number (input as B)
    num_param = float(num_param) * 1e9

    # Convert precision to bytes
    precision = {"float32": 4, "float16": 2, "bfloat16": 2}[precision]

    # Model Parameters: N×precision
    y1 = num_param * precision / 1e9

    # Optimizer States: 2×N×precision
    y2 = 2 * num_param * precision / 1e9

    # Activations: B×Sequence Length×K×precision
    K = 4.6894e-4 * num_param + 1.8494e6
    y3 = batch_size * seq_len * K * precision / 1e9

    if grad_ckpt:
        y3 /= 5

    # Gradients: N×precision
    y4 = num_param * precision / 1e9

    # Optimizer intermediates: N×precision
    y5 = num_param * precision / 1e9

    # Calculate total memory
    total_memory = y1 + y2 + max(y3, y4 + y5)

    fig = plt.figure(figsize=(4, 4))
    ax = fig.add_subplot(111)

    # Create stacked bars
    bar_width = 0.5
    ax.bar(0, y1, width=bar_width, color="r")
    ax.bar(0, y2, bottom=y1, width=bar_width, color="b")
    ax.bar(-bar_width / 4, y3, bottom=y1 + y2, width=bar_width / 2, color="g")
    ax.bar(bar_width / 4, y4, bottom=y1 + y2, width=bar_width / 2, color="y")
    ax.bar(bar_width / 4, y5, bottom=y1 + y2 + y4, width=bar_width / 2, color="c")

    # Add text labels inside the bars
    ax.text(0, y1 / 2, f"Model Parameters ({y1:.1f} GB)", ha="center", va="center", color="white", fontweight="bold")
    ax.text(
        0, y1 + y2 / 2, f"Optimizer States ({y2:.1f} GB)", ha="center", va="center", color="white", fontweight="bold"
    )
    ax.text(
        -bar_width / 4,
        y1 + y2 + y3 / 2,
        f"Activations\n({y3:.1f} GB)",
        ha="center",
        va="center",
        color="white",
        fontweight="bold",
    )
    ax.text(
        bar_width / 4,
        y1 + y2 + y4 / 2,
        f"Gradients\n({y4:.1f} GB)",
        ha="center",
        va="center",
        color="white",
        fontweight="bold",
    )
    ax.text(
        bar_width / 4,
        y1 + y2 + y4 + y5 / 2,
        f"Optimizer\nintermediates\n({y5:.1f} GB)",
        ha="center",
        va="center",
        color="white",
        fontweight="bold",
    )

    # Or as title
    ax.set_title(f"Total Memory: {total_memory:.1f} GB", fontweight="bold")

    # Remove x-axis
    ax.xaxis.set_visible(False)

    # Set GB as the unit for the y-axis
    ax.set_ylabel("Memory (GB)")

    # Adjust layout
    fig.tight_layout()
    return fig


with gr.Blocks() as demo:
    with gr.Row():
        with gr.Column():
            with gr.Accordion("Model"):
                num_param = gr.Number(3, label="Number of parameters (B)")
                precision = gr.Radio(["float32", "float16", "bfloat16"], value="float32", label="Precision")
            with gr.Accordion("Data"):
                batch_size = gr.Slider(1, 128, label="Batch size", step=1, value=8)
                seq_len = gr.Slider(1, 1000, label="Sequence Length", step=1, value=256)

            with gr.Accordion("Advanced", open=False):
                with gr.Accordion("Data"):
                    grad_ckpt = gr.Checkbox(False, label="Gradient Checkpointing")

            submit = gr.Button("Submit")

        with gr.Column():
            plot = gr.Plot(label="forecast", format="png")

    submit.click(plot_forecast, [num_param, precision, grad_ckpt, batch_size, seq_len], plot)

if __name__ == "__main__":
    demo.launch()