db commited on
Commit
b40998e
1 Parent(s): 9dfc7f9
Files changed (2) hide show
  1. scaling_laws.ipynb +0 -0
  2. transformer_sizing.ipynb +402 -0
scaling_laws.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
transformer_sizing.ipynb ADDED
@@ -0,0 +1,402 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "attachments": {},
5
+ "cell_type": "markdown",
6
+ "metadata": {},
7
+ "source": [
8
+ "### Transformer Theoretical Model\n",
9
+ "\n",
10
+ "This notebook stores a bunch of analysis about a Transformer, e.g. estimates the number of FLOPs, parameters, peak memory footprint, checkpoint size, etc."
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "code",
15
+ "execution_count": 1,
16
+ "metadata": {},
17
+ "outputs": [],
18
+ "source": [
19
+ "from collections import OrderedDict"
20
+ ]
21
+ },
22
+ {
23
+ "cell_type": "code",
24
+ "execution_count": 2,
25
+ "metadata": {},
26
+ "outputs": [],
27
+ "source": [
28
+ "# config_args = {\n",
29
+ "# 'gpt2': dict(n_layer=12, n_head=12, n_embd=768), # 124M params\n",
30
+ "# 'gpt2-medium': dict(n_layer=24, n_head=16, n_embd=1024), # 350M params\n",
31
+ "# 'gpt2-large': dict(n_layer=36, n_head=20, n_embd=1280), # 774M params\n",
32
+ "# 'gpt2-xl': dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params\n",
33
+ "# }[model_type]\n",
34
+ "\n",
35
+ "block_size = 1024\n",
36
+ "vocab_size = 50257\n",
37
+ "n_layer = 12\n",
38
+ "n_head = 12\n",
39
+ "n_embd = 768\n",
40
+ "bias = False\n",
41
+ "assert not bias, \"this notebook assumes bias=False just for simplicity\""
42
+ ]
43
+ },
44
+ {
45
+ "cell_type": "code",
46
+ "execution_count": 3,
47
+ "metadata": {},
48
+ "outputs": [
49
+ {
50
+ "name": "stdout",
51
+ "output_type": "stream",
52
+ "text": [
53
+ "we see: 124337664, expected: 124337664, match: True\n",
54
+ "name params ratio (%) \n",
55
+ "emebedding/position 786432 0.6325\n",
56
+ "embedding/token 38597376 31.0424\n",
57
+ "embedding 39383808 31.6749\n",
58
+ "attention/ln 768 0.0006\n",
59
+ "attention/kqv 1769472 1.4231\n",
60
+ "attention/proj 589824 0.4744\n",
61
+ "attention 2360064 1.8981\n",
62
+ "mlp/ln 768 0.0006\n",
63
+ "mlp/ffw 2359296 1.8975\n",
64
+ "mlp/proj 2359296 1.8975\n",
65
+ "mlp 4719360 3.7956\n",
66
+ "block 7079424 5.6937\n",
67
+ "transformer 84953088 68.3245\n",
68
+ "ln_f 768 0.0006\n",
69
+ "dense 0 0.0000\n",
70
+ "total 124337664 100.0000\n"
71
+ ]
72
+ }
73
+ ],
74
+ "source": [
75
+ "def params():\n",
76
+ " \"\"\" estimates the number of parameters in the model\"\"\"\n",
77
+ " out = OrderedDict()\n",
78
+ "\n",
79
+ " # token and position embeddings\n",
80
+ " out['emebedding/position'] = n_embd * block_size\n",
81
+ " out['embedding/token'] = n_embd * vocab_size\n",
82
+ " out['embedding'] = out['emebedding/position'] + out['embedding/token']\n",
83
+ "\n",
84
+ " # attention blocks\n",
85
+ " out['attention/ln'] = n_embd # note, bias=False in our LN\n",
86
+ " out['attention/kqv'] = n_embd * 3*n_embd\n",
87
+ " out['attention/proj'] = n_embd**2\n",
88
+ " out['attention'] = out['attention/ln'] + out['attention/kqv'] + out['attention/proj']\n",
89
+ "\n",
90
+ " # MLP blocks\n",
91
+ " ffw_size = 4*n_embd # feed forward size\n",
92
+ " out['mlp/ln'] = n_embd\n",
93
+ " out['mlp/ffw'] = n_embd * ffw_size\n",
94
+ " out['mlp/proj'] = ffw_size * n_embd\n",
95
+ " out['mlp'] = out['mlp/ln'] + out['mlp/ffw'] + out['mlp/proj']\n",
96
+ " \n",
97
+ " # the transformer and the rest of it\n",
98
+ " out['block'] = out['attention'] + out['mlp']\n",
99
+ " out['transformer'] = n_layer * out['block']\n",
100
+ " out['ln_f'] = n_embd # final layernorm\n",
101
+ " out['dense'] = 0 # 0 because of parameter sharing. This layer uses the weights from the embedding layer\n",
102
+ "\n",
103
+ " # total\n",
104
+ " out['total'] = out['embedding'] + out['transformer'] + out['ln_f'] + out['dense']\n",
105
+ "\n",
106
+ " return out\n",
107
+ "\n",
108
+ "# compare our param count to that reported by PyTorch\n",
109
+ "p = params()\n",
110
+ "params_total = p['total']\n",
111
+ "print(f\"we see: {params_total}, expected: {124337664}, match: {params_total == 124337664}\")\n",
112
+ "# create a header\n",
113
+ "print(f\"{'name':20s} {'params':10s} {'ratio (%)':10s}\")\n",
114
+ "for k,v in p.items():\n",
115
+ " print(f\"{k:20s} {v:10d} {v/params_total*100:10.4f}\")\n",
116
+ " "
117
+ ]
118
+ },
119
+ {
120
+ "cell_type": "code",
121
+ "execution_count": 4,
122
+ "metadata": {},
123
+ "outputs": [
124
+ {
125
+ "name": "stdout",
126
+ "output_type": "stream",
127
+ "text": [
128
+ "est checkpoint size: 1.49 GB\n",
129
+ "measured with wc -c ckpt.pt: 1542470366\n",
130
+ "fluff ratio: 103.38%\n"
131
+ ]
132
+ }
133
+ ],
134
+ "source": [
135
+ "# we can now calculate the size of each checkpoint\n",
136
+ "# params are stored in fp32, and the AdamW optimizer has 2 additional buffers per param for statistics\n",
137
+ "params_bytes = params_total*4\n",
138
+ "params_and_buffers_bytes = params_bytes + 2*params_bytes\n",
139
+ "print(f\"est checkpoint size: {params_and_buffers_bytes/1e9:.2f} GB\")\n",
140
+ "measured_bytes = 1542470366 # from wc -c ckpt.pt\n",
141
+ "print(f\"measured with wc -c ckpt.pt: {measured_bytes}\")\n",
142
+ "print(f\"fluff ratio: {measured_bytes/params_and_buffers_bytes*100:.2f}%\")"
143
+ ]
144
+ },
145
+ {
146
+ "attachments": {},
147
+ "cell_type": "markdown",
148
+ "metadata": {},
149
+ "source": [
150
+ "We can also estimate the ratio of our GPU memory that will be taken up just by the weights and the buffers inside the AdamW optimizer"
151
+ ]
152
+ },
153
+ {
154
+ "cell_type": "code",
155
+ "execution_count": 5,
156
+ "metadata": {},
157
+ "outputs": [
158
+ {
159
+ "name": "stdout",
160
+ "output_type": "stream",
161
+ "text": [
162
+ "memory ratio taken up just for parameters: 3.73%\n"
163
+ ]
164
+ }
165
+ ],
166
+ "source": [
167
+ "gpu_memory = 40e9 # 40 GB A100 GPU, roughly\n",
168
+ "print(f\"memory ratio taken up just for parameters: {params_and_buffers_bytes / gpu_memory * 100:.2f}%\")"
169
+ ]
170
+ },
171
+ {
172
+ "attachments": {},
173
+ "cell_type": "markdown",
174
+ "metadata": {},
175
+ "source": [
176
+ "i.e. not that much of the memory for this tiny model, most of the memory is activations (forward and backward). This of course changes dramatically for larger and larger models."
177
+ ]
178
+ },
179
+ {
180
+ "attachments": {},
181
+ "cell_type": "markdown",
182
+ "metadata": {},
183
+ "source": [
184
+ "Let's estimate FLOPs for a single forward pass."
185
+ ]
186
+ },
187
+ {
188
+ "cell_type": "code",
189
+ "execution_count": 6,
190
+ "metadata": {},
191
+ "outputs": [
192
+ {
193
+ "name": "stdout",
194
+ "output_type": "stream",
195
+ "text": [
196
+ "name flops ratio (%) \n",
197
+ "attention/kqv 3623878656 1.2426\n",
198
+ "attention/scores 1610612736 0.5522\n",
199
+ "attention/reduce 1610612736 0.5522\n",
200
+ "attention/proj 1207959552 0.4142\n",
201
+ "attention 8053063680 2.7612\n",
202
+ "mlp/ffw1 4831838208 1.6567\n",
203
+ "mlp/ffw2 4831838208 1.6567\n",
204
+ "mlp 9663676416 3.3135\n",
205
+ "block 17716740096 6.0747\n",
206
+ "transformer 212600881152 72.8963\n",
207
+ "dense 79047426048 27.1037\n",
208
+ "forward_total 291648307200 100.0000\n",
209
+ "backward_total 583296614400 200.0000\n",
210
+ "total 874944921600 300.0000\n"
211
+ ]
212
+ }
213
+ ],
214
+ "source": [
215
+ "def flops():\n",
216
+ " # we only count Weight FLOPs, all other layers (LayerNorm, Softmax, etc) are effectively irrelevant\n",
217
+ " # we count actual FLOPs, not MACs. Hence 2* all over the place\n",
218
+ " # basically for any matrix multiply A (BxC) @ B (CxD) -> (BxD) flops are 2*B*C*D\n",
219
+ "\n",
220
+ " out = OrderedDict()\n",
221
+ " head_size = n_embd // n_head\n",
222
+ "\n",
223
+ " # attention blocks\n",
224
+ " # 1) the projection to key, query, values\n",
225
+ " out['attention/kqv'] = 2 * block_size * (n_embd * 3*n_embd)\n",
226
+ " # 2) calculating the attention scores\n",
227
+ " out['attention/scores'] = 2 * block_size * block_size * n_embd\n",
228
+ " # 3) the reduction of the values (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)\n",
229
+ " out['attention/reduce'] = 2 * n_head * (block_size * block_size * head_size)\n",
230
+ " # 4) the final linear projection\n",
231
+ " out['attention/proj'] = 2 * block_size * (n_embd * n_embd)\n",
232
+ " out['attention'] = sum(out['attention/'+k] for k in ['kqv', 'scores', 'reduce', 'proj'])\n",
233
+ "\n",
234
+ " # MLP blocks\n",
235
+ " ffw_size = 4*n_embd # feed forward size\n",
236
+ " out['mlp/ffw1'] = 2 * block_size * (n_embd * ffw_size)\n",
237
+ " out['mlp/ffw2'] = 2 * block_size * (ffw_size * n_embd)\n",
238
+ " out['mlp'] = out['mlp/ffw1'] + out['mlp/ffw2']\n",
239
+ "\n",
240
+ " # the transformer and the rest of it\n",
241
+ " out['block'] = out['attention'] + out['mlp']\n",
242
+ " out['transformer'] = n_layer * out['block']\n",
243
+ " out['dense'] = 2 * block_size * (n_embd * vocab_size)\n",
244
+ "\n",
245
+ " # forward,backward,total\n",
246
+ " out['forward_total'] = out['transformer'] + out['dense']\n",
247
+ " out['backward_total'] = 2 * out['forward_total'] # use common estimate of bwd = 2*fwd\n",
248
+ " out['total'] = out['forward_total'] + out['backward_total']\n",
249
+ "\n",
250
+ " return out\n",
251
+ " \n",
252
+ "# compare our param count to that reported by PyTorch\n",
253
+ "f = flops()\n",
254
+ "flops_total = f['forward_total']\n",
255
+ "print(f\"{'name':20s} {'flops':14s} {'ratio (%)':10s}\")\n",
256
+ "for k,v in f.items():\n",
257
+ " print(f\"{k:20s} {v:14d} {v/flops_total*100:10.4f}\")\n",
258
+ " "
259
+ ]
260
+ },
261
+ {
262
+ "cell_type": "code",
263
+ "execution_count": 7,
264
+ "metadata": {},
265
+ "outputs": [
266
+ {
267
+ "name": "stdout",
268
+ "output_type": "stream",
269
+ "text": [
270
+ "palm_flops: 875062886400, flops: 874944921600, ratio: 1.0001\n"
271
+ ]
272
+ }
273
+ ],
274
+ "source": [
275
+ "# now here is an estimate copy pasted from the PaLM paper\n",
276
+ "# this formula is often used to calculate MFU (model flops utilization)\n",
277
+ "def palm_flops():\n",
278
+ " \"\"\"estimate of the model flops following PaLM paper formula\"\"\"\n",
279
+ " # non-embedding model parameters. note that we do not subtract the\n",
280
+ " # embedding/token params because those are tied and get used in the last layer.\n",
281
+ " N = params()['total'] - params()['emebedding/position']\n",
282
+ " L, H, Q, T = n_layer, n_head, n_embd//n_head, block_size\n",
283
+ " mf_per_token = 6*N + 12*L*H*Q*T\n",
284
+ " mf = mf_per_token * block_size\n",
285
+ " return mf\n",
286
+ "\n",
287
+ "print(f\"palm_flops: {palm_flops():d}, flops: {flops()['total']:d}, ratio: {palm_flops()/flops()['total']:.4f}\")"
288
+ ]
289
+ },
290
+ {
291
+ "attachments": {},
292
+ "cell_type": "markdown",
293
+ "metadata": {},
294
+ "source": [
295
+ "Ok they are quite similar, giving some confidence that my math in flops() function was ~ok. Now, A100 is cited at 312TFLOPS bfloat16 on tensor cores. So what is our model flops utilization (MFU)? I trained the model above with a batch_size of 20 and grad_accum of 5, which runs in about 755ms on a single A100 GPU. We get:"
296
+ ]
297
+ },
298
+ {
299
+ "cell_type": "code",
300
+ "execution_count": 8,
301
+ "metadata": {},
302
+ "outputs": [
303
+ {
304
+ "name": "stdout",
305
+ "output_type": "stream",
306
+ "text": [
307
+ "fraction of A100 used: 37.14%\n"
308
+ ]
309
+ }
310
+ ],
311
+ "source": [
312
+ "# here is what we currently roughly measure\n",
313
+ "batch_size = 20 * 5 # 5 is grad_accum, so total batch size is 100\n",
314
+ "measured_time = 0.755 # in seconds per iteration\n",
315
+ "measured_throughput = batch_size / measured_time\n",
316
+ "flops_achieved = f['total'] * measured_throughput\n",
317
+ "\n",
318
+ "# A100 is cited to be 312 TFLOPS of bloat16 running on tensor cores\n",
319
+ "a100_flops_promised = 312e12\n",
320
+ "\n",
321
+ "# the fraction of the A100 that we are using:\n",
322
+ "print(f\"fraction of A100 used: {flops_achieved / a100_flops_promised * 100:.2f}%\")"
323
+ ]
324
+ },
325
+ {
326
+ "attachments": {},
327
+ "cell_type": "markdown",
328
+ "metadata": {},
329
+ "source": [
330
+ "For reference, we'd prefer to be somewhere around 50%+, and not just for a single GPU but for an entire DDP run. So we still have some work to do, but at least we're within a factor of ~2X of what is achievable with this GPU."
331
+ ]
332
+ },
333
+ {
334
+ "cell_type": "code",
335
+ "execution_count": 9,
336
+ "metadata": {},
337
+ "outputs": [
338
+ {
339
+ "name": "stdout",
340
+ "output_type": "stream",
341
+ "text": [
342
+ "time needed to train the model: 3.46 days\n"
343
+ ]
344
+ }
345
+ ],
346
+ "source": [
347
+ "# Finally let's check out the 6ND approximation as total cost of training in FLOPs\n",
348
+ "model_size = params()['total'] # this is number of parameters, N\n",
349
+ "tokens_num = 300e9 # 300B tokens, this is dataset size in tokens, D\n",
350
+ "a100_flops = 312e12 # 312 TFLOPS\n",
351
+ "assumed_mfu = 0.3 # assume this model flops utilization (take the current 37% from above and add some DDP overhead)\n",
352
+ "flops_throughput = a100_flops * 8 * assumed_mfu # assume an 8XA100 node at 30% utilization\n",
353
+ "flops_needed = 6 * model_size * tokens_num # 6ND\n",
354
+ "time_needed_s = flops_needed / flops_throughput # in seconds\n",
355
+ "print(f\"time needed to train the model: {time_needed_s/3600/24:.2f} days\")"
356
+ ]
357
+ },
358
+ {
359
+ "attachments": {},
360
+ "cell_type": "markdown",
361
+ "metadata": {},
362
+ "source": [
363
+ "This is not a bad estimate at all. I trained this model and it converged in roughly 4 days. Btw as a good reference for where 6ND comes from and some intuition around it I recommend [Dzmitry's post](https://medium.com/@dzmitrybahdanau/the-flops-calculus-of-language-model-training-3b19c1f025e4)."
364
+ ]
365
+ },
366
+ {
367
+ "attachments": {},
368
+ "cell_type": "markdown",
369
+ "metadata": {},
370
+ "source": [
371
+ "Now, FLOPs are just one constraint, the other that we have to keep a close track of is the memory bandwidth. TODO estimate LOAD/STORE costs of our model later."
372
+ ]
373
+ }
374
+ ],
375
+ "metadata": {
376
+ "kernelspec": {
377
+ "display_name": "pytorch2",
378
+ "language": "python",
379
+ "name": "python3"
380
+ },
381
+ "language_info": {
382
+ "codemirror_mode": {
383
+ "name": "ipython",
384
+ "version": 3
385
+ },
386
+ "file_extension": ".py",
387
+ "mimetype": "text/x-python",
388
+ "name": "python",
389
+ "nbconvert_exporter": "python",
390
+ "pygments_lexer": "ipython3",
391
+ "version": "3.10.8"
392
+ },
393
+ "orig_nbformat": 4,
394
+ "vscode": {
395
+ "interpreter": {
396
+ "hash": "7f5833218766b48e6e35e4452ee875aac0e2188d05bbe5298f2c62b79f08b222"
397
+ }
398
+ }
399
+ },
400
+ "nbformat": 4,
401
+ "nbformat_minor": 2
402
+ }