| import copy |
| import hashlib |
| import json |
| import tempfile |
| import threading |
| import time |
| import traceback |
| from collections import OrderedDict |
| from dataclasses import dataclass |
| from pathlib import Path |
| from urllib.parse import urlparse |
| from uuid import uuid4 |
|
|
| import accelerate |
| import gradio as gr |
| import huggingface_hub |
|
|
| try: |
| from gradio_huggingfacehub_search import HuggingfaceHubSearch |
| HAS_HF_HUB_SEARCH = True |
| except Exception: |
| HuggingfaceHubSearch = None |
| HAS_HF_HUB_SEARCH = False |
| import pandas as pd |
| import timm |
| import transformers |
| from accelerate.utils import convert_bytes |
|
|
| from model_utils import ( |
| calculate_memory, |
| get_model_normalized, |
| normalize_model_name, |
| preflight_model_access_normalized, |
| ) |
|
|
|
|
| DEFAULT_MODEL = "bert-base-cased" |
| DEFAULT_LIBRARY = "auto" |
| DEFAULT_OPTIONS = ["float32"] |
| RESULTS_CACHE_SIZE = 128 |
| DOWNLOAD_RETENTION_SECONDS = 60 * 60 |
| DOWNLOAD_CLEANUP_MAX_FILES = 256 |
|
|
|
|
| def log_startup_versions(): |
| print( |
| "[startup] versions " |
| f"gradio={gr.__version__} " |
| f"accelerate={accelerate.__version__} " |
| f"transformers={transformers.__version__} " |
| f"huggingface_hub={huggingface_hub.__version__} " |
| f"timm={timm.__version__}" |
| ) |
|
|
|
|
| log_startup_versions() |
|
|
|
|
| @dataclass(frozen=True) |
| class EstimateRequest: |
| original_model_name: str |
| normalized_model_name: str |
| library: str |
| options: tuple[str, ...] |
| access_token: str | None |
| auth_mode: str |
|
|
| @property |
| def cache_key(self): |
| token_key = "anonymous" |
| if self.access_token is not None: |
| token_key = hashlib.sha256(self.access_token.encode("utf-8")).hexdigest() |
| return ( |
| self.normalized_model_name, |
| self.library, |
| self.options, |
| token_key, |
| ) |
|
|
|
|
| @dataclass |
| class EstimatePayload: |
| display_rows: list[dict] |
| raw_rows: list[dict] |
| explanation: str |
| breakdown_df: pd.DataFrame |
|
|
|
|
| @dataclass |
| class EstimateViewModel: |
| title: str |
| auth_message: str |
| summary_df: pd.DataFrame |
| explanation: str |
| breakdown_df: pd.DataFrame |
| error_summary: str = "" |
| error_details: str = "" |
| summary_path: str | None = None |
| breakdown_path: str | None = None |
| json_path: str | None = None |
|
|
| def to_updates(self): |
| return [ |
| self.title, |
| gr.update(value=self.auth_message, visible=True), |
| gr.update(visible=not self.summary_df.empty, value=self.summary_df), |
| gr.update(visible=self.explanation != "", value=self.explanation), |
| gr.update(visible=not self.breakdown_df.empty, value=self.breakdown_df), |
| gr.update(visible=self.error_summary != "", value=self.error_summary), |
| gr.update(visible=self.error_details != "", value=self.error_details), |
| gr.update(visible=self.summary_path is not None, value=self.summary_path), |
| gr.update(visible=self.breakdown_path is not None, value=self.breakdown_path), |
| gr.update(visible=self.json_path is not None, value=self.json_path), |
| ] |
|
|
|
|
| @dataclass |
| class ResetViewModel: |
| model_name: str = DEFAULT_MODEL |
| library: str = DEFAULT_LIBRARY |
| options: list[str] | tuple[str, ...] = None |
| access_token: str = "" |
| title: str = "" |
|
|
| def __post_init__(self): |
| if self.options is None: |
| self.options = list(DEFAULT_OPTIONS) |
|
|
| def to_updates(self): |
| return [ |
| self.model_name, |
| self.library, |
| list(self.options), |
| self.access_token, |
| self.title, |
| gr.update(visible=False, value=""), |
| gr.update(visible=False, value=pd.DataFrame()), |
| gr.update(visible=False, value=""), |
| gr.update(visible=False, value=pd.DataFrame()), |
| gr.update(visible=False, value=""), |
| gr.update(visible=False, value=""), |
| gr.update(visible=False, value=None), |
| gr.update(visible=False, value=None), |
| gr.update(visible=False, value=None), |
| ] |
|
|
|
|
| @dataclass |
| class _InflightEntry: |
| event: threading.Event |
| data: list[dict] | None = None |
| error: Exception | None = None |
|
|
|
|
| class ResultCache: |
| def __init__(self, max_size: int): |
| self.max_size = max_size |
| self._values = OrderedDict() |
| self._lock = threading.Lock() |
| self._inflight: dict[tuple, _InflightEntry] = {} |
|
|
| def get_or_compute(self, request: EstimateRequest, compute_fn): |
| cache_key = request.cache_key |
|
|
| with self._lock: |
| if cache_key in self._values: |
| self._values.move_to_end(cache_key) |
| return copy.deepcopy(self._values[cache_key]) |
|
|
| entry = self._inflight.get(cache_key) |
| if entry is None: |
| entry = _InflightEntry(event=threading.Event()) |
| self._inflight[cache_key] = entry |
| is_owner = True |
| else: |
| is_owner = False |
|
|
| if not is_owner: |
| entry.event.wait() |
| if entry.error is not None: |
| raise entry.error |
| return copy.deepcopy(entry.data) |
|
|
| try: |
| data = compute_fn() |
| with self._lock: |
| self._values[cache_key] = copy.deepcopy(data) |
| if len(self._values) > self.max_size: |
| self._values.popitem(last=False) |
| entry.data = copy.deepcopy(data) |
| return copy.deepcopy(data) |
| except Exception as error: |
| entry.error = error |
| raise |
| finally: |
| entry.event.set() |
| with self._lock: |
| self._inflight.pop(cache_key, None) |
|
|
|
|
| RESULT_CACHE = ResultCache(max_size=RESULTS_CACHE_SIZE) |
|
|
|
|
| def get_auth_status(oauth_profile: gr.OAuthProfile | None): |
| if oauth_profile is None: |
| return "Not signed in. You can still paste an API token for gated models." |
|
|
| username = getattr(oauth_profile, "preferred_username", None) or getattr(oauth_profile, "name", None) |
| if username is None: |
| username = "Hugging Face user" |
|
|
| return ( |
| f"Signed in as `{username}`. " |
| "If the API Token field is blank, this session token will be used for gated models." |
| ) |
|
|
|
|
| def use_hub_search(repo_id: str | None): |
| return (repo_id or "").strip() |
|
|
|
|
| def get_hub_search_status(): |
| if HAS_HF_HUB_SEARCH: |
| return "Search Hugging Face Hub to fill the model field automatically." |
| return "Hub Search component is unavailable in this runtime. Manual model input still works." |
|
|
|
|
| def validate_model_name(model_name: str): |
| stripped_name = model_name.strip() |
| if stripped_name == "": |
| raise gr.Error("Enter a model name or a Hugging Face model URL.") |
|
|
| try: |
| parsed = urlparse(stripped_name) |
| if parsed.scheme and parsed.netloc: |
| valid_hosts = {"huggingface.co", "www.huggingface.co"} |
| if parsed.netloc not in valid_hosts: |
| raise gr.Error("Only Hugging Face model URLs are supported here.") |
| except gr.Error: |
| raise |
| except Exception: |
| pass |
|
|
| return stripped_name |
|
|
|
|
| def validate_options(options: list): |
| if not options: |
| raise gr.Error("Select at least one precision.") |
|
|
|
|
| def validate_access_token(access_token: str): |
| if access_token and any(char.isspace() for char in access_token): |
| raise gr.Error("API tokens should not contain whitespace.") |
|
|
|
|
| def resolve_access_token(access_token: str, oauth_token: gr.OAuthToken | None): |
| if access_token == "": |
| access_token = None |
|
|
| if access_token is not None: |
| return access_token, "manual" |
|
|
| if oauth_token is not None: |
| return oauth_token.token, "oauth" |
|
|
| return None, "anonymous" |
|
|
|
|
| def build_estimate_request( |
| model_name: str, |
| library: str, |
| options: list, |
| access_token: str, |
| oauth_token: gr.OAuthToken | None, |
| ): |
| stripped_name = validate_model_name(model_name) |
| validate_options(options) |
| validate_access_token(access_token) |
|
|
| normalized_name = normalize_model_name(stripped_name) |
| resolved_token, auth_mode = resolve_access_token(access_token, oauth_token) |
|
|
| return EstimateRequest( |
| original_model_name=stripped_name, |
| normalized_model_name=normalized_name, |
| library=library, |
| options=tuple(options), |
| access_token=resolved_token, |
| auth_mode=auth_mode, |
| ) |
|
|
|
|
| def get_auth_message(auth_mode: str): |
| if auth_mode == "manual": |
| return "Using the manually provided API token for this estimate." |
| if auth_mode == "oauth": |
| return "Using your Hugging Face OAuth session for this estimate." |
| return "Running anonymously. Gated models will require a token or a signed-in Hugging Face session." |
|
|
|
|
| def get_download_dir(): |
| temp_dir = Path(tempfile.gettempdir()) / "model_memory_usage" |
| temp_dir.mkdir(parents=True, exist_ok=True) |
| return temp_dir |
|
|
|
|
| def cleanup_old_download_files(temp_dir: Path): |
| cutoff = time.time() - DOWNLOAD_RETENTION_SECONDS |
|
|
| try: |
| entries = [path for path in temp_dir.iterdir() if path.is_file()] |
| except FileNotFoundError: |
| return |
|
|
| for path in entries: |
| try: |
| if path.stat().st_mtime < cutoff: |
| path.unlink(missing_ok=True) |
| except OSError: |
| continue |
|
|
| try: |
| remaining_files = sorted( |
| [path for path in temp_dir.iterdir() if path.is_file()], |
| key=lambda path: path.stat().st_mtime, |
| reverse=True, |
| ) |
| except FileNotFoundError: |
| return |
|
|
| for stale_path in remaining_files[DOWNLOAD_CLEANUP_MAX_FILES:]: |
| try: |
| stale_path.unlink(missing_ok=True) |
| except OSError: |
| continue |
|
|
|
|
| def make_download_files(model_name: str, summary_df: pd.DataFrame, breakdown_df: pd.DataFrame, raw_data: list): |
| safe_name = model_name.replace("/", "__") or "model" |
| temp_dir = get_download_dir() |
| cleanup_old_download_files(temp_dir) |
| unique_id = uuid4().hex |
|
|
| summary_path = temp_dir / f"{safe_name}_{unique_id}_summary.csv" |
| summary_df.to_csv(summary_path, index=False) |
|
|
| breakdown_path = None |
| if not breakdown_df.empty: |
| breakdown_path = temp_dir / f"{safe_name}_{unique_id}_adam_breakdown.csv" |
| breakdown_df.to_csv(breakdown_path, index=False) |
|
|
| json_path = temp_dir / f"{safe_name}_{unique_id}_estimate.json" |
| with json_path.open("w", encoding="utf-8") as handle: |
| json.dump({"model_name": model_name, "estimates": raw_data}, handle, indent=2) |
|
|
| return str(summary_path), str(breakdown_path) if breakdown_path is not None else None, str(json_path) |
|
|
|
|
| def fetch_raw_estimate_data(request: EstimateRequest): |
| def _compute(): |
| model = get_model_normalized( |
| request.normalized_model_name, |
| request.library, |
| request.access_token, |
| skip_auth_check=True, |
| ) |
| return calculate_memory(model, list(request.options)) |
|
|
| return RESULT_CACHE.get_or_compute(request, _compute) |
|
|
|
|
| def build_estimate_payload(raw_rows: list[dict], options: tuple[str, ...]): |
| display_rows = copy.deepcopy(raw_rows) |
| stages = {"model": [], "gradients": [], "optimizer": [], "step": []} |
|
|
| for index, option in enumerate(display_rows): |
| for stage in stages: |
| stages[stage].append(option["Training using Adam (Peak vRAM)"][stage]) |
|
|
| peak_value = max(display_rows[index]["Training using Adam (Peak vRAM)"].values()) |
| display_rows[index]["Training using Adam (Peak vRAM)"] = "N/A" if peak_value == -1 else convert_bytes(peak_value) |
|
|
| explanation = "" |
| breakdown_df = pd.DataFrame( |
| columns=["dtype", "Model", "Gradient calculation", "Backward pass", "Optimizer step"] |
| ) |
|
|
| if any(value != -1 for value in stages["model"]): |
| explanation = "## Training using Adam explained:\n" |
| explanation += ( |
| "When training on a batch size of 1, each stage of the training process is expected " |
| "to have near the following memory results for each precision you selected:\n" |
| ) |
|
|
| for index, dtype in enumerate(options): |
| if stages["model"][index] != -1: |
| breakdown_df.loc[len(breakdown_df.index)] = [ |
| dtype, |
| convert_bytes(stages["model"][index]), |
| convert_bytes(stages["gradients"][index]), |
| convert_bytes(stages["optimizer"][index]), |
| convert_bytes(stages["step"][index]), |
| ] |
|
|
| return EstimatePayload( |
| display_rows=display_rows, |
| raw_rows=copy.deepcopy(raw_rows), |
| explanation=explanation, |
| breakdown_df=breakdown_df, |
| ) |
|
|
|
|
| def build_success_view_model(request: EstimateRequest, payload: EstimatePayload): |
| auth_message = get_auth_message(request.auth_mode) |
| summary_df = pd.DataFrame(payload.display_rows) |
| summary_path, breakdown_path, json_path = make_download_files( |
| request.normalized_model_name, |
| summary_df, |
| payload.breakdown_df, |
| payload.raw_rows, |
| ) |
| return EstimateViewModel( |
| title=f"## Static memory estimate for `{request.normalized_model_name}`", |
| auth_message=auth_message, |
| summary_df=summary_df, |
| explanation=payload.explanation, |
| breakdown_df=payload.breakdown_df, |
| summary_path=summary_path, |
| breakdown_path=breakdown_path, |
| json_path=json_path, |
| ) |
|
|
|
|
| def build_error_view_model(request: EstimateRequest, error: Exception): |
| auth_message = get_auth_message(request.auth_mode) |
| message = str(error).strip() or error.__class__.__name__ |
| details = traceback.format_exc().strip() |
| return EstimateViewModel( |
| title=f"## Unable to estimate memory for `{request.normalized_model_name}`", |
| auth_message=auth_message, |
| summary_df=pd.DataFrame(), |
| explanation="", |
| breakdown_df=pd.DataFrame(), |
| error_summary=( |
| f"{message}\n\n" |
| "Check the **Details** section below for the full traceback." |
| ), |
| error_details=details, |
| ) |
|
|
|
|
| def reset_app(): |
| return ResetViewModel().to_updates() |
|
|
|
|
| def get_results( |
| model_name: str, |
| library: str, |
| options: list, |
| access_token: str, |
| oauth_token: gr.OAuthToken | None, |
| progress=gr.Progress(track_tqdm=False), |
| ): |
| progress(0.05, desc="Checking inputs") |
| request = build_estimate_request(model_name, library, options, access_token, oauth_token) |
|
|
| try: |
| progress(0.12, desc="Checking Hub access") |
| preflight_model_access_normalized(request.normalized_model_name, request.access_token) |
|
|
| progress(0.3, desc="Building model skeleton") |
| raw_rows = fetch_raw_estimate_data(request) |
|
|
| progress(0.75, desc="Formatting results") |
| payload = build_estimate_payload(raw_rows, request.options) |
|
|
| progress(0.95, desc="Writing downloads") |
| view_model = build_success_view_model(request, payload) |
| progress(1.0, desc="Done") |
| return view_model.to_updates() |
| except Exception as error: |
| progress(1.0, desc="Failed") |
| return build_error_view_model(request, error).to_updates() |
|
|
|
|
| with gr.Blocks(delete_cache=(3600, DOWNLOAD_RETENTION_SECONDS)) as demo: |
| with gr.Column(): |
| gr.HTML( |
| """<img src="https://huggingface.co/spaces/hf-accelerate/model-memory-usage/resolve/main/measure_model_size.png" style="float: left;" width="250" height="250"><h1>🤗 Model Memory Calculator</h1> |
| <p>This tool provides a static memory estimate for the vRAM needed to load and train Hub models.</p> |
| <p>The minimum recommended vRAM needed to load a model is denoted as the size of the "largest layer", and training of a model is roughly 4x its size (for Adam).</p> |
| <p>These calculations are accurate within a few percent at most, such as <code>bert-base-cased</code> being 413.68 MB and the calculator estimating 413.18 MB.</p> |
| <p>When performing inference, expect to add up to an additional 20% to this as found by <a href="https://blog.eleuther.ai/transformer-math/" target="_blank">EleutherAI</a>.</p> |
| <p>More tests will be performed in the future to get a more accurate benchmark for each model.</p> |
| <p>Currently this tool supports all models hosted that use <code>transformers</code> and <code>timm</code>.</p> |
| <p>To use this tool pass in the URL or model name of the model you want to calculate the memory usage for, select which framework it originates from (<code>auto</code> will try and detect it from the model metadata), and what precisions you want to use.</p>""" |
| ) |
|
|
| with gr.Group(): |
| with gr.Row(equal_height=True): |
| inp = gr.Textbox(label="Model Name or URL", value=DEFAULT_MODEL) |
|
|
| with gr.Column(): |
| if HAS_HF_HUB_SEARCH: |
| hub_search = HuggingfaceHubSearch( |
| label="Search Hugging Face Hub", |
| placeholder="Search for models on Hugging Face", |
| search_type="model", |
| sumbit_on_select=True, |
| ) |
| hub_search_status = gr.Markdown(get_hub_search_status()) |
| else: |
| hub_search = None |
| hub_search_status = gr.Markdown(get_hub_search_status()) |
|
|
| with gr.Row(equal_height=True): |
| library = gr.Radio(["auto", "transformers", "timm"], label="Library", value=DEFAULT_LIBRARY) |
| options = gr.CheckboxGroup( |
| ["float32", "float16/bfloat16", "int8", "int4"], |
| value=DEFAULT_OPTIONS, |
| label="Model Precision", |
| ) |
|
|
| with gr.Column(): |
| gr.LoginButton() |
| access_token = gr.Textbox( |
| label="API Token", |
| placeholder="Optional. If blank, your Sign in with HF session will be used for gated models.", |
| ) |
| auth_status = gr.Markdown("Not signed in. You can still paste an API token for gated models.") |
| run_auth_status = gr.Markdown(visible=False) |
|
|
| with gr.Group(): |
| with gr.Row(equal_height=True): |
| btn = gr.Button("Calculate Memory Usage") |
| reset_btn = gr.Button("Reset") |
|
|
| out_text = gr.Markdown() |
| error_text = gr.Markdown(visible=False) |
| out = gr.DataFrame( |
| headers=["dtype", "Largest Layer", "Total Size", "Training using Adam (Peak vRAM)"], |
| interactive=False, |
| visible=False, |
| ) |
| out_explain = gr.Markdown(visible=False) |
| memory_values = gr.DataFrame( |
| headers=["dtype", "Model", "Gradient calculation", "Backward pass", "Optimizer step"], |
| interactive=False, |
| visible=False, |
| ) |
|
|
| with gr.Accordion("Downloads", open=False): |
| summary_file = gr.File(label="Summary CSV", visible=False) |
| breakdown_file = gr.File(label="Adam Breakdown CSV", visible=False) |
| json_file = gr.File(label="Full JSON", visible=False) |
|
|
| with gr.Accordion("Details", open=False): |
| error_details = gr.Textbox( |
| label="Error Details", |
| lines=12, |
| interactive=False, |
| visible=False, |
| ) |
|
|
| demo.load( |
| get_auth_status, |
| inputs=None, |
| outputs=auth_status, |
| api_name=False, |
| queue=False, |
| ) |
|
|
| if HAS_HF_HUB_SEARCH: |
| gr.on( |
| triggers=[hub_search.submit], |
| fn=use_hub_search, |
| inputs=[hub_search], |
| outputs=[inp], |
| api_name=False, |
| show_progress="hidden", |
| queue=False, |
| ) |
|
|
| gr.on( |
| triggers=[btn.click, inp.submit], |
| fn=get_results, |
| inputs=[inp, library, options, access_token], |
| outputs=[ |
| out_text, |
| run_auth_status, |
| out, |
| out_explain, |
| memory_values, |
| error_text, |
| error_details, |
| summary_file, |
| breakdown_file, |
| json_file, |
| ], |
| show_api=False, |
| show_progress="minimal", |
| concurrency_limit=1, |
| concurrency_id="memory-estimate", |
| ) |
|
|
| reset_btn.click( |
| reset_app, |
| inputs=None, |
| outputs=[ |
| inp, |
| library, |
| options, |
| access_token, |
| out_text, |
| run_auth_status, |
| out, |
| out_explain, |
| memory_values, |
| error_text, |
| error_details, |
| summary_file, |
| breakdown_file, |
| json_file, |
| ], |
| api_name=False, |
| show_progress="hidden", |
| queue=False, |
| ) |
|
|
|
|
| demo.queue(default_concurrency_limit=1, max_size=24) |
| demo.launch() |
|
|