henryholloway
commited on
Commit
·
c4f69f6
1
Parent(s):
475bc5f
Updated calculations, sources cited
Browse files
app.py
CHANGED
@@ -26,7 +26,7 @@ precision_options = {
|
|
26 |
'mixed': 6,
|
27 |
'half': 2
|
28 |
}
|
29 |
-
|
30 |
def calculate_memory_usage(parameter_count, context_length, data_type, batch_size, vocab_size, precision):
|
31 |
# Convert bit size to byte size
|
32 |
byte_size = quantization_bit_sizes[data_type] / 8
|
@@ -37,11 +37,8 @@ def calculate_memory_usage(parameter_count, context_length, data_type, batch_siz
|
|
37 |
# Memory usage for context (activations)
|
38 |
activations = calculate_activations(parameter_count, context_length, batch_size, vocab_size, precision)
|
39 |
|
40 |
-
# Outputs memory usage
|
41 |
-
outputs = 4 * batch_size * context_length * vocab_size
|
42 |
-
|
43 |
# Total memory usage
|
44 |
-
total_memory_usage = memory_params + activations
|
45 |
|
46 |
# Convert bytes to gigabytes
|
47 |
total_memory_usage_gb = total_memory_usage / (1024 ** 3)
|
@@ -49,33 +46,20 @@ def calculate_memory_usage(parameter_count, context_length, data_type, batch_siz
|
|
49 |
return total_memory_usage_gb
|
50 |
|
51 |
def calculate_activations(parameter_count, context_length, batch_size, vocab_size, precision):
|
52 |
-
#
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
bytes_per_param = precision_options[precision] / 8
|
58 |
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
softmax_output = bytes_per_param * batch_size * num_attention_heads * (context_length ** 2)
|
63 |
-
v = bytes_per_param * batch_size * context_length * (hidden_size / num_attention_heads) * num_attention_heads
|
64 |
-
out_proj_input = bytes_per_param * batch_size * context_length * hidden_size
|
65 |
-
attention_block = attention_input + q + k + softmax_output + v + out_proj_input
|
66 |
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
mlp_block = mlp_input + activation_input + down_proj_input
|
71 |
-
|
72 |
-
layer_norms = bytes_per_param * batch_size * context_length * hidden_size * 2
|
73 |
-
|
74 |
-
layer = attention_block + mlp_block + layer_norms
|
75 |
-
|
76 |
-
activations = layer # assuming 12 layers for simplicity
|
77 |
|
78 |
-
return
|
79 |
|
80 |
# Streamlit app
|
81 |
st.title("Memory Usage Calculator for Large Language Models")
|
|
|
26 |
'mixed': 6,
|
27 |
'half': 2
|
28 |
}
|
29 |
+
# Taken from "Reducing Activation Recomputation in Large Transformer Models" https://arxiv.org/abs/2205.05198
|
30 |
def calculate_memory_usage(parameter_count, context_length, data_type, batch_size, vocab_size, precision):
|
31 |
# Convert bit size to byte size
|
32 |
byte_size = quantization_bit_sizes[data_type] / 8
|
|
|
37 |
# Memory usage for context (activations)
|
38 |
activations = calculate_activations(parameter_count, context_length, batch_size, vocab_size, precision)
|
39 |
|
|
|
|
|
|
|
40 |
# Total memory usage
|
41 |
+
total_memory_usage = memory_params + activations
|
42 |
|
43 |
# Convert bytes to gigabytes
|
44 |
total_memory_usage_gb = total_memory_usage / (1024 ** 3)
|
|
|
46 |
return total_memory_usage_gb
|
47 |
|
48 |
def calculate_activations(parameter_count, context_length, batch_size, vocab_size, precision):
|
49 |
+
# Constants from the paper
|
50 |
+
layers = 32 # assuming 32 layers for the model
|
51 |
+
attention_heads = 32 # assuming 32 attention heads
|
52 |
+
hidden_dimensions = int(parameter_count ** 0.5) # assuming square root relationship for hidden size
|
|
|
|
|
53 |
|
54 |
+
# Calculate activations based on the formula from the paper
|
55 |
+
activations_per_layer = context_length * batch_size * hidden_dimensions * (34 + ((5 * attention_heads * context_length) / hidden_dimensions))
|
56 |
+
activations = layers * activations_per_layer / 2 # divided by 2 as per the paper's calculation at 16bit precision
|
|
|
|
|
|
|
|
|
57 |
|
58 |
+
# Convert activations to bytes based on the precision
|
59 |
+
bytes_per_param = precision_options[precision] / 8
|
60 |
+
total_activations = bytes_per_param * activations
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
|
62 |
+
return total_activations
|
63 |
|
64 |
# Streamlit app
|
65 |
st.title("Memory Usage Calculator for Large Language Models")
|