|
import gradio as gr |
|
from typing import Dict, Union |
|
from huggingface_hub import get_safetensors_metadata, hf_hub_download, login |
|
import json |
|
|
|
|
|
bytes_per_dtype: Dict[str, float] = { |
|
"int4": 0.5, |
|
"int8": 1, |
|
"float8": 1, |
|
"float16": 2, |
|
"float32": 4, |
|
} |
|
|
|
def extract_keys(json_obj, keys_to_extract): |
|
extracted_values = {} |
|
def recursive_search(obj): |
|
if isinstance(obj, dict): |
|
for key, value in obj.items(): |
|
if key in keys_to_extract: |
|
extracted_values[key] = value |
|
recursive_search(value) |
|
elif isinstance(obj, list): |
|
for item in obj: |
|
recursive_search(item) |
|
recursive_search(json_obj) |
|
return extracted_values |
|
|
|
def calculate_kv_cache_memory(context_size: int, model_id: str, dtype: str, token: str = None): |
|
try: |
|
file_path = hf_hub_download(repo_id=model_id, filename="config.json", token=token) |
|
with open(file_path, 'r') as f: |
|
config = json.load(f) |
|
|
|
keys_to_find = {"num_hidden_layers", "num_key_value_heads", "hidden_size", "num_attention_heads"} |
|
config = extract_keys(config, keys_to_find) |
|
|
|
num_layers = config["num_hidden_layers"] |
|
num_att_heads = config.get("num_key_value_heads", config["num_attention_heads"]) |
|
dim_att_head = config["hidden_size"] // config["num_attention_heads"] |
|
dtype_bytes = bytes_per_dtype[dtype] |
|
|
|
memory_per_token = num_layers * num_att_heads * dim_att_head * dtype_bytes * 2 |
|
context_size_memory_footprint_gb = (context_size * memory_per_token) / 1_000_000_000 |
|
|
|
return context_size_memory_footprint_gb |
|
except Exception as e: |
|
return f"Error: {str(e)}" |
|
|
|
def calculate_model_memory(parameters: float, dtype: str) -> float: |
|
bytes = bytes_per_dtype[dtype] |
|
return round((parameters * 4) / (32 / (bytes * 8)) * 1.18, 2) |
|
|
|
def get_model_size(model_id: str, dtype: str, token: str = None) -> Union[float, str]: |
|
try: |
|
metadata = get_safetensors_metadata(model_id, token=token) |
|
if not metadata or not metadata.parameter_count: |
|
return "Error: Could not fetch metadata." |
|
model_parameters = int(list(metadata.parameter_count.values())[0]) / 1_000_000_000 |
|
return calculate_model_memory(model_parameters, dtype) |
|
except Exception as e: |
|
return f"Error: {str(e)}" |
|
|
|
def estimate_vram(model_id, dtype, context_size, hf_token): |
|
if hf_token: |
|
login(token=hf_token) |
|
|
|
if dtype not in bytes_per_dtype: |
|
return "Error: Unsupported dtype" |
|
|
|
model_memory = get_model_size(model_id, dtype, hf_token) |
|
context_memory = calculate_kv_cache_memory(context_size, model_id, dtype, hf_token) |
|
|
|
if isinstance(model_memory, str) or isinstance(context_memory, str): |
|
return model_memory if isinstance(model_memory, str) else context_memory |
|
|
|
total_memory = model_memory + context_memory |
|
return f"Model VRAM: {model_memory:.2f} GB\nContext VRAM: {context_memory:.2f} GB\nTotal VRAM: {total_memory:.2f} GB" |
|
|
|
iface = gr.Interface( |
|
fn=estimate_vram, |
|
inputs=[ |
|
gr.Textbox(label="Hugging Face Model ID", value="google/gemma-3-27b-it"), |
|
gr.Dropdown(choices=list(bytes_per_dtype.keys()), label="Data Type", value="float16"), |
|
gr.Number(label="Context Size", value=128000), |
|
gr.Textbox(label="Hugging Face Access Token", type="password", placeholder="Optional - Needed for gated models") |
|
], |
|
outputs=gr.Textbox(label="Estimated VRAM Usage"), |
|
title="LLM GPU VRAM Calculator", |
|
description="Estimate the VRAM requirements of a model and context size. Optionally provide a Hugging Face token for gated models." |
|
) |
|
|
|
iface.launch() |
|
|