import gradio as gr def calculate_kv_cache_memory(context_length=4096, n_layers=32, model_dim=4096, kv_heads=8, attention_heads=32, activation_bits=16): """ Calculate the KV-cache memory in GB using the given formula. """ # Formula: context_length * n_layers * (model_dim * kv_heads // attention_heads) * 2 * activation_bits / (8 * 1024 * 1024 * 1024) kv_cache_memory = ( context_length * n_layers * (model_dim * kv_heads // attention_heads) * 2 * activation_bits / (8 * 1024 * 1024 * 1024) ) return round(kv_cache_memory, 4) # Round to 4 decimal places for readability # Define the Gradio interface with gr.Blocks() as demo: gr.Markdown("# KV-Cache Memory Calculator") gr.Markdown("This app calculates the KV-cache memory in GB based on the provided parameters.") with gr.Row(): context_length = gr.Number(label="Context Length", value=4096) n_layers = gr.Number(label="Number of Layers", value=32) with gr.Row(): model_dim = gr.Number(label="Model Dimension", value=4096) kv_heads = gr.Number(label="KV Heads", value=8) with gr.Row(): attention_heads = gr.Number(label="Attention Heads", value=32) activation_bits = gr.Number(label="Activation Bits", value=16) calculate_button = gr.Button("Calculate KV-Cache Memory") output = gr.Textbox(label="KV-Cache Memory (GB)") # Bind the button to the calculation function calculate_button.click( fn=calculate_kv_cache_memory, inputs=[context_length, n_layers, model_dim, kv_heads, attention_heads, activation_bits], outputs=output ) # Launch the Gradio app demo.launch()