|
import gradio as gr |
|
import pandas as pd |
|
from hub_utils import check_for_discussion, report_results |
|
from model_utils import calculate_memory, get_model |
|
from huggingface_hub.utils import HfHubHTTPError |
|
|
|
|
|
|
|
MODEL = None |
|
|
|
|
|
def get_results(model_name: str, library: str, precision: list, training: list, access_token: str, zero_stage: int, num_nodes: int, num_gpus: int, offloading: list, zero_init: list, additional_buffer_factor: float): |
|
global MODEL |
|
MODEL = get_model(model_name, library, access_token) |
|
try: |
|
has_discussion = check_for_discussion(model_name) |
|
except HfHubHTTPError: |
|
has_discussion = True |
|
|
|
options = { |
|
"precision": precision, |
|
"zero_stage": zero_stage, |
|
"cpu_offload": True if "Optimizer" in offloading else False, |
|
"cpu_offload_params": True if "Parameters" in offloading else False, |
|
"zero_init": True if "zero.Init" in zero_init else False, |
|
"num_nodes": num_nodes, |
|
"num_gpus_per_node": num_gpus, |
|
"training_regime": training, |
|
"additional_buffer_factor": additional_buffer_factor |
|
} |
|
data = calculate_memory(MODEL, options) |
|
|
|
title = f"## Memory usage for '{model_name}'" |
|
return [title, gr.update(visible=True, value=pd.DataFrame(data)), gr.update(visible=not has_discussion)] |
|
|
|
|
|
with gr.Blocks() as demo: |
|
with gr.Column(): |
|
gr.Markdown( |
|
"""<img src="https://huggingface.co/spaces/andstor/deepspeed-model-memory-usage/resolve/main/measure_model_size_deepspeed.svg" style="float: left;" width="250" height="250"><h1>🤗 DeepSpeed Model Memory Calculator</h1> |
|
|
|
This tool will help you calculate how much memory is required for the various Zero Redundancy Optimizer (ZeRO), given a model hosted on the 🤗 Hugging Face Hub and a hardware setup. The optimizer states assume that Adam is used. |
|
|
|
To use this tool pass in the URL or model name of the model you want to calculate the memory usage for, |
|
select which framework it originates from ("auto" will try and detect it from the model metadata), and |
|
what precisions you want to use. Then select the select the desired ZeRO configuration.""" |
|
) |
|
out_text = gr.Markdown() |
|
out = gr.DataFrame( |
|
headers=["dtype", "Largest Layer", "Total Size", "Training using Adam"], |
|
interactive=False, |
|
visible=False, |
|
) |
|
with gr.Row(): |
|
inp = gr.Textbox(label="Model Name or URL", value="bert-base-cased") |
|
with gr.Row(): |
|
library = gr.Radio(["auto", "transformers", "timm"], label="Library", value="auto") |
|
precision = gr.CheckboxGroup( |
|
["float32", "float16/bfloat16"], |
|
value="float32", |
|
label="Model Precision", |
|
) |
|
training = gr.Radio( |
|
["Mixed precision", "Single precision"], |
|
value="Mixed precision", |
|
label="Training Paradigm", |
|
) |
|
access_token = gr.Textbox(label="API Token", placeholder="Optional (for gated models)") |
|
num_gpus = gr.Number(label="GPUs per node", value=4, minimum=1, step=1) |
|
num_nodes = gr.Number(label="Nodes", value=1, minimum=1, step=1) |
|
with gr.Column(variant="panel"): |
|
with gr.Row(equal_height=True): |
|
|
|
zero_stage = gr.Radio(["Stage 0", "Stage 1", "Stage 2", "Stage 3"], label="ZeRO Stage", value="Stage 3", type="index") |
|
zero_description = gr.CheckboxGroup(["Optimizer state", "Gradients", "Parameters"], label="Partitioning", value=["Optimizer state", "Gradients", "Parameters"], interactive=False) |
|
|
|
with gr.Row(equal_height=True): |
|
|
|
offloading = gr.CheckboxGroup(["Optimizer", "Parameters"], label="ZeRO-Offload", info="Offloading data and compute to CPU", value=["Optimizer", "Parameters"]) |
|
zero_init = gr.CheckboxGroup(["zero.Init"], value=["zero.Init"], label="Initialization") |
|
|
|
|
|
additional_buffer_factor = gr.Number(label="Additional Buffer Factor", value=1.5, minimum=1, step=0.1) |
|
with gr.Row(): |
|
btn = gr.Button("Calculate Memory Usage") |
|
post_to_hub = gr.Button( |
|
value="Report results in this model repo's discussions!\n(Will open in a new tab)", visible=False |
|
) |
|
|
|
def change_zero_settings(evt: gr.SelectData): |
|
if evt.index == 0: |
|
return [gr.update(visible = False), gr.update(visible = False)] |
|
if evt.index == 1 or evt.index == 2: |
|
return [gr.update(choices=["Optimizer"], visible=True), gr.update(visible = False)] |
|
if evt.index == 3: |
|
return [gr.update(choices=["Optimizer", "Parameters"], visible=True), gr.update(visible = True)] |
|
|
|
def change_zero_description(evt: gr.SelectData): |
|
if evt.index == 0: |
|
return gr.update(value=None) |
|
if evt.index == 1: |
|
return gr.update(value=["Optimizer state"]) |
|
if evt.index == 2: |
|
return gr.update(value=["Optimizer state", "Gradients"]) |
|
if evt.index == 3: |
|
return gr.update(value=["Optimizer state", "Gradients", "Parameters"]) |
|
|
|
def change_offloading(evt: gr.SelectData, zero_stage): |
|
|
|
if evt.value == "Optimizer" and evt.selected == False: |
|
return gr.CheckboxGroup.update(choices=["Optimizer"], value=[]) |
|
|
|
if evt.value == "Optimizer" and evt.selected == True: |
|
if zero_stage in [1, 2]: |
|
return gr.CheckboxGroup.update(choices=["Optimizer"], value=["Optimizer"]) |
|
elif zero_stage == 3: |
|
return gr.CheckboxGroup.update(choices=["Optimizer", "Parameters"], value=["Optimizer"]) |
|
|
|
if evt.value == "Parameters" and evt.selected == False: |
|
return gr.CheckboxGroup.update(value=["Optimizer"]) |
|
|
|
if evt.value == "Parameters" and evt.selected == True: |
|
|
|
return gr.CheckboxGroup.update(value=["Optimizer", "Parameters"]) |
|
|
|
|
|
|
|
zero_stage.select(change_zero_settings, None, [offloading, zero_init]) |
|
zero_stage.select(change_zero_description, None, zero_description) |
|
offloading.select(change_offloading, zero_stage, offloading) |
|
|
|
|
|
btn.click( |
|
get_results, |
|
inputs=[inp, library, precision, training, access_token, zero_stage, num_nodes, num_gpus, offloading, zero_init, additional_buffer_factor], |
|
outputs=[out_text, out, post_to_hub], |
|
) |
|
|
|
post_to_hub.click(lambda: gr.Button.update(visible=False), outputs=post_to_hub).then( |
|
report_results, inputs=[inp, library, access_token] |
|
) |
|
|
|
|
|
demo.launch() |
|
|