HZjiangyi
init
9878518
import gradio as gr
def calculate_kv_cache_memory(context_length=4096, n_layers=32, model_dim=4096, kv_heads=8, attention_heads=32, activation_bits=16):
"""
Calculate the KV-cache memory in GB using the given formula.
"""
# Formula: context_length * n_layers * (model_dim * kv_heads // attention_heads) * 2 * activation_bits / (8 * 1024 * 1024 * 1024)
kv_cache_memory = (
context_length *
n_layers *
(model_dim * kv_heads // attention_heads) *
2 *
activation_bits /
(8 * 1024 * 1024 * 1024)
)
return round(kv_cache_memory, 4) # Round to 4 decimal places for readability
# Define the Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# KV-Cache Memory Calculator")
gr.Markdown("This app calculates the KV-cache memory in GB based on the provided parameters.")
with gr.Row():
context_length = gr.Number(label="Context Length", value=4096)
n_layers = gr.Number(label="Number of Layers", value=32)
with gr.Row():
model_dim = gr.Number(label="Model Dimension", value=4096)
kv_heads = gr.Number(label="KV Heads", value=8)
with gr.Row():
attention_heads = gr.Number(label="Attention Heads", value=32)
activation_bits = gr.Number(label="Activation Bits", value=16)
calculate_button = gr.Button("Calculate KV-Cache Memory")
output = gr.Textbox(label="KV-Cache Memory (GB)")
# Bind the button to the calculation function
calculate_button.click(
fn=calculate_kv_cache_memory,
inputs=[context_length, n_layers, model_dim, kv_heads, attention_heads, activation_bits],
outputs=output
)
# Launch the Gradio app
demo.launch()