Spaces:
Runtime error
Runtime error
File size: 8,741 Bytes
44ac4da 87bf3c7 07abc51 87bf3c7 a21c8ab 2cf774f 06c81be 4bb67b3 a21c8ab b31a1d5 87bf3c7 32aafee 67aca21 87bf3c7 67aca21 87bf3c7 a21c8ab 86cd028 409563e 87bf3c7 07abc51 87bf3c7 c904168 1c3236f c904168 1c3236f c904168 3edd3ca 87bf3c7 a21c8ab 4bb67b3 c88286f a21c8ab e95954d dbafc77 a21c8ab 8f8d70a 3ddfba1 52a11ab 828756a a21c8ab ebde2f8 8f8d70a 5e8eef3 3ddfba1 bc2a18b 0332fd5 81ea362 0332fd5 bc2a18b 0332fd5 3ddfba1 064a5f0 52a11ab 81ea362 4652972 81ea362 4652972 82370ff 32aafee 82370ff 4652972 81ea362 52a11ab 32aafee 82370ff 32aafee 82370ff 32aafee 064a5f0 a6d7fbc 6c2da96 064a5f0 c88286f 6c2da96 064a5f0 c88286f 87bf3c7 c88286f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 |
import streamlit as st
st.sidebar.header("Transformer parameters")
col1, col2 = st.sidebar.columns([2, 4])
bs = st.sidebar.number_input('Batch size', value=10)
h = st.sidebar.number_input('Num heads',value=16)
d = st.sidebar.number_input('Dimension', value=768)
l = st.sidebar.number_input('Num layers', value=24)
n_start = st.sidebar.number_input('Start seq', value=1)
n = st.sidebar.number_input('End seq', value=1024)
st.sidebar.header("GPU parameters")
GPU = st.sidebar.selectbox('GPU', ('A100', 'V100'))
if GPU == 'A100':
# A100 specs
TFLOPS = 312e12
GB_S = 1935e9
elif GPU == 'V100':
TFLOPS = 112e12
GB_S = 900e9
else:
raise ValueError('Unknown GPU')
# in ms
THREAD_OVERHEAD = st.sidebar.number_input('Thread overhead (in ms)', format="%.3f", value=0.005)
GPU_EFFICIENCY = st.sidebar.number_input('GPU efficiency', format="%.3f", value=0.5)
TFLOPS = GPU_EFFICIENCY*TFLOPS
# in ms
def calc_exec_time(comp_flop, mem_bytes, include_overhead=True):
exec_time = max(comp_flop/TFLOPS, mem_bytes/GB_S)
exec_time *= 1000
if include_overhead:
exec_time = max(exec_time, THREAD_OVERHEAD)
return exec_time
def qkv_mha_exec(bs, h, n, d):
flop = 2*bs*1*d*3*d
nbytes = 2*bs*1*d + 2*3*d*d + 2*bs*1*3*d
exec_time = calc_exec_time(flop, nbytes)
return flop, nbytes, exec_time
def qkv_mqa_exec(bs, h, n, d):
flop = 2*bs*1*d*(1+2/h)*d
nbytes = 2*bs*1*d + 2*(2/h)*d*d + 2*bs*1*(2/h)*d
exec_time = calc_exec_time(flop, nbytes)
return flop, nbytes, exec_time
def att1_mha_exec(bs, h, n, d):
flop = 2*bs*h*(d/h)*n
nbytes = 2*bs*h*(d/h) + 2*bs*h*n*(d/h) + 2*bs*h*n
exec_time = calc_exec_time(flop, nbytes)
return flop, nbytes, exec_time
def att1_mqa_exec(bs, h, n, d):
flop = 2*bs*h*(d/h)*n
nbytes = 2*bs*h*(d/h) + 2*bs*n*(d/h) + 2*bs*h*n
exec_time = calc_exec_time(flop, nbytes)
return flop, nbytes, exec_time
def att2_mha_exec(bs, h, n, d):
flop = 2*bs*h*n*(d/h)
nbytes = 2*bs*h*n + 2*bs*h*n*(d/h) + 2*bs*h*(d/h)
exec_time = calc_exec_time(flop, nbytes)
return flop, nbytes, exec_time
def att2_mqa_exec(bs, h, n, d):
flop = 2*bs*h*n*(d/h)
nbytes = 2*bs*n*(d/h) + 2*bs*n*(d/h) + 2*bs*h*(d/h)
exec_time = calc_exec_time(flop, nbytes)
return flop, nbytes, exec_time
def out_exec(bs, h, n, d):
flop = 2*bs*1*d*d
nbytes = 2*bs*1*d + 2*d*d + 2*bs*1*d
exec_time = calc_exec_time(flop, nbytes)
return flop, nbytes, exec_time
def softmax_exec(bs, h, n, d):
flop = 0
nbytes = 2*bs*h*n + 2*bs*h*n
exec_time = calc_exec_time(flop, nbytes)
return flop, nbytes, exec_time
def ln_exec(bs, h, n, d):
nbytes = 2*bs*1*d + 2*bs*1*d
flop = 0
exec_time = calc_exec_time(flop, nbytes)
return flop, nbytes, exec_time
def mlp_exec(bs, h, n, d):
flop = 2*bs*1*d*4*d
nbytes = 2*bs*1*d + 2*d*4*d + 2*bs*1*4*d
exec_time = calc_exec_time(flop, nbytes)
return flop, nbytes, exec_time
def print_kernel_execution(flop, nbytes):
c1, c2 = st.columns([2, 3])
exec_time = calc_exec_time(flop, nbytes, include_overhead=False)
flop = round(flop/1e9, 2)
nbytes = round(nbytes/1e6, 2)
c1.write("GFLOP:")
c2.write(str(flop))
c1.write("MB: ")
c2.write(str(nbytes))
c1.write("Time (ms):")
c2.write(str(exec_time))
c1.write("Overhead (ms):")
c2.write(str(THREAD_OVERHEAD))
st.title("Inference time MHA vs MQA")
st.write("This space approximates the inference time for Multi-Query Attention and Multi-Head Attention transformers. You can change the hyperparameters in sidebar.")
mqa_total_time = 0.
mha_total_time = 0.
for i in range(n_start, n):
shared_time = out_exec(bs, h, i, d)[2] + softmax_exec(bs, h, i , d)[2] + 2*ln_exec(bs, h, i, d)[2] \
+ 2*mlp_exec(bs, h, i, d)[2] + 3*ln_exec(bs, h, i, d)[2]
mha_time = shared_time + qkv_mha_exec(bs, h, i, d)[2] + att1_mha_exec(bs, h, i, d)[2] + att2_mha_exec(bs, h, i, d)[2]
mha_total_time += l*mha_time
mqa_time = shared_time + qkv_mqa_exec(bs, h, i, d)[2] + att1_mqa_exec(bs, h, i, d)[2] + att2_mqa_exec(bs, h, i, d)[2]
mqa_total_time += l*mqa_time
c1, c2 = st.columns([2, 4])
c1.write("Multi-Head Attention:")
c2.write(str(round(mha_total_time, 2)))
c1.write("Multi-Query Attention:")
c2.write(str(round(mqa_total_time, 2)))
c1.write("Speed-up MQA over MHA:")
c2.write(str(round(mha_total_time/mqa_total_time,2)))
st.subheader("Memory consumption")
st.caption("Multi-Head Attention")
c1, c2 = st.columns([2, 4])
num_params = 12*l*d*d
c1.write("Num Parameters (in B)")
c2.write(str(round(num_params/1e9, 3)))
c1.write("Stored Parameters (GB)")
c2.write(str(round(2*num_params/1e9, 3)))
c1.write("Cached keys and values (GB)")
acts = round(2*bs*l*(d/h)*h*2*n/1e9, 2)
c2.write(str(acts))
st.caption("Multi-Query Attention")
c1, c2 = st.columns([2, 4])
num_params = (10+2/h)*l*d*d
c1.write("Num Parameters (in B)")
c2.write(str(round(num_params/1e9, 3)))
c1.write("Stored Parameters (GB)")
c2.write(str(round(2*num_params/1e9, 3)))
c1.write("Cached keys and values (GB)")
acts = round(2*bs*l*(d/h)*2*n/1e9, 2)
c2.write(str(acts))
st.subheader("Estimating execution time")
st.markdown("We use the [following crude approximation](https://docs.nvidia.com/deeplearning/performance/dl-performance-gpu-background/index.html#understand-perf) to estimate the execution time for each matrix multiplication.")
st.latex("C = A \cdot B")
st.latex("A \in \mathbb{R}^{MxK}, B \in R^{KxN}, C \in \mathbb{R}^{MxN}")
st.markdown('''
To execute this operation on the GPU, we need to
1. Read A, B from memory
2. Perform matrix multiplication
3. Write C to memory
''')
st.markdown("For float16 operations (2 bytes), we can estimate the memory access time of A as follows:")
st.latex("T_{mem}(A) = 2*M*K / BW_{mem}")
st.markdown("where BW_mem is the memory bandwidth of the GPU (e.g. 1935 GB/s for an A100 GPU)")
st.markdown("The total time on memory access is T_mem = T_mem(A) + T_mem(B) + T_mem(C)")
st.markdown("We can estimate the compute time for the math operations as follows:")
st.latex("T_{math}(A \cdot B) = 2*M*K*N / BW_{math}")
st.markdown("where BW_math is the number of floating point operations per second (e.g. 312 TFLOPS for an A100 GPU)")
st.markdown("If we assume we can *perfectly* overlap memory access with math operations, then the estimated execution time for the operation is:")
st.latex("max(T_{math}, T_{mem})")
st.markdown("Note that there is a minimum time to execute the operation due to [kernel launch overhead](https://forums.developer.nvidia.com/t/any-way-to-measure-the-latency-of-a-kernel-launch/221413/2)")
st.subheader("Inference time for Transformer operations")
st.markdown("We can now estimate the execution for each of the operations in the transformer model. I suggest you inspect the code for details on the calculations. ")
st.subheader('Attention layer')
st.markdown('**QKV projection**')
st.caption("Multi-Head Attention")
flop, nbytes, exec_time = qkv_mha_exec(bs, h, n, d)
print_kernel_execution(flop, nbytes)
st.caption("Multi-Query Attention")
flop, nbytes, exec_time = qkv_mqa_exec(bs, h, n, d)
print_kernel_execution(flop, nbytes)
st.markdown('**QK gemm**')
st.write("Showing calculation for the maximum sequence length (n)")
st.caption("Multi-Head Attention")
flop, nbytes, exec_time = att1_mha_exec(bs, h, n, d)
print_kernel_execution(flop, nbytes)
st.caption("Multi-Query Attention")
flop, nbytes, exec_time = att1_mqa_exec(bs, h, n, d)
print_kernel_execution(flop, nbytes)
st.markdown('**Attention-value gemm**')
st.write("Showing calculation for the maximum sequence length (n)")
st.caption("Multi-Head Attention")
flop, nbytes, exec_time = att2_mha_exec(bs, h, n, d)
print_kernel_execution(flop, nbytes)
st.caption("Multi-Query Attention")
flop, nbytes, exec_time = att2_mqa_exec(bs, h, n, d)
print_kernel_execution(flop, nbytes)
st.markdown('**Output projection**')
flop, nbytes, exec_time = out_exec(bs, h, n, d)
print_kernel_execution(flop, nbytes)
st.markdown('**Element-wise ops**')
st.write("We also need to take into the softmax layer, layer norm, and residual connection. We assume that these operations are memory bound. ")
st.caption("Softmax")
flop, nbytes, exec_time = softmax_exec(bs, h, n, d)
print_kernel_execution(flop, nbytes)
st.caption("Layer norm/residual connection")
flop, nbytes, exec_time = ln_exec(bs, h, n, d)
print_kernel_execution(flop, nbytes)
st.subheader('MLP layer')
st.markdown('**First and Second Linear Layer**')
flop, nbytes, exec_time = mlp_exec(bs, h, n, d)
print_kernel_execution(flop, nbytes)
st.markdown('**Element-wise ops**')
st.write("We also need to take into the GeLU, layer norm, and residual connection. We assume that these operations are memory bound. ")
flop, nbytes, exec_time = ln_exec(bs, h, n, d)
print_kernel_execution(flop, nbytes)
|