import streamlit as st # Define bit sizes for different quantization options quantization_bit_sizes = { 'float32': 32, 'float16': 16, 'Q2_K': 2, 'Q3_K_L': 3, 'Q3_K_M': 3, 'Q3_K_S': 3, 'Q4_0': 4, 'Q4_1': 4, 'Q4_K_M': 4, 'Q4_K_S': 4, 'Q5_0': 5, 'Q5_1': 5, 'Q5_K_M': 5, 'Q5_K_S': 5, 'Q6_K': 6, 'Q8_0': 8 } # Define precision options precision_options = { 'full': 4, 'mixed': 6, 'half': 2 } # Streamlit app st.title("Memory Usage Calculator for Large Language Models") # Taken from "Reducing Activation Recomputation in Large Transformer Models" https://arxiv.org/abs/2205.05198 def calculate_memory_usage(parameter_count, context_length, data_type, batch_size, vocab_size, precision): # Convert bit size to byte size byte_size = quantization_bit_sizes[data_type] / 8 # Memory usage for model parameters memory_params = parameter_count * byte_size # Memory usage for context (activations) activations = calculate_activations(parameter_count, context_length, batch_size, vocab_size, precision) # Total memory usage total_memory_usage = memory_params + activations # Convert bytes to gigabytes total_memory_usage_gb = total_memory_usage / (1024 ** 3) return total_memory_usage_gb def calculate_activations(parameter_count, context_length, batch_size, vocab_size, precision): # Assuming square root relationship for hidden size hidden_dimensions = int(parameter_count ** 0.5) # Calculate activations based on the formula from the paper activations_per_layer = context_length * batch_size * hidden_dimensions * (34 + ((5 * attention_heads * context_length) / hidden_dimensions)) activations = layers * activations_per_layer / 2 # divided by 2 as per the paper's calculation at 16bit precision # Convert activations to bytes based on the precision bytes_per_param = precision_options[precision] / 8 total_activations = bytes_per_param * activations return total_activations # User inputs parameter_count = st.number_input("Parameter Count (in billions)", value=1, step=1) * 1e9 layers = st.number_input("Number of Layers", value=32, step=1) attention_heads = st.number_input("Number of Attention Heads", value=32, step=1) context_length = st.number_input("Context Length (number of tokens)", value=512, step=1) data_type = st.selectbox("Data Type", options=list(quantization_bit_sizes.keys())) batch_size = st.number_input("Batch Size", value=1, step=1) vocab_size = st.number_input("Vocabulary Size", value=30000, step=1000) precision = st.selectbox("Precision", options=list(precision_options.keys())) # Calculate memory usage if st.button("Calculate Memory Usage"): memory_usage = calculate_memory_usage(parameter_count, context_length, data_type, batch_size, vocab_size, precision) st.write(f"Estimated Memory Usage for Inference: {memory_usage:.2f} GB")