Spaces:
Running
Running
File size: 3,476 Bytes
addbb37 5ee5935 addbb37 f8eee5a addbb37 ac8821a db805e9 ac8821a f8eee5a ac8821a addbb37 5ee5935 addbb37 ac8821a f8eee5a ac8821a f8eee5a addbb37 f8eee5a addbb37 f8eee5a ac8821a addbb37 f8eee5a addbb37 f8eee5a addbb37 5ee5935 addbb37 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
import gradio as gr
import matplotlib.pyplot as plt
def plot_forecast(num_param, precision, grad_ckpt, batch_size, seq_len):
# Convert number (input as B)
num_param = float(num_param) * 1e9
# Convert precision to bytes
precision = {"float32": 4, "float16": 2, "bfloat16": 2}[precision]
# Model Parameters: N×precision
y1 = num_param * precision / 1e9
# Optimizer States: 2×N×precision
y2 = 2 * num_param * precision / 1e9
# Activations: B×Sequence Length×K×precision
K = 4.6894e-4 * num_param + 1.8494e6
y3 = batch_size * seq_len * K * precision / 1e9
if grad_ckpt:
y3 /= 5
# Gradients: N×precision
y4 = num_param * precision / 1e9
# Optimizer intermediates: N×precision
y5 = num_param * precision / 1e9
# Calculate total memory
total_memory = y1 + y2 + max(y3, y4 + y5)
fig = plt.figure(figsize=(4, 4))
ax = fig.add_subplot(111)
# Create stacked bars
bar_width = 0.5
ax.bar(0, y1, width=bar_width, color="r")
ax.bar(0, y2, bottom=y1, width=bar_width, color="b")
ax.bar(-bar_width / 4, y3, bottom=y1 + y2, width=bar_width / 2, color="g")
ax.bar(bar_width / 4, y4, bottom=y1 + y2, width=bar_width / 2, color="y")
ax.bar(bar_width / 4, y5, bottom=y1 + y2 + y4, width=bar_width / 2, color="c")
# Add text labels inside the bars
ax.text(0, y1 / 2, f"Model Parameters ({y1:.1f} GB)", ha="center", va="center", color="white", fontweight="bold")
ax.text(
0, y1 + y2 / 2, f"Optimizer States ({y2:.1f} GB)", ha="center", va="center", color="white", fontweight="bold"
)
ax.text(
-bar_width / 4,
y1 + y2 + y3 / 2,
f"Activations\n({y3:.1f} GB)",
ha="center",
va="center",
color="white",
fontweight="bold",
)
ax.text(
bar_width / 4,
y1 + y2 + y4 / 2,
f"Gradients\n({y4:.1f} GB)",
ha="center",
va="center",
color="white",
fontweight="bold",
)
ax.text(
bar_width / 4,
y1 + y2 + y4 + y5 / 2,
f"Optimizer\nintermediates\n({y5:.1f} GB)",
ha="center",
va="center",
color="white",
fontweight="bold",
)
# Or as title
ax.set_title(f"Total Memory: {total_memory:.1f} GB", fontweight="bold")
# Remove x-axis
ax.xaxis.set_visible(False)
# Set GB as the unit for the y-axis
ax.set_ylabel("Memory (GB)")
# Adjust layout
fig.tight_layout()
return fig
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
with gr.Accordion("Model"):
num_param = gr.Number(3, label="Number of parameters (B)")
precision = gr.Radio(["float32", "float16", "bfloat16"], value="float32", label="Precision")
with gr.Accordion("Data"):
batch_size = gr.Slider(1, 128, label="Batch size", step=1, value=8)
seq_len = gr.Slider(1, 1000, label="Sequence Length", step=1, value=256)
with gr.Accordion("Advanced", open=False):
with gr.Accordion("Data"):
grad_ckpt = gr.Checkbox(False, label="Gradient Checkpointing")
submit = gr.Button("Submit")
with gr.Column():
plot = gr.Plot(label="forecast", format="png")
submit.click(plot_forecast, [num_param, precision, grad_ckpt, batch_size, seq_len], plot)
if __name__ == "__main__":
demo.launch()
|