File size: 3,808 Bytes
d16240b 76ba794 d16240b 76ba794 d16240b 76ba794 d16240b 76ba794 d16240b 76ba794 d16240b 76ba794 d16240b 76ba794 d16240b 76ba794 d16240b 76ba794 d16240b 76ba794 d16240b 76ba794 d16240b 76ba794 d16240b 76ba794 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 |
import gradio as gr
from typing import Dict, Union
from huggingface_hub import get_safetensors_metadata, hf_hub_download, login
import json
# Dictionary mapping dtype strings to their byte sizes
bytes_per_dtype: Dict[str, float] = {
"int4": 0.5,
"int8": 1,
"float8": 1,
"float16": 2,
"float32": 4,
}
def extract_keys(json_obj, keys_to_extract):
extracted_values = {}
def recursive_search(obj):
if isinstance(obj, dict):
for key, value in obj.items():
if key in keys_to_extract:
extracted_values[key] = value
recursive_search(value)
elif isinstance(obj, list):
for item in obj:
recursive_search(item)
recursive_search(json_obj)
return extracted_values
def calculate_kv_cache_memory(context_size: int, model_id: str, dtype: str, token: str = None):
try:
file_path = hf_hub_download(repo_id=model_id, filename="config.json", token=token)
with open(file_path, 'r') as f:
config = json.load(f)
keys_to_find = {"num_hidden_layers", "num_key_value_heads", "hidden_size", "num_attention_heads"}
config = extract_keys(config, keys_to_find)
num_layers = config["num_hidden_layers"]
num_att_heads = config.get("num_key_value_heads", config["num_attention_heads"])
dim_att_head = config["hidden_size"] // config["num_attention_heads"]
dtype_bytes = bytes_per_dtype[dtype]
memory_per_token = num_layers * num_att_heads * dim_att_head * dtype_bytes * 2
context_size_memory_footprint_gb = (context_size * memory_per_token) / 1_000_000_000
return context_size_memory_footprint_gb
except Exception as e:
return f"Error: {str(e)}"
def calculate_model_memory(parameters: float, dtype: str) -> float:
bytes = bytes_per_dtype[dtype]
return round((parameters * 4) / (32 / (bytes * 8)) * 1.18, 2)
def get_model_size(model_id: str, dtype: str, token: str = None) -> Union[float, str]:
try:
metadata = get_safetensors_metadata(model_id, token=token)
if not metadata or not metadata.parameter_count:
return "Error: Could not fetch metadata."
model_parameters = int(list(metadata.parameter_count.values())[0]) / 1_000_000_000
return calculate_model_memory(model_parameters, dtype)
except Exception as e:
return f"Error: {str(e)}"
def estimate_vram(model_id, dtype, context_size, hf_token):
if hf_token:
login(token=hf_token)
if dtype not in bytes_per_dtype:
return "Error: Unsupported dtype"
model_memory = get_model_size(model_id, dtype, hf_token)
context_memory = calculate_kv_cache_memory(context_size, model_id, dtype, hf_token)
if isinstance(model_memory, str) or isinstance(context_memory, str):
return model_memory if isinstance(model_memory, str) else context_memory
total_memory = model_memory + context_memory
return f"Model VRAM: {model_memory:.2f} GB\nContext VRAM: {context_memory:.2f} GB\nTotal VRAM: {total_memory:.2f} GB"
iface = gr.Interface(
fn=estimate_vram,
inputs=[
gr.Textbox(label="Hugging Face Model ID", value="google/gemma-3-27b-it"),
gr.Dropdown(choices=list(bytes_per_dtype.keys()), label="Data Type", value="float16"),
gr.Number(label="Context Size", value=128000),
gr.Textbox(label="Hugging Face Access Token", type="password", placeholder="Optional - Needed for gated models")
],
outputs=gr.Textbox(label="Estimated VRAM Usage"),
title="LLM GPU VRAM Calculator",
description="Estimate the VRAM requirements of a model and context size. Optionally provide a Hugging Face token for gated models."
)
iface.launch()
|