import gradio as gr from typing import Dict, Union from huggingface_hub import get_safetensors_metadata, hf_hub_download, login import json # Dictionary mapping dtype strings to their byte sizes bytes_per_dtype: Dict[str, float] = { "int4": 0.5, "int8": 1, "float8": 1, "float16": 2, "float32": 4, } def extract_keys(json_obj, keys_to_extract): extracted_values = {} def recursive_search(obj): if isinstance(obj, dict): for key, value in obj.items(): if key in keys_to_extract: extracted_values[key] = value recursive_search(value) elif isinstance(obj, list): for item in obj: recursive_search(item) recursive_search(json_obj) return extracted_values def calculate_kv_cache_memory(context_size: int, model_id: str, dtype: str, token: str = None): try: file_path = hf_hub_download(repo_id=model_id, filename="config.json", token=token) with open(file_path, 'r') as f: config = json.load(f) keys_to_find = {"num_hidden_layers", "num_key_value_heads", "hidden_size", "num_attention_heads"} config = extract_keys(config, keys_to_find) num_layers = config["num_hidden_layers"] num_att_heads = config.get("num_key_value_heads", config["num_attention_heads"]) dim_att_head = config["hidden_size"] // config["num_attention_heads"] dtype_bytes = bytes_per_dtype[dtype] memory_per_token = num_layers * num_att_heads * dim_att_head * dtype_bytes * 2 context_size_memory_footprint_gb = (context_size * memory_per_token) / 1_000_000_000 return context_size_memory_footprint_gb except Exception as e: return f"Error: {str(e)}" def calculate_model_memory(parameters: float, dtype: str) -> float: bytes = bytes_per_dtype[dtype] return round((parameters * 4) / (32 / (bytes * 8)) * 1.18, 2) def get_model_size(model_id: str, dtype: str, token: str = None) -> Union[float, str]: try: metadata = get_safetensors_metadata(model_id, token=token) if not metadata or not metadata.parameter_count: return "Error: Could not fetch metadata." model_parameters = int(list(metadata.parameter_count.values())[0]) / 1_000_000_000 return calculate_model_memory(model_parameters, dtype) except Exception as e: return f"Error: {str(e)}" def estimate_vram(model_id, dtype, context_size, hf_token): if hf_token: login(token=hf_token) if dtype not in bytes_per_dtype: return "Error: Unsupported dtype" model_memory = get_model_size(model_id, dtype, hf_token) context_memory = calculate_kv_cache_memory(context_size, model_id, dtype, hf_token) if isinstance(model_memory, str) or isinstance(context_memory, str): return model_memory if isinstance(model_memory, str) else context_memory total_memory = model_memory + context_memory return f"Model VRAM: {model_memory:.2f} GB\nContext VRAM: {context_memory:.2f} GB\nTotal VRAM: {total_memory:.2f} GB" iface = gr.Interface( fn=estimate_vram, inputs=[ gr.Textbox(label="Hugging Face Model ID", value="google/gemma-3-27b-it"), gr.Dropdown(choices=list(bytes_per_dtype.keys()), label="Data Type", value="float16"), gr.Number(label="Context Size", value=128000), gr.Textbox(label="Hugging Face Access Token", type="password", placeholder="Optional - Needed for gated models") ], outputs=gr.Textbox(label="Estimated VRAM Usage"), title="LLM GPU VRAM Calculator", description="Estimate the VRAM requirements of a model and context size. Optionally provide a Hugging Face token for gated models." ) iface.launch()