PrateekJ17 commited on
Commit
13d87eb
1 Parent(s): 96c3b3a

Upload 11 files

Browse files
Files changed (11) hide show
  1. .gitattributes +3 -35
  2. .gitignore +4 -0
  3. LICENSE +21 -0
  4. README.md +223 -3
  5. bench.py +117 -0
  6. configurator.py +47 -0
  7. model.py +337 -0
  8. sample.py +89 -0
  9. scaling_laws.ipynb +0 -0
  10. train.py +332 -0
  11. transformer_sizing.ipynb +402 -0
.gitattributes CHANGED
@@ -1,35 +1,3 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ # Override jupyter in Github language stats for more accurate estimate of repo code languages
2
+ # reference: https://github.com/github/linguist/blob/master/docs/overrides.md#generated-code
3
+ *.ipynb linguist-generated
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ .DS_Store
2
+ .ipynb_checkpoints/
3
+ __pycache__/
4
+ *.pyc
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2022 Andrej Karpathy
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,3 +1,223 @@
1
- ---
2
- license: cc-by-nc-sa-4.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # nanoGPT
3
+
4
+ ![nanoGPT](assets/nanogpt.jpg)
5
+
6
+ The simplest, fastest repository for training/finetuning medium-sized GPTs. It is a rewrite of [minGPT](https://github.com/karpathy/minGPT) that prioritizes teeth over education. Still under active development, but currently the file `train.py` reproduces GPT-2 (124M) on OpenWebText, running on a single 8XA100 40GB node in about 4 days of training. The code itself is plain and readable: `train.py` is a ~300-line boilerplate training loop and `model.py` a ~300-line GPT model definition, which can optionally load the GPT-2 weights from OpenAI. That's it.
7
+
8
+ ![repro124m](assets/gpt2_124M_loss.png)
9
+
10
+ Because the code is so simple, it is very easy to hack to your needs, train new models from scratch, or finetune pretrained checkpoints (e.g. biggest one currently available as a starting point would be the GPT-2 1.3B model from OpenAI).
11
+
12
+ ## install
13
+
14
+ Dependencies:
15
+
16
+ - [pytorch](https://pytorch.org) <3
17
+ - [numpy](https://numpy.org/install/) <3
18
+ - `pip install transformers` for huggingface transformers <3 (to load GPT-2 checkpoints)
19
+ - `pip install datasets` for huggingface datasets <3 (if you want to download + preprocess OpenWebText)
20
+ - `pip install tiktoken` for OpenAI's fast BPE code <3
21
+ - `pip install wandb` for optional logging <3
22
+ - `pip install tqdm` <3
23
+
24
+ ## quick start
25
+
26
+ If you are not a deep learning professional and you just want to feel the magic and get your feet wet, the fastest way to get started is to train a character-level GPT on the works of Shakespeare. First, we download it as a single (1MB) file and turn it from raw text into one large stream of integers:
27
+
28
+ ```
29
+ $ python data/shakespeare_char/prepare.py
30
+ ```
31
+
32
+ This creates a `train.bin` and `val.bin` in that data directory. Now it is time to train your GPT. The size of it very much depends on the computational resources of your system:
33
+
34
+ **I have a GPU**. Great, we can quickly train a baby GPT with the settings provided in the [config/train_shakespeare_char.py](config/train_shakespeare_char.py) config file:
35
+
36
+ ```
37
+ $ python train.py config/train_shakespeare_char.py
38
+ ```
39
+
40
+ If you peek inside it, you'll see that we're training a GPT with a context size of up to 256 characters, 384 feature channels, and it is a 6-layer Transformer with 6 heads in each layer. On one A100 GPU this training run takes about 3 minutes and the best validation loss is 1.4697. Based on the configuration, the model checkpoints are being written into the `--out_dir` directory `out-shakespeare-char`. So once the training finishes we can sample from the best model by pointing the sampling script at this directory:
41
+
42
+ ```
43
+ $ python sample.py --out_dir=out-shakespeare-char
44
+ ```
45
+
46
+ This generates a few samples, for example:
47
+
48
+ ```
49
+ ANGELO:
50
+ And cowards it be strawn to my bed,
51
+ And thrust the gates of my threats,
52
+ Because he that ale away, and hang'd
53
+ An one with him.
54
+
55
+ DUKE VINCENTIO:
56
+ I thank your eyes against it.
57
+
58
+ DUKE VINCENTIO:
59
+ Then will answer him to save the malm:
60
+ And what have you tyrannous shall do this?
61
+
62
+ DUKE VINCENTIO:
63
+ If you have done evils of all disposition
64
+ To end his power, the day of thrust for a common men
65
+ That I leave, to fight with over-liking
66
+ Hasting in a roseman.
67
+ ```
68
+
69
+ lol `¯\_(ツ)_/¯`. Not bad for a character-level model after 3 minutes of training on a GPU. Better results are quite likely obtainable by instead finetuning a pretrained GPT-2 model on this dataset (see finetuning section later).
70
+
71
+ **I only have a macbook** (or other cheap computer). No worries, we can still train a GPT but we want to dial things down a notch. I recommend getting the bleeding edge PyTorch nightly ([select it here](https://pytorch.org/get-started/locally/) when installing) as it is currently quite likely to make your code more efficient. But even without it, a simple train run could look as follows:
72
+
73
+ ```
74
+ $ python train.py config/train_shakespeare_char.py --device=cpu --compile=False --eval_iters=20 --log_interval=1 --block_size=64 --batch_size=12 --n_layer=4 --n_head=4 --n_embd=128 --max_iters=2000 --lr_decay_iters=2000 --dropout=0.0
75
+ ```
76
+
77
+ Here, since we are running on CPU instead of GPU we must set both `--device=cpu` and also turn off PyTorch 2.0 compile with `--compile=False`. Then when we evaluate we get a bit more noisy but faster estimate (`--eval_iters=20`, down from 200), our context size is only 64 characters instead of 256, and the batch size only 12 examples per iteration, not 64. We'll also use a much smaller Transformer (4 layers, 4 heads, 128 embedding size), and decrease the number of iterations to 2000 (and correspondingly usually decay the learning rate to around max_iters with `--lr_decay_iters`). Because our network is so small we also ease down on regularization (`--dropout=0.0`). This still runs in about ~3 minutes, but gets us a loss of only 1.88 and therefore also worse samples, but it's still good fun:
78
+
79
+ ```
80
+ $ python sample.py --out_dir=out-shakespeare-char --device=cpu
81
+ ```
82
+ Generates samples like this:
83
+
84
+ ```
85
+ GLEORKEN VINGHARD III:
86
+ Whell's the couse, the came light gacks,
87
+ And the for mought you in Aut fries the not high shee
88
+ bot thou the sought bechive in that to doth groan you,
89
+ No relving thee post mose the wear
90
+ ```
91
+
92
+ Not bad for ~3 minutes on a CPU, for a hint of the right character gestalt. If you're willing to wait longer, feel free to tune the hyperparameters, increase the size of the network, the context length (`--block_size`), the length of training, etc.
93
+
94
+ Finally, on Apple Silicon Macbooks and with a recent PyTorch version make sure to add `--device=mps` (short for "Metal Performance Shaders"); PyTorch then uses the on-chip GPU that can *significantly* accelerate training (2-3X) and allow you to use larger networks. See [Issue 28](https://github.com/karpathy/nanoGPT/issues/28) for more.
95
+
96
+ ## reproducing GPT-2
97
+
98
+ A more serious deep learning professional may be more interested in reproducing GPT-2 results. So here we go - we first tokenize the dataset, in this case the [OpenWebText](https://openwebtext2.readthedocs.io/en/latest/), an open reproduction of OpenAI's (private) WebText:
99
+
100
+ ```
101
+ $ python data/openwebtext/prepare.py
102
+ ```
103
+
104
+ This downloads and tokenizes the [OpenWebText](https://huggingface.co/datasets/openwebtext) dataset. It will create a `train.bin` and `val.bin` which holds the GPT2 BPE token ids in one sequence, stored as raw uint16 bytes. Then we're ready to kick off training. To reproduce GPT-2 (124M) you'll want at least an 8X A100 40GB node and run:
105
+
106
+ ```
107
+ $ torchrun --standalone --nproc_per_node=8 train.py config/train_gpt2.py
108
+ ```
109
+
110
+ This will run for about 4 days using PyTorch Distributed Data Parallel (DDP) and go down to loss of ~2.85. Now, a GPT-2 model just evaluated on OWT gets a val loss of about 3.11, but if you finetune it it will come down to ~2.85 territory (due to an apparent domain gap), making the two models ~match.
111
+
112
+ If you're in a cluster environment and you are blessed with multiple GPU nodes you can make GPU go brrrr e.g. across 2 nodes like:
113
+
114
+ ```
115
+ Run on the first (master) node with example IP 123.456.123.456:
116
+ $ torchrun --nproc_per_node=8 --nnodes=2 --node_rank=0 --master_addr=123.456.123.456 --master_port=1234 train.py
117
+ Run on the worker node:
118
+ $ torchrun --nproc_per_node=8 --nnodes=2 --node_rank=1 --master_addr=123.456.123.456 --master_port=1234 train.py
119
+ ```
120
+
121
+ It is a good idea to benchmark your interconnect (e.g. iperf3). In particular, if you don't have Infiniband then also prepend `NCCL_IB_DISABLE=1` to the above launches. Your multinode training will work, but most likely _crawl_. By default checkpoints are periodically written to the `--out_dir`. We can sample from the model by simply `$ python sample.py`.
122
+
123
+ Finally, to train on a single GPU simply run the `$ python train.py` script. Have a look at all of its args, the script tries to be very readable, hackable and transparent. You'll most likely want to tune a number of those variables depending on your needs.
124
+
125
+ ## baselines
126
+
127
+ OpenAI GPT-2 checkpoints allow us to get some baselines in place for openwebtext. We can get the numbers as follows:
128
+
129
+ ```
130
+ $ python train.py eval_gpt2
131
+ $ python train.py eval_gpt2_medium
132
+ $ python train.py eval_gpt2_large
133
+ $ python train.py eval_gpt2_xl
134
+ ```
135
+
136
+ and observe the following losses on train and val:
137
+
138
+ | model | params | train loss | val loss |
139
+ | ------| ------ | ---------- | -------- |
140
+ | gpt2 | 124M | 3.11 | 3.12 |
141
+ | gpt2-medium | 350M | 2.85 | 2.84 |
142
+ | gpt2-large | 774M | 2.66 | 2.67 |
143
+ | gpt2-xl | 1558M | 2.56 | 2.54 |
144
+
145
+ However, we have to note that GPT-2 was trained on (closed, never released) WebText, while OpenWebText is just a best-effort open reproduction of this dataset. This means there is a dataset domain gap. Indeed, taking the GPT-2 (124M) checkpoint and finetuning on OWT directly for a while reaches loss down to ~2.85. This then becomes the more appropriate baseline w.r.t. reproduction.
146
+
147
+ ## finetuning
148
+
149
+ Finetuning is no different than training, we just make sure to initialize from a pretrained model and train with a smaller learning rate. For an example of how to finetune a GPT on new text go to `data/shakespeare` and run `prepare.py` to download the tiny shakespeare dataset and render it into a `train.bin` and `val.bin`, using the OpenAI BPE tokenizer from GPT-2. Unlike OpenWebText this will run in seconds. Finetuning can take very little time, e.g. on a single GPU just a few minutes. Run an example finetuning like:
150
+
151
+ ```
152
+ $ python train.py config/finetune_shakespeare.py
153
+ ```
154
+
155
+ This will load the config parameter overrides in `config/finetune_shakespeare.py` (I didn't tune them much though). Basically, we initialize from a GPT2 checkpoint with `init_from` and train as normal, except shorter and with a small learning rate. If you're running out of memory try decreasing the model size (they are `{'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'}`) or possibly decreasing the `block_size` (context length). The best checkpoint (lowest validation loss) will be in the `out_dir` directory, e.g. in `out-shakespeare` by default, per the config file. You can then run the code in `sample.py --out_dir=out-shakespeare`:
156
+
157
+ ```
158
+ THEODORE:
159
+ Thou shalt sell me to the highest bidder: if I die,
160
+ I sell thee to the first; if I go mad,
161
+ I sell thee to the second; if I
162
+ lie, I sell thee to the third; if I slay,
163
+ I sell thee to the fourth: so buy or sell,
164
+ I tell thee again, thou shalt not sell my
165
+ possession.
166
+
167
+ JULIET:
168
+ And if thou steal, thou shalt not sell thyself.
169
+
170
+ THEODORE:
171
+ I do not steal; I sell the stolen goods.
172
+
173
+ THEODORE:
174
+ Thou know'st not what thou sell'st; thou, a woman,
175
+ Thou art ever a victim, a thing of no worth:
176
+ Thou hast no right, no right, but to be sold.
177
+ ```
178
+
179
+ Whoa there, GPT, entering some dark place over there. I didn't really tune the hyperparameters in the config too much, feel free to try!
180
+
181
+ ## sampling / inference
182
+
183
+ Use the script `sample.py` to sample either from pre-trained GPT-2 models released by OpenAI, or from a model you trained yourself. For example, here is a way to sample from the largest available `gpt2-xl` model:
184
+
185
+ ```
186
+ $ python sample.py \
187
+ --init_from=gpt2-xl \
188
+ --start="What is the answer to life, the universe, and everything?" \
189
+ --num_samples=5 --max_new_tokens=100
190
+ ```
191
+
192
+ If you'd like to sample from a model you trained, use the `--out_dir` to point the code appropriately. You can also prompt the model with some text from a file, e.g. `$ python sample.py --start=FILE:prompt.txt`.
193
+
194
+ ## efficiency notes
195
+
196
+ For simple model benchmarking and profiling, `bench.py` might be useful. It's identical to what happens in the meat of the training loop of `train.py`, but omits much of the other complexities.
197
+
198
+ Note that the code by default uses [PyTorch 2.0](https://pytorch.org/get-started/pytorch-2.0/). At the time of writing (Dec 29, 2022) this makes `torch.compile()` available in the nightly release. The improvement from the one line of code is noticeable, e.g. cutting down iteration time from ~250ms / iter to 135ms / iter. Nice work PyTorch team!
199
+
200
+ ## todos
201
+
202
+ - Investigate and add FSDP instead of DDP
203
+ - Eval zero-shot perplexities on standard evals (e.g. LAMBADA? HELM? etc.)
204
+ - Finetune the finetuning script, I think the hyperparams are not great
205
+ - Schedule for linear batch size increase during training
206
+ - Incorporate other embeddings (rotary, alibi)
207
+ - Separate out the optim buffers from model params in checkpoints I think
208
+ - Additional logging around network health (e.g. gradient clip events, magnitudes)
209
+ - Few more investigations around better init etc.
210
+
211
+ ## troubleshooting
212
+
213
+ Note that by default this repo uses PyTorch 2.0 (i.e. `torch.compile`). This is fairly new and experimental, and not yet available on all platforms (e.g. Windows). If you're running into related error messages try to disable this by adding `--compile=False` flag. This will slow down the code but at least it will run.
214
+
215
+ For some context on this repository, GPT, and language modeling it might be helpful to watch my [Zero To Hero series](https://karpathy.ai/zero-to-hero.html). Specifically, the [GPT video](https://www.youtube.com/watch?v=kCc8FmEb1nY) is popular if you have some prior language modeling context.
216
+
217
+ For more questions/discussions feel free to stop by **#nanoGPT** on Discord:
218
+
219
+ [![](https://dcbadge.vercel.app/api/server/3zy8kqD9Cp?compact=true&style=flat)](https://discord.gg/3zy8kqD9Cp)
220
+
221
+ ## acknowledgements
222
+
223
+ All nanoGPT experiments are powered by GPUs on [Lambda labs](https://lambdalabs.com), my favorite Cloud GPU provider. Thank you Lambda labs for sponsoring nanoGPT!
bench.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A much shorter version of train.py for benchmarking
3
+ """
4
+ import os
5
+ from contextlib import nullcontext
6
+ import numpy as np
7
+ import time
8
+ import torch
9
+ from model import GPTConfig, GPT
10
+
11
+ # -----------------------------------------------------------------------------
12
+ batch_size = 12
13
+ block_size = 1024
14
+ bias = False
15
+ real_data = True
16
+ seed = 1337
17
+ device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1', etc.
18
+ dtype = 'bfloat16' # 'float32' or 'bfloat16' or 'float16'
19
+ compile = True # use PyTorch 2.0 to compile the model to be faster
20
+ profile = False # use pytorch profiler, or just simple benchmarking?
21
+ exec(open('configurator.py').read()) # overrides from command line or config file
22
+ # -----------------------------------------------------------------------------
23
+
24
+ torch.manual_seed(seed)
25
+ torch.cuda.manual_seed(seed)
26
+ torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
27
+ torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
28
+ device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast
29
+ ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
30
+ ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
31
+
32
+ # data loading init
33
+ if real_data:
34
+ dataset = 'openwebtext'
35
+ data_dir = os.path.join('data', dataset)
36
+ train_data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r')
37
+ def get_batch(split):
38
+ data = train_data # note ignore split in benchmarking script
39
+ ix = torch.randint(len(data) - block_size, (batch_size,))
40
+ x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix])
41
+ y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix])
42
+ x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)
43
+ return x, y
44
+ else:
45
+ # alternatively, if fixed data is desired to not care about data loading
46
+ x = torch.randint(50304, (batch_size, block_size), device=device)
47
+ y = torch.randint(50304, (batch_size, block_size), device=device)
48
+ get_batch = lambda split: (x, y)
49
+
50
+ # model init
51
+ gptconf = GPTConfig(
52
+ block_size = block_size, # how far back does the model look? i.e. context size
53
+ n_layer = 12, n_head = 12, n_embd = 768, # size of the model
54
+ dropout = 0, # for determinism
55
+ bias = bias,
56
+ )
57
+ model = GPT(gptconf)
58
+ model.to(device)
59
+
60
+ optimizer = model.configure_optimizers(weight_decay=1e-2, learning_rate=1e-4, betas=(0.9, 0.95), device_type=device_type)
61
+
62
+ if compile:
63
+ print("Compiling model...")
64
+ model = torch.compile(model) # pytorch 2.0
65
+
66
+ if profile:
67
+ # useful docs on pytorch profiler:
68
+ # - tutorial https://pytorch.org/tutorials/intermediate/tensorboard_profiler_tutorial.html
69
+ # - api https://pytorch.org/docs/stable/profiler.html#torch.profiler.profile
70
+ wait, warmup, active = 5, 5, 5
71
+ num_steps = wait + warmup + active
72
+ with torch.profiler.profile(
73
+ activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
74
+ schedule=torch.profiler.schedule(wait=wait, warmup=warmup, active=active, repeat=1),
75
+ on_trace_ready=torch.profiler.tensorboard_trace_handler('./bench_log'),
76
+ record_shapes=False,
77
+ profile_memory=False,
78
+ with_stack=False, # incurs an additional overhead, disable if not needed
79
+ with_flops=True,
80
+ with_modules=False, # only for torchscript models atm
81
+ ) as prof:
82
+
83
+ X, Y = get_batch('train')
84
+ for k in range(num_steps):
85
+ with ctx:
86
+ logits, loss = model(X, Y)
87
+ X, Y = get_batch('train')
88
+ optimizer.zero_grad(set_to_none=True)
89
+ loss.backward()
90
+ optimizer.step()
91
+ lossf = loss.item()
92
+ print(f"{k}/{num_steps} loss: {lossf:.4f}")
93
+
94
+ prof.step() # notify the profiler at end of each step
95
+
96
+ else:
97
+
98
+ # simple benchmarking
99
+ torch.cuda.synchronize()
100
+ for stage, num_steps in enumerate([10, 20]): # burnin, then benchmark
101
+ t0 = time.time()
102
+ X, Y = get_batch('train')
103
+ for k in range(num_steps):
104
+ with ctx:
105
+ logits, loss = model(X, Y)
106
+ X, Y = get_batch('train')
107
+ optimizer.zero_grad(set_to_none=True)
108
+ loss.backward()
109
+ optimizer.step()
110
+ lossf = loss.item()
111
+ print(f"{k}/{num_steps} loss: {lossf:.4f}")
112
+ torch.cuda.synchronize()
113
+ t1 = time.time()
114
+ dt = t1-t0
115
+ mfu = model.estimate_mfu(batch_size * 1 * num_steps, dt)
116
+ if stage == 1:
117
+ print(f"time per iteration: {dt/num_steps*1000:.4f}ms, MFU: {mfu*100:.2f}%")
configurator.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Poor Man's Configurator. Probably a terrible idea. Example usage:
3
+ $ python train.py config/override_file.py --batch_size=32
4
+ this will first run config/override_file.py, then override batch_size to 32
5
+
6
+ The code in this file will be run as follows from e.g. train.py:
7
+ >>> exec(open('configurator.py').read())
8
+
9
+ So it's not a Python module, it's just shuttling this code away from train.py
10
+ The code in this script then overrides the globals()
11
+
12
+ I know people are not going to love this, I just really dislike configuration
13
+ complexity and having to prepend config. to every single variable. If someone
14
+ comes up with a better simple Python solution I am all ears.
15
+ """
16
+
17
+ import sys
18
+ from ast import literal_eval
19
+
20
+ for arg in sys.argv[1:]:
21
+ if '=' not in arg:
22
+ # assume it's the name of a config file
23
+ assert not arg.startswith('--')
24
+ config_file = arg
25
+ print(f"Overriding config with {config_file}:")
26
+ with open(config_file) as f:
27
+ print(f.read())
28
+ exec(open(config_file).read())
29
+ else:
30
+ # assume it's a --key=value argument
31
+ assert arg.startswith('--')
32
+ key, val = arg.split('=')
33
+ key = key[2:]
34
+ if key in globals():
35
+ try:
36
+ # attempt to eval it it (e.g. if bool, number, or etc)
37
+ attempt = literal_eval(val)
38
+ except (SyntaxError, ValueError):
39
+ # if that goes wrong, just use the string
40
+ attempt = val
41
+ # ensure the types match ok
42
+ assert type(attempt) == type(globals()[key])
43
+ # cross fingers
44
+ print(f"Overriding: {key} = {attempt}")
45
+ globals()[key] = attempt
46
+ else:
47
+ raise ValueError(f"Unknown config key: {key}")
model.py ADDED
@@ -0,0 +1,337 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Full definition of a GPT Language Model, all of it in this single file.
3
+ References:
4
+ 1) the official GPT-2 TensorFlow implementation released by OpenAI:
5
+ https://github.com/openai/gpt-2/blob/master/src/model.py
6
+ 2) huggingface/transformers PyTorch implementation:
7
+ https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py
8
+ """
9
+
10
+ import math
11
+ import inspect
12
+ from dataclasses import dataclass
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ from torch.nn import functional as F
17
+
18
+ # @torch.jit.script # good to enable when not using torch.compile, disable when using (our default)
19
+ def new_gelu(x):
20
+ """
21
+ Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT).
22
+ Reference: Gaussian Error Linear Units (GELU) paper: https://arxiv.org/abs/1606.08415
23
+ """
24
+ return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
25
+
26
+ class LayerNorm(nn.Module):
27
+ """ LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """
28
+
29
+ def __init__(self, ndim, bias):
30
+ super().__init__()
31
+ self.weight = nn.Parameter(torch.ones(ndim))
32
+ self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
33
+
34
+ def forward(self, input):
35
+ return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)
36
+
37
+ class CausalSelfAttention(nn.Module):
38
+
39
+ def __init__(self, config):
40
+ super().__init__()
41
+ assert config.n_embd % config.n_head == 0
42
+ # key, query, value projections for all heads, but in a batch
43
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
44
+ # output projection
45
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
46
+ # regularization
47
+ self.attn_dropout = nn.Dropout(config.dropout)
48
+ self.resid_dropout = nn.Dropout(config.dropout)
49
+ self.n_head = config.n_head
50
+ self.n_embd = config.n_embd
51
+ self.dropout = config.dropout
52
+ # flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0
53
+ self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
54
+ if not self.flash:
55
+ print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
56
+ # causal mask to ensure that attention is only applied to the left in the input sequence
57
+ self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
58
+ .view(1, 1, config.block_size, config.block_size))
59
+
60
+ def forward(self, x):
61
+ B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
62
+
63
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
64
+ q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
65
+ k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
66
+ q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
67
+ v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
68
+
69
+ # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
70
+ if self.flash:
71
+ # efficient attention using Flash Attention CUDA kernels
72
+ y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True)
73
+ else:
74
+ # manual implementation of attention
75
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
76
+ att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
77
+ att = F.softmax(att, dim=-1)
78
+ att = self.attn_dropout(att)
79
+ y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
80
+ y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
81
+
82
+ # output projection
83
+ y = self.resid_dropout(self.c_proj(y))
84
+ return y
85
+
86
+ class MLP(nn.Module):
87
+
88
+ def __init__(self, config):
89
+ super().__init__()
90
+ self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
91
+ self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
92
+ self.dropout = nn.Dropout(config.dropout)
93
+
94
+ def forward(self, x):
95
+ x = self.c_fc(x)
96
+ x = new_gelu(x)
97
+ x = self.c_proj(x)
98
+ x = self.dropout(x)
99
+ return x
100
+
101
+ class Block(nn.Module):
102
+
103
+ def __init__(self, config):
104
+ super().__init__()
105
+ self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)
106
+ self.attn = CausalSelfAttention(config)
107
+ self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
108
+ self.mlp = MLP(config)
109
+
110
+ def forward(self, x):
111
+ x = x + self.attn(self.ln_1(x))
112
+ x = x + self.mlp(self.ln_2(x))
113
+ return x
114
+
115
+ @dataclass
116
+ class GPTConfig:
117
+ block_size: int = 1024
118
+ vocab_size: int = 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
119
+ n_layer: int = 12
120
+ n_head: int = 12
121
+ n_embd: int = 768
122
+ dropout: float = 0.0
123
+ bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
124
+
125
+ class GPT(nn.Module):
126
+
127
+ def __init__(self, config):
128
+ super().__init__()
129
+ assert config.vocab_size is not None
130
+ assert config.block_size is not None
131
+ self.config = config
132
+
133
+ self.transformer = nn.ModuleDict(dict(
134
+ wte = nn.Embedding(config.vocab_size, config.n_embd),
135
+ wpe = nn.Embedding(config.block_size, config.n_embd),
136
+ drop = nn.Dropout(config.dropout),
137
+ h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
138
+ ln_f = LayerNorm(config.n_embd, bias=config.bias),
139
+ ))
140
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
141
+ # with weight tying when using torch.compile() some warnings get generated:
142
+ # "UserWarning: functional_call was passed multiple values for tied weights.
143
+ # This behavior is deprecated and will be an error in future versions"
144
+ # not 100% sure what this is, so far seems to be harmless. TODO investigate
145
+ self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying
146
+
147
+ # init all weights
148
+ self.apply(self._init_weights)
149
+ # apply special scaled init to the residual projections, per GPT-2 paper
150
+ for pn, p in self.named_parameters():
151
+ if pn.endswith('c_proj.weight'):
152
+ torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))
153
+
154
+ # report number of parameters
155
+ print("number of parameters: %.2fM" % (self.get_num_params()/1e6,))
156
+
157
+ def get_num_params(self, non_embedding=True):
158
+ """
159
+ Return the number of parameters in the model.
160
+ For non-embedding count (default), the position embeddings get subtracted.
161
+ The token embeddings would too, except due to the parameter sharing these
162
+ params are actually used as weights in the final layer, so we include them.
163
+ """
164
+ n_params = sum(p.numel() for p in self.parameters())
165
+ if non_embedding:
166
+ n_params -= self.transformer.wpe.weight.numel()
167
+ return n_params
168
+
169
+ def _init_weights(self, module):
170
+ if isinstance(module, nn.Linear):
171
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
172
+ if module.bias is not None:
173
+ torch.nn.init.zeros_(module.bias)
174
+ elif isinstance(module, nn.Embedding):
175
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
176
+
177
+ def forward(self, idx, targets=None):
178
+ device = idx.device
179
+ b, t = idx.size()
180
+ assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
181
+ pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t)
182
+
183
+ # forward the GPT model itself
184
+ tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
185
+ pos_emb = self.transformer.wpe(pos) # position embeddings of shape (1, t, n_embd)
186
+ x = self.transformer.drop(tok_emb + pos_emb)
187
+ for block in self.transformer.h:
188
+ x = block(x)
189
+ x = self.transformer.ln_f(x)
190
+
191
+ if targets is not None:
192
+ # if we are given some desired targets also calculate the loss
193
+ logits = self.lm_head(x)
194
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
195
+ else:
196
+ # inference-time mini-optimization: only forward the lm_head on the very last position
197
+ logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
198
+ loss = None
199
+
200
+ return logits, loss
201
+
202
+ def crop_block_size(self, block_size):
203
+ # model surgery to decrease the block size if necessary
204
+ # e.g. we may load the GPT2 pretrained model checkpoint (block size 1024)
205
+ # but want to use a smaller block size for some smaller, simpler model
206
+ assert block_size <= self.config.block_size
207
+ self.config.block_size = block_size
208
+ self.transformer.wpe.weight = nn.Parameter(self.transformer.wpe.weight[:block_size])
209
+ for block in self.transformer.h:
210
+ if hasattr(block.attn, 'bias'):
211
+ block.attn.bias = block.attn.bias[:,:,:block_size,:block_size]
212
+
213
+ @classmethod
214
+ def from_pretrained(cls, model_type, override_args=None):
215
+ assert model_type in {'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'}
216
+ override_args = override_args or {} # default to empty dict
217
+ # only dropout can be overridden see more notes below
218
+ assert all(k == 'dropout' for k in override_args)
219
+ from transformers import GPT2LMHeadModel
220
+ print("loading weights from pretrained gpt: %s" % model_type)
221
+
222
+ # n_layer, n_head and n_embd are determined from model_type
223
+ config_args = {
224
+ 'gpt2': dict(n_layer=12, n_head=12, n_embd=768), # 124M params
225
+ 'gpt2-medium': dict(n_layer=24, n_head=16, n_embd=1024), # 350M params
226
+ 'gpt2-large': dict(n_layer=36, n_head=20, n_embd=1280), # 774M params
227
+ 'gpt2-xl': dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params
228
+ }[model_type]
229
+ print("forcing vocab_size=50257, block_size=1024, bias=True")
230
+ config_args['vocab_size'] = 50257 # always 50257 for GPT model checkpoints
231
+ config_args['block_size'] = 1024 # always 1024 for GPT model checkpoints
232
+ config_args['bias'] = True # always True for GPT model checkpoints
233
+ # we can override the dropout rate, if desired
234
+ if 'dropout' in override_args:
235
+ print(f"overriding dropout rate to {override_args['dropout']}")
236
+ config_args['dropout'] = override_args['dropout']
237
+ # create a from-scratch initialized minGPT model
238
+ config = GPTConfig(**config_args)
239
+ model = GPT(config)
240
+ sd = model.state_dict()
241
+ sd_keys = sd.keys()
242
+ sd_keys = [k for k in sd_keys if not k.endswith('.attn.bias')] # discard this mask / buffer, not a param
243
+
244
+ # init a huggingface/transformers model
245
+ model_hf = GPT2LMHeadModel.from_pretrained(model_type)
246
+ sd_hf = model_hf.state_dict()
247
+
248
+ # copy while ensuring all of the parameters are aligned and match in names and shapes
249
+ sd_keys_hf = sd_hf.keys()
250
+ sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.masked_bias')] # ignore these, just a buffer
251
+ sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.bias')] # same, just the mask (buffer)
252
+ transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight']
253
+ # basically the openai checkpoints use a "Conv1D" module, but we only want to use a vanilla Linear
254
+ # this means that we have to transpose these weights when we import them
255
+ assert len(sd_keys_hf) == len(sd_keys), f"mismatched keys: {len(sd_keys_hf)} != {len(sd_keys)}"
256
+ for k in sd_keys_hf:
257
+ if any(k.endswith(w) for w in transposed):
258
+ # special treatment for the Conv1D weights we need to transpose
259
+ assert sd_hf[k].shape[::-1] == sd[k].shape
260
+ with torch.no_grad():
261
+ sd[k].copy_(sd_hf[k].t())
262
+ else:
263
+ # vanilla copy over the other parameters
264
+ assert sd_hf[k].shape == sd[k].shape
265
+ with torch.no_grad():
266
+ sd[k].copy_(sd_hf[k])
267
+
268
+ return model
269
+
270
+ def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):
271
+ # start with all of the candidate parameters
272
+ param_dict = {pn: p for pn, p in self.named_parameters()}
273
+ # filter out those that do not require grad
274
+ param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
275
+ # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
276
+ # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
277
+ decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
278
+ nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
279
+ optim_groups = [
280
+ {'params': decay_params, 'weight_decay': weight_decay},
281
+ {'params': nodecay_params, 'weight_decay': 0.0}
282
+ ]
283
+ num_decay_params = sum(p.numel() for p in decay_params)
284
+ num_nodecay_params = sum(p.numel() for p in nodecay_params)
285
+ print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
286
+ print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
287
+ # Create AdamW optimizer and use the fused version if it is available
288
+ fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
289
+ use_fused = fused_available and device_type == 'cuda'
290
+ extra_args = dict(fused=True) if use_fused else dict()
291
+ optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args)
292
+ print(f"using fused AdamW: {use_fused}")
293
+
294
+ return optimizer
295
+
296
+ def estimate_mfu(self, fwdbwd_per_iter, dt):
297
+ """ estimate model flops utilization (MFU) in units of A100 bfloat16 peak FLOPS """
298
+ # first estimate the number of flops we do per iteration.
299
+ # see PaLM paper Appendix B as ref: https://arxiv.org/abs/2204.02311
300
+ N = self.get_num_params()
301
+ cfg = self.config
302
+ L, H, Q, T = cfg.n_layer, cfg.n_head, cfg.n_embd//cfg.n_head, cfg.block_size
303
+ flops_per_token = 6*N + 12*L*H*Q*T
304
+ flops_per_fwdbwd = flops_per_token * T
305
+ flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter
306
+ # express our flops throughput as ratio of A100 bfloat16 peak flops
307
+ flops_achieved = flops_per_iter * (1.0/dt) # per second
308
+ flops_promised = 312e12 # A100 GPU bfloat16 peak flops is 312 TFLOPS
309
+ mfu = flops_achieved / flops_promised
310
+ return mfu
311
+
312
+ @torch.no_grad()
313
+ def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
314
+ """
315
+ Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
316
+ the sequence max_new_tokens times, feeding the predictions back into the model each time.
317
+ Most likely you'll want to make sure to be in model.eval() mode of operation for this.
318
+ """
319
+ for _ in range(max_new_tokens):
320
+ # if the sequence context is growing too long we must crop it at block_size
321
+ idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
322
+ # forward the model to get the logits for the index in the sequence
323
+ logits, _ = self(idx_cond)
324
+ # pluck the logits at the final step and scale by desired temperature
325
+ logits = logits[:, -1, :] / temperature
326
+ # optionally crop the logits to only the top k options
327
+ if top_k is not None:
328
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
329
+ logits[logits < v[:, [-1]]] = -float('Inf')
330
+ # apply softmax to convert logits to (normalized) probabilities
331
+ probs = F.softmax(logits, dim=-1)
332
+ # sample from the distribution
333
+ idx_next = torch.multinomial(probs, num_samples=1)
334
+ # append sampled index to the running sequence and continue
335
+ idx = torch.cat((idx, idx_next), dim=1)
336
+
337
+ return idx
sample.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Sample from a trained model
3
+ """
4
+ import os
5
+ import pickle
6
+ from contextlib import nullcontext
7
+ import torch
8
+ import tiktoken
9
+ from model import GPTConfig, GPT
10
+
11
+ # -----------------------------------------------------------------------------
12
+ init_from = 'resume' # either 'resume' (from an out_dir) or a gpt2 variant (e.g. 'gpt2-xl')
13
+ out_dir = 'out' # ignored if init_from is not 'resume'
14
+ start = "\n" # or "<|endoftext|>" or etc. Can also specify a file, use as: "FILE:prompt.txt"
15
+ num_samples = 10 # number of samples to draw
16
+ max_new_tokens = 500 # number of tokens generated in each sample
17
+ temperature = 0.8 # 1.0 = no change, < 1.0 = less random, > 1.0 = more random, in predictions
18
+ top_k = 200 # retain only the top_k most likely tokens, clamp others to have 0 probability
19
+ seed = 1337
20
+ device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1', etc.
21
+ dtype = 'bfloat16' # 'float32' or 'bfloat16' or 'float16'
22
+ compile = False # use PyTorch 2.0 to compile the model to be faster
23
+ exec(open('configurator.py').read()) # overrides from command line or config file
24
+ # -----------------------------------------------------------------------------
25
+
26
+ torch.manual_seed(seed)
27
+ torch.cuda.manual_seed(seed)
28
+ torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
29
+ torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
30
+ device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast
31
+ ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
32
+ ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
33
+
34
+ # model
35
+ if init_from == 'resume':
36
+ # init from a model saved in a specific directory
37
+ ckpt_path = os.path.join(out_dir, 'ckpt.pt')
38
+ checkpoint = torch.load(ckpt_path, map_location=device)
39
+ gptconf = GPTConfig(**checkpoint['model_args'])
40
+ model = GPT(gptconf)
41
+ state_dict = checkpoint['model']
42
+ unwanted_prefix = '_orig_mod.'
43
+ for k,v in list(state_dict.items()):
44
+ if k.startswith(unwanted_prefix):
45
+ state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
46
+ model.load_state_dict(state_dict)
47
+ elif init_from.startswith('gpt2'):
48
+ # init from a given GPT-2 model
49
+ model = GPT.from_pretrained(init_from, dict(dropout=0.0))
50
+
51
+ model.eval()
52
+ model.to(device)
53
+ if compile:
54
+ model = torch.compile(model) # requires PyTorch 2.0 (optional)
55
+
56
+ # look for the meta pickle in case it is available in the dataset folder
57
+ load_meta = False
58
+ if init_from == 'resume' and 'config' in checkpoint and 'dataset' in checkpoint['config']: # older checkpoints might not have these...
59
+ meta_path = os.path.join('data', checkpoint['config']['dataset'], 'meta.pkl')
60
+ load_meta = os.path.exists(meta_path)
61
+ if load_meta:
62
+ print(f"Loading meta from {meta_path}...")
63
+ with open(meta_path, 'rb') as f:
64
+ meta = pickle.load(f)
65
+ # TODO want to make this more general to arbitrary encoder/decoder schemes
66
+ stoi, itos = meta['stoi'], meta['itos']
67
+ encode = lambda s: [stoi[c] for c in s]
68
+ decode = lambda l: ''.join([itos[i] for i in l])
69
+ else:
70
+ # ok let's assume gpt-2 encodings by default
71
+ print("No meta.pkl found, assuming GPT-2 encodings...")
72
+ enc = tiktoken.get_encoding("gpt2")
73
+ encode = lambda s: enc.encode(s, allowed_special={"<|endoftext|>"})
74
+ decode = lambda l: enc.decode(l)
75
+
76
+ # encode the beginning of the prompt
77
+ if start.startswith('FILE:'):
78
+ with open(start[5:], 'r', encoding='utf-8') as f:
79
+ start = f.read()
80
+ start_ids = encode(start)
81
+ x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...])
82
+
83
+ # run generation
84
+ with torch.no_grad():
85
+ with ctx:
86
+ for k in range(num_samples):
87
+ y = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k)
88
+ print(decode(y[0].tolist()))
89
+ print('---------------')
scaling_laws.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
train.py ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This training script can be run both on a single gpu in debug mode,
3
+ and also in a larger training run with distributed data parallel (ddp).
4
+
5
+ To run on a single GPU, example:
6
+ $ python train.py --batch_size=32 --compile=False
7
+
8
+ To run with DDP on 4 gpus on 1 node, example:
9
+ $ torchrun --standalone --nproc_per_node=4 train.py
10
+
11
+ To run with DDP on 4 gpus across 2 nodes, example:
12
+ - Run on the first (master) node with example IP 123.456.123.456:
13
+ $ torchrun --nproc_per_node=8 --nnodes=2 --node_rank=0 --master_addr=123.456.123.456 --master_port=1234 train.py
14
+ - Run on the worker node:
15
+ $ torchrun --nproc_per_node=8 --nnodes=2 --node_rank=1 --master_addr=123.456.123.456 --master_port=1234 train.py
16
+ (If your cluster does not have Infiniband interconnect prepend NCCL_IB_DISABLE=1)
17
+ """
18
+
19
+ import os
20
+ import time
21
+ import math
22
+ import pickle
23
+ from contextlib import nullcontext
24
+
25
+ import numpy as np
26
+ import torch
27
+ from torch.nn.parallel import DistributedDataParallel as DDP
28
+ from torch.distributed import init_process_group, destroy_process_group
29
+
30
+ from model import GPTConfig, GPT
31
+
32
+ # -----------------------------------------------------------------------------
33
+ # default config values designed to train a gpt2 (124M) on OpenWebText
34
+ # I/O
35
+ out_dir = 'out'
36
+ eval_interval = 2000
37
+ log_interval = 1
38
+ eval_iters = 200
39
+ eval_only = False # if True, script exits right after the first eval
40
+ always_save_checkpoint = True # if True, always save a checkpoint after each eval
41
+ init_from = 'scratch' # 'scratch' or 'resume' or 'gpt2*'
42
+ # wandb logging
43
+ wandb_log = False # disabled by default
44
+ wandb_project = 'owt'
45
+ wandb_run_name = 'gpt2' # 'run' + str(time.time())
46
+ # data
47
+ dataset = 'openwebtext'
48
+ gradient_accumulation_steps = 5 * 8 # used to simulate larger batch sizes
49
+ batch_size = 12 # if gradient_accumulation_steps > 1, this is the micro-batch size
50
+ block_size = 1024
51
+ # model
52
+ n_layer = 12
53
+ n_head = 12
54
+ n_embd = 768
55
+ dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+
56
+ bias = False # do we use bias inside LayerNorm and Linear layers?
57
+ # adamw optimizer
58
+ learning_rate = 6e-4 # max learning rate
59
+ max_iters = 600000 # total number of training iterations
60
+ weight_decay = 1e-1
61
+ beta1 = 0.9
62
+ beta2 = 0.95
63
+ grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0
64
+ # learning rate decay settings
65
+ decay_lr = True # whether to decay the learning rate
66
+ warmup_iters = 2000 # how many steps to warm up for
67
+ lr_decay_iters = 600000 # should be ~= max_iters per Chinchilla
68
+ min_lr = 6e-5 # minimum learning rate, should be ~= learning_rate/10 per Chinchilla
69
+ # DDP settings
70
+ backend = 'nccl' # 'nccl', 'gloo', etc.
71
+ # system
72
+ device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks
73
+ dtype = 'bfloat16' # 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler
74
+ compile = True # use PyTorch 2.0 to compile the model to be faster
75
+ # -----------------------------------------------------------------------------
76
+ config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
77
+ exec(open('configurator.py').read()) # overrides from command line or config file
78
+ config = {k: globals()[k] for k in config_keys} # will be useful for logging
79
+ # -----------------------------------------------------------------------------
80
+
81
+ # various inits, derived attributes, I/O setup
82
+ ddp = int(os.environ.get('RANK', -1)) != -1 # is this a ddp run?
83
+ if ddp:
84
+ init_process_group(backend=backend)
85
+ ddp_rank = int(os.environ['RANK'])
86
+ ddp_local_rank = int(os.environ['LOCAL_RANK'])
87
+ ddp_world_size = int(os.environ['WORLD_SIZE'])
88
+ device = f'cuda:{ddp_local_rank}'
89
+ torch.cuda.set_device(device)
90
+ master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
91
+ seed_offset = ddp_rank # each process gets a different seed
92
+ assert gradient_accumulation_steps % torch.cuda.device_count() == 0
93
+ gradient_accumulation_steps //= torch.cuda.device_count()
94
+ else:
95
+ # if not ddp, we are running on a single gpu, and one process
96
+ master_process = True
97
+ seed_offset = 0
98
+ ddp_world_size = 1
99
+ tokens_per_iter = gradient_accumulation_steps * ddp_world_size * batch_size * block_size
100
+ print(f"tokens per iteration will be: {tokens_per_iter:,}")
101
+
102
+ if master_process:
103
+ os.makedirs(out_dir, exist_ok=True)
104
+ torch.manual_seed(1337 + seed_offset)
105
+ torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
106
+ torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
107
+ device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast
108
+ # note: float16 data type will automatically use a GradScaler
109
+ ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
110
+ ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
111
+
112
+ # poor man's data loader
113
+ data_dir = os.path.join('data', dataset)
114
+ train_data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r')
115
+ val_data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r')
116
+ def get_batch(split):
117
+ data = train_data if split == 'train' else val_data
118
+ ix = torch.randint(len(data) - block_size, (batch_size,))
119
+ x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix])
120
+ y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix])
121
+ if device_type == 'cuda':
122
+ # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True)
123
+ x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)
124
+ else:
125
+ x, y = x.to(device), y.to(device)
126
+ return x, y
127
+
128
+ # init these up here, can override if init_from='resume' (i.e. from a checkpoint)
129
+ iter_num = 0
130
+ best_val_loss = 1e9
131
+
132
+ # attempt to derive vocab_size from the dataset
133
+ meta_path = os.path.join(data_dir, 'meta.pkl')
134
+ meta_vocab_size = None
135
+ if os.path.exists(meta_path):
136
+ with open(meta_path, 'rb') as f:
137
+ meta = pickle.load(f)
138
+ meta_vocab_size = meta['vocab_size']
139
+ print(f"found vocab_size = {meta_vocab_size} (inside {meta_path})")
140
+
141
+ # model init
142
+ model_args = dict(n_layer=n_layer, n_head=n_head, n_embd=n_embd, block_size=block_size,
143
+ bias=bias, vocab_size=None, dropout=dropout) # start with model_args from command line
144
+ if init_from == 'scratch':
145
+ # init a new model from scratch
146
+ print("Initializing a new model from scratch")
147
+ # determine the vocab size we'll use for from-scratch training
148
+ if meta_vocab_size is None:
149
+ print("defaulting to vocab_size of GPT-2 to 50304 (50257 rounded up for efficiency)")
150
+ model_args['vocab_size'] = meta_vocab_size if meta_vocab_size is not None else 50304
151
+ gptconf = GPTConfig(**model_args)
152
+ model = GPT(gptconf)
153
+ elif init_from == 'resume':
154
+ print(f"Resuming training from {out_dir}")
155
+ # resume training from a checkpoint.
156
+ ckpt_path = os.path.join(out_dir, 'ckpt.pt')
157
+ checkpoint = torch.load(ckpt_path, map_location=device)
158
+ checkpoint_model_args = checkpoint['model_args']
159
+ # force these config attributes to be equal otherwise we can't even resume training
160
+ # the rest of the attributes (e.g. dropout) can stay as desired from command line
161
+ for k in ['n_layer', 'n_head', 'n_embd', 'block_size', 'bias', 'vocab_size']:
162
+ model_args[k] = checkpoint_model_args[k]
163
+ # create the model
164
+ gptconf = GPTConfig(**model_args)
165
+ model = GPT(gptconf)
166
+ state_dict = checkpoint['model']
167
+ # fix the keys of the state dictionary :(
168
+ # honestly no idea how checkpoints sometimes get this prefix, have to debug more
169
+ unwanted_prefix = '_orig_mod.'
170
+ for k,v in list(state_dict.items()):
171
+ if k.startswith(unwanted_prefix):
172
+ state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
173
+ model.load_state_dict(state_dict)
174
+ iter_num = checkpoint['iter_num']
175
+ best_val_loss = checkpoint['best_val_loss']
176
+ elif init_from.startswith('gpt2'):
177
+ print(f"Initializing from OpenAI GPT-2 weights: {init_from}")
178
+ # initialize from OpenAI GPT-2 weights
179
+ override_args = dict(dropout=dropout)
180
+ model = GPT.from_pretrained(init_from, override_args)
181
+ # read off the created config params, so we can store them into checkpoint correctly
182
+ for k in ['n_layer', 'n_head', 'n_embd', 'block_size', 'bias', 'vocab_size']:
183
+ model_args[k] = getattr(model.config, k)
184
+ # crop down the model block size if desired, using model surgery
185
+ if block_size < model.config.block_size:
186
+ model.crop_block_size(block_size)
187
+ model_args['block_size'] = block_size # so that the checkpoint will have the right value
188
+ model.to(device)
189
+
190
+ # initialize a GradScaler. If enabled=False scaler is a no-op
191
+ scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16'))
192
+
193
+ # optimizer
194
+ optimizer = model.configure_optimizers(weight_decay, learning_rate, (beta1, beta2), device_type)
195
+ if init_from == 'resume':
196
+ optimizer.load_state_dict(checkpoint['optimizer'])
197
+ checkpoint = None # free up memory
198
+
199
+ # compile the model
200
+ if compile:
201
+ print("compiling the model... (takes a ~minute)")
202
+ unoptimized_model = model
203
+ model = torch.compile(model) # requires PyTorch 2.0
204
+
205
+ # wrap model into DDP container
206
+ if ddp:
207
+ model = DDP(model, device_ids=[ddp_local_rank])
208
+
209
+ # helps estimate an arbitrarily accurate loss over either split using many batches
210
+ @torch.no_grad()
211
+ def estimate_loss():
212
+ out = {}
213
+ model.eval()
214
+ for split in ['train', 'val']:
215
+ losses = torch.zeros(eval_iters)
216
+ for k in range(eval_iters):
217
+ X, Y = get_batch(split)
218
+ with ctx:
219
+ logits, loss = model(X, Y)
220
+ losses[k] = loss.item()
221
+ out[split] = losses.mean()
222
+ model.train()
223
+ return out
224
+
225
+ # learning rate decay scheduler (cosine with warmup)
226
+ def get_lr(it):
227
+ # 1) linear warmup for warmup_iters steps
228
+ if it < warmup_iters:
229
+ return learning_rate * it / warmup_iters
230
+ # 2) if it > lr_decay_iters, return min learning rate
231
+ if it > lr_decay_iters:
232
+ return min_lr
233
+ # 3) in between, use cosine decay down to min learning rate
234
+ decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)
235
+ assert 0 <= decay_ratio <= 1
236
+ coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1
237
+ return min_lr + coeff * (learning_rate - min_lr)
238
+
239
+ # logging
240
+ if wandb_log and master_process:
241
+ import wandb
242
+ wandb.init(project=wandb_project, name=wandb_run_name, config=config)
243
+
244
+ # training loop
245
+ X, Y = get_batch('train') # fetch the very first batch
246
+ t0 = time.time()
247
+ local_iter_num = 0 # number of iterations in the lifetime of this process
248
+ raw_model = model.module if ddp else model # unwrap DDP container if needed
249
+ running_mfu = -1.0
250
+ while True:
251
+
252
+ # determine and set the learning rate for this iteration
253
+ lr = get_lr(iter_num) if decay_lr else learning_rate
254
+ for param_group in optimizer.param_groups:
255
+ param_group['lr'] = lr
256
+
257
+ # evaluate the loss on train/val sets and write checkpoints
258
+ if iter_num % eval_interval == 0 and master_process:
259
+ losses = estimate_loss()
260
+ print(f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
261
+ if wandb_log:
262
+ wandb.log({
263
+ "iter": iter_num,
264
+ "train/loss": losses['train'],
265
+ "val/loss": losses['val'],
266
+ "lr": lr,
267
+ "mfu": running_mfu*100, # convert to percentage
268
+ })
269
+ if losses['val'] < best_val_loss or always_save_checkpoint:
270
+ best_val_loss = losses['val']
271
+ if iter_num > 0:
272
+ checkpoint = {
273
+ 'model': raw_model.state_dict(),
274
+ 'optimizer': optimizer.state_dict(),
275
+ 'model_args': model_args,
276
+ 'iter_num': iter_num,
277
+ 'best_val_loss': best_val_loss,
278
+ 'config': config,
279
+ }
280
+ print(f"saving checkpoint to {out_dir}")
281
+ torch.save(checkpoint, os.path.join(out_dir, 'ckpt.pt'))
282
+ if iter_num == 0 and eval_only:
283
+
284
+ break
285
+
286
+ # forward backward update, with optional gradient accumulation to simulate larger batch size
287
+ # and using the GradScaler if data type is float16
288
+ for micro_step in range(gradient_accumulation_steps):
289
+ if ddp:
290
+ # in DDP training we only need to sync gradients at the last micro step.
291
+ # the official way to do this is with model.no_sync() context manager, but
292
+ # I really dislike that this bloats the code and forces us to repeat code
293
+ # looking at the source of that context manager, it just toggles this variable
294
+ model.require_backward_grad_sync = (micro_step == gradient_accumulation_steps - 1)
295
+ with ctx:
296
+ logits, loss = model(X, Y)
297
+ loss = loss / gradient_accumulation_steps # scale the loss to account for gradient accumulation
298
+ # immediately async prefetch next batch while model is doing the forward pass on the GPU
299
+ X, Y = get_batch('train')
300
+ # backward pass, with gradient scaling if training in fp16
301
+ scaler.scale(loss).backward()
302
+ # clip the gradient
303
+ if grad_clip != 0.0:
304
+ scaler.unscale_(optimizer)
305
+ torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
306
+ # step the optimizer and scaler if training in fp16
307
+ scaler.step(optimizer)
308
+ scaler.update()
309
+ # flush the gradients as soon as we can, no need for this memory anymore
310
+ optimizer.zero_grad(set_to_none=True)
311
+
312
+ # timing and logging
313
+ t1 = time.time()
314
+ dt = t1 - t0
315
+ t0 = t1
316
+ if iter_num % log_interval == 0 and master_process:
317
+ # get loss as float. note: this is a CPU-GPU sync point
318
+ # scale up to undo the division above, approximating the true total loss (exact would have been a sum)
319
+ lossf = loss.item() * gradient_accumulation_steps
320
+ if local_iter_num >= 5: # let the training loop settle a bit
321
+ mfu = raw_model.estimate_mfu(batch_size * gradient_accumulation_steps, dt)
322
+ running_mfu = mfu if running_mfu == -1.0 else 0.9*running_mfu + 0.1*mfu
323
+ print(f"iter {iter_num}: loss {lossf:.4f}, time {dt*1000:.2f}ms, mfu {running_mfu*100:.2f}%")
324
+ iter_num += 1
325
+ local_iter_num += 1
326
+
327
+ # termination conditions
328
+ if iter_num > max_iters:
329
+ break
330
+
331
+ if ddp:
332
+ destroy_process_group()
transformer_sizing.ipynb ADDED
@@ -0,0 +1,402 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "attachments": {},
5
+ "cell_type": "markdown",
6
+ "metadata": {},
7
+ "source": [
8
+ "### Transformer Theoretical Model\n",
9
+ "\n",
10
+ "This notebook stores a bunch of analysis about a Transformer, e.g. estimates the number of FLOPs, parameters, peak memory footprint, checkpoint size, etc."
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "code",
15
+ "execution_count": 1,
16
+ "metadata": {},
17
+ "outputs": [],
18
+ "source": [
19
+ "from collections import OrderedDict"
20
+ ]
21
+ },
22
+ {
23
+ "cell_type": "code",
24
+ "execution_count": 2,
25
+ "metadata": {},
26
+ "outputs": [],
27
+ "source": [
28
+ "# config_args = {\n",
29
+ "# 'gpt2': dict(n_layer=12, n_head=12, n_embd=768), # 124M params\n",
30
+ "# 'gpt2-medium': dict(n_layer=24, n_head=16, n_embd=1024), # 350M params\n",
31
+ "# 'gpt2-large': dict(n_layer=36, n_head=20, n_embd=1280), # 774M params\n",
32
+ "# 'gpt2-xl': dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params\n",
33
+ "# }[model_type]\n",
34
+ "\n",
35
+ "block_size = 1024\n",
36
+ "vocab_size = 50257\n",
37
+ "n_layer = 12\n",
38
+ "n_head = 12\n",
39
+ "n_embd = 768\n",
40
+ "bias = False\n",
41
+ "assert not bias, \"this notebook assumes bias=False just for simplicity\""
42
+ ]
43
+ },
44
+ {
45
+ "cell_type": "code",
46
+ "execution_count": 3,
47
+ "metadata": {},
48
+ "outputs": [
49
+ {
50
+ "name": "stdout",
51
+ "output_type": "stream",
52
+ "text": [
53
+ "we see: 124337664, expected: 124337664, match: True\n",
54
+ "name params ratio (%) \n",
55
+ "emebedding/position 786432 0.6325\n",
56
+ "embedding/token 38597376 31.0424\n",
57
+ "embedding 39383808 31.6749\n",
58
+ "attention/ln 768 0.0006\n",
59
+ "attention/kqv 1769472 1.4231\n",
60
+ "attention/proj 589824 0.4744\n",
61
+ "attention 2360064 1.8981\n",
62
+ "mlp/ln 768 0.0006\n",
63
+ "mlp/ffw 2359296 1.8975\n",
64
+ "mlp/proj 2359296 1.8975\n",
65
+ "mlp 4719360 3.7956\n",
66
+ "block 7079424 5.6937\n",
67
+ "transformer 84953088 68.3245\n",
68
+ "ln_f 768 0.0006\n",
69
+ "dense 0 0.0000\n",
70
+ "total 124337664 100.0000\n"
71
+ ]
72
+ }
73
+ ],
74
+ "source": [
75
+ "def params():\n",
76
+ " \"\"\" estimates the number of parameters in the model\"\"\"\n",
77
+ " out = OrderedDict()\n",
78
+ "\n",
79
+ " # token and position embeddings\n",
80
+ " out['emebedding/position'] = n_embd * block_size\n",
81
+ " out['embedding/token'] = n_embd * vocab_size\n",
82
+ " out['embedding'] = out['emebedding/position'] + out['embedding/token']\n",
83
+ "\n",
84
+ " # attention blocks\n",
85
+ " out['attention/ln'] = n_embd # note, bias=False in our LN\n",
86
+ " out['attention/kqv'] = n_embd * 3*n_embd\n",
87
+ " out['attention/proj'] = n_embd**2\n",
88
+ " out['attention'] = out['attention/ln'] + out['attention/kqv'] + out['attention/proj']\n",
89
+ "\n",
90
+ " # MLP blocks\n",
91
+ " ffw_size = 4*n_embd # feed forward size\n",
92
+ " out['mlp/ln'] = n_embd\n",
93
+ " out['mlp/ffw'] = n_embd * ffw_size\n",
94
+ " out['mlp/proj'] = ffw_size * n_embd\n",
95
+ " out['mlp'] = out['mlp/ln'] + out['mlp/ffw'] + out['mlp/proj']\n",
96
+ " \n",
97
+ " # the transformer and the rest of it\n",
98
+ " out['block'] = out['attention'] + out['mlp']\n",
99
+ " out['transformer'] = n_layer * out['block']\n",
100
+ " out['ln_f'] = n_embd # final layernorm\n",
101
+ " out['dense'] = 0 # 0 because of parameter sharing. This layer uses the weights from the embedding layer\n",
102
+ "\n",
103
+ " # total\n",
104
+ " out['total'] = out['embedding'] + out['transformer'] + out['ln_f'] + out['dense']\n",
105
+ "\n",
106
+ " return out\n",
107
+ "\n",
108
+ "# compare our param count to that reported by PyTorch\n",
109
+ "p = params()\n",
110
+ "params_total = p['total']\n",
111
+ "print(f\"we see: {params_total}, expected: {124337664}, match: {params_total == 124337664}\")\n",
112
+ "# create a header\n",
113
+ "print(f\"{'name':20s} {'params':10s} {'ratio (%)':10s}\")\n",
114
+ "for k,v in p.items():\n",
115
+ " print(f\"{k:20s} {v:10d} {v/params_total*100:10.4f}\")\n",
116
+ " "
117
+ ]
118
+ },
119
+ {
120
+ "cell_type": "code",
121
+ "execution_count": 4,
122
+ "metadata": {},
123
+ "outputs": [
124
+ {
125
+ "name": "stdout",
126
+ "output_type": "stream",
127
+ "text": [
128
+ "est checkpoint size: 1.49 GB\n",
129
+ "measured with wc -c ckpt.pt: 1542470366\n",
130
+ "fluff ratio: 103.38%\n"
131
+ ]
132
+ }
133
+ ],
134
+ "source": [
135
+ "# we can now calculate the size of each checkpoint\n",
136
+ "# params are stored in fp32, and the AdamW optimizer has 2 additional buffers per param for statistics\n",
137
+ "params_bytes = params_total*4\n",
138
+ "params_and_buffers_bytes = params_bytes + 2*params_bytes\n",
139
+ "print(f\"est checkpoint size: {params_and_buffers_bytes/1e9:.2f} GB\")\n",
140
+ "measured_bytes = 1542470366 # from wc -c ckpt.pt\n",
141
+ "print(f\"measured with wc -c ckpt.pt: {measured_bytes}\")\n",
142
+ "print(f\"fluff ratio: {measured_bytes/params_and_buffers_bytes*100:.2f}%\")"
143
+ ]
144
+ },
145
+ {
146
+ "attachments": {},
147
+ "cell_type": "markdown",
148
+ "metadata": {},
149
+ "source": [
150
+ "We can also estimate the ratio of our GPU memory that will be taken up just by the weights and the buffers inside the AdamW optimizer"
151
+ ]
152
+ },
153
+ {
154
+ "cell_type": "code",
155
+ "execution_count": 5,
156
+ "metadata": {},
157
+ "outputs": [
158
+ {
159
+ "name": "stdout",
160
+ "output_type": "stream",
161
+ "text": [
162
+ "memory ratio taken up just for parameters: 3.73%\n"
163
+ ]
164
+ }
165
+ ],
166
+ "source": [
167
+ "gpu_memory = 40e9 # 40 GB A100 GPU, roughly\n",
168
+ "print(f\"memory ratio taken up just for parameters: {params_and_buffers_bytes / gpu_memory * 100:.2f}%\")"
169
+ ]
170
+ },
171
+ {
172
+ "attachments": {},
173
+ "cell_type": "markdown",
174
+ "metadata": {},
175
+ "source": [
176
+ "i.e. not that much of the memory for this tiny model, most of the memory is activations (forward and backward). This of course changes dramatically for larger and larger models."
177
+ ]
178
+ },
179
+ {
180
+ "attachments": {},
181
+ "cell_type": "markdown",
182
+ "metadata": {},
183
+ "source": [
184
+ "Let's estimate FLOPs for a single forward pass."
185
+ ]
186
+ },
187
+ {
188
+ "cell_type": "code",
189
+ "execution_count": 6,
190
+ "metadata": {},
191
+ "outputs": [
192
+ {
193
+ "name": "stdout",
194
+ "output_type": "stream",
195
+ "text": [
196
+ "name flops ratio (%) \n",
197
+ "attention/kqv 3623878656 1.2426\n",
198
+ "attention/scores 1610612736 0.5522\n",
199
+ "attention/reduce 1610612736 0.5522\n",
200
+ "attention/proj 1207959552 0.4142\n",
201
+ "attention 8053063680 2.7612\n",
202
+ "mlp/ffw1 4831838208 1.6567\n",
203
+ "mlp/ffw2 4831838208 1.6567\n",
204
+ "mlp 9663676416 3.3135\n",
205
+ "block 17716740096 6.0747\n",
206
+ "transformer 212600881152 72.8963\n",
207
+ "dense 79047426048 27.1037\n",
208
+ "forward_total 291648307200 100.0000\n",
209
+ "backward_total 583296614400 200.0000\n",
210
+ "total 874944921600 300.0000\n"
211
+ ]
212
+ }
213
+ ],
214
+ "source": [
215
+ "def flops():\n",
216
+ " # we only count Weight FLOPs, all other layers (LayerNorm, Softmax, etc) are effectively irrelevant\n",
217
+ " # we count actual FLOPs, not MACs. Hence 2* all over the place\n",
218
+ " # basically for any matrix multiply A (BxC) @ B (CxD) -> (BxD) flops are 2*B*C*D\n",
219
+ "\n",
220
+ " out = OrderedDict()\n",
221
+ " head_size = n_embd // n_head\n",
222
+ "\n",
223
+ " # attention blocks\n",
224
+ " # 1) the projection to key, query, values\n",
225
+ " out['attention/kqv'] = 2 * block_size * (n_embd * 3*n_embd)\n",
226
+ " # 2) calculating the attention scores\n",
227
+ " out['attention/scores'] = 2 * block_size * block_size * n_embd\n",
228
+ " # 3) the reduction of the values (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)\n",
229
+ " out['attention/reduce'] = 2 * n_head * (block_size * block_size * head_size)\n",
230
+ " # 4) the final linear projection\n",
231
+ " out['attention/proj'] = 2 * block_size * (n_embd * n_embd)\n",
232
+ " out['attention'] = sum(out['attention/'+k] for k in ['kqv', 'scores', 'reduce', 'proj'])\n",
233
+ "\n",
234
+ " # MLP blocks\n",
235
+ " ffw_size = 4*n_embd # feed forward size\n",
236
+ " out['mlp/ffw1'] = 2 * block_size * (n_embd * ffw_size)\n",
237
+ " out['mlp/ffw2'] = 2 * block_size * (ffw_size * n_embd)\n",
238
+ " out['mlp'] = out['mlp/ffw1'] + out['mlp/ffw2']\n",
239
+ "\n",
240
+ " # the transformer and the rest of it\n",
241
+ " out['block'] = out['attention'] + out['mlp']\n",
242
+ " out['transformer'] = n_layer * out['block']\n",
243
+ " out['dense'] = 2 * block_size * (n_embd * vocab_size)\n",
244
+ "\n",
245
+ " # forward,backward,total\n",
246
+ " out['forward_total'] = out['transformer'] + out['dense']\n",
247
+ " out['backward_total'] = 2 * out['forward_total'] # use common estimate of bwd = 2*fwd\n",
248
+ " out['total'] = out['forward_total'] + out['backward_total']\n",
249
+ "\n",
250
+ " return out\n",
251
+ " \n",
252
+ "# compare our param count to that reported by PyTorch\n",
253
+ "f = flops()\n",
254
+ "flops_total = f['forward_total']\n",
255
+ "print(f\"{'name':20s} {'flops':14s} {'ratio (%)':10s}\")\n",
256
+ "for k,v in f.items():\n",
257
+ " print(f\"{k:20s} {v:14d} {v/flops_total*100:10.4f}\")\n",
258
+ " "
259
+ ]
260
+ },
261
+ {
262
+ "cell_type": "code",
263
+ "execution_count": 7,
264
+ "metadata": {},
265
+ "outputs": [
266
+ {
267
+ "name": "stdout",
268
+ "output_type": "stream",
269
+ "text": [
270
+ "palm_flops: 875062886400, flops: 874944921600, ratio: 1.0001\n"
271
+ ]
272
+ }
273
+ ],
274
+ "source": [
275
+ "# now here is an estimate copy pasted from the PaLM paper\n",
276
+ "# this formula is often used to calculate MFU (model flops utilization)\n",
277
+ "def palm_flops():\n",
278
+ " \"\"\"estimate of the model flops following PaLM paper formula\"\"\"\n",
279
+ " # non-embedding model parameters. note that we do not subtract the\n",
280
+ " # embedding/token params because those are tied and get used in the last layer.\n",
281
+ " N = params()['total'] - params()['emebedding/position']\n",
282
+ " L, H, Q, T = n_layer, n_head, n_embd//n_head, block_size\n",
283
+ " mf_per_token = 6*N + 12*L*H*Q*T\n",
284
+ " mf = mf_per_token * block_size\n",
285
+ " return mf\n",
286
+ "\n",
287
+ "print(f\"palm_flops: {palm_flops():d}, flops: {flops()['total']:d}, ratio: {palm_flops()/flops()['total']:.4f}\")"
288
+ ]
289
+ },
290
+ {
291
+ "attachments": {},
292
+ "cell_type": "markdown",
293
+ "metadata": {},
294
+ "source": [
295
+ "Ok they are quite similar, giving some confidence that my math in flops() function was ~ok. Now, A100 is cited at 312TFLOPS bfloat16 on tensor cores. So what is our model flops utilization (MFU)? I trained the model above with a batch_size of 20 and grad_accum of 5, which runs in about 755ms on a single A100 GPU. We get:"
296
+ ]
297
+ },
298
+ {
299
+ "cell_type": "code",
300
+ "execution_count": 8,
301
+ "metadata": {},
302
+ "outputs": [
303
+ {
304
+ "name": "stdout",
305
+ "output_type": "stream",
306
+ "text": [
307
+ "fraction of A100 used: 37.14%\n"
308
+ ]
309
+ }
310
+ ],
311
+ "source": [
312
+ "# here is what we currently roughly measure\n",
313
+ "batch_size = 20 * 5 # 5 is grad_accum, so total batch size is 100\n",
314
+ "measured_time = 0.755 # in seconds per iteration\n",
315
+ "measured_throughput = batch_size / measured_time\n",
316
+ "flops_achieved = f['total'] * measured_throughput\n",
317
+ "\n",
318
+ "# A100 is cited to be 312 TFLOPS of bloat16 running on tensor cores\n",
319
+ "a100_flops_promised = 312e12\n",
320
+ "\n",
321
+ "# the fraction of the A100 that we are using:\n",
322
+ "print(f\"fraction of A100 used: {flops_achieved / a100_flops_promised * 100:.2f}%\")"
323
+ ]
324
+ },
325
+ {
326
+ "attachments": {},
327
+ "cell_type": "markdown",
328
+ "metadata": {},
329
+ "source": [
330
+ "For reference, we'd prefer to be somewhere around 50%+, and not just for a single GPU but for an entire DDP run. So we still have some work to do, but at least we're within a factor of ~2X of what is achievable with this GPU."
331
+ ]
332
+ },
333
+ {
334
+ "cell_type": "code",
335
+ "execution_count": 9,
336
+ "metadata": {},
337
+ "outputs": [
338
+ {
339
+ "name": "stdout",
340
+ "output_type": "stream",
341
+ "text": [
342
+ "time needed to train the model: 3.46 days\n"
343
+ ]
344
+ }
345
+ ],
346
+ "source": [
347
+ "# Finally let's check out the 6ND approximation as total cost of training in FLOPs\n",
348
+ "model_size = params()['total'] # this is number of parameters, N\n",
349
+ "tokens_num = 300e9 # 300B tokens, this is dataset size in tokens, D\n",
350
+ "a100_flops = 312e12 # 312 TFLOPS\n",
351
+ "assumed_mfu = 0.3 # assume this model flops utilization (take the current 37% from above and add some DDP overhead)\n",
352
+ "flops_throughput = a100_flops * 8 * assumed_mfu # assume an 8XA100 node at 30% utilization\n",
353
+ "flops_needed = 6 * model_size * tokens_num # 6ND\n",
354
+ "time_needed_s = flops_needed / flops_throughput # in seconds\n",
355
+ "print(f\"time needed to train the model: {time_needed_s/3600/24:.2f} days\")"
356
+ ]
357
+ },
358
+ {
359
+ "attachments": {},
360
+ "cell_type": "markdown",
361
+ "metadata": {},
362
+ "source": [
363
+ "This is not a bad estimate at all. I trained this model and it converged in roughly 4 days. Btw as a good reference for where 6ND comes from and some intuition around it I recommend [Dzmitry's post](https://medium.com/@dzmitrybahdanau/the-flops-calculus-of-language-model-training-3b19c1f025e4)."
364
+ ]
365
+ },
366
+ {
367
+ "attachments": {},
368
+ "cell_type": "markdown",
369
+ "metadata": {},
370
+ "source": [
371
+ "Now, FLOPs are just one constraint, the other that we have to keep a close track of is the memory bandwidth. TODO estimate LOAD/STORE costs of our model later."
372
+ ]
373
+ }
374
+ ],
375
+ "metadata": {
376
+ "kernelspec": {
377
+ "display_name": "pytorch2",
378
+ "language": "python",
379
+ "name": "python3"
380
+ },
381
+ "language_info": {
382
+ "codemirror_mode": {
383
+ "name": "ipython",
384
+ "version": 3
385
+ },
386
+ "file_extension": ".py",
387
+ "mimetype": "text/x-python",
388
+ "name": "python",
389
+ "nbconvert_exporter": "python",
390
+ "pygments_lexer": "ipython3",
391
+ "version": "3.11.4"
392
+ },
393
+ "orig_nbformat": 4,
394
+ "vscode": {
395
+ "interpreter": {
396
+ "hash": "7f5833218766b48e6e35e4452ee875aac0e2188d05bbe5298f2c62b79f08b222"
397
+ }
398
+ }
399
+ },
400
+ "nbformat": 4,
401
+ "nbformat_minor": 2
402
+ }