import gradio as gr import pandas as pd from hub_utils import check_for_discussion, report_results from model_utils import calculate_memory, get_model # We need to store them as globals because gradio doesn't have a way for us to pass them in to the button MODEL = None def get_results(model_name: str, library: str, options: list, access_token: str): global MODEL MODEL = get_model(model_name, library, access_token) has_discussion = check_for_discussion(model_name) title = f"## Memory usage for '{model_name}'" data = calculate_memory(MODEL, options) return [title, gr.update(visible=True, value=pd.DataFrame(data)), gr.update(visible=not has_discussion)] with gr.Blocks() as demo: with gr.Column(): gr.Markdown( """

🤗 Model Memory Calculator

This tool will help you calculate how much vRAM is needed to train and perform big model inference on a model hosted on the 🤗 Hugging Face Hub. The minimum recommended vRAM needed for a model is denoted as the size of the "largest layer", and training of a model is roughly 4x its size (for Adam). These calculations are accurate within a few percent at most, such as `bert-base-cased` being 413.68 MB and the calculator estimating 413.18 MB. When performing inference, expect to add up to an additional 20% to this as found by [EleutherAI](https://blog.eleuther.ai/transformer-math/). More tests will be performed in the future to get a more accurate benchmark for each model. Currently this tool supports all models hosted that use `transformers` and `timm`. To use this tool pass in the URL or model name of the model you want to calculate the memory usage for, select which framework it originates from ("auto" will try and detect it from the model metadata), and what precisions you want to use.""" ) out_text = gr.Markdown() out = gr.DataFrame( headers=["dtype", "Largest Layer", "Total Size", "Training using Adam"], interactive=False, visible=False, ) with gr.Row(): inp = gr.Textbox(label="Model Name or URL", value="bert-base-cased") with gr.Row(): library = gr.Radio(["auto", "transformers", "timm"], label="Library", value="auto") options = gr.CheckboxGroup( ["float32", "float16/bfloat16", "int8", "int4"], value="float32", label="Model Precision", ) access_token = gr.Textbox(label="API Token", placeholder="Optional (for gated models)") with gr.Row(): btn = gr.Button("Calculate Memory Usage") post_to_hub = gr.Button( value="Report results in this model repo's discussions!\n(Will open in a new tab)", visible=False ) btn.click( get_results, inputs=[inp, library, options, access_token], outputs=[out_text, out, post_to_hub], ) post_to_hub.click(report_results, inputs=[inp, library, access_token]).then( lambda: gr.Button.update(visible=False), outputs=post_to_hub ) demo.launch()