File size: 3,808 Bytes
d16240b
 
76ba794
d16240b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76ba794
 
 
 
 
 
 
 
d16240b
76ba794
 
 
 
 
 
 
 
 
 
 
d16240b
76ba794
 
 
d16240b
76ba794
d16240b
76ba794
d16240b
76ba794
 
 
d16240b
76ba794
d16240b
76ba794
 
 
 
d16240b
 
 
76ba794
 
d16240b
 
 
 
 
 
 
 
 
 
 
 
76ba794
 
d16240b
 
 
76ba794
d16240b
 
76ba794
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import gradio as gr
from typing import Dict, Union
from huggingface_hub import get_safetensors_metadata, hf_hub_download, login
import json

# Dictionary mapping dtype strings to their byte sizes
bytes_per_dtype: Dict[str, float] = {
    "int4": 0.5,
    "int8": 1,
    "float8": 1,
    "float16": 2,
    "float32": 4,
}

def extract_keys(json_obj, keys_to_extract):
    extracted_values = {}
    def recursive_search(obj):
        if isinstance(obj, dict):
            for key, value in obj.items():
                if key in keys_to_extract:
                    extracted_values[key] = value
                recursive_search(value)
        elif isinstance(obj, list):
            for item in obj:
                recursive_search(item)
    recursive_search(json_obj)
    return extracted_values

def calculate_kv_cache_memory(context_size: int, model_id: str, dtype: str, token: str = None):
    try:
        file_path = hf_hub_download(repo_id=model_id, filename="config.json", token=token)
        with open(file_path, 'r') as f:
            config = json.load(f)
        
        keys_to_find = {"num_hidden_layers", "num_key_value_heads", "hidden_size", "num_attention_heads"}
        config = extract_keys(config, keys_to_find)

        num_layers = config["num_hidden_layers"]
        num_att_heads = config.get("num_key_value_heads", config["num_attention_heads"])
        dim_att_head = config["hidden_size"] // config["num_attention_heads"]
        dtype_bytes = bytes_per_dtype[dtype]
        
        memory_per_token = num_layers * num_att_heads * dim_att_head * dtype_bytes * 2
        context_size_memory_footprint_gb = (context_size * memory_per_token) / 1_000_000_000
        
        return context_size_memory_footprint_gb
    except Exception as e:
        return f"Error: {str(e)}"

def calculate_model_memory(parameters: float, dtype: str) -> float:
    bytes = bytes_per_dtype[dtype]
    return round((parameters * 4) / (32 / (bytes * 8)) * 1.18, 2)

def get_model_size(model_id: str, dtype: str, token: str = None) -> Union[float, str]:
    try:
        metadata = get_safetensors_metadata(model_id, token=token)
        if not metadata or not metadata.parameter_count:
            return "Error: Could not fetch metadata."
        model_parameters = int(list(metadata.parameter_count.values())[0]) / 1_000_000_000
        return calculate_model_memory(model_parameters, dtype)
    except Exception as e:
        return f"Error: {str(e)}"

def estimate_vram(model_id, dtype, context_size, hf_token):
    if hf_token:
        login(token=hf_token)
    
    if dtype not in bytes_per_dtype:
        return "Error: Unsupported dtype"
    
    model_memory = get_model_size(model_id, dtype, hf_token)
    context_memory = calculate_kv_cache_memory(context_size, model_id, dtype, hf_token)
    
    if isinstance(model_memory, str) or isinstance(context_memory, str):
        return model_memory if isinstance(model_memory, str) else context_memory
    
    total_memory = model_memory + context_memory
    return f"Model VRAM: {model_memory:.2f} GB\nContext VRAM: {context_memory:.2f} GB\nTotal VRAM: {total_memory:.2f} GB"

iface = gr.Interface(
    fn=estimate_vram,
    inputs=[
        gr.Textbox(label="Hugging Face Model ID", value="google/gemma-3-27b-it"),
        gr.Dropdown(choices=list(bytes_per_dtype.keys()), label="Data Type", value="float16"),
        gr.Number(label="Context Size", value=128000),
        gr.Textbox(label="Hugging Face Access Token", type="password", placeholder="Optional - Needed for gated models")
    ],
    outputs=gr.Textbox(label="Estimated VRAM Usage"),
    title="LLM GPU VRAM Calculator",
    description="Estimate the VRAM requirements of a model and context size. Optionally provide a Hugging Face token for gated models."
)

iface.launch()