File size: 1,710 Bytes
9878518
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import gradio as gr

def calculate_kv_cache_memory(context_length=4096, n_layers=32, model_dim=4096, kv_heads=8, attention_heads=32, activation_bits=16):
    """
    Calculate the KV-cache memory in GB using the given formula.
    """
    # Formula: context_length * n_layers * (model_dim * kv_heads // attention_heads) * 2 * activation_bits / (8 * 1024 * 1024 * 1024)
    kv_cache_memory = (
        context_length *
        n_layers *
        (model_dim * kv_heads // attention_heads) *
        2 *
        activation_bits /
        (8 * 1024 * 1024 * 1024)
    )
    return round(kv_cache_memory, 4)  # Round to 4 decimal places for readability

# Define the Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("# KV-Cache Memory Calculator")
    gr.Markdown("This app calculates the KV-cache memory in GB based on the provided parameters.")

    with gr.Row():
        context_length = gr.Number(label="Context Length", value=4096)
        n_layers = gr.Number(label="Number of Layers", value=32)
    with gr.Row():
        model_dim = gr.Number(label="Model Dimension", value=4096)
        kv_heads = gr.Number(label="KV Heads", value=8)
    with gr.Row():
        attention_heads = gr.Number(label="Attention Heads", value=32)
        activation_bits = gr.Number(label="Activation Bits", value=16)

    calculate_button = gr.Button("Calculate KV-Cache Memory")
    output = gr.Textbox(label="KV-Cache Memory (GB)")

    # Bind the button to the calculation function
    calculate_button.click(
        fn=calculate_kv_cache_memory,
        inputs=[context_length, n_layers, model_dim, kv_heads, attention_heads, activation_bits],
        outputs=output
    )

# Launch the Gradio app
demo.launch()